Support XetSession in async context (#694)

`XetSession` always created its own tokio runtime via
`XetRuntime::new_with_config`, and calling `external_run_async_task`
panics when already inside a tokio context. This blocked embedding the
session in async Rust frameworks.

Core strategy:

 - `RuntimeMode` enum — 
`Owned` (session created its own thread pool via
`XetSessionBuilder::build` or `XetSessionBuilder::build_async` when
outside tokio context. Both `_blocking` and async methods are supported.
Async methods use an internal `bridge_to_owned` bridge that routes
futures onto the owned thread pool, so they work from any executor
(tokio, smol, async-std))
vs
`External` (session wraps a caller-supplied tokio handle via
`XetSessionBuilder::with_tokio_handle` or
`XetSessionBuilder::build_async` when inside qualified tokio context.
Only async methods may be called; `_blocking` methods return
`SessionError::WrongRuntimeMode`. No second thread pool is created).
- `XetRuntime::bridge_to_owned` — a new bridge that routes a future onto
the owned tokio thread pool from any executor (smol, async-std,
futures::executor, non-qualified tokio runtime) by delivering the result
via a `tokio::sync::oneshot` channel that can be polled by any async
executor.
- Async public API — `UploadCommit` and `DownloadGroup` methods
(`upload_from_path`, `upload_bytes`, `upload_file`, `commit`, `finish`)
are now async fn. Factory methods `XetSession::new_upload_commit` and
`new_download_group` are async.
Example:
```
let session = XetSessionBuilder::new().build_async().await?;
// Upload
 let commit = session.new_upload_commit().await?;
 let handle = commit.upload_from_path("file.bin".into()).await?;
 let results = commit.commit().await?;

 // Download
 let group = session.new_download_group().await?;
 let info = XetFileInfo {
     hash: ...,
     file_size: ...,
 };
 let dl_handle = group.download_file_to_path(info, "out/file.bin".into())?;
 let finish_results = group.finish().await?;
```

- Sync wrappers — New `UploadCommitSync` / `DownloadGroupSync` in
`xet_session/sync/` expose a fully blocking API for sync Rust and Python
(PyO3) callers. Returned by `new_upload_commit_blocking()` and
`new_download_group_blocking()`.
Example:
```
let session = XetSessionBuilder::new().build()?;
// Upload
let commit = session.new_upload_commit_blocking()?;
 let handle = commit.upload_from_path("file.bin".into())?;
 let results = commit.commit()?;
 let m = results.values().next().unwrap().as_ref().as_ref().unwrap();

// Download
 let group = session.new_download_group_blocking()?;
 let info = XetFileInfo {
     hash: ...,
     file_size: ...,
 };
 let dl_handle = group.download_file_to_path(info, "out/file.bin".into())?;
 let finish_results = group.finish()?;
```



Additional fixes: `download_file_to_path` and `upload_from_path` now
canonicalize paths with `std::path::absolute` before enqueuing; task
status is only overwritten when still `Running`, preventing a race with
concurrent abort().

Fix XET-891

---------

Co-authored-by: Hoyt Koepke <hoytak@huggingface.co>
This commit is contained in:
Di Xiao
2026-03-13 14:57:20 -07:00
committed by GitHub
parent 3390bdc716
commit e701aeddac
16 changed files with 2727 additions and 933 deletions

274
Cargo.lock generated
View File

@@ -160,17 +160,109 @@ dependencies = [
"serde_json",
]
[[package]]
name = "async-channel"
version = "1.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81953c529336010edd6d8e358f886d9581267795c61b19475b71314bffa46d35"
dependencies = [
"concurrent-queue",
"event-listener 2.5.3",
"futures-core",
]
[[package]]
name = "async-channel"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "924ed96dd52d1b75e9c1a3e6275715fd320f5f9439fb5a4a11fa51f4221158d2"
dependencies = [
"concurrent-queue",
"event-listener-strategy",
"futures-core",
"pin-project-lite",
]
[[package]]
name = "async-executor"
version = "1.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c96bf972d85afc50bf5ab8fe2d54d1586b4e0b46c97c50a0c9e71e2f7bcd812a"
dependencies = [
"async-task",
"concurrent-queue",
"fastrand",
"futures-lite",
"pin-project-lite",
"slab",
]
[[package]]
name = "async-fs"
version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8034a681df4aed8b8edbd7fbe472401ecf009251c8b40556b304567052e294c5"
dependencies = [
"async-lock",
"blocking",
"futures-lite",
]
[[package]]
name = "async-global-executor"
version = "2.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05b1b633a2115cd122d73b955eadd9916c18c8f510ec9cd1686404c60ad1c29c"
dependencies = [
"async-channel 2.5.0",
"async-executor",
"async-io",
"async-lock",
"blocking",
"futures-lite",
"once_cell",
]
[[package]]
name = "async-io"
version = "2.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "456b8a8feb6f42d237746d4b3e9a178494627745c3c56c6ea55d92ba50d026fc"
dependencies = [
"autocfg",
"cfg-if 1.0.4",
"concurrent-queue",
"futures-io",
"futures-lite",
"parking",
"polling",
"rustix",
"slab",
"windows-sys 0.61.2",
]
[[package]]
name = "async-lock"
version = "3.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "290f7f2596bd5b78a9fec8088ccd89180d7f9f55b94b0576823bbbdc72ee8311"
dependencies = [
"event-listener",
"event-listener 5.4.1",
"event-listener-strategy",
"pin-project-lite",
]
[[package]]
name = "async-net"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b948000fad4873c1c9339d60f2623323a0cfd3816e5181033c6a5cb68b2accf7"
dependencies = [
"async-io",
"blocking",
"futures-lite",
]
[[package]]
name = "async-object-pool"
version = "0.2.0"
@@ -178,9 +270,77 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1ac0219111eb7bb7cb76d4cf2cb50c598e7ae549091d3616f9e95442c18486f"
dependencies = [
"async-lock",
"event-listener",
"event-listener 5.4.1",
]
[[package]]
name = "async-process"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc50921ec0055cdd8a16de48773bfeec5c972598674347252c0399676be7da75"
dependencies = [
"async-channel 2.5.0",
"async-io",
"async-lock",
"async-signal",
"async-task",
"blocking",
"cfg-if 1.0.4",
"event-listener 5.4.1",
"futures-lite",
"rustix",
]
[[package]]
name = "async-signal"
version = "0.2.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43c070bbf59cd3570b6b2dd54cd772527c7c3620fce8be898406dd3ed6adc64c"
dependencies = [
"async-io",
"async-lock",
"atomic-waker",
"cfg-if 1.0.4",
"futures-core",
"futures-io",
"rustix",
"signal-hook-registry",
"slab",
"windows-sys 0.61.2",
]
[[package]]
name = "async-std"
version = "1.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c8e079a4ab67ae52b7403632e4618815d6db36d2a010cfe41b02c1b1578f93b"
dependencies = [
"async-channel 1.9.0",
"async-global-executor",
"async-io",
"async-lock",
"crossbeam-utils",
"futures-channel",
"futures-core",
"futures-io",
"futures-lite",
"gloo-timers",
"kv-log-macro",
"log",
"memchr",
"once_cell",
"pin-project-lite",
"pin-utils",
"slab",
"wasm-bindgen-futures",
]
[[package]]
name = "async-task"
version = "4.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b75356056920673b02621b35afd0f7dda9306d03c79a30f5c56c44cf256e3de"
[[package]]
name = "async-trait"
version = "0.1.89"
@@ -383,6 +543,19 @@ dependencies = [
"generic-array 0.14.7",
]
[[package]]
name = "blocking"
version = "1.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e83f8d02be6967315521be875afa792a316e28d57b5a2d401897e2a7921b7f21"
dependencies = [
"async-channel 2.5.0",
"async-task",
"futures-io",
"futures-lite",
"piper",
]
[[package]]
name = "blowfish"
version = "0.9.1"
@@ -1103,6 +1276,12 @@ dependencies = [
"windows-sys 0.61.2",
]
[[package]]
name = "event-listener"
version = "2.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0"
[[package]]
name = "event-listener"
version = "5.4.1"
@@ -1120,7 +1299,7 @@ version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8be9f3dfaaffdae2972880079a491a1a8bb7cbed0b8dd7a347f668b4150a3b93"
dependencies = [
"event-listener",
"event-listener 5.4.1",
"pin-project-lite",
]
@@ -1264,6 +1443,19 @@ version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718"
[[package]]
name = "futures-lite"
version = "2.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad"
dependencies = [
"fastrand",
"futures-core",
"futures-io",
"parking",
"pin-project-lite",
]
[[package]]
name = "futures-macro"
version = "0.3.32"
@@ -1466,6 +1658,18 @@ dependencies = [
"xet-runtime",
]
[[package]]
name = "gloo-timers"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbb143cf96099802033e0d4f4963b19fd2e0b728bcf076cd9cf7f6634f092994"
dependencies = [
"futures-channel",
"futures-core",
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "group"
version = "0.13.0"
@@ -1669,15 +1873,19 @@ name = "hf-xet"
version = "1.4.0"
dependencies = [
"anyhow",
"async-std",
"async-trait",
"clap",
"futures",
"http",
"serde",
"serde_json",
"serial_test",
"smol",
"tempfile",
"thiserror 2.0.18",
"tokio",
"tracing",
"tracing-subscriber",
"ulid",
"xet-client",
@@ -2178,6 +2386,15 @@ version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e037a2e1d8d5fdbd49b16a4ea09d5d6401c1f29eca5ff29d03d3824dba16256a"
[[package]]
name = "kv-log-macro"
version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0de8b303297635ad57c9f5059fd9cee7a47f8e8daa09df0fcd07dd39fb22977f"
dependencies = [
"log",
]
[[package]]
name = "lazy_static"
version = "1.5.0"
@@ -2357,6 +2574,9 @@ name = "log"
version = "0.4.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897"
dependencies = [
"value-bag",
]
[[package]]
name = "lru-slab"
@@ -3006,6 +3226,17 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]]
name = "piper"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c835479a4443ded371d6c535cbfd8d31ad92c5d23ae9770a61bc155e4992a3c1"
dependencies = [
"atomic-waker",
"fastrand",
"futures-io",
]
[[package]]
name = "pkcs1"
version = "0.7.5"
@@ -3050,6 +3281,20 @@ version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
[[package]]
name = "polling"
version = "3.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d0e4f59085d47d8241c88ead0f274e8a0cb551f3625263c05eb8dd897c34218"
dependencies = [
"cfg-if 1.0.4",
"concurrent-queue",
"hermit-abi",
"pin-project-lite",
"rustix",
"windows-sys 0.61.2",
]
[[package]]
name = "poly1305"
version = "0.8.0"
@@ -4238,6 +4483,23 @@ version = "1.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03"
[[package]]
name = "smol"
version = "2.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a33bd3e260892199c3ccfc487c88b2da2265080acb316cd920da72fdfd7c599f"
dependencies = [
"async-channel 2.5.0",
"async-executor",
"async-fs",
"async-io",
"async-lock",
"async-net",
"async-process",
"blocking",
"futures-lite",
]
[[package]]
name = "socket2"
version = "0.6.3"
@@ -5013,6 +5275,12 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65"
[[package]]
name = "value-bag"
version = "1.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ba6f5989077681266825251a52748b8c1d8a4ad098cc37e440103d0ea717fc0"
[[package]]
name = "vcpkg"
version = "0.2.15"

View File

@@ -32,8 +32,9 @@ debug = 1
[workspace.dependencies]
anyhow = "1"
axum = "0.8"
async-std = "1"
async-trait = "0.1"
axum = "0.8"
base64 = "0.22"
bincode = "1.3"
bitflags = { version = "2.10", features = ["serde"] }
@@ -50,7 +51,6 @@ csv = "1"
ctor = "0.6"
derivative = "2.2"
dirs = "6.0"
human-bandwidth = "0.1"
duration-str = "0.19"
futures = "0.3"
futures-util = "0.3"
@@ -63,6 +63,7 @@ half = "2.7"
heapify = "0.2"
heed = "0.22"
http = "1"
human-bandwidth = "0.1"
hyper = "1.8"
hyper-util = "0.1"
itertools = "0.14"
@@ -95,6 +96,7 @@ serde_repr = "0.1"
sha2 = "0.10"
shell-words = "1.1"
shellexpand = "3.1"
smol = "2"
static_assertions = "1.1"
statrs = "0.18"
sysinfo = "0.38"

View File

@@ -20,12 +20,16 @@ xet-data = { version = "1.4.0", path = "../xet_data" }
async-trait = { workspace = true }
http = { workspace = true }
tokio = { workspace = true }
tokio = { workspace = true, features = ["net"] }
ulid = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
serde = { workspace = true, features = ["derive"] }
[dev-dependencies]
async-std = { workspace = true }
futures = { workspace = true }
smol = { workspace = true }
tempfile = { workspace = true }
serial_test = { workspace = true }
anyhow = { workspace = true }

View File

@@ -1,9 +1,9 @@
//! Session-based upload/download example.
//! Async session-based upload/download example.
//!
//! Shows the three-level hierarchy: XetSession → UploadCommit/DownloadGroup → files.
//! Mirror of `example.rs` using the async API (`UploadCommit` / `DownloadGroup`).
//! Requires an async runtime — here provided by `#[tokio::main]`.
use std::path::PathBuf;
use std::time::Duration;
use anyhow::Result;
use clap::{Parser, Subcommand};
@@ -12,7 +12,7 @@ use xet::xet_session::{
};
#[derive(Parser)]
#[clap(name = "session-demo", about = "XetSession API demo")]
#[clap(name = "session-demo-async", about = "XetSession async API demo")]
struct Cli {
#[clap(subcommand)]
command: Command,
@@ -37,52 +37,53 @@ enum Command {
},
}
fn main() -> Result<()> {
#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::fmt::init();
let cli = Cli::parse();
match cli.command {
Command::Upload { files, endpoint } => upload_files(files, endpoint),
Command::Upload { files, endpoint } => upload_files(files, endpoint).await,
Command::Download {
metadata_file,
output_dir,
endpoint,
} => download_files(metadata_file, output_dir, endpoint),
} => download_files(metadata_file, output_dir, endpoint).await,
}
}
fn upload_files(files: Vec<PathBuf>, endpoint: Option<String>) -> Result<()> {
async fn upload_files(files: Vec<PathBuf>, endpoint: Option<String>) -> Result<()> {
let mut builder = XetSessionBuilder::new();
if let Some(ep) = endpoint {
builder = builder.with_endpoint(ep);
}
let session = builder.build()?;
let commit = session.new_upload_commit()?;
let session = builder.build_async().await?;
let commit = session.new_upload_commit().await?;
// Enqueue all uploads; each starts immediately in the background.
let n_files = files.len();
let handles: Vec<UploadTaskHandle> = files
.iter()
.map(|f| commit.upload_from_path(f.clone()))
.collect::<Result<_, _>>()?;
let mut handles = Vec::with_capacity(n_files);
for f in &files {
handles.push(commit.upload_from_path(f.clone()).await?);
}
// Spawn a task to print progress; the main thread blocks in commit() below.
// Spawn a task to print progress while the main task awaits commit().
let commit_for_progress = commit.clone();
std::thread::spawn(move || {
tokio::spawn(async move {
loop {
if let Ok(snapshot) = commit_for_progress.get_progress() {
let p = snapshot.total();
let done = handles
.iter()
.filter(|h| matches!(h.status(), Ok(TaskStatus::Completed)))
.filter(|h: &&UploadTaskHandle| matches!(h.status(), Ok(TaskStatus::Completed)))
.count();
println!("{}/{} files | {}/{} bytes", done, n_files, p.total_bytes_completed, p.total_bytes);
}
std::thread::sleep(Duration::from_millis(100));
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
});
// Block until all uploads finish and metadata is finalized.
let results = commit.commit()?;
// Await until all uploads finish and metadata is finalized.
let results = commit.commit().await?;
for m in results.values().filter_map(|m| m.as_ref().as_ref().ok()) {
println!(" {} -> {} ({} bytes)", m.tracking_name.as_deref().unwrap_or("?"), m.hash, m.file_size);
@@ -98,7 +99,7 @@ fn upload_files(files: Vec<PathBuf>, endpoint: Option<String>) -> Result<()> {
Ok(())
}
fn download_files(metadata_file: PathBuf, output_dir: PathBuf, endpoint: Option<String>) -> Result<()> {
async fn download_files(metadata_file: PathBuf, output_dir: PathBuf, endpoint: Option<String>) -> Result<()> {
let metadata: Vec<FileMetadata> = serde_json::from_str(&std::fs::read_to_string(metadata_file)?)?;
std::fs::create_dir_all(&output_dir)?;
@@ -106,28 +107,26 @@ fn download_files(metadata_file: PathBuf, output_dir: PathBuf, endpoint: Option<
if let Some(ep) = endpoint {
builder = builder.with_endpoint(ep);
}
let session = builder.build()?;
let group = session.new_download_group()?;
let session = builder.build_async().await?;
let group = session.new_download_group().await?;
// Enqueue all downloads; each starts immediately in the background.
let n_files = metadata.len();
let handles: Vec<DownloadTaskHandle> = metadata
.iter()
.map(|m| {
let dest = output_dir.join(m.tracking_name.as_deref().unwrap_or("file"));
group.download_file_to_path(
XetFileInfo {
hash: m.hash.clone(),
file_size: m.file_size,
},
dest,
)
})
.collect::<Result<_, _>>()?;
let mut handles: Vec<DownloadTaskHandle> = Vec::with_capacity(n_files);
for m in &metadata {
let dest = output_dir.join(m.tracking_name.as_deref().unwrap_or("file"));
handles.push(group.download_file_to_path(
XetFileInfo {
hash: m.hash.clone(),
file_size: m.file_size,
},
dest,
)?);
}
// Spawn a task to print progress; the main thread blocks in finish() below.
// Spawn a task to print progress while the main task awaits finish().
let group_for_progress = group.clone();
std::thread::spawn(move || {
tokio::spawn(async move {
loop {
if let Ok(snapshot) = group_for_progress.get_progress() {
let p = snapshot.total();
@@ -137,12 +136,12 @@ fn download_files(metadata_file: PathBuf, output_dir: PathBuf, endpoint: Option<
.count();
println!("{}/{} files | {}/{} bytes", done, n_files, p.total_bytes_completed, p.total_bytes);
}
std::thread::sleep(Duration::from_millis(100));
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
}
});
// Block until all downloads finish.
let results = group.finish()?;
// Await until all downloads finish.
let results = group.finish().await?;
for (_task_id, result) in &results {
if let Ok(r) = result.as_ref() {

View File

@@ -0,0 +1,150 @@
//! Session-based upload/download example.
//!
//! Shows the three-level hierarchy: XetSession → UploadCommit/DownloadGroup → files.
use std::path::PathBuf;
use std::time::Duration;
use anyhow::Result;
use clap::{Parser, Subcommand};
use xet::xet_session::{FileMetadata, TaskStatus, XetFileInfo, XetSessionBuilder};
#[derive(Parser)]
#[clap(name = "session-demo", about = "XetSession API demo")]
struct Cli {
#[clap(subcommand)]
command: Command,
}
#[derive(Subcommand)]
enum Command {
/// Upload files and save metadata to upload_metadata.json
Upload {
#[clap(required = true)]
files: Vec<PathBuf>,
#[clap(long)]
endpoint: Option<String>,
},
/// Download files from metadata saved by the upload subcommand
Download {
metadata_file: PathBuf,
#[clap(short, long, default_value = "./downloads")]
output_dir: PathBuf,
#[clap(long)]
endpoint: Option<String>,
},
}
fn main() -> Result<()> {
tracing_subscriber::fmt::init();
let cli = Cli::parse();
match cli.command {
Command::Upload { files, endpoint } => upload_files(files, endpoint),
Command::Download {
metadata_file,
output_dir,
endpoint,
} => download_files(metadata_file, output_dir, endpoint),
}
}
fn upload_files(files: Vec<PathBuf>, endpoint: Option<String>) -> Result<()> {
let mut builder = XetSessionBuilder::new();
if let Some(ep) = endpoint {
builder = builder.with_endpoint(ep);
}
let session = builder.build()?;
let commit = session.new_upload_commit_blocking()?;
// Enqueue all uploads; each starts immediately in the background.
let n_files = files.len();
let mut handles = Vec::with_capacity(n_files);
for f in &files {
handles.push(commit.upload_from_path(f.clone())?);
}
// Spawn a task to print progress; the main thread blocks in commit() below.
let commit_for_progress = commit.clone();
std::thread::spawn(move || {
loop {
if let Ok(snapshot) = commit_for_progress.get_progress() {
let p = snapshot.total();
let done = handles
.iter()
.filter(|h| matches!(h.status(), Ok(TaskStatus::Completed)))
.count();
println!("{}/{} files | {}/{} bytes", done, n_files, p.total_bytes_completed, p.total_bytes);
}
std::thread::sleep(Duration::from_millis(100));
}
});
// Block until all uploads finish and metadata is finalized.
let results = commit.commit()?;
for m in results.values().filter_map(|m| m.as_ref().as_ref().ok()) {
println!(" {} -> {} ({} bytes)", m.tracking_name.as_deref().unwrap_or("?"), m.hash, m.file_size);
}
// Persist metadata so it can be passed to the `download` subcommand.
let metadata: Vec<_> = results
.into_values()
.filter_map(|m| m.as_ref().as_ref().ok().cloned())
.collect();
std::fs::write("upload_metadata.json", serde_json::to_string_pretty(&metadata)?)?;
Ok(())
}
fn download_files(metadata_file: PathBuf, output_dir: PathBuf, endpoint: Option<String>) -> Result<()> {
let metadata: Vec<FileMetadata> = serde_json::from_str(&std::fs::read_to_string(metadata_file)?)?;
std::fs::create_dir_all(&output_dir)?;
let mut builder = XetSessionBuilder::new();
if let Some(ep) = endpoint {
builder = builder.with_endpoint(ep);
}
let session = builder.build()?;
let group = session.new_download_group_blocking()?;
// Enqueue all downloads; each starts immediately in the background.
let n_files = metadata.len();
let mut handles = Vec::with_capacity(n_files);
for m in &metadata {
let dest = output_dir.join(m.tracking_name.as_deref().unwrap_or("file"));
handles.push(group.download_file_to_path(
XetFileInfo {
hash: m.hash.clone(),
file_size: m.file_size,
},
dest,
)?);
}
// Spawn a task to print progress; the main thread blocks in finish() below.
let group_for_progress = group.clone();
std::thread::spawn(move || {
loop {
if let Ok(snapshot) = group_for_progress.get_progress() {
let p = snapshot.total();
let done = handles
.iter()
.filter(|h| matches!(h.status(), Ok(TaskStatus::Completed)))
.count();
println!("{}/{} files | {}/{} bytes", done, n_files, p.total_bytes_completed, p.total_bytes);
}
std::thread::sleep(Duration::from_millis(100));
}
});
// Block until all downloads finish.
let results = group.finish()?;
for (_task_id, result) in &results {
if let Ok(r) = result.as_ref() {
println!(" {} ({} bytes)", r.dest_path.display(), r.file_info.file_size);
}
}
Ok(())
}

View File

@@ -62,12 +62,21 @@ pub enum XetError {
/// Catch-all for unexpected internal errors (panics, lock poison, bugs).
#[error("Internal error: {0}")]
Internal(String),
/// Caller invoked a method that is incompatible with the session's runtime mode.
#[error("Wrong runtime mode: {0}")]
WrongRuntimeMode(String),
}
impl XetError {
pub fn other(msg: impl std::fmt::Display) -> Self {
Self::Internal(msg.to_string())
}
pub fn wrong_mode(msg: impl std::fmt::Display) -> Self {
Self::WrongRuntimeMode(msg.to_string())
}
fn from_runtime_error_ref(re: &RuntimeError) -> Self {
match re {
RuntimeError::TaskCanceled(_) => XetError::Cancelled(re.to_string()),

View File

@@ -3,7 +3,7 @@ use xet_data::processing::configurations::TranslatorConfig;
use super::{SessionError, XetSession};
// Helper function to create TranslatorConfig
pub(crate) fn create_translator_config(session: &XetSession) -> Result<TranslatorConfig, SessionError> {
pub(super) fn create_translator_config(session: &XetSession) -> Result<TranslatorConfig, SessionError> {
let endpoint = session
.endpoint
.clone()
@@ -21,7 +21,7 @@ pub(crate) fn create_translator_config(session: &XetSession) -> Result<Translato
/// State of the upload commit and download group
#[derive(Clone, Copy, PartialEq, Eq)]
pub(crate) enum GroupState {
pub(super) enum GroupState {
Alive,
Finished,
Aborted,

View File

@@ -14,11 +14,15 @@ use super::errors::SessionError;
use super::progress::{DownloadTaskHandle, GroupProgress, ProgressSnapshot, TaskHandle, TaskStatus};
use super::session::XetSession;
/// Groups related file downloads into a single unit of work.
/// Async API for grouping related file downloads into a single unit of work.
///
/// Obtain via [`XetSession::new_download_group`] from an `async` context.
/// For sync / non-async code use [`DownloadGroupSync`] from
/// [`XetSession::new_download_group_blocking`] instead.
///
/// Queue files with [`download_file_to_path`](Self::download_file_to_path) (they start
/// downloading immediately in the background), poll progress with
/// [`get_progress`](Self::get_progress), then call
/// [`get_progress`](Self::get_progress), then `await`
/// [`finish`](Self::finish) to wait for all downloads to complete.
///
/// # Cloning
@@ -31,6 +35,8 @@ use super::session::XetSession;
/// Methods return [`SessionError::Aborted`] if the parent session has been
/// aborted, and [`SessionError::AlreadyFinished`] if
/// [`finish`](Self::finish) has already been called.
///
/// [`DownloadGroupSync`]: crate::xet_session::sync::DownloadGroupSync
#[derive(Clone)]
pub struct DownloadGroup {
inner: Arc<DownloadGroupInner>,
@@ -44,17 +50,14 @@ impl std::ops::Deref for DownloadGroup {
}
impl DownloadGroup {
/// Create a new download group
pub(crate) fn new(session: XetSession) -> Result<Self, SessionError> {
/// Create a new download group from an **async** context. Initialisation logic shared by the sync and async
/// constructors.
pub(super) async fn new(session: XetSession) -> Result<Self, SessionError> {
let group_id = Ulid::new();
let progress = Arc::new(GroupProgress::new());
let progress_clone = progress.clone();
let config = create_translator_config(&session)?;
let download_session = session.runtime.external_run_async_task(async move {
let progress_updater = progress_clone as Arc<dyn xet_data::progress_tracking::TrackingProgressUpdater>;
FileDownloadSession::new(Arc::new(config), Some(progress_updater)).await
})??;
let progress_updater = progress.clone() as Arc<dyn xet_data::progress_tracking::TrackingProgressUpdater>;
let download_session = FileDownloadSession::new(Arc::new(config), Some(progress_updater)).await?;
let inner = Arc::new(DownloadGroupInner {
group_id,
@@ -69,23 +72,26 @@ impl DownloadGroup {
}
/// Get the group ID.
pub(crate) fn id(&self) -> Ulid {
pub(super) fn id(&self) -> Ulid {
self.group_id
}
/// Abort this download group.
pub(crate) fn abort(&self) -> Result<(), SessionError> {
pub(super) fn abort(&self) -> Result<(), SessionError> {
self.inner.abort()
}
// ===== Public synchronous methods =====
/// Returns the runtime used by this group.
pub(super) fn runtime(&self) -> &XetRuntime {
&self.inner.session.runtime
}
/// Queue a file for download to `dest_path`, starting the transfer immediately if system resource permits.
///
/// # Parameters
///
/// * `file_info` Content-addressed hash and size returned by a previous
/// [`UploadCommit::commit`](crate::UploadCommit::commit).
/// [`UploadCommit::commit`](crate::xet_session::UploadCommit::commit).
/// * `dest_path` Local path where the downloaded file will be written. Parent directories are created
/// automatically.
///
@@ -103,16 +109,11 @@ impl DownloadGroup {
dest_path: PathBuf,
) -> Result<DownloadTaskHandle, SessionError> {
self.session.check_alive()?;
self.inner.start_download_file_to_path(file_info, dest_path)
}
/// Returns `true` if [`finish`](Self::finish) has been called and completed.
#[cfg(test)]
fn is_finished(&self) -> bool {
match self.state.lock() {
Ok(state) => *state == GroupState::Finished,
Err(_) => false,
}
// Use the absolute path in case the process current working directory changes
// while the task is queued.
let absolute_path = std::path::absolute(dest_path)?;
self.inner.start_download_file_to_path(file_info, absolute_path)
}
/// Return a snapshot of progress for every queued download.
@@ -122,11 +123,9 @@ impl DownloadGroup {
/// Wait for all downloads to complete and return their results.
///
/// Blocks until every queued download finishes (or fails). Returns a
/// `HashMap` keyed by task ID (the [`Ulid`] returned by
/// [`download_file_to_path`](Self::download_file_to_path)), where each
/// value is [`DownloadResult`] (= `Arc<Result<`[`DownloadedFile`]`,
/// `[`SessionError`](crate::SessionError)`>>`). A single failed download
/// Returns a `HashMap` keyed by task ID where each value is
/// [`DownloadResult`] (= `Arc<Result<`[`DownloadedFile`]`,
/// [`SessionError`](crate::SessionError)`>>`). A single failed download
/// does not prevent the others from being collected.
///
/// Per-task results can also be read directly from the
@@ -134,13 +133,21 @@ impl DownloadGroup {
/// [`result`](DownloadTaskHandle::result) after this method returns.
///
/// Consumes `self` — subsequent calls on any clone will return
/// [`SessionError::AlreadyFinished`] (or a channel-closed error if the
/// background worker has already exited).
pub fn finish(self) -> Result<HashMap<Ulid, DownloadResult>, SessionError> {
/// [`SessionError::AlreadyFinished`].
pub async fn finish(self) -> Result<HashMap<Ulid, DownloadResult>, SessionError> {
let inner = self.inner.clone();
self.session
.runtime
.external_run_async_task(async move { inner.handle_finish().await })?
.dispatch("finish", async move { inner.handle_finish().await })
.await?
}
/// Returns `true` if [`finish`](Self::finish) has been called and completed.
#[cfg(test)]
fn is_finished(&self) -> bool {
match self.state.lock() {
Ok(state) => *state == GroupState::Finished,
Err(_) => false,
}
}
}
@@ -152,7 +159,7 @@ impl DownloadGroup {
pub type DownloadResult = Arc<Result<DownloadedFile, SessionError>>;
/// Handle for a single download task tracked internally by DownloadGroup.
pub(crate) struct InnerDownloadTaskHandle {
struct InnerDownloadTaskHandle {
status: Arc<Mutex<TaskStatus>>,
dest_path: PathBuf,
join_handle: JoinHandle<Result<XetFileInfo, SessionError>>,
@@ -217,7 +224,11 @@ impl DownloadGroupInner {
} else {
TaskStatus::Failed
};
*status.lock()? = new_status;
// Only overwrite if still Running — abort() may have set Cancelled concurrently.
let mut s = status.lock()?;
if matches!(*s, TaskStatus::Running) {
*s = new_status;
}
Ok(XetFileInfo {
hash: file_info.hash,
@@ -273,8 +284,8 @@ impl DownloadGroupInner {
Ok(task_handle)
}
/// Handle a `Finish` command from the public API.
async fn handle_finish(self: &Arc<Self>) -> Result<HashMap<Ulid, DownloadResult>, SessionError> {
/// Join all active download tasks and mark the group as finished.
async fn handle_finish(&self) -> Result<HashMap<Ulid, DownloadResult>, SessionError> {
// Mark as not accepting new tasks
{
let mut state_guard = self.state.lock()?;
@@ -356,23 +367,30 @@ pub struct DownloadedFile {
#[cfg(test)]
mod tests {
use std::sync::mpsc;
use std::time::Duration;
use tempfile::{TempDir, tempdir};
use super::*;
use crate::xet_session::progress::UploadTaskHandle;
use crate::xet_session::session::XetSession;
use crate::xet_session::session::{RuntimeMode, XetSession, XetSessionBuilder};
fn local_session(temp: &TempDir) -> Result<XetSession, Box<dyn std::error::Error>> {
async fn local_session(temp: &TempDir) -> Result<XetSession, Box<dyn std::error::Error>> {
let cas_path = temp.path().join("cas");
Ok(XetSession::new(Some(format!("local://{}", cas_path.display())), None, None, None)?)
Ok(XetSessionBuilder::new()
.with_endpoint(format!("local://{}", cas_path.display()))
.build_async()
.await?)
}
fn upload_bytes(session: &XetSession, data: &[u8], name: &str) -> Result<XetFileInfo, Box<dyn std::error::Error>> {
let commit = session.new_upload_commit()?;
let handle = commit.upload_bytes(data.to_vec(), Some(name.into()))?;
let results = commit.commit()?;
async fn upload_bytes(
session: &XetSession,
data: &[u8],
name: &str,
) -> Result<XetFileInfo, Box<dyn std::error::Error>> {
let commit = session.new_upload_commit().await?;
let handle = commit.upload_bytes(data.to_vec(), Some(name.into())).await?;
let results = commit.commit().await?;
let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap();
Ok(XetFileInfo {
hash: meta.hash.clone(),
@@ -380,288 +398,6 @@ mod tests {
})
}
// ── Identity ─────────────────────────────────────────────────────────────
#[test]
// Two download groups created from the same session have distinct IDs.
fn test_group_has_unique_id() -> Result<(), Box<dyn std::error::Error>> {
let session = XetSession::new(None, None, None, None)?;
let g1 = session.new_download_group()?;
let g2 = session.new_download_group()?;
assert_ne!(g1.id(), g2.id());
Ok(())
}
// ── Initial state ────────────────────────────────────────────────────────
#[test]
// A fresh group has all-zero aggregate progress.
fn test_get_progress_empty_initially() -> Result<(), Box<dyn std::error::Error>> {
let session = XetSession::new(None, None, None, None)?;
let group = session.new_download_group()?;
let snapshot = group.get_progress()?;
let total = snapshot.total();
assert_eq!(total.total_bytes, 0);
assert_eq!(total.total_bytes_completed, 0);
Ok(())
}
// ── Finish lifecycle ─────────────────────────────────────────────────────
#[test]
// An empty finish succeeds and returns an empty result set.
fn test_finish_empty_succeeds() -> Result<(), Box<dyn std::error::Error>> {
let session = XetSession::new(None, None, None, None)?;
let group = session.new_download_group()?;
let results = group.finish()?;
assert!(results.is_empty());
Ok(())
}
#[test]
// finish() transitions the group into the Finished state.
fn test_finish_marks_as_finished() -> Result<(), Box<dyn std::error::Error>> {
let session = XetSession::new(None, None, None, None)?;
let group = session.new_download_group()?;
let group_clone = group.clone();
group.finish().unwrap();
assert!(group_clone.is_finished());
Ok(())
}
#[test]
// A second finish() call on any clone returns AlreadyFinished.
fn test_second_finish_fails() -> Result<(), Box<dyn std::error::Error>> {
let session = XetSession::new(None, None, None, None)?;
let g1 = session.new_download_group()?;
let g2 = g1.clone();
g1.finish()?;
let err = g2.finish().unwrap_err();
assert!(matches!(err, SessionError::AlreadyFinished | SessionError::Internal(_)));
Ok(())
}
#[test]
// finish() unregisters the group from the session's active set.
fn test_finish_unregisters_from_session() -> Result<(), Box<dyn std::error::Error>> {
let session = XetSession::new(None, None, None, None)?;
let group = session.new_download_group()?;
assert_eq!(session.active_download_groups.lock().unwrap().len(), 1);
group.finish().unwrap();
assert_eq!(session.active_download_groups.lock().unwrap().len(), 0);
Ok(())
}
// ── Guards ───────────────────────────────────────────────────────────────
#[test]
// download_file_to_path returns Aborted when the parent session has been aborted.
fn test_download_file_on_aborted_session_returns_error() -> Result<(), Box<dyn std::error::Error>> {
let session = XetSession::new(None, None, None, None)?;
let group = session.new_download_group()?;
session.abort().unwrap();
let err = group
.download_file_to_path(
XetFileInfo {
hash: "abc123".to_string(),
file_size: 1024,
},
PathBuf::from("dest.bin"),
)
.unwrap_err();
assert!(matches!(err, SessionError::Aborted));
Ok(())
}
#[test]
// download_file_to_path after finish returns AlreadyFinished.
fn test_download_file_after_finish_fails() -> Result<(), Box<dyn std::error::Error>> {
let session = XetSession::new(None, None, None, None)?;
let g1 = session.new_download_group()?;
let g2 = g1.clone();
g1.finish()?;
let err = g2
.download_file_to_path(
XetFileInfo {
hash: "abc123".to_string(),
file_size: 1024,
},
PathBuf::from("dest.bin"),
)
.unwrap_err();
assert!(matches!(err, SessionError::AlreadyFinished));
Ok(())
}
#[test]
// download_file_to_path on a directly-aborted group returns Aborted (not AlreadyFinished).
fn test_download_file_on_aborted_group_returns_aborted() -> Result<(), Box<dyn std::error::Error>> {
let session = XetSession::new(None, None, None, None)?;
let group = session.new_download_group()?;
group.abort()?;
let err = group
.download_file_to_path(
XetFileInfo {
hash: "abc123".to_string(),
file_size: 1024,
},
PathBuf::from("dest.bin"),
)
.unwrap_err();
assert!(matches!(err, SessionError::Aborted));
Ok(())
}
// ── Independence ─────────────────────────────────────────────────────────
#[test]
// Finishing one group does not affect the state of another from the same session.
fn test_two_groups_are_independent() -> Result<(), Box<dyn std::error::Error>> {
let session = XetSession::new(None, None, None, None)?;
let g1 = session.new_download_group()?;
let g2 = session.new_download_group()?;
g1.finish()?;
assert!(!g2.is_finished());
Ok(())
}
// ── Round-trip tests ─────────────────────────────────────────────────────
#[test]
// Downloading a previously uploaded file produces byte-identical content at the destination.
fn test_download_file_round_trip() -> Result<(), Box<dyn std::error::Error>> {
let temp = tempdir()?;
let session = local_session(&temp)?;
let original = b"Hello, download round-trip!";
let file_info = upload_bytes(&session, original, "payload.bin")?;
let dest = temp.path().join("downloaded.bin");
let group = session.new_download_group()?;
group.download_file_to_path(file_info, dest.clone())?;
group.finish()?;
assert_eq!(std::fs::read(&dest)?, original);
Ok(())
}
#[test]
// Downloading multiple files from a single group produces correct content for each.
fn test_download_multiple_files() -> Result<(), Box<dyn std::error::Error>> {
let temp = tempdir()?;
let session = local_session(&temp)?;
let data_a = b"First file content";
let data_b = b"Second file content - different";
// Upload both files; capture handles so results can be retrieved by task_id.
let commit = session.new_upload_commit()?;
let handle_a = commit.upload_bytes(data_a.to_vec(), Some("a.bin".into()))?;
let handle_b = commit.upload_bytes(data_b.to_vec(), Some("b.bin".into()))?;
let results = commit.commit()?;
let to_file_info = |handle: &UploadTaskHandle| -> XetFileInfo {
let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap();
XetFileInfo {
hash: meta.hash.clone(),
file_size: meta.file_size,
}
};
let dest_a = temp.path().join("a_out.bin");
let dest_b = temp.path().join("b_out.bin");
let group = session.new_download_group()?;
group.download_file_to_path(to_file_info(&handle_a), dest_a.clone())?;
group.download_file_to_path(to_file_info(&handle_b), dest_b.clone())?;
group.finish()?;
assert_eq!(std::fs::read(&dest_a)?, data_a);
assert_eq!(std::fs::read(&dest_b)?, data_b);
Ok(())
}
#[test]
// After a successful finish the aggregate download progress reflects bytes received.
fn test_download_progress_reflects_bytes_after_finish() -> Result<(), Box<dyn std::error::Error>> {
let temp = tempdir()?;
let session = local_session(&temp)?;
let original = b"download progress tracking data";
let file_info = upload_bytes(&session, original, "prog.bin")?;
let dest = temp.path().join("out.bin");
let group = session.new_download_group()?;
let progress_observer = group.clone();
group.download_file_to_path(file_info, dest)?;
group.finish()?;
std::thread::sleep(
session
.runtime
.config()
.data
.progress_update_interval
.saturating_add(Duration::from_secs(1)),
);
let snapshot = progress_observer.get_progress()?;
assert!(snapshot.total().total_bytes_completed > 0);
Ok(())
}
// ── Per-task result access patterns ──────────────────────────────────────
//
// After finish() completes there are two equivalent ways to retrieve a
// per-task DownloadResult:
//
// 1. HashMap lookup: `finish_results.get(&handle.task_id)`
// 2. Direct handle: `handle.result()` (on DownloadTaskHandle)
#[test]
// Pattern 1: per-task result is accessible via task_id in the finish() HashMap.
fn test_download_result_accessible_via_task_id_in_finish_map() -> Result<(), Box<dyn std::error::Error>> {
let temp = tempdir()?;
let session = local_session(&temp)?;
let data = b"result via task_id in finish map";
let file_info = upload_bytes(&session, data, "file.bin")?;
let dest = temp.path().join("out.bin");
let group = session.new_download_group()?;
let handle = group.download_file_to_path(file_info, dest)?;
let results = group.finish()?;
let result = results.get(&handle.task_id).expect("task_id must be present in results");
assert_eq!(result.as_ref().as_ref().unwrap().file_info.file_size, data.len() as u64);
Ok(())
}
#[test]
// DownloadTaskHandle::result() returns None before finish() is called.
fn test_download_result_none_before_finish() -> Result<(), Box<dyn std::error::Error>> {
let temp = tempdir()?;
let session = local_session(&temp)?;
let file_info = upload_bytes(&session, b"some data", "file.bin")?;
let dest = temp.path().join("out.bin");
let group = session.new_download_group()?;
let handle = group.download_file_to_path(file_info, dest)?;
assert!(handle.result().is_none(), "result must be None before finish()");
group.finish()?;
Ok(())
}
#[test]
// DownloadTaskHandle::result() returns Some after finish() completes.
fn test_download_result_some_after_finish() -> Result<(), Box<dyn std::error::Error>> {
let temp = tempdir()?;
let session = local_session(&temp)?;
let data = b"download result test data";
let file_info = upload_bytes(&session, data, "file.bin")?;
let dest = temp.path().join("out.bin");
let group = session.new_download_group()?;
let handle = group.download_file_to_path(file_info.clone(), dest)?;
group.finish()?;
let result = handle.result().expect("result must be set after finish()");
let dl = result.as_ref().as_ref().unwrap();
assert_eq!(dl.file_info.file_size, data.len() as u64);
assert_eq!(dl.file_info.hash, file_info.hash);
Ok(())
}
// ── Mutex guard / concurrency test ───────────────────────────────────────
//
// `download_file_to_path` holds `self.state` for its entire execution so
@@ -673,27 +409,26 @@ mod tests {
#[test]
// finish() must block while download_file_to_path() holds the state lock.
fn test_finish_blocked_while_download_registration_holds_state_lock() -> Result<(), Box<dyn std::error::Error>> {
use std::sync::mpsc;
let session = XetSession::new(None, None, None, None)?;
let group = session.new_download_group()?;
let session = XetSessionBuilder::new().build()?;
let runtime = session.runtime.clone();
// Create DownloadGroup directly so we can access its private state field
// (accessible here because mod tests is a submodule of download_group).
let group = runtime.external_run_async_task(DownloadGroup::new(session.clone()))??;
let group_for_thread = group.clone();
let runtime_for_thread = runtime.clone();
// Simulate download_file_to_path() holding the state lock mid-registration.
let guard = group.inner.state.lock().unwrap();
let (done_tx, done_rx) = mpsc::channel::<()>();
let join_handle = std::thread::spawn(move || {
let _ = group_for_thread.finish(); // must block until guard is dropped
let _ = runtime_for_thread.external_run_async_task(async move { group_for_thread.finish().await });
let _ = done_tx.send(());
});
// Give the spawned thread enough time to reach the state-lock acquisition
// inside finish() and block there.
std::thread::sleep(Duration::from_millis(50));
assert!(done_rx.try_recv().is_err(), "finish() should be blocked while state lock is held");
// Release the lock — simulates the enqueue method completing its registration.
drop(guard);
assert!(
@@ -703,4 +438,349 @@ mod tests {
let _ = join_handle.join();
Ok(())
}
// ── Identity ─────────────────────────────────────────────────────────────
#[tokio::test(flavor = "multi_thread")]
// Two download groups created from the same session have distinct IDs.
async fn test_group_has_unique_id() {
let session = XetSessionBuilder::new().build_async().await.unwrap();
let g1 = session.new_download_group().await.unwrap();
let g2 = session.new_download_group().await.unwrap();
assert_ne!(g1.id(), g2.id());
}
// ── Initial state ────────────────────────────────────────────────────────
#[tokio::test(flavor = "multi_thread")]
// A fresh group has all-zero aggregate progress.
async fn test_get_progress_empty_initially() {
let session = XetSessionBuilder::new().build_async().await.unwrap();
let group = session.new_download_group().await.unwrap();
let snapshot = group.get_progress().unwrap();
let total = snapshot.total();
assert_eq!(total.total_bytes, 0);
assert_eq!(total.total_bytes_completed, 0);
}
// ── Finish lifecycle ─────────────────────────────────────────────────────
#[tokio::test(flavor = "multi_thread")]
// An empty finish succeeds and returns an empty result set.
async fn test_finish_empty_succeeds() {
let session = XetSessionBuilder::new().build_async().await.unwrap();
let group = session.new_download_group().await.unwrap();
let results = group.finish().await.unwrap();
assert!(results.is_empty());
}
#[tokio::test(flavor = "multi_thread")]
// finish() transitions the group into the Finished state.
async fn test_finish_marks_as_finished() {
let session = XetSessionBuilder::new().build_async().await.unwrap();
let group = session.new_download_group().await.unwrap();
let group_clone = group.clone();
group.finish().await.unwrap();
assert!(group_clone.is_finished());
}
#[tokio::test(flavor = "multi_thread")]
// A second finish() call on any clone returns AlreadyFinished.
async fn test_second_finish_fails() {
let session = XetSessionBuilder::new().build_async().await.unwrap();
let g1 = session.new_download_group().await.unwrap();
let g2 = g1.clone();
g1.finish().await.unwrap();
let err = g2.finish().await.unwrap_err();
assert!(matches!(err, SessionError::AlreadyFinished | SessionError::Internal(_)));
}
#[tokio::test(flavor = "multi_thread")]
// finish() unregisters the group from the session's active set.
async fn test_finish_unregisters_from_session() {
let session = XetSessionBuilder::new().build_async().await.unwrap();
let group = session.new_download_group().await.unwrap();
assert_eq!(session.active_download_groups.lock().unwrap().len(), 1);
group.finish().await.unwrap();
assert_eq!(session.active_download_groups.lock().unwrap().len(), 0);
}
// ── Guards ───────────────────────────────────────────────────────────────
#[tokio::test(flavor = "multi_thread")]
// download_file_to_path returns Aborted when the parent session has been aborted.
async fn test_download_file_on_aborted_session_returns_error() {
let session = XetSessionBuilder::new().build_async().await.unwrap();
let group = session.new_download_group().await.unwrap();
session.abort().unwrap();
let err = group
.download_file_to_path(
XetFileInfo {
hash: "abc123".to_string(),
file_size: 1024,
},
std::path::PathBuf::from("dest.bin"),
)
.unwrap_err();
assert!(matches!(err, SessionError::Aborted));
}
#[tokio::test(flavor = "multi_thread")]
// download_file_to_path after finish returns AlreadyFinished.
async fn test_download_file_after_finish_fails() {
let session = XetSessionBuilder::new().build_async().await.unwrap();
let g1 = session.new_download_group().await.unwrap();
let g2 = g1.clone();
g1.finish().await.unwrap();
let err = g2
.download_file_to_path(
XetFileInfo {
hash: "abc123".to_string(),
file_size: 1024,
},
std::path::PathBuf::from("dest.bin"),
)
.unwrap_err();
assert!(matches!(err, SessionError::AlreadyFinished));
}
#[tokio::test(flavor = "multi_thread")]
// download_file_to_path on a directly-aborted group returns Aborted.
async fn test_download_file_on_aborted_group_returns_aborted() {
let session = XetSessionBuilder::new().build_async().await.unwrap();
let group = session.new_download_group().await.unwrap();
group.abort().unwrap();
let err = group
.download_file_to_path(
XetFileInfo {
hash: "abc123".to_string(),
file_size: 1024,
},
std::path::PathBuf::from("dest.bin"),
)
.unwrap_err();
assert!(matches!(err, SessionError::Aborted));
}
// ── Independence ─────────────────────────────────────────────────────────
#[tokio::test(flavor = "multi_thread")]
// Finishing one group does not affect the state of another from the same session.
async fn test_two_groups_are_independent() {
let session = XetSessionBuilder::new().build_async().await.unwrap();
let g1 = session.new_download_group().await.unwrap();
let g2 = session.new_download_group().await.unwrap();
g1.finish().await.unwrap();
assert!(!g2.is_finished());
}
// ── Round-trip tests ─────────────────────────────────────────────────────
#[tokio::test(flavor = "multi_thread")]
// Downloading a previously uploaded file produces byte-identical content at the destination.
async fn test_download_file_round_trip() {
let temp = tempdir().unwrap();
let session = local_session(&temp).await.unwrap();
let original = b"Hello, download round-trip!";
let file_info = upload_bytes(&session, original, "payload.bin").await.unwrap();
let dest = temp.path().join("downloaded.bin");
let group = session.new_download_group().await.unwrap();
group.download_file_to_path(file_info, dest.clone()).unwrap();
group.finish().await.unwrap();
assert_eq!(std::fs::read(&dest).unwrap(), original);
}
#[tokio::test(flavor = "multi_thread")]
// Downloading multiple files from a single group produces correct content for each.
async fn test_download_multiple_files() {
let temp = tempdir().unwrap();
let session = local_session(&temp).await.unwrap();
let data_a = b"First file content";
let data_b = b"Second file content - different";
let commit = session.new_upload_commit().await.unwrap();
let handle_a = commit.upload_bytes(data_a.to_vec(), Some("a.bin".into())).await.unwrap();
let handle_b = commit.upload_bytes(data_b.to_vec(), Some("b.bin".into())).await.unwrap();
let results = commit.commit().await.unwrap();
let to_file_info = |handle: &crate::xet_session::progress::UploadTaskHandle| -> XetFileInfo {
let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap();
XetFileInfo {
hash: meta.hash.clone(),
file_size: meta.file_size,
}
};
let dest_a = temp.path().join("a_out.bin");
let dest_b = temp.path().join("b_out.bin");
let group = session.new_download_group().await.unwrap();
group.download_file_to_path(to_file_info(&handle_a), dest_a.clone()).unwrap();
group.download_file_to_path(to_file_info(&handle_b), dest_b.clone()).unwrap();
group.finish().await.unwrap();
assert_eq!(std::fs::read(&dest_a).unwrap(), data_a);
assert_eq!(std::fs::read(&dest_b).unwrap(), data_b);
}
#[tokio::test(flavor = "multi_thread")]
// After a successful finish the aggregate download progress reflects bytes received.
async fn test_download_progress_reflects_bytes_after_finish() {
let temp = tempdir().unwrap();
let session = local_session(&temp).await.unwrap();
let original = b"download progress tracking data";
let file_info = upload_bytes(&session, original, "prog.bin").await.unwrap();
let dest = temp.path().join("out.bin");
let group = session.new_download_group().await.unwrap();
let progress_observer = group.clone();
group.download_file_to_path(file_info, dest).unwrap();
group.finish().await.unwrap();
tokio::time::sleep(
session
.runtime
.config()
.data
.progress_update_interval
.saturating_add(Duration::from_secs(1)),
)
.await;
let snapshot = progress_observer.get_progress().unwrap();
assert!(snapshot.total().total_bytes_completed > 0);
}
// ── Per-task result access patterns ──────────────────────────────────────
#[tokio::test(flavor = "multi_thread")]
// Pattern 1: per-task result is accessible via task_id in the finish() HashMap.
async fn test_download_result_accessible_via_task_id_in_finish_map() {
let temp = tempdir().unwrap();
let session = local_session(&temp).await.unwrap();
let data = b"result via task_id in finish map";
let file_info = upload_bytes(&session, data, "file.bin").await.unwrap();
let dest = temp.path().join("out.bin");
let group = session.new_download_group().await.unwrap();
let handle = group.download_file_to_path(file_info, dest).unwrap();
let results = group.finish().await.unwrap();
let result = results.get(&handle.task_id).expect("task_id must be present in results");
assert_eq!(result.as_ref().as_ref().unwrap().file_info.file_size, data.len() as u64);
}
#[tokio::test(flavor = "multi_thread")]
// DownloadTaskHandle::result() returns None before finish() is called.
async fn test_download_result_none_before_finish() {
let temp = tempdir().unwrap();
let session = local_session(&temp).await.unwrap();
let file_info = upload_bytes(&session, b"some data", "file.bin").await.unwrap();
let dest = temp.path().join("out.bin");
let group = session.new_download_group().await.unwrap();
let handle = group.download_file_to_path(file_info, dest).unwrap();
assert!(handle.result().is_none(), "result must be None before finish()");
group.finish().await.unwrap();
}
#[tokio::test(flavor = "multi_thread")]
// DownloadTaskHandle::result() returns Some after finish() completes.
async fn test_download_result_some_after_finish() {
let temp = tempdir().unwrap();
let session = local_session(&temp).await.unwrap();
let data = b"download result test data";
let file_info = upload_bytes(&session, data, "file.bin").await.unwrap();
let dest = temp.path().join("out.bin");
let group = session.new_download_group().await.unwrap();
let handle = group.download_file_to_path(file_info.clone(), dest).unwrap();
group.finish().await.unwrap();
let result = handle.result().expect("result must be set after finish()");
let dl = result.as_ref().as_ref().unwrap();
assert_eq!(dl.file_info.file_size, data.len() as u64);
assert_eq!(dl.file_info.hash, file_info.hash);
}
// ── Non-tokio executor (Owned-mode bridge) ────────────────────────────────
#[test]
// build_async() falls back to Owned mode from a futures executor, and the
// bridge correctly routes download through the owned thread pool while the
// future is driven by the caller's executor (futures::block_on).
fn test_async_bridge_works_from_futures_executor() {
let temp = tempdir().unwrap();
futures::executor::block_on(async {
let session = local_session(&temp).await.unwrap();
assert_eq!(session.runtime_mode, RuntimeMode::Owned);
let data = b"hello from futures executor";
let commit = session.new_upload_commit().await.unwrap();
let handle = commit.upload_bytes(data.to_vec(), Some("test.bin".into())).await.unwrap();
let results = commit.commit().await.unwrap();
let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap();
let file_info = XetFileInfo {
hash: meta.hash.clone(),
file_size: meta.file_size,
};
let dest = temp.path().join("out_futures.bin");
let group = session.new_download_group().await.unwrap();
group.download_file_to_path(file_info, dest.clone()).unwrap();
group.finish().await.unwrap();
assert_eq!(std::fs::read(&dest).unwrap(), data);
});
}
#[test]
// Same as above but driven by the smol executor.
fn test_async_bridge_works_from_smol_executor() {
let temp = tempdir().unwrap();
smol::block_on(async {
let session = local_session(&temp).await.unwrap();
assert_eq!(session.runtime_mode, RuntimeMode::Owned);
let data = b"hello from smol executor";
let commit = session.new_upload_commit().await.unwrap();
let handle = commit.upload_bytes(data.to_vec(), Some("test.bin".into())).await.unwrap();
let results = commit.commit().await.unwrap();
let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap();
let file_info = XetFileInfo {
hash: meta.hash.clone(),
file_size: meta.file_size,
};
let dest = temp.path().join("out_smol.bin");
let group = session.new_download_group().await.unwrap();
group.download_file_to_path(file_info, dest.clone()).unwrap();
group.finish().await.unwrap();
assert_eq!(std::fs::read(&dest).unwrap(), data);
});
}
#[test]
// Same as above but driven by the async-std executor.
fn test_async_bridge_works_from_async_std_executor() {
let temp = tempdir().unwrap();
async_std::task::block_on(async {
let session = local_session(&temp).await.unwrap();
assert_eq!(session.runtime_mode, RuntimeMode::Owned);
let data = b"hello from async-std executor";
let commit = session.new_upload_commit().await.unwrap();
let handle = commit.upload_bytes(data.to_vec(), Some("test.bin".into())).await.unwrap();
let results = commit.commit().await.unwrap();
let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap();
let file_info = XetFileInfo {
hash: meta.hash.clone(),
file_size: meta.file_size,
};
let dest = temp.path().join("out_async_std.bin");
let group = session.new_download_group().await.unwrap();
group.download_file_to_path(file_info, dest.clone()).unwrap();
group.finish().await.unwrap();
assert_eq!(std::fs::read(&dest).unwrap(), data);
});
}
}

View File

@@ -4,46 +4,57 @@
//! file operations:
//!
//! ```text
//! XetSession — owns the tokio runtime and authentication credentials
//! XetSession — holds runtime context and authentication credentials
//! ├── UploadCommit — groups related uploads; finalised with commit()
//! └── DownloadGroup — groups related downloads; finalised with finish()
//! ```
//!
//! Each [`XetSession`] owns its own tokio runtime and configuration, so
//! Each [`XetSession`] holds its own runtime context and configuration, so
//! multiple sessions with different endpoints or credentials can coexist in
//! the same process. Cloning a session, commit, or group is cheap — all
//! clones share the same underlying state via `Arc`.
//!
//! ## Uploads
//!
//! Create an [`UploadCommit`] with [`XetSession::new_upload_commit`], queue
//! files with [`upload_from_path`](UploadCommit::upload_from_path) or
//! [`upload_bytes`](UploadCommit::upload_bytes), then call
//! [`commit`](UploadCommit::commit) to wait for all transfers to finish and
//! receive a `HashMap<Ulid, `[`UploadResult`]`>` keyed by task ID
//! (`UploadResult` = `Arc<Result<`[`FileMetadata`]`, `[`SessionError`]`>>`).
//! Per-task results can also be read directly from the returned
//! [`UploadTaskHandle`] via [`result`](UploadTaskHandle::result) after
//! `commit()` returns.
//! For **sync** callers: create an [`UploadCommitSync`] with
//! [`XetSession::new_upload_commit_blocking`], queue files with
//! [`upload_from_path`](UploadCommitSync::upload_from_path) or
//! [`upload_bytes`](UploadCommitSync::upload_bytes), then call
//! [`commit`](UploadCommitSync::commit) to block until all transfers finish and
//! receive a `HashMap<Ulid, `[`UploadResult`]`>` keyed by task ID.
//!
//! For **async** callers: create an [`UploadCommit`] with
//! [`XetSession::new_upload_commit`], queue files the same way, then
//! `await` [`commit`](UploadCommit::commit).
//!
//! `UploadResult` = `Arc<Result<`[`FileMetadata`]`, `[`SessionError`]`>>`.
//! Per-task results can also be read from the returned [`UploadTaskHandle`]
//! via [`result`](UploadTaskHandle::result) after `commit()` returns.
//!
//! ## Downloads
//!
//! Create a [`DownloadGroup`] with [`XetSession::new_download_group`], queue
//! files with [`download_file_to_path`](DownloadGroup::download_file_to_path),
//! then call [`finish`](DownloadGroup::finish) to wait for all transfers and
//! receive a `HashMap<Ulid, `[`DownloadResult`]`>` keyed by task ID
//! (`DownloadResult` = `Arc<Result<`[`DownloadedFile`]`, `[`SessionError`]`>>`).
//! Per-task results can also be read directly from the returned
//! [`DownloadTaskHandle`] via [`result`](DownloadTaskHandle::result) after
//! `finish()` returns.
//! For **sync** callers: create a [`DownloadGroupSync`] with
//! [`XetSession::new_download_group_blocking`], queue files with
//! [`download_file_to_path`](DownloadGroupSync::download_file_to_path), then
//! call [`finish`](DownloadGroupSync::finish) to block until all transfers
//! complete and receive a `HashMap<Ulid, `[`DownloadResult`]`>` keyed by task ID.
//!
//! For **async** callers: create a [`DownloadGroup`] with
//! [`XetSession::new_download_group`], queue files the same way, then
//! `await` [`finish`](DownloadGroup::finish).
//!
//! `DownloadResult` = `Arc<Result<`[`DownloadedFile`]`, `[`SessionError`]`>>`.
//! Per-task results can also be read from the returned [`DownloadTaskHandle`]
//! via [`result`](DownloadTaskHandle::result) after `finish()` returns.
//!
//! ## Progress tracking
//!
//! Both [`UploadCommit`] and [`DownloadGroup`] expose
//! [`get_progress`](UploadCommit::get_progress), which returns a
//! All four types ([`UploadCommit`], [`UploadCommitSync`], [`DownloadGroup`],
//! [`DownloadGroupSync`]) expose `get_progress()`, which returns a
//! [`ProgressSnapshot`] without acquiring a lock on the calling thread
//! (useful for Python bindings that must release the GIL). Poll it from a
//! background thread while the main thread blocks in `commit()` / `finish()`.
//! background thread/task while the main thread/task blocks in
//! `commit()` / `finish()`.
//!
//! ## Error handling
//!
@@ -53,50 +64,81 @@
//! `HashMap<Ulid, `[`DownloadResult`]`>` keyed by task ID, so a single failed
//! file does not discard all others.
//!
//! # Quick start
//! # Quick start — sync API
//!
//! ```rust,no_run
//! use xet::xet_session::{XetFileInfo, XetSessionBuilder};
//!
//! // 1. Build a session
//! // 1. Build a session — sync (non-async) context only.
//! // For async code call build_async().await instead.
//! let session = XetSessionBuilder::new()
//! .with_endpoint("https://cas.example.com".into())
//! .with_token_info("my-token".into(), 1_700_000_000)
//! .build()?;
//!
//! // 2. Upload
//! let commit = session.new_upload_commit()?;
//! // 2. Upload — use the _blocking factory; returns UploadCommitSync
//! let commit = session.new_upload_commit_blocking()?;
//! let handle = commit.upload_from_path("file.bin".into())?;
//! commit.commit()?;
//! // Access result directly from the handle (populated by commit())
//! // UploadResult = Arc<Result<FileMetadata, SessionError>>
//! let m = handle.result().unwrap(); // Option<UploadResult>
//! let m = m.as_ref().as_ref().unwrap(); // &FileMetadata
//! let results = commit.commit()?;
//! let m = results.values().next().unwrap().as_ref().as_ref().unwrap();
//!
//! // 3. Download
//! let group = session.new_download_group()?;
//! // 3. Download — use the _blocking factory; returns DownloadGroupSync
//! let group = session.new_download_group_blocking()?;
//! let info = XetFileInfo {
//! hash: m.hash.clone(),
//! file_size: m.file_size,
//! };
//! let dl_handle = group.download_file_to_path(info, "out/file.bin".into())?;
//! let finish_results = group.finish()?;
//! // Pattern 1: look up by task ID in the returned HashMap
//! // DownloadResult = Arc<Result<DownloadedFile, SessionError>>
//! let r1 = finish_results.get(&dl_handle.task_id).unwrap(); // &DownloadResult
//! let r1 = r1.as_ref().as_ref().unwrap(); // &DownloadedFile
//! // Pattern 2: read directly from the handle (populated by finish())
//! let r2 = dl_handle.result().unwrap(); // DownloadResult
//! let r2 = r2.as_ref().as_ref().unwrap(); // &DownloadedFile
//! let r = finish_results.get(&dl_handle.task_id).unwrap().as_ref().as_ref().unwrap();
//!
//! # Ok::<(), xet::xet_session::SessionError>(())
//! ```
//!
//! # Quick start — async API
//!
//! ```rust,no_run
//! use xet::xet_session::{XetFileInfo, XetSessionBuilder};
//!
//! # async fn example() -> Result<(), xet::xet_session::SessionError> {
//! // 1. Build a session. build_async() auto-detects the executor:
//! // - tokio (multi-thread): wraps the caller's handle, no second thread pool.
//! // - non-tokio (smol, async-std, etc.): creates an owned thread pool.
//! let session = XetSessionBuilder::new()
//! .with_endpoint("https://cas.example.com".into())
//! .with_token_info("my-token".into(), 1_700_000_000)
//! .build_async()
//! .await?;
//!
//! // 2. Upload — use the async factory; returns UploadCommit
//! let commit = session.new_upload_commit().await?;
//! let handle = commit.upload_from_path("file.bin".into()).await?;
//! // UploadResult = Arc<Result<FileMetadata, SessionError>>
//! let results = commit.commit().await?;
//! let m = results.values().next().unwrap().as_ref().as_ref().unwrap();
//!
//! // 3. Download — use the async factory; returns DownloadGroup
//! let group = session.new_download_group().await?;
//! let info = XetFileInfo {
//! hash: m.hash.clone(),
//! file_size: m.file_size,
//! };
//! let dl_handle = group.download_file_to_path(info, "out/file.bin".into())?;
//! let finish_results = group.finish().await?;
//! // DownloadResult = Arc<Result<DownloadedFile, SessionError>>
//! let r = finish_results.get(&dl_handle.task_id).unwrap().as_ref().as_ref().unwrap();
//! # Ok(())
//! # }
//! ```
mod common;
mod download_group;
mod errors;
mod progress;
mod session;
pub mod sync;
mod upload_commit;
pub use download_group::{DownloadGroup, DownloadResult, DownloadedFile};
@@ -105,6 +147,7 @@ pub use progress::{
DownloadTaskHandle, FileProgress, ProgressSnapshot, TaskHandle, TaskStatus, TotalProgressSnapshot, UploadTaskHandle,
};
pub use session::{XetSession, XetSessionBuilder};
pub use sync::{DownloadGroupSync, UploadCommitSync};
pub use upload_commit::{FileMetadata, UploadCommit, UploadResult};
pub use xet_data::processing::XetFileInfo;
// Re-export XetConfig for convenience

View File

@@ -30,16 +30,16 @@ pub enum TaskStatus {
#[derive(Debug)]
pub struct TaskHandle {
pub(crate) status: Option<Arc<Mutex<TaskStatus>>>,
pub(crate) group_progress: Arc<GroupProgress>,
pub(super) status: Option<Arc<Mutex<TaskStatus>>>,
pub(super) group_progress: Arc<GroupProgress>,
/// Id of the task, can be used to retrive per-task progress and result.
pub task_id: Ulid,
}
#[derive(Debug)]
pub struct UploadTaskHandle {
pub(crate) inner: TaskHandle,
pub(crate) result: Arc<OnceLock<UploadResult>>,
pub(super) inner: TaskHandle,
pub(super) result: Arc<OnceLock<UploadResult>>,
}
impl Deref for UploadTaskHandle {
@@ -52,8 +52,8 @@ impl Deref for UploadTaskHandle {
#[derive(Debug)]
pub struct DownloadTaskHandle {
pub(crate) inner: TaskHandle,
pub(crate) result: Arc<OnceLock<DownloadResult>>,
pub(super) inner: TaskHandle,
pub(super) result: Arc<OnceLock<DownloadResult>>,
}
impl Deref for DownloadTaskHandle {
@@ -107,7 +107,7 @@ impl ProgressSnapshot {
}
}
/// Snapshot of aggregate progress returned by [`TaskProgress::total`].
/// Snapshot of aggregate progress.
#[derive(Clone, Debug, Default)]
pub struct TotalProgressSnapshot {
/// Total bytes known to process (includes deduplicated bytes).
@@ -124,7 +124,7 @@ pub struct TotalProgressSnapshot {
pub total_transfer_bytes_completion_rate: Option<f64>,
}
/// Snapshot of a single file's progress returned by [`TaskProgress::files`].
/// Snapshot of a single file's progress.
#[derive(Clone, Debug)]
pub struct FileProgress {
/// File name as reported by the data layer.

View File

@@ -1,16 +1,22 @@
//! XetSession - manages runtime and configuration
use std::collections::HashMap;
use std::future::Future;
use std::pin::pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Waker};
use http::HeaderMap;
use tracing::info;
use ulid::Ulid;
use xet_client::cas_client::auth::TokenRefresher;
use xet_runtime::RuntimeError;
use xet_runtime::config::XetConfig;
use xet_runtime::core::XetRuntime;
use super::download_group::DownloadGroup;
use super::errors::SessionError;
use super::sync::{DownloadGroupSync, UploadCommitSync};
use super::upload_commit::UploadCommit;
/// Session state
@@ -19,49 +25,131 @@ enum SessionState {
Aborted,
}
/// Whether the session owns its tokio runtime or inherits an external one.
///
/// - **`Owned`**: session created its own thread pool via [`XetSessionBuilder::build`] or
/// [`XetSessionBuilder::build_async`] (outside tokio). Both `_blocking` and async methods are supported. Async
/// methods use an internal `bridge_to_owned` bridge that routes futures onto the owned thread pool, so they work from
/// any executor (tokio, smol, async-std).
///
/// - **`External`**: session wraps a caller-provided tokio handle via [`XetSessionBuilder::with_tokio_handle`] or
/// [`XetSessionBuilder::build_async`] (tokio context). Only async methods may be called; `_blocking` methods return
/// [`SessionError::WrongRuntimeMode`]. No second thread pool is created.
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub(super) enum RuntimeMode {
Owned,
External,
}
/// All shared state for a session.
/// Lives behind `Arc<XetSessionInner>` — do not use this type directly.
#[doc(hidden)]
pub struct XetSessionInner {
// Independently cloned by background tasks, so needs its own Arc.
pub(crate) runtime: Arc<XetRuntime>,
pub(super) runtime: Arc<XetRuntime>,
/// Whether the session owns its runtime or wraps an external tokio handle.
pub(super) runtime_mode: RuntimeMode,
// Only accessed through &self; no independent cloning needed.
pub(crate) config: XetConfig,
pub(super) config: XetConfig,
// CAS endpoint and auth (shared by all upload commits/download groups)
pub(crate) endpoint: Option<String>,
pub(crate) token_info: Option<(String, u64)>,
pub(crate) token_refresher: Option<Arc<dyn TokenRefresher>>,
pub(crate) custom_headers: Option<Arc<HeaderMap>>,
pub(super) endpoint: Option<String>,
pub(super) token_info: Option<(String, u64)>,
pub(super) token_refresher: Option<Arc<dyn TokenRefresher>>,
pub(super) custom_headers: Option<Arc<HeaderMap>>,
// Track active upload commits and download groups.
pub(crate) active_upload_commits: Mutex<HashMap<Ulid, UploadCommit>>,
pub(crate) active_download_groups: Mutex<HashMap<Ulid, DownloadGroup>>,
pub(super) active_upload_commits: Mutex<HashMap<Ulid, UploadCommit>>,
pub(super) active_download_groups: Mutex<HashMap<Ulid, DownloadGroup>>,
// Session state
state: Mutex<SessionState>,
pub(crate) id: Ulid,
pub(super) id: Ulid,
}
/// Probe whether a tokio runtime handle meets the requirements for External mode.
///
/// Checks three things:
/// 1. **Multi-threaded flavor** (non-WASM only).
/// 2. **Time driver** — required for timeouts, retry backoff, and progress intervals.
/// 3. **IO driver** — required for all network I/O via `reqwest`/`hyper`.
///
/// Driver availability is probed by entering the handle's context and polling a
/// driver-dependent future once inside `catch_unwind`. Tokio panics synchronously
/// on the first poll when a driver is absent, so the result is immediate — no
/// spawning or blocking required.
///
/// **Fragility note:** this probing technique relies on tokio panicking
/// synchronously on the first poll of `tokio::time::sleep` /
/// `tokio::net::TcpListener::bind` when the corresponding driver (time / IO)
/// is absent. This is undocumented internal behavior validated against
/// tokio 1.x. If a future tokio version returns an error instead of
/// panicking, this function will incorrectly accept a runtime missing drivers.
///
/// Returns `true` if all requirements are met, `false` otherwise.
fn handle_meets_session_requirements(handle: &tokio::runtime::Handle) -> bool {
// Non-WASM: require a multi-threaded runtime.
#[cfg(not(target_family = "wasm"))]
if matches!(handle.runtime_flavor(), tokio::runtime::RuntimeFlavor::CurrentThread) {
return false;
}
let _guard = handle.enter();
let waker = Waker::noop();
let mut cx = Context::from_waker(waker);
let has_time = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut sleep = pin!(tokio::time::sleep(std::time::Duration::ZERO));
let _ = sleep.as_mut().poll(&mut cx);
}))
.is_ok();
let has_io = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut bind = pin!(tokio::net::TcpListener::bind("127.0.0.1:0"));
let _ = bind.as_mut().poll(&mut cx);
}))
.is_ok();
has_time && has_io
}
/// Builder for [`XetSession`].
///
/// All fields are optional; call [`build`](XetSessionBuilder::build) when done.
/// All fields are optional; call [`build`](XetSessionBuilder::build) (sync) or
/// [`build_async`](XetSessionBuilder::build_async) (async) when done.
///
/// ```rust,no_run
/// # use xet::xet_session::XetSessionBuilder;
/// // Sync context — session owns its runtime:
/// let session = XetSessionBuilder::new()
/// .with_endpoint("https://cas.example.com".into())
/// .with_token_info("my-token".into(), 1_700_000_000)
/// .build()?;
/// # Ok::<(), xet::xet_session::SessionError>(())
/// ```
///
/// ```rust,no_run
/// # use xet::xet_session::XetSessionBuilder;
/// # async fn example() -> Result<(), xet::xet_session::SessionError> {
/// // Async context — wraps the caller's tokio handle (External mode) if inside tokio,
/// // or creates an owned runtime (Owned mode) if called from a non-tokio executor:
/// let session = XetSessionBuilder::new()
/// .with_endpoint("https://cas.example.com".into())
/// .with_token_info("my-token".into(), 1_700_000_000)
/// .build_async()
/// .await?;
/// # Ok(())
/// # }
/// ```
pub struct XetSessionBuilder {
config: XetConfig,
endpoint: Option<String>,
token_info: Option<(String, u64)>,
token_refresher: Option<Arc<dyn TokenRefresher>>,
custom_headers: Option<Arc<HeaderMap>>,
tokio_handle: Option<tokio::runtime::Handle>,
}
impl Default for XetSessionBuilder {
@@ -79,6 +167,7 @@ impl XetSessionBuilder {
token_info: None,
token_refresher: None,
custom_headers: None,
tokio_handle: None,
}
}
@@ -90,6 +179,7 @@ impl XetSessionBuilder {
token_info: None,
token_refresher: None,
custom_headers: None,
tokio_handle: None,
}
}
@@ -125,15 +215,72 @@ impl XetSessionBuilder {
}
}
/// Attach to an existing tokio runtime handle.
///
/// If the handle meets session requirements (multi-thread flavor, time driver, IO driver),
/// the session will wrap it — no second thread pool is created (External mode). Only async
/// methods (`new_upload_commit`, `new_download_group`) may be called; `_blocking` variants
/// will return [`SessionError::WrongRuntimeMode`].
///
/// If the handle does **not** meet requirements (e.g. `current_thread` flavor or missing
/// drivers), it is silently ignored and [`build`](Self::build) will fall back to creating
/// an owned thread pool (Owned mode) instead.
///
/// Use [`build_async`](Self::build_async) as a convenient alternative when building from
/// within a tokio async context.
pub fn with_tokio_handle(self, handle: tokio::runtime::Handle) -> Self {
let accept = handle_meets_session_requirements(&handle);
if !accept {
info!("supplied tokio handle rejected (missing drivers or wrong flavor); falling back to Owned mode");
}
Self {
tokio_handle: accept.then_some(handle),
..self
}
}
/// Build and automatically attach to the current runtime.
///
/// Despite being `async`, this method resolves synchronously (no internal
/// `.await` points). It is declared `async` so callers in an async context
/// can use it naturally alongside `tokio::runtime::Handle::try_current()`
/// detection.
///
/// - **Tokio context** with a suitable runtime (multi-thread, time + IO drivers): wraps the caller's handle via
/// [`with_tokio_handle`](Self::with_tokio_handle) — External mode.
/// - **Tokio context** with an unsuitable runtime (e.g. `current_thread`): handle is discarded by
/// `with_tokio_handle`; falls back to an owned thread pool — Owned mode.
/// - **Non-tokio context** (smol, async-std, etc.): creates an owned thread pool — Owned mode; async methods use an
/// internal bridge compatible with any executor.
pub async fn build_async(self) -> Result<XetSession, SessionError> {
match tokio::runtime::Handle::try_current() {
Ok(handle) => self.with_tokio_handle(handle).build(),
Err(_) => self.build(),
}
}
/// Consume the builder and create a [`XetSession`].
///
/// - If a valid tokio handle was previously set via [`with_tokio_handle`](Self::with_tokio_handle), the session
/// wraps that handle (External mode) — no second thread pool is created.
/// - Otherwise, creates an owned thread pool (Owned mode); async methods use an internal bridge and work from any
/// executor, and `_blocking` methods are available.
///
/// For async contexts, prefer [`build_async`](Self::build_async).
pub fn build(self) -> Result<XetSession, SessionError> {
XetSession::new_with_config(
let (runtime, mode) = match self.tokio_handle {
Some(handle) => (XetRuntime::from_external_with_config(handle, self.config.clone()), RuntimeMode::External),
None => (XetRuntime::new_with_config(self.config.clone())?, RuntimeMode::Owned),
};
Ok(XetSession::new(
self.config,
self.endpoint,
self.token_info,
self.token_refresher,
self.custom_headers,
)
runtime,
mode,
))
}
}
@@ -167,33 +314,20 @@ impl std::ops::Deref for XetSession {
}
impl XetSession {
/// Create a session with default [`XetConfig`] — used by tests only.
/// In production code, use [`XetSessionBuilder`] instead.
#[cfg(test)]
pub(crate) fn new(
endpoint: Option<String>,
token_info: Option<(String, u64)>,
token_refresher: Option<Arc<dyn TokenRefresher>>,
custom_headers: Option<Arc<HeaderMap>>,
) -> Result<Self, SessionError> {
Self::new_with_config(XetConfig::new(), endpoint, token_info, token_refresher, custom_headers)
}
/// Internal constructor called by [`XetSessionBuilder::build`].
pub(crate) fn new_with_config(
/// Low-level constructor used by [`XetSessionBuilder::build`].
fn new(
config: XetConfig,
endpoint: Option<String>,
token_info: Option<(String, u64)>,
token_refresher: Option<Arc<dyn TokenRefresher>>,
custom_headers: Option<Arc<HeaderMap>>,
) -> Result<Self, SessionError> {
let runtime = XetRuntime::new_with_config(config.clone())?;
let session_id = Ulid::new();
Ok(Self {
runtime: Arc<XetRuntime>,
runtime_mode: RuntimeMode,
) -> Self {
Self {
inner: Arc::new(XetSessionInner {
runtime,
runtime_mode,
config,
endpoint,
token_info,
@@ -202,45 +336,151 @@ impl XetSession {
active_upload_commits: Mutex::new(HashMap::new()),
active_download_groups: Mutex::new(HashMap::new()),
state: Mutex::new(SessionState::Alive),
id: session_id,
id: Ulid::new(),
}),
})
}
}
/// Run a future on the appropriate runtime for this session.
///
/// In External mode the future is awaited directly on the caller's executor.
/// In Owned mode the future is bridged onto the owned thread pool via
/// [`XetRuntime::bridge_to_owned`].
pub(super) async fn dispatch<T, F>(&self, task_name: &'static str, fut: F) -> Result<T, RuntimeError>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
match self.runtime_mode {
RuntimeMode::External => Ok(fut.await),
RuntimeMode::Owned => self.runtime.bridge_to_owned(task_name, fut).await,
}
}
/// Create a new [`UploadCommit`] that groups related file uploads.
///
/// Returns `Err(SessionError::Aborted)` if the session has been aborted.
pub fn new_upload_commit(&self) -> Result<UploadCommit, SessionError> {
let state = self.state.lock()?;
if matches!(*state, SessionState::Aborted) {
return Err(SessionError::Aborted);
///
/// # Note
///
/// This is an `async fn` and must be `.await`ed. For sync Rust or Python (PyO3) callers,
/// use [`new_upload_commit_blocking`](Self::new_upload_commit_blocking).
pub async fn new_upload_commit(&self) -> Result<UploadCommit, SessionError> {
// Check state before the async init; drop the guard so it is not held across .await.
{
let state = self.state.lock()?;
if matches!(*state, SessionState::Aborted) {
return Err(SessionError::Aborted);
}
}
let commit = UploadCommit::new(self.clone())?;
let session = self.clone();
let commit = self
.dispatch("new_upload_commit", async move { UploadCommit::new(session).await })
.await??;
// Register the commit
// Register the commit (sync insertion, safe in any executor context)
self.active_upload_commits.lock()?.insert(commit.id(), commit.clone());
Ok(commit)
}
/// Create a new [`UploadCommit`] from a **sync** (non-async) context.
///
/// Returns `Err(SessionError::Aborted)` if the session has been aborted.
/// Returns `Err(SessionError::WrongRuntimeMode)` if the session was built with
/// [`XetSessionBuilder::with_tokio_handle`] / [`XetSessionBuilder::build_async`].
///
/// # Panics
///
/// Panics if called from within a **tokio** async runtime (tokio sets a thread-local
/// context that `Handle::block_on` detects and panics on). Non-tokio executors (smol,
/// async-std, `futures::executor`) do not set this context, so calling from those is
/// safe — it blocks the executor thread until the task completes. Use
/// [`new_upload_commit`](Self::new_upload_commit) from async contexts instead.
pub fn new_upload_commit_blocking(&self) -> Result<UploadCommitSync, SessionError> {
if matches!(self.runtime_mode, RuntimeMode::External) {
return Err(SessionError::wrong_mode(
"new_upload_commit_blocking() cannot be called on a session built with \
with_tokio_handle() / build_async(); use new_upload_commit().await instead",
));
}
{
let state = self.state.lock()?;
if matches!(*state, SessionState::Aborted) {
return Err(SessionError::Aborted);
}
}
let sync_commit = UploadCommitSync::new(self.clone())?;
self.active_upload_commits
.lock()?
.insert(sync_commit.inner.id(), sync_commit.inner.clone());
Ok(sync_commit)
}
/// Create a new [`DownloadGroup`] that groups related file downloads.
///
/// Returns `Err(SessionError::Aborted)` if the session has been aborted.
pub fn new_download_group(&self) -> Result<DownloadGroup, SessionError> {
let state = self.state.lock()?;
if matches!(*state, SessionState::Aborted) {
return Err(SessionError::Aborted);
///
/// # Note
///
/// This is an `async fn` and must be `.await`ed. For sync Rust or Python (PyO3) callers,
/// use [`new_download_group_blocking`](Self::new_download_group_blocking).
pub async fn new_download_group(&self) -> Result<DownloadGroup, SessionError> {
// Check state before the async init; drop the guard so it is not held across .await.
{
let state = self.state.lock()?;
if matches!(*state, SessionState::Aborted) {
return Err(SessionError::Aborted);
}
}
let group = DownloadGroup::new(self.clone())?;
let session = self.clone();
let group = self
.dispatch("new_download_group", async move { DownloadGroup::new(session).await })
.await??;
// Register the group
// Register the group (sync insertion, safe in any executor context)
self.active_download_groups.lock()?.insert(group.id(), group.clone());
Ok(group)
}
/// Create a new [`DownloadGroup`] from a **sync** (non-async) context.
///
/// Returns `Err(SessionError::Aborted)` if the session has been aborted.
/// Returns `Err(SessionError::WrongRuntimeMode)` if the session was built with
/// [`XetSessionBuilder::with_tokio_handle`] / [`XetSessionBuilder::build_async`].
///
/// # Panics
///
/// Panics if called from within a **tokio** async runtime (tokio sets a thread-local
/// context that `Handle::block_on` detects and panics on). Non-tokio executors (smol,
/// async-std, `futures::executor`) do not set this context, so calling from those is
/// safe — it blocks the executor thread until the task completes. Use
/// [`new_download_group`](Self::new_download_group) from async contexts instead.
pub fn new_download_group_blocking(&self) -> Result<DownloadGroupSync, SessionError> {
if matches!(self.runtime_mode, RuntimeMode::External) {
return Err(SessionError::wrong_mode(
"new_download_group_blocking() cannot be called on a session built with \
with_tokio_handle() / build_async(); use new_download_group().await instead",
));
}
{
let state = self.state.lock()?;
if matches!(*state, SessionState::Aborted) {
return Err(SessionError::Aborted);
}
}
let sync_group = DownloadGroupSync::new(self.clone())?;
self.active_download_groups
.lock()?
.insert(sync_group.inner.id(), sync_group.inner.clone());
Ok(sync_group)
}
/// Abort the session - cancel all currently running tasks
///
/// This performs a SIGINT-style shutdown, aborting all active upload and download tasks.
@@ -266,19 +506,19 @@ impl XetSession {
Ok(())
}
pub(crate) fn check_alive(&self) -> Result<(), SessionError> {
pub(super) fn check_alive(&self) -> Result<(), SessionError> {
if matches!(*self.state.lock()?, SessionState::Aborted) {
return Err(SessionError::Aborted);
}
Ok(())
}
pub(crate) fn finish_upload_commit(&self, commit_id: Ulid) -> Result<(), SessionError> {
pub(super) fn finish_upload_commit(&self, commit_id: Ulid) -> Result<(), SessionError> {
self.active_upload_commits.lock()?.remove(&commit_id);
Ok(())
}
pub(crate) fn finish_download_group(&self, group_id: Ulid) -> Result<(), SessionError> {
pub(super) fn finish_download_group(&self, group_id: Ulid) -> Result<(), SessionError> {
self.active_download_groups.lock()?.remove(&group_id);
Ok(())
}
@@ -288,16 +528,12 @@ impl XetSession {
mod tests {
use super::*;
fn make_session() -> XetSession {
XetSession::new(None, None, None, None).expect("Failed to create session")
}
// ── Identity ─────────────────────────────────────────────────────────────
#[test]
// A clone refers to the same inner Arc, so their session IDs must match.
fn test_session_clone_shares_state() {
let s1 = make_session();
let s1 = XetSessionBuilder::new().build().unwrap();
let s2 = s1.clone();
assert_eq!(s1.id, s2.id);
}
@@ -305,8 +541,8 @@ mod tests {
#[test]
// Two independently created sessions have distinct IDs.
fn test_two_sessions_have_distinct_ids() {
let s1 = make_session();
let s2 = make_session();
let s1 = XetSessionBuilder::new().build().unwrap();
let s2 = XetSessionBuilder::new().build().unwrap();
assert_ne!(s1.id, s2.id);
}
@@ -315,36 +551,36 @@ mod tests {
#[test]
// After abort, check_alive returns Aborted.
fn test_check_alive_after_abort() {
let session = make_session();
let session = XetSessionBuilder::new().build().unwrap();
session.abort().unwrap();
let err = session.check_alive().unwrap_err();
assert!(matches!(err, SessionError::Aborted));
}
#[test]
// new_upload_commit on an aborted session returns Aborted.
// new_upload_commit_blocking on an aborted session returns Aborted.
fn test_new_upload_commit_after_abort_returns_aborted() {
let session = make_session();
let session = XetSessionBuilder::new().build().unwrap();
session.abort().unwrap();
let err = session.new_upload_commit().err().unwrap();
let err = session.new_upload_commit_blocking().err().unwrap();
assert!(matches!(err, SessionError::Aborted));
}
#[test]
// new_download_group on an aborted session returns Aborted.
// new_download_group_blocking on an aborted session returns Aborted.
fn test_new_download_group_after_abort_returns_aborted() {
let session = make_session();
let session = XetSessionBuilder::new().build().unwrap();
session.abort().unwrap();
let err = session.new_download_group().err().unwrap();
let err = session.new_download_group_blocking().err().unwrap();
assert!(matches!(err, SessionError::Aborted));
}
#[test]
// Aborting a session clears all registered upload commits.
fn test_abort_clears_active_upload_commits() {
let session = make_session();
let _c1 = session.new_upload_commit().unwrap();
let _c2 = session.new_upload_commit().unwrap();
let session = XetSessionBuilder::new().build().unwrap();
let _c1 = session.new_upload_commit_blocking().unwrap();
let _c2 = session.new_upload_commit_blocking().unwrap();
session.abort().unwrap();
assert_eq!(session.active_upload_commits.lock().unwrap().len(), 0);
}
@@ -352,8 +588,8 @@ mod tests {
#[test]
// Aborting a session clears all registered download groups.
fn test_abort_clears_active_download_groups() {
let session = make_session();
let _g1 = session.new_download_group().unwrap();
let session = XetSessionBuilder::new().build().unwrap();
let _g1 = session.new_download_group_blocking().unwrap();
session.abort().unwrap();
assert_eq!(session.active_download_groups.lock().unwrap().len(), 0);
}
@@ -363,16 +599,16 @@ mod tests {
#[test]
// A new upload commit is registered in the session's active set.
fn test_new_upload_commit_registers_in_session() {
let session = make_session();
let _commit = session.new_upload_commit().unwrap();
let session = XetSessionBuilder::new().build().unwrap();
let _commit = session.new_upload_commit_blocking().unwrap();
assert_eq!(session.active_upload_commits.lock().unwrap().len(), 1);
}
#[test]
// A new download group is registered in the session's active set.
fn test_new_download_group_registers_in_session() {
let session = make_session();
let _group = session.new_download_group().unwrap();
let session = XetSessionBuilder::new().build().unwrap();
let _group = session.new_download_group_blocking().unwrap();
assert_eq!(session.active_download_groups.lock().unwrap().len(), 1);
}
@@ -381,32 +617,176 @@ mod tests {
#[test]
// finish_upload_commit removes only the specified commit, leaving others intact.
fn test_finish_upload_commit_removes_only_that_commit() {
let session = make_session();
let c1 = session.new_upload_commit().unwrap();
let _c2 = session.new_upload_commit().unwrap();
let session = XetSessionBuilder::new().build().unwrap();
let c1 = session.new_upload_commit_blocking().unwrap();
let _c2 = session.new_upload_commit_blocking().unwrap();
assert_eq!(session.active_upload_commits.lock().unwrap().len(), 2);
session.finish_upload_commit(c1.id()).unwrap();
session.finish_upload_commit(c1.inner.id()).unwrap();
assert_eq!(session.active_upload_commits.lock().unwrap().len(), 1);
}
#[test]
// finish_download_group removes only the specified group, leaving others intact.
fn test_finish_download_group_removes_only_that_group() {
let session = make_session();
let g1 = session.new_download_group().unwrap();
let _g2 = session.new_download_group().unwrap();
let session = XetSessionBuilder::new().build().unwrap();
let g1 = session.new_download_group_blocking().unwrap();
let _g2 = session.new_download_group_blocking().unwrap();
assert_eq!(session.active_download_groups.lock().unwrap().len(), 2);
session.finish_download_group(g1.id()).unwrap();
session.finish_download_group(g1.inner.id()).unwrap();
assert_eq!(session.active_download_groups.lock().unwrap().len(), 1);
}
#[test]
// finish_upload_commit on an unknown ID is a no-op (no error, no change).
fn test_finish_upload_commit_with_unknown_id_is_noop() {
let session = make_session();
let _c1 = session.new_upload_commit().unwrap();
let session = XetSessionBuilder::new().build().unwrap();
let _c1 = session.new_upload_commit_blocking().unwrap();
let unknown_id = ulid::Ulid::new();
assert!(session.finish_upload_commit(unknown_id).is_ok());
assert_eq!(session.active_upload_commits.lock().unwrap().len(), 1);
}
// ── Async abort behavior ──────────────────────────────────────────────────
#[tokio::test(flavor = "multi_thread")]
// new_upload_commit / new_download_group on an aborted session both return Aborted.
async fn test_async_new_after_abort_returns_aborted() {
let session = XetSessionBuilder::new().build_async().await.unwrap();
session.abort().unwrap();
let commit_err = session.new_upload_commit().await.err().unwrap();
let group_err = session.new_download_group().await.err().unwrap();
assert!(matches!(commit_err, SessionError::Aborted));
assert!(matches!(group_err, SessionError::Aborted));
}
#[tokio::test(flavor = "multi_thread")]
// Aborting a session clears all active upload commits and download groups.
async fn test_async_abort_clears_active_commits_and_groups() {
let session = XetSessionBuilder::new().build_async().await.unwrap();
let _c1 = session.new_upload_commit().await.unwrap();
let _c2 = session.new_upload_commit().await.unwrap();
let _g1 = session.new_download_group().await.unwrap();
session.abort().unwrap();
assert_eq!(session.active_upload_commits.lock().unwrap().len(), 0);
assert_eq!(session.active_download_groups.lock().unwrap().len(), 0);
}
// ── Async registration ────────────────────────────────────────────────────
#[tokio::test(flavor = "multi_thread")]
// A new upload commit and a new download group are each registered in the
// session's active set, and concurrent creation registers both.
async fn test_async_new_registers_in_session() {
let session = XetSessionBuilder::new().build_async().await.unwrap();
let _commit = session.new_upload_commit().await.unwrap();
let _group = session.new_download_group().await.unwrap();
assert_eq!(session.active_upload_commits.lock().unwrap().len(), 1);
assert_eq!(session.active_download_groups.lock().unwrap().len(), 1);
}
// ── Async deregistration ──────────────────────────────────────────────────
#[tokio::test(flavor = "multi_thread")]
// Finishing one upload commit / download group removes only that one,
// leaving the other still registered.
async fn test_async_finish_removes_only_that_item() {
let session = XetSessionBuilder::new().build_async().await.unwrap();
let c1 = session.new_upload_commit().await.unwrap();
let _c2 = session.new_upload_commit().await.unwrap();
let g1 = session.new_download_group().await.unwrap();
let _g2 = session.new_download_group().await.unwrap();
assert_eq!(session.active_upload_commits.lock().unwrap().len(), 2);
assert_eq!(session.active_download_groups.lock().unwrap().len(), 2);
session.finish_upload_commit(c1.id()).unwrap();
session.finish_download_group(g1.id()).unwrap();
assert_eq!(session.active_upload_commits.lock().unwrap().len(), 1);
assert_eq!(session.active_download_groups.lock().unwrap().len(), 1);
}
// ── handle_meets_session_requirements ────────────────────────────────────
#[test]
// A multi-thread runtime with enable_all() meets all requirements.
fn test_handle_multi_thread_all_features_returns_true() {
let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
assert!(handle_meets_session_requirements(rt.handle()));
}
#[test]
#[cfg(not(target_family = "wasm"))]
// A current_thread runtime is rejected even when enable_all() is set.
fn test_handle_current_thread_returns_false() {
let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap();
assert!(!handle_meets_session_requirements(rt.handle()));
}
#[test]
// A multi-thread runtime with no drivers enabled returns false.
fn test_handle_without_any_driver_returns_false() {
let rt = tokio::runtime::Builder::new_multi_thread().build().unwrap();
assert!(!handle_meets_session_requirements(rt.handle()));
}
#[test]
// A multi-thread runtime with only enable_time() is missing the IO driver.
fn test_handle_without_io_driver_returns_false() {
let rt = tokio::runtime::Builder::new_multi_thread().enable_time().build().unwrap();
assert!(!handle_meets_session_requirements(rt.handle()));
}
#[test]
// A multi-thread runtime with only enable_io() is missing the time driver.
fn test_handle_without_time_driver_returns_false() {
let rt = tokio::runtime::Builder::new_multi_thread().enable_io().build().unwrap();
assert!(!handle_meets_session_requirements(rt.handle()));
}
// ── External-mode _blocking guard ────────────────────────────────────────
#[tokio::test(flavor = "multi_thread")]
// new_upload_commit_blocking returns WrongRuntimeMode on an External-mode session.
async fn test_new_upload_commit_blocking_errors_in_external_mode() {
let session = XetSessionBuilder::new().build_async().await.unwrap();
assert_eq!(session.runtime_mode, RuntimeMode::External);
let err = session.new_upload_commit_blocking().err().unwrap();
assert!(matches!(err, SessionError::WrongRuntimeMode(_)));
}
#[tokio::test(flavor = "multi_thread")]
// new_download_group_blocking returns WrongRuntimeMode on an External-mode session.
async fn test_new_download_group_blocking_errors_in_external_mode() {
let session = XetSessionBuilder::new().build_async().await.unwrap();
assert_eq!(session.runtime_mode, RuntimeMode::External);
let err = session.new_download_group_blocking().err().unwrap();
assert!(matches!(err, SessionError::WrongRuntimeMode(_)));
}
// ── Owned-mode _blocking panic guard ─────────────────────────────────────
#[test]
// new_upload_commit_blocking panics when called from within a tokio runtime on an
// Owned-mode session: external_run_async_task calls handle.block_on(), which panics
// because tokio sets a thread-local runtime context that it detects and rejects.
fn test_new_upload_commit_blocking_panics_in_async_context() {
let session = XetSessionBuilder::new().build().unwrap();
assert_eq!(session.runtime_mode, RuntimeMode::Owned);
let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
rt.block_on(async { session.new_upload_commit_blocking() })
}));
assert!(result.is_err(), "new_upload_commit_blocking() must panic when called from async");
}
#[test]
// new_download_group_blocking panics when called from within a tokio runtime on an
// Owned-mode session: same mechanism as the upload variant above.
fn test_new_download_group_blocking_panics_in_async_context() {
let session = XetSessionBuilder::new().build().unwrap();
assert_eq!(session.runtime_mode, RuntimeMode::Owned);
let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
rt.block_on(async { session.new_download_group_blocking() })
}));
assert!(result.is_err(), "new_download_group_blocking() must panic when called from async");
}
}

View File

@@ -0,0 +1,286 @@
//! Sync-context download group wrapper.
//!
//! [`DownloadGroupSync`] is obtained from [`XetSession::new_download_group_blocking`] and
//! provides a fully blocking API suitable for sync Rust or Python (PyO3) callers.
//! For async Rust use [`DownloadGroup`] instead.
use std::collections::HashMap;
use std::path::PathBuf;
use ulid::Ulid;
use xet_data::processing::XetFileInfo;
use super::super::download_group::{DownloadGroup, DownloadResult};
use super::super::errors::SessionError;
use super::super::progress::{DownloadTaskHandle, ProgressSnapshot};
use super::super::session::XetSession;
/// Sync-context handle for grouping related file downloads.
///
/// Obtained from [`XetSession::new_download_group_blocking`]. All methods block
/// the calling thread — **do not use from within an async runtime** (it will panic).
/// For async Rust code use [`DownloadGroup`] from [`XetSession::new_download_group`].
///
/// # Cloning
///
/// Cloning is cheap — it simply increments an atomic reference count.
/// All clones share the same background worker and task state.
#[derive(Clone)]
pub struct DownloadGroupSync {
pub(in super::super) inner: DownloadGroup,
}
impl DownloadGroupSync {
/// Create a new download group from a **sync** (non-async) context.
///
/// # Panics
///
/// Panics if called from within an async runtime — use
/// [`XetSession::new_download_group`] instead.
pub(in super::super) fn new(session: XetSession) -> Result<Self, SessionError> {
let group = session.runtime.external_run_async_task(DownloadGroup::new(session.clone()))??;
Ok(Self { inner: group })
}
/// Queue a file for download. See [`DownloadGroup::download_file_to_path`] for full documentation.
pub fn download_file_to_path(
&self,
file_info: XetFileInfo,
dest_path: PathBuf,
) -> Result<DownloadTaskHandle, SessionError> {
self.inner.download_file_to_path(file_info, dest_path)
}
/// Return a snapshot of progress for every queued download.
pub fn get_progress(&self) -> Result<ProgressSnapshot, SessionError> {
self.inner.get_progress()
}
/// Wait for all downloads to complete and return their results.
///
/// See [`DownloadGroup::finish`] for full documentation.
///
/// # Panics
///
/// Panics if called from within an async runtime. Use [`DownloadGroup::finish`] instead.
pub fn finish(self) -> Result<HashMap<Ulid, DownloadResult>, SessionError> {
let group = self.inner.clone();
self.inner.runtime().external_run_async_task(group.finish())?
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use tempfile::{TempDir, tempdir};
use super::*;
use crate::xet_session::progress::UploadTaskHandle;
use crate::xet_session::session::{XetSession, XetSessionBuilder};
fn local_session(temp: &TempDir) -> Result<XetSession, Box<dyn std::error::Error>> {
let cas_path = temp.path().join("cas");
Ok(XetSessionBuilder::new()
.with_endpoint(format!("local://{}", cas_path.display()))
.build()?)
}
fn upload_bytes(session: &XetSession, data: &[u8], name: &str) -> Result<XetFileInfo, Box<dyn std::error::Error>> {
let commit = session.new_upload_commit_blocking()?;
let handle = commit.upload_bytes(data.to_vec(), Some(name.into()))?;
let results = commit.commit()?;
let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap();
Ok(XetFileInfo {
hash: meta.hash.clone(),
file_size: meta.file_size,
})
}
// ── Round-trip tests ─────────────────────────────────────────────────────
#[test]
// Downloading a previously uploaded file produces byte-identical content at the destination.
fn test_download_file_round_trip() -> Result<(), Box<dyn std::error::Error>> {
let temp = tempdir()?;
let session = local_session(&temp)?;
let original = b"Hello, download round-trip!";
let file_info = upload_bytes(&session, original, "payload.bin")?;
let dest = temp.path().join("downloaded.bin");
let group = session.new_download_group_blocking()?;
group.download_file_to_path(file_info, dest.clone())?;
group.finish()?;
assert_eq!(std::fs::read(&dest)?, original);
Ok(())
}
#[test]
// Downloading multiple files from a single group produces correct content for each.
fn test_download_multiple_files() -> Result<(), Box<dyn std::error::Error>> {
let temp = tempdir()?;
let session = local_session(&temp)?;
let data_a = b"First file content";
let data_b = b"Second file content - different";
let commit = session.new_upload_commit_blocking()?;
let handle_a = commit.upload_bytes(data_a.to_vec(), Some("a.bin".into()))?;
let handle_b = commit.upload_bytes(data_b.to_vec(), Some("b.bin".into()))?;
let results = commit.commit()?;
let to_file_info = |handle: &UploadTaskHandle| -> XetFileInfo {
let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap();
XetFileInfo {
hash: meta.hash.clone(),
file_size: meta.file_size,
}
};
let dest_a = temp.path().join("a_out.bin");
let dest_b = temp.path().join("b_out.bin");
let group = session.new_download_group_blocking()?;
group.download_file_to_path(to_file_info(&handle_a), dest_a.clone())?;
group.download_file_to_path(to_file_info(&handle_b), dest_b.clone())?;
group.finish()?;
assert_eq!(std::fs::read(&dest_a)?, data_a);
assert_eq!(std::fs::read(&dest_b)?, data_b);
Ok(())
}
#[test]
// After a successful finish the aggregate download progress reflects bytes received.
fn test_download_progress_reflects_bytes_after_finish() -> Result<(), Box<dyn std::error::Error>> {
let temp = tempdir()?;
let session = local_session(&temp)?;
let original = b"download progress tracking data";
let file_info = upload_bytes(&session, original, "prog.bin")?;
let dest = temp.path().join("out.bin");
let group = session.new_download_group_blocking()?;
let progress_observer = group.clone();
group.download_file_to_path(file_info, dest)?;
group.finish()?;
std::thread::sleep(
session
.runtime
.config()
.data
.progress_update_interval
.saturating_add(Duration::from_secs(1)),
);
let snapshot = progress_observer.get_progress()?;
assert!(snapshot.total().total_bytes_completed > 0);
Ok(())
}
// ── Per-task result access patterns ──────────────────────────────────────
#[test]
// Pattern 1: per-task result is accessible via task_id in the finish() HashMap.
fn test_download_result_accessible_via_task_id_in_finish_map() -> Result<(), Box<dyn std::error::Error>> {
let temp = tempdir()?;
let session = local_session(&temp)?;
let data = b"result via task_id in finish map";
let file_info = upload_bytes(&session, data, "file.bin")?;
let dest = temp.path().join("out.bin");
let group = session.new_download_group_blocking()?;
let handle = group.download_file_to_path(file_info, dest)?;
let results = group.finish()?;
let result = results.get(&handle.task_id).expect("task_id must be present in results");
assert_eq!(result.as_ref().as_ref().unwrap().file_info.file_size, data.len() as u64);
Ok(())
}
#[test]
// DownloadTaskHandle::result() returns None before finish() is called.
fn test_download_result_none_before_finish() -> Result<(), Box<dyn std::error::Error>> {
let temp = tempdir()?;
let session = local_session(&temp)?;
let file_info = upload_bytes(&session, b"some data", "file.bin")?;
let dest = temp.path().join("out.bin");
let group = session.new_download_group_blocking()?;
let handle = group.download_file_to_path(file_info, dest)?;
assert!(handle.result().is_none(), "result must be None before finish()");
group.finish()?;
Ok(())
}
#[test]
// DownloadTaskHandle::result() returns Some after finish() completes.
fn test_download_result_some_after_finish() -> Result<(), Box<dyn std::error::Error>> {
let temp = tempdir()?;
let session = local_session(&temp)?;
let data = b"download result test data";
let file_info = upload_bytes(&session, data, "file.bin")?;
let dest = temp.path().join("out.bin");
let group = session.new_download_group_blocking()?;
let handle = group.download_file_to_path(file_info.clone(), dest)?;
group.finish()?;
let result = handle.result().unwrap();
let dl = result.as_ref().as_ref().unwrap();
assert_eq!(dl.file_info.file_size, data.len() as u64);
assert_eq!(dl.file_info.hash, file_info.hash);
Ok(())
}
// ── Non-tokio executor (no-panic + round-trip) ────────────────────────────
//
// smol, async-std, and futures::executor do not set tokio's thread-local
// runtime context, so Handle::block_on inside external_run_async_task does
// not panic — it just blocks the calling executor thread.
#[test]
// new_download_group_blocking completes a full upload+download round-trip inside smol.
fn test_download_blocking_round_trip_in_smol() {
let temp = tempdir().unwrap();
let session = local_session(&temp).unwrap();
smol::block_on(async {
let data = b"download from smol executor";
let file_info = upload_bytes(&session, data, "test.bin").unwrap();
let dest = temp.path().join("out_smol.bin");
let group = session.new_download_group_blocking().unwrap();
group.download_file_to_path(file_info, dest.clone()).unwrap();
group.finish().unwrap();
assert_eq!(std::fs::read(&dest).unwrap(), data);
});
}
#[test]
// new_download_group_blocking completes a full upload+download round-trip inside futures::executor.
fn test_download_blocking_round_trip_in_futures_executor() {
let temp = tempdir().unwrap();
let session = local_session(&temp).unwrap();
futures::executor::block_on(async {
let data = b"download from futures executor";
let file_info = upload_bytes(&session, data, "test.bin").unwrap();
let dest = temp.path().join("out_futures.bin");
let group = session.new_download_group_blocking().unwrap();
group.download_file_to_path(file_info, dest.clone()).unwrap();
group.finish().unwrap();
assert_eq!(std::fs::read(&dest).unwrap(), data);
});
}
#[test]
// new_download_group_blocking completes a full upload+download round-trip inside async-std.
fn test_download_blocking_round_trip_in_async_std() {
let temp = tempdir().unwrap();
let session = local_session(&temp).unwrap();
async_std::task::block_on(async {
let data = b"download from async-std executor";
let file_info = upload_bytes(&session, data, "test.bin").unwrap();
let dest = temp.path().join("out_async_std.bin");
let group = session.new_download_group_blocking().unwrap();
group.download_file_to_path(file_info, dest.clone()).unwrap();
group.finish().unwrap();
assert_eq!(std::fs::read(&dest).unwrap(), data);
});
}
}

View File

@@ -0,0 +1,5 @@
pub mod download_group_sync;
pub mod upload_commit_sync;
pub use download_group_sync::DownloadGroupSync;
pub use upload_commit_sync::UploadCommitSync;

View File

@@ -0,0 +1,315 @@
//! Sync-context upload commit wrapper.
//!
//! [`UploadCommitSync`] is obtained from [`XetSession::new_upload_commit_blocking`] and
//! provides a fully blocking API suitable for sync Rust or Python (PyO3) callers.
//! For async Rust use [`UploadCommit`] instead.
use std::collections::HashMap;
use std::path::PathBuf;
use ulid::Ulid;
use xet_data::processing::SingleFileCleaner;
use super::super::errors::SessionError;
use super::super::progress::{ProgressSnapshot, TaskHandle, UploadTaskHandle};
use super::super::session::XetSession;
use super::super::upload_commit::{UploadCommit, UploadResult};
/// Sync-context handle for grouping related file uploads.
///
/// Obtained from [`XetSession::new_upload_commit_blocking`]. All methods block
/// the calling thread — **do not use from within an async runtime** (it will panic).
/// For async Rust code use [`UploadCommit`] from [`XetSession::new_upload_commit`].
///
/// # Cloning
///
/// Cloning is cheap — it simply increments an atomic reference count.
/// All clones share the same upload session and task state.
#[derive(Clone)]
pub struct UploadCommitSync {
pub(in super::super) inner: UploadCommit,
}
impl UploadCommitSync {
/// Create a new upload commit from a **sync** (non-async) context.
///
/// # Panics
///
/// Panics if called from within an async runtime — use
/// [`XetSession::new_upload_commit`] instead.
pub(in super::super) fn new(session: XetSession) -> Result<Self, SessionError> {
let commit = session.runtime.external_run_async_task(UploadCommit::new(session.clone()))??;
Ok(Self { inner: commit })
}
/// Queue a file for upload. See [`UploadCommit::upload_from_path`] for full documentation.
pub fn upload_from_path(&self, file_path: PathBuf) -> Result<UploadTaskHandle, SessionError> {
self.inner.session.check_alive()?;
let commit_inner = self.inner.inner.clone();
self.inner
.runtime()
.external_run_async_task(async move { commit_inner.start_upload_file_from_path(file_path).await })?
}
/// Queue raw bytes for upload. See [`UploadCommit::upload_bytes`] for full documentation.
pub fn upload_bytes(
&self,
bytes: Vec<u8>,
tracking_name: Option<String>,
) -> Result<UploadTaskHandle, SessionError> {
self.inner.session.check_alive()?;
let commit_inner = self.inner.inner.clone();
self.inner
.runtime()
.external_run_async_task(async move { commit_inner.start_upload_bytes(bytes, tracking_name).await })?
}
/// Begin an incremental file upload, returning a [`SingleFileCleaner`] to stream bytes.
///
/// This is the sync-context equivalent of [`UploadCommit::upload_file`].
///
/// # Parameters
///
/// - `file_name`: optional name used for progress/telemetry reporting.
/// - `file_size`: expected size in bytes (used for progress tracking; `0` is valid if unknown).
///
/// # Panics
///
/// Panics if called from within an async runtime. Use [`UploadCommit::upload_file`] instead.
pub fn upload_file(
&self,
file_name: Option<String>,
file_size: u64,
) -> Result<(TaskHandle, SingleFileCleaner), SessionError> {
self.inner.session.check_alive()?;
let commit_inner = self.inner.clone();
self.inner
.runtime()
.external_run_async_task(async move { commit_inner.start_upload_file(file_name, file_size).await })?
}
/// Return a snapshot of progress for every queued upload.
pub fn get_progress(&self) -> Result<ProgressSnapshot, SessionError> {
self.inner.get_progress()
}
/// Wait for all uploads to complete and push metadata to the CAS server.
///
/// Returns a `HashMap` keyed by task ID. See [`UploadCommit::commit`] for full documentation.
///
/// # Panics
///
/// Panics if called from within an async runtime. Use [`UploadCommit::commit`] instead.
pub fn commit(self) -> Result<HashMap<Ulid, UploadResult>, SessionError> {
let commit = self.inner.clone();
self.inner.runtime().external_run_async_task(commit.commit())?
}
}
#[cfg(test)]
mod tests {
use tempfile::{TempDir, tempdir};
use crate::xet_session::session::{XetSession, XetSessionBuilder};
fn local_session(temp: &TempDir) -> Result<XetSession, Box<dyn std::error::Error>> {
let cas_path = temp.path().join("cas");
Ok(XetSessionBuilder::new()
.with_endpoint(format!("local://{}", cas_path.display()))
.build()?)
}
// ── Round-trip tests ─────────────────────────────────────────────────────
#[test]
// Uploading raw bytes and committing returns a non-empty hash and the correct file size.
fn test_upload_bytes_round_trip() -> Result<(), Box<dyn std::error::Error>> {
let temp = tempdir()?;
let session = local_session(&temp)?;
let data = b"Hello, upload commit round-trip!";
let commit = session.new_upload_commit_blocking()?;
let task_handle = commit.upload_bytes(data.to_vec(), Some("hello.bin".into()))?;
let results = commit.commit()?;
assert_eq!(results.len(), 1);
let meta = results.get(&task_handle.task_id).unwrap().as_ref().as_ref().unwrap();
assert_eq!(meta.file_size, data.len() as u64);
assert!(!meta.hash.is_empty());
Ok(())
}
#[test]
// Uploading a file from disk and committing returns the correct file size.
fn test_upload_from_path_round_trip() -> Result<(), Box<dyn std::error::Error>> {
let temp = tempdir()?;
let session = local_session(&temp)?;
let src = temp.path().join("data.bin");
let data = b"file path upload content";
std::fs::write(&src, data)?;
let commit = session.new_upload_commit_blocking()?;
let handle = commit.upload_from_path(src)?;
commit.commit()?;
let meta = handle.result().unwrap();
let meta = meta.as_ref().as_ref().unwrap();
assert_eq!(meta.file_size, data.len() as u64);
assert!(!meta.hash.is_empty());
Ok(())
}
// ── Per-task result access patterns ──────────────────────────────────────
#[test]
// UploadTaskHandle::result() returns None before commit() is called.
fn test_upload_result_none_before_commit() -> Result<(), Box<dyn std::error::Error>> {
let temp = tempdir()?;
let session = local_session(&temp)?;
let src = temp.path().join("data.bin");
std::fs::write(&src, b"content")?;
let commit = session.new_upload_commit_blocking()?;
let handle = commit.upload_from_path(src)?;
assert!(handle.result().is_none(), "result must be None before commit()");
commit.commit()?;
Ok(())
}
#[test]
// Pattern 1: per-task result is accessible via task_id in the commit() HashMap.
fn test_upload_result_accessible_via_task_id_in_commit_map() -> Result<(), Box<dyn std::error::Error>> {
let temp = tempdir()?;
let session = local_session(&temp)?;
let data = b"result via task_id";
let src = temp.path().join("data.bin");
std::fs::write(&src, data)?;
let commit = session.new_upload_commit_blocking()?;
let handle = commit.upload_from_path(src)?;
let results = commit.commit()?;
let result = results.get(&handle.task_id).expect("task_id must be present in results");
assert_eq!(result.as_ref().as_ref().unwrap().file_size, data.len() as u64);
Ok(())
}
#[test]
// Pattern 2: per-task result is accessible directly from the UploadTaskHandle after commit().
fn test_upload_result_accessible_via_handle_after_commit() -> Result<(), Box<dyn std::error::Error>> {
let temp = tempdir()?;
let session = local_session(&temp)?;
let data = b"result via handle";
let src = temp.path().join("data.bin");
std::fs::write(&src, data)?;
let commit = session.new_upload_commit_blocking()?;
let handle = commit.upload_from_path(src)?;
commit.commit()?;
let result = handle.result().expect("result must be set after commit");
assert_eq!(result.as_ref().as_ref().unwrap().file_size, data.len() as u64);
Ok(())
}
#[test]
// Streaming upload via upload_file + SingleFileCleaner.
fn test_upload_streaming_round_trip() -> Result<(), Box<dyn std::error::Error>> {
let temp = tempdir()?;
let session = local_session(&temp)?;
let data = b"streamed upload bytes";
let runtime = session.runtime.clone();
let commit = session.new_upload_commit_blocking()?;
let (_handle, mut cleaner) = commit.upload_file(Some("stream.bin".into()), data.len() as u64)?;
let (hash, file_size) = runtime.external_run_async_task(async move {
cleaner.add_data(data).await.unwrap();
let (xfi, _) = cleaner.finish().await.unwrap();
(xfi.hash, xfi.file_size)
})?;
let results = commit.commit()?;
assert!(results.is_empty());
assert_eq!(file_size, data.len() as u64);
assert!(!hash.is_empty());
Ok(())
}
#[test]
// Uploading multiple blobs in one commit returns one result per upload.
fn test_upload_multiple_files_in_one_commit() -> Result<(), Box<dyn std::error::Error>> {
let temp = tempdir()?;
let session = local_session(&temp)?;
let commit = session.new_upload_commit_blocking()?;
commit.upload_bytes(b"file one".to_vec(), Some("a.bin".into()))?;
commit.upload_bytes(b"file two".to_vec(), Some("b.bin".into()))?;
commit.upload_bytes(b"file three".to_vec(), Some("c.bin".into()))?;
let results = commit.commit()?;
assert_eq!(results.len(), 3);
Ok(())
}
#[test]
// After a successful commit the aggregate progress reflects bytes processed.
fn test_upload_progress_reflects_bytes_after_commit() -> Result<(), Box<dyn std::error::Error>> {
let temp = tempdir()?;
let session = local_session(&temp)?;
let data = b"progress tracking upload data";
let commit = session.new_upload_commit_blocking()?;
let progress_observer = commit.clone();
commit.upload_bytes(data.to_vec(), Some("prog.bin".into()))?;
commit.commit()?;
let snapshot = progress_observer.get_progress()?;
assert!(snapshot.total().total_bytes_completed > 0);
Ok(())
}
// ── Non-tokio executor (no-panic + round-trip) ────────────────────────────
//
// smol, async-std, and futures::executor do not set tokio's thread-local
// runtime context, so Handle::block_on inside external_run_async_task does
// not panic — it just blocks the calling executor thread.
#[test]
// new_upload_commit_blocking completes a full upload round-trip inside smol.
fn test_upload_blocking_round_trip_in_smol() {
let temp = tempdir().unwrap();
let session = local_session(&temp).unwrap();
smol::block_on(async {
let data = b"upload from smol executor";
let commit = session.new_upload_commit_blocking().unwrap();
let handle = commit.upload_bytes(data.to_vec(), Some("test.bin".into())).unwrap();
let results = commit.commit().unwrap();
let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap();
assert_eq!(meta.file_size, data.len() as u64);
assert!(!meta.hash.is_empty());
});
}
#[test]
// new_upload_commit_blocking completes a full upload round-trip inside futures::executor.
fn test_upload_blocking_round_trip_in_futures_executor() {
let temp = tempdir().unwrap();
let session = local_session(&temp).unwrap();
futures::executor::block_on(async {
let data = b"upload from futures executor";
let commit = session.new_upload_commit_blocking().unwrap();
let handle = commit.upload_bytes(data.to_vec(), Some("test.bin".into())).unwrap();
let results = commit.commit().unwrap();
let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap();
assert_eq!(meta.file_size, data.len() as u64);
assert!(!meta.hash.is_empty());
});
}
#[test]
// new_upload_commit_blocking completes a full upload round-trip inside async-std.
fn test_upload_blocking_round_trip_in_async_std() {
let temp = tempdir().unwrap();
let session = local_session(&temp).unwrap();
async_std::task::block_on(async {
let data = b"upload from async-std executor";
let commit = session.new_upload_commit_blocking().unwrap();
let handle = commit.upload_bytes(data.to_vec(), Some("test.bin".into())).unwrap();
let results = commit.commit().unwrap();
let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap();
assert_eq!(meta.file_size, data.len() as u64);
assert!(!meta.hash.is_empty());
});
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,13 +1,17 @@
use std::cell::RefCell;
use std::collections::HashMap;
use std::fmt::Display;
use std::future::Future;
use std::panic::AssertUnwindSafe;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, OnceLock};
use std::sync::{Arc, LazyLock, OnceLock, Weak};
use futures::FutureExt;
use reqwest::Client;
use tokio::runtime::{Builder as TokioRuntimeBuilder, Handle as TokioRuntimeHandle, Runtime as TokioRuntime};
use tokio::sync::oneshot;
use tokio::task::JoinHandle;
use tracing::{debug, info};
use tracing::{debug, info, warn};
use super::XetCommon;
use crate::config::XetConfig;
@@ -139,6 +143,10 @@ pub struct XetRuntime {
// System monitor instance if enabled, monitor starts on initiation
#[cfg(not(target_family = "wasm"))]
system_monitor: Option<SystemMonitor>,
/// Set only for External-mode instances created via `from_external_with_config`.
/// Used to deregister from `EXTERNAL_RUNTIME_REGISTRY` on drop.
external_handle_id: Option<tokio::runtime::Id>,
}
// Use thread-local references to the runtime that are set on initialization among all
@@ -148,6 +156,16 @@ thread_local! {
static THREAD_RUNTIME_REF: RefCell<Option<(u32, Arc<XetRuntime>)>> = const { RefCell::new(None) };
}
// Registry for External-mode runtimes created via from_external_with_config.
// Keyed by tokio runtime ID so current_if_exists() can find the right XetRuntime
// (with the correct XetConfig and XetCommon) when called from the caller's tokio threads,
// where THREAD_RUNTIME_REF is never set.
//
// Uses std::sync (not tokio::sync) because the registry must be accessible from non-async
// contexts such as Drop impls and sync builder methods.
static EXTERNAL_RUNTIME_REGISTRY: LazyLock<std::sync::RwLock<HashMap<tokio::runtime::Id, Weak<XetRuntime>>>> =
LazyLock::new(|| std::sync::RwLock::new(HashMap::new()));
impl XetRuntime {
/// Return the current threadpool that the current worker thread uses. Will fail if
/// called from a thread that is not spawned from the current runtime.
@@ -166,16 +184,25 @@ impl XetRuntime {
#[inline]
pub fn current_if_exists() -> Option<Arc<Self>> {
// 1. Thread-local: set by on_thread_start in new_with_config (Owned mode).
let maybe_rt = THREAD_RUNTIME_REF.with_borrow(|rt| rt.clone());
if let Some((pid, rt)) = maybe_rt
&& pid == std::process::id()
{
return Some(rt);
}
if let Ok(tokio_rt) = TokioRuntimeHandle::try_current() {
Some(Self::from_external(tokio_rt))
// 2. Handle registry: set by from_external_with_config (External mode). Returns the XetRuntime with the correct
// XetConfig and XetCommon for this runtime.
if let Ok(handle) = TokioRuntimeHandle::try_current() {
if let Ok(reg) = EXTERNAL_RUNTIME_REGISTRY.read()
&& let Some(weak) = reg.get(&handle.id())
&& let Some(rt) = weak.upgrade()
{
return Some(rt);
}
// Fallback: no XetSession owns this handle; create a bare default-config wrapper.
Some(Self::from_external(handle))
} else {
None
}
@@ -209,6 +236,7 @@ impl XetRuntime {
})
.flatten(),
config: Arc::new(config),
external_handle_id: None,
});
// Each thread in each of the tokio worker threads holds a reference to the runtime handling
@@ -263,6 +291,48 @@ impl XetRuntime {
Ok(rt)
}
/// Wrap an existing tokio [`TokioRuntimeHandle`] with a [`XetRuntime`] using the provided
/// [`XetConfig`]. No new thread pool is created; `spawn()` calls will schedule work on the
/// runtime that owns `rt_handle`.
///
/// The resulting `XetRuntime` is registered in `EXTERNAL_RUNTIME_REGISTRY` so that
/// [`XetRuntime::current()`] called from tasks running on `rt_handle`'s threads will return
/// this instance (with the correct config and shared `XetCommon`) rather than a default
/// throwaway. The entry is removed when the last `Arc<XetRuntime>` is dropped.
pub fn from_external_with_config(rt_handle: TokioRuntimeHandle, config: XetConfig) -> Arc<Self> {
let id = rt_handle.id();
let rt = Arc::new(Self {
runtime: std::sync::RwLock::new(None),
handle_ref: rt_handle.into(),
external_executor_count: 0.into(),
sigint_shutdown: false.into(),
common: XetCommon::new(&config),
#[cfg(not(target_family = "wasm"))]
system_monitor: config
.system_monitor
.enabled
.then(|| {
SystemMonitor::follow_process(
config.system_monitor.sample_interval,
config.system_monitor.log_path.clone(),
)
.ok()
})
.flatten(),
config: Arc::new(config),
external_handle_id: Some(id),
});
if let Ok(mut reg) = EXTERNAL_RUNTIME_REGISTRY.write() {
reg.insert(id, Arc::downgrade(&rt));
} else {
warn!("EXTERNAL_RUNTIME_REGISTRY poisoned; skipping registration");
}
rt
}
/// Wrap an existing tokio [`TokioRuntimeHandle`] with a [`XetRuntime`] using a default
/// [`XetConfig`]. Prefer [`from_external_with_config`](Self::from_external_with_config) when
/// you have a config available.
pub fn from_external(rt_handle: TokioRuntimeHandle) -> Arc<Self> {
let config = XetConfig::new();
Arc::new(Self {
@@ -284,6 +354,7 @@ impl XetRuntime {
})
.flatten(),
config: Arc::new(config),
external_handle_id: None,
})
}
@@ -422,6 +493,43 @@ impl XetRuntime {
self.handle().spawn(future)
}
/// Bridge a future onto this runtime's `hf-xet-*` thread pool from any async context,
/// including non-tokio executors (smol, async-std, `futures::executor::block_on`).
///
/// Unlike [`external_run_async_task`](Self::external_run_async_task) which **blocks**
/// the calling thread, this method returns a future that any executor can poll.
/// The result is delivered via a `tokio::sync::oneshot` channel whose receiver only
/// requires a `std::task::Waker`, making it compatible with every standard executor.
///
/// Returns `Err(RuntimeError::TaskPanic)` if the spawned future panics, or
/// `Err(RuntimeError::TaskCanceled)` if the runtime shuts down before the result
/// can be delivered.
pub async fn bridge_to_owned<T, F>(&self, task_name: &'static str, fut: F) -> Result<T, RuntimeError>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let (tx, rx) = oneshot::channel();
self.spawn(async move {
let result = AssertUnwindSafe(fut).catch_unwind().await;
let _ = tx.send(result);
});
match rx.await {
Ok(Ok(value)) => Ok(value),
Ok(Err(panic_payload)) => {
let msg = if let Some(s) = panic_payload.downcast_ref::<&str>() {
format!("{task_name}: {s}")
} else if let Some(s) = panic_payload.downcast_ref::<String>() {
format!("{task_name}: {s}")
} else {
format!("{task_name}: <unknown panic>")
};
Err(RuntimeError::TaskPanic(msg))
},
Err(_) => Err(RuntimeError::TaskCanceled(task_name.to_string())),
}
}
/// Spawn a blocking task on the runtime's blocking thread pool. The task runs with this
/// runtime stored in thread-local storage so [`XetRuntime::current()`] works inside `f`.
///
@@ -447,6 +555,16 @@ impl XetRuntime {
}
}
impl Drop for XetRuntime {
fn drop(&mut self) {
if let Some(id) = &self.external_handle_id
&& let Ok(mut reg) = EXTERNAL_RUNTIME_REGISTRY.write()
{
reg.remove(id);
}
}
}
impl Display for XetRuntime {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// Need to be careful that this doesn't acquire locks eagerly, as this function can be called
@@ -492,4 +610,28 @@ mod tests {
let same = rt.external_run_async_task(async { jh.await.unwrap() }).unwrap();
assert!(same);
}
/// current_if_exists() must return the session-owned XetRuntime (with the correct config)
/// when called from tasks on an External-mode runtime, not a default-config throwaway.
#[test]
fn test_current_if_exists_sees_external_runtime_config() {
let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
let mut config = XetConfig::new();
config.data.default_cas_endpoint = "https://test-endpoint.example.com".into();
let xet_rt = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), config);
// current_if_exists() from within the runtime must find the registered entry.
tokio_rt.block_on(async {
let found = XetRuntime::current_if_exists().expect("should find a runtime");
assert!(Arc::ptr_eq(&found, &xet_rt), "must be the same XetRuntime instance");
assert_eq!(found.config().data.default_cas_endpoint, "https://test-endpoint.example.com");
});
// After drop the entry is removed; current_if_exists() falls back to a default wrapper.
drop(xet_rt);
tokio_rt.block_on(async {
let found = XetRuntime::current_if_exists().expect("should still find a runtime");
assert_ne!(found.config().data.default_cas_endpoint, "https://test-endpoint.example.com");
});
}
}