From 1f0918c33e4e7d86cfdcc3529af09bae9802eb36 Mon Sep 17 00:00:00 2001 From: Di Xiao Date: Thu, 2 Apr 2026 11:07:07 -0700 Subject: [PATCH] Refactor XetSession commit / group CAS endpoint and auth configuration (#771) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit There's no publicly documented Xet CAS endpoint. To interact with Xet CAS, all public clients need to obtain a CAS endpoint from the same route to obtain a CAS token. Currently users need to 1. first construct a CAS token URL with respect to a certain operation ("read" or "write", targeted repo type, targeted repo, targeted revision), 2. send a request to this URL to get a CAS token and CAS endpoint, 3. use the CAS endpoint to build a `XetSession`, 4. use the `XetSession` instance and the CAS token and CAS token URL to build an upload or download group. This is a rather completed setup. This PR address this blocker by eagerly "refresh"-ing the CAS token if no CAS endpoint is provided, thus users can 1. build a `XetSession`, 2. construct a CAS token URL with respect to a certain operation ("read" or "write", targeted repo type, targeted repo, targeted revision), 3. use the `XetSession` instance and the CAS token URL to build an upload or download group. So effectively, there will be two common patterns: Pattern A: endpoint known ahead of time — no eager refresh, token_info is used as-is ``` let session = XetSessionBuilder::new().build()?; let commit = session .new_upload_commit()? .with_endpoint(cas_url) .with_token_info(token, expiry) .with_token_refresh_url(refresh_url, /*Auth headers*/) .build_blocking()?; ``` Pattern B: endpoint unknown — build call fetches it; token_info seeded from response ``` let session = XetSessionBuilder::new().build()?; let commit = session .new_upload_commit()? .with_token_refresh_url(token_refresh_url, /*Auth headers*/) .build_blocking()?; ``` Other changes: 1. `with_endpoint()` and `with_custom_headers()` configuration is moved from the `XetSession` level down to the operation level, because we can actually have multiple operations with different CAS endpoints co-exist in the same session instance. 2. Builder for different operations `XetUploadCommit`, `XetFileDownloadGroup`, `XetDownloadStreamGroup` are refactored to share common code under `struct AuthGroupBuilder`. --- Cargo.lock | 1 + ...pdate_260402_unified_auth_group_builder.md | 142 ++++ xet_client/src/cas_client/auth.rs | 19 +- xet_pkg/Cargo.toml | 1 + xet_pkg/examples/example.rs | 44 +- xet_pkg/examples/example_sync.rs | 44 +- xet_pkg/src/xet_session/auth_group_builder.rs | 100 +++ xet_pkg/src/xet_session/common.rs | 186 ++++- .../src/xet_session/download_stream_group.rs | 160 ++-- .../src/xet_session/file_download_group.rs | 311 ++++--- xet_pkg/src/xet_session/mod.rs | 55 +- xet_pkg/src/xet_session/session.rs | 183 +++-- xet_pkg/src/xet_session/upload_commit.rs | 374 +++++---- xet_pkg/tests/test_xet_session.rs | 775 ++++++++++++------ 14 files changed, 1567 insertions(+), 828 deletions(-) create mode 100644 api_changes/update_260402_unified_auth_group_builder.md create mode 100644 xet_pkg/src/xet_session/auth_group_builder.rs diff --git a/Cargo.lock b/Cargo.lock index 91f2cf55..648c377e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1963,6 +1963,7 @@ dependencies = [ "tracing", "tracing-subscriber", "ulid", + "wiremock", "xet-client", "xet-core-structures", "xet-data", diff --git a/api_changes/update_260402_unified_auth_group_builder.md b/api_changes/update_260402_unified_auth_group_builder.md new file mode 100644 index 00000000..9dc9f198 --- /dev/null +++ b/api_changes/update_260402_unified_auth_group_builder.md @@ -0,0 +1,142 @@ +# API Update: Unified auth group builder replaces per-type builder structs (2026-04-02) + +## Overview + +`UploadCommitBuilder`, `FileDownloadGroupBuilder`, and `DownloadStreamGroupBuilder` have +been replaced by `XetUploadCommitBuilder`, `XetFileDownloadGroupBuilder`, and +`XetDownloadStreamGroupBuilder`. All three share the same four configuration methods +(`with_endpoint`, `with_custom_headers`, `with_token_info`, `with_token_refresh_url`); +`build` and `build_blocking` remain type-specific. + +--- + +## Breaking Changes + +### Removed types + +- `UploadCommitBuilder` +- `FileDownloadGroupBuilder` +- `DownloadStreamGroupBuilder` + +### `with_endpoint` removed from `XetSessionBuilder` + +The CAS endpoint is no longer set on the session; it is now set per-operation on the +builder. Any call to `XetSessionBuilder::with_endpoint` must be moved to the builder +chain of each individual operation: + +```rust +// Before +let session = XetSessionBuilder::new() + .with_endpoint("https://cas.example.com") + .build()?; +let commit = session.new_upload_commit()?.build_blocking()?; + +// After +let session = XetSessionBuilder::new().build()?; +let commit = session.new_upload_commit()? + .with_endpoint("https://cas.example.com") + .build_blocking()?; +``` + +### Changed return types on `XetSession` + +All three factory methods now return the named type aliases instead of the old builder +types: + +```rust +// Before +pub fn new_upload_commit(&self) -> Result +pub fn new_file_download_group(&self) -> Result +pub fn new_download_stream_group(&self) -> Result + +// After +pub fn new_upload_commit(&self) -> Result +pub fn new_file_download_group(&self) -> Result +pub fn new_download_stream_group(&self) -> Result +``` + +--- + +## Migration Guide + +Method names and call sites are unchanged — only the concrete type used in type +annotations needs updating: + +```rust +// Before +let builder: UploadCommitBuilder = session.new_upload_commit()?; +let builder: FileDownloadGroupBuilder = session.new_file_download_group()?; +let builder: DownloadStreamGroupBuilder = session.new_download_stream_group()?; + +// After +let builder: XetUploadCommitBuilder = session.new_upload_commit()?; +let builder: XetFileDownloadGroupBuilder = session.new_file_download_group()?; +let builder: XetDownloadStreamGroupBuilder = session.new_download_stream_group()?; +``` + +In practice most call sites use method chaining and never name the builder type, so +no changes are needed: + +```rust +// Unchanged — no type annotation, chaining still works +let commit = session.new_upload_commit()? + .with_token_refresh_url(url, headers) + .with_token_info(token, expiry) + .build_blocking()?; +``` + +--- + +## New capability: `with_endpoint` and `with_custom_headers` on stream groups + +These methods were previously unavailable on `DownloadStreamGroupBuilder`. They are +now available on all three builder variants: + +```rust +let stream_group = session.new_download_stream_group()? + .with_endpoint("https://cas.example.com") + .with_custom_headers(headers) + .with_token_refresh_url(url, refresh_headers) + .build().await?; +``` + +--- + +## Endpoint resolution during `build` + +`create_translator_config` resolves the CAS endpoint in this order: + +1. `with_endpoint(...)` — used as-is if provided. +2. **If `with_endpoint` is omitted but `with_token_refresh_url` is set**, the refresher + is called once eagerly during `build` to obtain the CAS URL. The token from that + response is stored as the initial `token_info` **only if `with_token_info` was not + already called** — a pre-seeded token is preserved as-is. +3. The session's `default_cas_endpoint` from `XetConfig` — used when neither of the + above applies (local or pre-configured deployments). + +The common patterns are: + +```rust +// Pattern A: endpoint known ahead of time — no eager refresh, token_info is used as-is +session.new_upload_commit()? + .with_endpoint(cas_url) + .with_token_info(token, expiry) + .with_token_refresh_url(refresh_url, refresh_headers) + .build_blocking()?; + +// Pattern B: endpoint unknown — first build call fetches it; token_info seeded from response +session.new_upload_commit()? + .with_token_refresh_url(refresh_url, refresh_headers) + .build_blocking()?; +``` + +--- + +## Affected Files + +- `xet_pkg/src/xet_session/auth_group_builder.rs` — new shared builder implementation +- `xet_pkg/src/xet_session/common.rs` — endpoint resolution logic updated +- `xet_pkg/src/xet_session/session.rs` — factory method return types updated +- `xet_pkg/src/xet_session/upload_commit.rs` — exports `XetUploadCommitBuilder` +- `xet_pkg/src/xet_session/file_download_group.rs` — exports `XetFileDownloadGroupBuilder`; `with_endpoint`/`with_custom_headers` bug fixed +- `xet_pkg/src/xet_session/download_stream_group.rs` — exports `XetDownloadStreamGroupBuilder`; gains `with_endpoint` and `with_custom_headers` diff --git a/xet_client/src/cas_client/auth.rs b/xet_client/src/cas_client/auth.rs index 1016610f..59d91561 100644 --- a/xet_client/src/cas_client/auth.rs +++ b/xet_client/src/cas_client/auth.rs @@ -92,12 +92,8 @@ impl DirectRefreshRouteTokenRefresher { cred_helper, } } -} -#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)] -#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))] -impl TokenRefresher for DirectRefreshRouteTokenRefresher { - async fn refresh(&self) -> Result { + pub async fn get_cas_jwt(&self) -> Result { let client = self.client.clone(); let refresh_route = self.refresh_route.clone(); let cred_helper = self.cred_helper.clone(); @@ -122,8 +118,17 @@ impl TokenRefresher for DirectRefreshRouteTokenRefresher { req.send().await } }) - .await - .map_err(AuthError::token_refresh_failure)?; + .await?; + + Ok(jwt_info) + } +} + +#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)] +#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))] +impl TokenRefresher for DirectRefreshRouteTokenRefresher { + async fn refresh(&self) -> Result { + let jwt_info = self.get_cas_jwt().await.map_err(AuthError::token_refresh_failure)?; Ok((jwt_info.access_token, jwt_info.exp)) } diff --git a/xet_pkg/Cargo.toml b/xet_pkg/Cargo.toml index 14bad99a..03f683fb 100644 --- a/xet_pkg/Cargo.toml +++ b/xet_pkg/Cargo.toml @@ -47,3 +47,4 @@ smol = { workspace = true } tempfile = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread", "rt", "time", "macros"] } tracing-subscriber = { workspace = true } +wiremock = { workspace = true } diff --git a/xet_pkg/examples/example.rs b/xet_pkg/examples/example.rs index a1186293..09752a66 100644 --- a/xet_pkg/examples/example.rs +++ b/xet_pkg/examples/example.rs @@ -7,9 +7,9 @@ use std::path::PathBuf; use anyhow::Result; use clap::{Parser, Subcommand}; -use http::{HeaderMap, HeaderValue, header}; -use xet::xet_session::{Sha256Policy, XetFileDownload, XetFileMetadata, XetSessionBuilder, XetTaskState}; -use xet_client::hub_client::{self, HFRepoType, HubClient, RepoInfo}; +use xet::xet_session::{ + HeaderMap, HeaderValue, Sha256Policy, XetFileDownload, XetFileMetadata, XetSessionBuilder, XetTaskState, header, +}; #[derive(Parser)] #[clap(name = "session-demo-async", about = "XetSession async API demo")] @@ -54,25 +54,14 @@ async fn main() -> Result<()> { async fn upload_files(files: Vec, endpoint: Option) -> Result<()> { let mut hf_hub_header = HeaderMap::new(); hf_hub_header.insert(header::AUTHORIZATION, HeaderValue::from_str("Bearer [HF_WRITE_TOKEN]")?); - let hub_client = HubClient::new( - &endpoint.unwrap_or("https://huggingface.co".into()), - RepoInfo { - repo_type: HFRepoType::Model, - full_name: "user/repo".into(), - }, - Some("main".into()), - "", - None, - Some(hf_hub_header), - )?; - let token_info = hub_client.get_cas_jwt(hub_client::Operation::Upload).await?; + let endpoint = endpoint.unwrap_or("https://huggingface.co".into()); + let token_refresh_url = format!("{endpoint}/api/{}s/{}/xet-{}-token/{}", "model", "user/repo", "write", "main"); - let session = XetSessionBuilder::new().with_endpoint(token_info.cas_url).build()?; + let session = XetSessionBuilder::new().build()?; let commit = session .new_upload_commit()? - .with_token_info(token_info.access_token, token_info.exp) - //.with_token_refresh_url(token_refresh_url, hf_hub_header) // see HubClient::get_cas_jwt for how to build a token_refresh_url + .with_token_refresh_url(token_refresh_url, hf_hub_header) .build() .await?; @@ -112,25 +101,14 @@ async fn download_files(metadata_file: PathBuf, output_dir: PathBuf, endpoint: O let mut hf_hub_header = HeaderMap::new(); hf_hub_header.insert(header::AUTHORIZATION, HeaderValue::from_str("Bearer [HF_READ_TOKEN]")?); - let hub_client = HubClient::new( - &endpoint.unwrap_or("https://huggingface.co".into()), - RepoInfo { - repo_type: HFRepoType::Model, - full_name: "user/repo".into(), - }, - Some("main".into()), - "", - None, - Some(hf_hub_header), - )?; - let token_info = hub_client.get_cas_jwt(hub_client::Operation::Download).await?; + let endpoint = endpoint.unwrap_or("https://huggingface.co".into()); + let token_refresh_url = format!("{endpoint}/api/{}s/{}/xet-{}-token/{}", "model", "user/repo", "read", "main"); - let session = XetSessionBuilder::new().with_endpoint(token_info.cas_url).build()?; + let session = XetSessionBuilder::new().build()?; let group = session .new_file_download_group()? - .with_token_info(token_info.access_token, token_info.exp) - //.with_token_refresh_url(token_refresh_url, hf_hub_header) // see HubClient::get_cas_jwt for how to build a token_refresh_url + .with_token_refresh_url(token_refresh_url, hf_hub_header) .build() .await?; diff --git a/xet_pkg/examples/example_sync.rs b/xet_pkg/examples/example_sync.rs index 7c0252dd..f26b0976 100644 --- a/xet_pkg/examples/example_sync.rs +++ b/xet_pkg/examples/example_sync.rs @@ -7,9 +7,9 @@ use std::time::Duration; use anyhow::Result; use clap::{Parser, Subcommand}; -use http::{HeaderMap, HeaderValue, header}; -use xet::xet_session::{Sha256Policy, XetFileMetadata, XetSessionBuilder, XetTaskState}; -use xet_client::hub_client::{self, HFRepoType, HubClient, RepoInfo}; +use xet::xet_session::{ + HeaderMap, HeaderValue, Sha256Policy, XetFileMetadata, XetSessionBuilder, XetTaskState, header, +}; #[derive(Parser)] #[clap(name = "session-demo", about = "XetSession API demo")] @@ -53,25 +53,14 @@ fn main() -> Result<()> { fn upload_files(files: Vec, endpoint: Option) -> Result<()> { let mut hf_hub_header = HeaderMap::new(); hf_hub_header.insert(header::AUTHORIZATION, HeaderValue::from_str("Bearer [HF_WRITE_TOKEN]")?); - let hub_client = HubClient::new( - &endpoint.unwrap_or("https://huggingface.co".into()), - RepoInfo { - repo_type: HFRepoType::Model, - full_name: "user/repo".into(), - }, - Some("main".into()), - "", - None, - Some(hf_hub_header), - )?; - let token_info = smol::block_on(async move { hub_client.get_cas_jwt(hub_client::Operation::Upload).await })?; + let endpoint = endpoint.unwrap_or("https://huggingface.co".into()); + let token_refresh_url = format!("{endpoint}/api/{}s/{}/xet-{}-token/{}", "model", "user/repo", "write", "main"); - let session = XetSessionBuilder::new().with_endpoint(token_info.cas_url).build()?; + let session = XetSessionBuilder::new().build()?; let commit = session .new_upload_commit()? - .with_token_info(token_info.access_token, token_info.exp) - //.with_token_refresh_url(token_refresh_url, hf_hub_header) // see HubClient::get_cas_jwt for how to build a token_refresh_url + .with_token_refresh_url(token_refresh_url, hf_hub_header) .build_blocking()?; let n_files = files.len(); @@ -110,25 +99,14 @@ fn download_files(metadata_file: PathBuf, output_dir: PathBuf, endpoint: Option< let mut hf_hub_header = HeaderMap::new(); hf_hub_header.insert(header::AUTHORIZATION, HeaderValue::from_str("Bearer [HF_READ_TOKEN]")?); - let hub_client = HubClient::new( - &endpoint.unwrap_or("https://huggingface.co".into()), - RepoInfo { - repo_type: HFRepoType::Model, - full_name: "user/repo".into(), - }, - Some("main".into()), - "", - None, - Some(hf_hub_header), - )?; - let token_info = smol::block_on(async move { hub_client.get_cas_jwt(hub_client::Operation::Download).await })?; + let endpoint = endpoint.unwrap_or("https://huggingface.co".into()); + let token_refresh_url = format!("{endpoint}/api/{}s/{}/xet-{}-token/{}", "model", "user/repo", "read", "main"); - let session = XetSessionBuilder::new().with_endpoint(token_info.cas_url).build()?; + let session = XetSessionBuilder::new().build()?; let group = session .new_file_download_group()? - .with_token_info(token_info.access_token, token_info.exp) - //.with_token_refresh_url(token_refresh_url, hf_hub_header) // see HubClient::get_cas_jwt for how to build a token_refresh_url + .with_token_refresh_url(token_refresh_url, hf_hub_header) .build_blocking()?; // Enqueue all downloads; each starts immediately in the background. diff --git a/xet_pkg/src/xet_session/auth_group_builder.rs b/xet_pkg/src/xet_session/auth_group_builder.rs new file mode 100644 index 00000000..2c97ac28 --- /dev/null +++ b/xet_pkg/src/xet_session/auth_group_builder.rs @@ -0,0 +1,100 @@ +use std::marker::PhantomData; + +use http::HeaderMap; + +use crate::xet_session::XetSession; + +/// Per-commit/group auth and connection configuration passed to +/// [`create_translator_config`](super::common::create_translator_config) on build. +#[derive(Default)] +pub(super) struct AuthOptions { + pub(super) endpoint: Option, + pub(super) custom_headers: Option, + pub(super) token_info: Option<(String, u64)>, + pub(super) token_refresh: Option<(String, HeaderMap)>, +} + +/// Generic builder for session-scoped operation groups. +/// +/// `G` is the product type — [`XetUploadCommit`](super::upload_commit::XetUploadCommit), +/// [`XetFileDownloadGroup`](super::file_download_group::XetFileDownloadGroup), or +/// [`XetDownloadStreamGroup`](super::download_stream_group::XetDownloadStreamGroup). +/// +/// The shared configuration methods (`with_endpoint`, `with_custom_headers`, +/// `with_token_info`, `with_token_refresh_url`) are implemented once here on +/// `AuthGroupBuilder`. The `build` and `build_blocking` methods are implemented +/// separately per `G`. +/// +/// Obtain via [`XetSession::new_upload_commit`](super::session::XetSession::new_upload_commit), +/// [`XetSession::new_file_download_group`](super::session::XetSession::new_file_download_group), or +/// [`XetSession::new_download_stream_group`](super::session::XetSession::new_download_stream_group). +pub struct AuthGroupBuilder { + pub(super) session: XetSession, + pub(super) auth_options: AuthOptions, + _marker: PhantomData, +} + +impl AuthGroupBuilder { + pub(super) fn new(session: XetSession) -> Self { + Self { + session, + auth_options: Default::default(), + _marker: PhantomData, + } + } + + /// Set the Xet CAS server endpoint URL (e.g. `"https://cas.example.com"`). If this is + /// not provided but a token refresh URL is provided, during build, a request will be + /// sent to the token refresh route to fetch the CAS server endpoint. + pub fn with_endpoint(self, endpoint: impl Into) -> Self { + Self { + auth_options: AuthOptions { + endpoint: Some(endpoint.into()), + ..self.auth_options + }, + ..self + } + } + + /// Attach custom HTTP headers that are forwarded with every CAS request. + pub fn with_custom_headers(self, headers: HeaderMap) -> Self { + Self { + auth_options: AuthOptions { + custom_headers: Some(headers), + ..self.auth_options + }, + ..self + } + } + + /// Seed an initial CAS access token and its expiry as a Unix timestamp (seconds). + /// + /// If endpoint is not provided but a token refresh URL is provided, the eager refresh + /// response's token info is used only if no token was pre-seeded here. + pub fn with_token_info(self, token: impl Into, expiry: u64) -> Self { + Self { + auth_options: AuthOptions { + token_info: Some((token.into(), expiry)), + ..self.auth_options + }, + ..self + } + } + + /// Set a URL and authentication headers used to obtain a fresh CAS access token + /// whenever the current one is about to expire. + /// + /// The client issues an authenticated HTTP GET to `url` with `headers` (which should + /// include auth credentials, e.g. `Authorization: Bearer `). The endpoint + /// must return JSON: + /// `{ "accessToken": "", "exp": , "casUrl": "" }`. + pub fn with_token_refresh_url(self, url: impl Into, headers: HeaderMap) -> Self { + Self { + auth_options: AuthOptions { + token_refresh: Some((url.into(), headers)), + ..self.auth_options + }, + ..self + } + } +} diff --git a/xet_pkg/src/xet_session/common.rs b/xet_pkg/src/xet_session/common.rs index 144619e3..91449f16 100644 --- a/xet_pkg/src/xet_session/common.rs +++ b/xet_pkg/src/xet_session/common.rs @@ -1,48 +1,70 @@ use std::sync::Arc; -use http::HeaderMap; use xet_client::cas_client::auth::{DirectRefreshRouteTokenRefresher, TokenRefresher}; use xet_client::common::http_client::build_http_client; use xet_data::processing::configurations::TranslatorConfig; use super::XetSession; +use super::auth_group_builder::AuthOptions; use crate::error::XetError; -/// Builds a [`TranslatorConfig`] from the session's endpoint and shared settings, -/// combined with the per-commit/group token credentials supplied by the caller. +/// Builds a [`TranslatorConfig`] from the session defaults and the per-commit/group +/// [`AuthOptions`] supplied by the caller. /// -/// `token_info` is an optional pre-seeded `(token, expiry_unix_secs)` pair that -/// lets the CAS request use an already-known token instead of fetching one. +/// Endpoint resolution order: +/// 1. `auth_options.endpoint`, if set. +/// 2. If `token_refresh` is set but `endpoint` is not, the refresher is called once via `get_cas_jwt()` to obtain the +/// CAS URL. The resulting token is also stored as the initial `token_info` only if no `token_info` was already +/// provided. +/// 3. The session's `default_cas_endpoint` from its configuration. /// -/// `token_refresh` is an optional `(refresh_url, request_headers)` pair. When -/// present, an HTTP client is built with those headers and wrapped in a -/// [`DirectRefreshRouteTokenRefresher`] so the commit/group can fetch a fresh CAS -/// token whenever the current one is about to expire. -pub(super) fn create_translator_config( +/// `token_info` is an optional pre-seeded `(token, expiry_unix_secs)` pair. +/// +/// `token_refresh` is an optional `(refresh_url, request_headers)` pair that is +/// wrapped in a [`DirectRefreshRouteTokenRefresher`] to keep the token fresh. +/// +/// `custom_headers` are forwarded with every CAS HTTP request. +pub(super) async fn create_translator_config( session: &XetSession, - token_info: Option<(String, u64)>, - token_refresh: Option<&(String, Arc)>, + auth_options: AuthOptions, ) -> Result { let session_id = session.inner.id.to_string(); - let token_refresher: Option> = token_refresh - .map(|(url, headers)| -> Result, XetError> { - let client = build_http_client(&session_id, None, Some(headers.clone()))?; - Ok(Arc::new(DirectRefreshRouteTokenRefresher::new(url, client, None))) - }) - .transpose()?; + let AuthOptions { + mut endpoint, + custom_headers, + mut token_info, + token_refresh, + } = auth_options; - let endpoint = session - .inner - .endpoint - .clone() - .unwrap_or_else(|| session.inner.config.data.default_cas_endpoint.clone()); + // Build token refresher + let token_refresher: Option> = if let Some((url, token_refresh_headers)) = token_refresh { + let client = build_http_client(&session_id, None, Some(Arc::new(token_refresh_headers)))?; + let direct_route_refresher = DirectRefreshRouteTokenRefresher::new(url, client, None); + + // CAS endpoint is not provided but CAS token refresh endpoint is provided, we + // refresh once to get the CAS endpoint, and fill the token info if nothing is provided. + if endpoint.is_none() { + let jwt_info = direct_route_refresher.get_cas_jwt().await?; + + if token_info.is_none() { + token_info = Some((jwt_info.access_token, jwt_info.exp)); + } + endpoint = Some(jwt_info.cas_url); + } + + Some(Arc::new(direct_route_refresher)) + } else { + None + }; + + let endpoint = endpoint.unwrap_or_else(|| session.inner.config.data.default_cas_endpoint.clone()); let mut config = xet_data::processing::data_client::default_config( endpoint, token_info, token_refresher, - session.inner.custom_headers.clone(), + custom_headers.map(Arc::new), )?; if !session_id.is_empty() { @@ -51,3 +73,119 @@ pub(super) fn create_translator_config( Ok(config) } + +#[cfg(test)] +mod tests { + use http::HeaderMap; + use wiremock::matchers::{method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; + + use super::*; + use crate::xet_session::XetSessionBuilder; + use crate::xet_session::auth_group_builder::AuthOptions; + + /// Pattern A: when `endpoint` is set, the token refresh route must not be called + /// during `build` — the endpoint is used as-is and the first refresh is deferred + /// until the token actually expires. + #[tokio::test] + async fn test_endpoint_provided_skips_eager_refresh() { + let server = MockServer::start().await; + + Mock::given(method("GET")) + .and(path("/token")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "casUrl": "https://should-not-be-used.example.com", + "exp": 9_999_999_999u64, + "accessToken": "should-not-be-fetched", + }))) + .expect(0) + .mount(&server) + .await; + + let refresh_url = format!("{}/token", server.uri()); + let session = XetSessionBuilder::new().build().unwrap(); + + let auth_options = AuthOptions { + endpoint: Some("https://cas.example.com".to_string()), + custom_headers: None, + token_info: None, + token_refresh: Some((refresh_url, HeaderMap::new())), + }; + + let config = create_translator_config(&session, auth_options).await.unwrap(); + + assert_eq!(config.session.endpoint, "https://cas.example.com"); + } + + /// Pattern B: when `endpoint` is not set but `token_refresh` is set, `create_translator_config` + /// calls the refresh URL exactly once, uses the returned `cas_url` as the endpoint, + /// and seeds `token_info` from the response when none was pre-supplied. + #[tokio::test] + async fn test_eager_refresh_sets_endpoint_and_token() { + let server = MockServer::start().await; + + Mock::given(method("GET")) + .and(path("/token")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "casUrl": "https://cas.example.com", + "exp": 9_999_999_999u64, + "accessToken": "eagerly-fetched-token", + }))) + .expect(1) + .mount(&server) + .await; + + let refresh_url = format!("{}/token", server.uri()); + let session = XetSessionBuilder::new().build().unwrap(); + + let auth_options = AuthOptions { + endpoint: None, + custom_headers: None, + token_info: None, + token_refresh: Some((refresh_url, HeaderMap::new())), + }; + + let config = create_translator_config(&session, auth_options).await.unwrap(); + + assert_eq!(config.session.endpoint, "https://cas.example.com"); + let auth = config.session.auth.expect("auth config should be set"); + assert_eq!(auth.token, "eagerly-fetched-token"); + assert_eq!(auth.token_expiration, 9_999_999_999); + } + + /// Pattern B: when `token_info` is already provided alongside `token_refresh` but no `endpoint`, + /// the refresh is still called once to obtain the `cas_url`, but the pre-supplied + /// token is preserved — it must NOT be overwritten by the refresh response. + #[tokio::test] + async fn test_eager_refresh_preserves_existing_token_info() { + let server = MockServer::start().await; + + Mock::given(method("GET")) + .and(path("/token")) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "casUrl": "https://cas.example.com", + "exp": 9_999_999_999u64, + "accessToken": "eagerly-fetched-token", + }))) + .expect(1) + .mount(&server) + .await; + + let refresh_url = format!("{}/token", server.uri()); + let session = XetSessionBuilder::new().build().unwrap(); + + let auth_options = AuthOptions { + endpoint: None, + custom_headers: None, + token_info: Some(("pre-supplied-token".to_string(), 1_000_000_000)), + token_refresh: Some((refresh_url, HeaderMap::new())), + }; + + let config = create_translator_config(&session, auth_options).await.unwrap(); + + assert_eq!(config.session.endpoint, "https://cas.example.com"); + let auth = config.session.auth.expect("auth config should be set"); + assert_eq!(auth.token, "pre-supplied-token"); + assert_eq!(auth.token_expiration, 1_000_000_000); + } +} diff --git a/xet_pkg/src/xet_session/download_stream_group.rs b/xet_pkg/src/xet_session/download_stream_group.rs index 4f0c8e12..3eeef077 100644 --- a/xet_pkg/src/xet_session/download_stream_group.rs +++ b/xet_pkg/src/xet_session/download_stream_group.rs @@ -1,4 +1,4 @@ -//! [`XetDownloadStreamGroup`] and [`DownloadStreamGroupBuilder`] — authenticated +//! [`XetDownloadStreamGroup`] and [`AuthGroupBuilder`] — authenticated //! streaming download group management. //! //! [`XetDownloadStreamGroup`] manages a shared CAS connection pool and auth token for one @@ -13,69 +13,24 @@ use std::ops::Range; use std::sync::Arc; -use http::HeaderMap; use tracing::info; use xet_data::processing::{FileDownloadSession, XetFileInfo}; use xet_data::progress_tracking::UniqueID; +use super::auth_group_builder::{AuthGroupBuilder, AuthOptions}; use super::common::create_translator_config; use super::session::XetSession; use super::task_runtime::TaskRuntime; use super::{XetDownloadStream, XetUnorderedDownloadStream}; use crate::error::XetError; -/// Builder for [`XetDownloadStreamGroup`]. -/// -/// Obtain via [`XetSession::new_download_stream_group`], configure per-group auth -/// with [`with_token_info`](Self::with_token_info) and -/// [`with_token_refresh_url`](Self::with_token_refresh_url), then call -/// [`build`](Self::build) (async) or [`build_blocking`](Self::build_blocking) (sync). -pub struct DownloadStreamGroupBuilder { - session: XetSession, - token_info: Option<(String, u64)>, - token_refresh: Option<(String, Arc)>, -} - -impl DownloadStreamGroupBuilder { - pub(super) fn new(session: XetSession) -> Self { - Self { - session, - token_info: None, - token_refresh: None, - } - } - - /// Seed an initial CAS access token and its expiry as a Unix timestamp (seconds). - /// - /// When combined with [`with_token_refresh_url`](Self::with_token_refresh_url) this - /// avoids an extra refresh round-trip on the first request. - pub fn with_token_info(self, token: impl Into, expiry: u64) -> Self { - Self { - token_info: Some((token.into(), expiry)), - ..self - } - } - - /// Set a URL and authentication headers used to obtain a fresh CAS access token - /// whenever the current one is about to expire. - /// - /// The client issues an authenticated HTTP GET to `url` with `headers` (which should - /// include auth credentials, e.g. `Authorization: Bearer `). The endpoint - /// must return JSON: - /// `{ "accessToken": "", "exp": , "casUrl": "" }`. - pub fn with_token_refresh_url(self, url: impl Into, headers: HeaderMap) -> Self { - Self { - token_refresh: Some((url.into(), Arc::new(headers))), - ..self - } - } +pub type XetDownloadStreamGroupBuilder = AuthGroupBuilder; +impl AuthGroupBuilder { /// Create the [`XetDownloadStreamGroup`] from an async context. pub async fn build(self) -> Result { - let DownloadStreamGroupBuilder { - session, - token_info, - token_refresh, + let AuthGroupBuilder { + session, auth_options, .. } = self; let session_for_reg = session.clone(); let parent_runtime = session.inner.task_runtime.clone(); @@ -83,7 +38,7 @@ impl DownloadStreamGroupBuilder { let group = parent_runtime .bridge_async("new_download_stream_group", async move { let group_runtime = child_parent.child()?; - XetDownloadStreamGroup::new(session, group_runtime, token_info, token_refresh).await + XetDownloadStreamGroup::new(session, group_runtime, auth_options).await }) .await?; info!("New download stream group, session_id={}, group_id={}", group.session().id(), group.id()); @@ -103,17 +58,15 @@ impl DownloadStreamGroupBuilder { /// /// Panics if called from within a tokio async runtime on an Owned-mode session. pub fn build_blocking(self) -> Result { - let DownloadStreamGroupBuilder { - session, - token_info, - token_refresh, + let AuthGroupBuilder { + session, auth_options, .. } = self; let session_for_reg = session.clone(); let parent_runtime = session.inner.task_runtime.clone(); let child_parent = parent_runtime.clone(); let group = parent_runtime.bridge_sync("new_download_stream_group_blocking", async move { let group_runtime = child_parent.child()?; - XetDownloadStreamGroup::new(session, group_runtime, token_info, token_refresh).await + XetDownloadStreamGroup::new(session, group_runtime, auth_options).await })?; info!("New download stream group, session_id={}, group_id={}", group.session().id(), group.id()); session_for_reg.register_download_stream_group(&group)?; @@ -164,25 +117,23 @@ impl XetDownloadStreamGroupInner { } impl XetDownloadStreamGroup { - /// Create a new download stream group. Called by [`DownloadStreamGroupBuilder::build`]. + /// Create a new download stream group. Called by [`AuthGroupBuilder::build`]. async fn new( session: XetSession, task_runtime: Arc, - token_info: Option<(String, u64)>, - token_refresh: Option<(String, Arc)>, + auth_options: AuthOptions, ) -> Result { let group_id = UniqueID::new(); - let config = create_translator_config(&session, token_info, token_refresh.as_ref())?; + let config = create_translator_config(&session, auth_options).await?; let download_session = FileDownloadSession::new(Arc::new(config), None).await?; - Ok(Self { - inner: Arc::new(XetDownloadStreamGroupInner { - session, - group_id, - download_session, - }), - task_runtime, - }) + let inner = Arc::new(XetDownloadStreamGroupInner { + session, + group_id, + download_session, + }); + + Ok(Self { inner, task_runtime }) } /// Returns the unique ID for this stream group. @@ -319,25 +270,19 @@ impl XetDownloadStreamGroup { #[cfg(test)] mod tests { - use tempfile::{TempDir, tempdir}; + use tempfile::tempdir; use xet_data::processing::{Sha256Policy, XetFileInfo}; use super::super::session::{XetSession, XetSessionBuilder}; use super::*; - fn local_session(temp: &TempDir) -> Result> { - let cas_path = temp.path().join("cas"); - Ok(XetSessionBuilder::new() - .with_endpoint(format!("local://{}", cas_path.display())) - .build()?) - } - async fn upload_bytes( session: &XetSession, + endpoint: &str, data: &[u8], name: &str, ) -> Result> { - let commit = session.new_upload_commit()?.build().await?; + let commit = session.new_upload_commit()?.with_endpoint(endpoint).build().await?; let _handle = commit .upload_bytes(data.to_vec(), Sha256Policy::Compute, Some(name.into())) .await?; @@ -348,33 +293,46 @@ mod tests { fn upload_bytes_blocking( session: &XetSession, + endpoint: &str, data: &[u8], name: &str, ) -> Result> { - let commit = session.new_upload_commit()?.build_blocking()?; + let commit = session.new_upload_commit()?.with_endpoint(endpoint).build_blocking()?; let _handle = commit.upload_bytes_blocking(data.to_vec(), Sha256Policy::Compute, Some(name.into()))?; let results = commit.commit_blocking()?; let meta = results.uploads.into_values().next().expect("one uploaded file"); Ok(meta.xet_info) } - async fn stream_group_async(session: &XetSession) -> XetDownloadStreamGroup { - session.new_download_stream_group().unwrap().build().await.unwrap() + async fn stream_group_async(session: &XetSession, endpoint: &str) -> XetDownloadStreamGroup { + session + .new_download_stream_group() + .unwrap() + .with_endpoint(endpoint) + .build() + .await + .unwrap() } - fn stream_group_sync(session: &XetSession) -> XetDownloadStreamGroup { - session.new_download_stream_group().unwrap().build_blocking().unwrap() + fn stream_group_sync(session: &XetSession, endpoint: &str) -> XetDownloadStreamGroup { + session + .new_download_stream_group() + .unwrap() + .with_endpoint(endpoint) + .build_blocking() + .unwrap() } #[tokio::test(flavor = "multi_thread")] // Async streaming download round-trip: upload, stream, verify content. async fn test_download_stream_round_trip() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let original = b"Hello, streaming download!"; - let file_info = upload_bytes(&session, original, "stream.bin").await.unwrap(); + let file_info = upload_bytes(&session, &endpoint, original, "stream.bin").await.unwrap(); - let group = stream_group_async(&session).await; + let group = stream_group_async(&session, &endpoint).await; let mut stream = group.download_stream(file_info, None).await.unwrap(); let mut collected = Vec::new(); while let Some(chunk) = stream.next().await.unwrap() { @@ -387,11 +345,12 @@ mod tests { // Blocking streaming download round-trip: upload, stream, verify content. fn test_download_stream_blocking_round_trip() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let original = b"Hello, blocking streaming download!"; - let file_info = upload_bytes_blocking(&session, original, "stream.bin").unwrap(); + let file_info = upload_bytes_blocking(&session, &endpoint, original, "stream.bin").unwrap(); - let group = stream_group_sync(&session); + let group = stream_group_sync(&session, &endpoint); let mut stream = group.download_stream_blocking(file_info, None).unwrap(); let mut collected = Vec::new(); @@ -405,11 +364,12 @@ mod tests { // get_progress() reports correct totals after consuming the stream. async fn test_download_stream_progress_reports_completion() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let original = b"progress tracking test data for streaming"; - let file_info = upload_bytes(&session, original, "progress.bin").await.unwrap(); + let file_info = upload_bytes(&session, &endpoint, original, "progress.bin").await.unwrap(); - let group = stream_group_async(&session).await; + let group = stream_group_async(&session, &endpoint).await; let mut stream = group.download_stream(file_info, None).await.unwrap(); let initial = stream.progress(); assert_eq!(initial.total_bytes, original.len() as u64); @@ -430,11 +390,12 @@ mod tests { // get_progress() works correctly in blocking mode. fn test_download_stream_blocking_progress_reports_completion() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let original = b"blocking progress tracking test data"; - let file_info = upload_bytes_blocking(&session, original, "progress.bin").unwrap(); + let file_info = upload_bytes_blocking(&session, &endpoint, original, "progress.bin").unwrap(); - let group = stream_group_sync(&session); + let group = stream_group_sync(&session, &endpoint); let mut stream = group.download_stream_blocking(file_info, None).unwrap(); let mut collected = Vec::new(); @@ -452,13 +413,14 @@ mod tests { // Multiple sequential streaming downloads use the same group. async fn test_download_stream_multiple_sequential() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data_a = b"first stream payload"; let data_b = b"second stream payload"; - let info_a = upload_bytes(&session, data_a, "a.bin").await.unwrap(); - let info_b = upload_bytes(&session, data_b, "b.bin").await.unwrap(); + let info_a = upload_bytes(&session, &endpoint, data_a, "a.bin").await.unwrap(); + let info_b = upload_bytes(&session, &endpoint, data_b, "b.bin").await.unwrap(); - let group = stream_group_async(&session).await; + let group = stream_group_async(&session, &endpoint).await; let mut stream_a = group.download_stream(info_a, None).await.unwrap(); let mut collected_a = Vec::new(); while let Some(chunk) = stream_a.next().await.unwrap() { diff --git a/xet_pkg/src/xet_session/file_download_group.rs b/xet_pkg/src/xet_session/file_download_group.rs index 2559f2b1..41227186 100644 --- a/xet_pkg/src/xet_session/file_download_group.rs +++ b/xet_pkg/src/xet_session/file_download_group.rs @@ -4,76 +4,31 @@ use std::collections::HashMap; use std::path::PathBuf; use std::sync::{Arc, RwLock}; -use http::HeaderMap; use tracing::info; use xet_data::processing::{FileDownloadSession, XetFileInfo}; use xet_data::progress_tracking::{GroupProgressReport, UniqueID}; +use super::auth_group_builder::{AuthGroupBuilder, AuthOptions}; use super::common::create_translator_config; use super::file_download_handle::{XetDownloadReport, XetFileDownload, XetFileDownloadInner}; use super::session::XetSession; use super::task_runtime::{BackgroundTaskState, TaskRuntime, XetTaskState}; use crate::error::XetError; -/// Builder for [`XetFileDownloadGroup`]. -/// -/// Obtain via [`XetSession::new_file_download_group`], configure per-group auth -/// with [`with_token_info`](Self::with_token_info) and -/// [`with_token_refresh_url`](Self::with_token_refresh_url), then call -/// [`build`](Self::build) (async) or [`build_blocking`](Self::build_blocking) (sync). -pub struct FileDownloadGroupBuilder { - session: XetSession, - token_info: Option<(String, u64)>, - token_refresh: Option<(String, Arc)>, -} - -impl FileDownloadGroupBuilder { - pub(super) fn new(session: XetSession) -> Self { - Self { - session, - token_info: None, - token_refresh: None, - } - } - - /// Seed an initial CAS access token and its expiry as a Unix timestamp (seconds). - /// - /// When combined with [`with_token_refresh_url`](Self::with_token_refresh_url) this - /// avoids an extra refresh round-trip on the first request. - pub fn with_token_info(self, token: impl Into, expiry: u64) -> Self { - Self { - token_info: Some((token.into(), expiry)), - ..self - } - } - - /// Set a URL and authentication headers used to obtain a fresh CAS access token - /// whenever the current one is about to expire. - /// - /// The client issues an authenticated HTTP GET to `url` with `headers` (which should - /// include auth credentials, e.g. `Authorization: Bearer `). The endpoint - /// must return JSON: - /// `{ "accessToken": "", "exp": , "casUrl": "" }`. - pub fn with_token_refresh_url(self, url: impl Into, headers: HeaderMap) -> Self { - Self { - token_refresh: Some((url.into(), Arc::new(headers))), - ..self - } - } +pub type XetFileDownloadGroupBuilder = AuthGroupBuilder; +impl AuthGroupBuilder { /// Create the [`XetFileDownloadGroup`] from an async context. pub async fn build(self) -> Result { - let FileDownloadGroupBuilder { - session, - token_info, - token_refresh, + let AuthGroupBuilder { + session, auth_options, .. } = self; let parent_runtime = session.inner.task_runtime.clone(); let child_parent = parent_runtime.clone(); let group = parent_runtime .bridge_async("new_file_download_group", async move { let group_runtime = child_parent.child()?; - XetFileDownloadGroup::new(session, group_runtime, token_info, token_refresh).await + XetFileDownloadGroup::new(session, group_runtime, auth_options).await }) .await?; info!("New file download group, session_id={}, group_id={}", group.session().id(), group.id()); @@ -94,16 +49,14 @@ impl FileDownloadGroupBuilder { /// /// Panics if called from within a tokio async runtime on an Owned-mode session. pub fn build_blocking(self) -> Result { - let FileDownloadGroupBuilder { - session, - token_info, - token_refresh, + let AuthGroupBuilder { + session, auth_options, .. } = self; let parent_runtime = session.inner.task_runtime.clone(); let child_parent = parent_runtime.clone(); let group = parent_runtime.bridge_sync("new_file_download_group_blocking", async move { let group_runtime = child_parent.child()?; - XetFileDownloadGroup::new(session, group_runtime, token_info, token_refresh).await + XetFileDownloadGroup::new(session, group_runtime, auth_options).await })?; info!("New file download group, session_id={}, group_id={}", group.session().id(), group.id()); group.session().register_file_download_group(&group)?; @@ -126,9 +79,9 @@ pub struct XetDownloadGroupReport { /// API for grouping related file downloads into a single unit of work. /// /// Obtain via [`XetSession::new_file_download_group`] — configure per-group -/// auth on the returned [`FileDownloadGroupBuilder`], then call -/// [`build`](FileDownloadGroupBuilder::build) (async) or -/// [`build_blocking`](FileDownloadGroupBuilder::build_blocking) (sync). +/// auth on the returned [`AuthGroupBuilder`], then call +/// [`build`](AuthGroupBuilder::build) (async) or +/// [`build_blocking`](AuthGroupBuilder::build_blocking) (sync). /// /// Queue files with [`download_file_to_path`](Self::download_file_to_path) (they start /// downloading immediately in the background), poll progress with @@ -159,11 +112,10 @@ impl XetFileDownloadGroup { pub(super) async fn new( session: XetSession, task_runtime: Arc, - token_info: Option<(String, u64)>, - token_refresh: Option<(String, Arc)>, + auth_options: AuthOptions, ) -> Result { let group_id = UniqueID::new(); - let config = create_translator_config(&session, token_info, token_refresh.as_ref())?; + let config = create_translator_config(&session, auth_options).await?; let download_session = FileDownloadSession::new(Arc::new(config), None).await?; let inner = Arc::new(XetFileDownloadGroupInner { @@ -423,22 +375,15 @@ mod tests { use std::time::Duration; use anyhow::Result; - use tempfile::{TempDir, tempdir}; + use tempfile::tempdir; use xet_data::processing::Sha256Policy; use xet_runtime::core::RuntimeMode; use super::*; use crate::xet_session::session::{XetSession, XetSessionBuilder}; - fn local_session(temp: &TempDir) -> Result { - let cas_path = temp.path().join("cas"); - Ok(XetSessionBuilder::new() - .with_endpoint(format!("local://{}", cas_path.display())) - .build()?) - } - - async fn upload_bytes(session: &XetSession, data: &[u8], name: &str) -> Result { - let commit = session.new_upload_commit()?.build().await?; + async fn upload_bytes(session: &XetSession, endpoint: &str, data: &[u8], name: &str) -> Result { + let commit = session.new_upload_commit()?.with_endpoint(endpoint).build().await?; let _handle = commit .upload_bytes(data.to_vec(), Sha256Policy::Compute, Some(name.into())) .await?; @@ -629,12 +574,19 @@ mod tests { // Downloading a previously uploaded file produces byte-identical content at the destination. async fn test_download_file_round_trip() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let original = b"Hello, download round-trip!"; - let file_info = upload_bytes(&session, original, "payload.bin").await.unwrap(); + let file_info = upload_bytes(&session, &endpoint, original, "payload.bin").await.unwrap(); let dest = temp.path().join("downloaded.bin"); - let group = session.new_file_download_group().unwrap().build().await.unwrap(); + let group = session + .new_file_download_group() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let handle = group.download_file_to_path(file_info, dest.clone()).await.unwrap(); assert!(matches!(handle.status().unwrap(), XetTaskState::Running | XetTaskState::Completed)); group.finish().await.unwrap(); @@ -647,8 +599,15 @@ mod tests { // A download task that fails transitions to Error status. async fn test_download_status_failed_for_invalid_file_info() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); - let group = session.new_file_download_group().unwrap().build().await.unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let group = session + .new_file_download_group() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let handle = group .download_file_to_path( XetFileInfo { @@ -669,12 +628,19 @@ mod tests { // task_id returned by download_file_to_path must match the per-item progress entry id. async fn test_download_task_id_matches_progress_item_id() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let original = b"download id match"; - let file_info = upload_bytes(&session, original, "id.bin").await.unwrap(); + let file_info = upload_bytes(&session, &endpoint, original, "id.bin").await.unwrap(); let dest = temp.path().join("download_id.bin"); - let group = session.new_file_download_group().unwrap().build().await.unwrap(); + let group = session + .new_file_download_group() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let handle = group.download_file_to_path(file_info, dest).await.unwrap(); let mut reports = HashMap::new(); @@ -695,13 +661,20 @@ mod tests { // Downloading multiple files from a single group produces correct content for each. async fn test_download_multiple_files() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data_a = b"First file content"; let data_b = b"Second file content - different"; let (file_a_info, file_b_info) = { - let commit = session.new_upload_commit().unwrap().build().await.unwrap(); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let _handle_a = commit .upload_bytes(data_a.to_vec(), Sha256Policy::Compute, Some("a.bin".into())) .await @@ -719,7 +692,13 @@ mod tests { let dest_a = temp.path().join("a_out.bin"); let dest_b = temp.path().join("b_out.bin"); - let group = session.new_file_download_group().unwrap().build().await.unwrap(); + let group = session + .new_file_download_group() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); group.download_file_to_path(file_a_info, dest_a.clone()).await.unwrap(); group.download_file_to_path(file_b_info, dest_b.clone()).await.unwrap(); group.finish().await.unwrap(); @@ -732,12 +711,19 @@ mod tests { // After a successful finish the aggregate download progress reflects bytes received. async fn test_download_progress_reflects_bytes_after_finish() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let original = b"download progress tracking data"; - let file_info = upload_bytes(&session, original, "prog.bin").await.unwrap(); + let file_info = upload_bytes(&session, &endpoint, original, "prog.bin").await.unwrap(); let dest = temp.path().join("out.bin"); - let group = session.new_file_download_group().unwrap().build().await.unwrap(); + let group = session + .new_file_download_group() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let progress_observer = group.clone(); group.download_file_to_path(file_info, dest).await.unwrap(); let finish_report = group.finish().await.unwrap(); @@ -767,11 +753,18 @@ mod tests { // Pattern 1: per-task result is accessible via task_id in the finish report downloads map. async fn test_download_result_accessible_via_task_id_in_finish_map() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data = b"result via task_id in finish map"; - let file_info = upload_bytes(&session, data, "file.bin").await.unwrap(); + let file_info = upload_bytes(&session, &endpoint, data, "file.bin").await.unwrap(); let dest = temp.path().join("out.bin"); - let group = session.new_file_download_group().unwrap().build().await.unwrap(); + let group = session + .new_file_download_group() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let handle = group.download_file_to_path(file_info, dest).await.unwrap(); let report = group.finish().await.unwrap(); let result = report @@ -785,10 +778,17 @@ mod tests { // XetFileDownload::result() returns None before finish() is called. async fn test_download_result_none_before_finish() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); - let file_info = upload_bytes(&session, b"some data", "file.bin").await.unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let file_info = upload_bytes(&session, &endpoint, b"some data", "file.bin").await.unwrap(); let dest = temp.path().join("out.bin"); - let group = session.new_file_download_group().unwrap().build().await.unwrap(); + let group = session + .new_file_download_group() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let handle = group.download_file_to_path(file_info, dest).await.unwrap(); assert!(handle.result().is_none(), "result must be None before finish()"); group.finish().await.unwrap(); @@ -798,11 +798,18 @@ mod tests { // XetFileDownload::result() returns Some after finish() completes. async fn test_download_result_some_after_finish() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data = b"download result test data"; - let file_info = upload_bytes(&session, data, "file.bin").await.unwrap(); + let file_info = upload_bytes(&session, &endpoint, data, "file.bin").await.unwrap(); let dest = temp.path().join("out.bin"); - let group = session.new_file_download_group().unwrap().build().await.unwrap(); + let group = session + .new_file_download_group() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let handle = group.download_file_to_path(file_info.clone(), dest).await.unwrap(); group.finish().await.unwrap(); let result = handle.result().expect("result must be set after finish()"); @@ -814,11 +821,18 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_download_finish_second_call_returns_cached_result() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data = b"download finish cache test"; - let file_info = upload_bytes(&session, data, "cache.bin").await.unwrap(); + let file_info = upload_bytes(&session, &endpoint, data, "cache.bin").await.unwrap(); let dest = temp.path().join("cache.out"); - let group = session.new_file_download_group().unwrap().build().await.unwrap(); + let group = session + .new_file_download_group() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let handle = group.download_file_to_path(file_info, dest).await.unwrap(); let first = handle.finish().await.unwrap(); let second = handle.finish().await.unwrap(); @@ -837,12 +851,19 @@ mod tests { let temp = tempdir().unwrap(); futures::executor::block_on(async { - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); assert_eq!(session.inner.runtime.mode(), RuntimeMode::Owned); let data = b"hello from futures executor"; let file_info = { - let commit = session.new_upload_commit().unwrap().build().await.unwrap(); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let _handle = commit .upload_bytes(data.to_vec(), Sha256Policy::Compute, Some("test.bin".into())) .await @@ -852,7 +873,13 @@ mod tests { }; let dest = temp.path().join("out_futures.bin"); - let group = session.new_file_download_group().unwrap().build().await.unwrap(); + let group = session + .new_file_download_group() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); group.download_file_to_path(file_info, dest.clone()).await.unwrap(); group.finish().await.unwrap(); assert_eq!(std::fs::read(&dest).unwrap(), data); @@ -865,12 +892,19 @@ mod tests { let temp = tempdir().unwrap(); smol::block_on(async { - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); assert_eq!(session.inner.runtime.mode(), RuntimeMode::Owned); let data = b"hello from smol executor"; let file_info = { - let commit = session.new_upload_commit().unwrap().build().await.unwrap(); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let _handle = commit .upload_bytes(data.to_vec(), Sha256Policy::Compute, Some("test.bin".into())) .await @@ -880,7 +914,13 @@ mod tests { }; let dest = temp.path().join("out_smol.bin"); - let group = session.new_file_download_group().unwrap().build().await.unwrap(); + let group = session + .new_file_download_group() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); group.download_file_to_path(file_info, dest.clone()).await.unwrap(); group.finish().await.unwrap(); assert_eq!(std::fs::read(&dest).unwrap(), data); @@ -893,12 +933,19 @@ mod tests { let temp = tempdir().unwrap(); async_std::task::block_on(async { - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); assert_eq!(session.inner.runtime.mode(), RuntimeMode::Owned); let data = b"hello from async-std executor"; let file_info = { - let commit = session.new_upload_commit().unwrap().build().await.unwrap(); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let _handle = commit .upload_bytes(data.to_vec(), Sha256Policy::Compute, Some("test.bin".into())) .await @@ -908,7 +955,13 @@ mod tests { }; let dest = temp.path().join("out_async_std.bin"); - let group = session.new_file_download_group().unwrap().build().await.unwrap(); + let group = session + .new_file_download_group() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); group.download_file_to_path(file_info, dest.clone()).await.unwrap(); group.finish().await.unwrap(); assert_eq!(std::fs::read(&dest).unwrap(), data); @@ -917,8 +970,8 @@ mod tests { // ── Blocking API tests ──────────────────────────────────────────────────── - fn upload_bytes_blocking(session: &XetSession, data: &[u8], name: &str) -> Result { - let commit = session.new_upload_commit()?.build_blocking()?; + fn upload_bytes_blocking(session: &XetSession, endpoint: &str, data: &[u8], name: &str) -> Result { + let commit = session.new_upload_commit()?.with_endpoint(endpoint).build_blocking()?; let _handle = commit.upload_bytes_blocking(data.to_vec(), Sha256Policy::Compute, Some(name.into()))?; let results = commit.commit_blocking()?; let meta = results.uploads.into_values().next().expect("one uploaded file"); @@ -928,12 +981,13 @@ mod tests { #[test] fn test_blocking_download_file_round_trip() -> Result<()> { let temp = tempdir()?; - let session = local_session(&temp)?; + let session = XetSessionBuilder::new().build()?; + let endpoint = format!("local://{}", temp.path().join("cas").display()); let original = b"Hello, download round-trip!"; - let file_info = upload_bytes_blocking(&session, original, "payload.bin")?; + let file_info = upload_bytes_blocking(&session, &endpoint, original, "payload.bin")?; let dest = temp.path().join("downloaded.bin"); - let group = session.new_file_download_group()?.build_blocking()?; + let group = session.new_file_download_group()?.with_endpoint(&endpoint).build_blocking()?; group.download_file_to_path_blocking(file_info, dest.clone())?; group.finish_blocking()?; @@ -944,13 +998,14 @@ mod tests { #[test] fn test_blocking_download_multiple_files() -> Result<()> { let temp = tempdir()?; - let session = local_session(&temp)?; + let session = XetSessionBuilder::new().build()?; + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data_a = b"First file content"; let data_b = b"Second file content - different"; let (file_a_info, file_b_info) = { - let commit = session.new_upload_commit()?.build_blocking()?; + let commit = session.new_upload_commit()?.with_endpoint(&endpoint).build_blocking()?; let _handle_a = commit.upload_bytes_blocking(data_a.to_vec(), Sha256Policy::Compute, Some("a.bin".into()))?; let _handle_b = @@ -964,7 +1019,7 @@ mod tests { let dest_a = temp.path().join("a_out.bin"); let dest_b = temp.path().join("b_out.bin"); - let group = session.new_file_download_group()?.build_blocking()?; + let group = session.new_file_download_group()?.with_endpoint(&endpoint).build_blocking()?; group.download_file_to_path_blocking(file_a_info, dest_a.clone())?; group.download_file_to_path_blocking(file_b_info, dest_b.clone())?; group.finish_blocking()?; @@ -977,12 +1032,13 @@ mod tests { #[test] fn test_blocking_download_progress_reflects_bytes_after_finish() -> Result<()> { let temp = tempdir()?; - let session = local_session(&temp)?; + let session = XetSessionBuilder::new().build()?; + let endpoint = format!("local://{}", temp.path().join("cas").display()); let original = b"download progress tracking data"; - let file_info = upload_bytes_blocking(&session, original, "prog.bin")?; + let file_info = upload_bytes_blocking(&session, &endpoint, original, "prog.bin")?; let dest = temp.path().join("out.bin"); - let group = session.new_file_download_group()?.build_blocking()?; + let group = session.new_file_download_group()?.with_endpoint(&endpoint).build_blocking()?; let progress_observer = group.clone(); group.download_file_to_path_blocking(file_info, dest)?; group.finish_blocking()?; @@ -1007,11 +1063,12 @@ mod tests { #[test] fn test_blocking_download_result_access_patterns() -> Result<()> { let temp = tempdir()?; - let session = local_session(&temp)?; + let session = XetSessionBuilder::new().build()?; + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data = b"download result access patterns"; - let file_info = upload_bytes_blocking(&session, data, "file.bin")?; + let file_info = upload_bytes_blocking(&session, &endpoint, data, "file.bin")?; let dest = temp.path().join("out.bin"); - let group = session.new_file_download_group()?.build_blocking()?; + let group = session.new_file_download_group()?.with_endpoint(&endpoint).build_blocking()?; let handle = group.download_file_to_path_blocking(file_info.clone(), dest)?; // Before finish, per-task result is not available yet. @@ -1039,13 +1096,19 @@ mod tests { R: FnOnce(std::pin::Pin>>), { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); run(Box::pin(async move { let data = b"download from smol executor"; - let file_info = upload_bytes_blocking(&session, data, "test.bin").unwrap(); + let file_info = upload_bytes_blocking(&session, &endpoint, data, "test.bin").unwrap(); let dest = temp.path().join("out_smol.bin"); - let group = session.new_file_download_group().unwrap().build_blocking().unwrap(); + let group = session + .new_file_download_group() + .unwrap() + .with_endpoint(&endpoint) + .build_blocking() + .unwrap(); group.download_file_to_path_blocking(file_info, dest.clone()).unwrap(); group.finish_blocking().unwrap(); assert_eq!(std::fs::read(&dest).unwrap(), data); diff --git a/xet_pkg/src/xet_session/mod.rs b/xet_pkg/src/xet_session/mod.rs index 38497ac2..c3457535 100644 --- a/xet_pkg/src/xet_session/mod.rs +++ b/xet_pkg/src/xet_session/mod.rs @@ -4,10 +4,10 @@ //! file operations: //! //! ```text -//! XetSession — holds runtime context and shared HTTP settings -//! ├── UploadCommitBuilder — configures per-commit auth; build() → XetUploadCommit -//! ├── FileDownloadGroupBuilder — configures per-group auth; build() → XetDownloadGroup -//! └── DownloadStreamGroupBuilder — configures per-group auth; build() → DownloadStreamGroup +//! XetSession — holds runtime context and shared HTTP settings +//! ├── AuthGroupBuilder — configures per-commit auth; build() → XetUploadCommit +//! ├── AuthGroupBuilder — configures per-group auth; build() → XetFileDownloadGroup +//! └── AuthGroupBuilder — configures per-group auth; build() → XetDownloadStreamGroup //! ``` //! //! Each [`XetSession`] holds its own runtime context and configuration, so @@ -18,11 +18,11 @@ //! //! ## Uploads //! -//! Call [`XetSession::new_upload_commit`] to obtain an [`UploadCommitBuilder`]. -//! Configure auth with [`with_token_info`](UploadCommitBuilder::with_token_info) and -//! [`with_token_refresh_url`](UploadCommitBuilder::with_token_refresh_url), then call -//! [`build`](UploadCommitBuilder::build) (async) or -//! [`build_blocking`](UploadCommitBuilder::build_blocking) (sync). +//! Call [`XetSession::new_upload_commit`] to obtain an [`AuthGroupBuilder`]. +//! Configure auth with [`with_token_info`](AuthGroupBuilder::with_token_info) and +//! [`with_token_refresh_url`](AuthGroupBuilder::with_token_refresh_url), then call +//! [`build`](AuthGroupBuilder::build) (async) or +//! [`build_blocking`](AuthGroupBuilder::build_blocking) (sync). //! Queue files with [`upload_from_path`](XetUploadCommit::upload_from_path) / //! [`upload_from_path_blocking`](XetUploadCommit::upload_from_path_blocking) or //! [`upload_bytes`](XetUploadCommit::upload_bytes) / @@ -38,9 +38,9 @@ //! //! ## File Downloads //! -//! Call [`XetSession::new_file_download_group`] to obtain a [`FileDownloadGroupBuilder`]. -//! Configure auth similarly, then call [`build`](FileDownloadGroupBuilder::build) (async) or -//! [`build_blocking`](FileDownloadGroupBuilder::build_blocking) (sync). +//! Call [`XetSession::new_file_download_group`] to obtain an [`AuthGroupBuilder`]. +//! Configure auth similarly, then call [`build`](AuthGroupBuilder::build) (async) or +//! [`build_blocking`](AuthGroupBuilder::build_blocking) (sync). //! Queue files with [`download_file_to_path`](XetFileDownloadGroup::download_file_to_path) / //! [`download_file_to_path_blocking`](XetFileDownloadGroup::download_file_to_path_blocking), //! then call [`finish`](XetFileDownloadGroup::finish) (async) or @@ -50,9 +50,9 @@ //! //! ## Streaming Downloads //! -//! Call [`XetSession::new_download_stream_group`] to obtain a [`DownloadStreamGroupBuilder`]. -//! Configure auth similarly, then call [`build`](DownloadStreamGroupBuilder::build) (async) or -//! [`build_blocking`](DownloadStreamGroupBuilder::build_blocking) (sync). +//! Call [`XetSession::new_download_stream_group`] to obtain an [`AuthGroupBuilder`]. +//! Configure auth similarly, then call [`build`](AuthGroupBuilder::build) (async) or +//! [`build_blocking`](AuthGroupBuilder::build_blocking) (sync). //! Create individual streams with //! [`download_stream`](XetDownloadStreamGroup::download_stream) / //! [`download_stream_blocking`](XetDownloadStreamGroup::download_stream_blocking) for @@ -83,7 +83,7 @@ //! //! Session-level factory methods and upload/file-download operations return //! `Result<_, `[`SessionError`]`>`. -//! Streaming operations — [`DownloadStreamGroupBuilder::build`], +//! Streaming operations — [`AuthGroupBuilder::build`] (for `XetDownloadStreamGroup`), //! [`XetDownloadStreamGroup`] methods, [`XetDownloadStream`] methods, and //! [`XetUnorderedDownloadStream`] methods — return `Result<_, XetError>`. //! [`commit`](XetUploadCommit::commit) returns a [`XetCommitReport`] containing @@ -97,11 +97,13 @@ //! ```rust,no_run //! use xet::xet_session::{Sha256Policy, XetFileInfo, XetSessionBuilder}; //! -//! let session = XetSessionBuilder::new().with_endpoint("https://cas.example.com").build()?; +//! # fn example() -> Result<(), xet::xet_session::SessionError> { +//! let session = XetSessionBuilder::new().build()?; //! -//! // Upload — configure token on the commit builder, then build_blocking +//! // Upload — configure endpoint and token on the commit builder, then build_blocking //! let commit = session //! .new_upload_commit()? +//! .with_endpoint("https://cas.example.com") //! .with_token_info("write-token", 1_700_000_000) //! .build_blocking()?; //! let handle = commit.upload_from_path_blocking("file.bin".into(), Sha256Policy::Compute)?; @@ -118,8 +120,8 @@ //! let dl_handle = group.download_file_to_path_blocking(info, "out/file.bin".into())?; //! let finish_report = group.finish_blocking()?; //! let r = finish_report.downloads.get(&dl_handle.task_id()).unwrap(); -//! -//! # Ok::<(), xet::xet_session::SessionError>(()) +//! # Ok(()) +//! # } //! ``` //! //! # Quick start — async API @@ -130,11 +132,12 @@ //! # async fn example() -> Result<(), xet::xet_session::SessionError> { //! // build() auto-detects: if inside a suitable tokio runtime, wraps it; //! // otherwise creates an owned thread pool. -//! let session = XetSessionBuilder::new().with_endpoint("https://cas.example.com").build()?; +//! let session = XetSessionBuilder::new().build()?; //! -//! // Upload — configure token on the commit builder, then build().await +//! // Upload — configure endpoint and token on the commit builder, then build().await //! let commit = session //! .new_upload_commit()? +//! .with_endpoint("https://cas.example.com") //! .with_token_info("write-token", 1_700_000_000) //! .build() //! .await?; @@ -157,6 +160,7 @@ //! # } //! ``` +mod auth_group_builder; mod common; mod download_stream_group; mod download_stream_handle; @@ -169,14 +173,15 @@ mod upload_commit; mod upload_file_handle; mod upload_stream_handle; -pub use download_stream_group::{DownloadStreamGroupBuilder, XetDownloadStreamGroup}; +pub use download_stream_group::{XetDownloadStreamGroup, XetDownloadStreamGroupBuilder}; pub use download_stream_handle::{XetDownloadStream, XetUnorderedDownloadStream}; pub use errors::SessionError; -pub use file_download_group::{FileDownloadGroupBuilder, XetDownloadGroupReport, XetFileDownloadGroup}; +pub use file_download_group::{XetDownloadGroupReport, XetFileDownloadGroup, XetFileDownloadGroupBuilder}; pub use file_download_handle::{XetDownloadReport, XetFileDownload}; +pub use http::{HeaderMap, HeaderValue, header}; pub use session::{XetSession, XetSessionBuilder}; pub use task_runtime::XetTaskState; -pub use upload_commit::{UploadCommitBuilder, XetCommitReport, XetFileMetadata, XetUploadCommit}; +pub use upload_commit::{XetCommitReport, XetFileMetadata, XetUploadCommit, XetUploadCommitBuilder}; pub use upload_file_handle::XetFileUpload; pub use upload_stream_handle::XetStreamUpload; pub use xet_data::deduplication::DeduplicationMetrics; diff --git a/xet_pkg/src/xet_session/session.rs b/xet_pkg/src/xet_session/session.rs index 1f8a89ce..7f0c53ae 100644 --- a/xet_pkg/src/xet_session/session.rs +++ b/xet_pkg/src/xet_session/session.rs @@ -3,7 +3,6 @@ use std::collections::HashMap; use std::sync::{Arc, Mutex, Weak}; -use http::HeaderMap; use tracing::info; use ulid::Ulid; use xet_data::progress_tracking::UniqueID; @@ -11,11 +10,13 @@ use xet_runtime::RuntimeError; use xet_runtime::config::XetConfig; use xet_runtime::core::XetRuntime; -use super::download_stream_group::{DownloadStreamGroupBuilder, XetDownloadStreamGroup, XetDownloadStreamGroupInner}; +use super::download_stream_group::{ + XetDownloadStreamGroup, XetDownloadStreamGroupBuilder, XetDownloadStreamGroupInner, +}; use super::errors::SessionError; -use super::file_download_group::{FileDownloadGroupBuilder, XetFileDownloadGroup}; +use super::file_download_group::{XetFileDownloadGroup, XetFileDownloadGroupBuilder}; use super::task_runtime::{TaskRuntime, XetTaskState}; -use super::upload_commit::{UploadCommitBuilder, XetUploadCommit}; +use super::upload_commit::{XetUploadCommit, XetUploadCommitBuilder}; /// All shared state for a session. /// Lives behind `Arc` — do not use this type directly. @@ -27,10 +28,6 @@ pub struct XetSessionInner { // Only accessed through &self; no independent cloning needed. pub(super) config: XetConfig, - // CAS endpoint and shared HTTP settings (auth lives at the commit/group level) - pub(super) endpoint: Option, - pub(super) custom_headers: Option>, - // Track active upload commits and download groups. pub(super) active_upload_commits: Mutex>, pub(super) active_file_download_groups: Mutex>, @@ -54,20 +51,21 @@ pub struct XetSessionInner { /// /// ## Authentication /// -/// Auth tokens are configured per-operation on [`UploadCommitBuilder`] and -/// [`FileDownloadGroupBuilder`], not on the session itself. This lets uploads -/// and downloads use different access-level tokens from the same session: +/// Auth tokens are configured per-operation on the builder returned by each factory method, not on the +/// session itself. This lets uploads and downloads use different access-level +/// tokens from the same session: /// /// ```rust,no_run /// # use http::HeaderMap; /// # use xet::xet_session::XetSessionBuilder; -/// let session = XetSessionBuilder::new().with_endpoint("https://cas.example.com").build()?; +/// let session = XetSessionBuilder::new().build()?; /// /// // Upload token (write access) /// let mut upload_headers = HeaderMap::new(); /// upload_headers.insert("Authorization", "Bearer hub-write-token".parse().unwrap()); /// let commit = session /// .new_upload_commit()? +/// .with_endpoint("https://cas.example.com") /// .with_token_info("CAS_WRITE_JWT", 900) /// .with_token_refresh_url("https://huggingface.co/api/repos/token/write", upload_headers) /// .build_blocking()?; @@ -93,8 +91,6 @@ pub struct XetSessionInner { /// ``` pub struct XetSessionBuilder { config: XetConfig, - endpoint: Option, - custom_headers: Option>, tokio_handle: Option, } @@ -109,8 +105,6 @@ impl XetSessionBuilder { pub fn new() -> Self { Self { config: XetConfig::new(), - endpoint: None, - custom_headers: None, tokio_handle: None, } } @@ -119,28 +113,10 @@ impl XetSessionBuilder { pub fn new_with_config(config: XetConfig) -> Self { Self { config, - endpoint: None, - custom_headers: None, tokio_handle: None, } } - /// Set the Xet CAS server endpoint URL (e.g. `"https://cas.example.com"`). - pub fn with_endpoint(self, endpoint: impl Into) -> Self { - Self { - endpoint: Some(endpoint.into()), - ..self - } - } - - /// Attach custom HTTP headers that are forwarded with every CAS request. - pub fn with_custom_headers(self, headers: Arc) -> Self { - Self { - custom_headers: Some(headers), - ..self - } - } - /// Attach to an existing tokio runtime handle. /// /// If the handle meets runtime requirements (multi-thread flavor, time driver, IO driver), @@ -206,7 +182,7 @@ impl XetSessionBuilder { }, }; - let session = XetSession::new(self.config, self.endpoint, self.custom_headers, runtime); + let session = XetSession::new(self.config, runtime); info!("Session created, session_id={}", session.inner.id); Ok(session) } @@ -230,11 +206,11 @@ impl XetSessionBuilder { /// /// 1. Create a session with [`XetSessionBuilder`]. /// 2. Create operations: -/// - uploads via [`new_upload_commit`](Self::new_upload_commit) → [`UploadCommitBuilder`] → [`XetUploadCommit`] -/// - file downloads via [`new_file_download_group`](Self::new_file_download_group) → [`FileDownloadGroupBuilder`] → -/// [`XetFileDownloadGroup`] +/// - uploads via [`new_upload_commit`](Self::new_upload_commit) → [`XetUploadCommitBuilder`] → [`XetUploadCommit`] +/// - file downloads via [`new_file_download_group`](Self::new_file_download_group) → [`XetFileDownloadGroupBuilder`] +/// → [`XetFileDownloadGroup`] /// - streaming downloads via [`new_download_stream_group`](Self::new_download_stream_group) → -/// [`DownloadStreamGroupBuilder`] → [`XetDownloadStreamGroup`] +/// [`XetDownloadStreamGroupBuilder`] → [`XetDownloadStreamGroup`] /// 3. For an emergency stop, call [`XetSession::abort`]. #[derive(Clone)] pub struct XetSession { @@ -243,19 +219,12 @@ pub struct XetSession { impl XetSession { /// Low-level constructor used by [`XetSessionBuilder::build`]. - fn new( - config: XetConfig, - endpoint: Option, - custom_headers: Option>, - runtime: Arc, - ) -> Self { + fn new(config: XetConfig, runtime: Arc) -> Self { let task_runtime = TaskRuntime::new_root(runtime.clone()); Self { inner: Arc::new(XetSessionInner { runtime, config, - endpoint, - custom_headers, active_upload_commits: Mutex::new(HashMap::new()), active_file_download_groups: Mutex::new(HashMap::new()), task_runtime, @@ -265,47 +234,71 @@ impl XetSession { } } - /// Create an [`UploadCommitBuilder`] for configuring and constructing an upload commit. + /// Create a [`XetUploadCommitBuilder`] for configuring and constructing an upload commit. /// - /// Configure per-commit auth with [`with_token_info`](UploadCommitBuilder::with_token_info) - /// and [`with_token_refresh_url`](UploadCommitBuilder::with_token_refresh_url), then call - /// [`build`](UploadCommitBuilder::build) (async) or - /// [`build_blocking`](UploadCommitBuilder::build_blocking) (sync). + /// Configure the builder with any combination of: + /// - [`with_endpoint`](XetUploadCommitBuilder::with_endpoint) — CAS server URL (if omitted, resolved from the token + /// refresh response or the session default) + /// - [`with_custom_headers`](XetUploadCommitBuilder::with_custom_headers) — extra HTTP headers forwarded with every + /// CAS request + /// - [`with_token_info`](XetUploadCommitBuilder::with_token_info) — pre-seeded CAS token and expiry to skip the + /// initial refresh round-trip + /// - [`with_token_refresh_url`](XetUploadCommitBuilder::with_token_refresh_url) — URL and auth headers for + /// refreshing the CAS token + /// + /// Then call [`build`](XetUploadCommitBuilder::build) (async) or + /// [`build_blocking`](XetUploadCommitBuilder::build_blocking) (sync). /// /// Returns `Err(SessionError::UserCancelled)` if the session has been aborted. - pub fn new_upload_commit(&self) -> Result { + pub fn new_upload_commit(&self) -> Result { self.inner.task_runtime.check_state("new_upload_commit")?; - Ok(UploadCommitBuilder::new(self.clone())) + Ok(XetUploadCommitBuilder::new(self.clone())) } - /// Create a [`FileDownloadGroupBuilder`] for configuring and constructing a download group. + /// Create a [`XetFileDownloadGroupBuilder`] for configuring and constructing a file download group. /// - /// Configure per-group auth with [`with_token_info`](FileDownloadGroupBuilder::with_token_info) - /// and [`with_token_refresh_url`](FileDownloadGroupBuilder::with_token_refresh_url), then call - /// [`build`](FileDownloadGroupBuilder::build) (async) or - /// [`build_blocking`](FileDownloadGroupBuilder::build_blocking) (sync). + /// Configure the builder with any combination of: + /// - [`with_endpoint`](XetFileDownloadGroupBuilder::with_endpoint) — CAS server URL (if omitted, resolved from the + /// token refresh response or the session default) + /// - [`with_custom_headers`](XetFileDownloadGroupBuilder::with_custom_headers) — extra HTTP headers forwarded with + /// every CAS request + /// - [`with_token_info`](XetFileDownloadGroupBuilder::with_token_info) — pre-seeded CAS token and expiry to skip + /// the initial refresh round-trip + /// - [`with_token_refresh_url`](XetFileDownloadGroupBuilder::with_token_refresh_url) — URL and auth headers for + /// refreshing the CAS token + /// + /// Then call [`build`](XetFileDownloadGroupBuilder::build) (async) or + /// [`build_blocking`](XetFileDownloadGroupBuilder::build_blocking) (sync). /// /// Returns `Err(SessionError::UserCancelled)` if the session has been aborted. - pub fn new_file_download_group(&self) -> Result { + pub fn new_file_download_group(&self) -> Result { self.inner.task_runtime.check_state("new_file_download_group")?; - Ok(FileDownloadGroupBuilder::new(self.clone())) + Ok(XetFileDownloadGroupBuilder::new(self.clone())) } - /// Create a [`DownloadStreamGroupBuilder`] for configuring and constructing a download stream group. + /// Create a [`XetDownloadStreamGroupBuilder`] for configuring and constructing a download stream group. /// - /// Configure per-group auth with [`with_token_info`](DownloadStreamGroupBuilder::with_token_info) - /// and [`with_token_refresh_url`](DownloadStreamGroupBuilder::with_token_refresh_url), then call - /// [`build`](DownloadStreamGroupBuilder::build) (async) or - /// [`build_blocking`](DownloadStreamGroupBuilder::build_blocking) (sync). + /// Configure the builder with any combination of: + /// - [`with_endpoint`](XetDownloadStreamGroupBuilder::with_endpoint) — CAS server URL (if omitted, resolved from + /// the token refresh response or the session default) + /// - [`with_custom_headers`](XetDownloadStreamGroupBuilder::with_custom_headers) — extra HTTP headers forwarded + /// with every CAS request + /// - [`with_token_info`](XetDownloadStreamGroupBuilder::with_token_info) — pre-seeded CAS token and expiry to skip + /// the initial refresh round-trip + /// - [`with_token_refresh_url`](XetDownloadStreamGroupBuilder::with_token_refresh_url) — URL and auth headers for + /// refreshing the CAS token + /// + /// Then call [`build`](XetDownloadStreamGroupBuilder::build) (async) or + /// [`build_blocking`](XetDownloadStreamGroupBuilder::build_blocking) (sync). /// /// Use the resulting [`XetDownloadStreamGroup`] to create individual streams via /// [`download_stream`](XetDownloadStreamGroup::download_stream) and /// [`download_unordered_stream`](XetDownloadStreamGroup::download_unordered_stream). /// /// Returns `Err(SessionError::UserCancelled)` if the session has been aborted. - pub fn new_download_stream_group(&self) -> Result { + pub fn new_download_stream_group(&self) -> Result { self.inner.task_runtime.check_state("new_download_stream_group")?; - Ok(DownloadStreamGroupBuilder::new(self.clone())) + Ok(XetDownloadStreamGroupBuilder::new(self.clone())) } pub fn status(&self) -> Result { @@ -400,7 +393,7 @@ impl XetSession { #[cfg(test)] mod tests { - use tempfile::{TempDir, tempdir}; + use tempfile::tempdir; use xet_data::processing::{Sha256Policy, XetFileInfo}; use xet_runtime::core::{RuntimeMode, XetRuntime}; @@ -720,19 +713,13 @@ mod tests { // ── Streaming download round-trip tests ───────────────────────────────── - fn local_session(temp: &TempDir) -> Result> { - let cas_path = temp.path().join("cas"); - Ok(XetSessionBuilder::new() - .with_endpoint(format!("local://{}", cas_path.display())) - .build()?) - } - async fn upload_bytes( session: &XetSession, + endpoint: &str, data: &[u8], name: &str, ) -> Result> { - let commit = session.new_upload_commit()?.build().await?; + let commit = session.new_upload_commit()?.with_endpoint(endpoint).build().await?; let _handle = commit .upload_bytes(data.to_vec(), Sha256Policy::Compute, Some(name.into())) .await?; @@ -743,10 +730,11 @@ mod tests { fn upload_bytes_blocking( session: &XetSession, + endpoint: &str, data: &[u8], name: &str, ) -> Result> { - let commit = session.new_upload_commit()?.build_blocking()?; + let commit = session.new_upload_commit()?.with_endpoint(endpoint).build_blocking()?; let _handle = commit.upload_bytes_blocking(data.to_vec(), Sha256Policy::Compute, Some(name.into()))?; let results = commit.commit_blocking()?; let meta = results.uploads.into_values().next().expect("one uploaded file"); @@ -757,13 +745,15 @@ mod tests { // Async streaming download round-trip: upload, stream, verify content. async fn test_download_stream_round_trip() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let original = b"Hello, streaming download!"; - let file_info = upload_bytes(&session, original, "stream.bin").await.unwrap(); + let file_info = upload_bytes(&session, &endpoint, original, "stream.bin").await.unwrap(); let mut stream = session .new_download_stream_group() .unwrap() + .with_endpoint(&endpoint) .build() .await .unwrap() @@ -781,13 +771,15 @@ mod tests { // Blocking streaming download round-trip: upload, stream, verify content. fn test_download_stream_blocking_round_trip() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let original = b"Hello, blocking streaming download!"; - let file_info = upload_bytes_blocking(&session, original, "stream.bin").unwrap(); + let file_info = upload_bytes_blocking(&session, &endpoint, original, "stream.bin").unwrap(); let mut stream = session .new_download_stream_group() .unwrap() + .with_endpoint(&endpoint) .build_blocking() .unwrap() .download_stream_blocking(file_info, None) @@ -804,13 +796,15 @@ mod tests { // progress() reports correct totals after consuming the stream. async fn test_download_stream_progress_reports_completion() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let original = b"progress tracking test data for streaming"; - let file_info = upload_bytes(&session, original, "progress.bin").await.unwrap(); + let file_info = upload_bytes(&session, &endpoint, original, "progress.bin").await.unwrap(); let mut stream = session .new_download_stream_group() .unwrap() + .with_endpoint(&endpoint) .build() .await .unwrap() @@ -836,13 +830,15 @@ mod tests { // progress() works correctly in blocking mode. fn test_download_stream_blocking_progress_reports_completion() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let original = b"blocking progress tracking test data"; - let file_info = upload_bytes_blocking(&session, original, "progress.bin").unwrap(); + let file_info = upload_bytes_blocking(&session, &endpoint, original, "progress.bin").unwrap(); let mut stream = session .new_download_stream_group() .unwrap() + .with_endpoint(&endpoint) .build_blocking() .unwrap() .download_stream_blocking(file_info, None) @@ -863,13 +859,20 @@ mod tests { // Multiple sequential streaming downloads share a single group's connection pool. async fn test_download_stream_multiple_sequential() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data_a = b"first stream payload"; let data_b = b"second stream payload"; - let info_a = upload_bytes(&session, data_a, "a.bin").await.unwrap(); - let info_b = upload_bytes(&session, data_b, "b.bin").await.unwrap(); + let info_a = upload_bytes(&session, &endpoint, data_a, "a.bin").await.unwrap(); + let info_b = upload_bytes(&session, &endpoint, data_b, "b.bin").await.unwrap(); - let group = session.new_download_stream_group().unwrap().build().await.unwrap(); + let group = session + .new_download_stream_group() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let mut stream_a = group.download_stream(info_a, None).await.unwrap(); let mut collected_a = Vec::new(); diff --git a/xet_pkg/src/xet_session/upload_commit.rs b/xet_pkg/src/xet_session/upload_commit.rs index 01b4a6ed..77ff9276 100644 --- a/xet_pkg/src/xet_session/upload_commit.rs +++ b/xet_pkg/src/xet_session/upload_commit.rs @@ -4,12 +4,12 @@ use std::collections::HashMap; use std::path::PathBuf; use std::sync::{Arc, Mutex, OnceLock}; -use http::HeaderMap; use tracing::{error, info}; use xet_data::deduplication::DeduplicationMetrics; use xet_data::processing::{FileUploadSession, Sha256Policy, XetFileInfo}; use xet_data::progress_tracking::{GroupProgressReport, UniqueID}; +use super::auth_group_builder::{AuthGroupBuilder, AuthOptions}; use super::common::create_translator_config; use super::session::XetSession; use super::task_runtime::{BackgroundTaskState, TaskRuntime, XetTaskState}; @@ -17,65 +17,20 @@ use super::upload_file_handle::{XetFileUpload, XetFileUploadInner}; use super::upload_stream_handle::{XetStreamUpload, XetStreamUploadInner}; use crate::error::XetError; -/// Builder for [`XetUploadCommit`]. -/// -/// Obtain via [`XetSession::new_upload_commit`], configure per-commit auth -/// with [`with_token_info`](Self::with_token_info) and -/// [`with_token_refresh_url`](Self::with_token_refresh_url), then call -/// [`build`](Self::build) (async) or [`build_blocking`](Self::build_blocking) (sync). -pub struct UploadCommitBuilder { - session: XetSession, - token_info: Option<(String, u64)>, - token_refresh: Option<(String, Arc)>, -} - -impl UploadCommitBuilder { - pub(super) fn new(session: XetSession) -> Self { - Self { - session, - token_info: None, - token_refresh: None, - } - } - - /// Seed an initial CAS access token and its expiry as a Unix timestamp (seconds). - /// - /// When combined with [`with_token_refresh_url`](Self::with_token_refresh_url) this - /// avoids an extra refresh round-trip on the first request. - pub fn with_token_info(self, token: impl Into, expiry: u64) -> Self { - Self { - token_info: Some((token.into(), expiry)), - ..self - } - } - - /// Set a URL and authentication headers used to obtain a fresh CAS access token - /// whenever the current one is about to expire. - /// - /// The client issues an authenticated HTTP GET to `url` with `headers` (which should - /// include auth credentials, e.g. `Authorization: Bearer `). The endpoint - /// must return JSON: - /// `{ "accessToken": "", "exp": , "casUrl": "" }`. - pub fn with_token_refresh_url(self, url: impl Into, headers: HeaderMap) -> Self { - Self { - token_refresh: Some((url.into(), Arc::new(headers))), - ..self - } - } +pub type XetUploadCommitBuilder = AuthGroupBuilder; +impl AuthGroupBuilder { /// Create the [`XetUploadCommit`] from an async context. pub async fn build(self) -> Result { - let UploadCommitBuilder { - session, - token_info, - token_refresh, + let AuthGroupBuilder { + session, auth_options, .. } = self; let parent_runtime = session.inner.task_runtime.clone(); let child_parent = parent_runtime.clone(); let commit = parent_runtime .bridge_async("new_upload_commit", async move { let commit_runtime = child_parent.child()?; - XetUploadCommit::new(session, commit_runtime, token_info, token_refresh).await + XetUploadCommit::new(session, commit_runtime, auth_options).await }) .await?; info!("New upload commit, session_id={}, commit_id={}", commit.session().id(), commit.id()); @@ -95,16 +50,14 @@ impl UploadCommitBuilder { /// /// Panics if called from within a tokio async runtime on an Owned-mode session. pub fn build_blocking(self) -> Result { - let UploadCommitBuilder { - session, - token_info, - token_refresh, + let AuthGroupBuilder { + session, auth_options, .. } = self; let parent_runtime = session.inner.task_runtime.clone(); let child_parent = parent_runtime.clone(); let commit = parent_runtime.bridge_sync("new_upload_commit_blocking", async move { let commit_runtime = child_parent.child()?; - XetUploadCommit::new(session, commit_runtime, token_info, token_refresh).await + XetUploadCommit::new(session, commit_runtime, auth_options).await })?; info!("New upload commit, session_id={}, commit_id={}", commit.session().id(), commit.id()); commit.session().register_upload_commit(&commit)?; @@ -157,26 +110,6 @@ pub(super) struct XetUploadCommitInner { } impl XetUploadCommitInner { - async fn new( - session: XetSession, - task_runtime: Arc, - token_info: Option<(String, u64)>, - token_refresh: Option<(String, Arc)>, - ) -> Result { - let commit_id = UniqueID::new(); - let config = create_translator_config(&session, token_info, token_refresh.as_ref())?; - let upload_session = FileUploadSession::new(Arc::new(config)).await?; - - Ok(Self { - commit_id, - session, - task_runtime, - upload_session, - file_handles: Mutex::new(Vec::new()), - stream_handles: Mutex::new(Vec::new()), - }) - } - async fn upload_from_path( self: &Arc, file_path: PathBuf, @@ -359,9 +292,9 @@ impl XetUploadCommitInner { /// API for grouping related file uploads into a single atomic commit. /// /// Obtain via [`XetSession::new_upload_commit`] — configure per-commit -/// auth on the returned [`UploadCommitBuilder`], then call -/// [`build`](UploadCommitBuilder::build) (async) or -/// [`build_blocking`](UploadCommitBuilder::build_blocking) (sync). +/// auth on the returned [`AuthGroupBuilder`], then call +/// [`build`](AuthGroupBuilder::build) (async) or +/// [`build_blocking`](AuthGroupBuilder::build_blocking) (sync). /// /// Enqueue files with [`upload_from_path`](Self::upload_from_path) / /// [`upload_from_path_blocking`](Self::upload_from_path_blocking), stream bytes @@ -396,14 +329,22 @@ impl XetUploadCommit { pub(super) async fn new( session: XetSession, task_runtime: Arc, - token_info: Option<(String, u64)>, - token_refresh: Option<(String, Arc)>, + auth_options: AuthOptions, ) -> Result { - let inner = XetUploadCommitInner::new(session, task_runtime.clone(), token_info, token_refresh).await?; - Ok(Self { - inner: Arc::new(inner), - task_runtime, - }) + let commit_id = UniqueID::new(); + let config = create_translator_config(&session, auth_options).await?; + let upload_session = FileUploadSession::new(Arc::new(config)).await?; + + let inner = Arc::new(XetUploadCommitInner { + commit_id, + session, + task_runtime: task_runtime.clone(), + upload_session, + file_handles: Mutex::new(Vec::new()), + stream_handles: Mutex::new(Vec::new()), + }); + + Ok(Self { inner, task_runtime }) } /// Unique identifier for this upload commit. @@ -607,18 +548,11 @@ mod tests { use std::sync::mpsc; use std::time::Duration; - use tempfile::{TempDir, tempdir}; + use tempfile::tempdir; use xet_runtime::core::RuntimeMode; + use super::super::session::XetSessionBuilder; use super::*; - use crate::xet_session::session::{XetSession, XetSessionBuilder}; - - fn local_session(temp: &TempDir) -> Result> { - let cas_path = temp.path().join("cas"); - Ok(XetSessionBuilder::new() - .with_endpoint(format!("local://{}", cas_path.display())) - .build()?) - } // ── Mutex guard / concurrency test ─────────────────────────────────────── @@ -626,11 +560,10 @@ mod tests { fn test_commit_blocked_while_upload_registration_holds_state_lock() -> Result<(), Box> { let temp = tempdir()?; let cas_path = temp.path().join("cas"); - let session = XetSessionBuilder::new() - .with_endpoint(format!("local://{}", cas_path.display())) - .build()?; + let endpoint = format!("local://{}", cas_path.display()); + let session = XetSessionBuilder::new().build()?; let runtime = session.inner.runtime.clone(); - let commit = session.new_upload_commit()?.build_blocking()?; + let commit = session.new_upload_commit()?.with_endpoint(&endpoint).build_blocking()?; let commit_for_thread = commit.clone(); let runtime_for_thread = runtime.clone(); @@ -837,9 +770,16 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_upload_bytes_round_trip() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data = b"Hello, upload commit round-trip!"; - let commit = session.new_upload_commit().unwrap().build().await.unwrap(); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let task_handle = commit .upload_bytes(data.to_vec(), Sha256Policy::Compute, Some("hello.bin".into())) .await @@ -858,8 +798,15 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_upload_bytes_task_id_matches_progress() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); - let commit = session.new_upload_commit().unwrap().build().await.unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let handle = commit .upload_bytes(b"id-match".to_vec(), Sha256Policy::Compute, Some("id.bin".into())) @@ -879,8 +826,15 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_upload_handle_file_path_none_for_bytes_upload() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); - let commit = session.new_upload_commit().unwrap().build().await.unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let handle = commit .upload_bytes(b"no-path".to_vec(), Sha256Policy::Compute, Some("bytes.bin".into())) .await @@ -891,11 +845,18 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_upload_from_path_round_trip() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let src = temp.path().join("data.bin"); let data = b"file path upload content"; std::fs::write(&src, data).unwrap(); - let commit = session.new_upload_commit().unwrap().build().await.unwrap(); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let handle = commit.upload_from_path(src, Sha256Policy::Compute).await.unwrap(); commit.commit().await.unwrap(); let meta = handle.try_finish().unwrap(); @@ -908,11 +869,18 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_upload_handle_file_path_for_path_upload() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let src = temp.path().join("path_meta.bin"); std::fs::write(&src, b"path metadata").unwrap(); let absolute = std::path::absolute(&src).unwrap(); - let commit = session.new_upload_commit().unwrap().build().await.unwrap(); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let handle = commit.upload_from_path(src, Sha256Policy::Compute).await.unwrap(); assert_eq!(handle.file_path(), Some(absolute)); } @@ -920,8 +888,15 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_upload_bytes_sha256_policy_metadata() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); - let commit = session.new_upload_commit().unwrap().build().await.unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let provided_sha256 = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".to_string(); let compute_handle = commit @@ -955,9 +930,16 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_finish_returns_result_before_commit() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data = b"finish before commit"; - let commit = session.new_upload_commit().unwrap().build().await.unwrap(); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let handle = commit .upload_bytes(data.to_vec(), Sha256Policy::Compute, Some("early.bin".into())) .await @@ -970,8 +952,15 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_finish_second_call_returns_cached_result() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); - let commit = session.new_upload_commit().unwrap().build().await.unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let handle = commit .upload_bytes(b"idem".to_vec(), Sha256Policy::Compute, None) .await @@ -987,9 +976,16 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_finish_includes_dedup_metrics() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data = b"dedup metrics check"; - let commit = session.new_upload_commit().unwrap().build().await.unwrap(); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let handle = commit.upload_bytes(data.to_vec(), Sha256Policy::Compute, None).await.unwrap(); let meta = handle.finalize_ingestion().await.unwrap(); assert_eq!(meta.dedup_metrics.total_bytes, data.len() as u64); @@ -1001,9 +997,16 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_upload_streaming_round_trip() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data = b"streamed upload bytes"; - let commit = session.new_upload_commit().unwrap().build().await.unwrap(); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let handle = commit .upload_stream(Some("stream.bin".into()), Sha256Policy::Compute) .await @@ -1019,8 +1022,15 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_commit_errors_when_stream_not_finished() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); - let commit = session.new_upload_commit().unwrap().build().await.unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let handle = commit .upload_stream(Some("unfinished.bin".into()), Sha256Policy::Compute) .await @@ -1033,8 +1043,15 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_stream_finish_second_call_is_already_completed() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); - let commit = session.new_upload_commit().unwrap().build().await.unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let handle = commit .upload_stream(Some("idem.bin".into()), Sha256Policy::Compute) .await @@ -1051,8 +1068,15 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_stream_write_after_finish_errors() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); - let commit = session.new_upload_commit().unwrap().build().await.unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let handle = commit .upload_stream(Some("done.bin".into()), Sha256Policy::Compute) .await @@ -1084,8 +1108,15 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_upload_multiple_files_in_one_commit() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); - let commit = session.new_upload_commit().unwrap().build().await.unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let h1 = commit .upload_bytes(b"file one".to_vec(), Sha256Policy::Compute, Some("a.bin".into())) .await @@ -1109,9 +1140,16 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_upload_progress_reflects_bytes_after_commit() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data = b"progress tracking upload data"; - let commit = session.new_upload_commit().unwrap().build().await.unwrap(); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let progress_observer = commit.clone(); commit .upload_bytes(data.to_vec(), Sha256Policy::Compute, Some("prog.bin".into())) @@ -1131,12 +1169,19 @@ mod tests { fn test_async_bridge_works_from_futures_executor() { let temp = tempdir().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); futures::executor::block_on(async { - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); assert_eq!(session.inner.runtime.mode(), RuntimeMode::Owned); let data = b"hello from non-tokio executor"; - let commit = session.new_upload_commit().unwrap().build().await.unwrap(); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let handle = commit .upload_bytes(data.to_vec(), Sha256Policy::Compute, Some("test.bin".into())) .await @@ -1152,12 +1197,19 @@ mod tests { fn test_async_bridge_works_from_smol_executor() { let temp = tempdir().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); smol::block_on(async { - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); assert_eq!(session.inner.runtime.mode(), RuntimeMode::Owned); let data = b"hello from smol executor"; - let commit = session.new_upload_commit().unwrap().build().await.unwrap(); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let handle = commit .upload_bytes(data.to_vec(), Sha256Policy::Compute, Some("test.bin".into())) .await @@ -1173,12 +1225,19 @@ mod tests { fn test_async_bridge_works_from_async_std_executor() { let temp = tempdir().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); async_std::task::block_on(async { - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); assert_eq!(session.inner.runtime.mode(), RuntimeMode::Owned); let data = b"hello from async-std executor"; - let commit = session.new_upload_commit().unwrap().build().await.unwrap(); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let handle = commit .upload_bytes(data.to_vec(), Sha256Policy::Compute, Some("test.bin".into())) .await @@ -1195,9 +1254,10 @@ mod tests { #[test] fn test_blocking_upload_bytes_round_trip() -> Result<(), Box> { let temp = tempdir()?; - let session = local_session(&temp)?; + let session = XetSessionBuilder::new().build()?; + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data = b"Hello, upload commit round-trip!"; - let commit = session.new_upload_commit()?.build_blocking()?; + let commit = session.new_upload_commit()?.with_endpoint(&endpoint).build_blocking()?; let task_handle = commit.upload_bytes_blocking(data.to_vec(), Sha256Policy::Compute, Some("hello.bin".into()))?; let results = commit.commit_blocking()?; @@ -1211,11 +1271,12 @@ mod tests { #[test] fn test_blocking_upload_from_path_round_trip() -> Result<(), Box> { let temp = tempdir()?; - let session = local_session(&temp)?; + let session = XetSessionBuilder::new().build()?; + let endpoint = format!("local://{}", temp.path().join("cas").display()); let src = temp.path().join("data.bin"); let data = b"file path upload content"; std::fs::write(&src, data)?; - let commit = session.new_upload_commit()?.build_blocking()?; + let commit = session.new_upload_commit()?.with_endpoint(&endpoint).build_blocking()?; let handle = commit.upload_from_path_blocking(src, Sha256Policy::Compute)?; commit.commit_blocking()?; let meta = handle.try_finish().unwrap(); @@ -1228,11 +1289,12 @@ mod tests { #[test] fn test_blocking_upload_result_access_patterns() -> Result<(), Box> { let temp = tempdir()?; - let session = local_session(&temp)?; + let session = XetSessionBuilder::new().build()?; + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data = b"result access patterns"; let src = temp.path().join("data.bin"); std::fs::write(&src, data)?; - let commit = session.new_upload_commit()?.build_blocking()?; + let commit = session.new_upload_commit()?.with_endpoint(&endpoint).build_blocking()?; let handle = commit.upload_from_path_blocking(src, Sha256Policy::Compute)?; // Before commit, per-task result is not available yet. @@ -1256,9 +1318,10 @@ mod tests { #[test] fn test_blocking_upload_streaming_round_trip() -> Result<(), Box> { let temp = tempdir()?; - let session = local_session(&temp)?; + let session = XetSessionBuilder::new().build()?; + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data = b"streamed upload bytes"; - let commit = session.new_upload_commit()?.build_blocking()?; + let commit = session.new_upload_commit()?.with_endpoint(&endpoint).build_blocking()?; let stream = commit.upload_stream_blocking(Some("stream.bin".into()), Sha256Policy::Compute)?; stream.write_blocking(data.to_vec())?; let meta = stream.finish_blocking()?; @@ -1272,8 +1335,9 @@ mod tests { #[test] fn test_blocking_upload_multiple_files_in_one_commit() -> Result<(), Box> { let temp = tempdir()?; - let session = local_session(&temp)?; - let commit = session.new_upload_commit()?.build_blocking()?; + let session = XetSessionBuilder::new().build()?; + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let commit = session.new_upload_commit()?.with_endpoint(&endpoint).build_blocking()?; commit.upload_bytes_blocking(b"file one".to_vec(), Sha256Policy::Compute, Some("a.bin".into()))?; commit.upload_bytes_blocking(b"file two".to_vec(), Sha256Policy::Compute, Some("b.bin".into()))?; commit.upload_bytes_blocking(b"file three".to_vec(), Sha256Policy::Compute, Some("c.bin".into()))?; @@ -1285,9 +1349,10 @@ mod tests { #[test] fn test_blocking_upload_progress_reflects_bytes_after_commit() -> Result<(), Box> { let temp = tempdir()?; - let session = local_session(&temp)?; + let session = XetSessionBuilder::new().build()?; + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data = b"progress tracking upload data"; - let commit = session.new_upload_commit()?.build_blocking()?; + let commit = session.new_upload_commit()?.with_endpoint(&endpoint).build_blocking()?; let progress_observer = commit.clone(); commit.upload_bytes_blocking(data.to_vec(), Sha256Policy::Compute, Some("prog.bin".into()))?; commit.commit_blocking()?; @@ -1302,8 +1367,9 @@ mod tests { #[test] fn test_blocking_upload_file_returns_handle_without_status() -> Result<(), Box> { let temp = tempdir()?; - let session = local_session(&temp)?; - let commit = session.new_upload_commit()?.build_blocking()?; + let session = XetSessionBuilder::new().build()?; + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let commit = session.new_upload_commit()?.with_endpoint(&endpoint).build_blocking()?; let handle = commit.upload_stream_blocking(Some("stream.bin".into()), Sha256Policy::Compute)?; assert!(handle.try_finish().is_none()); Ok(()) @@ -1314,11 +1380,17 @@ mod tests { R: FnOnce(std::pin::Pin>>), { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); run(Box::pin(async move { let data = b"upload from smol executor"; - let commit = session.new_upload_commit().unwrap().build_blocking().unwrap(); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build_blocking() + .unwrap(); let handle = commit .upload_bytes_blocking(data.to_vec(), Sha256Policy::Compute, Some("test.bin".into())) .unwrap(); diff --git a/xet_pkg/tests/test_xet_session.rs b/xet_pkg/tests/test_xet_session.rs index 66461878..200ca2a4 100644 --- a/xet_pkg/tests/test_xet_session.rs +++ b/xet_pkg/tests/test_xet_session.rs @@ -26,19 +26,18 @@ use xet::xet_session::{ // ── Helpers ────────────────────────────────────────────────────────────── -fn local_session(temp: &TempDir) -> Result> { - let cas_path = temp.path().join("cas"); - Ok(XetSessionBuilder::new() - .with_endpoint(format!("local://{}", cas_path.display())) - .build()?) -} - fn to_file_info(meta: &XetFileMetadata) -> XetFileInfo { meta.xet_info.clone() } -async fn upload_bytes_async(session: &XetSession, data: &[u8], name: &str) -> XetFileInfo { - let commit = session.new_upload_commit().unwrap().build().await.unwrap(); +async fn upload_bytes_async(session: &XetSession, endpoint: &str, data: &[u8], name: &str) -> XetFileInfo { + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(endpoint) + .build() + .await + .unwrap(); let handle = commit .upload_bytes(data.to_vec(), Sha256Policy::Compute, Some(name.into())) .await @@ -48,8 +47,13 @@ async fn upload_bytes_async(session: &XetSession, data: &[u8], name: &str) -> Xe file_meta.xet_info } -fn upload_bytes_sync(session: &XetSession, data: &[u8], name: &str) -> XetFileInfo { - let commit = session.new_upload_commit().unwrap().build_blocking().unwrap(); +fn upload_bytes_sync(session: &XetSession, endpoint: &str, data: &[u8], name: &str) -> XetFileInfo { + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(endpoint) + .build_blocking() + .unwrap(); let handle = commit .upload_bytes_blocking(data.to_vec(), Sha256Policy::Compute, Some(name.into())) .unwrap(); @@ -58,23 +62,34 @@ fn upload_bytes_sync(session: &XetSession, data: &[u8], name: &str) -> XetFileIn file_meta.xet_info } -async fn assert_roundtrip_async(session: &XetSession, temp: &TempDir, data: &[u8], name: &str) { - let file_info = upload_bytes_async(session, data, name).await; +async fn assert_roundtrip_async(session: &XetSession, endpoint: &str, temp: &TempDir, data: &[u8], name: &str) { + let file_info = upload_bytes_async(session, endpoint, data, name).await; assert_eq!(file_info.file_size(), Some(data.len() as u64)); let dest = temp.path().join(format!("{name}.out")); - let group = session.new_file_download_group().unwrap().build().await.unwrap(); + let group = session + .new_file_download_group() + .unwrap() + .with_endpoint(endpoint) + .build() + .await + .unwrap(); group.download_file_to_path(file_info, dest.clone()).await.unwrap(); group.finish().await.unwrap(); assert_eq!(fs::read(&dest).unwrap(), data); } -fn assert_roundtrip_sync(session: &XetSession, temp: &TempDir, data: &[u8], name: &str) { - let file_info = upload_bytes_sync(session, data, name); +fn assert_roundtrip_sync(session: &XetSession, endpoint: &str, temp: &TempDir, data: &[u8], name: &str) { + let file_info = upload_bytes_sync(session, endpoint, data, name); assert_eq!(file_info.file_size(), Some(data.len() as u64)); let dest = temp.path().join(format!("{name}.out")); - let group = session.new_file_download_group().unwrap().build_blocking().unwrap(); + let group = session + .new_file_download_group() + .unwrap() + .with_endpoint(endpoint) + .build_blocking() + .unwrap(); group.download_file_to_path_blocking(file_info, dest.clone()).unwrap(); group.finish_blocking().unwrap(); assert_eq!(fs::read(&dest).unwrap(), data); @@ -82,6 +97,7 @@ fn assert_roundtrip_sync(session: &XetSession, temp: &TempDir, data: &[u8], name async fn assert_upload_from_path_roundtrip_async( session: &XetSession, + endpoint: &str, temp: &TempDir, src_name: &str, dest_name: &str, @@ -91,7 +107,13 @@ async fn assert_upload_from_path_roundtrip_async( fs::write(&src, data).unwrap(); let file_meta = { - let commit = session.new_upload_commit().unwrap().build().await.unwrap(); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(endpoint) + .build() + .await + .unwrap(); let handle = commit.upload_from_path(src, Sha256Policy::Compute).await.unwrap(); let file_meta = handle.finalize_ingestion().await.unwrap(); commit.commit().await.unwrap(); @@ -99,7 +121,13 @@ async fn assert_upload_from_path_roundtrip_async( }; let dest = temp.path().join(dest_name); - let group = session.new_file_download_group().unwrap().build().await.unwrap(); + let group = session + .new_file_download_group() + .unwrap() + .with_endpoint(endpoint) + .build() + .await + .unwrap(); group .download_file_to_path(to_file_info(&file_meta), dest.clone()) .await @@ -110,6 +138,7 @@ async fn assert_upload_from_path_roundtrip_async( fn assert_upload_from_path_roundtrip_sync( session: &XetSession, + endpoint: &str, temp: &TempDir, src_name: &str, dest_name: &str, @@ -119,7 +148,12 @@ fn assert_upload_from_path_roundtrip_sync( fs::write(&src, data).unwrap(); let file_meta = { - let commit = session.new_upload_commit().unwrap().build_blocking().unwrap(); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(endpoint) + .build_blocking() + .unwrap(); let handle = commit.upload_from_path_blocking(src, Sha256Policy::Compute).unwrap(); let file_meta = handle.finalize_ingestion_blocking().unwrap(); commit.commit_blocking().unwrap(); @@ -127,7 +161,12 @@ fn assert_upload_from_path_roundtrip_sync( }; let dest = temp.path().join(dest_name); - let group = session.new_file_download_group().unwrap().build_blocking().unwrap(); + let group = session + .new_file_download_group() + .unwrap() + .with_endpoint(endpoint) + .build_blocking() + .unwrap(); group .download_file_to_path_blocking(to_file_info(&file_meta), dest.clone()) .unwrap(); @@ -210,21 +249,29 @@ fn deficient_runtime_cases() -> Vec<(&'static str, RuntimeBuilder)> { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_upload_bytes_roundtrip() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); - assert_roundtrip_async(&session, &temp, b"async upload bytes test", "bytes").await; + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + assert_roundtrip_async(&session, &endpoint, &temp, b"async upload bytes test", "bytes").await; } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_upload_from_path_roundtrip() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let src = temp.path().join("source.bin"); let data = b"upload from path integration test content"; fs::write(&src, data).unwrap(); let file_meta = { - let commit = session.new_upload_commit().unwrap().build().await.unwrap(); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let handle = commit.upload_from_path(src, Sha256Policy::Compute).await.unwrap(); let file_meta = handle.finalize_ingestion().await.unwrap(); assert_eq!(file_meta.xet_info.file_size(), Some(data.len() as u64)); @@ -234,7 +281,13 @@ async fn async_upload_from_path_roundtrip() { }; let dest = temp.path().join("dest.bin"); - let group = session.new_file_download_group().unwrap().build().await.unwrap(); + let group = session + .new_file_download_group() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); group .download_file_to_path(file_meta.xet_info.clone(), dest.clone()) .await @@ -246,7 +299,8 @@ async fn async_upload_from_path_roundtrip() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_multiple_files_in_one_commit() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let files: Vec<(&str, &[u8])> = vec![ ("alpha.bin", b"alpha content"), @@ -255,7 +309,13 @@ async fn async_multiple_files_in_one_commit() { ]; let metas = { - let commit = session.new_upload_commit().unwrap().build().await.unwrap(); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let mut metas = Vec::new(); for (name, data) in &files { let h = commit @@ -268,7 +328,13 @@ async fn async_multiple_files_in_one_commit() { metas }; - let group = session.new_file_download_group().unwrap().build().await.unwrap(); + let group = session + .new_file_download_group() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let mut dest_paths = Vec::new(); for (i, file_meta) in metas.iter().enumerate() { let dest = temp.path().join(format!("out_{i}.bin")); @@ -288,10 +354,17 @@ async fn async_multiple_files_in_one_commit() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_sha256_policy_variants() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let provided_sha256 = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".to_string(); - let commit = session.new_upload_commit().unwrap().build().await.unwrap(); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let h_compute = commit .upload_bytes(b"compute sha".to_vec(), Sha256Policy::Compute, Some("compute.bin".into())) @@ -322,27 +395,41 @@ async fn async_sha256_policy_variants() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_large_file_roundtrip() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data: Vec = (0..65536u64).map(|i| (i % 251) as u8).collect(); - assert_roundtrip_async(&session, &temp, &data, "large").await; + assert_roundtrip_async(&session, &endpoint, &temp, &data, "large").await; } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_multiple_commits_and_groups() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); - let info_a = upload_bytes_async(&session, b"commit A data", "a.bin").await; - let info_b = upload_bytes_async(&session, b"commit B data", "b.bin").await; + let info_a = upload_bytes_async(&session, &endpoint, b"commit A data", "a.bin").await; + let info_b = upload_bytes_async(&session, &endpoint, b"commit B data", "b.bin").await; let dest_a = temp.path().join("a.out"); let dest_b = temp.path().join("b.out"); - let group1 = session.new_file_download_group().unwrap().build().await.unwrap(); + let group1 = session + .new_file_download_group() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); group1.download_file_to_path(info_a, dest_a.clone()).await.unwrap(); group1.finish().await.unwrap(); - let group2 = session.new_file_download_group().unwrap().build().await.unwrap(); + let group2 = session + .new_file_download_group() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); group2.download_file_to_path(info_b, dest_b.clone()).await.unwrap(); group2.finish().await.unwrap(); @@ -353,9 +440,16 @@ async fn async_multiple_commits_and_groups() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_task_status_transitions() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); - let commit = session.new_upload_commit().unwrap().build().await.unwrap(); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let handle = commit .upload_bytes(b"status test".to_vec(), Sha256Policy::Compute, Some("status.bin".into())) .await @@ -372,10 +466,17 @@ async fn async_task_status_transitions() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_progress_tracking() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data = b"progress tracking integration test data"; - let commit = session.new_upload_commit().unwrap().build().await.unwrap(); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); commit .upload_bytes(data.to_vec(), Sha256Policy::Compute, Some("prog.bin".into())) .await @@ -393,14 +494,21 @@ async fn async_progress_tracking() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_download_unknown_size_roundtrip() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data = b"download with unknown size via xet_pkg"; - let file_info = upload_bytes_async(&session, data, "unknown_size.bin").await; + let file_info = upload_bytes_async(&session, &endpoint, data, "unknown_size.bin").await; let hash_only = XetFileInfo::new_hash_only(file_info.hash().to_string()); let dest = temp.path().join("unknown_size.out"); - let group = session.new_file_download_group().unwrap().build().await.unwrap(); + let group = session + .new_file_download_group() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); group.download_file_to_path(hash_only, dest.clone()).await.unwrap(); let report = group.finish().await.unwrap(); @@ -413,9 +521,16 @@ async fn async_download_unknown_size_roundtrip() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_download_invalid_hash_fails() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); - let group = session.new_file_download_group().unwrap().build().await.unwrap(); + let group = session + .new_file_download_group() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let handle = group .download_file_to_path( XetFileInfo { @@ -435,7 +550,8 @@ async fn async_download_invalid_hash_fails() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_upload_from_path_multiple_files() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let src_a = temp.path().join("src_a.bin"); let src_b = temp.path().join("src_b.bin"); @@ -443,7 +559,13 @@ async fn async_upload_from_path_multiple_files() { fs::write(&src_b, [0xCD; 8192]).unwrap(); let (info_a, info_b) = { - let commit = session.new_upload_commit().unwrap().build().await.unwrap(); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let ha = commit.upload_from_path(src_a, Sha256Policy::Compute).await.unwrap(); let hb = commit.upload_from_path(src_b, Sha256Policy::Compute).await.unwrap(); let info_a = ha.finalize_ingestion().await.unwrap().xet_info; @@ -454,7 +576,13 @@ async fn async_upload_from_path_multiple_files() { let dest_a = temp.path().join("dest_a.bin"); let dest_b = temp.path().join("dest_b.bin"); - let group = session.new_file_download_group().unwrap().build().await.unwrap(); + let group = session + .new_file_download_group() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); group.download_file_to_path(info_a, dest_a.clone()).await.unwrap(); group.download_file_to_path(info_b, dest_b.clone()).await.unwrap(); group.finish().await.unwrap(); @@ -468,16 +596,19 @@ async fn async_upload_from_path_multiple_files() { #[test] fn blocking_upload_bytes_roundtrip() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); - assert_roundtrip_sync(&session, &temp, b"blocking upload bytes test", "bytes"); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + assert_roundtrip_sync(&session, &endpoint, &temp, b"blocking upload bytes test", "bytes"); } #[test] fn blocking_upload_from_path_roundtrip() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); assert_upload_from_path_roundtrip_sync( &session, + &endpoint, &temp, "source.bin", "dest.bin", @@ -488,13 +619,19 @@ fn blocking_upload_from_path_roundtrip() { #[test] fn blocking_multiple_files_roundtrip() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data_a = b"blocking file A"; let data_b = b"blocking file B is longer"; let (info_a, info_b) = { - let commit = session.new_upload_commit().unwrap().build_blocking().unwrap(); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build_blocking() + .unwrap(); let ha = commit .upload_bytes_blocking(data_a.to_vec(), Sha256Policy::Compute, Some("a.bin".into())) .unwrap(); @@ -509,7 +646,12 @@ fn blocking_multiple_files_roundtrip() { let dest_a = temp.path().join("a.out"); let dest_b = temp.path().join("b.out"); - let group = session.new_file_download_group().unwrap().build_blocking().unwrap(); + let group = session + .new_file_download_group() + .unwrap() + .with_endpoint(&endpoint) + .build_blocking() + .unwrap(); group.download_file_to_path_blocking(info_a, dest_a.clone()).unwrap(); group.download_file_to_path_blocking(info_b, dest_b.clone()).unwrap(); group.finish_blocking().unwrap(); @@ -521,17 +663,24 @@ fn blocking_multiple_files_roundtrip() { #[test] fn blocking_large_file_roundtrip() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data: Vec = (0..65536u64).map(|i| (i % 251) as u8).collect(); - assert_roundtrip_sync(&session, &temp, &data, "large"); + assert_roundtrip_sync(&session, &endpoint, &temp, &data, "large"); } #[test] fn blocking_task_status_transitions() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); - let commit = session.new_upload_commit().unwrap().build_blocking().unwrap(); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build_blocking() + .unwrap(); let handle = commit .upload_bytes_blocking(b"status blocking".to_vec(), Sha256Policy::Compute, Some("status.bin".into())) .unwrap(); @@ -543,10 +692,16 @@ fn blocking_task_status_transitions() { #[test] fn blocking_progress_tracking() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data = b"blocking progress tracking data"; - let commit = session.new_upload_commit().unwrap().build_blocking().unwrap(); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build_blocking() + .unwrap(); commit .upload_bytes_blocking(data.to_vec(), Sha256Policy::Compute, Some("prog.bin".into())) .unwrap(); @@ -561,19 +716,30 @@ fn blocking_progress_tracking() { #[test] fn blocking_multiple_commits_and_groups() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); - let info_a = upload_bytes_sync(&session, b"blocking commit A", "a.bin"); - let info_b = upload_bytes_sync(&session, b"blocking commit B", "b.bin"); + let info_a = upload_bytes_sync(&session, &endpoint, b"blocking commit A", "a.bin"); + let info_b = upload_bytes_sync(&session, &endpoint, b"blocking commit B", "b.bin"); let dest_a = temp.path().join("a.out"); - let group1 = session.new_file_download_group().unwrap().build_blocking().unwrap(); + let group1 = session + .new_file_download_group() + .unwrap() + .with_endpoint(&endpoint) + .build_blocking() + .unwrap(); group1.download_file_to_path_blocking(info_a, dest_a.clone()).unwrap(); group1.finish_blocking().unwrap(); assert_eq!(fs::read(&dest_a).unwrap(), b"blocking commit A"); let dest_b = temp.path().join("b.out"); - let group2 = session.new_file_download_group().unwrap().build_blocking().unwrap(); + let group2 = session + .new_file_download_group() + .unwrap() + .with_endpoint(&endpoint) + .build_blocking() + .unwrap(); group2.download_file_to_path_blocking(info_b, dest_b.clone()).unwrap(); group2.finish_blocking().unwrap(); assert_eq!(fs::read(&dest_b).unwrap(), b"blocking commit B"); @@ -591,9 +757,10 @@ fn bridge_upload_download_roundtrip() { let temp = tempdir().unwrap(); let tag = executor.label().to_string(); Box::pin(async move { - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let payload = format!("{tag} executor roundtrip"); - assert_roundtrip_async(&session, &temp, payload.as_bytes(), &tag).await; + assert_roundtrip_async(&session, &endpoint, &temp, payload.as_bytes(), &tag).await; }) }); } @@ -604,7 +771,8 @@ fn bridge_multiple_files() { let temp = tempdir().unwrap(); let tag = executor.label().to_string(); Box::pin(async move { - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let files: Vec<(String, Vec)> = vec![ (format!("{tag}_a.bin"), format!("{tag} A").into_bytes()), @@ -612,7 +780,13 @@ fn bridge_multiple_files() { ]; let metas = { - let commit = session.new_upload_commit().unwrap().build().await.unwrap(); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let mut metas = Vec::new(); for (name, data) in &files { let h = commit @@ -625,7 +799,13 @@ fn bridge_multiple_files() { metas }; - let group = session.new_file_download_group().unwrap().build().await.unwrap(); + let group = session + .new_file_download_group() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let mut outputs = Vec::new(); for (index, file_meta) in metas.iter().enumerate() { let info = file_meta.xet_info.clone(); @@ -648,10 +828,12 @@ fn bridge_upload_from_path_roundtrip() { let temp = tempdir().unwrap(); let tag = executor.label().to_string(); Box::pin(async move { - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let payload = format!("{tag} upload from path"); assert_upload_from_path_roundtrip_async( &session, + &endpoint, &temp, &format!("src_{tag}.bin"), &format!("dest_{tag}.bin"), @@ -668,9 +850,10 @@ fn bridge_large_file_roundtrip() { let temp = tempdir().unwrap(); let tag = executor.label().to_string(); Box::pin(async move { - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data: Vec = (0..65536u64).map(|i| (i % 251) as u8).collect(); - assert_roundtrip_async(&session, &temp, &data, &format!("large_{tag}")).await; + assert_roundtrip_async(&session, &endpoint, &temp, &data, &format!("large_{tag}")).await; }) }); } @@ -689,9 +872,10 @@ fn deficient_tokio_async_roundtrip_matrix() { let rt = builder(); let temp = tempdir().unwrap(); rt.block_on(async { - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let payload = format!("{label} async roundtrip"); - assert_roundtrip_async(&session, &temp, payload.as_bytes(), label).await; + assert_roundtrip_async(&session, &endpoint, &temp, payload.as_bytes(), label).await; }); } } @@ -701,10 +885,17 @@ fn deficient_tokio_no_drivers_multiple_files() { let rt = tokio::runtime::Builder::new_multi_thread().build().unwrap(); let temp = tempdir().unwrap(); rt.block_on(async { - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let (info_a, info_b) = { - let commit = session.new_upload_commit().unwrap().build().await.unwrap(); + let commit = session + .new_upload_commit() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let ha = commit .upload_bytes(b"deficient A".to_vec(), Sha256Policy::Compute, Some("a.bin".into())) .await @@ -721,7 +912,13 @@ fn deficient_tokio_no_drivers_multiple_files() { let dest_a = temp.path().join("a.out"); let dest_b = temp.path().join("b.out"); - let group = session.new_file_download_group().unwrap().build().await.unwrap(); + let group = session + .new_file_download_group() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); group.download_file_to_path(info_a, dest_a.clone()).await.unwrap(); group.download_file_to_path(info_b, dest_b.clone()).await.unwrap(); group.finish().await.unwrap(); @@ -736,9 +933,11 @@ fn deficient_tokio_no_drivers_upload_from_path() { let rt = build_rt_no_drivers(); let temp = tempdir().unwrap(); rt.block_on(async { - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); assert_upload_from_path_roundtrip_async( &session, + &endpoint, &temp, "src.bin", "dest.bin", @@ -753,9 +952,10 @@ fn deficient_tokio_no_drivers_large_file() { let rt = build_rt_no_drivers(); let temp = tempdir().unwrap(); rt.block_on(async { - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data: Vec = (0..65536u64).map(|i| (i % 251) as u8).collect(); - assert_roundtrip_async(&session, &temp, &data, "large_deficient").await; + assert_roundtrip_async(&session, &endpoint, &temp, &data, "large_deficient").await; }); } @@ -770,10 +970,11 @@ fn deficient_tokio_handle_auto_fallback_blocking_roundtrip() { ] { let rt = builder(); let temp = tempdir().unwrap(); - let session = rt.block_on(async { local_session(&temp).unwrap() }); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let session = rt.block_on(async { XetSessionBuilder::new().build().unwrap() }); let payload = format!("{label} handle blocking roundtrip"); - assert_roundtrip_sync(&session, &temp, payload.as_bytes(), &format!("{label}_blocking")); + assert_roundtrip_sync(&session, &endpoint, &temp, payload.as_bytes(), &format!("{label}_blocking")); } } @@ -787,11 +988,12 @@ fn deficient_tokio_handle_auto_fallback_blocking_roundtrip() { fn blocking_in_non_tokio_executor_roundtrip() { run_on_all_non_tokio_executors(|executor| { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let tag = executor.label().to_string(); Box::pin(async move { let payload = format!("blocking in {tag}"); - assert_roundtrip_sync(&session, &temp, payload.as_bytes(), &format!("blocking_{tag}")); + assert_roundtrip_sync(&session, &endpoint, &temp, payload.as_bytes(), &format!("blocking_{tag}")); }) }); } @@ -800,12 +1002,14 @@ fn blocking_in_non_tokio_executor_roundtrip() { fn blocking_in_non_tokio_executor_upload_from_path() { run_on_all_non_tokio_executors(|executor| { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let tag = executor.label().to_string(); Box::pin(async move { let payload = format!("blocking {tag} upload from path"); assert_upload_from_path_roundtrip_sync( &session, + &endpoint, &temp, &format!("src_{tag}.bin"), &format!("dest_{tag}.bin"), @@ -903,11 +1107,12 @@ async fn async_abort_rejects_download_on_existing_group() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_duplicate_content_produces_same_hash() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data = b"deduplication test content"; - let info1 = upload_bytes_async(&session, data, "first.bin").await; - let info2 = upload_bytes_async(&session, data, "second.bin").await; + let info1 = upload_bytes_async(&session, &endpoint, data, "first.bin").await; + let info2 = upload_bytes_async(&session, &endpoint, data, "second.bin").await; assert_eq!(info1.hash, info2.hash); assert_eq!(info1.file_size, info2.file_size); @@ -919,13 +1124,21 @@ async fn async_duplicate_content_produces_same_hash() { async fn async_separate_sessions_are_isolated() { let temp1 = tempdir().unwrap(); let temp2 = tempdir().unwrap(); - let session1 = local_session(&temp1).unwrap(); - let session2 = local_session(&temp2).unwrap(); + let session1 = XetSessionBuilder::new().build().unwrap(); + let endpoint1 = format!("local://{}", temp1.path().join("cas").display()); + let session2 = XetSessionBuilder::new().build().unwrap(); + let endpoint2 = format!("local://{}", temp2.path().join("cas").display()); - let info1 = upload_bytes_async(&session1, b"session 1 data", "s1.bin").await; + let info1 = upload_bytes_async(&session1, &endpoint1, b"session 1 data", "s1.bin").await; // Data from session1 should not be downloadable from session2 (different CAS store). - let group = session2.new_file_download_group().unwrap().build().await.unwrap(); + let group = session2 + .new_file_download_group() + .unwrap() + .with_endpoint(&endpoint2) + .build() + .await + .unwrap(); group .download_file_to_path(info1, temp2.path().join("cross.bin")) .await @@ -935,12 +1148,23 @@ async fn async_separate_sessions_are_isolated() { // ── 10. Streaming download (XetDownloadStream) ────────────────────────── -async fn async_stream_group(session: &XetSession) -> XetDownloadStreamGroup { - session.new_download_stream_group().unwrap().build().await.unwrap() +async fn async_stream_group(session: &XetSession, endpoint: &str) -> XetDownloadStreamGroup { + session + .new_download_stream_group() + .unwrap() + .with_endpoint(endpoint) + .build() + .await + .unwrap() } -fn sync_stream_group(session: &XetSession) -> XetDownloadStreamGroup { - session.new_download_stream_group().unwrap().build_blocking().unwrap() +fn sync_stream_group(session: &XetSession, endpoint: &str) -> XetDownloadStreamGroup { + session + .new_download_stream_group() + .unwrap() + .with_endpoint(endpoint) + .build_blocking() + .unwrap() } async fn collect_stream(stream: &mut XetDownloadStream) -> Vec { @@ -962,11 +1186,12 @@ fn collect_stream_blocking(stream: &mut XetDownloadStream) -> Vec { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_stream_roundtrip() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data = b"async streaming download roundtrip"; - let file_info = upload_bytes_async(&session, data, "stream.bin").await; + let file_info = upload_bytes_async(&session, &endpoint, data, "stream.bin").await; - let group = async_stream_group(&session).await; + let group = async_stream_group(&session, &endpoint).await; let mut stream = group.download_stream(file_info, None).await.unwrap(); assert_eq!(collect_stream(&mut stream).await, data); } @@ -974,11 +1199,12 @@ async fn async_stream_roundtrip() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_stream_large_file() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data: Vec = (0..65536u64).map(|i| (i % 251) as u8).collect(); - let file_info = upload_bytes_async(&session, &data, "large_stream.bin").await; + let file_info = upload_bytes_async(&session, &endpoint, &data, "large_stream.bin").await; - let group = async_stream_group(&session).await; + let group = async_stream_group(&session, &endpoint).await; let mut stream = group.download_stream(file_info, None).await.unwrap(); assert_eq!(collect_stream(&mut stream).await, data); } @@ -986,11 +1212,12 @@ async fn async_stream_large_file() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_stream_progress_tracking() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data = b"stream progress tracking integration test"; - let file_info = upload_bytes_async(&session, data, "progress_stream.bin").await; + let file_info = upload_bytes_async(&session, &endpoint, data, "progress_stream.bin").await; - let group = async_stream_group(&session).await; + let group = async_stream_group(&session, &endpoint).await; let mut stream = group.download_stream(file_info, None).await.unwrap(); let initial = stream.progress(); @@ -1007,14 +1234,15 @@ async fn async_stream_progress_tracking() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_stream_multiple_sequential() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data_a = b"stream sequential A"; let data_b = b"stream sequential B is different"; - let info_a = upload_bytes_async(&session, data_a, "seq_a.bin").await; - let info_b = upload_bytes_async(&session, data_b, "seq_b.bin").await; + let info_a = upload_bytes_async(&session, &endpoint, data_a, "seq_a.bin").await; + let info_b = upload_bytes_async(&session, &endpoint, data_b, "seq_b.bin").await; - let group = async_stream_group(&session).await; + let group = async_stream_group(&session, &endpoint).await; let mut stream_a = group.download_stream(info_a, None).await.unwrap(); assert_eq!(collect_stream(&mut stream_a).await, data_a); @@ -1025,11 +1253,12 @@ async fn async_stream_multiple_sequential() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_stream_cancel_before_consuming() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data = b"stream cancel test data"; - let file_info = upload_bytes_async(&session, data, "cancel_stream.bin").await; + let file_info = upload_bytes_async(&session, &endpoint, data, "cancel_stream.bin").await; - let group = async_stream_group(&session).await; + let group = async_stream_group(&session, &endpoint).await; let mut stream = group.download_stream(file_info, None).await.unwrap(); stream.cancel(); assert!(stream.next().await.unwrap().is_none()); @@ -1037,8 +1266,7 @@ async fn async_stream_cancel_before_consuming() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_stream_aborted_session() { - let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); session.abort().unwrap(); let result = session.new_download_stream_group(); assert!(matches!(result, Err(SessionError::UserCancelled(_)))); @@ -1047,11 +1275,18 @@ async fn async_stream_aborted_session() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_stream_abort_cancels_active_stream() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data: Vec = (0..65536u64).map(|i| (i % 251) as u8).collect(); - let file_info = upload_bytes_async(&session, &data, "abort_stream.bin").await; + let file_info = upload_bytes_async(&session, &endpoint, &data, "abort_stream.bin").await; - let group = session.new_download_stream_group().unwrap().build().await.unwrap(); + let group = session + .new_download_stream_group() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let mut stream = group.download_stream(file_info, None).await.unwrap(); session.abort().unwrap(); @@ -1066,11 +1301,12 @@ async fn async_stream_abort_cancels_active_stream() { #[test] fn blocking_stream_roundtrip() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data = b"blocking streaming download roundtrip"; - let file_info = upload_bytes_sync(&session, data, "stream.bin"); + let file_info = upload_bytes_sync(&session, &endpoint, data, "stream.bin"); - let group = sync_stream_group(&session); + let group = sync_stream_group(&session, &endpoint); let mut stream = group.download_stream_blocking(file_info, None).unwrap(); assert_eq!(collect_stream_blocking(&mut stream), data); } @@ -1078,11 +1314,12 @@ fn blocking_stream_roundtrip() { #[test] fn blocking_stream_large_file() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data: Vec = (0..65536u64).map(|i| (i % 251) as u8).collect(); - let file_info = upload_bytes_sync(&session, &data, "large_stream.bin"); + let file_info = upload_bytes_sync(&session, &endpoint, &data, "large_stream.bin"); - let group = sync_stream_group(&session); + let group = sync_stream_group(&session, &endpoint); let mut stream = group.download_stream_blocking(file_info, None).unwrap(); assert_eq!(collect_stream_blocking(&mut stream), data); } @@ -1090,11 +1327,12 @@ fn blocking_stream_large_file() { #[test] fn blocking_stream_progress_tracking() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data = b"blocking stream progress integration test"; - let file_info = upload_bytes_sync(&session, data, "progress_stream.bin"); + let file_info = upload_bytes_sync(&session, &endpoint, data, "progress_stream.bin"); - let group = sync_stream_group(&session); + let group = sync_stream_group(&session, &endpoint); let mut stream = group.download_stream_blocking(file_info, None).unwrap(); let _ = collect_stream_blocking(&mut stream); @@ -1106,14 +1344,15 @@ fn blocking_stream_progress_tracking() { #[test] fn blocking_stream_multiple_sequential() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data_a = b"blocking stream seq A"; let data_b = b"blocking stream seq B is longer"; - let info_a = upload_bytes_sync(&session, data_a, "seq_a.bin"); - let info_b = upload_bytes_sync(&session, data_b, "seq_b.bin"); + let info_a = upload_bytes_sync(&session, &endpoint, data_a, "seq_a.bin"); + let info_b = upload_bytes_sync(&session, &endpoint, data_b, "seq_b.bin"); - let group = sync_stream_group(&session); + let group = sync_stream_group(&session, &endpoint); let mut stream_a = group.download_stream_blocking(info_a, None).unwrap(); assert_eq!(collect_stream_blocking(&mut stream_a), data_a); @@ -1131,8 +1370,10 @@ fn blocking_stream_aborted_session() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn external_mode_blocking_stream_returns_wrong_mode() { + let temp = tempdir().unwrap(); let session = XetSessionBuilder::new().build().unwrap(); - let group = async_stream_group(&session).await; + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let group = async_stream_group(&session, &endpoint).await; let result = group.download_stream_blocking( XetFileInfo { hash: "abc".to_string(), @@ -1150,11 +1391,13 @@ fn bridge_stream_roundtrip() { let temp = tempdir().unwrap(); let tag = executor.label().to_string(); Box::pin(async move { - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let payload = format!("{tag} stream roundtrip"); - let file_info = upload_bytes_async(&session, payload.as_bytes(), &format!("{tag}_stream.bin")).await; + let file_info = + upload_bytes_async(&session, &endpoint, payload.as_bytes(), &format!("{tag}_stream.bin")).await; - let group = async_stream_group(&session).await; + let group = async_stream_group(&session, &endpoint).await; let mut stream = group.download_stream(file_info, None).await.unwrap(); assert_eq!(collect_stream(&mut stream).await, payload.as_bytes()); }) @@ -1167,11 +1410,13 @@ fn deficient_tokio_stream_roundtrip() { let rt = builder(); let temp = tempdir().unwrap(); rt.block_on(async { - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let payload = format!("{label} deficient stream"); - let file_info = upload_bytes_async(&session, payload.as_bytes(), &format!("{label}_stream.bin")).await; + let file_info = + upload_bytes_async(&session, &endpoint, payload.as_bytes(), &format!("{label}_stream.bin")).await; - let group = async_stream_group(&session).await; + let group = async_stream_group(&session, &endpoint).await; let mut stream = group.download_stream(file_info, None).await.unwrap(); assert_eq!(collect_stream(&mut stream).await, payload.as_bytes()); }); @@ -1182,13 +1427,14 @@ fn deficient_tokio_stream_roundtrip() { fn blocking_stream_in_non_tokio_executor() { run_on_all_non_tokio_executors(|executor| { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let tag = executor.label().to_string(); Box::pin(async move { let payload = format!("blocking stream in {tag}"); - let file_info = upload_bytes_sync(&session, payload.as_bytes(), &format!("{tag}_stream.bin")); + let file_info = upload_bytes_sync(&session, &endpoint, payload.as_bytes(), &format!("{tag}_stream.bin")); - let group = sync_stream_group(&session); + let group = sync_stream_group(&session, &endpoint); let mut stream = group.download_stream_blocking(file_info, None).unwrap(); assert_eq!(collect_stream_blocking(&mut stream), payload.as_bytes()); }) @@ -1224,11 +1470,12 @@ fn collect_unordered_stream_blocking(stream: &mut XetUnorderedDownloadStream, ex #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_unordered_stream_roundtrip() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data = b"async unordered streaming download roundtrip"; - let file_info = upload_bytes_async(&session, data, "unordered.bin").await; + let file_info = upload_bytes_async(&session, &endpoint, data, "unordered.bin").await; - let group = async_stream_group(&session).await; + let group = async_stream_group(&session, &endpoint).await; let mut stream = group.download_unordered_stream(file_info, None).await.unwrap(); assert_eq!(collect_unordered_stream(&mut stream, data.len()).await, data); } @@ -1236,11 +1483,12 @@ async fn async_unordered_stream_roundtrip() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_unordered_stream_large_file() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data: Vec = (0..65536u64).map(|i| (i % 251) as u8).collect(); - let file_info = upload_bytes_async(&session, &data, "large_unordered.bin").await; + let file_info = upload_bytes_async(&session, &endpoint, &data, "large_unordered.bin").await; - let group = async_stream_group(&session).await; + let group = async_stream_group(&session, &endpoint).await; let mut stream = group.download_unordered_stream(file_info, None).await.unwrap(); assert_eq!(collect_unordered_stream(&mut stream, data.len()).await, data); } @@ -1248,11 +1496,12 @@ async fn async_unordered_stream_large_file() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_unordered_stream_progress_tracking() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data = b"unordered stream progress tracking integration test"; - let file_info = upload_bytes_async(&session, data, "progress_unordered.bin").await; + let file_info = upload_bytes_async(&session, &endpoint, data, "progress_unordered.bin").await; - let group = async_stream_group(&session).await; + let group = async_stream_group(&session, &endpoint).await; let mut stream = group.download_unordered_stream(file_info, None).await.unwrap(); let initial = stream.progress(); @@ -1269,11 +1518,12 @@ async fn async_unordered_stream_progress_tracking() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_unordered_stream_cancel_before_consuming() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data = b"unordered stream cancel test data"; - let file_info = upload_bytes_async(&session, data, "cancel_unordered.bin").await; + let file_info = upload_bytes_async(&session, &endpoint, data, "cancel_unordered.bin").await; - let group = async_stream_group(&session).await; + let group = async_stream_group(&session, &endpoint).await; let mut stream = group.download_unordered_stream(file_info, None).await.unwrap(); stream.cancel(); assert!(stream.next().await.unwrap().is_none()); @@ -1281,8 +1531,7 @@ async fn async_unordered_stream_cancel_before_consuming() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_unordered_stream_aborted_session() { - let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); session.abort().unwrap(); let result = session.new_download_stream_group(); assert!(matches!(result, Err(SessionError::UserCancelled(_)))); @@ -1291,11 +1540,18 @@ async fn async_unordered_stream_aborted_session() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_unordered_stream_abort_cancels_active_stream() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data: Vec = (0..65536u64).map(|i| (i % 251) as u8).collect(); - let file_info = upload_bytes_async(&session, &data, "abort_unordered_stream.bin").await; + let file_info = upload_bytes_async(&session, &endpoint, &data, "abort_unordered_stream.bin").await; - let group = session.new_download_stream_group().unwrap().build().await.unwrap(); + let group = session + .new_download_stream_group() + .unwrap() + .with_endpoint(&endpoint) + .build() + .await + .unwrap(); let mut stream = group.download_unordered_stream(file_info, None).await.unwrap(); session.abort().unwrap(); @@ -1310,11 +1566,12 @@ async fn async_unordered_stream_abort_cancels_active_stream() { #[test] fn blocking_unordered_stream_roundtrip() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data = b"blocking unordered streaming download roundtrip"; - let file_info = upload_bytes_sync(&session, data, "unordered.bin"); + let file_info = upload_bytes_sync(&session, &endpoint, data, "unordered.bin"); - let group = sync_stream_group(&session); + let group = sync_stream_group(&session, &endpoint); let mut stream = group.download_unordered_stream_blocking(file_info, None).unwrap(); assert_eq!(collect_unordered_stream_blocking(&mut stream, data.len()), data); } @@ -1322,11 +1579,12 @@ fn blocking_unordered_stream_roundtrip() { #[test] fn blocking_unordered_stream_large_file() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data: Vec = (0..65536u64).map(|i| (i % 251) as u8).collect(); - let file_info = upload_bytes_sync(&session, &data, "large_unordered.bin"); + let file_info = upload_bytes_sync(&session, &endpoint, &data, "large_unordered.bin"); - let group = sync_stream_group(&session); + let group = sync_stream_group(&session, &endpoint); let mut stream = group.download_unordered_stream_blocking(file_info, None).unwrap(); assert_eq!(collect_unordered_stream_blocking(&mut stream, data.len()), data); } @@ -1334,11 +1592,12 @@ fn blocking_unordered_stream_large_file() { #[test] fn blocking_unordered_stream_progress_tracking() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data = b"blocking unordered stream progress integration test"; - let file_info = upload_bytes_sync(&session, data, "progress_unordered.bin"); + let file_info = upload_bytes_sync(&session, &endpoint, data, "progress_unordered.bin"); - let group = sync_stream_group(&session); + let group = sync_stream_group(&session, &endpoint); let mut stream = group.download_unordered_stream_blocking(file_info, None).unwrap(); let _ = collect_unordered_stream_blocking(&mut stream, data.len()); @@ -1357,8 +1616,10 @@ fn blocking_unordered_stream_aborted_session() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn external_mode_blocking_unordered_stream_returns_wrong_mode() { + let temp = tempdir().unwrap(); let session = XetSessionBuilder::new().build().unwrap(); - let group = async_stream_group(&session).await; + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let group = async_stream_group(&session, &endpoint).await; let result = group.download_unordered_stream_blocking( XetFileInfo { hash: "abc".to_string(), @@ -1376,11 +1637,13 @@ fn bridge_unordered_stream_roundtrip() { let temp = tempdir().unwrap(); let tag = executor.label().to_string(); Box::pin(async move { - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let payload = format!("{tag} unordered stream roundtrip"); - let file_info = upload_bytes_async(&session, payload.as_bytes(), &format!("{tag}_unordered.bin")).await; + let file_info = + upload_bytes_async(&session, &endpoint, payload.as_bytes(), &format!("{tag}_unordered.bin")).await; - let group = async_stream_group(&session).await; + let group = async_stream_group(&session, &endpoint).await; let mut stream = group.download_unordered_stream(file_info, None).await.unwrap(); assert_eq!(collect_unordered_stream(&mut stream, payload.len()).await, payload.as_bytes()); }) @@ -1393,11 +1656,13 @@ fn deficient_tokio_unordered_stream_roundtrip() { let rt = builder(); let temp = tempdir().unwrap(); rt.block_on(async { - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let payload = format!("{label} deficient unordered stream"); - let file_info = upload_bytes_async(&session, payload.as_bytes(), &format!("{label}_unordered.bin")).await; + let file_info = + upload_bytes_async(&session, &endpoint, payload.as_bytes(), &format!("{label}_unordered.bin")).await; - let group = async_stream_group(&session).await; + let group = async_stream_group(&session, &endpoint).await; let mut stream = group.download_unordered_stream(file_info, None).await.unwrap(); assert_eq!(collect_unordered_stream(&mut stream, payload.len()).await, payload.as_bytes()); }); @@ -1408,13 +1673,14 @@ fn deficient_tokio_unordered_stream_roundtrip() { fn blocking_unordered_stream_in_non_tokio_executor() { run_on_all_non_tokio_executors(|executor| { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let tag = executor.label().to_string(); Box::pin(async move { let payload = format!("blocking unordered stream in {tag}"); - let file_info = upload_bytes_sync(&session, payload.as_bytes(), &format!("{tag}_unordered.bin")); + let file_info = upload_bytes_sync(&session, &endpoint, payload.as_bytes(), &format!("{tag}_unordered.bin")); - let group = sync_stream_group(&session); + let group = sync_stream_group(&session, &endpoint); let mut stream = group.download_unordered_stream_blocking(file_info, None).unwrap(); assert_eq!(collect_unordered_stream_blocking(&mut stream, payload.len()), payload.as_bytes()); }) @@ -1436,10 +1702,11 @@ const RANGE_TEST_DATA: &[u8; 256] = &{ #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_stream_range_middle() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); - let file_info = upload_bytes_async(&session, RANGE_TEST_DATA, "range.bin").await; + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let file_info = upload_bytes_async(&session, &endpoint, RANGE_TEST_DATA, "range.bin").await; - let group = async_stream_group(&session).await; + let group = async_stream_group(&session, &endpoint).await; let mut stream = group.download_stream(file_info, Some(64..192)).await.unwrap(); assert_eq!(collect_stream(&mut stream).await, &RANGE_TEST_DATA[64..192]); } @@ -1447,10 +1714,11 @@ async fn async_stream_range_middle() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_stream_range_from_start() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); - let file_info = upload_bytes_async(&session, RANGE_TEST_DATA, "range_start.bin").await; + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let file_info = upload_bytes_async(&session, &endpoint, RANGE_TEST_DATA, "range_start.bin").await; - let group = async_stream_group(&session).await; + let group = async_stream_group(&session, &endpoint).await; let mut stream = group.download_stream(file_info, Some(0..100)).await.unwrap(); assert_eq!(collect_stream(&mut stream).await, &RANGE_TEST_DATA[..100]); } @@ -1458,10 +1726,11 @@ async fn async_stream_range_from_start() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_stream_range_to_end() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); - let file_info = upload_bytes_async(&session, RANGE_TEST_DATA, "range_end.bin").await; + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let file_info = upload_bytes_async(&session, &endpoint, RANGE_TEST_DATA, "range_end.bin").await; - let group = async_stream_group(&session).await; + let group = async_stream_group(&session, &endpoint).await; let mut stream = group.download_stream(file_info, Some(200..256)).await.unwrap(); assert_eq!(collect_stream(&mut stream).await, &RANGE_TEST_DATA[200..]); } @@ -1469,10 +1738,11 @@ async fn async_stream_range_to_end() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_stream_range_full() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); - let file_info = upload_bytes_async(&session, RANGE_TEST_DATA, "range_full.bin").await; + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let file_info = upload_bytes_async(&session, &endpoint, RANGE_TEST_DATA, "range_full.bin").await; - let group = async_stream_group(&session).await; + let group = async_stream_group(&session, &endpoint).await; let mut stream = group.download_stream(file_info, Some(0..256)).await.unwrap(); assert_eq!(collect_stream(&mut stream).await, RANGE_TEST_DATA.as_slice()); } @@ -1480,10 +1750,11 @@ async fn async_stream_range_full() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_stream_range_progress() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); - let file_info = upload_bytes_async(&session, RANGE_TEST_DATA, "range_progress.bin").await; + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let file_info = upload_bytes_async(&session, &endpoint, RANGE_TEST_DATA, "range_progress.bin").await; - let group = async_stream_group(&session).await; + let group = async_stream_group(&session, &endpoint).await; let mut stream = group.download_stream(file_info, Some(50..150)).await.unwrap(); let initial = stream.progress(); @@ -1500,10 +1771,11 @@ async fn async_stream_range_progress() { #[test] fn blocking_stream_range_middle() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); - let file_info = upload_bytes_sync(&session, RANGE_TEST_DATA, "range.bin"); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let file_info = upload_bytes_sync(&session, &endpoint, RANGE_TEST_DATA, "range.bin"); - let group = sync_stream_group(&session); + let group = sync_stream_group(&session, &endpoint); let mut stream = group.download_stream_blocking(file_info, Some(64..192)).unwrap(); assert_eq!(collect_stream_blocking(&mut stream), &RANGE_TEST_DATA[64..192]); } @@ -1511,10 +1783,11 @@ fn blocking_stream_range_middle() { #[test] fn blocking_stream_range_progress() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); - let file_info = upload_bytes_sync(&session, RANGE_TEST_DATA, "range_progress.bin"); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let file_info = upload_bytes_sync(&session, &endpoint, RANGE_TEST_DATA, "range_progress.bin"); - let group = sync_stream_group(&session); + let group = sync_stream_group(&session, &endpoint); let mut stream = group.download_stream_blocking(file_info, Some(10..110)).unwrap(); let _ = collect_stream_blocking(&mut stream); @@ -1526,10 +1799,11 @@ fn blocking_stream_range_progress() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_unordered_stream_range_middle() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); - let file_info = upload_bytes_async(&session, RANGE_TEST_DATA, "unord_range.bin").await; + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let file_info = upload_bytes_async(&session, &endpoint, RANGE_TEST_DATA, "unord_range.bin").await; - let group = async_stream_group(&session).await; + let group = async_stream_group(&session, &endpoint).await; let mut stream = group.download_unordered_stream(file_info, Some(64..192)).await.unwrap(); assert_eq!(collect_unordered_stream(&mut stream, 128).await, &RANGE_TEST_DATA[64..192]); } @@ -1537,10 +1811,11 @@ async fn async_unordered_stream_range_middle() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_unordered_stream_range_from_start() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); - let file_info = upload_bytes_async(&session, RANGE_TEST_DATA, "unord_range_start.bin").await; + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let file_info = upload_bytes_async(&session, &endpoint, RANGE_TEST_DATA, "unord_range_start.bin").await; - let group = async_stream_group(&session).await; + let group = async_stream_group(&session, &endpoint).await; let mut stream = group.download_unordered_stream(file_info, Some(0..100)).await.unwrap(); assert_eq!(collect_unordered_stream(&mut stream, 100).await, &RANGE_TEST_DATA[..100]); } @@ -1548,10 +1823,11 @@ async fn async_unordered_stream_range_from_start() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_unordered_stream_range_to_end() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); - let file_info = upload_bytes_async(&session, RANGE_TEST_DATA, "unord_range_end.bin").await; + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let file_info = upload_bytes_async(&session, &endpoint, RANGE_TEST_DATA, "unord_range_end.bin").await; - let group = async_stream_group(&session).await; + let group = async_stream_group(&session, &endpoint).await; let mut stream = group.download_unordered_stream(file_info, Some(200..256)).await.unwrap(); assert_eq!(collect_unordered_stream(&mut stream, 56).await, &RANGE_TEST_DATA[200..]); } @@ -1559,10 +1835,11 @@ async fn async_unordered_stream_range_to_end() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_unordered_stream_range_progress() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); - let file_info = upload_bytes_async(&session, RANGE_TEST_DATA, "unord_range_progress.bin").await; + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let file_info = upload_bytes_async(&session, &endpoint, RANGE_TEST_DATA, "unord_range_progress.bin").await; - let group = async_stream_group(&session).await; + let group = async_stream_group(&session, &endpoint).await; let mut stream = group.download_unordered_stream(file_info, Some(50..150)).await.unwrap(); let initial = stream.progress(); @@ -1579,10 +1856,11 @@ async fn async_unordered_stream_range_progress() { #[test] fn blocking_unordered_stream_range_middle() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); - let file_info = upload_bytes_sync(&session, RANGE_TEST_DATA, "unord_range.bin"); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let file_info = upload_bytes_sync(&session, &endpoint, RANGE_TEST_DATA, "unord_range.bin"); - let group = sync_stream_group(&session); + let group = sync_stream_group(&session, &endpoint); let mut stream = group.download_unordered_stream_blocking(file_info, Some(64..192)).unwrap(); assert_eq!(collect_unordered_stream_blocking(&mut stream, 128), &RANGE_TEST_DATA[64..192]); } @@ -1590,10 +1868,11 @@ fn blocking_unordered_stream_range_middle() { #[test] fn blocking_unordered_stream_range_progress() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); - let file_info = upload_bytes_sync(&session, RANGE_TEST_DATA, "unord_range_progress.bin"); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let file_info = upload_bytes_sync(&session, &endpoint, RANGE_TEST_DATA, "unord_range_progress.bin"); - let group = sync_stream_group(&session); + let group = sync_stream_group(&session, &endpoint); let mut stream = group.download_unordered_stream_blocking(file_info, Some(10..110)).unwrap(); let _ = collect_unordered_stream_blocking(&mut stream, 100); @@ -1608,10 +1887,12 @@ fn bridge_stream_range_roundtrip() { let temp = tempdir().unwrap(); let tag = executor.label().to_string(); Box::pin(async move { - let session = local_session(&temp).unwrap(); - let file_info = upload_bytes_async(&session, RANGE_TEST_DATA, &format!("{tag}_range_stream.bin")).await; + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let file_info = + upload_bytes_async(&session, &endpoint, RANGE_TEST_DATA, &format!("{tag}_range_stream.bin")).await; - let group = async_stream_group(&session).await; + let group = async_stream_group(&session, &endpoint).await; let mut stream = group.download_stream(file_info, Some(30..200)).await.unwrap(); assert_eq!(collect_stream(&mut stream).await, &RANGE_TEST_DATA[30..200]); }) @@ -1624,10 +1905,12 @@ fn bridge_unordered_stream_range_roundtrip() { let temp = tempdir().unwrap(); let tag = executor.label().to_string(); Box::pin(async move { - let session = local_session(&temp).unwrap(); - let file_info = upload_bytes_async(&session, RANGE_TEST_DATA, &format!("{tag}_range_unord.bin")).await; + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let file_info = + upload_bytes_async(&session, &endpoint, RANGE_TEST_DATA, &format!("{tag}_range_unord.bin")).await; - let group = async_stream_group(&session).await; + let group = async_stream_group(&session, &endpoint).await; let mut stream = group.download_unordered_stream(file_info, Some(30..200)).await.unwrap(); assert_eq!(collect_unordered_stream(&mut stream, 170).await, &RANGE_TEST_DATA[30..200]); }) @@ -1640,10 +1923,12 @@ fn deficient_tokio_stream_range_roundtrip() { let rt = builder(); let temp = tempdir().unwrap(); rt.block_on(async { - let session = local_session(&temp).unwrap(); - let file_info = upload_bytes_async(&session, RANGE_TEST_DATA, &format!("{label}_range_stream.bin")).await; + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let file_info = + upload_bytes_async(&session, &endpoint, RANGE_TEST_DATA, &format!("{label}_range_stream.bin")).await; - let group = async_stream_group(&session).await; + let group = async_stream_group(&session, &endpoint).await; let mut stream = group.download_stream(file_info, Some(40..180)).await.unwrap(); assert_eq!(collect_stream(&mut stream).await, &RANGE_TEST_DATA[40..180]); }); @@ -1656,10 +1941,12 @@ fn deficient_tokio_unordered_stream_range_roundtrip() { let rt = builder(); let temp = tempdir().unwrap(); rt.block_on(async { - let session = local_session(&temp).unwrap(); - let file_info = upload_bytes_async(&session, RANGE_TEST_DATA, &format!("{label}_range_unord.bin")).await; + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); + let file_info = + upload_bytes_async(&session, &endpoint, RANGE_TEST_DATA, &format!("{label}_range_unord.bin")).await; - let group = async_stream_group(&session).await; + let group = async_stream_group(&session, &endpoint).await; let mut stream = group.download_unordered_stream(file_info, Some(40..180)).await.unwrap(); assert_eq!(collect_unordered_stream(&mut stream, 140).await, &RANGE_TEST_DATA[40..180]); }); @@ -1670,12 +1957,13 @@ fn deficient_tokio_unordered_stream_range_roundtrip() { fn blocking_stream_range_in_non_tokio_executor() { run_on_all_non_tokio_executors(|executor| { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let tag = executor.label().to_string(); Box::pin(async move { - let file_info = upload_bytes_sync(&session, RANGE_TEST_DATA, &format!("{tag}_range_stream.bin")); + let file_info = upload_bytes_sync(&session, &endpoint, RANGE_TEST_DATA, &format!("{tag}_range_stream.bin")); - let group = sync_stream_group(&session); + let group = sync_stream_group(&session, &endpoint); let mut stream = group.download_stream_blocking(file_info, Some(20..220)).unwrap(); assert_eq!(collect_stream_blocking(&mut stream), &RANGE_TEST_DATA[20..220]); }) @@ -1686,12 +1974,13 @@ fn blocking_stream_range_in_non_tokio_executor() { fn blocking_unordered_stream_range_in_non_tokio_executor() { run_on_all_non_tokio_executors(|executor| { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let tag = executor.label().to_string(); Box::pin(async move { - let file_info = upload_bytes_sync(&session, RANGE_TEST_DATA, &format!("{tag}_range_unord.bin")); + let file_info = upload_bytes_sync(&session, &endpoint, RANGE_TEST_DATA, &format!("{tag}_range_unord.bin")); - let group = sync_stream_group(&session); + let group = sync_stream_group(&session, &endpoint); let mut stream = group.download_unordered_stream_blocking(file_info, Some(20..220)).unwrap(); assert_eq!(collect_unordered_stream_blocking(&mut stream, 200), &RANGE_TEST_DATA[20..220]); }) @@ -1701,11 +1990,12 @@ fn blocking_unordered_stream_range_in_non_tokio_executor() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_stream_range_large_file() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data: Vec = (0..65536u64).map(|i| (i % 251) as u8).collect(); - let file_info = upload_bytes_async(&session, &data, "range_large.bin").await; + let file_info = upload_bytes_async(&session, &endpoint, &data, "range_large.bin").await; - let group = async_stream_group(&session).await; + let group = async_stream_group(&session, &endpoint).await; let mut stream = group.download_stream(file_info, Some(10000..50000)).await.unwrap(); assert_eq!(collect_stream(&mut stream).await, &data[10000..50000]); } @@ -1713,11 +2003,12 @@ async fn async_stream_range_large_file() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_unordered_stream_range_large_file() { let temp = tempdir().unwrap(); - let session = local_session(&temp).unwrap(); + let session = XetSessionBuilder::new().build().unwrap(); + let endpoint = format!("local://{}", temp.path().join("cas").display()); let data: Vec = (0..65536u64).map(|i| (i % 251) as u8).collect(); - let file_info = upload_bytes_async(&session, &data, "range_large_unord.bin").await; + let file_info = upload_bytes_async(&session, &endpoint, &data, "range_large_unord.bin").await; - let group = async_stream_group(&session).await; + let group = async_stream_group(&session, &endpoint).await; let mut stream = group.download_unordered_stream(file_info, Some(10000..50000)).await.unwrap(); assert_eq!(collect_unordered_stream(&mut stream, 40000).await, &data[10000..50000]); }