mirror of
https://github.com/huggingface/xet-core.git
synced 2026-06-04 13:30:29 +08:00
Support XetSession in async context (#694)
`XetSession` always created its own tokio runtime via
`XetRuntime::new_with_config`, and calling `external_run_async_task`
panics when already inside a tokio context. This blocked embedding the
session in async Rust frameworks.
Core strategy:
- `RuntimeMode` enum —
`Owned` (session created its own thread pool via
`XetSessionBuilder::build` or `XetSessionBuilder::build_async` when
outside tokio context. Both `_blocking` and async methods are supported.
Async methods use an internal `bridge_to_owned` bridge that routes
futures onto the owned thread pool, so they work from any executor
(tokio, smol, async-std))
vs
`External` (session wraps a caller-supplied tokio handle via
`XetSessionBuilder::with_tokio_handle` or
`XetSessionBuilder::build_async` when inside qualified tokio context.
Only async methods may be called; `_blocking` methods return
`SessionError::WrongRuntimeMode`. No second thread pool is created).
- `XetRuntime::bridge_to_owned` — a new bridge that routes a future onto
the owned tokio thread pool from any executor (smol, async-std,
futures::executor, non-qualified tokio runtime) by delivering the result
via a `tokio::sync::oneshot` channel that can be polled by any async
executor.
- Async public API — `UploadCommit` and `DownloadGroup` methods
(`upload_from_path`, `upload_bytes`, `upload_file`, `commit`, `finish`)
are now async fn. Factory methods `XetSession::new_upload_commit` and
`new_download_group` are async.
Example:
```
let session = XetSessionBuilder::new().build_async().await?;
// Upload
let commit = session.new_upload_commit().await?;
let handle = commit.upload_from_path("file.bin".into()).await?;
let results = commit.commit().await?;
// Download
let group = session.new_download_group().await?;
let info = XetFileInfo {
hash: ...,
file_size: ...,
};
let dl_handle = group.download_file_to_path(info, "out/file.bin".into())?;
let finish_results = group.finish().await?;
```
- Sync wrappers — New `UploadCommitSync` / `DownloadGroupSync` in
`xet_session/sync/` expose a fully blocking API for sync Rust and Python
(PyO3) callers. Returned by `new_upload_commit_blocking()` and
`new_download_group_blocking()`.
Example:
```
let session = XetSessionBuilder::new().build()?;
// Upload
let commit = session.new_upload_commit_blocking()?;
let handle = commit.upload_from_path("file.bin".into())?;
let results = commit.commit()?;
let m = results.values().next().unwrap().as_ref().as_ref().unwrap();
// Download
let group = session.new_download_group_blocking()?;
let info = XetFileInfo {
hash: ...,
file_size: ...,
};
let dl_handle = group.download_file_to_path(info, "out/file.bin".into())?;
let finish_results = group.finish()?;
```
Additional fixes: `download_file_to_path` and `upload_from_path` now
canonicalize paths with `std::path::absolute` before enqueuing; task
status is only overwritten when still `Running`, preventing a race with
concurrent abort().
Fix XET-891
---------
Co-authored-by: Hoyt Koepke <hoytak@huggingface.co>
This commit is contained in:
274
Cargo.lock
generated
274
Cargo.lock
generated
@@ -160,17 +160,109 @@ dependencies = [
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-channel"
|
||||
version = "1.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "81953c529336010edd6d8e358f886d9581267795c61b19475b71314bffa46d35"
|
||||
dependencies = [
|
||||
"concurrent-queue",
|
||||
"event-listener 2.5.3",
|
||||
"futures-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-channel"
|
||||
version = "2.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "924ed96dd52d1b75e9c1a3e6275715fd320f5f9439fb5a4a11fa51f4221158d2"
|
||||
dependencies = [
|
||||
"concurrent-queue",
|
||||
"event-listener-strategy",
|
||||
"futures-core",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-executor"
|
||||
version = "1.14.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c96bf972d85afc50bf5ab8fe2d54d1586b4e0b46c97c50a0c9e71e2f7bcd812a"
|
||||
dependencies = [
|
||||
"async-task",
|
||||
"concurrent-queue",
|
||||
"fastrand",
|
||||
"futures-lite",
|
||||
"pin-project-lite",
|
||||
"slab",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-fs"
|
||||
version = "2.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8034a681df4aed8b8edbd7fbe472401ecf009251c8b40556b304567052e294c5"
|
||||
dependencies = [
|
||||
"async-lock",
|
||||
"blocking",
|
||||
"futures-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-global-executor"
|
||||
version = "2.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "05b1b633a2115cd122d73b955eadd9916c18c8f510ec9cd1686404c60ad1c29c"
|
||||
dependencies = [
|
||||
"async-channel 2.5.0",
|
||||
"async-executor",
|
||||
"async-io",
|
||||
"async-lock",
|
||||
"blocking",
|
||||
"futures-lite",
|
||||
"once_cell",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-io"
|
||||
version = "2.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "456b8a8feb6f42d237746d4b3e9a178494627745c3c56c6ea55d92ba50d026fc"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"cfg-if 1.0.4",
|
||||
"concurrent-queue",
|
||||
"futures-io",
|
||||
"futures-lite",
|
||||
"parking",
|
||||
"polling",
|
||||
"rustix",
|
||||
"slab",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-lock"
|
||||
version = "3.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "290f7f2596bd5b78a9fec8088ccd89180d7f9f55b94b0576823bbbdc72ee8311"
|
||||
dependencies = [
|
||||
"event-listener",
|
||||
"event-listener 5.4.1",
|
||||
"event-listener-strategy",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-net"
|
||||
version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b948000fad4873c1c9339d60f2623323a0cfd3816e5181033c6a5cb68b2accf7"
|
||||
dependencies = [
|
||||
"async-io",
|
||||
"blocking",
|
||||
"futures-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-object-pool"
|
||||
version = "0.2.0"
|
||||
@@ -178,9 +270,77 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e1ac0219111eb7bb7cb76d4cf2cb50c598e7ae549091d3616f9e95442c18486f"
|
||||
dependencies = [
|
||||
"async-lock",
|
||||
"event-listener",
|
||||
"event-listener 5.4.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-process"
|
||||
version = "2.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc50921ec0055cdd8a16de48773bfeec5c972598674347252c0399676be7da75"
|
||||
dependencies = [
|
||||
"async-channel 2.5.0",
|
||||
"async-io",
|
||||
"async-lock",
|
||||
"async-signal",
|
||||
"async-task",
|
||||
"blocking",
|
||||
"cfg-if 1.0.4",
|
||||
"event-listener 5.4.1",
|
||||
"futures-lite",
|
||||
"rustix",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-signal"
|
||||
version = "0.2.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "43c070bbf59cd3570b6b2dd54cd772527c7c3620fce8be898406dd3ed6adc64c"
|
||||
dependencies = [
|
||||
"async-io",
|
||||
"async-lock",
|
||||
"atomic-waker",
|
||||
"cfg-if 1.0.4",
|
||||
"futures-core",
|
||||
"futures-io",
|
||||
"rustix",
|
||||
"signal-hook-registry",
|
||||
"slab",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-std"
|
||||
version = "1.13.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2c8e079a4ab67ae52b7403632e4618815d6db36d2a010cfe41b02c1b1578f93b"
|
||||
dependencies = [
|
||||
"async-channel 1.9.0",
|
||||
"async-global-executor",
|
||||
"async-io",
|
||||
"async-lock",
|
||||
"crossbeam-utils",
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-io",
|
||||
"futures-lite",
|
||||
"gloo-timers",
|
||||
"kv-log-macro",
|
||||
"log",
|
||||
"memchr",
|
||||
"once_cell",
|
||||
"pin-project-lite",
|
||||
"pin-utils",
|
||||
"slab",
|
||||
"wasm-bindgen-futures",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-task"
|
||||
version = "4.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8b75356056920673b02621b35afd0f7dda9306d03c79a30f5c56c44cf256e3de"
|
||||
|
||||
[[package]]
|
||||
name = "async-trait"
|
||||
version = "0.1.89"
|
||||
@@ -383,6 +543,19 @@ dependencies = [
|
||||
"generic-array 0.14.7",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "blocking"
|
||||
version = "1.6.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e83f8d02be6967315521be875afa792a316e28d57b5a2d401897e2a7921b7f21"
|
||||
dependencies = [
|
||||
"async-channel 2.5.0",
|
||||
"async-task",
|
||||
"futures-io",
|
||||
"futures-lite",
|
||||
"piper",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "blowfish"
|
||||
version = "0.9.1"
|
||||
@@ -1103,6 +1276,12 @@ dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "event-listener"
|
||||
version = "2.5.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0"
|
||||
|
||||
[[package]]
|
||||
name = "event-listener"
|
||||
version = "5.4.1"
|
||||
@@ -1120,7 +1299,7 @@ version = "0.5.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8be9f3dfaaffdae2972880079a491a1a8bb7cbed0b8dd7a347f668b4150a3b93"
|
||||
dependencies = [
|
||||
"event-listener",
|
||||
"event-listener 5.4.1",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
@@ -1264,6 +1443,19 @@ version = "0.3.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718"
|
||||
|
||||
[[package]]
|
||||
name = "futures-lite"
|
||||
version = "2.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad"
|
||||
dependencies = [
|
||||
"fastrand",
|
||||
"futures-core",
|
||||
"futures-io",
|
||||
"parking",
|
||||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-macro"
|
||||
version = "0.3.32"
|
||||
@@ -1466,6 +1658,18 @@ dependencies = [
|
||||
"xet-runtime",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gloo-timers"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bbb143cf96099802033e0d4f4963b19fd2e0b728bcf076cd9cf7f6634f092994"
|
||||
dependencies = [
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"js-sys",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "group"
|
||||
version = "0.13.0"
|
||||
@@ -1669,15 +1873,19 @@ name = "hf-xet"
|
||||
version = "1.4.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-std",
|
||||
"async-trait",
|
||||
"clap",
|
||||
"futures",
|
||||
"http",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serial_test",
|
||||
"smol",
|
||||
"tempfile",
|
||||
"thiserror 2.0.18",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"tracing-subscriber",
|
||||
"ulid",
|
||||
"xet-client",
|
||||
@@ -2178,6 +2386,15 @@ version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e037a2e1d8d5fdbd49b16a4ea09d5d6401c1f29eca5ff29d03d3824dba16256a"
|
||||
|
||||
[[package]]
|
||||
name = "kv-log-macro"
|
||||
version = "1.0.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0de8b303297635ad57c9f5059fd9cee7a47f8e8daa09df0fcd07dd39fb22977f"
|
||||
dependencies = [
|
||||
"log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lazy_static"
|
||||
version = "1.5.0"
|
||||
@@ -2357,6 +2574,9 @@ name = "log"
|
||||
version = "0.4.29"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897"
|
||||
dependencies = [
|
||||
"value-bag",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lru-slab"
|
||||
@@ -3006,6 +3226,17 @@ version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
|
||||
|
||||
[[package]]
|
||||
name = "piper"
|
||||
version = "0.2.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c835479a4443ded371d6c535cbfd8d31ad92c5d23ae9770a61bc155e4992a3c1"
|
||||
dependencies = [
|
||||
"atomic-waker",
|
||||
"fastrand",
|
||||
"futures-io",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pkcs1"
|
||||
version = "0.7.5"
|
||||
@@ -3050,6 +3281,20 @@ version = "0.3.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
|
||||
|
||||
[[package]]
|
||||
name = "polling"
|
||||
version = "3.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5d0e4f59085d47d8241c88ead0f274e8a0cb551f3625263c05eb8dd897c34218"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.4",
|
||||
"concurrent-queue",
|
||||
"hermit-abi",
|
||||
"pin-project-lite",
|
||||
"rustix",
|
||||
"windows-sys 0.61.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "poly1305"
|
||||
version = "0.8.0"
|
||||
@@ -4238,6 +4483,23 @@ version = "1.15.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03"
|
||||
|
||||
[[package]]
|
||||
name = "smol"
|
||||
version = "2.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a33bd3e260892199c3ccfc487c88b2da2265080acb316cd920da72fdfd7c599f"
|
||||
dependencies = [
|
||||
"async-channel 2.5.0",
|
||||
"async-executor",
|
||||
"async-fs",
|
||||
"async-io",
|
||||
"async-lock",
|
||||
"async-net",
|
||||
"async-process",
|
||||
"blocking",
|
||||
"futures-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "socket2"
|
||||
version = "0.6.3"
|
||||
@@ -5013,6 +5275,12 @@ version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65"
|
||||
|
||||
[[package]]
|
||||
name = "value-bag"
|
||||
version = "1.12.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7ba6f5989077681266825251a52748b8c1d8a4ad098cc37e440103d0ea717fc0"
|
||||
|
||||
[[package]]
|
||||
name = "vcpkg"
|
||||
version = "0.2.15"
|
||||
|
||||
@@ -32,8 +32,9 @@ debug = 1
|
||||
|
||||
[workspace.dependencies]
|
||||
anyhow = "1"
|
||||
axum = "0.8"
|
||||
async-std = "1"
|
||||
async-trait = "0.1"
|
||||
axum = "0.8"
|
||||
base64 = "0.22"
|
||||
bincode = "1.3"
|
||||
bitflags = { version = "2.10", features = ["serde"] }
|
||||
@@ -50,7 +51,6 @@ csv = "1"
|
||||
ctor = "0.6"
|
||||
derivative = "2.2"
|
||||
dirs = "6.0"
|
||||
human-bandwidth = "0.1"
|
||||
duration-str = "0.19"
|
||||
futures = "0.3"
|
||||
futures-util = "0.3"
|
||||
@@ -63,6 +63,7 @@ half = "2.7"
|
||||
heapify = "0.2"
|
||||
heed = "0.22"
|
||||
http = "1"
|
||||
human-bandwidth = "0.1"
|
||||
hyper = "1.8"
|
||||
hyper-util = "0.1"
|
||||
itertools = "0.14"
|
||||
@@ -95,6 +96,7 @@ serde_repr = "0.1"
|
||||
sha2 = "0.10"
|
||||
shell-words = "1.1"
|
||||
shellexpand = "3.1"
|
||||
smol = "2"
|
||||
static_assertions = "1.1"
|
||||
statrs = "0.18"
|
||||
sysinfo = "0.38"
|
||||
|
||||
@@ -20,12 +20,16 @@ xet-data = { version = "1.4.0", path = "../xet_data" }
|
||||
|
||||
async-trait = { workspace = true }
|
||||
http = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tokio = { workspace = true, features = ["net"] }
|
||||
ulid = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
serde = { workspace = true, features = ["derive"] }
|
||||
|
||||
[dev-dependencies]
|
||||
async-std = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
smol = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
serial_test = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
//! Session-based upload/download example.
|
||||
//! Async session-based upload/download example.
|
||||
//!
|
||||
//! Shows the three-level hierarchy: XetSession → UploadCommit/DownloadGroup → files.
|
||||
//! Mirror of `example.rs` using the async API (`UploadCommit` / `DownloadGroup`).
|
||||
//! Requires an async runtime — here provided by `#[tokio::main]`.
|
||||
|
||||
use std::path::PathBuf;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::{Parser, Subcommand};
|
||||
@@ -12,7 +12,7 @@ use xet::xet_session::{
|
||||
};
|
||||
|
||||
#[derive(Parser)]
|
||||
#[clap(name = "session-demo", about = "XetSession API demo")]
|
||||
#[clap(name = "session-demo-async", about = "XetSession async API demo")]
|
||||
struct Cli {
|
||||
#[clap(subcommand)]
|
||||
command: Command,
|
||||
@@ -37,52 +37,53 @@ enum Command {
|
||||
},
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt::init();
|
||||
let cli = Cli::parse();
|
||||
match cli.command {
|
||||
Command::Upload { files, endpoint } => upload_files(files, endpoint),
|
||||
Command::Upload { files, endpoint } => upload_files(files, endpoint).await,
|
||||
Command::Download {
|
||||
metadata_file,
|
||||
output_dir,
|
||||
endpoint,
|
||||
} => download_files(metadata_file, output_dir, endpoint),
|
||||
} => download_files(metadata_file, output_dir, endpoint).await,
|
||||
}
|
||||
}
|
||||
|
||||
fn upload_files(files: Vec<PathBuf>, endpoint: Option<String>) -> Result<()> {
|
||||
async fn upload_files(files: Vec<PathBuf>, endpoint: Option<String>) -> Result<()> {
|
||||
let mut builder = XetSessionBuilder::new();
|
||||
if let Some(ep) = endpoint {
|
||||
builder = builder.with_endpoint(ep);
|
||||
}
|
||||
let session = builder.build()?;
|
||||
let commit = session.new_upload_commit()?;
|
||||
let session = builder.build_async().await?;
|
||||
let commit = session.new_upload_commit().await?;
|
||||
|
||||
// Enqueue all uploads; each starts immediately in the background.
|
||||
let n_files = files.len();
|
||||
let handles: Vec<UploadTaskHandle> = files
|
||||
.iter()
|
||||
.map(|f| commit.upload_from_path(f.clone()))
|
||||
.collect::<Result<_, _>>()?;
|
||||
let mut handles = Vec::with_capacity(n_files);
|
||||
for f in &files {
|
||||
handles.push(commit.upload_from_path(f.clone()).await?);
|
||||
}
|
||||
|
||||
// Spawn a task to print progress; the main thread blocks in commit() below.
|
||||
// Spawn a task to print progress while the main task awaits commit().
|
||||
let commit_for_progress = commit.clone();
|
||||
std::thread::spawn(move || {
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
if let Ok(snapshot) = commit_for_progress.get_progress() {
|
||||
let p = snapshot.total();
|
||||
let done = handles
|
||||
.iter()
|
||||
.filter(|h| matches!(h.status(), Ok(TaskStatus::Completed)))
|
||||
.filter(|h: &&UploadTaskHandle| matches!(h.status(), Ok(TaskStatus::Completed)))
|
||||
.count();
|
||||
println!("{}/{} files | {}/{} bytes", done, n_files, p.total_bytes_completed, p.total_bytes);
|
||||
}
|
||||
std::thread::sleep(Duration::from_millis(100));
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
}
|
||||
});
|
||||
|
||||
// Block until all uploads finish and metadata is finalized.
|
||||
let results = commit.commit()?;
|
||||
// Await until all uploads finish and metadata is finalized.
|
||||
let results = commit.commit().await?;
|
||||
|
||||
for m in results.values().filter_map(|m| m.as_ref().as_ref().ok()) {
|
||||
println!(" {} -> {} ({} bytes)", m.tracking_name.as_deref().unwrap_or("?"), m.hash, m.file_size);
|
||||
@@ -98,7 +99,7 @@ fn upload_files(files: Vec<PathBuf>, endpoint: Option<String>) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn download_files(metadata_file: PathBuf, output_dir: PathBuf, endpoint: Option<String>) -> Result<()> {
|
||||
async fn download_files(metadata_file: PathBuf, output_dir: PathBuf, endpoint: Option<String>) -> Result<()> {
|
||||
let metadata: Vec<FileMetadata> = serde_json::from_str(&std::fs::read_to_string(metadata_file)?)?;
|
||||
std::fs::create_dir_all(&output_dir)?;
|
||||
|
||||
@@ -106,28 +107,26 @@ fn download_files(metadata_file: PathBuf, output_dir: PathBuf, endpoint: Option<
|
||||
if let Some(ep) = endpoint {
|
||||
builder = builder.with_endpoint(ep);
|
||||
}
|
||||
let session = builder.build()?;
|
||||
let group = session.new_download_group()?;
|
||||
let session = builder.build_async().await?;
|
||||
let group = session.new_download_group().await?;
|
||||
|
||||
// Enqueue all downloads; each starts immediately in the background.
|
||||
let n_files = metadata.len();
|
||||
let handles: Vec<DownloadTaskHandle> = metadata
|
||||
.iter()
|
||||
.map(|m| {
|
||||
let dest = output_dir.join(m.tracking_name.as_deref().unwrap_or("file"));
|
||||
group.download_file_to_path(
|
||||
XetFileInfo {
|
||||
hash: m.hash.clone(),
|
||||
file_size: m.file_size,
|
||||
},
|
||||
dest,
|
||||
)
|
||||
})
|
||||
.collect::<Result<_, _>>()?;
|
||||
let mut handles: Vec<DownloadTaskHandle> = Vec::with_capacity(n_files);
|
||||
for m in &metadata {
|
||||
let dest = output_dir.join(m.tracking_name.as_deref().unwrap_or("file"));
|
||||
handles.push(group.download_file_to_path(
|
||||
XetFileInfo {
|
||||
hash: m.hash.clone(),
|
||||
file_size: m.file_size,
|
||||
},
|
||||
dest,
|
||||
)?);
|
||||
}
|
||||
|
||||
// Spawn a task to print progress; the main thread blocks in finish() below.
|
||||
// Spawn a task to print progress while the main task awaits finish().
|
||||
let group_for_progress = group.clone();
|
||||
std::thread::spawn(move || {
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
if let Ok(snapshot) = group_for_progress.get_progress() {
|
||||
let p = snapshot.total();
|
||||
@@ -137,12 +136,12 @@ fn download_files(metadata_file: PathBuf, output_dir: PathBuf, endpoint: Option<
|
||||
.count();
|
||||
println!("{}/{} files | {}/{} bytes", done, n_files, p.total_bytes_completed, p.total_bytes);
|
||||
}
|
||||
std::thread::sleep(Duration::from_millis(100));
|
||||
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
||||
}
|
||||
});
|
||||
|
||||
// Block until all downloads finish.
|
||||
let results = group.finish()?;
|
||||
// Await until all downloads finish.
|
||||
let results = group.finish().await?;
|
||||
|
||||
for (_task_id, result) in &results {
|
||||
if let Ok(r) = result.as_ref() {
|
||||
|
||||
150
xet_pkg/examples/example_sync.rs
Normal file
150
xet_pkg/examples/example_sync.rs
Normal file
@@ -0,0 +1,150 @@
|
||||
//! Session-based upload/download example.
|
||||
//!
|
||||
//! Shows the three-level hierarchy: XetSession → UploadCommit/DownloadGroup → files.
|
||||
|
||||
use std::path::PathBuf;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::{Parser, Subcommand};
|
||||
use xet::xet_session::{FileMetadata, TaskStatus, XetFileInfo, XetSessionBuilder};
|
||||
|
||||
#[derive(Parser)]
|
||||
#[clap(name = "session-demo", about = "XetSession API demo")]
|
||||
struct Cli {
|
||||
#[clap(subcommand)]
|
||||
command: Command,
|
||||
}
|
||||
|
||||
#[derive(Subcommand)]
|
||||
enum Command {
|
||||
/// Upload files and save metadata to upload_metadata.json
|
||||
Upload {
|
||||
#[clap(required = true)]
|
||||
files: Vec<PathBuf>,
|
||||
#[clap(long)]
|
||||
endpoint: Option<String>,
|
||||
},
|
||||
/// Download files from metadata saved by the upload subcommand
|
||||
Download {
|
||||
metadata_file: PathBuf,
|
||||
#[clap(short, long, default_value = "./downloads")]
|
||||
output_dir: PathBuf,
|
||||
#[clap(long)]
|
||||
endpoint: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt::init();
|
||||
let cli = Cli::parse();
|
||||
match cli.command {
|
||||
Command::Upload { files, endpoint } => upload_files(files, endpoint),
|
||||
Command::Download {
|
||||
metadata_file,
|
||||
output_dir,
|
||||
endpoint,
|
||||
} => download_files(metadata_file, output_dir, endpoint),
|
||||
}
|
||||
}
|
||||
|
||||
fn upload_files(files: Vec<PathBuf>, endpoint: Option<String>) -> Result<()> {
|
||||
let mut builder = XetSessionBuilder::new();
|
||||
if let Some(ep) = endpoint {
|
||||
builder = builder.with_endpoint(ep);
|
||||
}
|
||||
let session = builder.build()?;
|
||||
let commit = session.new_upload_commit_blocking()?;
|
||||
|
||||
// Enqueue all uploads; each starts immediately in the background.
|
||||
let n_files = files.len();
|
||||
let mut handles = Vec::with_capacity(n_files);
|
||||
for f in &files {
|
||||
handles.push(commit.upload_from_path(f.clone())?);
|
||||
}
|
||||
|
||||
// Spawn a task to print progress; the main thread blocks in commit() below.
|
||||
let commit_for_progress = commit.clone();
|
||||
std::thread::spawn(move || {
|
||||
loop {
|
||||
if let Ok(snapshot) = commit_for_progress.get_progress() {
|
||||
let p = snapshot.total();
|
||||
let done = handles
|
||||
.iter()
|
||||
.filter(|h| matches!(h.status(), Ok(TaskStatus::Completed)))
|
||||
.count();
|
||||
println!("{}/{} files | {}/{} bytes", done, n_files, p.total_bytes_completed, p.total_bytes);
|
||||
}
|
||||
std::thread::sleep(Duration::from_millis(100));
|
||||
}
|
||||
});
|
||||
|
||||
// Block until all uploads finish and metadata is finalized.
|
||||
let results = commit.commit()?;
|
||||
|
||||
for m in results.values().filter_map(|m| m.as_ref().as_ref().ok()) {
|
||||
println!(" {} -> {} ({} bytes)", m.tracking_name.as_deref().unwrap_or("?"), m.hash, m.file_size);
|
||||
}
|
||||
|
||||
// Persist metadata so it can be passed to the `download` subcommand.
|
||||
let metadata: Vec<_> = results
|
||||
.into_values()
|
||||
.filter_map(|m| m.as_ref().as_ref().ok().cloned())
|
||||
.collect();
|
||||
std::fs::write("upload_metadata.json", serde_json::to_string_pretty(&metadata)?)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn download_files(metadata_file: PathBuf, output_dir: PathBuf, endpoint: Option<String>) -> Result<()> {
|
||||
let metadata: Vec<FileMetadata> = serde_json::from_str(&std::fs::read_to_string(metadata_file)?)?;
|
||||
std::fs::create_dir_all(&output_dir)?;
|
||||
|
||||
let mut builder = XetSessionBuilder::new();
|
||||
if let Some(ep) = endpoint {
|
||||
builder = builder.with_endpoint(ep);
|
||||
}
|
||||
let session = builder.build()?;
|
||||
let group = session.new_download_group_blocking()?;
|
||||
|
||||
// Enqueue all downloads; each starts immediately in the background.
|
||||
let n_files = metadata.len();
|
||||
let mut handles = Vec::with_capacity(n_files);
|
||||
for m in &metadata {
|
||||
let dest = output_dir.join(m.tracking_name.as_deref().unwrap_or("file"));
|
||||
handles.push(group.download_file_to_path(
|
||||
XetFileInfo {
|
||||
hash: m.hash.clone(),
|
||||
file_size: m.file_size,
|
||||
},
|
||||
dest,
|
||||
)?);
|
||||
}
|
||||
|
||||
// Spawn a task to print progress; the main thread blocks in finish() below.
|
||||
let group_for_progress = group.clone();
|
||||
std::thread::spawn(move || {
|
||||
loop {
|
||||
if let Ok(snapshot) = group_for_progress.get_progress() {
|
||||
let p = snapshot.total();
|
||||
let done = handles
|
||||
.iter()
|
||||
.filter(|h| matches!(h.status(), Ok(TaskStatus::Completed)))
|
||||
.count();
|
||||
println!("{}/{} files | {}/{} bytes", done, n_files, p.total_bytes_completed, p.total_bytes);
|
||||
}
|
||||
std::thread::sleep(Duration::from_millis(100));
|
||||
}
|
||||
});
|
||||
|
||||
// Block until all downloads finish.
|
||||
let results = group.finish()?;
|
||||
|
||||
for (_task_id, result) in &results {
|
||||
if let Ok(r) = result.as_ref() {
|
||||
println!(" {} ({} bytes)", r.dest_path.display(), r.file_info.file_size);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -62,12 +62,21 @@ pub enum XetError {
|
||||
/// Catch-all for unexpected internal errors (panics, lock poison, bugs).
|
||||
#[error("Internal error: {0}")]
|
||||
Internal(String),
|
||||
|
||||
/// Caller invoked a method that is incompatible with the session's runtime mode.
|
||||
#[error("Wrong runtime mode: {0}")]
|
||||
WrongRuntimeMode(String),
|
||||
}
|
||||
|
||||
impl XetError {
|
||||
pub fn other(msg: impl std::fmt::Display) -> Self {
|
||||
Self::Internal(msg.to_string())
|
||||
}
|
||||
|
||||
pub fn wrong_mode(msg: impl std::fmt::Display) -> Self {
|
||||
Self::WrongRuntimeMode(msg.to_string())
|
||||
}
|
||||
|
||||
fn from_runtime_error_ref(re: &RuntimeError) -> Self {
|
||||
match re {
|
||||
RuntimeError::TaskCanceled(_) => XetError::Cancelled(re.to_string()),
|
||||
|
||||
@@ -3,7 +3,7 @@ use xet_data::processing::configurations::TranslatorConfig;
|
||||
use super::{SessionError, XetSession};
|
||||
|
||||
// Helper function to create TranslatorConfig
|
||||
pub(crate) fn create_translator_config(session: &XetSession) -> Result<TranslatorConfig, SessionError> {
|
||||
pub(super) fn create_translator_config(session: &XetSession) -> Result<TranslatorConfig, SessionError> {
|
||||
let endpoint = session
|
||||
.endpoint
|
||||
.clone()
|
||||
@@ -21,7 +21,7 @@ pub(crate) fn create_translator_config(session: &XetSession) -> Result<Translato
|
||||
|
||||
/// State of the upload commit and download group
|
||||
#[derive(Clone, Copy, PartialEq, Eq)]
|
||||
pub(crate) enum GroupState {
|
||||
pub(super) enum GroupState {
|
||||
Alive,
|
||||
Finished,
|
||||
Aborted,
|
||||
|
||||
@@ -14,11 +14,15 @@ use super::errors::SessionError;
|
||||
use super::progress::{DownloadTaskHandle, GroupProgress, ProgressSnapshot, TaskHandle, TaskStatus};
|
||||
use super::session::XetSession;
|
||||
|
||||
/// Groups related file downloads into a single unit of work.
|
||||
/// Async API for grouping related file downloads into a single unit of work.
|
||||
///
|
||||
/// Obtain via [`XetSession::new_download_group`] from an `async` context.
|
||||
/// For sync / non-async code use [`DownloadGroupSync`] from
|
||||
/// [`XetSession::new_download_group_blocking`] instead.
|
||||
///
|
||||
/// Queue files with [`download_file_to_path`](Self::download_file_to_path) (they start
|
||||
/// downloading immediately in the background), poll progress with
|
||||
/// [`get_progress`](Self::get_progress), then call
|
||||
/// [`get_progress`](Self::get_progress), then `await`
|
||||
/// [`finish`](Self::finish) to wait for all downloads to complete.
|
||||
///
|
||||
/// # Cloning
|
||||
@@ -31,6 +35,8 @@ use super::session::XetSession;
|
||||
/// Methods return [`SessionError::Aborted`] if the parent session has been
|
||||
/// aborted, and [`SessionError::AlreadyFinished`] if
|
||||
/// [`finish`](Self::finish) has already been called.
|
||||
///
|
||||
/// [`DownloadGroupSync`]: crate::xet_session::sync::DownloadGroupSync
|
||||
#[derive(Clone)]
|
||||
pub struct DownloadGroup {
|
||||
inner: Arc<DownloadGroupInner>,
|
||||
@@ -44,17 +50,14 @@ impl std::ops::Deref for DownloadGroup {
|
||||
}
|
||||
|
||||
impl DownloadGroup {
|
||||
/// Create a new download group
|
||||
pub(crate) fn new(session: XetSession) -> Result<Self, SessionError> {
|
||||
/// Create a new download group from an **async** context. Initialisation logic shared by the sync and async
|
||||
/// constructors.
|
||||
pub(super) async fn new(session: XetSession) -> Result<Self, SessionError> {
|
||||
let group_id = Ulid::new();
|
||||
|
||||
let progress = Arc::new(GroupProgress::new());
|
||||
let progress_clone = progress.clone();
|
||||
let config = create_translator_config(&session)?;
|
||||
let download_session = session.runtime.external_run_async_task(async move {
|
||||
let progress_updater = progress_clone as Arc<dyn xet_data::progress_tracking::TrackingProgressUpdater>;
|
||||
FileDownloadSession::new(Arc::new(config), Some(progress_updater)).await
|
||||
})??;
|
||||
let progress_updater = progress.clone() as Arc<dyn xet_data::progress_tracking::TrackingProgressUpdater>;
|
||||
let download_session = FileDownloadSession::new(Arc::new(config), Some(progress_updater)).await?;
|
||||
|
||||
let inner = Arc::new(DownloadGroupInner {
|
||||
group_id,
|
||||
@@ -69,23 +72,26 @@ impl DownloadGroup {
|
||||
}
|
||||
|
||||
/// Get the group ID.
|
||||
pub(crate) fn id(&self) -> Ulid {
|
||||
pub(super) fn id(&self) -> Ulid {
|
||||
self.group_id
|
||||
}
|
||||
|
||||
/// Abort this download group.
|
||||
pub(crate) fn abort(&self) -> Result<(), SessionError> {
|
||||
pub(super) fn abort(&self) -> Result<(), SessionError> {
|
||||
self.inner.abort()
|
||||
}
|
||||
|
||||
// ===== Public synchronous methods =====
|
||||
/// Returns the runtime used by this group.
|
||||
pub(super) fn runtime(&self) -> &XetRuntime {
|
||||
&self.inner.session.runtime
|
||||
}
|
||||
|
||||
/// Queue a file for download to `dest_path`, starting the transfer immediately if system resource permits.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `file_info` – Content-addressed hash and size returned by a previous
|
||||
/// [`UploadCommit::commit`](crate::UploadCommit::commit).
|
||||
/// [`UploadCommit::commit`](crate::xet_session::UploadCommit::commit).
|
||||
/// * `dest_path` – Local path where the downloaded file will be written. Parent directories are created
|
||||
/// automatically.
|
||||
///
|
||||
@@ -103,16 +109,11 @@ impl DownloadGroup {
|
||||
dest_path: PathBuf,
|
||||
) -> Result<DownloadTaskHandle, SessionError> {
|
||||
self.session.check_alive()?;
|
||||
self.inner.start_download_file_to_path(file_info, dest_path)
|
||||
}
|
||||
|
||||
/// Returns `true` if [`finish`](Self::finish) has been called and completed.
|
||||
#[cfg(test)]
|
||||
fn is_finished(&self) -> bool {
|
||||
match self.state.lock() {
|
||||
Ok(state) => *state == GroupState::Finished,
|
||||
Err(_) => false,
|
||||
}
|
||||
// Use the absolute path in case the process current working directory changes
|
||||
// while the task is queued.
|
||||
let absolute_path = std::path::absolute(dest_path)?;
|
||||
self.inner.start_download_file_to_path(file_info, absolute_path)
|
||||
}
|
||||
|
||||
/// Return a snapshot of progress for every queued download.
|
||||
@@ -122,11 +123,9 @@ impl DownloadGroup {
|
||||
|
||||
/// Wait for all downloads to complete and return their results.
|
||||
///
|
||||
/// Blocks until every queued download finishes (or fails). Returns a
|
||||
/// `HashMap` keyed by task ID (the [`Ulid`] returned by
|
||||
/// [`download_file_to_path`](Self::download_file_to_path)), where each
|
||||
/// value is [`DownloadResult`] (= `Arc<Result<`[`DownloadedFile`]`,
|
||||
/// `[`SessionError`](crate::SessionError)`>>`). A single failed download
|
||||
/// Returns a `HashMap` keyed by task ID where each value is
|
||||
/// [`DownloadResult`] (= `Arc<Result<`[`DownloadedFile`]`,
|
||||
/// [`SessionError`](crate::SessionError)`>>`). A single failed download
|
||||
/// does not prevent the others from being collected.
|
||||
///
|
||||
/// Per-task results can also be read directly from the
|
||||
@@ -134,13 +133,21 @@ impl DownloadGroup {
|
||||
/// [`result`](DownloadTaskHandle::result) after this method returns.
|
||||
///
|
||||
/// Consumes `self` — subsequent calls on any clone will return
|
||||
/// [`SessionError::AlreadyFinished`] (or a channel-closed error if the
|
||||
/// background worker has already exited).
|
||||
pub fn finish(self) -> Result<HashMap<Ulid, DownloadResult>, SessionError> {
|
||||
/// [`SessionError::AlreadyFinished`].
|
||||
pub async fn finish(self) -> Result<HashMap<Ulid, DownloadResult>, SessionError> {
|
||||
let inner = self.inner.clone();
|
||||
self.session
|
||||
.runtime
|
||||
.external_run_async_task(async move { inner.handle_finish().await })?
|
||||
.dispatch("finish", async move { inner.handle_finish().await })
|
||||
.await?
|
||||
}
|
||||
|
||||
/// Returns `true` if [`finish`](Self::finish) has been called and completed.
|
||||
#[cfg(test)]
|
||||
fn is_finished(&self) -> bool {
|
||||
match self.state.lock() {
|
||||
Ok(state) => *state == GroupState::Finished,
|
||||
Err(_) => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,7 +159,7 @@ impl DownloadGroup {
|
||||
pub type DownloadResult = Arc<Result<DownloadedFile, SessionError>>;
|
||||
|
||||
/// Handle for a single download task tracked internally by DownloadGroup.
|
||||
pub(crate) struct InnerDownloadTaskHandle {
|
||||
struct InnerDownloadTaskHandle {
|
||||
status: Arc<Mutex<TaskStatus>>,
|
||||
dest_path: PathBuf,
|
||||
join_handle: JoinHandle<Result<XetFileInfo, SessionError>>,
|
||||
@@ -217,7 +224,11 @@ impl DownloadGroupInner {
|
||||
} else {
|
||||
TaskStatus::Failed
|
||||
};
|
||||
*status.lock()? = new_status;
|
||||
// Only overwrite if still Running — abort() may have set Cancelled concurrently.
|
||||
let mut s = status.lock()?;
|
||||
if matches!(*s, TaskStatus::Running) {
|
||||
*s = new_status;
|
||||
}
|
||||
|
||||
Ok(XetFileInfo {
|
||||
hash: file_info.hash,
|
||||
@@ -273,8 +284,8 @@ impl DownloadGroupInner {
|
||||
Ok(task_handle)
|
||||
}
|
||||
|
||||
/// Handle a `Finish` command from the public API.
|
||||
async fn handle_finish(self: &Arc<Self>) -> Result<HashMap<Ulid, DownloadResult>, SessionError> {
|
||||
/// Join all active download tasks and mark the group as finished.
|
||||
async fn handle_finish(&self) -> Result<HashMap<Ulid, DownloadResult>, SessionError> {
|
||||
// Mark as not accepting new tasks
|
||||
{
|
||||
let mut state_guard = self.state.lock()?;
|
||||
@@ -356,23 +367,30 @@ pub struct DownloadedFile {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::mpsc;
|
||||
use std::time::Duration;
|
||||
|
||||
use tempfile::{TempDir, tempdir};
|
||||
|
||||
use super::*;
|
||||
use crate::xet_session::progress::UploadTaskHandle;
|
||||
use crate::xet_session::session::XetSession;
|
||||
use crate::xet_session::session::{RuntimeMode, XetSession, XetSessionBuilder};
|
||||
|
||||
fn local_session(temp: &TempDir) -> Result<XetSession, Box<dyn std::error::Error>> {
|
||||
async fn local_session(temp: &TempDir) -> Result<XetSession, Box<dyn std::error::Error>> {
|
||||
let cas_path = temp.path().join("cas");
|
||||
Ok(XetSession::new(Some(format!("local://{}", cas_path.display())), None, None, None)?)
|
||||
Ok(XetSessionBuilder::new()
|
||||
.with_endpoint(format!("local://{}", cas_path.display()))
|
||||
.build_async()
|
||||
.await?)
|
||||
}
|
||||
|
||||
fn upload_bytes(session: &XetSession, data: &[u8], name: &str) -> Result<XetFileInfo, Box<dyn std::error::Error>> {
|
||||
let commit = session.new_upload_commit()?;
|
||||
let handle = commit.upload_bytes(data.to_vec(), Some(name.into()))?;
|
||||
let results = commit.commit()?;
|
||||
async fn upload_bytes(
|
||||
session: &XetSession,
|
||||
data: &[u8],
|
||||
name: &str,
|
||||
) -> Result<XetFileInfo, Box<dyn std::error::Error>> {
|
||||
let commit = session.new_upload_commit().await?;
|
||||
let handle = commit.upload_bytes(data.to_vec(), Some(name.into())).await?;
|
||||
let results = commit.commit().await?;
|
||||
let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap();
|
||||
Ok(XetFileInfo {
|
||||
hash: meta.hash.clone(),
|
||||
@@ -380,288 +398,6 @@ mod tests {
|
||||
})
|
||||
}
|
||||
|
||||
// ── Identity ─────────────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
// Two download groups created from the same session have distinct IDs.
|
||||
fn test_group_has_unique_id() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let session = XetSession::new(None, None, None, None)?;
|
||||
let g1 = session.new_download_group()?;
|
||||
let g2 = session.new_download_group()?;
|
||||
assert_ne!(g1.id(), g2.id());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Initial state ────────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
// A fresh group has all-zero aggregate progress.
|
||||
fn test_get_progress_empty_initially() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let session = XetSession::new(None, None, None, None)?;
|
||||
let group = session.new_download_group()?;
|
||||
let snapshot = group.get_progress()?;
|
||||
let total = snapshot.total();
|
||||
assert_eq!(total.total_bytes, 0);
|
||||
assert_eq!(total.total_bytes_completed, 0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Finish lifecycle ─────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
// An empty finish succeeds and returns an empty result set.
|
||||
fn test_finish_empty_succeeds() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let session = XetSession::new(None, None, None, None)?;
|
||||
let group = session.new_download_group()?;
|
||||
let results = group.finish()?;
|
||||
assert!(results.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
// finish() transitions the group into the Finished state.
|
||||
fn test_finish_marks_as_finished() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let session = XetSession::new(None, None, None, None)?;
|
||||
let group = session.new_download_group()?;
|
||||
let group_clone = group.clone();
|
||||
group.finish().unwrap();
|
||||
assert!(group_clone.is_finished());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
// A second finish() call on any clone returns AlreadyFinished.
|
||||
fn test_second_finish_fails() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let session = XetSession::new(None, None, None, None)?;
|
||||
let g1 = session.new_download_group()?;
|
||||
let g2 = g1.clone();
|
||||
g1.finish()?;
|
||||
let err = g2.finish().unwrap_err();
|
||||
assert!(matches!(err, SessionError::AlreadyFinished | SessionError::Internal(_)));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
// finish() unregisters the group from the session's active set.
|
||||
fn test_finish_unregisters_from_session() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let session = XetSession::new(None, None, None, None)?;
|
||||
let group = session.new_download_group()?;
|
||||
assert_eq!(session.active_download_groups.lock().unwrap().len(), 1);
|
||||
group.finish().unwrap();
|
||||
assert_eq!(session.active_download_groups.lock().unwrap().len(), 0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Guards ───────────────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
// download_file_to_path returns Aborted when the parent session has been aborted.
|
||||
fn test_download_file_on_aborted_session_returns_error() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let session = XetSession::new(None, None, None, None)?;
|
||||
let group = session.new_download_group()?;
|
||||
session.abort().unwrap();
|
||||
let err = group
|
||||
.download_file_to_path(
|
||||
XetFileInfo {
|
||||
hash: "abc123".to_string(),
|
||||
file_size: 1024,
|
||||
},
|
||||
PathBuf::from("dest.bin"),
|
||||
)
|
||||
.unwrap_err();
|
||||
assert!(matches!(err, SessionError::Aborted));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
// download_file_to_path after finish returns AlreadyFinished.
|
||||
fn test_download_file_after_finish_fails() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let session = XetSession::new(None, None, None, None)?;
|
||||
let g1 = session.new_download_group()?;
|
||||
let g2 = g1.clone();
|
||||
g1.finish()?;
|
||||
let err = g2
|
||||
.download_file_to_path(
|
||||
XetFileInfo {
|
||||
hash: "abc123".to_string(),
|
||||
file_size: 1024,
|
||||
},
|
||||
PathBuf::from("dest.bin"),
|
||||
)
|
||||
.unwrap_err();
|
||||
assert!(matches!(err, SessionError::AlreadyFinished));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
// download_file_to_path on a directly-aborted group returns Aborted (not AlreadyFinished).
|
||||
fn test_download_file_on_aborted_group_returns_aborted() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let session = XetSession::new(None, None, None, None)?;
|
||||
let group = session.new_download_group()?;
|
||||
group.abort()?;
|
||||
let err = group
|
||||
.download_file_to_path(
|
||||
XetFileInfo {
|
||||
hash: "abc123".to_string(),
|
||||
file_size: 1024,
|
||||
},
|
||||
PathBuf::from("dest.bin"),
|
||||
)
|
||||
.unwrap_err();
|
||||
assert!(matches!(err, SessionError::Aborted));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Independence ─────────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
// Finishing one group does not affect the state of another from the same session.
|
||||
fn test_two_groups_are_independent() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let session = XetSession::new(None, None, None, None)?;
|
||||
let g1 = session.new_download_group()?;
|
||||
let g2 = session.new_download_group()?;
|
||||
g1.finish()?;
|
||||
assert!(!g2.is_finished());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Round-trip tests ─────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
// Downloading a previously uploaded file produces byte-identical content at the destination.
|
||||
fn test_download_file_round_trip() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let temp = tempdir()?;
|
||||
let session = local_session(&temp)?;
|
||||
let original = b"Hello, download round-trip!";
|
||||
let file_info = upload_bytes(&session, original, "payload.bin")?;
|
||||
|
||||
let dest = temp.path().join("downloaded.bin");
|
||||
let group = session.new_download_group()?;
|
||||
group.download_file_to_path(file_info, dest.clone())?;
|
||||
group.finish()?;
|
||||
|
||||
assert_eq!(std::fs::read(&dest)?, original);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
// Downloading multiple files from a single group produces correct content for each.
|
||||
fn test_download_multiple_files() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let temp = tempdir()?;
|
||||
let session = local_session(&temp)?;
|
||||
|
||||
let data_a = b"First file content";
|
||||
let data_b = b"Second file content - different";
|
||||
|
||||
// Upload both files; capture handles so results can be retrieved by task_id.
|
||||
let commit = session.new_upload_commit()?;
|
||||
let handle_a = commit.upload_bytes(data_a.to_vec(), Some("a.bin".into()))?;
|
||||
let handle_b = commit.upload_bytes(data_b.to_vec(), Some("b.bin".into()))?;
|
||||
let results = commit.commit()?;
|
||||
|
||||
let to_file_info = |handle: &UploadTaskHandle| -> XetFileInfo {
|
||||
let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap();
|
||||
XetFileInfo {
|
||||
hash: meta.hash.clone(),
|
||||
file_size: meta.file_size,
|
||||
}
|
||||
};
|
||||
|
||||
let dest_a = temp.path().join("a_out.bin");
|
||||
let dest_b = temp.path().join("b_out.bin");
|
||||
let group = session.new_download_group()?;
|
||||
group.download_file_to_path(to_file_info(&handle_a), dest_a.clone())?;
|
||||
group.download_file_to_path(to_file_info(&handle_b), dest_b.clone())?;
|
||||
group.finish()?;
|
||||
|
||||
assert_eq!(std::fs::read(&dest_a)?, data_a);
|
||||
assert_eq!(std::fs::read(&dest_b)?, data_b);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
// After a successful finish the aggregate download progress reflects bytes received.
|
||||
fn test_download_progress_reflects_bytes_after_finish() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let temp = tempdir()?;
|
||||
let session = local_session(&temp)?;
|
||||
let original = b"download progress tracking data";
|
||||
let file_info = upload_bytes(&session, original, "prog.bin")?;
|
||||
|
||||
let dest = temp.path().join("out.bin");
|
||||
let group = session.new_download_group()?;
|
||||
let progress_observer = group.clone();
|
||||
group.download_file_to_path(file_info, dest)?;
|
||||
group.finish()?;
|
||||
|
||||
std::thread::sleep(
|
||||
session
|
||||
.runtime
|
||||
.config()
|
||||
.data
|
||||
.progress_update_interval
|
||||
.saturating_add(Duration::from_secs(1)),
|
||||
);
|
||||
let snapshot = progress_observer.get_progress()?;
|
||||
assert!(snapshot.total().total_bytes_completed > 0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Per-task result access patterns ──────────────────────────────────────
|
||||
//
|
||||
// After finish() completes there are two equivalent ways to retrieve a
|
||||
// per-task DownloadResult:
|
||||
//
|
||||
// 1. HashMap lookup: `finish_results.get(&handle.task_id)`
|
||||
// 2. Direct handle: `handle.result()` (on DownloadTaskHandle)
|
||||
|
||||
#[test]
|
||||
// Pattern 1: per-task result is accessible via task_id in the finish() HashMap.
|
||||
fn test_download_result_accessible_via_task_id_in_finish_map() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let temp = tempdir()?;
|
||||
let session = local_session(&temp)?;
|
||||
let data = b"result via task_id in finish map";
|
||||
let file_info = upload_bytes(&session, data, "file.bin")?;
|
||||
let dest = temp.path().join("out.bin");
|
||||
let group = session.new_download_group()?;
|
||||
let handle = group.download_file_to_path(file_info, dest)?;
|
||||
let results = group.finish()?;
|
||||
let result = results.get(&handle.task_id).expect("task_id must be present in results");
|
||||
assert_eq!(result.as_ref().as_ref().unwrap().file_info.file_size, data.len() as u64);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
// DownloadTaskHandle::result() returns None before finish() is called.
|
||||
fn test_download_result_none_before_finish() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let temp = tempdir()?;
|
||||
let session = local_session(&temp)?;
|
||||
let file_info = upload_bytes(&session, b"some data", "file.bin")?;
|
||||
let dest = temp.path().join("out.bin");
|
||||
let group = session.new_download_group()?;
|
||||
let handle = group.download_file_to_path(file_info, dest)?;
|
||||
assert!(handle.result().is_none(), "result must be None before finish()");
|
||||
group.finish()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
// DownloadTaskHandle::result() returns Some after finish() completes.
|
||||
fn test_download_result_some_after_finish() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let temp = tempdir()?;
|
||||
let session = local_session(&temp)?;
|
||||
let data = b"download result test data";
|
||||
let file_info = upload_bytes(&session, data, "file.bin")?;
|
||||
let dest = temp.path().join("out.bin");
|
||||
let group = session.new_download_group()?;
|
||||
let handle = group.download_file_to_path(file_info.clone(), dest)?;
|
||||
group.finish()?;
|
||||
let result = handle.result().expect("result must be set after finish()");
|
||||
let dl = result.as_ref().as_ref().unwrap();
|
||||
assert_eq!(dl.file_info.file_size, data.len() as u64);
|
||||
assert_eq!(dl.file_info.hash, file_info.hash);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Mutex guard / concurrency test ───────────────────────────────────────
|
||||
//
|
||||
// `download_file_to_path` holds `self.state` for its entire execution so
|
||||
@@ -673,27 +409,26 @@ mod tests {
|
||||
#[test]
|
||||
// finish() must block while download_file_to_path() holds the state lock.
|
||||
fn test_finish_blocked_while_download_registration_holds_state_lock() -> Result<(), Box<dyn std::error::Error>> {
|
||||
use std::sync::mpsc;
|
||||
|
||||
let session = XetSession::new(None, None, None, None)?;
|
||||
let group = session.new_download_group()?;
|
||||
let session = XetSessionBuilder::new().build()?;
|
||||
let runtime = session.runtime.clone();
|
||||
// Create DownloadGroup directly so we can access its private state field
|
||||
// (accessible here because mod tests is a submodule of download_group).
|
||||
let group = runtime.external_run_async_task(DownloadGroup::new(session.clone()))??;
|
||||
let group_for_thread = group.clone();
|
||||
let runtime_for_thread = runtime.clone();
|
||||
|
||||
// Simulate download_file_to_path() holding the state lock mid-registration.
|
||||
let guard = group.inner.state.lock().unwrap();
|
||||
|
||||
let (done_tx, done_rx) = mpsc::channel::<()>();
|
||||
let join_handle = std::thread::spawn(move || {
|
||||
let _ = group_for_thread.finish(); // must block until guard is dropped
|
||||
let _ = runtime_for_thread.external_run_async_task(async move { group_for_thread.finish().await });
|
||||
let _ = done_tx.send(());
|
||||
});
|
||||
|
||||
// Give the spawned thread enough time to reach the state-lock acquisition
|
||||
// inside finish() and block there.
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
assert!(done_rx.try_recv().is_err(), "finish() should be blocked while state lock is held");
|
||||
|
||||
// Release the lock — simulates the enqueue method completing its registration.
|
||||
drop(guard);
|
||||
|
||||
assert!(
|
||||
@@ -703,4 +438,349 @@ mod tests {
|
||||
let _ = join_handle.join();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Identity ─────────────────────────────────────────────────────────────
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
// Two download groups created from the same session have distinct IDs.
|
||||
async fn test_group_has_unique_id() {
|
||||
let session = XetSessionBuilder::new().build_async().await.unwrap();
|
||||
let g1 = session.new_download_group().await.unwrap();
|
||||
let g2 = session.new_download_group().await.unwrap();
|
||||
assert_ne!(g1.id(), g2.id());
|
||||
}
|
||||
|
||||
// ── Initial state ────────────────────────────────────────────────────────
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
// A fresh group has all-zero aggregate progress.
|
||||
async fn test_get_progress_empty_initially() {
|
||||
let session = XetSessionBuilder::new().build_async().await.unwrap();
|
||||
let group = session.new_download_group().await.unwrap();
|
||||
let snapshot = group.get_progress().unwrap();
|
||||
let total = snapshot.total();
|
||||
assert_eq!(total.total_bytes, 0);
|
||||
assert_eq!(total.total_bytes_completed, 0);
|
||||
}
|
||||
|
||||
// ── Finish lifecycle ─────────────────────────────────────────────────────
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
// An empty finish succeeds and returns an empty result set.
|
||||
async fn test_finish_empty_succeeds() {
|
||||
let session = XetSessionBuilder::new().build_async().await.unwrap();
|
||||
let group = session.new_download_group().await.unwrap();
|
||||
let results = group.finish().await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
// finish() transitions the group into the Finished state.
|
||||
async fn test_finish_marks_as_finished() {
|
||||
let session = XetSessionBuilder::new().build_async().await.unwrap();
|
||||
let group = session.new_download_group().await.unwrap();
|
||||
let group_clone = group.clone();
|
||||
group.finish().await.unwrap();
|
||||
assert!(group_clone.is_finished());
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
// A second finish() call on any clone returns AlreadyFinished.
|
||||
async fn test_second_finish_fails() {
|
||||
let session = XetSessionBuilder::new().build_async().await.unwrap();
|
||||
let g1 = session.new_download_group().await.unwrap();
|
||||
let g2 = g1.clone();
|
||||
g1.finish().await.unwrap();
|
||||
let err = g2.finish().await.unwrap_err();
|
||||
assert!(matches!(err, SessionError::AlreadyFinished | SessionError::Internal(_)));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
// finish() unregisters the group from the session's active set.
|
||||
async fn test_finish_unregisters_from_session() {
|
||||
let session = XetSessionBuilder::new().build_async().await.unwrap();
|
||||
let group = session.new_download_group().await.unwrap();
|
||||
assert_eq!(session.active_download_groups.lock().unwrap().len(), 1);
|
||||
group.finish().await.unwrap();
|
||||
assert_eq!(session.active_download_groups.lock().unwrap().len(), 0);
|
||||
}
|
||||
|
||||
// ── Guards ───────────────────────────────────────────────────────────────
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
// download_file_to_path returns Aborted when the parent session has been aborted.
|
||||
async fn test_download_file_on_aborted_session_returns_error() {
|
||||
let session = XetSessionBuilder::new().build_async().await.unwrap();
|
||||
let group = session.new_download_group().await.unwrap();
|
||||
session.abort().unwrap();
|
||||
let err = group
|
||||
.download_file_to_path(
|
||||
XetFileInfo {
|
||||
hash: "abc123".to_string(),
|
||||
file_size: 1024,
|
||||
},
|
||||
std::path::PathBuf::from("dest.bin"),
|
||||
)
|
||||
.unwrap_err();
|
||||
assert!(matches!(err, SessionError::Aborted));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
// download_file_to_path after finish returns AlreadyFinished.
|
||||
async fn test_download_file_after_finish_fails() {
|
||||
let session = XetSessionBuilder::new().build_async().await.unwrap();
|
||||
let g1 = session.new_download_group().await.unwrap();
|
||||
let g2 = g1.clone();
|
||||
g1.finish().await.unwrap();
|
||||
let err = g2
|
||||
.download_file_to_path(
|
||||
XetFileInfo {
|
||||
hash: "abc123".to_string(),
|
||||
file_size: 1024,
|
||||
},
|
||||
std::path::PathBuf::from("dest.bin"),
|
||||
)
|
||||
.unwrap_err();
|
||||
assert!(matches!(err, SessionError::AlreadyFinished));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
// download_file_to_path on a directly-aborted group returns Aborted.
|
||||
async fn test_download_file_on_aborted_group_returns_aborted() {
|
||||
let session = XetSessionBuilder::new().build_async().await.unwrap();
|
||||
let group = session.new_download_group().await.unwrap();
|
||||
group.abort().unwrap();
|
||||
let err = group
|
||||
.download_file_to_path(
|
||||
XetFileInfo {
|
||||
hash: "abc123".to_string(),
|
||||
file_size: 1024,
|
||||
},
|
||||
std::path::PathBuf::from("dest.bin"),
|
||||
)
|
||||
.unwrap_err();
|
||||
assert!(matches!(err, SessionError::Aborted));
|
||||
}
|
||||
|
||||
// ── Independence ─────────────────────────────────────────────────────────
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
// Finishing one group does not affect the state of another from the same session.
|
||||
async fn test_two_groups_are_independent() {
|
||||
let session = XetSessionBuilder::new().build_async().await.unwrap();
|
||||
let g1 = session.new_download_group().await.unwrap();
|
||||
let g2 = session.new_download_group().await.unwrap();
|
||||
g1.finish().await.unwrap();
|
||||
assert!(!g2.is_finished());
|
||||
}
|
||||
|
||||
// ── Round-trip tests ─────────────────────────────────────────────────────
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
// Downloading a previously uploaded file produces byte-identical content at the destination.
|
||||
async fn test_download_file_round_trip() {
|
||||
let temp = tempdir().unwrap();
|
||||
let session = local_session(&temp).await.unwrap();
|
||||
let original = b"Hello, download round-trip!";
|
||||
let file_info = upload_bytes(&session, original, "payload.bin").await.unwrap();
|
||||
|
||||
let dest = temp.path().join("downloaded.bin");
|
||||
let group = session.new_download_group().await.unwrap();
|
||||
group.download_file_to_path(file_info, dest.clone()).unwrap();
|
||||
group.finish().await.unwrap();
|
||||
|
||||
assert_eq!(std::fs::read(&dest).unwrap(), original);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
// Downloading multiple files from a single group produces correct content for each.
|
||||
async fn test_download_multiple_files() {
|
||||
let temp = tempdir().unwrap();
|
||||
let session = local_session(&temp).await.unwrap();
|
||||
|
||||
let data_a = b"First file content";
|
||||
let data_b = b"Second file content - different";
|
||||
|
||||
let commit = session.new_upload_commit().await.unwrap();
|
||||
let handle_a = commit.upload_bytes(data_a.to_vec(), Some("a.bin".into())).await.unwrap();
|
||||
let handle_b = commit.upload_bytes(data_b.to_vec(), Some("b.bin".into())).await.unwrap();
|
||||
let results = commit.commit().await.unwrap();
|
||||
|
||||
let to_file_info = |handle: &crate::xet_session::progress::UploadTaskHandle| -> XetFileInfo {
|
||||
let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap();
|
||||
XetFileInfo {
|
||||
hash: meta.hash.clone(),
|
||||
file_size: meta.file_size,
|
||||
}
|
||||
};
|
||||
|
||||
let dest_a = temp.path().join("a_out.bin");
|
||||
let dest_b = temp.path().join("b_out.bin");
|
||||
let group = session.new_download_group().await.unwrap();
|
||||
group.download_file_to_path(to_file_info(&handle_a), dest_a.clone()).unwrap();
|
||||
group.download_file_to_path(to_file_info(&handle_b), dest_b.clone()).unwrap();
|
||||
group.finish().await.unwrap();
|
||||
|
||||
assert_eq!(std::fs::read(&dest_a).unwrap(), data_a);
|
||||
assert_eq!(std::fs::read(&dest_b).unwrap(), data_b);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
// After a successful finish the aggregate download progress reflects bytes received.
|
||||
async fn test_download_progress_reflects_bytes_after_finish() {
|
||||
let temp = tempdir().unwrap();
|
||||
let session = local_session(&temp).await.unwrap();
|
||||
let original = b"download progress tracking data";
|
||||
let file_info = upload_bytes(&session, original, "prog.bin").await.unwrap();
|
||||
|
||||
let dest = temp.path().join("out.bin");
|
||||
let group = session.new_download_group().await.unwrap();
|
||||
let progress_observer = group.clone();
|
||||
group.download_file_to_path(file_info, dest).unwrap();
|
||||
group.finish().await.unwrap();
|
||||
|
||||
tokio::time::sleep(
|
||||
session
|
||||
.runtime
|
||||
.config()
|
||||
.data
|
||||
.progress_update_interval
|
||||
.saturating_add(Duration::from_secs(1)),
|
||||
)
|
||||
.await;
|
||||
let snapshot = progress_observer.get_progress().unwrap();
|
||||
assert!(snapshot.total().total_bytes_completed > 0);
|
||||
}
|
||||
|
||||
// ── Per-task result access patterns ──────────────────────────────────────
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
// Pattern 1: per-task result is accessible via task_id in the finish() HashMap.
|
||||
async fn test_download_result_accessible_via_task_id_in_finish_map() {
|
||||
let temp = tempdir().unwrap();
|
||||
let session = local_session(&temp).await.unwrap();
|
||||
let data = b"result via task_id in finish map";
|
||||
let file_info = upload_bytes(&session, data, "file.bin").await.unwrap();
|
||||
let dest = temp.path().join("out.bin");
|
||||
let group = session.new_download_group().await.unwrap();
|
||||
let handle = group.download_file_to_path(file_info, dest).unwrap();
|
||||
let results = group.finish().await.unwrap();
|
||||
let result = results.get(&handle.task_id).expect("task_id must be present in results");
|
||||
assert_eq!(result.as_ref().as_ref().unwrap().file_info.file_size, data.len() as u64);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
// DownloadTaskHandle::result() returns None before finish() is called.
|
||||
async fn test_download_result_none_before_finish() {
|
||||
let temp = tempdir().unwrap();
|
||||
let session = local_session(&temp).await.unwrap();
|
||||
let file_info = upload_bytes(&session, b"some data", "file.bin").await.unwrap();
|
||||
let dest = temp.path().join("out.bin");
|
||||
let group = session.new_download_group().await.unwrap();
|
||||
let handle = group.download_file_to_path(file_info, dest).unwrap();
|
||||
assert!(handle.result().is_none(), "result must be None before finish()");
|
||||
group.finish().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
// DownloadTaskHandle::result() returns Some after finish() completes.
|
||||
async fn test_download_result_some_after_finish() {
|
||||
let temp = tempdir().unwrap();
|
||||
let session = local_session(&temp).await.unwrap();
|
||||
let data = b"download result test data";
|
||||
let file_info = upload_bytes(&session, data, "file.bin").await.unwrap();
|
||||
let dest = temp.path().join("out.bin");
|
||||
let group = session.new_download_group().await.unwrap();
|
||||
let handle = group.download_file_to_path(file_info.clone(), dest).unwrap();
|
||||
group.finish().await.unwrap();
|
||||
let result = handle.result().expect("result must be set after finish()");
|
||||
let dl = result.as_ref().as_ref().unwrap();
|
||||
assert_eq!(dl.file_info.file_size, data.len() as u64);
|
||||
assert_eq!(dl.file_info.hash, file_info.hash);
|
||||
}
|
||||
|
||||
// ── Non-tokio executor (Owned-mode bridge) ────────────────────────────────
|
||||
|
||||
#[test]
|
||||
// build_async() falls back to Owned mode from a futures executor, and the
|
||||
// bridge correctly routes download through the owned thread pool while the
|
||||
// future is driven by the caller's executor (futures::block_on).
|
||||
fn test_async_bridge_works_from_futures_executor() {
|
||||
let temp = tempdir().unwrap();
|
||||
|
||||
futures::executor::block_on(async {
|
||||
let session = local_session(&temp).await.unwrap();
|
||||
assert_eq!(session.runtime_mode, RuntimeMode::Owned);
|
||||
|
||||
let data = b"hello from futures executor";
|
||||
let commit = session.new_upload_commit().await.unwrap();
|
||||
let handle = commit.upload_bytes(data.to_vec(), Some("test.bin".into())).await.unwrap();
|
||||
let results = commit.commit().await.unwrap();
|
||||
let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap();
|
||||
let file_info = XetFileInfo {
|
||||
hash: meta.hash.clone(),
|
||||
file_size: meta.file_size,
|
||||
};
|
||||
|
||||
let dest = temp.path().join("out_futures.bin");
|
||||
let group = session.new_download_group().await.unwrap();
|
||||
group.download_file_to_path(file_info, dest.clone()).unwrap();
|
||||
group.finish().await.unwrap();
|
||||
assert_eq!(std::fs::read(&dest).unwrap(), data);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
// Same as above but driven by the smol executor.
|
||||
fn test_async_bridge_works_from_smol_executor() {
|
||||
let temp = tempdir().unwrap();
|
||||
|
||||
smol::block_on(async {
|
||||
let session = local_session(&temp).await.unwrap();
|
||||
assert_eq!(session.runtime_mode, RuntimeMode::Owned);
|
||||
|
||||
let data = b"hello from smol executor";
|
||||
let commit = session.new_upload_commit().await.unwrap();
|
||||
let handle = commit.upload_bytes(data.to_vec(), Some("test.bin".into())).await.unwrap();
|
||||
let results = commit.commit().await.unwrap();
|
||||
let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap();
|
||||
let file_info = XetFileInfo {
|
||||
hash: meta.hash.clone(),
|
||||
file_size: meta.file_size,
|
||||
};
|
||||
|
||||
let dest = temp.path().join("out_smol.bin");
|
||||
let group = session.new_download_group().await.unwrap();
|
||||
group.download_file_to_path(file_info, dest.clone()).unwrap();
|
||||
group.finish().await.unwrap();
|
||||
assert_eq!(std::fs::read(&dest).unwrap(), data);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
// Same as above but driven by the async-std executor.
|
||||
fn test_async_bridge_works_from_async_std_executor() {
|
||||
let temp = tempdir().unwrap();
|
||||
|
||||
async_std::task::block_on(async {
|
||||
let session = local_session(&temp).await.unwrap();
|
||||
assert_eq!(session.runtime_mode, RuntimeMode::Owned);
|
||||
|
||||
let data = b"hello from async-std executor";
|
||||
let commit = session.new_upload_commit().await.unwrap();
|
||||
let handle = commit.upload_bytes(data.to_vec(), Some("test.bin".into())).await.unwrap();
|
||||
let results = commit.commit().await.unwrap();
|
||||
let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap();
|
||||
let file_info = XetFileInfo {
|
||||
hash: meta.hash.clone(),
|
||||
file_size: meta.file_size,
|
||||
};
|
||||
|
||||
let dest = temp.path().join("out_async_std.bin");
|
||||
let group = session.new_download_group().await.unwrap();
|
||||
group.download_file_to_path(file_info, dest.clone()).unwrap();
|
||||
group.finish().await.unwrap();
|
||||
assert_eq!(std::fs::read(&dest).unwrap(), data);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,46 +4,57 @@
|
||||
//! file operations:
|
||||
//!
|
||||
//! ```text
|
||||
//! XetSession — owns the tokio runtime and authentication credentials
|
||||
//! XetSession — holds runtime context and authentication credentials
|
||||
//! ├── UploadCommit — groups related uploads; finalised with commit()
|
||||
//! └── DownloadGroup — groups related downloads; finalised with finish()
|
||||
//! ```
|
||||
//!
|
||||
//! Each [`XetSession`] owns its own tokio runtime and configuration, so
|
||||
//! Each [`XetSession`] holds its own runtime context and configuration, so
|
||||
//! multiple sessions with different endpoints or credentials can coexist in
|
||||
//! the same process. Cloning a session, commit, or group is cheap — all
|
||||
//! clones share the same underlying state via `Arc`.
|
||||
//!
|
||||
//! ## Uploads
|
||||
//!
|
||||
//! Create an [`UploadCommit`] with [`XetSession::new_upload_commit`], queue
|
||||
//! files with [`upload_from_path`](UploadCommit::upload_from_path) or
|
||||
//! [`upload_bytes`](UploadCommit::upload_bytes), then call
|
||||
//! [`commit`](UploadCommit::commit) to wait for all transfers to finish and
|
||||
//! receive a `HashMap<Ulid, `[`UploadResult`]`>` keyed by task ID
|
||||
//! (`UploadResult` = `Arc<Result<`[`FileMetadata`]`, `[`SessionError`]`>>`).
|
||||
//! Per-task results can also be read directly from the returned
|
||||
//! [`UploadTaskHandle`] via [`result`](UploadTaskHandle::result) after
|
||||
//! `commit()` returns.
|
||||
//! For **sync** callers: create an [`UploadCommitSync`] with
|
||||
//! [`XetSession::new_upload_commit_blocking`], queue files with
|
||||
//! [`upload_from_path`](UploadCommitSync::upload_from_path) or
|
||||
//! [`upload_bytes`](UploadCommitSync::upload_bytes), then call
|
||||
//! [`commit`](UploadCommitSync::commit) to block until all transfers finish and
|
||||
//! receive a `HashMap<Ulid, `[`UploadResult`]`>` keyed by task ID.
|
||||
//!
|
||||
//! For **async** callers: create an [`UploadCommit`] with
|
||||
//! [`XetSession::new_upload_commit`], queue files the same way, then
|
||||
//! `await` [`commit`](UploadCommit::commit).
|
||||
//!
|
||||
//! `UploadResult` = `Arc<Result<`[`FileMetadata`]`, `[`SessionError`]`>>`.
|
||||
//! Per-task results can also be read from the returned [`UploadTaskHandle`]
|
||||
//! via [`result`](UploadTaskHandle::result) after `commit()` returns.
|
||||
//!
|
||||
//! ## Downloads
|
||||
//!
|
||||
//! Create a [`DownloadGroup`] with [`XetSession::new_download_group`], queue
|
||||
//! files with [`download_file_to_path`](DownloadGroup::download_file_to_path),
|
||||
//! then call [`finish`](DownloadGroup::finish) to wait for all transfers and
|
||||
//! receive a `HashMap<Ulid, `[`DownloadResult`]`>` keyed by task ID
|
||||
//! (`DownloadResult` = `Arc<Result<`[`DownloadedFile`]`, `[`SessionError`]`>>`).
|
||||
//! Per-task results can also be read directly from the returned
|
||||
//! [`DownloadTaskHandle`] via [`result`](DownloadTaskHandle::result) after
|
||||
//! `finish()` returns.
|
||||
//! For **sync** callers: create a [`DownloadGroupSync`] with
|
||||
//! [`XetSession::new_download_group_blocking`], queue files with
|
||||
//! [`download_file_to_path`](DownloadGroupSync::download_file_to_path), then
|
||||
//! call [`finish`](DownloadGroupSync::finish) to block until all transfers
|
||||
//! complete and receive a `HashMap<Ulid, `[`DownloadResult`]`>` keyed by task ID.
|
||||
//!
|
||||
//! For **async** callers: create a [`DownloadGroup`] with
|
||||
//! [`XetSession::new_download_group`], queue files the same way, then
|
||||
//! `await` [`finish`](DownloadGroup::finish).
|
||||
//!
|
||||
//! `DownloadResult` = `Arc<Result<`[`DownloadedFile`]`, `[`SessionError`]`>>`.
|
||||
//! Per-task results can also be read from the returned [`DownloadTaskHandle`]
|
||||
//! via [`result`](DownloadTaskHandle::result) after `finish()` returns.
|
||||
//!
|
||||
//! ## Progress tracking
|
||||
//!
|
||||
//! Both [`UploadCommit`] and [`DownloadGroup`] expose
|
||||
//! [`get_progress`](UploadCommit::get_progress), which returns a
|
||||
//! All four types ([`UploadCommit`], [`UploadCommitSync`], [`DownloadGroup`],
|
||||
//! [`DownloadGroupSync`]) expose `get_progress()`, which returns a
|
||||
//! [`ProgressSnapshot`] without acquiring a lock on the calling thread
|
||||
//! (useful for Python bindings that must release the GIL). Poll it from a
|
||||
//! background thread while the main thread blocks in `commit()` / `finish()`.
|
||||
//! background thread/task while the main thread/task blocks in
|
||||
//! `commit()` / `finish()`.
|
||||
//!
|
||||
//! ## Error handling
|
||||
//!
|
||||
@@ -53,50 +64,81 @@
|
||||
//! `HashMap<Ulid, `[`DownloadResult`]`>` keyed by task ID, so a single failed
|
||||
//! file does not discard all others.
|
||||
//!
|
||||
//! # Quick start
|
||||
//! # Quick start — sync API
|
||||
//!
|
||||
//! ```rust,no_run
|
||||
//! use xet::xet_session::{XetFileInfo, XetSessionBuilder};
|
||||
//!
|
||||
//! // 1. Build a session
|
||||
//! // 1. Build a session — sync (non-async) context only.
|
||||
//! // For async code call build_async().await instead.
|
||||
//! let session = XetSessionBuilder::new()
|
||||
//! .with_endpoint("https://cas.example.com".into())
|
||||
//! .with_token_info("my-token".into(), 1_700_000_000)
|
||||
//! .build()?;
|
||||
//!
|
||||
//! // 2. Upload
|
||||
//! let commit = session.new_upload_commit()?;
|
||||
//! // 2. Upload — use the _blocking factory; returns UploadCommitSync
|
||||
//! let commit = session.new_upload_commit_blocking()?;
|
||||
//! let handle = commit.upload_from_path("file.bin".into())?;
|
||||
//! commit.commit()?;
|
||||
//! // Access result directly from the handle (populated by commit())
|
||||
//! // UploadResult = Arc<Result<FileMetadata, SessionError>>
|
||||
//! let m = handle.result().unwrap(); // Option<UploadResult>
|
||||
//! let m = m.as_ref().as_ref().unwrap(); // &FileMetadata
|
||||
//! let results = commit.commit()?;
|
||||
//! let m = results.values().next().unwrap().as_ref().as_ref().unwrap();
|
||||
//!
|
||||
//! // 3. Download
|
||||
//! let group = session.new_download_group()?;
|
||||
//! // 3. Download — use the _blocking factory; returns DownloadGroupSync
|
||||
//! let group = session.new_download_group_blocking()?;
|
||||
//! let info = XetFileInfo {
|
||||
//! hash: m.hash.clone(),
|
||||
//! file_size: m.file_size,
|
||||
//! };
|
||||
//! let dl_handle = group.download_file_to_path(info, "out/file.bin".into())?;
|
||||
//! let finish_results = group.finish()?;
|
||||
//! // Pattern 1: look up by task ID in the returned HashMap
|
||||
//! // DownloadResult = Arc<Result<DownloadedFile, SessionError>>
|
||||
//! let r1 = finish_results.get(&dl_handle.task_id).unwrap(); // &DownloadResult
|
||||
//! let r1 = r1.as_ref().as_ref().unwrap(); // &DownloadedFile
|
||||
//! // Pattern 2: read directly from the handle (populated by finish())
|
||||
//! let r2 = dl_handle.result().unwrap(); // DownloadResult
|
||||
//! let r2 = r2.as_ref().as_ref().unwrap(); // &DownloadedFile
|
||||
//! let r = finish_results.get(&dl_handle.task_id).unwrap().as_ref().as_ref().unwrap();
|
||||
//!
|
||||
//! # Ok::<(), xet::xet_session::SessionError>(())
|
||||
//! ```
|
||||
//!
|
||||
//! # Quick start — async API
|
||||
//!
|
||||
//! ```rust,no_run
|
||||
//! use xet::xet_session::{XetFileInfo, XetSessionBuilder};
|
||||
//!
|
||||
//! # async fn example() -> Result<(), xet::xet_session::SessionError> {
|
||||
//! // 1. Build a session. build_async() auto-detects the executor:
|
||||
//! // - tokio (multi-thread): wraps the caller's handle, no second thread pool.
|
||||
//! // - non-tokio (smol, async-std, etc.): creates an owned thread pool.
|
||||
//! let session = XetSessionBuilder::new()
|
||||
//! .with_endpoint("https://cas.example.com".into())
|
||||
//! .with_token_info("my-token".into(), 1_700_000_000)
|
||||
//! .build_async()
|
||||
//! .await?;
|
||||
//!
|
||||
//! // 2. Upload — use the async factory; returns UploadCommit
|
||||
//! let commit = session.new_upload_commit().await?;
|
||||
//! let handle = commit.upload_from_path("file.bin".into()).await?;
|
||||
//! // UploadResult = Arc<Result<FileMetadata, SessionError>>
|
||||
//! let results = commit.commit().await?;
|
||||
//! let m = results.values().next().unwrap().as_ref().as_ref().unwrap();
|
||||
//!
|
||||
//! // 3. Download — use the async factory; returns DownloadGroup
|
||||
//! let group = session.new_download_group().await?;
|
||||
//! let info = XetFileInfo {
|
||||
//! hash: m.hash.clone(),
|
||||
//! file_size: m.file_size,
|
||||
//! };
|
||||
//! let dl_handle = group.download_file_to_path(info, "out/file.bin".into())?;
|
||||
//! let finish_results = group.finish().await?;
|
||||
//! // DownloadResult = Arc<Result<DownloadedFile, SessionError>>
|
||||
//! let r = finish_results.get(&dl_handle.task_id).unwrap().as_ref().as_ref().unwrap();
|
||||
//! # Ok(())
|
||||
//! # }
|
||||
//! ```
|
||||
|
||||
mod common;
|
||||
mod download_group;
|
||||
mod errors;
|
||||
mod progress;
|
||||
mod session;
|
||||
pub mod sync;
|
||||
mod upload_commit;
|
||||
|
||||
pub use download_group::{DownloadGroup, DownloadResult, DownloadedFile};
|
||||
@@ -105,6 +147,7 @@ pub use progress::{
|
||||
DownloadTaskHandle, FileProgress, ProgressSnapshot, TaskHandle, TaskStatus, TotalProgressSnapshot, UploadTaskHandle,
|
||||
};
|
||||
pub use session::{XetSession, XetSessionBuilder};
|
||||
pub use sync::{DownloadGroupSync, UploadCommitSync};
|
||||
pub use upload_commit::{FileMetadata, UploadCommit, UploadResult};
|
||||
pub use xet_data::processing::XetFileInfo;
|
||||
// Re-export XetConfig for convenience
|
||||
|
||||
@@ -30,16 +30,16 @@ pub enum TaskStatus {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TaskHandle {
|
||||
pub(crate) status: Option<Arc<Mutex<TaskStatus>>>,
|
||||
pub(crate) group_progress: Arc<GroupProgress>,
|
||||
pub(super) status: Option<Arc<Mutex<TaskStatus>>>,
|
||||
pub(super) group_progress: Arc<GroupProgress>,
|
||||
/// Id of the task, can be used to retrive per-task progress and result.
|
||||
pub task_id: Ulid,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct UploadTaskHandle {
|
||||
pub(crate) inner: TaskHandle,
|
||||
pub(crate) result: Arc<OnceLock<UploadResult>>,
|
||||
pub(super) inner: TaskHandle,
|
||||
pub(super) result: Arc<OnceLock<UploadResult>>,
|
||||
}
|
||||
|
||||
impl Deref for UploadTaskHandle {
|
||||
@@ -52,8 +52,8 @@ impl Deref for UploadTaskHandle {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DownloadTaskHandle {
|
||||
pub(crate) inner: TaskHandle,
|
||||
pub(crate) result: Arc<OnceLock<DownloadResult>>,
|
||||
pub(super) inner: TaskHandle,
|
||||
pub(super) result: Arc<OnceLock<DownloadResult>>,
|
||||
}
|
||||
|
||||
impl Deref for DownloadTaskHandle {
|
||||
@@ -107,7 +107,7 @@ impl ProgressSnapshot {
|
||||
}
|
||||
}
|
||||
|
||||
/// Snapshot of aggregate progress returned by [`TaskProgress::total`].
|
||||
/// Snapshot of aggregate progress.
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct TotalProgressSnapshot {
|
||||
/// Total bytes known to process (includes deduplicated bytes).
|
||||
@@ -124,7 +124,7 @@ pub struct TotalProgressSnapshot {
|
||||
pub total_transfer_bytes_completion_rate: Option<f64>,
|
||||
}
|
||||
|
||||
/// Snapshot of a single file's progress returned by [`TaskProgress::files`].
|
||||
/// Snapshot of a single file's progress.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct FileProgress {
|
||||
/// File name as reported by the data layer.
|
||||
|
||||
@@ -1,16 +1,22 @@
|
||||
//! XetSession - manages runtime and configuration
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::pin::pin;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::task::{Context, Waker};
|
||||
|
||||
use http::HeaderMap;
|
||||
use tracing::info;
|
||||
use ulid::Ulid;
|
||||
use xet_client::cas_client::auth::TokenRefresher;
|
||||
use xet_runtime::RuntimeError;
|
||||
use xet_runtime::config::XetConfig;
|
||||
use xet_runtime::core::XetRuntime;
|
||||
|
||||
use super::download_group::DownloadGroup;
|
||||
use super::errors::SessionError;
|
||||
use super::sync::{DownloadGroupSync, UploadCommitSync};
|
||||
use super::upload_commit::UploadCommit;
|
||||
|
||||
/// Session state
|
||||
@@ -19,49 +25,131 @@ enum SessionState {
|
||||
Aborted,
|
||||
}
|
||||
|
||||
/// Whether the session owns its tokio runtime or inherits an external one.
|
||||
///
|
||||
/// - **`Owned`**: session created its own thread pool via [`XetSessionBuilder::build`] or
|
||||
/// [`XetSessionBuilder::build_async`] (outside tokio). Both `_blocking` and async methods are supported. Async
|
||||
/// methods use an internal `bridge_to_owned` bridge that routes futures onto the owned thread pool, so they work from
|
||||
/// any executor (tokio, smol, async-std).
|
||||
///
|
||||
/// - **`External`**: session wraps a caller-provided tokio handle via [`XetSessionBuilder::with_tokio_handle`] or
|
||||
/// [`XetSessionBuilder::build_async`] (tokio context). Only async methods may be called; `_blocking` methods return
|
||||
/// [`SessionError::WrongRuntimeMode`]. No second thread pool is created.
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
|
||||
pub(super) enum RuntimeMode {
|
||||
Owned,
|
||||
External,
|
||||
}
|
||||
|
||||
/// All shared state for a session.
|
||||
/// Lives behind `Arc<XetSessionInner>` — do not use this type directly.
|
||||
#[doc(hidden)]
|
||||
pub struct XetSessionInner {
|
||||
// Independently cloned by background tasks, so needs its own Arc.
|
||||
pub(crate) runtime: Arc<XetRuntime>,
|
||||
pub(super) runtime: Arc<XetRuntime>,
|
||||
|
||||
/// Whether the session owns its runtime or wraps an external tokio handle.
|
||||
pub(super) runtime_mode: RuntimeMode,
|
||||
|
||||
// Only accessed through &self; no independent cloning needed.
|
||||
pub(crate) config: XetConfig,
|
||||
pub(super) config: XetConfig,
|
||||
|
||||
// CAS endpoint and auth (shared by all upload commits/download groups)
|
||||
pub(crate) endpoint: Option<String>,
|
||||
pub(crate) token_info: Option<(String, u64)>,
|
||||
pub(crate) token_refresher: Option<Arc<dyn TokenRefresher>>,
|
||||
pub(crate) custom_headers: Option<Arc<HeaderMap>>,
|
||||
pub(super) endpoint: Option<String>,
|
||||
pub(super) token_info: Option<(String, u64)>,
|
||||
pub(super) token_refresher: Option<Arc<dyn TokenRefresher>>,
|
||||
pub(super) custom_headers: Option<Arc<HeaderMap>>,
|
||||
|
||||
// Track active upload commits and download groups.
|
||||
pub(crate) active_upload_commits: Mutex<HashMap<Ulid, UploadCommit>>,
|
||||
pub(crate) active_download_groups: Mutex<HashMap<Ulid, DownloadGroup>>,
|
||||
pub(super) active_upload_commits: Mutex<HashMap<Ulid, UploadCommit>>,
|
||||
pub(super) active_download_groups: Mutex<HashMap<Ulid, DownloadGroup>>,
|
||||
|
||||
// Session state
|
||||
state: Mutex<SessionState>,
|
||||
pub(crate) id: Ulid,
|
||||
pub(super) id: Ulid,
|
||||
}
|
||||
|
||||
/// Probe whether a tokio runtime handle meets the requirements for External mode.
|
||||
///
|
||||
/// Checks three things:
|
||||
/// 1. **Multi-threaded flavor** (non-WASM only).
|
||||
/// 2. **Time driver** — required for timeouts, retry backoff, and progress intervals.
|
||||
/// 3. **IO driver** — required for all network I/O via `reqwest`/`hyper`.
|
||||
///
|
||||
/// Driver availability is probed by entering the handle's context and polling a
|
||||
/// driver-dependent future once inside `catch_unwind`. Tokio panics synchronously
|
||||
/// on the first poll when a driver is absent, so the result is immediate — no
|
||||
/// spawning or blocking required.
|
||||
///
|
||||
/// **Fragility note:** this probing technique relies on tokio panicking
|
||||
/// synchronously on the first poll of `tokio::time::sleep` /
|
||||
/// `tokio::net::TcpListener::bind` when the corresponding driver (time / IO)
|
||||
/// is absent. This is undocumented internal behavior validated against
|
||||
/// tokio 1.x. If a future tokio version returns an error instead of
|
||||
/// panicking, this function will incorrectly accept a runtime missing drivers.
|
||||
///
|
||||
/// Returns `true` if all requirements are met, `false` otherwise.
|
||||
fn handle_meets_session_requirements(handle: &tokio::runtime::Handle) -> bool {
|
||||
// Non-WASM: require a multi-threaded runtime.
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
if matches!(handle.runtime_flavor(), tokio::runtime::RuntimeFlavor::CurrentThread) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let _guard = handle.enter();
|
||||
let waker = Waker::noop();
|
||||
let mut cx = Context::from_waker(waker);
|
||||
|
||||
let has_time = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
|
||||
let mut sleep = pin!(tokio::time::sleep(std::time::Duration::ZERO));
|
||||
let _ = sleep.as_mut().poll(&mut cx);
|
||||
}))
|
||||
.is_ok();
|
||||
|
||||
let has_io = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
|
||||
let mut bind = pin!(tokio::net::TcpListener::bind("127.0.0.1:0"));
|
||||
let _ = bind.as_mut().poll(&mut cx);
|
||||
}))
|
||||
.is_ok();
|
||||
|
||||
has_time && has_io
|
||||
}
|
||||
|
||||
/// Builder for [`XetSession`].
|
||||
///
|
||||
/// All fields are optional; call [`build`](XetSessionBuilder::build) when done.
|
||||
/// All fields are optional; call [`build`](XetSessionBuilder::build) (sync) or
|
||||
/// [`build_async`](XetSessionBuilder::build_async) (async) when done.
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// # use xet::xet_session::XetSessionBuilder;
|
||||
/// // Sync context — session owns its runtime:
|
||||
/// let session = XetSessionBuilder::new()
|
||||
/// .with_endpoint("https://cas.example.com".into())
|
||||
/// .with_token_info("my-token".into(), 1_700_000_000)
|
||||
/// .build()?;
|
||||
/// # Ok::<(), xet::xet_session::SessionError>(())
|
||||
/// ```
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// # use xet::xet_session::XetSessionBuilder;
|
||||
/// # async fn example() -> Result<(), xet::xet_session::SessionError> {
|
||||
/// // Async context — wraps the caller's tokio handle (External mode) if inside tokio,
|
||||
/// // or creates an owned runtime (Owned mode) if called from a non-tokio executor:
|
||||
/// let session = XetSessionBuilder::new()
|
||||
/// .with_endpoint("https://cas.example.com".into())
|
||||
/// .with_token_info("my-token".into(), 1_700_000_000)
|
||||
/// .build_async()
|
||||
/// .await?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub struct XetSessionBuilder {
|
||||
config: XetConfig,
|
||||
endpoint: Option<String>,
|
||||
token_info: Option<(String, u64)>,
|
||||
token_refresher: Option<Arc<dyn TokenRefresher>>,
|
||||
custom_headers: Option<Arc<HeaderMap>>,
|
||||
tokio_handle: Option<tokio::runtime::Handle>,
|
||||
}
|
||||
|
||||
impl Default for XetSessionBuilder {
|
||||
@@ -79,6 +167,7 @@ impl XetSessionBuilder {
|
||||
token_info: None,
|
||||
token_refresher: None,
|
||||
custom_headers: None,
|
||||
tokio_handle: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,6 +179,7 @@ impl XetSessionBuilder {
|
||||
token_info: None,
|
||||
token_refresher: None,
|
||||
custom_headers: None,
|
||||
tokio_handle: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -125,15 +215,72 @@ impl XetSessionBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
/// Attach to an existing tokio runtime handle.
|
||||
///
|
||||
/// If the handle meets session requirements (multi-thread flavor, time driver, IO driver),
|
||||
/// the session will wrap it — no second thread pool is created (External mode). Only async
|
||||
/// methods (`new_upload_commit`, `new_download_group`) may be called; `_blocking` variants
|
||||
/// will return [`SessionError::WrongRuntimeMode`].
|
||||
///
|
||||
/// If the handle does **not** meet requirements (e.g. `current_thread` flavor or missing
|
||||
/// drivers), it is silently ignored and [`build`](Self::build) will fall back to creating
|
||||
/// an owned thread pool (Owned mode) instead.
|
||||
///
|
||||
/// Use [`build_async`](Self::build_async) as a convenient alternative when building from
|
||||
/// within a tokio async context.
|
||||
pub fn with_tokio_handle(self, handle: tokio::runtime::Handle) -> Self {
|
||||
let accept = handle_meets_session_requirements(&handle);
|
||||
if !accept {
|
||||
info!("supplied tokio handle rejected (missing drivers or wrong flavor); falling back to Owned mode");
|
||||
}
|
||||
Self {
|
||||
tokio_handle: accept.then_some(handle),
|
||||
..self
|
||||
}
|
||||
}
|
||||
|
||||
/// Build and automatically attach to the current runtime.
|
||||
///
|
||||
/// Despite being `async`, this method resolves synchronously (no internal
|
||||
/// `.await` points). It is declared `async` so callers in an async context
|
||||
/// can use it naturally alongside `tokio::runtime::Handle::try_current()`
|
||||
/// detection.
|
||||
///
|
||||
/// - **Tokio context** with a suitable runtime (multi-thread, time + IO drivers): wraps the caller's handle via
|
||||
/// [`with_tokio_handle`](Self::with_tokio_handle) — External mode.
|
||||
/// - **Tokio context** with an unsuitable runtime (e.g. `current_thread`): handle is discarded by
|
||||
/// `with_tokio_handle`; falls back to an owned thread pool — Owned mode.
|
||||
/// - **Non-tokio context** (smol, async-std, etc.): creates an owned thread pool — Owned mode; async methods use an
|
||||
/// internal bridge compatible with any executor.
|
||||
pub async fn build_async(self) -> Result<XetSession, SessionError> {
|
||||
match tokio::runtime::Handle::try_current() {
|
||||
Ok(handle) => self.with_tokio_handle(handle).build(),
|
||||
Err(_) => self.build(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Consume the builder and create a [`XetSession`].
|
||||
///
|
||||
/// - If a valid tokio handle was previously set via [`with_tokio_handle`](Self::with_tokio_handle), the session
|
||||
/// wraps that handle (External mode) — no second thread pool is created.
|
||||
/// - Otherwise, creates an owned thread pool (Owned mode); async methods use an internal bridge and work from any
|
||||
/// executor, and `_blocking` methods are available.
|
||||
///
|
||||
/// For async contexts, prefer [`build_async`](Self::build_async).
|
||||
pub fn build(self) -> Result<XetSession, SessionError> {
|
||||
XetSession::new_with_config(
|
||||
let (runtime, mode) = match self.tokio_handle {
|
||||
Some(handle) => (XetRuntime::from_external_with_config(handle, self.config.clone()), RuntimeMode::External),
|
||||
None => (XetRuntime::new_with_config(self.config.clone())?, RuntimeMode::Owned),
|
||||
};
|
||||
Ok(XetSession::new(
|
||||
self.config,
|
||||
self.endpoint,
|
||||
self.token_info,
|
||||
self.token_refresher,
|
||||
self.custom_headers,
|
||||
)
|
||||
runtime,
|
||||
mode,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -167,33 +314,20 @@ impl std::ops::Deref for XetSession {
|
||||
}
|
||||
|
||||
impl XetSession {
|
||||
/// Create a session with default [`XetConfig`] — used by tests only.
|
||||
/// In production code, use [`XetSessionBuilder`] instead.
|
||||
#[cfg(test)]
|
||||
pub(crate) fn new(
|
||||
endpoint: Option<String>,
|
||||
token_info: Option<(String, u64)>,
|
||||
token_refresher: Option<Arc<dyn TokenRefresher>>,
|
||||
custom_headers: Option<Arc<HeaderMap>>,
|
||||
) -> Result<Self, SessionError> {
|
||||
Self::new_with_config(XetConfig::new(), endpoint, token_info, token_refresher, custom_headers)
|
||||
}
|
||||
|
||||
/// Internal constructor called by [`XetSessionBuilder::build`].
|
||||
pub(crate) fn new_with_config(
|
||||
/// Low-level constructor used by [`XetSessionBuilder::build`].
|
||||
fn new(
|
||||
config: XetConfig,
|
||||
endpoint: Option<String>,
|
||||
token_info: Option<(String, u64)>,
|
||||
token_refresher: Option<Arc<dyn TokenRefresher>>,
|
||||
custom_headers: Option<Arc<HeaderMap>>,
|
||||
) -> Result<Self, SessionError> {
|
||||
let runtime = XetRuntime::new_with_config(config.clone())?;
|
||||
|
||||
let session_id = Ulid::new();
|
||||
|
||||
Ok(Self {
|
||||
runtime: Arc<XetRuntime>,
|
||||
runtime_mode: RuntimeMode,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(XetSessionInner {
|
||||
runtime,
|
||||
runtime_mode,
|
||||
config,
|
||||
endpoint,
|
||||
token_info,
|
||||
@@ -202,45 +336,151 @@ impl XetSession {
|
||||
active_upload_commits: Mutex::new(HashMap::new()),
|
||||
active_download_groups: Mutex::new(HashMap::new()),
|
||||
state: Mutex::new(SessionState::Alive),
|
||||
id: session_id,
|
||||
id: Ulid::new(),
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Run a future on the appropriate runtime for this session.
|
||||
///
|
||||
/// In External mode the future is awaited directly on the caller's executor.
|
||||
/// In Owned mode the future is bridged onto the owned thread pool via
|
||||
/// [`XetRuntime::bridge_to_owned`].
|
||||
pub(super) async fn dispatch<T, F>(&self, task_name: &'static str, fut: F) -> Result<T, RuntimeError>
|
||||
where
|
||||
F: Future<Output = T> + Send + 'static,
|
||||
T: Send + 'static,
|
||||
{
|
||||
match self.runtime_mode {
|
||||
RuntimeMode::External => Ok(fut.await),
|
||||
RuntimeMode::Owned => self.runtime.bridge_to_owned(task_name, fut).await,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new [`UploadCommit`] that groups related file uploads.
|
||||
///
|
||||
/// Returns `Err(SessionError::Aborted)` if the session has been aborted.
|
||||
pub fn new_upload_commit(&self) -> Result<UploadCommit, SessionError> {
|
||||
let state = self.state.lock()?;
|
||||
if matches!(*state, SessionState::Aborted) {
|
||||
return Err(SessionError::Aborted);
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// This is an `async fn` and must be `.await`ed. For sync Rust or Python (PyO3) callers,
|
||||
/// use [`new_upload_commit_blocking`](Self::new_upload_commit_blocking).
|
||||
pub async fn new_upload_commit(&self) -> Result<UploadCommit, SessionError> {
|
||||
// Check state before the async init; drop the guard so it is not held across .await.
|
||||
{
|
||||
let state = self.state.lock()?;
|
||||
if matches!(*state, SessionState::Aborted) {
|
||||
return Err(SessionError::Aborted);
|
||||
}
|
||||
}
|
||||
|
||||
let commit = UploadCommit::new(self.clone())?;
|
||||
let session = self.clone();
|
||||
let commit = self
|
||||
.dispatch("new_upload_commit", async move { UploadCommit::new(session).await })
|
||||
.await??;
|
||||
|
||||
// Register the commit
|
||||
// Register the commit (sync insertion, safe in any executor context)
|
||||
self.active_upload_commits.lock()?.insert(commit.id(), commit.clone());
|
||||
|
||||
Ok(commit)
|
||||
}
|
||||
|
||||
/// Create a new [`UploadCommit`] from a **sync** (non-async) context.
|
||||
///
|
||||
/// Returns `Err(SessionError::Aborted)` if the session has been aborted.
|
||||
/// Returns `Err(SessionError::WrongRuntimeMode)` if the session was built with
|
||||
/// [`XetSessionBuilder::with_tokio_handle`] / [`XetSessionBuilder::build_async`].
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if called from within a **tokio** async runtime (tokio sets a thread-local
|
||||
/// context that `Handle::block_on` detects and panics on). Non-tokio executors (smol,
|
||||
/// async-std, `futures::executor`) do not set this context, so calling from those is
|
||||
/// safe — it blocks the executor thread until the task completes. Use
|
||||
/// [`new_upload_commit`](Self::new_upload_commit) from async contexts instead.
|
||||
pub fn new_upload_commit_blocking(&self) -> Result<UploadCommitSync, SessionError> {
|
||||
if matches!(self.runtime_mode, RuntimeMode::External) {
|
||||
return Err(SessionError::wrong_mode(
|
||||
"new_upload_commit_blocking() cannot be called on a session built with \
|
||||
with_tokio_handle() / build_async(); use new_upload_commit().await instead",
|
||||
));
|
||||
}
|
||||
{
|
||||
let state = self.state.lock()?;
|
||||
if matches!(*state, SessionState::Aborted) {
|
||||
return Err(SessionError::Aborted);
|
||||
}
|
||||
}
|
||||
|
||||
let sync_commit = UploadCommitSync::new(self.clone())?;
|
||||
self.active_upload_commits
|
||||
.lock()?
|
||||
.insert(sync_commit.inner.id(), sync_commit.inner.clone());
|
||||
Ok(sync_commit)
|
||||
}
|
||||
|
||||
/// Create a new [`DownloadGroup`] that groups related file downloads.
|
||||
///
|
||||
/// Returns `Err(SessionError::Aborted)` if the session has been aborted.
|
||||
pub fn new_download_group(&self) -> Result<DownloadGroup, SessionError> {
|
||||
let state = self.state.lock()?;
|
||||
if matches!(*state, SessionState::Aborted) {
|
||||
return Err(SessionError::Aborted);
|
||||
///
|
||||
/// # Note
|
||||
///
|
||||
/// This is an `async fn` and must be `.await`ed. For sync Rust or Python (PyO3) callers,
|
||||
/// use [`new_download_group_blocking`](Self::new_download_group_blocking).
|
||||
pub async fn new_download_group(&self) -> Result<DownloadGroup, SessionError> {
|
||||
// Check state before the async init; drop the guard so it is not held across .await.
|
||||
{
|
||||
let state = self.state.lock()?;
|
||||
if matches!(*state, SessionState::Aborted) {
|
||||
return Err(SessionError::Aborted);
|
||||
}
|
||||
}
|
||||
|
||||
let group = DownloadGroup::new(self.clone())?;
|
||||
let session = self.clone();
|
||||
let group = self
|
||||
.dispatch("new_download_group", async move { DownloadGroup::new(session).await })
|
||||
.await??;
|
||||
|
||||
// Register the group
|
||||
// Register the group (sync insertion, safe in any executor context)
|
||||
self.active_download_groups.lock()?.insert(group.id(), group.clone());
|
||||
|
||||
Ok(group)
|
||||
}
|
||||
|
||||
/// Create a new [`DownloadGroup`] from a **sync** (non-async) context.
|
||||
///
|
||||
/// Returns `Err(SessionError::Aborted)` if the session has been aborted.
|
||||
/// Returns `Err(SessionError::WrongRuntimeMode)` if the session was built with
|
||||
/// [`XetSessionBuilder::with_tokio_handle`] / [`XetSessionBuilder::build_async`].
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if called from within a **tokio** async runtime (tokio sets a thread-local
|
||||
/// context that `Handle::block_on` detects and panics on). Non-tokio executors (smol,
|
||||
/// async-std, `futures::executor`) do not set this context, so calling from those is
|
||||
/// safe — it blocks the executor thread until the task completes. Use
|
||||
/// [`new_download_group`](Self::new_download_group) from async contexts instead.
|
||||
pub fn new_download_group_blocking(&self) -> Result<DownloadGroupSync, SessionError> {
|
||||
if matches!(self.runtime_mode, RuntimeMode::External) {
|
||||
return Err(SessionError::wrong_mode(
|
||||
"new_download_group_blocking() cannot be called on a session built with \
|
||||
with_tokio_handle() / build_async(); use new_download_group().await instead",
|
||||
));
|
||||
}
|
||||
{
|
||||
let state = self.state.lock()?;
|
||||
if matches!(*state, SessionState::Aborted) {
|
||||
return Err(SessionError::Aborted);
|
||||
}
|
||||
}
|
||||
|
||||
let sync_group = DownloadGroupSync::new(self.clone())?;
|
||||
self.active_download_groups
|
||||
.lock()?
|
||||
.insert(sync_group.inner.id(), sync_group.inner.clone());
|
||||
Ok(sync_group)
|
||||
}
|
||||
|
||||
/// Abort the session - cancel all currently running tasks
|
||||
///
|
||||
/// This performs a SIGINT-style shutdown, aborting all active upload and download tasks.
|
||||
@@ -266,19 +506,19 @@ impl XetSession {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn check_alive(&self) -> Result<(), SessionError> {
|
||||
pub(super) fn check_alive(&self) -> Result<(), SessionError> {
|
||||
if matches!(*self.state.lock()?, SessionState::Aborted) {
|
||||
return Err(SessionError::Aborted);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn finish_upload_commit(&self, commit_id: Ulid) -> Result<(), SessionError> {
|
||||
pub(super) fn finish_upload_commit(&self, commit_id: Ulid) -> Result<(), SessionError> {
|
||||
self.active_upload_commits.lock()?.remove(&commit_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn finish_download_group(&self, group_id: Ulid) -> Result<(), SessionError> {
|
||||
pub(super) fn finish_download_group(&self, group_id: Ulid) -> Result<(), SessionError> {
|
||||
self.active_download_groups.lock()?.remove(&group_id);
|
||||
Ok(())
|
||||
}
|
||||
@@ -288,16 +528,12 @@ impl XetSession {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn make_session() -> XetSession {
|
||||
XetSession::new(None, None, None, None).expect("Failed to create session")
|
||||
}
|
||||
|
||||
// ── Identity ─────────────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
// A clone refers to the same inner Arc, so their session IDs must match.
|
||||
fn test_session_clone_shares_state() {
|
||||
let s1 = make_session();
|
||||
let s1 = XetSessionBuilder::new().build().unwrap();
|
||||
let s2 = s1.clone();
|
||||
assert_eq!(s1.id, s2.id);
|
||||
}
|
||||
@@ -305,8 +541,8 @@ mod tests {
|
||||
#[test]
|
||||
// Two independently created sessions have distinct IDs.
|
||||
fn test_two_sessions_have_distinct_ids() {
|
||||
let s1 = make_session();
|
||||
let s2 = make_session();
|
||||
let s1 = XetSessionBuilder::new().build().unwrap();
|
||||
let s2 = XetSessionBuilder::new().build().unwrap();
|
||||
assert_ne!(s1.id, s2.id);
|
||||
}
|
||||
|
||||
@@ -315,36 +551,36 @@ mod tests {
|
||||
#[test]
|
||||
// After abort, check_alive returns Aborted.
|
||||
fn test_check_alive_after_abort() {
|
||||
let session = make_session();
|
||||
let session = XetSessionBuilder::new().build().unwrap();
|
||||
session.abort().unwrap();
|
||||
let err = session.check_alive().unwrap_err();
|
||||
assert!(matches!(err, SessionError::Aborted));
|
||||
}
|
||||
|
||||
#[test]
|
||||
// new_upload_commit on an aborted session returns Aborted.
|
||||
// new_upload_commit_blocking on an aborted session returns Aborted.
|
||||
fn test_new_upload_commit_after_abort_returns_aborted() {
|
||||
let session = make_session();
|
||||
let session = XetSessionBuilder::new().build().unwrap();
|
||||
session.abort().unwrap();
|
||||
let err = session.new_upload_commit().err().unwrap();
|
||||
let err = session.new_upload_commit_blocking().err().unwrap();
|
||||
assert!(matches!(err, SessionError::Aborted));
|
||||
}
|
||||
|
||||
#[test]
|
||||
// new_download_group on an aborted session returns Aborted.
|
||||
// new_download_group_blocking on an aborted session returns Aborted.
|
||||
fn test_new_download_group_after_abort_returns_aborted() {
|
||||
let session = make_session();
|
||||
let session = XetSessionBuilder::new().build().unwrap();
|
||||
session.abort().unwrap();
|
||||
let err = session.new_download_group().err().unwrap();
|
||||
let err = session.new_download_group_blocking().err().unwrap();
|
||||
assert!(matches!(err, SessionError::Aborted));
|
||||
}
|
||||
|
||||
#[test]
|
||||
// Aborting a session clears all registered upload commits.
|
||||
fn test_abort_clears_active_upload_commits() {
|
||||
let session = make_session();
|
||||
let _c1 = session.new_upload_commit().unwrap();
|
||||
let _c2 = session.new_upload_commit().unwrap();
|
||||
let session = XetSessionBuilder::new().build().unwrap();
|
||||
let _c1 = session.new_upload_commit_blocking().unwrap();
|
||||
let _c2 = session.new_upload_commit_blocking().unwrap();
|
||||
session.abort().unwrap();
|
||||
assert_eq!(session.active_upload_commits.lock().unwrap().len(), 0);
|
||||
}
|
||||
@@ -352,8 +588,8 @@ mod tests {
|
||||
#[test]
|
||||
// Aborting a session clears all registered download groups.
|
||||
fn test_abort_clears_active_download_groups() {
|
||||
let session = make_session();
|
||||
let _g1 = session.new_download_group().unwrap();
|
||||
let session = XetSessionBuilder::new().build().unwrap();
|
||||
let _g1 = session.new_download_group_blocking().unwrap();
|
||||
session.abort().unwrap();
|
||||
assert_eq!(session.active_download_groups.lock().unwrap().len(), 0);
|
||||
}
|
||||
@@ -363,16 +599,16 @@ mod tests {
|
||||
#[test]
|
||||
// A new upload commit is registered in the session's active set.
|
||||
fn test_new_upload_commit_registers_in_session() {
|
||||
let session = make_session();
|
||||
let _commit = session.new_upload_commit().unwrap();
|
||||
let session = XetSessionBuilder::new().build().unwrap();
|
||||
let _commit = session.new_upload_commit_blocking().unwrap();
|
||||
assert_eq!(session.active_upload_commits.lock().unwrap().len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
// A new download group is registered in the session's active set.
|
||||
fn test_new_download_group_registers_in_session() {
|
||||
let session = make_session();
|
||||
let _group = session.new_download_group().unwrap();
|
||||
let session = XetSessionBuilder::new().build().unwrap();
|
||||
let _group = session.new_download_group_blocking().unwrap();
|
||||
assert_eq!(session.active_download_groups.lock().unwrap().len(), 1);
|
||||
}
|
||||
|
||||
@@ -381,32 +617,176 @@ mod tests {
|
||||
#[test]
|
||||
// finish_upload_commit removes only the specified commit, leaving others intact.
|
||||
fn test_finish_upload_commit_removes_only_that_commit() {
|
||||
let session = make_session();
|
||||
let c1 = session.new_upload_commit().unwrap();
|
||||
let _c2 = session.new_upload_commit().unwrap();
|
||||
let session = XetSessionBuilder::new().build().unwrap();
|
||||
let c1 = session.new_upload_commit_blocking().unwrap();
|
||||
let _c2 = session.new_upload_commit_blocking().unwrap();
|
||||
assert_eq!(session.active_upload_commits.lock().unwrap().len(), 2);
|
||||
session.finish_upload_commit(c1.id()).unwrap();
|
||||
session.finish_upload_commit(c1.inner.id()).unwrap();
|
||||
assert_eq!(session.active_upload_commits.lock().unwrap().len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
// finish_download_group removes only the specified group, leaving others intact.
|
||||
fn test_finish_download_group_removes_only_that_group() {
|
||||
let session = make_session();
|
||||
let g1 = session.new_download_group().unwrap();
|
||||
let _g2 = session.new_download_group().unwrap();
|
||||
let session = XetSessionBuilder::new().build().unwrap();
|
||||
let g1 = session.new_download_group_blocking().unwrap();
|
||||
let _g2 = session.new_download_group_blocking().unwrap();
|
||||
assert_eq!(session.active_download_groups.lock().unwrap().len(), 2);
|
||||
session.finish_download_group(g1.id()).unwrap();
|
||||
session.finish_download_group(g1.inner.id()).unwrap();
|
||||
assert_eq!(session.active_download_groups.lock().unwrap().len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
// finish_upload_commit on an unknown ID is a no-op (no error, no change).
|
||||
fn test_finish_upload_commit_with_unknown_id_is_noop() {
|
||||
let session = make_session();
|
||||
let _c1 = session.new_upload_commit().unwrap();
|
||||
let session = XetSessionBuilder::new().build().unwrap();
|
||||
let _c1 = session.new_upload_commit_blocking().unwrap();
|
||||
let unknown_id = ulid::Ulid::new();
|
||||
assert!(session.finish_upload_commit(unknown_id).is_ok());
|
||||
assert_eq!(session.active_upload_commits.lock().unwrap().len(), 1);
|
||||
}
|
||||
|
||||
// ── Async abort behavior ──────────────────────────────────────────────────
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
// new_upload_commit / new_download_group on an aborted session both return Aborted.
|
||||
async fn test_async_new_after_abort_returns_aborted() {
|
||||
let session = XetSessionBuilder::new().build_async().await.unwrap();
|
||||
session.abort().unwrap();
|
||||
let commit_err = session.new_upload_commit().await.err().unwrap();
|
||||
let group_err = session.new_download_group().await.err().unwrap();
|
||||
assert!(matches!(commit_err, SessionError::Aborted));
|
||||
assert!(matches!(group_err, SessionError::Aborted));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
// Aborting a session clears all active upload commits and download groups.
|
||||
async fn test_async_abort_clears_active_commits_and_groups() {
|
||||
let session = XetSessionBuilder::new().build_async().await.unwrap();
|
||||
let _c1 = session.new_upload_commit().await.unwrap();
|
||||
let _c2 = session.new_upload_commit().await.unwrap();
|
||||
let _g1 = session.new_download_group().await.unwrap();
|
||||
session.abort().unwrap();
|
||||
assert_eq!(session.active_upload_commits.lock().unwrap().len(), 0);
|
||||
assert_eq!(session.active_download_groups.lock().unwrap().len(), 0);
|
||||
}
|
||||
|
||||
// ── Async registration ────────────────────────────────────────────────────
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
// A new upload commit and a new download group are each registered in the
|
||||
// session's active set, and concurrent creation registers both.
|
||||
async fn test_async_new_registers_in_session() {
|
||||
let session = XetSessionBuilder::new().build_async().await.unwrap();
|
||||
let _commit = session.new_upload_commit().await.unwrap();
|
||||
let _group = session.new_download_group().await.unwrap();
|
||||
assert_eq!(session.active_upload_commits.lock().unwrap().len(), 1);
|
||||
assert_eq!(session.active_download_groups.lock().unwrap().len(), 1);
|
||||
}
|
||||
|
||||
// ── Async deregistration ──────────────────────────────────────────────────
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
// Finishing one upload commit / download group removes only that one,
|
||||
// leaving the other still registered.
|
||||
async fn test_async_finish_removes_only_that_item() {
|
||||
let session = XetSessionBuilder::new().build_async().await.unwrap();
|
||||
let c1 = session.new_upload_commit().await.unwrap();
|
||||
let _c2 = session.new_upload_commit().await.unwrap();
|
||||
let g1 = session.new_download_group().await.unwrap();
|
||||
let _g2 = session.new_download_group().await.unwrap();
|
||||
assert_eq!(session.active_upload_commits.lock().unwrap().len(), 2);
|
||||
assert_eq!(session.active_download_groups.lock().unwrap().len(), 2);
|
||||
session.finish_upload_commit(c1.id()).unwrap();
|
||||
session.finish_download_group(g1.id()).unwrap();
|
||||
assert_eq!(session.active_upload_commits.lock().unwrap().len(), 1);
|
||||
assert_eq!(session.active_download_groups.lock().unwrap().len(), 1);
|
||||
}
|
||||
|
||||
// ── handle_meets_session_requirements ────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
// A multi-thread runtime with enable_all() meets all requirements.
|
||||
fn test_handle_multi_thread_all_features_returns_true() {
|
||||
let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
|
||||
assert!(handle_meets_session_requirements(rt.handle()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
// A current_thread runtime is rejected even when enable_all() is set.
|
||||
fn test_handle_current_thread_returns_false() {
|
||||
let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap();
|
||||
assert!(!handle_meets_session_requirements(rt.handle()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
// A multi-thread runtime with no drivers enabled returns false.
|
||||
fn test_handle_without_any_driver_returns_false() {
|
||||
let rt = tokio::runtime::Builder::new_multi_thread().build().unwrap();
|
||||
assert!(!handle_meets_session_requirements(rt.handle()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
// A multi-thread runtime with only enable_time() is missing the IO driver.
|
||||
fn test_handle_without_io_driver_returns_false() {
|
||||
let rt = tokio::runtime::Builder::new_multi_thread().enable_time().build().unwrap();
|
||||
assert!(!handle_meets_session_requirements(rt.handle()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
// A multi-thread runtime with only enable_io() is missing the time driver.
|
||||
fn test_handle_without_time_driver_returns_false() {
|
||||
let rt = tokio::runtime::Builder::new_multi_thread().enable_io().build().unwrap();
|
||||
assert!(!handle_meets_session_requirements(rt.handle()));
|
||||
}
|
||||
|
||||
// ── External-mode _blocking guard ────────────────────────────────────────
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
// new_upload_commit_blocking returns WrongRuntimeMode on an External-mode session.
|
||||
async fn test_new_upload_commit_blocking_errors_in_external_mode() {
|
||||
let session = XetSessionBuilder::new().build_async().await.unwrap();
|
||||
assert_eq!(session.runtime_mode, RuntimeMode::External);
|
||||
let err = session.new_upload_commit_blocking().err().unwrap();
|
||||
assert!(matches!(err, SessionError::WrongRuntimeMode(_)));
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
// new_download_group_blocking returns WrongRuntimeMode on an External-mode session.
|
||||
async fn test_new_download_group_blocking_errors_in_external_mode() {
|
||||
let session = XetSessionBuilder::new().build_async().await.unwrap();
|
||||
assert_eq!(session.runtime_mode, RuntimeMode::External);
|
||||
let err = session.new_download_group_blocking().err().unwrap();
|
||||
assert!(matches!(err, SessionError::WrongRuntimeMode(_)));
|
||||
}
|
||||
|
||||
// ── Owned-mode _blocking panic guard ─────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
// new_upload_commit_blocking panics when called from within a tokio runtime on an
|
||||
// Owned-mode session: external_run_async_task calls handle.block_on(), which panics
|
||||
// because tokio sets a thread-local runtime context that it detects and rejects.
|
||||
fn test_new_upload_commit_blocking_panics_in_async_context() {
|
||||
let session = XetSessionBuilder::new().build().unwrap();
|
||||
assert_eq!(session.runtime_mode, RuntimeMode::Owned);
|
||||
let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
|
||||
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
|
||||
rt.block_on(async { session.new_upload_commit_blocking() })
|
||||
}));
|
||||
assert!(result.is_err(), "new_upload_commit_blocking() must panic when called from async");
|
||||
}
|
||||
|
||||
#[test]
|
||||
// new_download_group_blocking panics when called from within a tokio runtime on an
|
||||
// Owned-mode session: same mechanism as the upload variant above.
|
||||
fn test_new_download_group_blocking_panics_in_async_context() {
|
||||
let session = XetSessionBuilder::new().build().unwrap();
|
||||
assert_eq!(session.runtime_mode, RuntimeMode::Owned);
|
||||
let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
|
||||
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
|
||||
rt.block_on(async { session.new_download_group_blocking() })
|
||||
}));
|
||||
assert!(result.is_err(), "new_download_group_blocking() must panic when called from async");
|
||||
}
|
||||
}
|
||||
|
||||
286
xet_pkg/src/xet_session/sync/download_group_sync.rs
Normal file
286
xet_pkg/src/xet_session/sync/download_group_sync.rs
Normal file
@@ -0,0 +1,286 @@
|
||||
//! Sync-context download group wrapper.
|
||||
//!
|
||||
//! [`DownloadGroupSync`] is obtained from [`XetSession::new_download_group_blocking`] and
|
||||
//! provides a fully blocking API suitable for sync Rust or Python (PyO3) callers.
|
||||
//! For async Rust use [`DownloadGroup`] instead.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use ulid::Ulid;
|
||||
use xet_data::processing::XetFileInfo;
|
||||
|
||||
use super::super::download_group::{DownloadGroup, DownloadResult};
|
||||
use super::super::errors::SessionError;
|
||||
use super::super::progress::{DownloadTaskHandle, ProgressSnapshot};
|
||||
use super::super::session::XetSession;
|
||||
|
||||
/// Sync-context handle for grouping related file downloads.
|
||||
///
|
||||
/// Obtained from [`XetSession::new_download_group_blocking`]. All methods block
|
||||
/// the calling thread — **do not use from within an async runtime** (it will panic).
|
||||
/// For async Rust code use [`DownloadGroup`] from [`XetSession::new_download_group`].
|
||||
///
|
||||
/// # Cloning
|
||||
///
|
||||
/// Cloning is cheap — it simply increments an atomic reference count.
|
||||
/// All clones share the same background worker and task state.
|
||||
#[derive(Clone)]
|
||||
pub struct DownloadGroupSync {
|
||||
pub(in super::super) inner: DownloadGroup,
|
||||
}
|
||||
|
||||
impl DownloadGroupSync {
|
||||
/// Create a new download group from a **sync** (non-async) context.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if called from within an async runtime — use
|
||||
/// [`XetSession::new_download_group`] instead.
|
||||
pub(in super::super) fn new(session: XetSession) -> Result<Self, SessionError> {
|
||||
let group = session.runtime.external_run_async_task(DownloadGroup::new(session.clone()))??;
|
||||
Ok(Self { inner: group })
|
||||
}
|
||||
|
||||
/// Queue a file for download. See [`DownloadGroup::download_file_to_path`] for full documentation.
|
||||
pub fn download_file_to_path(
|
||||
&self,
|
||||
file_info: XetFileInfo,
|
||||
dest_path: PathBuf,
|
||||
) -> Result<DownloadTaskHandle, SessionError> {
|
||||
self.inner.download_file_to_path(file_info, dest_path)
|
||||
}
|
||||
|
||||
/// Return a snapshot of progress for every queued download.
|
||||
pub fn get_progress(&self) -> Result<ProgressSnapshot, SessionError> {
|
||||
self.inner.get_progress()
|
||||
}
|
||||
|
||||
/// Wait for all downloads to complete and return their results.
|
||||
///
|
||||
/// See [`DownloadGroup::finish`] for full documentation.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if called from within an async runtime. Use [`DownloadGroup::finish`] instead.
|
||||
pub fn finish(self) -> Result<HashMap<Ulid, DownloadResult>, SessionError> {
|
||||
let group = self.inner.clone();
|
||||
self.inner.runtime().external_run_async_task(group.finish())?
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
use tempfile::{TempDir, tempdir};
|
||||
|
||||
use super::*;
|
||||
use crate::xet_session::progress::UploadTaskHandle;
|
||||
use crate::xet_session::session::{XetSession, XetSessionBuilder};
|
||||
|
||||
fn local_session(temp: &TempDir) -> Result<XetSession, Box<dyn std::error::Error>> {
|
||||
let cas_path = temp.path().join("cas");
|
||||
Ok(XetSessionBuilder::new()
|
||||
.with_endpoint(format!("local://{}", cas_path.display()))
|
||||
.build()?)
|
||||
}
|
||||
|
||||
fn upload_bytes(session: &XetSession, data: &[u8], name: &str) -> Result<XetFileInfo, Box<dyn std::error::Error>> {
|
||||
let commit = session.new_upload_commit_blocking()?;
|
||||
let handle = commit.upload_bytes(data.to_vec(), Some(name.into()))?;
|
||||
let results = commit.commit()?;
|
||||
let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap();
|
||||
Ok(XetFileInfo {
|
||||
hash: meta.hash.clone(),
|
||||
file_size: meta.file_size,
|
||||
})
|
||||
}
|
||||
|
||||
// ── Round-trip tests ─────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
// Downloading a previously uploaded file produces byte-identical content at the destination.
|
||||
fn test_download_file_round_trip() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let temp = tempdir()?;
|
||||
let session = local_session(&temp)?;
|
||||
let original = b"Hello, download round-trip!";
|
||||
let file_info = upload_bytes(&session, original, "payload.bin")?;
|
||||
|
||||
let dest = temp.path().join("downloaded.bin");
|
||||
let group = session.new_download_group_blocking()?;
|
||||
group.download_file_to_path(file_info, dest.clone())?;
|
||||
group.finish()?;
|
||||
|
||||
assert_eq!(std::fs::read(&dest)?, original);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
// Downloading multiple files from a single group produces correct content for each.
|
||||
fn test_download_multiple_files() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let temp = tempdir()?;
|
||||
let session = local_session(&temp)?;
|
||||
|
||||
let data_a = b"First file content";
|
||||
let data_b = b"Second file content - different";
|
||||
|
||||
let commit = session.new_upload_commit_blocking()?;
|
||||
let handle_a = commit.upload_bytes(data_a.to_vec(), Some("a.bin".into()))?;
|
||||
let handle_b = commit.upload_bytes(data_b.to_vec(), Some("b.bin".into()))?;
|
||||
let results = commit.commit()?;
|
||||
|
||||
let to_file_info = |handle: &UploadTaskHandle| -> XetFileInfo {
|
||||
let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap();
|
||||
XetFileInfo {
|
||||
hash: meta.hash.clone(),
|
||||
file_size: meta.file_size,
|
||||
}
|
||||
};
|
||||
|
||||
let dest_a = temp.path().join("a_out.bin");
|
||||
let dest_b = temp.path().join("b_out.bin");
|
||||
let group = session.new_download_group_blocking()?;
|
||||
group.download_file_to_path(to_file_info(&handle_a), dest_a.clone())?;
|
||||
group.download_file_to_path(to_file_info(&handle_b), dest_b.clone())?;
|
||||
group.finish()?;
|
||||
|
||||
assert_eq!(std::fs::read(&dest_a)?, data_a);
|
||||
assert_eq!(std::fs::read(&dest_b)?, data_b);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
// After a successful finish the aggregate download progress reflects bytes received.
|
||||
fn test_download_progress_reflects_bytes_after_finish() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let temp = tempdir()?;
|
||||
let session = local_session(&temp)?;
|
||||
let original = b"download progress tracking data";
|
||||
let file_info = upload_bytes(&session, original, "prog.bin")?;
|
||||
|
||||
let dest = temp.path().join("out.bin");
|
||||
let group = session.new_download_group_blocking()?;
|
||||
let progress_observer = group.clone();
|
||||
group.download_file_to_path(file_info, dest)?;
|
||||
group.finish()?;
|
||||
|
||||
std::thread::sleep(
|
||||
session
|
||||
.runtime
|
||||
.config()
|
||||
.data
|
||||
.progress_update_interval
|
||||
.saturating_add(Duration::from_secs(1)),
|
||||
);
|
||||
let snapshot = progress_observer.get_progress()?;
|
||||
assert!(snapshot.total().total_bytes_completed > 0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Per-task result access patterns ──────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
// Pattern 1: per-task result is accessible via task_id in the finish() HashMap.
|
||||
fn test_download_result_accessible_via_task_id_in_finish_map() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let temp = tempdir()?;
|
||||
let session = local_session(&temp)?;
|
||||
let data = b"result via task_id in finish map";
|
||||
let file_info = upload_bytes(&session, data, "file.bin")?;
|
||||
let dest = temp.path().join("out.bin");
|
||||
let group = session.new_download_group_blocking()?;
|
||||
let handle = group.download_file_to_path(file_info, dest)?;
|
||||
let results = group.finish()?;
|
||||
let result = results.get(&handle.task_id).expect("task_id must be present in results");
|
||||
assert_eq!(result.as_ref().as_ref().unwrap().file_info.file_size, data.len() as u64);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
// DownloadTaskHandle::result() returns None before finish() is called.
|
||||
fn test_download_result_none_before_finish() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let temp = tempdir()?;
|
||||
let session = local_session(&temp)?;
|
||||
let file_info = upload_bytes(&session, b"some data", "file.bin")?;
|
||||
let dest = temp.path().join("out.bin");
|
||||
let group = session.new_download_group_blocking()?;
|
||||
let handle = group.download_file_to_path(file_info, dest)?;
|
||||
assert!(handle.result().is_none(), "result must be None before finish()");
|
||||
group.finish()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
// DownloadTaskHandle::result() returns Some after finish() completes.
|
||||
fn test_download_result_some_after_finish() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let temp = tempdir()?;
|
||||
let session = local_session(&temp)?;
|
||||
let data = b"download result test data";
|
||||
let file_info = upload_bytes(&session, data, "file.bin")?;
|
||||
let dest = temp.path().join("out.bin");
|
||||
let group = session.new_download_group_blocking()?;
|
||||
let handle = group.download_file_to_path(file_info.clone(), dest)?;
|
||||
group.finish()?;
|
||||
let result = handle.result().unwrap();
|
||||
let dl = result.as_ref().as_ref().unwrap();
|
||||
assert_eq!(dl.file_info.file_size, data.len() as u64);
|
||||
assert_eq!(dl.file_info.hash, file_info.hash);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Non-tokio executor (no-panic + round-trip) ────────────────────────────
|
||||
//
|
||||
// smol, async-std, and futures::executor do not set tokio's thread-local
|
||||
// runtime context, so Handle::block_on inside external_run_async_task does
|
||||
// not panic — it just blocks the calling executor thread.
|
||||
|
||||
#[test]
|
||||
// new_download_group_blocking completes a full upload+download round-trip inside smol.
|
||||
fn test_download_blocking_round_trip_in_smol() {
|
||||
let temp = tempdir().unwrap();
|
||||
let session = local_session(&temp).unwrap();
|
||||
|
||||
smol::block_on(async {
|
||||
let data = b"download from smol executor";
|
||||
let file_info = upload_bytes(&session, data, "test.bin").unwrap();
|
||||
let dest = temp.path().join("out_smol.bin");
|
||||
let group = session.new_download_group_blocking().unwrap();
|
||||
group.download_file_to_path(file_info, dest.clone()).unwrap();
|
||||
group.finish().unwrap();
|
||||
assert_eq!(std::fs::read(&dest).unwrap(), data);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
// new_download_group_blocking completes a full upload+download round-trip inside futures::executor.
|
||||
fn test_download_blocking_round_trip_in_futures_executor() {
|
||||
let temp = tempdir().unwrap();
|
||||
let session = local_session(&temp).unwrap();
|
||||
|
||||
futures::executor::block_on(async {
|
||||
let data = b"download from futures executor";
|
||||
let file_info = upload_bytes(&session, data, "test.bin").unwrap();
|
||||
let dest = temp.path().join("out_futures.bin");
|
||||
let group = session.new_download_group_blocking().unwrap();
|
||||
group.download_file_to_path(file_info, dest.clone()).unwrap();
|
||||
group.finish().unwrap();
|
||||
assert_eq!(std::fs::read(&dest).unwrap(), data);
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
// new_download_group_blocking completes a full upload+download round-trip inside async-std.
|
||||
fn test_download_blocking_round_trip_in_async_std() {
|
||||
let temp = tempdir().unwrap();
|
||||
let session = local_session(&temp).unwrap();
|
||||
|
||||
async_std::task::block_on(async {
|
||||
let data = b"download from async-std executor";
|
||||
let file_info = upload_bytes(&session, data, "test.bin").unwrap();
|
||||
let dest = temp.path().join("out_async_std.bin");
|
||||
let group = session.new_download_group_blocking().unwrap();
|
||||
group.download_file_to_path(file_info, dest.clone()).unwrap();
|
||||
group.finish().unwrap();
|
||||
assert_eq!(std::fs::read(&dest).unwrap(), data);
|
||||
});
|
||||
}
|
||||
}
|
||||
5
xet_pkg/src/xet_session/sync/mod.rs
Normal file
5
xet_pkg/src/xet_session/sync/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
pub mod download_group_sync;
|
||||
pub mod upload_commit_sync;
|
||||
|
||||
pub use download_group_sync::DownloadGroupSync;
|
||||
pub use upload_commit_sync::UploadCommitSync;
|
||||
315
xet_pkg/src/xet_session/sync/upload_commit_sync.rs
Normal file
315
xet_pkg/src/xet_session/sync/upload_commit_sync.rs
Normal file
@@ -0,0 +1,315 @@
|
||||
//! Sync-context upload commit wrapper.
|
||||
//!
|
||||
//! [`UploadCommitSync`] is obtained from [`XetSession::new_upload_commit_blocking`] and
|
||||
//! provides a fully blocking API suitable for sync Rust or Python (PyO3) callers.
|
||||
//! For async Rust use [`UploadCommit`] instead.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use ulid::Ulid;
|
||||
use xet_data::processing::SingleFileCleaner;
|
||||
|
||||
use super::super::errors::SessionError;
|
||||
use super::super::progress::{ProgressSnapshot, TaskHandle, UploadTaskHandle};
|
||||
use super::super::session::XetSession;
|
||||
use super::super::upload_commit::{UploadCommit, UploadResult};
|
||||
|
||||
/// Sync-context handle for grouping related file uploads.
|
||||
///
|
||||
/// Obtained from [`XetSession::new_upload_commit_blocking`]. All methods block
|
||||
/// the calling thread — **do not use from within an async runtime** (it will panic).
|
||||
/// For async Rust code use [`UploadCommit`] from [`XetSession::new_upload_commit`].
|
||||
///
|
||||
/// # Cloning
|
||||
///
|
||||
/// Cloning is cheap — it simply increments an atomic reference count.
|
||||
/// All clones share the same upload session and task state.
|
||||
#[derive(Clone)]
|
||||
pub struct UploadCommitSync {
|
||||
pub(in super::super) inner: UploadCommit,
|
||||
}
|
||||
|
||||
impl UploadCommitSync {
|
||||
/// Create a new upload commit from a **sync** (non-async) context.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if called from within an async runtime — use
|
||||
/// [`XetSession::new_upload_commit`] instead.
|
||||
pub(in super::super) fn new(session: XetSession) -> Result<Self, SessionError> {
|
||||
let commit = session.runtime.external_run_async_task(UploadCommit::new(session.clone()))??;
|
||||
Ok(Self { inner: commit })
|
||||
}
|
||||
|
||||
/// Queue a file for upload. See [`UploadCommit::upload_from_path`] for full documentation.
|
||||
pub fn upload_from_path(&self, file_path: PathBuf) -> Result<UploadTaskHandle, SessionError> {
|
||||
self.inner.session.check_alive()?;
|
||||
|
||||
let commit_inner = self.inner.inner.clone();
|
||||
self.inner
|
||||
.runtime()
|
||||
.external_run_async_task(async move { commit_inner.start_upload_file_from_path(file_path).await })?
|
||||
}
|
||||
|
||||
/// Queue raw bytes for upload. See [`UploadCommit::upload_bytes`] for full documentation.
|
||||
pub fn upload_bytes(
|
||||
&self,
|
||||
bytes: Vec<u8>,
|
||||
tracking_name: Option<String>,
|
||||
) -> Result<UploadTaskHandle, SessionError> {
|
||||
self.inner.session.check_alive()?;
|
||||
|
||||
let commit_inner = self.inner.inner.clone();
|
||||
self.inner
|
||||
.runtime()
|
||||
.external_run_async_task(async move { commit_inner.start_upload_bytes(bytes, tracking_name).await })?
|
||||
}
|
||||
|
||||
/// Begin an incremental file upload, returning a [`SingleFileCleaner`] to stream bytes.
|
||||
///
|
||||
/// This is the sync-context equivalent of [`UploadCommit::upload_file`].
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// - `file_name`: optional name used for progress/telemetry reporting.
|
||||
/// - `file_size`: expected size in bytes (used for progress tracking; `0` is valid if unknown).
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if called from within an async runtime. Use [`UploadCommit::upload_file`] instead.
|
||||
pub fn upload_file(
|
||||
&self,
|
||||
file_name: Option<String>,
|
||||
file_size: u64,
|
||||
) -> Result<(TaskHandle, SingleFileCleaner), SessionError> {
|
||||
self.inner.session.check_alive()?;
|
||||
|
||||
let commit_inner = self.inner.clone();
|
||||
self.inner
|
||||
.runtime()
|
||||
.external_run_async_task(async move { commit_inner.start_upload_file(file_name, file_size).await })?
|
||||
}
|
||||
|
||||
/// Return a snapshot of progress for every queued upload.
|
||||
pub fn get_progress(&self) -> Result<ProgressSnapshot, SessionError> {
|
||||
self.inner.get_progress()
|
||||
}
|
||||
|
||||
/// Wait for all uploads to complete and push metadata to the CAS server.
|
||||
///
|
||||
/// Returns a `HashMap` keyed by task ID. See [`UploadCommit::commit`] for full documentation.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if called from within an async runtime. Use [`UploadCommit::commit`] instead.
|
||||
pub fn commit(self) -> Result<HashMap<Ulid, UploadResult>, SessionError> {
|
||||
let commit = self.inner.clone();
|
||||
self.inner.runtime().external_run_async_task(commit.commit())?
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use tempfile::{TempDir, tempdir};
|
||||
|
||||
use crate::xet_session::session::{XetSession, XetSessionBuilder};
|
||||
|
||||
fn local_session(temp: &TempDir) -> Result<XetSession, Box<dyn std::error::Error>> {
|
||||
let cas_path = temp.path().join("cas");
|
||||
Ok(XetSessionBuilder::new()
|
||||
.with_endpoint(format!("local://{}", cas_path.display()))
|
||||
.build()?)
|
||||
}
|
||||
|
||||
// ── Round-trip tests ─────────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
// Uploading raw bytes and committing returns a non-empty hash and the correct file size.
|
||||
fn test_upload_bytes_round_trip() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let temp = tempdir()?;
|
||||
let session = local_session(&temp)?;
|
||||
let data = b"Hello, upload commit round-trip!";
|
||||
let commit = session.new_upload_commit_blocking()?;
|
||||
let task_handle = commit.upload_bytes(data.to_vec(), Some("hello.bin".into()))?;
|
||||
let results = commit.commit()?;
|
||||
assert_eq!(results.len(), 1);
|
||||
let meta = results.get(&task_handle.task_id).unwrap().as_ref().as_ref().unwrap();
|
||||
assert_eq!(meta.file_size, data.len() as u64);
|
||||
assert!(!meta.hash.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
// Uploading a file from disk and committing returns the correct file size.
|
||||
fn test_upload_from_path_round_trip() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let temp = tempdir()?;
|
||||
let session = local_session(&temp)?;
|
||||
let src = temp.path().join("data.bin");
|
||||
let data = b"file path upload content";
|
||||
std::fs::write(&src, data)?;
|
||||
let commit = session.new_upload_commit_blocking()?;
|
||||
let handle = commit.upload_from_path(src)?;
|
||||
commit.commit()?;
|
||||
let meta = handle.result().unwrap();
|
||||
let meta = meta.as_ref().as_ref().unwrap();
|
||||
assert_eq!(meta.file_size, data.len() as u64);
|
||||
assert!(!meta.hash.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Per-task result access patterns ──────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
// UploadTaskHandle::result() returns None before commit() is called.
|
||||
fn test_upload_result_none_before_commit() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let temp = tempdir()?;
|
||||
let session = local_session(&temp)?;
|
||||
let src = temp.path().join("data.bin");
|
||||
std::fs::write(&src, b"content")?;
|
||||
let commit = session.new_upload_commit_blocking()?;
|
||||
let handle = commit.upload_from_path(src)?;
|
||||
assert!(handle.result().is_none(), "result must be None before commit()");
|
||||
commit.commit()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
// Pattern 1: per-task result is accessible via task_id in the commit() HashMap.
|
||||
fn test_upload_result_accessible_via_task_id_in_commit_map() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let temp = tempdir()?;
|
||||
let session = local_session(&temp)?;
|
||||
let data = b"result via task_id";
|
||||
let src = temp.path().join("data.bin");
|
||||
std::fs::write(&src, data)?;
|
||||
let commit = session.new_upload_commit_blocking()?;
|
||||
let handle = commit.upload_from_path(src)?;
|
||||
let results = commit.commit()?;
|
||||
let result = results.get(&handle.task_id).expect("task_id must be present in results");
|
||||
assert_eq!(result.as_ref().as_ref().unwrap().file_size, data.len() as u64);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
// Pattern 2: per-task result is accessible directly from the UploadTaskHandle after commit().
|
||||
fn test_upload_result_accessible_via_handle_after_commit() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let temp = tempdir()?;
|
||||
let session = local_session(&temp)?;
|
||||
let data = b"result via handle";
|
||||
let src = temp.path().join("data.bin");
|
||||
std::fs::write(&src, data)?;
|
||||
let commit = session.new_upload_commit_blocking()?;
|
||||
let handle = commit.upload_from_path(src)?;
|
||||
commit.commit()?;
|
||||
let result = handle.result().expect("result must be set after commit");
|
||||
assert_eq!(result.as_ref().as_ref().unwrap().file_size, data.len() as u64);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
// Streaming upload via upload_file + SingleFileCleaner.
|
||||
fn test_upload_streaming_round_trip() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let temp = tempdir()?;
|
||||
let session = local_session(&temp)?;
|
||||
let data = b"streamed upload bytes";
|
||||
let runtime = session.runtime.clone();
|
||||
let commit = session.new_upload_commit_blocking()?;
|
||||
let (_handle, mut cleaner) = commit.upload_file(Some("stream.bin".into()), data.len() as u64)?;
|
||||
let (hash, file_size) = runtime.external_run_async_task(async move {
|
||||
cleaner.add_data(data).await.unwrap();
|
||||
let (xfi, _) = cleaner.finish().await.unwrap();
|
||||
(xfi.hash, xfi.file_size)
|
||||
})?;
|
||||
let results = commit.commit()?;
|
||||
assert!(results.is_empty());
|
||||
assert_eq!(file_size, data.len() as u64);
|
||||
assert!(!hash.is_empty());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
// Uploading multiple blobs in one commit returns one result per upload.
|
||||
fn test_upload_multiple_files_in_one_commit() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let temp = tempdir()?;
|
||||
let session = local_session(&temp)?;
|
||||
let commit = session.new_upload_commit_blocking()?;
|
||||
commit.upload_bytes(b"file one".to_vec(), Some("a.bin".into()))?;
|
||||
commit.upload_bytes(b"file two".to_vec(), Some("b.bin".into()))?;
|
||||
commit.upload_bytes(b"file three".to_vec(), Some("c.bin".into()))?;
|
||||
let results = commit.commit()?;
|
||||
assert_eq!(results.len(), 3);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
// After a successful commit the aggregate progress reflects bytes processed.
|
||||
fn test_upload_progress_reflects_bytes_after_commit() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let temp = tempdir()?;
|
||||
let session = local_session(&temp)?;
|
||||
let data = b"progress tracking upload data";
|
||||
let commit = session.new_upload_commit_blocking()?;
|
||||
let progress_observer = commit.clone();
|
||||
commit.upload_bytes(data.to_vec(), Some("prog.bin".into()))?;
|
||||
commit.commit()?;
|
||||
let snapshot = progress_observer.get_progress()?;
|
||||
assert!(snapshot.total().total_bytes_completed > 0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Non-tokio executor (no-panic + round-trip) ────────────────────────────
|
||||
//
|
||||
// smol, async-std, and futures::executor do not set tokio's thread-local
|
||||
// runtime context, so Handle::block_on inside external_run_async_task does
|
||||
// not panic — it just blocks the calling executor thread.
|
||||
|
||||
#[test]
|
||||
// new_upload_commit_blocking completes a full upload round-trip inside smol.
|
||||
fn test_upload_blocking_round_trip_in_smol() {
|
||||
let temp = tempdir().unwrap();
|
||||
let session = local_session(&temp).unwrap();
|
||||
|
||||
smol::block_on(async {
|
||||
let data = b"upload from smol executor";
|
||||
let commit = session.new_upload_commit_blocking().unwrap();
|
||||
let handle = commit.upload_bytes(data.to_vec(), Some("test.bin".into())).unwrap();
|
||||
let results = commit.commit().unwrap();
|
||||
let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap();
|
||||
assert_eq!(meta.file_size, data.len() as u64);
|
||||
assert!(!meta.hash.is_empty());
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
// new_upload_commit_blocking completes a full upload round-trip inside futures::executor.
|
||||
fn test_upload_blocking_round_trip_in_futures_executor() {
|
||||
let temp = tempdir().unwrap();
|
||||
let session = local_session(&temp).unwrap();
|
||||
|
||||
futures::executor::block_on(async {
|
||||
let data = b"upload from futures executor";
|
||||
let commit = session.new_upload_commit_blocking().unwrap();
|
||||
let handle = commit.upload_bytes(data.to_vec(), Some("test.bin".into())).unwrap();
|
||||
let results = commit.commit().unwrap();
|
||||
let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap();
|
||||
assert_eq!(meta.file_size, data.len() as u64);
|
||||
assert!(!meta.hash.is_empty());
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
// new_upload_commit_blocking completes a full upload round-trip inside async-std.
|
||||
fn test_upload_blocking_round_trip_in_async_std() {
|
||||
let temp = tempdir().unwrap();
|
||||
let session = local_session(&temp).unwrap();
|
||||
|
||||
async_std::task::block_on(async {
|
||||
let data = b"upload from async-std executor";
|
||||
let commit = session.new_upload_commit_blocking().unwrap();
|
||||
let handle = commit.upload_bytes(data.to_vec(), Some("test.bin".into())).unwrap();
|
||||
let results = commit.commit().unwrap();
|
||||
let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap();
|
||||
assert_eq!(meta.file_size, data.len() as u64);
|
||||
assert!(!meta.hash.is_empty());
|
||||
});
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,13 +1,17 @@
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Display;
|
||||
use std::future::Future;
|
||||
use std::panic::AssertUnwindSafe;
|
||||
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
|
||||
use std::sync::{Arc, OnceLock};
|
||||
use std::sync::{Arc, LazyLock, OnceLock, Weak};
|
||||
|
||||
use futures::FutureExt;
|
||||
use reqwest::Client;
|
||||
use tokio::runtime::{Builder as TokioRuntimeBuilder, Handle as TokioRuntimeHandle, Runtime as TokioRuntime};
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::task::JoinHandle;
|
||||
use tracing::{debug, info};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use super::XetCommon;
|
||||
use crate::config::XetConfig;
|
||||
@@ -139,6 +143,10 @@ pub struct XetRuntime {
|
||||
// System monitor instance if enabled, monitor starts on initiation
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
system_monitor: Option<SystemMonitor>,
|
||||
|
||||
/// Set only for External-mode instances created via `from_external_with_config`.
|
||||
/// Used to deregister from `EXTERNAL_RUNTIME_REGISTRY` on drop.
|
||||
external_handle_id: Option<tokio::runtime::Id>,
|
||||
}
|
||||
|
||||
// Use thread-local references to the runtime that are set on initialization among all
|
||||
@@ -148,6 +156,16 @@ thread_local! {
|
||||
static THREAD_RUNTIME_REF: RefCell<Option<(u32, Arc<XetRuntime>)>> = const { RefCell::new(None) };
|
||||
}
|
||||
|
||||
// Registry for External-mode runtimes created via from_external_with_config.
|
||||
// Keyed by tokio runtime ID so current_if_exists() can find the right XetRuntime
|
||||
// (with the correct XetConfig and XetCommon) when called from the caller's tokio threads,
|
||||
// where THREAD_RUNTIME_REF is never set.
|
||||
//
|
||||
// Uses std::sync (not tokio::sync) because the registry must be accessible from non-async
|
||||
// contexts such as Drop impls and sync builder methods.
|
||||
static EXTERNAL_RUNTIME_REGISTRY: LazyLock<std::sync::RwLock<HashMap<tokio::runtime::Id, Weak<XetRuntime>>>> =
|
||||
LazyLock::new(|| std::sync::RwLock::new(HashMap::new()));
|
||||
|
||||
impl XetRuntime {
|
||||
/// Return the current threadpool that the current worker thread uses. Will fail if
|
||||
/// called from a thread that is not spawned from the current runtime.
|
||||
@@ -166,16 +184,25 @@ impl XetRuntime {
|
||||
|
||||
#[inline]
|
||||
pub fn current_if_exists() -> Option<Arc<Self>> {
|
||||
// 1. Thread-local: set by on_thread_start in new_with_config (Owned mode).
|
||||
let maybe_rt = THREAD_RUNTIME_REF.with_borrow(|rt| rt.clone());
|
||||
|
||||
if let Some((pid, rt)) = maybe_rt
|
||||
&& pid == std::process::id()
|
||||
{
|
||||
return Some(rt);
|
||||
}
|
||||
|
||||
if let Ok(tokio_rt) = TokioRuntimeHandle::try_current() {
|
||||
Some(Self::from_external(tokio_rt))
|
||||
// 2. Handle registry: set by from_external_with_config (External mode). Returns the XetRuntime with the correct
|
||||
// XetConfig and XetCommon for this runtime.
|
||||
if let Ok(handle) = TokioRuntimeHandle::try_current() {
|
||||
if let Ok(reg) = EXTERNAL_RUNTIME_REGISTRY.read()
|
||||
&& let Some(weak) = reg.get(&handle.id())
|
||||
&& let Some(rt) = weak.upgrade()
|
||||
{
|
||||
return Some(rt);
|
||||
}
|
||||
// Fallback: no XetSession owns this handle; create a bare default-config wrapper.
|
||||
Some(Self::from_external(handle))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
@@ -209,6 +236,7 @@ impl XetRuntime {
|
||||
})
|
||||
.flatten(),
|
||||
config: Arc::new(config),
|
||||
external_handle_id: None,
|
||||
});
|
||||
|
||||
// Each thread in each of the tokio worker threads holds a reference to the runtime handling
|
||||
@@ -263,6 +291,48 @@ impl XetRuntime {
|
||||
Ok(rt)
|
||||
}
|
||||
|
||||
/// Wrap an existing tokio [`TokioRuntimeHandle`] with a [`XetRuntime`] using the provided
|
||||
/// [`XetConfig`]. No new thread pool is created; `spawn()` calls will schedule work on the
|
||||
/// runtime that owns `rt_handle`.
|
||||
///
|
||||
/// The resulting `XetRuntime` is registered in `EXTERNAL_RUNTIME_REGISTRY` so that
|
||||
/// [`XetRuntime::current()`] called from tasks running on `rt_handle`'s threads will return
|
||||
/// this instance (with the correct config and shared `XetCommon`) rather than a default
|
||||
/// throwaway. The entry is removed when the last `Arc<XetRuntime>` is dropped.
|
||||
pub fn from_external_with_config(rt_handle: TokioRuntimeHandle, config: XetConfig) -> Arc<Self> {
|
||||
let id = rt_handle.id();
|
||||
let rt = Arc::new(Self {
|
||||
runtime: std::sync::RwLock::new(None),
|
||||
handle_ref: rt_handle.into(),
|
||||
external_executor_count: 0.into(),
|
||||
sigint_shutdown: false.into(),
|
||||
common: XetCommon::new(&config),
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
system_monitor: config
|
||||
.system_monitor
|
||||
.enabled
|
||||
.then(|| {
|
||||
SystemMonitor::follow_process(
|
||||
config.system_monitor.sample_interval,
|
||||
config.system_monitor.log_path.clone(),
|
||||
)
|
||||
.ok()
|
||||
})
|
||||
.flatten(),
|
||||
config: Arc::new(config),
|
||||
external_handle_id: Some(id),
|
||||
});
|
||||
if let Ok(mut reg) = EXTERNAL_RUNTIME_REGISTRY.write() {
|
||||
reg.insert(id, Arc::downgrade(&rt));
|
||||
} else {
|
||||
warn!("EXTERNAL_RUNTIME_REGISTRY poisoned; skipping registration");
|
||||
}
|
||||
rt
|
||||
}
|
||||
|
||||
/// Wrap an existing tokio [`TokioRuntimeHandle`] with a [`XetRuntime`] using a default
|
||||
/// [`XetConfig`]. Prefer [`from_external_with_config`](Self::from_external_with_config) when
|
||||
/// you have a config available.
|
||||
pub fn from_external(rt_handle: TokioRuntimeHandle) -> Arc<Self> {
|
||||
let config = XetConfig::new();
|
||||
Arc::new(Self {
|
||||
@@ -284,6 +354,7 @@ impl XetRuntime {
|
||||
})
|
||||
.flatten(),
|
||||
config: Arc::new(config),
|
||||
external_handle_id: None,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -422,6 +493,43 @@ impl XetRuntime {
|
||||
self.handle().spawn(future)
|
||||
}
|
||||
|
||||
/// Bridge a future onto this runtime's `hf-xet-*` thread pool from any async context,
|
||||
/// including non-tokio executors (smol, async-std, `futures::executor::block_on`).
|
||||
///
|
||||
/// Unlike [`external_run_async_task`](Self::external_run_async_task) which **blocks**
|
||||
/// the calling thread, this method returns a future that any executor can poll.
|
||||
/// The result is delivered via a `tokio::sync::oneshot` channel whose receiver only
|
||||
/// requires a `std::task::Waker`, making it compatible with every standard executor.
|
||||
///
|
||||
/// Returns `Err(RuntimeError::TaskPanic)` if the spawned future panics, or
|
||||
/// `Err(RuntimeError::TaskCanceled)` if the runtime shuts down before the result
|
||||
/// can be delivered.
|
||||
pub async fn bridge_to_owned<T, F>(&self, task_name: &'static str, fut: F) -> Result<T, RuntimeError>
|
||||
where
|
||||
F: Future<Output = T> + Send + 'static,
|
||||
T: Send + 'static,
|
||||
{
|
||||
let (tx, rx) = oneshot::channel();
|
||||
self.spawn(async move {
|
||||
let result = AssertUnwindSafe(fut).catch_unwind().await;
|
||||
let _ = tx.send(result);
|
||||
});
|
||||
match rx.await {
|
||||
Ok(Ok(value)) => Ok(value),
|
||||
Ok(Err(panic_payload)) => {
|
||||
let msg = if let Some(s) = panic_payload.downcast_ref::<&str>() {
|
||||
format!("{task_name}: {s}")
|
||||
} else if let Some(s) = panic_payload.downcast_ref::<String>() {
|
||||
format!("{task_name}: {s}")
|
||||
} else {
|
||||
format!("{task_name}: <unknown panic>")
|
||||
};
|
||||
Err(RuntimeError::TaskPanic(msg))
|
||||
},
|
||||
Err(_) => Err(RuntimeError::TaskCanceled(task_name.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Spawn a blocking task on the runtime's blocking thread pool. The task runs with this
|
||||
/// runtime stored in thread-local storage so [`XetRuntime::current()`] works inside `f`.
|
||||
///
|
||||
@@ -447,6 +555,16 @@ impl XetRuntime {
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for XetRuntime {
|
||||
fn drop(&mut self) {
|
||||
if let Some(id) = &self.external_handle_id
|
||||
&& let Ok(mut reg) = EXTERNAL_RUNTIME_REGISTRY.write()
|
||||
{
|
||||
reg.remove(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for XetRuntime {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
// Need to be careful that this doesn't acquire locks eagerly, as this function can be called
|
||||
@@ -492,4 +610,28 @@ mod tests {
|
||||
let same = rt.external_run_async_task(async { jh.await.unwrap() }).unwrap();
|
||||
assert!(same);
|
||||
}
|
||||
|
||||
/// current_if_exists() must return the session-owned XetRuntime (with the correct config)
|
||||
/// when called from tasks on an External-mode runtime, not a default-config throwaway.
|
||||
#[test]
|
||||
fn test_current_if_exists_sees_external_runtime_config() {
|
||||
let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
|
||||
let mut config = XetConfig::new();
|
||||
config.data.default_cas_endpoint = "https://test-endpoint.example.com".into();
|
||||
let xet_rt = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), config);
|
||||
|
||||
// current_if_exists() from within the runtime must find the registered entry.
|
||||
tokio_rt.block_on(async {
|
||||
let found = XetRuntime::current_if_exists().expect("should find a runtime");
|
||||
assert!(Arc::ptr_eq(&found, &xet_rt), "must be the same XetRuntime instance");
|
||||
assert_eq!(found.config().data.default_cas_endpoint, "https://test-endpoint.example.com");
|
||||
});
|
||||
|
||||
// After drop the entry is removed; current_if_exists() falls back to a default wrapper.
|
||||
drop(xet_rt);
|
||||
tokio_rt.block_on(async {
|
||||
let found = XetRuntime::current_if_exists().expect("should still find a runtime");
|
||||
assert_ne!(found.config().data.default_cas_endpoint, "https://test-endpoint.example.com");
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user