Add ordered and unordered download streaming to session interface (#729)

This PR adds ordered and unordered download streams on XetSession,
including optional byte-range support and per-stream progress reporting.
Blocking and async variants are supported.

On the reconstruction side, this introduces UnorderedWriter and
UnorderedDownloadStream in xet_data, and extends the FileDownloadSession
stream APIs to take optional source ranges. Ordered and unordered
streams now share the same session-facing access pattern for async and
blocking callers.

This PR also renames DownloadGroup to FileDownloadGroup; the stream data
uses the per-session memory pool but don't count towards the maximum
number of concurrent downloads in progress.

<!-- CURSOR_SUMMARY -->
---

> [!NOTE]
> **Medium Risk**
> Touches core file reconstruction/writer plumbing (including
`DataWriter` ownership and new unordered writer/stream paths) and
changes public session APIs, so regressions could impact download
correctness, cancellation, or progress reporting.
> 
> **Overview**
> Adds first-class **ordered and unordered streaming download APIs** to
`xet_pkg::xet_session`, including async and blocking variants, optional
source-relative byte ranges, and per-stream progress via new
`XetDownloadStream` / `XetUnorderedDownloadStream` wrappers.
> 
> On the data layer, introduces an **unordered reconstruction path**
(`UnorderedWriter` + `UnorderedDownloadStream`) and refactors streaming
to spawn reconstruction tasks immediately but gate execution behind
`start()`; stream abort callbacks are now registered per-stream and
automatically unregistered on drop to avoid callback accumulation.
> 
> Updates the reconstruction writer contract by making
`DataWriter::finish` consume the writer (and shifting `DataWriter` to
`&mut self` usage), adjusts `SequentialWriter` accordingly, and adds
Criterion-based reconstruction benchmarks plus extensive
unordered-stream tests. Also renames session `DownloadGroup` to
`FileDownloadGroup` (and constructors) and updates call sites/examples.
> 
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
e02890aa4b. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->
This commit is contained in:
Hoyt Koepke
2026-03-20 14:40:18 -07:00
committed by GitHub
parent 602d7679f6
commit 332a456e1d
30 changed files with 3623 additions and 532 deletions

365
Cargo.lock generated
View File

@@ -62,10 +62,16 @@ dependencies = [
]
[[package]]
name = "anstream"
version = "1.0.0"
name = "anes"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "824a212faf96e9acacdbd09febd34438f8f711fb84e09a8916013cd7815ca28d"
checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299"
[[package]]
name = "anstream"
version = "0.6.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a"
dependencies = [
"anstyle",
"anstyle-parse",
@@ -78,15 +84,15 @@ dependencies = [
[[package]]
name = "anstyle"
version = "1.0.14"
version = "1.0.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "940b3a0ca603d1eade50a4846a2afffd5ef57a9feac2c0e2ec2e14f9ead76000"
checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78"
[[package]]
name = "anstyle-parse"
version = "1.0.0"
version = "0.2.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "52ce7f38b242319f7cabaa6813055467063ecdc9d355bbb4ce0c68908cd8130e"
checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2"
dependencies = [
"utf8parse",
]
@@ -358,6 +364,17 @@ version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
[[package]]
name = "atty"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
dependencies = [
"hermit-abi 0.1.19",
"libc",
"winapi",
]
[[package]]
name = "autocfg"
version = "1.5.0"
@@ -493,6 +510,12 @@ dependencies = [
"serde",
]
[[package]]
name = "bitflags"
version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "bitflags"
version = "2.11.0"
@@ -604,6 +627,12 @@ dependencies = [
"serde",
]
[[package]]
name = "cast"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]]
name = "cbc"
version = "0.1.2"
@@ -615,9 +644,9 @@ dependencies = [
[[package]]
name = "cc"
version = "1.2.57"
version = "1.2.56"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a0dd1ca384932ff3641c8718a02769f1698e7563dc6974ffd03346116310423"
checksum = "aebf35691d1bfb0ac386a69bac2fde4dd276fb618cf8bf4f5318fe285e821bb2"
dependencies = [
"find-msvc-tools",
"jobserver",
@@ -673,6 +702,33 @@ dependencies = [
"windows-link",
]
[[package]]
name = "ciborium"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e"
dependencies = [
"ciborium-io",
"ciborium-ll",
"serde",
]
[[package]]
name = "ciborium-io"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757"
[[package]]
name = "ciborium-ll"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9"
dependencies = [
"ciborium-io",
"half",
]
[[package]]
name = "cipher"
version = "0.4.4"
@@ -685,9 +741,21 @@ dependencies = [
[[package]]
name = "clap"
version = "4.6.0"
version = "3.2.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b193af5b67834b676abd72466a96c1024e6a6ad978a1f484bd90b85c94041351"
checksum = "4ea181bf566f71cb9a5d17a59e1871af638180a18fb0035c92ae62b705207123"
dependencies = [
"bitflags 1.3.2",
"clap_lex 0.2.4",
"indexmap 1.9.3",
"textwrap",
]
[[package]]
name = "clap"
version = "4.5.60"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2797f34da339ce31042b27d23607e051786132987f595b02ba4f6a6dffb7030a"
dependencies = [
"clap_builder",
"clap_derive",
@@ -695,21 +763,21 @@ dependencies = [
[[package]]
name = "clap_builder"
version = "4.6.0"
version = "4.5.60"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "714a53001bf66416adb0e2ef5ac857140e7dc3a0c48fb28b2f10762fc4b5069f"
checksum = "24a241312cea5059b13574bb9b3861cabf758b879c15190b37b6d6fd63ab6876"
dependencies = [
"anstream",
"anstyle",
"clap_lex",
"clap_lex 1.0.0",
"strsim",
]
[[package]]
name = "clap_derive"
version = "4.6.0"
version = "4.5.55"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1110bd8a634a1ab8cb04345d8d878267d57c3cf1b38d91b71af6686408bbca6a"
checksum = "a92793da1a46a5f2a02a6f4c46c6496b28c43638adea8306fcb0caa1634f24e5"
dependencies = [
"heck",
"proc-macro2",
@@ -719,9 +787,18 @@ dependencies = [
[[package]]
name = "clap_lex"
version = "1.1.0"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9"
checksum = "2850f2f5a82cbf437dd5af4d49848fbdfc27c157c3d010345776f952765261c5"
dependencies = [
"os_str_bytes",
]
[[package]]
name = "clap_lex"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a822ea5bc7590f9d40f1ba12c0dc3c2760f3482c6984db1573ad11031420831"
[[package]]
name = "cmake"
@@ -734,9 +811,9 @@ dependencies = [
[[package]]
name = "colorchoice"
version = "1.0.5"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570"
checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75"
[[package]]
name = "colored"
@@ -897,6 +974,44 @@ dependencies = [
"cfg-if 1.0.4",
]
[[package]]
name = "criterion"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7c76e09c1aae2bc52b3d2f29e13c6572553b30c4aa1b8a49fd70de6412654cb"
dependencies = [
"anes",
"atty",
"cast",
"ciborium",
"clap 3.2.25",
"criterion-plot",
"futures",
"itertools 0.10.5",
"lazy_static",
"num-traits",
"oorandom",
"plotters",
"rayon",
"regex",
"serde",
"serde_derive",
"serde_json",
"tinytemplate",
"tokio",
"walkdir",
]
[[package]]
name = "criterion-plot"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1"
dependencies = [
"cast",
"itertools 0.10.5",
]
[[package]]
name = "crossbeam-channel"
version = "0.5.15"
@@ -906,6 +1021,25 @@ dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-deque"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51"
dependencies = [
"crossbeam-epoch",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-epoch"
version = "0.9.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-queue"
version = "0.3.12"
@@ -1622,7 +1756,7 @@ version = "0.20.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b88256088d75a56f8ecfa070513a775dd9107f6530ef14919dac831af9cfe2b"
dependencies = [
"bitflags",
"bitflags 2.11.0",
"libc",
"libgit2-sys",
"log",
@@ -1637,7 +1771,7 @@ version = "0.2.1"
dependencies = [
"anyhow",
"async-trait",
"clap",
"clap 4.5.60",
"git-url-parse",
"git2",
"hf-xet",
@@ -1694,7 +1828,7 @@ dependencies = [
"futures-core",
"futures-sink",
"http",
"indexmap",
"indexmap 2.13.0",
"slab",
"tokio",
"tokio-util",
@@ -1712,6 +1846,12 @@ dependencies = [
"zerocopy",
]
[[package]]
name = "hashbrown"
version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
[[package]]
name = "hashbrown"
version = "0.15.5"
@@ -1819,7 +1959,7 @@ version = "0.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a56c94661ddfb51aa9cdfbf102cfcc340aa69267f95ebccc4af08d7c530d393"
dependencies = [
"bitflags",
"bitflags 2.11.0",
"byteorder",
"heed-traits",
"heed-types",
@@ -1851,6 +1991,15 @@ dependencies = [
"serde_json",
]
[[package]]
name = "hermit-abi"
version = "0.1.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33"
dependencies = [
"libc",
]
[[package]]
name = "hermit-abi"
version = "0.5.2"
@@ -1876,7 +2025,8 @@ dependencies = [
"anyhow",
"async-std",
"async-trait",
"clap",
"bytes",
"clap 4.5.60",
"futures",
"http",
"more-asserts",
@@ -1890,7 +2040,6 @@ dependencies = [
"tokio",
"tracing",
"tracing-subscriber",
"ulid",
"xet-client",
"xet-core-structures",
"xet-data",
@@ -2243,6 +2392,16 @@ dependencies = [
"icu_properties",
]
[[package]]
name = "indexmap"
version = "1.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99"
dependencies = [
"autocfg",
"hashbrown 0.12.3",
]
[[package]]
name = "indexmap"
version = "2.13.0"
@@ -2333,6 +2492,15 @@ version = "1.70.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695"
[[package]]
name = "itertools"
version = "0.10.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473"
dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.14.0"
@@ -2775,7 +2943,7 @@ version = "0.29.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "71e2746dc3a24dd78b3cfcb7be93368c6de9963d30f43a6a73998a9cf4b17b46"
dependencies = [
"bitflags",
"bitflags 2.11.0",
"cfg-if 1.0.4",
"cfg_aliases",
"libc",
@@ -2898,7 +3066,7 @@ version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b"
dependencies = [
"hermit-abi",
"hermit-abi 0.5.2",
"libc",
]
@@ -2908,7 +3076,7 @@ version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a180dd8642fa45cdb7dd721cd4c11b1cadd4929ce112ebd8b9f5803cc79d536"
dependencies = [
"bitflags",
"bitflags 2.11.0",
]
[[package]]
@@ -2932,9 +3100,9 @@ dependencies = [
[[package]]
name = "once_cell"
version = "1.21.4"
version = "1.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
[[package]]
name = "once_cell_polyfill"
@@ -2948,6 +3116,12 @@ version = "0.1.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "269bca4c2591a28585d6bf10d9ed0332b7d76900a1b02bec41bdc3a2cdcda107"
[[package]]
name = "oorandom"
version = "11.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e"
[[package]]
name = "opaque-debug"
version = "0.3.1"
@@ -2956,11 +3130,11 @@ checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381"
[[package]]
name = "openssl"
version = "0.10.76"
version = "0.10.75"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "951c002c75e16ea2c65b8c7e4d3d51d5530d8dfa7d060b4776828c88cfb18ecf"
checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328"
dependencies = [
"bitflags",
"bitflags 2.11.0",
"cfg-if 1.0.4",
"foreign-types",
"libc",
@@ -3003,9 +3177,9 @@ dependencies = [
[[package]]
name = "openssl-sys"
version = "0.9.112"
version = "0.9.111"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57d55af3b3e226502be1526dfdba67ab0e9c96fc293004e79576b2b9edb0dbdb"
checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321"
dependencies = [
"cc",
"libc",
@@ -3311,6 +3485,34 @@ version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
[[package]]
name = "plotters"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5aeb6f403d7a4911efb1e33402027fc44f29b5bf6def3effcc22d7bb75f2b747"
dependencies = [
"num-traits",
"plotters-backend",
"plotters-svg",
"wasm-bindgen",
"web-sys",
]
[[package]]
name = "plotters-backend"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df42e13c12958a16b3f7f4386b9ab1f3e7933914ecea48da7139435263a4172a"
[[package]]
name = "plotters-svg"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51bae2ac328883f7acdfea3d66a7c35751187f870bc81f94563733a154d7a670"
dependencies = [
"plotters-backend",
]
[[package]]
name = "polling"
version = "3.11.0"
@@ -3319,7 +3521,7 @@ checksum = "5d0e4f59085d47d8241c88ead0f274e8a0cb551f3625263c05eb8dd897c34218"
dependencies = [
"cfg-if 1.0.4",
"concurrent-queue",
"hermit-abi",
"hermit-abi 0.5.2",
"pin-project-lite",
"rustix",
"windows-sys 0.61.2",
@@ -3486,7 +3688,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "27c6023962132f4b30eb4c172c91ce92d933da334c59c23cddee82358ddafb0b"
dependencies = [
"anyhow",
"itertools",
"itertools 0.14.0",
"proc-macro2",
"quote",
"syn 2.0.117",
@@ -3773,6 +3975,26 @@ version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
[[package]]
name = "rayon"
version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f"
dependencies = [
"either",
"rayon-core",
]
[[package]]
name = "rayon-core"
version = "1.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91"
dependencies = [
"crossbeam-deque",
"crossbeam-utils",
]
[[package]]
name = "rdrand"
version = "0.4.0"
@@ -3788,7 +4010,7 @@ version = "0.5.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d"
dependencies = [
"bitflags",
"bitflags 2.11.0",
]
[[package]]
@@ -3980,7 +4202,7 @@ checksum = "82b4d036bb45d7bbe99dbfef4ec60eaeb614708d22ff107124272f8ef6b54548"
dependencies = [
"aes",
"aws-lc-rs",
"bitflags",
"bitflags 2.11.0",
"block-padding",
"byteorder",
"bytes",
@@ -4100,7 +4322,7 @@ version = "1.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190"
dependencies = [
"bitflags",
"bitflags 2.11.0",
"errno",
"libc",
"linux-raw-sys",
@@ -4294,7 +4516,7 @@ version = "3.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d"
dependencies = [
"bitflags",
"bitflags 2.11.0",
"core-foundation 0.10.1",
"core-foundation-sys",
"libc",
@@ -4546,7 +4768,7 @@ dependencies = [
"anyhow",
"bytes",
"chrono",
"clap",
"clap 4.5.60",
"duration-str",
"http",
"rand 0.9.2",
@@ -4788,7 +5010,7 @@ version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a13f3d0daba03132c0aa9767f98351b3488edc2c100cda2d2ec2b04f3d8d3c8b"
dependencies = [
"bitflags",
"bitflags 2.11.0",
"core-foundation 0.9.4",
"system-configuration-sys",
]
@@ -4847,6 +5069,12 @@ version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683"
[[package]]
name = "textwrap"
version = "0.16.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c13547615a44dc9c452a8a534638acdf07120d4b6847c8178705da06306a3057"
[[package]]
name = "thiserror"
version = "1.0.69"
@@ -4938,10 +5166,20 @@ dependencies = [
]
[[package]]
name = "tinyvec"
version = "1.11.0"
name = "tinytemplate"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3e61e67053d25a4e82c844e8424039d9745781b3fc4f32b8d55ed50f5f667ef3"
checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc"
dependencies = [
"serde",
"serde_json",
]
[[package]]
name = "tinyvec"
version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa"
dependencies = [
"tinyvec_macros",
]
@@ -5104,7 +5342,7 @@ checksum = "ebe5ef63511595f1344e2d5cfa636d973292adc0eec1f0ad45fae9f0851ab1d4"
dependencies = [
"futures-core",
"futures-util",
"indexmap",
"indexmap 2.13.0",
"pin-project-lite",
"slab",
"sync_wrapper",
@@ -5121,7 +5359,7 @@ version = "0.6.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8"
dependencies = [
"bitflags",
"bitflags 2.11.0",
"bytes",
"futures-util",
"http",
@@ -5213,9 +5451,9 @@ dependencies = [
[[package]]
name = "tracing-subscriber"
version = "0.3.23"
version = "0.3.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb7f578e5945fb242538965c2d0b04418d38ec25c79d160cd279bf0731c8d319"
checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e"
dependencies = [
"matchers",
"nu-ansi-term",
@@ -5570,7 +5808,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909"
dependencies = [
"anyhow",
"indexmap",
"indexmap 2.13.0",
"wasm-encoder",
"wasmparser",
]
@@ -5594,9 +5832,9 @@ version = "0.244.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe"
dependencies = [
"bitflags",
"bitflags 2.11.0",
"hashbrown 0.15.5",
"indexmap",
"indexmap 2.13.0",
"semver",
]
@@ -6100,7 +6338,7 @@ checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21"
dependencies = [
"anyhow",
"heck",
"indexmap",
"indexmap 2.13.0",
"prettyplease",
"syn 2.0.117",
"wasm-metadata",
@@ -6130,8 +6368,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2"
dependencies = [
"anyhow",
"bitflags",
"indexmap",
"bitflags 2.11.0",
"indexmap 2.13.0",
"log",
"serde",
"serde_derive",
@@ -6150,7 +6388,7 @@ checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736"
dependencies = [
"anyhow",
"id-arena",
"indexmap",
"indexmap 2.13.0",
"log",
"semver",
"serde",
@@ -6176,7 +6414,7 @@ dependencies = [
"axum",
"base64 0.22.1",
"bytes",
"clap",
"clap 4.5.60",
"crc32fast",
"ctor",
"derivative",
@@ -6229,7 +6467,7 @@ dependencies = [
"blake3",
"bytemuck",
"bytes",
"clap",
"clap 4.5.60",
"countio",
"csv",
"futures",
@@ -6238,7 +6476,7 @@ dependencies = [
"half",
"heapify",
"heed",
"itertools",
"itertools 0.14.0",
"lazy_static",
"lz4_flex",
"more-asserts",
@@ -6266,12 +6504,13 @@ dependencies = [
"async-trait",
"bytes",
"chrono",
"clap",
"clap 4.5.60",
"criterion",
"ctor",
"dirs",
"gearhash",
"http",
"itertools",
"itertools 0.14.0",
"lazy_static",
"more-asserts",
"prometheus",

View File

@@ -0,0 +1,70 @@
# Streaming download APIs
**Date**: 2026-03-19
## Summary
This update adds first-class stream download APIs in `xet_pkg::xet_session`.
It also renames the session download group type from `DownloadGroup` to
`FileDownloadGroup` and renames the corresponding session constructors.
## `xet_pkg::xet_session` API changes
### Type/constructor renames
- `DownloadGroup` -> `FileDownloadGroup`
- `XetSession::new_download_group()` -> `XetSession::new_file_download_group()`
- `XetSession::new_download_group_blocking()` ->
`XetSession::new_file_download_group_blocking()`
### New streaming session APIs
- Ordered stream:
- `XetSession::download_stream(file_info, range)` (async)
- `XetSession::download_stream_blocking(file_info, range)` (sync)
- Returns `XetDownloadStream`
- Unordered stream:
- `XetSession::download_unordered_stream(file_info, range)` (async)
- `XetSession::download_unordered_stream_blocking(file_info, range)` (sync)
- Returns `XetUnorderedDownloadStream`
Both stream types support:
- `start()`
- async and blocking next-item APIs
- `cancel()`
- per-stream progress via `get_progress()` (returns `Option<ItemProgressReport>`)
Range semantics are source-file-relative (`Option<Range<u64>>`).
Stream abort callbacks are automatically unregistered on drop to prevent
accumulation in long-lived sessions.
## `xet_data` API/config changes
### New unordered stream path
- `FileDownloadSession::download_unordered_stream(file_info, source_range)`
- `FileDownloadSession::download_stream_range(file_info, range)` — ordered
stream with `RangeBounds<u64>` (merged from main; includes open-ended ranges)
- `processing::mod` now re-exports `UnorderedDownloadStream`
### `DataWriter` trait contract update
- `finish` now consumes the writer:
- `async fn finish(self: Box<Self>) -> Result<u64>`
### Stream abort callback API
- `FileDownloadSession::register_stream_abort_callback` now takes `(UniqueID, callback)`.
- New: `FileDownloadSession::unregister_stream_abort_callback(UniqueID)`.
- Callbacks are stored in a `HashMap<UniqueID, _>` instead of a `Vec`.
## Migration notes
- Downstream code must update symbol/method names from `DownloadGroup` /
`new_download_group*` to `FileDownloadGroup` /
`new_file_download_group*`.
- Consumers that need chunk-level streaming can migrate from group/file APIs to
the new stream APIs when appropriate.

2
hf_xet/Cargo.lock generated
View File

@@ -1237,6 +1237,7 @@ version = "1.4.0"
dependencies = [
"anyhow",
"async-trait",
"bytes",
"http",
"more-asserts",
"pyo3",
@@ -1244,7 +1245,6 @@ dependencies = [
"thiserror 2.0.18",
"tokio",
"tracing",
"ulid",
"xet-client",
"xet-core-structures",
"xet-data",

View File

@@ -73,6 +73,7 @@ name = "xorb-check"
path = "examples/xorb-check/main.rs"
[dev-dependencies]
criterion = { version = "0.4", features = ["async_tokio"] }
ctor = { workspace = true }
dirs = { workspace = true }
rand = { workspace = true }
@@ -80,6 +81,10 @@ serial_test = { workspace = true }
tempfile = { workspace = true }
tracing-test = { workspace = true }
[[bench]]
name = "reconstruction_bench"
harness = false
[features]
strict = []
smoke-test = []

View File

@@ -0,0 +1,91 @@
use std::sync::Arc;
use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main};
use tempfile::TempDir;
use tokio::runtime::Runtime;
use xet_client::cas_client::{Client, MemoryClient};
use xet_data::file_reconstruction::FileReconstructor;
use xet_runtime::config::ReconstructionConfig;
struct BenchFixture {
client: Arc<dyn Client>,
file_hash: xet_core_structures::merklehash::MerkleHash,
_file_size: u64,
}
async fn create_fixture(num_xorbs: usize, chunks_per_xorb: u64, chunk_size: usize) -> BenchFixture {
let client = MemoryClient::new();
let term_spec: Vec<(u64, (u64, u64))> = (0..num_xorbs).map(|i| ((i + 1) as u64, (0, chunks_per_xorb))).collect();
let file_contents = client.insert_random_lazy_file(&term_spec, chunk_size).await.unwrap();
let file_size = file_contents.data.len() as u64;
BenchFixture {
client: client as Arc<dyn Client>,
file_hash: file_contents.file_hash,
_file_size: file_size,
}
}
fn bench_sequential_non_vectored(c: &mut Criterion) {
let rt = Runtime::new().unwrap();
let fixture = rt.block_on(create_fixture(4, 256, 65_536));
let mut config = ReconstructionConfig::default();
config.use_vectored_write = false;
c.bench_with_input(
BenchmarkId::new("reconstruct/sequential_write", format!("{}MB", fixture._file_size / (1024 * 1024))),
&fixture,
|b, fix| {
b.to_async(&rt).iter(|| {
let client = fix.client.clone();
let hash = fix.file_hash;
let cfg = config.clone();
async move {
let dir = TempDir::new().unwrap();
let path = dir.path().join("out.bin");
FileReconstructor::new(&client, hash)
.with_config(cfg)
.reconstruct_to_file(&path, None)
.await
.unwrap();
}
});
},
);
}
fn bench_sequential_vectored(c: &mut Criterion) {
let rt = Runtime::new().unwrap();
let fixture = rt.block_on(create_fixture(4, 256, 65_536));
let mut config = ReconstructionConfig::default();
config.use_vectored_write = true;
c.bench_with_input(
BenchmarkId::new("reconstruct/vectored_write", format!("{}MB", fixture._file_size / (1024 * 1024))),
&fixture,
|b, fix| {
b.to_async(&rt).iter(|| {
let client = fix.client.clone();
let hash = fix.file_hash;
let cfg = config.clone();
async move {
let dir = TempDir::new().unwrap();
let path = dir.path().join("out.bin");
FileReconstructor::new(&client, hash)
.with_config(cfg)
.reconstruct_to_file(&path, None)
.await
.unwrap();
}
});
},
);
}
criterion_group!(benches, bench_sequential_non_vectored, bench_sequential_vectored);
criterion_main!(benches);

View File

@@ -11,7 +11,7 @@ use super::super::Result;
pub type DataFuture = Pin<Box<dyn Future<Output = Result<Bytes>> + Send + 'static>>;
#[async_trait::async_trait]
pub trait DataWriter: Send + Sync + 'static {
pub trait DataWriter: Send + 'static {
/// Sets the data source for the next sequential term.
///
/// The byte range must be sequential - its start must match the end of the
@@ -23,14 +23,14 @@ pub trait DataWriter: Send + Sync + 'static {
/// An optional semaphore permit can be passed for rate limiting. The permit
/// will be released by the background writer after the data has been written.
async fn set_next_term_data_source(
&self,
&mut self,
byte_range: FileRange,
permit: Option<AdjustableSemaphorePermit>,
data_future: DataFuture,
) -> Result<()>;
/// Waits until all data has been written and returns the number of bytes written.
///
/// Once this method is called, further calls to set_next_term_data_source will fail.
async fn finish(&self) -> Result<u64>;
/// Consumes the writer, waits until all data has been written, and returns the
/// number of bytes written. Dropping the writer without calling `finish` cancels
/// the reconstruction via the shared run state.
async fn finish(mut self: Box<Self>) -> Result<u64>;
}

View File

@@ -1,6 +1,7 @@
use std::sync::Arc;
use bytes::Bytes;
use tokio::sync::Notify;
use tokio::sync::mpsc::UnboundedReceiver;
use tracing::info;
@@ -11,10 +12,12 @@ use super::sequential_writer::{SequentialRetrievalItem, SequentialWriter};
/// A streaming download handle that yields data chunks as they are reconstructed.
///
/// Created by [`FileReconstructor::reconstruct_to_stream`]. The reconstruction
/// task is **not** started until [`start`](Self::start) is called explicitly, or
/// automatically on the first call to [`next`](Self::next) /
/// [`blocking_next`](Self::blocking_next).
/// Created by [`FileReconstructor::reconstruct_to_stream`]. The reconstruction
/// task is spawned immediately but pauses until [`start`](Self::start) is
/// called (or the first [`next`](Self::next) / [`blocking_next`](Self::blocking_next)).
/// Because the `tokio::spawn` happens at construction time, subsequent calls to
/// `start()`, `next()`, and `blocking_next()` do **not** require a tokio runtime
/// context.
///
/// Data is delivered by pulling items directly from the sequential writer's
/// internal queue, bypassing the synchronous writer thread entirely. Each call
@@ -23,52 +26,83 @@ use super::sequential_writer::{SequentialRetrievalItem, SequentialWriter};
/// reconstruction error is surfaced on the call that would have returned the
/// next chunk (or on the final `None` boundary) via the shared run state.
pub struct DownloadStream {
/// The `FileReconstructor` to start when `start()` is called.
/// `None` once the reconstruction has been started (or cancelled before start).
reconstructor: Option<FileReconstructor>,
/// Channel receiver for sequential retrieval items from the writer queue (set after start).
receiver: Option<UnboundedReceiver<SequentialRetrievalItem>>,
/// Channel receiver for sequential retrieval items from the writer queue.
receiver: UnboundedReceiver<SequentialRetrievalItem>,
/// Whether the stream has finished (no more data).
finished: bool,
/// Shared run state with the `FileReconstructor`. When cancelled,
/// the reconstruction loop aborts promptly at its next check point or
/// `select!` branch. Also used for progress reporting and error propagation.
run_state: Arc<RunState>,
/// Signal to unblock the spawned reconstruction task. `Some` means
/// `start()` has not yet been called; the spawned task is waiting.
start_signal: Option<Arc<Notify>>,
}
impl DownloadStream {
/// Creates a new `DownloadStream`, immediately spawning the reconstruction
/// task on the current tokio runtime. The task blocks on an internal
/// [`Notify`] until [`start`](Self::start) is called.
///
/// # Panics
///
/// Panics if called outside a tokio runtime context.
pub(crate) fn new(reconstructor: FileReconstructor, run_state: Arc<RunState>) -> Self {
let (data_writer, receiver) = SequentialWriter::new_streaming(run_state.clone());
let start_signal = Arc::new(Notify::new());
let signal = start_signal.clone();
let rs = run_state.clone();
tokio::spawn(async move {
signal.notified().await;
info!(file_hash = %rs.file_hash(), "Starting download stream");
let _ = reconstructor.run(data_writer, rs, true).await;
});
Self {
reconstructor: Some(reconstructor),
receiver: None,
receiver,
finished: false,
run_state,
start_signal: Some(start_signal),
}
}
/// Starts the reconstruction task in the background. If already started,
/// this is a no-op. Called automatically on the first [`next`](Self::next) /
/// [`blocking_next`](Self::blocking_next).
pub(crate) fn abort_callback(&self) -> Box<dyn Fn() + Send + Sync> {
let run_state = self.run_state.clone();
let start_signal = self.start_signal.clone();
Box::new(move || {
run_state.cancel();
if let Some(signal) = start_signal.as_ref() {
signal.notify_one();
}
})
}
/// Unblocks the reconstruction task so it begins producing data.
///
/// If already started, this is a no-op. Called automatically on the first
/// [`next`](Self::next) / [`blocking_next`](Self::blocking_next).
///
/// This method is non-async and does not require a tokio runtime context.
pub fn start(&mut self) {
if let Some(reconstructor) = self.reconstructor.take() {
info!(file_hash = %self.run_state.file_hash(), "Starting download stream");
let (data_writer, receiver) = SequentialWriter::new_streaming(self.run_state.clone());
let run_state = self.run_state.clone();
tokio::spawn(async move {
let _ = reconstructor.run(data_writer, run_state, true).await;
});
self.receiver = Some(receiver);
if let Some(signal) = self.start_signal.take() {
signal.notify_one();
}
}
fn ensure_started(&mut self) {
if self.reconstructor.is_some() {
if self.start_signal.is_some() {
self.start();
}
}
fn cancel_reconstruction(&self) {
self.run_state.cancel();
if let Some(signal) = self.start_signal.as_ref() {
signal.notify_one();
}
}
/// Returns the next chunk of downloaded data, blocking the current thread
/// until data is available.
///
@@ -85,9 +119,8 @@ impl DownloadStream {
return Ok(None);
}
self.ensure_started();
let receiver = self.receiver.as_mut().expect("receiver must exist after start");
match receiver.blocking_recv() {
match self.receiver.blocking_recv() {
Some(SequentialRetrievalItem::Data { receiver, permit }) => {
let data = receiver.blocking_recv().map_err(|_| {
FileReconstructionError::InternalWriterError(
@@ -108,15 +141,24 @@ impl DownloadStream {
/// Returns the next chunk of downloaded data asynchronously.
///
/// Returns `Ok(None)` when the download is complete.
/// Returns `Ok(None)` when the download is complete or cancelled.
pub async fn next(&mut self) -> Result<Option<Bytes>> {
if self.finished {
return Ok(None);
}
self.ensure_started();
let receiver = self.receiver.as_mut().expect("receiver must exist after start");
match receiver.recv().await {
let item = if let Ok(item) = self.receiver.try_recv() {
Some(item)
} else {
tokio::select! {
biased;
recv = self.receiver.recv() => recv,
_ = self.run_state.cancelled() => None,
}
};
match item {
Some(SequentialRetrievalItem::Data { receiver, permit }) => {
let data = receiver.await.map_err(|_| {
FileReconstructionError::InternalWriterError(
@@ -142,20 +184,16 @@ impl DownloadStream {
/// After calling this, subsequent calls to [`blocking_next`](Self::blocking_next)
/// / [`next`](Self::next) will return `Ok(None)`.
pub fn cancel(&mut self) {
self.run_state.cancel();
self.reconstructor.take();
if let Some(ref mut receiver) = self.receiver {
receiver.close();
}
self.cancel_reconstruction();
let _ = self.start_signal.take();
self.receiver.close();
self.finished = true;
}
}
impl Drop for DownloadStream {
fn drop(&mut self) {
self.run_state.cancel();
if let Some(ref mut receiver) = self.receiver {
receiver.close();
}
self.cancel_reconstruction();
self.receiver.close();
}
}

View File

@@ -2,7 +2,11 @@
mod data_writer;
pub mod download_stream;
mod sequential_writer;
pub mod unordered_download_stream;
mod unordered_writer;
pub use data_writer::{DataFuture, DataWriter};
pub use download_stream::DownloadStream;
pub use sequential_writer::SequentialWriter;
pub use unordered_download_stream::UnorderedDownloadStream;
pub use unordered_writer::UnorderedWriter;

View File

@@ -5,7 +5,7 @@ use std::sync::atomic::{AtomicU64, Ordering};
use bytes::Bytes;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
use tokio::sync::{Mutex, oneshot};
use tokio::sync::oneshot;
use tokio::task::{JoinHandle, JoinSet};
use xet_client::cas_types::FileRange;
use xet_runtime::core::{XetRuntime, check_sigint_shutdown};
@@ -207,27 +207,24 @@ impl SyncWriterThread {
}
}
/// Mutable state for the writing queue, protected by a mutex.
struct WritingQueueState {
sender: UnboundedSender<SequentialRetrievalItem>,
next_position: u64,
finished: bool,
}
/// Writes data sequentially to an output stream from async data futures.
/// Spawns async tasks to resolve futures and a background thread to perform
/// blocking writes, allowing out-of-order future resolution with in-order writes.
pub struct SequentialWriter {
queue_state: Mutex<WritingQueueState>,
background_handle: Mutex<Option<JoinHandle<()>>>,
sender: UnboundedSender<SequentialRetrievalItem>,
next_position: u64,
background_handle: Option<JoinHandle<()>>,
run_state: Arc<RunState>,
bytes_written: Arc<AtomicU64>,
active_tasks: Arc<Mutex<JoinSet<Result<()>>>>,
active_tasks: JoinSet<Result<()>>,
finished: bool,
}
impl Drop for SequentialWriter {
fn drop(&mut self) {
self.run_state.cancel();
if !self.finished {
self.run_state.cancel();
}
}
}
@@ -237,56 +234,38 @@ impl DataWriter for SequentialWriter {
/// can be executing in the background. This must be the next one sequentially,
/// otherwise it will error out.
async fn set_next_term_data_source(
&self,
&mut self,
byte_range: FileRange,
permit: Option<AdjustableSemaphorePermit>,
data_future: DataFuture,
) -> Result<()> {
self.run_state.check_error()?;
// Check for any errors from previously spawned tasks.
{
let mut tasks = self.active_tasks.lock().await;
while let Some(result) = tasks.try_join_next() {
result.map_err(|e| FileReconstructionError::InternalError(format!("Task join error: {e}")))??;
}
while let Some(result) = self.active_tasks.try_join_next() {
result.map_err(|e| FileReconstructionError::InternalError(format!("Task join error: {e}")))??;
}
let (sender, expected_size) = {
let mut state = self.queue_state.lock().await;
if self.finished {
return Err(FileReconstructionError::InternalWriterError("Writer has already finished".to_string()));
}
if state.finished {
return Err(FileReconstructionError::InternalWriterError("Writer has already finished".to_string()));
}
if byte_range.start != self.next_position {
return Err(FileReconstructionError::InternalWriterError(format!(
"Byte range not sequential: expected start at {}, got {}",
self.next_position, byte_range.start
)));
}
if byte_range.start != state.next_position {
return Err(FileReconstructionError::InternalWriterError(format!(
"Byte range not sequential: expected start at {}, got {}",
state.next_position, byte_range.start
)));
}
let expected_size = byte_range.end - byte_range.start;
self.next_position = byte_range.end;
let expected_size = byte_range.end - byte_range.start;
state.next_position = byte_range.end;
let (sender, receiver) = oneshot::channel();
let (sender, receiver) = oneshot::channel();
if self.sender.send(SequentialRetrievalItem::Data { receiver, permit }).is_err() {
self.run_state.check_error()?;
return Err(FileReconstructionError::InternalWriterError("Background writer channel closed".to_string()));
}
if state.sender.send(SequentialRetrievalItem::Data { receiver, permit }).is_err() {
// The background writer exited. Return the original error that
// killed it (stored in RunState) instead of a generic message.
drop(state);
self.run_state.check_error()?;
return Err(FileReconstructionError::InternalWriterError(
"Background writer channel closed".to_string(),
));
}
(sender, expected_size)
};
// Spawn a task to evaluate the future and send the result.
// On error, set_error() stores the error and cancels the token,
// immediately waking the main reconstruction loop.
let run_state = self.run_state.clone();
let task = async move {
let result = async {
@@ -319,48 +298,34 @@ impl DataWriter for SequentialWriter {
result
};
{
let mut tasks = self.active_tasks.lock().await;
tasks.spawn(task);
}
self.active_tasks.spawn(task);
Ok(())
}
/// Wait for the background writer to finish and all tasks to complete.
/// Returns the number of bytes written.
async fn finish(&self) -> Result<u64> {
async fn finish(mut self: Box<Self>) -> Result<u64> {
self.run_state.check_error()?;
let expected_bytes = {
let mut state = self.queue_state.lock().await;
if state.finished {
return Err(FileReconstructionError::InternalWriterError("Writer has already finished".to_string()));
}
state.finished = true;
if state.sender.send(SequentialRetrievalItem::Finish).is_err() {
drop(state);
self.run_state.check_error()?;
return Err(FileReconstructionError::InternalWriterError(
"Background writer channel closed".to_string(),
));
}
state.next_position
};
// Wait for all spawned data-fetching tasks to complete.
{
let mut tasks = self.active_tasks.lock().await;
while let Some(result) = tasks.join_next().await {
result.map_err(|e| FileReconstructionError::InternalError(format!("Task join error: {e}")))??;
}
if self.finished {
return Err(FileReconstructionError::InternalWriterError("Writer has already finished".to_string()));
}
match self.background_handle.lock().await.take() {
self.finished = true;
if self.sender.send(SequentialRetrievalItem::Finish).is_err() {
self.run_state.check_error()?;
return Err(FileReconstructionError::InternalWriterError("Background writer channel closed".to_string()));
}
let expected_bytes = self.next_position;
while let Some(result) = self.active_tasks.join_next().await {
result.map_err(|e| FileReconstructionError::InternalError(format!("Task join error: {e}")))??;
}
match self.background_handle.take() {
Some(handle) => {
handle.await.map_err(|e| {
FileReconstructionError::InternalWriterError(format!("Background writer task failed: {e}"))
@@ -395,25 +360,20 @@ impl SequentialWriter {
/// values that the caller (typically a `DownloadStream`) consumes directly.
pub(crate) fn new_streaming(
run_state: Arc<RunState>,
) -> (Arc<dyn DataWriter>, UnboundedReceiver<SequentialRetrievalItem>) {
) -> (Box<dyn DataWriter>, UnboundedReceiver<SequentialRetrievalItem>) {
let (tx, rx) = unbounded_channel::<SequentialRetrievalItem>();
let bytes_written = Arc::new(AtomicU64::new(0));
let writing_queue_state = WritingQueueState {
let writer = Self {
sender: tx,
next_position: 0,
background_handle: None,
run_state,
bytes_written: Arc::new(AtomicU64::new(0)),
active_tasks: JoinSet::new(),
finished: false,
};
let writer = Self {
queue_state: Mutex::new(writing_queue_state),
background_handle: Mutex::new(None),
run_state,
bytes_written,
active_tasks: Arc::new(Mutex::new(JoinSet::new())),
};
(Arc::new(writer), rx)
(Box::new(writer), rx)
}
/// Creates a sequential writer backed by the given `Write` impl.
@@ -426,7 +386,7 @@ impl SequentialWriter {
writer: W,
use_vectorized: bool,
run_state: Arc<RunState>,
) -> Arc<dyn DataWriter> {
) -> Box<dyn DataWriter> {
let (tx, rx) = unbounded_channel::<SequentialRetrievalItem>();
let bytes_written = Arc::new(AtomicU64::new(0));
@@ -446,18 +406,14 @@ impl SequentialWriter {
}
});
let writing_queue_state = WritingQueueState {
Box::new(Self {
sender: tx,
next_position: 0,
finished: false,
};
Arc::new(Self {
queue_state: Mutex::new(writing_queue_state),
background_handle: Mutex::new(Some(handle)),
background_handle: Some(handle),
run_state,
bytes_written,
active_tasks: Arc::new(Mutex::new(JoinSet::new())),
active_tasks: JoinSet::new(),
finished: false,
})
}
}
@@ -630,7 +586,7 @@ mod tests {
let buffer = Arc::new(std::sync::Mutex::new(Vec::new()));
let buffer_clone = buffer.clone();
let writer = SequentialWriter::new(Box::new(SharedBuffer(buffer_clone)), false, RunState::new_for_test());
let mut writer = SequentialWriter::new(Box::new(SharedBuffer(buffer_clone)), false, RunState::new_for_test());
writer
.set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
@@ -656,7 +612,7 @@ mod tests {
let buffer = Arc::new(std::sync::Mutex::new(Vec::new()));
let buffer_clone = buffer.clone();
let writer = SequentialWriter::new(Box::new(SharedBuffer(buffer_clone)), false, RunState::new_for_test());
let mut writer = SequentialWriter::new(Box::new(SharedBuffer(buffer_clone)), false, RunState::new_for_test());
// Create futures that resolve with delays
let f0: DataFuture = Box::pin(async {
@@ -682,7 +638,7 @@ mod tests {
#[tokio::test]
async fn test_size_mismatch_error() {
let buffer = std::io::Cursor::new(Vec::new());
let writer = SequentialWriter::new(Box::new(buffer), false, RunState::new_for_test());
let mut writer = SequentialWriter::new(Box::new(buffer), false, RunState::new_for_test());
writer
.set_next_term_data_source(FileRange::new(0, 10), None, immediate_future(Bytes::from("Hello")))
@@ -705,7 +661,7 @@ mod tests {
}
}
let writer = SequentialWriter::new(Box::new(FailingWriter), false, RunState::new_for_test());
let mut writer = SequentialWriter::new(Box::new(FailingWriter), false, RunState::new_for_test());
writer
.set_next_term_data_source(FileRange::new(0, 4), None, immediate_future(Bytes::from("Test")))
@@ -722,30 +678,6 @@ mod tests {
assert!(matches!(result, Err(FileReconstructionError::IoError(_))));
}
#[tokio::test]
async fn test_finish_twice_returns_error() {
let buffer = std::io::Cursor::new(Vec::new());
let writer = SequentialWriter::new(Box::new(buffer), false, RunState::new_for_test());
writer.finish().await.unwrap();
let result = writer.finish().await;
assert!(result.is_err());
assert!(matches!(result, Err(FileReconstructionError::InternalWriterError(_))));
}
#[tokio::test]
async fn test_write_after_finish_returns_error() {
let buffer = std::io::Cursor::new(Vec::new());
let writer = SequentialWriter::new(Box::new(buffer), false, RunState::new_for_test());
writer.finish().await.unwrap();
let result = writer
.set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
.await;
assert!(result.is_err());
assert!(matches!(result, Err(FileReconstructionError::InternalWriterError(_))));
}
#[tokio::test]
async fn test_flush_error_propagates() {
struct FlushFailingWriter;
@@ -769,7 +701,7 @@ mod tests {
let buffer = Arc::new(std::sync::Mutex::new(Vec::new()));
let buffer_clone = buffer.clone();
let writer = SequentialWriter::new(Box::new(SharedBuffer(buffer_clone)), false, RunState::new_for_test());
let mut writer = SequentialWriter::new(Box::new(SharedBuffer(buffer_clone)), false, RunState::new_for_test());
let failing_future: DataFuture =
Box::pin(async { Err(FileReconstructionError::InternalError("Simulated future error".to_string())) });
@@ -786,7 +718,7 @@ mod tests {
#[tokio::test]
async fn test_size_mismatch_too_small() {
let buffer = std::io::Cursor::new(Vec::new());
let writer = SequentialWriter::new(Box::new(buffer), false, RunState::new_for_test());
let mut writer = SequentialWriter::new(Box::new(buffer), false, RunState::new_for_test());
writer
.set_next_term_data_source(FileRange::new(0, 10), None, immediate_future(Bytes::from("Hi")))
@@ -800,7 +732,7 @@ mod tests {
#[tokio::test]
async fn test_size_mismatch_too_large() {
let buffer = std::io::Cursor::new(Vec::new());
let writer = SequentialWriter::new(Box::new(buffer), false, RunState::new_for_test());
let mut writer = SequentialWriter::new(Box::new(buffer), false, RunState::new_for_test());
writer
.set_next_term_data_source(FileRange::new(0, 2), None, immediate_future(Bytes::from("Hello World")))
@@ -816,7 +748,7 @@ mod tests {
let buffer = Arc::new(std::sync::Mutex::new(Vec::new()));
let buffer_clone = buffer.clone();
let writer = SequentialWriter::new(Box::new(SharedBuffer(buffer_clone)), false, RunState::new_for_test());
let mut writer = SequentialWriter::new(Box::new(SharedBuffer(buffer_clone)), false, RunState::new_for_test());
writer
.set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
@@ -841,7 +773,7 @@ mod tests {
#[tokio::test]
async fn test_non_sequential_range_returns_error() {
let buffer = std::io::Cursor::new(Vec::new());
let writer = SequentialWriter::new(Box::new(buffer), false, RunState::new_for_test());
let mut writer = SequentialWriter::new(Box::new(buffer), false, RunState::new_for_test());
writer
.set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
@@ -858,7 +790,7 @@ mod tests {
#[tokio::test]
async fn test_first_range_must_start_at_zero() {
let buffer = std::io::Cursor::new(Vec::new());
let writer = SequentialWriter::new(Box::new(buffer), false, RunState::new_for_test());
let mut writer = SequentialWriter::new(Box::new(buffer), false, RunState::new_for_test());
let result = writer
.set_next_term_data_source(FileRange::new(5, 10), None, immediate_future(Bytes::from("Hello")))
@@ -873,7 +805,7 @@ mod tests {
let buffer_clone = buffer.clone();
let semaphore = AdjustableSemaphore::new(2, (0, 2));
let writer = SequentialWriter::new(Box::new(SharedBuffer(buffer_clone)), false, RunState::new_for_test());
let mut writer = SequentialWriter::new(Box::new(SharedBuffer(buffer_clone)), false, RunState::new_for_test());
let permit1 = semaphore.acquire().await.unwrap();
let permit2 = semaphore.acquire().await.unwrap();
@@ -910,7 +842,7 @@ mod tests {
let buffer = test_writer.buffer.clone();
let vectored_count = test_writer.vectored_write_count.clone();
let writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
writer
.set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
@@ -937,7 +869,7 @@ mod tests {
let test_writer = TestWriter::new(TestWriterConfig::vectorized_partial(3));
let buffer = test_writer.buffer.clone();
let writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
writer
.set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
@@ -967,7 +899,7 @@ mod tests {
let test_writer = TestWriter::new(TestWriterConfig::vectorized());
let buffer = test_writer.buffer.clone();
let writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
// Create futures that resolve with different delays
let f0: DataFuture = Box::pin(async {
@@ -997,7 +929,7 @@ mod tests {
let buffer = test_writer.buffer.clone();
let vectored_count = test_writer.vectored_write_count.clone();
let writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
// Write 100 single-byte chunks
for i in 0..100u8 {
@@ -1026,7 +958,7 @@ mod tests {
let test_writer = TestWriter::new(TestWriterConfig::vectorized_with_interrupts());
let buffer = test_writer.buffer.clone();
let writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
writer
.set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
@@ -1053,7 +985,7 @@ mod tests {
let buffer = test_writer.buffer.clone();
let semaphore = AdjustableSemaphore::new(2, (0, 2));
let writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
let permit1 = semaphore.acquire().await.unwrap();
let permit2 = semaphore.acquire().await.unwrap();
@@ -1088,7 +1020,7 @@ mod tests {
let buffer = test_writer.buffer.clone();
let semaphore = AdjustableSemaphore::new(3, (0, 3));
let writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
let permit1 = semaphore.acquire().await.unwrap();
let permit2 = semaphore.acquire().await.unwrap();
@@ -1124,7 +1056,7 @@ mod tests {
let write_count = test_writer.write_count.clone();
let vectored_count = test_writer.vectored_write_count.clone();
let writer = SequentialWriter::new(Box::new(test_writer), false, RunState::new_for_test());
let mut writer = SequentialWriter::new(Box::new(test_writer), false, RunState::new_for_test());
writer
.set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
@@ -1152,7 +1084,7 @@ mod tests {
let test_writer = TestWriter::new(TestWriterConfig::partial(3));
let buffer = test_writer.buffer.clone();
let writer = SequentialWriter::new(Box::new(test_writer), false, RunState::new_for_test());
let mut writer = SequentialWriter::new(Box::new(test_writer), false, RunState::new_for_test());
writer
.set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
@@ -1182,7 +1114,7 @@ mod tests {
let test_writer = TestWriter::new(TestWriterConfig::vectorized_partial(1));
let buffer = test_writer.buffer.clone();
let writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
writer
.set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("ABCDE")))
@@ -1205,7 +1137,7 @@ mod tests {
let test_writer = TestWriter::new(TestWriterConfig::vectorized());
let buffer = test_writer.buffer.clone();
let writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
// Write in chunks of 1000 bytes
for i in 0..10 {
@@ -1234,7 +1166,7 @@ mod tests {
let test_writer = TestWriter::new(TestWriterConfig::vectorized_partial(100));
let buffer = test_writer.buffer.clone();
let writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
// Write in chunks of 500 bytes
for i in 0..10 {
@@ -1261,7 +1193,7 @@ mod tests {
async fn test_vectorized_exceeded_max_slice() {
let test_writer = TestWriter::new(TestWriterConfig::vectorized_hard_limit(2)); // hard limit set to 2 slices at a time
let writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test()); // controlled writev at max 24 slices at a time
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test()); // controlled writev at max 24 slices at a time
// Write in slices of 10 bytes, creating in total 1000 slices
for i in 0..1000 {
@@ -1294,7 +1226,7 @@ mod tests {
let test_writer = TestWriter::new(TestWriterConfig::vectorized_hard_limit(40)); // hard limit set to 40 slices at a time
let buffer = test_writer.buffer.clone();
let writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test()); // controlled writev at max 24 slices at a time
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test()); // controlled writev at max 24 slices at a time
// Write in slices of 10 bytes, creating in total 1000 slices
for i in 0..1000 {

View File

@@ -0,0 +1,239 @@
use std::sync::Arc;
use std::sync::atomic::Ordering;
use bytes::Bytes;
use tokio::sync::Notify;
use tokio::sync::mpsc::UnboundedReceiver;
use tracing::info;
use super::super::error::Result;
use super::super::file_reconstructor::FileReconstructor;
use super::super::run_state::RunState;
use super::unordered_writer::{CompletedTerm, UnorderedWriterProgress};
/// A streaming download handle that yields data chunks in completion order,
/// each tagged with its byte offset in the output file.
///
/// Created by [`FileReconstructor::reconstruct_to_unordered_stream`]. The
/// reconstruction task is spawned immediately but pauses until
/// [`start`](Self::start) is called (or the first [`next`](Self::next) /
/// [`blocking_next`](Self::blocking_next)). Because the `tokio::spawn`
/// happens at construction time, subsequent calls to `start()`, `next()`,
/// and `blocking_next()` do **not** require a tokio runtime context.
///
/// Unlike [`DownloadStream`](super::download_stream::DownloadStream), data
/// chunks may arrive out of order. Each chunk is returned as `(offset, Bytes)`
/// so the consumer knows where it belongs. Progress can be monitored via
/// the tracking methods which read shared atomic counters.
///
/// Holds only `Arc<WriterProgress>`, not the writer itself, so the channel
/// sender is dropped naturally when the reconstruction task finishes.
pub struct UnorderedDownloadStream {
/// Shared atomic progress counters (also held by the writer and its tasks).
progress: Arc<UnorderedWriterProgress>,
/// Channel receiver for completed terms from spawned tasks.
receiver: UnboundedReceiver<Result<CompletedTerm>>,
/// Whether the stream has finished (no more data).
finished: bool,
/// Shared run state with the `FileReconstructor`.
run_state: Arc<RunState>,
/// Signal to unblock the spawned reconstruction task. `Some` means
/// `start()` has not yet been called; the spawned task is waiting.
start_signal: Option<Arc<Notify>>,
}
impl UnorderedDownloadStream {
/// Creates a new `UnorderedDownloadStream`, immediately spawning the
/// reconstruction task on the current tokio runtime. The task blocks
/// on an internal [`Notify`] until [`start`](Self::start) is called.
///
/// # Panics
///
/// Panics if called outside a tokio runtime context.
pub(crate) fn new(reconstructor: FileReconstructor, run_state: Arc<RunState>) -> Self {
use super::unordered_writer::UnorderedWriter;
let (writer, receiver, progress) = UnorderedWriter::new_streaming(run_state.clone());
let start_signal = Arc::new(Notify::new());
let signal = start_signal.clone();
let rs = run_state.clone();
tokio::spawn(async move {
signal.notified().await;
info!(file_hash = %rs.file_hash(), "Starting unordered download stream");
let _ = reconstructor.run(writer, rs, true).await;
});
Self {
progress,
receiver,
finished: false,
run_state,
start_signal: Some(start_signal),
}
}
pub(crate) fn abort_callback(&self) -> Box<dyn Fn() + Send + Sync> {
let run_state = self.run_state.clone();
let start_signal = self.start_signal.clone();
Box::new(move || {
run_state.cancel();
if let Some(signal) = start_signal.as_ref() {
signal.notify_one();
}
})
}
/// Unblocks the reconstruction task so it begins producing data.
///
/// If already started, this is a no-op. Called automatically on the first
/// [`next`](Self::next) / [`blocking_next`](Self::blocking_next).
///
/// This method is non-async and does not require a tokio runtime context.
pub fn start(&mut self) {
if let Some(signal) = self.start_signal.take() {
signal.notify_one();
}
}
fn ensure_started(&mut self) {
if self.start_signal.is_some() {
self.start();
}
}
fn cancel_reconstruction(&self) {
self.run_state.cancel();
if let Some(signal) = self.start_signal.as_ref() {
signal.notify_one();
}
}
/// Returns the next chunk of downloaded data with its byte offset,
/// blocking the current thread until data is available.
///
/// Returns `Ok(None)` when the download is complete.
///
/// # Panics
///
/// Panics if called from within an async runtime context. Use from a
/// regular thread or from [`tokio::task::spawn_blocking`] instead.
/// For the async-safe variant, use [`next`](Self::next).
pub fn blocking_next(&mut self) -> Result<Option<(u64, Bytes)>> {
if self.finished {
return Ok(None);
}
self.ensure_started();
match self.receiver.blocking_recv() {
Some(result) => self.process_term(result),
None => {
self.finished = true;
self.run_state.check_error()?;
Ok(None)
},
}
}
/// Returns the next chunk of downloaded data with its byte offset
/// asynchronously.
///
/// Returns `Ok(None)` when the download is complete.
pub async fn next(&mut self) -> Result<Option<(u64, Bytes)>> {
if self.finished {
return Ok(None);
}
self.ensure_started();
if let Ok(result) = self.receiver.try_recv() {
return self.process_term(result);
}
let next_item = tokio::select! {
biased;
recv = self.receiver.recv() => recv,
_ = self.run_state.cancelled() => None,
};
match next_item {
Some(result) => self.process_term(result),
None => {
self.finished = true;
self.run_state.check_error()?;
Ok(None)
},
}
}
fn process_term(&mut self, result: Result<CompletedTerm>) -> Result<Option<(u64, Bytes)>> {
let term = result?;
self.run_state.report_bytes_written(term.data.len() as u64);
let offset = term.byte_range.start;
let data = term.data;
drop(term.permit);
Ok(Some((offset, data)))
}
/// Cancels the in-progress (or not-yet-started) download.
///
/// Signals the shared run state so the reconstruction loop aborts at its
/// next check point. After calling this, subsequent calls to
/// [`blocking_next`](Self::blocking_next) / [`next`](Self::next) will
/// return `Ok(None)`.
pub fn cancel(&mut self) {
self.cancel_reconstruction();
let _ = self.start_signal.take();
self.receiver.close();
self.finished = true;
}
// ── Tracking methods ─────────────────────────────────────────────────
/// Total bytes expected for the reconstruction, read from the progress
/// updater. Returns 0 if not yet known or no progress updater is set.
pub fn total_bytes_expected(&self) -> u64 {
self.run_state
.progress_updater()
.map(|u| u.item().total_bytes.load(Ordering::Acquire))
.unwrap_or(0)
}
/// Bytes currently being fetched by in-progress tasks.
pub fn bytes_in_progress(&self) -> u64 {
self.progress.bytes_in_progress()
}
/// Bytes that have been delivered through the progress updater.
/// Returns 0 if no progress updater is set.
pub fn bytes_completed(&self) -> u64 {
self.run_state
.progress_updater()
.map(|u| u.total_bytes_completed())
.unwrap_or(0)
}
/// Number of tasks currently resolving data futures.
pub fn terms_in_progress(&self) -> u64 {
self.progress.terms_in_progress()
}
/// Returns `true` once the stream has reached terminal state.
///
/// This flips to `true` after [`next`](Self::next) / [`blocking_next`](Self::blocking_next)
/// has observed the end-of-stream (`None`), or after [`cancel`](Self::cancel).
/// Buffered but unconsumed channel items do not count as complete.
pub fn is_complete(&self) -> bool {
self.finished
}
}
impl Drop for UnorderedDownloadStream {
fn drop(&mut self) {
self.cancel_reconstruction();
self.receiver.close();
}
}

View File

@@ -0,0 +1,529 @@
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use bytes::Bytes;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
use tokio::task::JoinSet;
use xet_client::cas_types::FileRange;
use xet_runtime::utils::adjustable_semaphore::AdjustableSemaphorePermit;
use super::super::data_writer::{DataFuture, DataWriter};
use super::super::run_state::RunState;
use super::super::{FileReconstructionError, Result};
/// A completed term ready for consumption. Contains the byte range indicating
/// where this data belongs in the output file, the actual data bytes, and an
/// optional semaphore permit for backpressure control.
pub(crate) struct CompletedTerm {
pub byte_range: FileRange,
pub data: Bytes,
pub permit: Option<AdjustableSemaphorePermit>,
}
/// Atomic progress counters shared between the writer, its spawned tasks,
/// and the consumer stream. Wrapped in an `Arc` so each party can read/update
/// counters without holding a reference to the full `UnorderedWriter`.
pub(crate) struct UnorderedWriterProgress {
pub terms_in_progress: AtomicU64,
pub bytes_in_progress: AtomicU64,
}
impl UnorderedWriterProgress {
pub fn terms_in_progress(&self) -> u64 {
self.terms_in_progress.load(Ordering::Acquire)
}
pub fn bytes_in_progress(&self) -> u64 {
self.bytes_in_progress.load(Ordering::Relaxed)
}
}
/// Writer that delivers completed data terms in arbitrary order.
///
/// Each call to [`set_next_term_data_source`](DataWriter::set_next_term_data_source)
/// spawns a task (tracked via a [`JoinSet`]) that resolves the data future and
/// sends the result through an [`mpsc`](tokio::sync::mpsc) channel. The consumer
/// (typically an [`UnorderedDownloadStream`](super::unordered_download_stream::UnorderedDownloadStream))
/// reads from the receiver end and gets items in whatever order tasks complete.
///
/// The consumer stream holds only `Arc<UnorderedWriterProgress>`, not the writer
/// itself, so the writer's channel sender is dropped naturally when the
/// reconstruction task finishes and consumes the writer via
/// [`finish()`](DataWriter::finish).
pub struct UnorderedWriter {
result_tx: UnboundedSender<Result<CompletedTerm>>,
run_state: Arc<RunState>,
progress: Arc<UnorderedWriterProgress>,
task_set: JoinSet<Result<u64>>,
total_bytes_sent: u64,
finished: bool,
}
impl Drop for UnorderedWriter {
fn drop(&mut self) {
if !self.finished {
self.run_state.cancel();
}
}
}
#[async_trait::async_trait]
impl DataWriter for UnorderedWriter {
async fn set_next_term_data_source(
&mut self,
byte_range: FileRange,
permit: Option<AdjustableSemaphorePermit>,
data_future: DataFuture,
) -> Result<()> {
self.run_state.check_error()?;
while let Some(result) = self.task_set.try_join_next() {
self.total_bytes_sent +=
result.map_err(|e| FileReconstructionError::InternalError(format!("Task join error: {e}")))??;
}
if self.finished {
return Err(FileReconstructionError::InternalWriterError("Writer has already finished".to_string()));
}
let expected_size = byte_range.end - byte_range.start;
self.progress.terms_in_progress.fetch_add(1, Ordering::Relaxed);
self.progress.bytes_in_progress.fetch_add(expected_size, Ordering::Relaxed);
let result_tx = self.result_tx.clone();
let run_state = self.run_state.clone();
let progress = self.progress.clone();
self.task_set.spawn(async move {
let result = async {
run_state.check_error()?;
let data = data_future.await?;
if data.len() as u64 != expected_size {
return Err(FileReconstructionError::InternalWriterError(format!(
"Data size mismatch: expected {} bytes, got {} bytes",
expected_size,
data.len()
)));
}
Ok(CompletedTerm {
byte_range,
data,
permit,
})
}
.await;
if let Err(ref e) = result {
run_state.set_error(e.clone());
}
let completed_bytes = result.as_ref().map(|t| t.data.len() as u64).unwrap_or(0);
let _ = result_tx.send(result);
progress.bytes_in_progress.fetch_sub(expected_size, Ordering::Relaxed);
progress.terms_in_progress.fetch_sub(1, Ordering::Release);
if completed_bytes > 0 {
Ok(completed_bytes)
} else {
run_state.check_error()?;
Ok(0)
}
});
Ok(())
}
async fn finish(mut self: Box<Self>) -> Result<u64> {
self.run_state.check_error()?;
while let Some(result) = self.task_set.join_next().await {
self.total_bytes_sent +=
result.map_err(|e| FileReconstructionError::InternalError(format!("Task join error: {e}")))??;
}
self.finished = true;
Ok(self.total_bytes_sent)
}
}
impl UnorderedWriter {
/// Creates an unordered writer for streaming use. Returns the writer (to be
/// passed to the reconstruction task as `Box<dyn DataWriter>`), the receiver
/// end of the channel, and the shared progress counters for the consumer.
///
/// The consumer stream should hold only the `Arc<UnorderedWriterProgress>`,
/// **not** the writer itself. This way the channel sender is dropped
/// naturally when the reconstruction task finishes (consuming the writer
/// via `finish()`), closing the channel without explicit lifetime management.
pub(crate) fn new_streaming(
run_state: Arc<RunState>,
) -> (Box<dyn DataWriter>, UnboundedReceiver<Result<CompletedTerm>>, Arc<UnorderedWriterProgress>) {
let (tx, rx) = unbounded_channel();
let progress = Arc::new(UnorderedWriterProgress {
terms_in_progress: AtomicU64::new(0),
bytes_in_progress: AtomicU64::new(0),
});
let writer = Box::new(UnorderedWriter {
result_tx: tx,
run_state,
progress: progress.clone(),
task_set: JoinSet::new(),
total_bytes_sent: 0,
finished: false,
});
(writer, rx, progress)
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use xet_runtime::utils::adjustable_semaphore::AdjustableSemaphore;
use super::*;
fn immediate_future(data: Bytes) -> DataFuture {
Box::pin(async move { Ok(data) })
}
fn delayed_future(data: Bytes, delay: Duration) -> DataFuture {
Box::pin(async move {
tokio::time::sleep(delay).await;
Ok(data)
})
}
/// Drains all results from the receiver, returning data sorted by offset.
/// The writer must have been dropped (after calling `finish()`) so that
/// the channel closes naturally when all spawned tasks complete.
async fn drain_sorted(rx: &mut UnboundedReceiver<Result<CompletedTerm>>) -> Result<Vec<(u64, Bytes)>> {
let mut items = Vec::new();
while let Some(result) = rx.recv().await {
let term = result?;
items.push((term.byte_range.start, term.data));
drop(term.permit);
}
items.sort_by_key(|(offset, _)| *offset);
Ok(items)
}
#[tokio::test]
async fn test_basic_unordered_writes() {
let run_state = RunState::new_for_test();
let (mut writer, mut rx, _progress) = UnorderedWriter::new_streaming(run_state);
writer
.set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
.await
.unwrap();
writer
.set_next_term_data_source(FileRange::new(5, 6), None, immediate_future(Bytes::from(" ")))
.await
.unwrap();
writer
.set_next_term_data_source(FileRange::new(6, 11), None, immediate_future(Bytes::from("World")))
.await
.unwrap();
let total = writer.finish().await.unwrap();
assert_eq!(total, 11);
let items = drain_sorted(&mut rx).await.unwrap();
let assembled: Vec<u8> = items.into_iter().flat_map(|(_, data)| data.to_vec()).collect();
assert_eq!(&assembled, b"Hello World");
}
#[tokio::test]
async fn test_delayed_futures_complete_out_of_order() {
let run_state = RunState::new_for_test();
let (mut writer, mut rx, _progress) = UnorderedWriter::new_streaming(run_state);
writer
.set_next_term_data_source(
FileRange::new(0, 5),
None,
delayed_future(Bytes::from("Hello"), Duration::from_millis(80)),
)
.await
.unwrap();
writer
.set_next_term_data_source(
FileRange::new(5, 6),
None,
delayed_future(Bytes::from(" "), Duration::from_millis(40)),
)
.await
.unwrap();
writer
.set_next_term_data_source(FileRange::new(6, 11), None, immediate_future(Bytes::from("World")))
.await
.unwrap();
let total = writer.finish().await.unwrap();
assert_eq!(total, 11);
let items = drain_sorted(&mut rx).await.unwrap();
let assembled: Vec<u8> = items.into_iter().flat_map(|(_, data)| data.to_vec()).collect();
assert_eq!(&assembled, b"Hello World");
}
#[tokio::test]
async fn test_size_mismatch_error() {
let run_state = RunState::new_for_test();
let (mut writer, mut rx, _progress) = UnorderedWriter::new_streaming(run_state);
writer
.set_next_term_data_source(FileRange::new(0, 10), None, immediate_future(Bytes::from("Hello")))
.await
.unwrap();
let result = writer.finish().await;
assert!(result.is_err());
let result = rx.recv().await.unwrap();
assert!(result.is_err());
assert!(matches!(result, Err(FileReconstructionError::InternalWriterError(_))));
}
#[tokio::test]
async fn test_future_error_propagates() {
let run_state = RunState::new_for_test();
let (mut writer, mut rx, _progress) = UnorderedWriter::new_streaming(run_state);
let failing_future: DataFuture =
Box::pin(async { Err(FileReconstructionError::InternalError("Simulated error".to_string())) });
writer
.set_next_term_data_source(FileRange::new(0, 5), None, failing_future)
.await
.unwrap();
let result = writer.finish().await;
assert!(result.is_err());
let result = rx.recv().await.unwrap();
assert!(result.is_err());
}
#[tokio::test]
async fn test_semaphore_permit_released_after_consumption() {
let run_state = RunState::new_for_test();
let (mut writer, mut rx, _progress) = UnorderedWriter::new_streaming(run_state);
let semaphore = AdjustableSemaphore::new(2, (0, 2));
let permit1 = semaphore.acquire().await.unwrap();
let permit2 = semaphore.acquire().await.unwrap();
assert_eq!(semaphore.available_permits(), 0);
writer
.set_next_term_data_source(FileRange::new(0, 5), Some(permit1), immediate_future(Bytes::from("Hello")))
.await
.unwrap();
writer
.set_next_term_data_source(FileRange::new(5, 6), Some(permit2), immediate_future(Bytes::from(" ")))
.await
.unwrap();
writer.finish().await.unwrap();
let items = drain_sorted(&mut rx).await.unwrap();
drop(items);
assert_eq!(semaphore.available_permits(), 2);
}
#[tokio::test]
async fn test_counter_accuracy() {
let run_state = RunState::new_for_test();
let (mut writer, mut rx, progress) = UnorderedWriter::new_streaming(run_state);
writer
.set_next_term_data_source(
FileRange::new(0, 5),
None,
delayed_future(Bytes::from("Hello"), Duration::from_millis(50)),
)
.await
.unwrap();
writer
.set_next_term_data_source(
FileRange::new(5, 11),
None,
delayed_future(Bytes::from(" World"), Duration::from_millis(50)),
)
.await
.unwrap();
let total = writer.finish().await.unwrap();
assert_eq!(total, 11);
let _items = drain_sorted(&mut rx).await.unwrap();
assert_eq!(progress.bytes_in_progress(), 0);
assert_eq!(progress.terms_in_progress(), 0);
}
#[tokio::test]
async fn test_finish_returns_total_bytes() {
let run_state = RunState::new_for_test();
let (mut writer, mut rx, _progress) = UnorderedWriter::new_streaming(run_state);
writer
.set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
.await
.unwrap();
writer
.set_next_term_data_source(FileRange::new(5, 11), None, immediate_future(Bytes::from(" World")))
.await
.unwrap();
let total = writer.finish().await.unwrap();
assert_eq!(total, 11);
let _items = drain_sorted(&mut rx).await.unwrap();
}
#[tokio::test]
async fn test_error_propagation_prevents_subsequent_writes() {
let run_state = RunState::new_for_test();
let (mut writer, mut _rx, _progress) = UnorderedWriter::new_streaming(run_state.clone());
let failing_future: DataFuture =
Box::pin(async { Err(FileReconstructionError::InternalError("fail".to_string())) });
writer
.set_next_term_data_source(FileRange::new(0, 5), None, failing_future)
.await
.unwrap();
let wait_for_error = tokio::time::timeout(Duration::from_secs(1), async {
loop {
if run_state.check_error().is_err() {
break;
}
tokio::task::yield_now().await;
}
})
.await;
assert!(wait_for_error.is_ok());
let result = writer
.set_next_term_data_source(FileRange::new(5, 10), None, immediate_future(Bytes::from("World")))
.await;
assert!(result.is_err());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_test_many_concurrent_terms() {
let run_state = RunState::new_for_test();
let (mut writer, mut rx, _progress) = UnorderedWriter::new_streaming(run_state);
let num_terms: usize = 100;
let mut expected: Vec<(u64, Vec<u8>)> = Vec::new();
let mut offset = 0u64;
for i in 0..num_terms {
let size = 100 + (i % 50) * 10;
let data: Vec<u8> = (0..size).map(|j| ((i * 7 + j * 13) % 256) as u8).collect();
let bytes = Bytes::from(data.clone());
expected.push((offset, data));
let delay = Duration::from_micros((i % 10) as u64 * 100);
writer
.set_next_term_data_source(
FileRange::new(offset, offset + size as u64),
None,
delayed_future(bytes, delay),
)
.await
.unwrap();
offset += size as u64;
}
let total = writer.finish().await.unwrap();
assert_eq!(total, offset);
let items = drain_sorted(&mut rx).await.unwrap();
assert_eq!(items.len(), num_terms);
for ((exp_offset, exp_data), (act_offset, act_data)) in expected.iter().zip(items.iter()) {
assert_eq!(*exp_offset, *act_offset);
assert_eq!(exp_data.as_slice(), act_data.as_ref());
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_test_rapid_finish_after_writes() {
for _ in 0..50 {
let run_state = RunState::new_for_test();
let (mut writer, mut rx, _progress) = UnorderedWriter::new_streaming(run_state);
for i in 0..10u64 {
let data = Bytes::from(vec![i as u8; 100]);
writer
.set_next_term_data_source(FileRange::new(i * 100, (i + 1) * 100), None, immediate_future(data))
.await
.unwrap();
}
let total = writer.finish().await.unwrap();
assert_eq!(total, 1000);
let items = drain_sorted(&mut rx).await.unwrap();
assert_eq!(items.len(), 10);
let total_bytes: usize = items.iter().map(|(_, data)| data.len()).sum();
assert_eq!(total_bytes, 1000);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_test_mixed_immediate_and_delayed() {
for _ in 0..20 {
let run_state = RunState::new_for_test();
let (mut writer, mut rx, progress) = UnorderedWriter::new_streaming(run_state);
let mut offset = 0u64;
let mut total_size = 0u64;
let num_terms = 30usize;
for i in 0..num_terms {
let size = ((i + 1) * 50) as u64;
let data = Bytes::from(vec![(i % 256) as u8; size as usize]);
total_size += size;
let future = if i % 3 == 0 {
delayed_future(data, Duration::from_millis((i % 5) as u64))
} else {
immediate_future(data)
};
writer
.set_next_term_data_source(FileRange::new(offset, offset + size), None, future)
.await
.unwrap();
offset += size;
}
let total = writer.finish().await.unwrap();
assert_eq!(total, total_size);
let items = drain_sorted(&mut rx).await.unwrap();
assert_eq!(items.len(), num_terms);
let received_bytes: u64 = items.iter().map(|(_, data)| data.len() as u64).sum();
assert_eq!(received_bytes, total_size);
assert_eq!(progress.terms_in_progress(), 0);
}
}
}

View File

@@ -14,7 +14,7 @@ use xet_runtime::core::{XetRuntime, xet_config};
use xet_runtime::utils::ClosureGuard;
use xet_runtime::utils::adjustable_semaphore::AdjustableSemaphore;
use super::data_writer::{DataWriter, DownloadStream, SequentialWriter};
use super::data_writer::{DataWriter, DownloadStream, SequentialWriter, UnorderedDownloadStream};
use super::error::{FileReconstructionError, Result};
use super::reconstruction_terms::ReconstructionTermManager;
use super::run_state::{RunError, RunState};
@@ -124,7 +124,9 @@ impl FileReconstructor {
}
let run_state = RunState::new(self.cancellation_token.clone(), self.file_hash, self.progress_updater.clone());
let data_writer = SequentialWriter::new(file, self.config.use_vectored_write, run_state.clone());
self.run(data_writer, run_state, false).await
}
@@ -147,15 +149,41 @@ impl FileReconstructor {
/// Reconstructs the file as a stream, returning a [`DownloadStream`] that
/// yields data chunks as they become available.
///
/// The reconstruction task is **not** started immediately. It begins when
/// [`DownloadStream::start`] is called, or automatically on the first call
/// to [`DownloadStream::next`] / [`DownloadStream::blocking_next`].
/// The reconstruction task is spawned immediately but pauses on an
/// internal [`tokio::sync::Notify`] until [`DownloadStream::start`] is
/// called (or the first [`DownloadStream::next`] /
/// [`DownloadStream::blocking_next`]).
///
/// # Panics
///
/// Panics if called outside a tokio runtime context (the constructor
/// uses [`tokio::spawn`]).
pub fn reconstruct_to_stream(self) -> DownloadStream {
let run_state = RunState::new(self.cancellation_token.clone(), self.file_hash, self.progress_updater.clone());
DownloadStream::new(self, run_state)
}
/// Reconstructs the file as an unordered stream, returning an
/// [`UnorderedDownloadStream`] that yields `(offset, Bytes)` chunks
/// in whatever order they complete.
///
/// The reconstruction task is spawned immediately but pauses on an
/// internal [`tokio::sync::Notify`] until
/// [`UnorderedDownloadStream::start`] is called (or the first
/// [`UnorderedDownloadStream::next`] /
/// [`UnorderedDownloadStream::blocking_next`]).
///
/// # Panics
///
/// Panics if called outside a tokio runtime context (the constructor
/// uses [`tokio::spawn`]).
pub fn reconstruct_to_unordered_stream(self) -> UnorderedDownloadStream {
let run_state = RunState::new(self.cancellation_token.clone(), self.file_hash, self.progress_updater.clone());
UnorderedDownloadStream::new(self, run_state)
}
/// Runs the file reconstruction with error handling and cancellation support.
/// Returns the number of bytes written.
///
@@ -164,7 +192,7 @@ impl FileReconstructor {
/// asynchronously after this method returns.
pub(crate) async fn run(
self,
data_writer: Arc<dyn DataWriter>,
data_writer: Box<dyn DataWriter>,
run_state: Arc<RunState>,
is_streaming: bool,
) -> Result<u64> {
@@ -183,7 +211,7 @@ impl FileReconstructor {
async fn run_impl(
self,
data_writer: Arc<dyn DataWriter>,
mut data_writer: Box<dyn DataWriter>,
run_state: &RunState,
_is_streaming: bool,
) -> std::result::Result<u64, RunError> {

View File

@@ -4,7 +4,7 @@ mod file_reconstructor;
mod reconstruction_terms;
mod run_state;
pub use data_writer::{DataWriter, DownloadStream, SequentialWriter};
pub use data_writer::{DataWriter, DownloadStream, SequentialWriter, UnorderedDownloadStream, UnorderedWriter};
pub use error::{FileReconstructionError, Result};
pub use file_reconstructor::FileReconstructor;
pub use reconstruction_terms::{FileTerm, ReconstructionTermManager, XorbBlock, XorbBlockData};

View File

@@ -1,9 +1,10 @@
use std::borrow::Cow;
use std::collections::HashMap;
use std::io::Write;
use std::ops::{Bound, RangeBounds};
use std::ops::{Bound, Range, RangeBounds};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use tokio::task::JoinHandle;
use tracing::instrument;
@@ -15,7 +16,7 @@ use super::configurations::TranslatorConfig;
use super::remote_client_interface::create_remote_client;
use super::{XetFileInfo, prometheus_metrics};
use crate::error::{DataError, Result};
use crate::file_reconstruction::{DownloadStream, FileReconstructor};
use crate::file_reconstruction::{DownloadStream, FileReconstructor, UnorderedDownloadStream};
use crate::progress_tracking::{GroupProgress, ItemProgressUpdater, UniqueID};
/// Manages the downloading of files from CAS storage.
@@ -25,6 +26,7 @@ use crate::progress_tracking::{GroupProgress, ItemProgressUpdater, UniqueID};
pub struct FileDownloadSession {
client: Arc<dyn Client>,
progress: Arc<GroupProgress>,
active_stream_abort_callbacks: Mutex<HashMap<UniqueID, Box<dyn Fn() + Send + Sync>>>,
finalized: AtomicBool,
}
@@ -46,6 +48,7 @@ impl FileDownloadSession {
Ok(Arc::new(Self {
client,
progress,
active_stream_abort_callbacks: Mutex::new(HashMap::new()),
finalized: AtomicBool::new(false),
}))
}
@@ -59,6 +62,7 @@ impl FileDownloadSession {
Arc::new(Self {
client,
progress,
active_stream_abort_callbacks: Mutex::new(HashMap::new()),
finalized: AtomicBool::new(false),
})
}
@@ -71,10 +75,25 @@ impl FileDownloadSession {
self.progress.item_report(id)
}
pub fn item_reports(&self) -> std::collections::HashMap<UniqueID, crate::progress_tracking::ItemProgressReport> {
pub fn item_reports(&self) -> HashMap<UniqueID, crate::progress_tracking::ItemProgressReport> {
self.progress.item_reports()
}
fn register_stream_abort_callback(&self, id: UniqueID, callback: Box<dyn Fn() + Send + Sync>) {
self.active_stream_abort_callbacks.lock().unwrap().insert(id, callback);
}
pub fn unregister_stream_abort_callback(&self, id: UniqueID) {
self.active_stream_abort_callbacks.lock().unwrap().remove(&id);
}
pub fn abort_active_streams(&self) {
let callbacks = self.active_stream_abort_callbacks.lock().unwrap();
for callback in callbacks.values() {
callback();
}
}
/// Spawns a download task that writes `file_info` to `write_path`.
///
/// Acquires a permit from the global download semaphore before starting.
@@ -170,17 +189,60 @@ impl FileDownloadSession {
Ok((id, n_bytes))
}
/// Creates a streaming download of a file.
/// Creates a streaming download of a file, optionally restricted to a
/// byte range.
///
/// Returns a [`DownloadStream`] that yields data chunks as the file is
/// reconstructed. Reconstruction starts lazily on first
/// [`DownloadStream::next`] / [`DownloadStream::blocking_next`] call
/// (or when `start()` is called explicitly).
///
/// If `source_range` is `Some`, only the specified byte range of the
/// file is reconstructed.
///
/// This path does not acquire the session-level file download semaphore.
#[instrument(skip_all, name = "FileDownloadSession::download_stream", fields(hash = file_info.hash()))]
pub async fn download_stream(&self, file_info: &XetFileInfo) -> Result<(UniqueID, DownloadStream)> {
self.download_stream_range(file_info, ..).await
pub async fn download_stream(
&self,
file_info: &XetFileInfo,
source_range: Option<Range<u64>>,
) -> Result<(UniqueID, DownloadStream)> {
self.check_not_finalized()?;
let id = UniqueID::new();
let progress_updater = self.progress.new_item(id, "stream");
let range = source_range.map(|r| FileRange::new(r.start, r.end));
let reconstructor = self.setup_reconstructor(file_info, range, Some(progress_updater))?;
let stream = reconstructor.reconstruct_to_stream();
self.register_stream_abort_callback(id, stream.abort_callback());
Ok((id, stream))
}
/// Creates an unordered streaming download of a file, optionally
/// restricted to a byte range.
///
/// Returns an [`UnorderedDownloadStream`] that yields `(offset, Bytes)`
/// chunks in whatever order they complete. The total expected size is
/// set from the range length (or `file_info.file_size()` when no range
/// is given).
///
/// If `source_range` is `Some`, only the specified byte range of the
/// file is reconstructed.
///
/// This path does not acquire the session-level file download semaphore.
#[instrument(skip_all, name = "FileDownloadSession::download_unordered_stream", fields(hash = file_info.hash()))]
pub async fn download_unordered_stream(
&self,
file_info: &XetFileInfo,
source_range: Option<Range<u64>>,
) -> Result<(UniqueID, UnorderedDownloadStream)> {
self.check_not_finalized()?;
let id = UniqueID::new();
let progress_updater = self.progress.new_item(id, "unordered_stream");
let range = source_range.map(|r| FileRange::new(r.start, r.end));
let reconstructor = self.setup_reconstructor(file_info, range, Some(progress_updater))?;
let stream = reconstructor.reconstruct_to_unordered_stream();
self.register_stream_abort_callback(id, stream.abort_callback());
Ok((id, stream))
}
/// Creates a streaming download of a byte range of a file.
@@ -199,7 +261,9 @@ impl FileDownloadSession {
let id = UniqueID::new();
let progress_updater = self.progress.new_item(id, "stream");
let reconstructor = self.setup_reconstructor(file_info, file_range, Some(progress_updater))?;
Ok((id, reconstructor.reconstruct_to_stream()))
let stream = reconstructor.reconstruct_to_stream();
self.register_stream_abort_callback(id, stream.abort_callback());
Ok((id, stream))
}
fn check_not_finalized(&self) -> Result<()> {
if self.finalized.load(Ordering::Acquire) {
@@ -512,7 +576,7 @@ mod tests {
let config = TranslatorConfig::local_config(&cas_path).unwrap();
let session = FileDownloadSession::new(config.into()).await.unwrap();
let (_id, mut stream) = session.download_stream(&xfi).await.unwrap();
let (_id, mut stream) = session.download_stream(&xfi, None).await.unwrap();
let mut collected = Vec::new();
while let Some(chunk) = stream.next().await.unwrap() {
@@ -539,7 +603,7 @@ mod tests {
let config = TranslatorConfig::local_config(&cas_path).unwrap();
let session = FileDownloadSession::new(config.into()).await.unwrap();
let (_id, stream) = session.download_stream(&xfi).await.unwrap();
let (_id, stream) = session.download_stream(&xfi, None).await.unwrap();
let collected = tokio::task::spawn_blocking(move || {
let mut stream = stream;
@@ -572,7 +636,7 @@ mod tests {
let config = TranslatorConfig::local_config(&cas_path).unwrap();
let session = FileDownloadSession::new(config.into()).await.unwrap();
let (_id, mut stream) = session.download_stream(&xfi).await.unwrap();
let (_id, mut stream) = session.download_stream(&xfi, None).await.unwrap();
while stream.next().await.unwrap().is_some() {}
@@ -601,8 +665,8 @@ mod tests {
let config = TranslatorConfig::local_config(&cas_path).unwrap();
let session = FileDownloadSession::new(config.into()).await.unwrap();
let (_id_a, mut stream_a) = session.download_stream(&xfi_a).await.unwrap();
let (_id_b, mut stream_b) = session.download_stream(&xfi_b).await.unwrap();
let (_id_a, mut stream_a) = session.download_stream(&xfi_a, None).await.unwrap();
let (_id_b, mut stream_b) = session.download_stream(&xfi_b, None).await.unwrap();
let task_a = tokio::spawn(async move {
let mut buf = Vec::new();
@@ -644,7 +708,7 @@ mod tests {
let config = TranslatorConfig::local_config(&cas_path).unwrap();
let session = FileDownloadSession::new(config.into()).await.unwrap();
let (_id, stream) = session.download_stream(&xfi).await.unwrap();
let (_id, stream) = session.download_stream(&xfi, None).await.unwrap();
drop(stream);
tokio::task::yield_now().await;
@@ -671,7 +735,7 @@ mod tests {
let session = FileDownloadSession::new(config.into()).await.unwrap();
for i in 0..5u32 {
let (_id, mut stream) = session.download_stream(&xfi).await.unwrap();
let (_id, mut stream) = session.download_stream(&xfi, None).await.unwrap();
if i % 3 == 0 {
let _ = stream.next().await;
}
@@ -701,7 +765,7 @@ mod tests {
let config = TranslatorConfig::local_config(&cas_path).unwrap();
let session = FileDownloadSession::new(config.into()).await.unwrap();
let (_id, stream) = session.download_stream(&xfi).await.unwrap();
let (_id, stream) = session.download_stream(&xfi, None).await.unwrap();
tokio::task::spawn_blocking(move || {
let mut stream = stream;
@@ -734,7 +798,7 @@ mod tests {
let config = TranslatorConfig::local_config(&cas_path).unwrap();
let session = FileDownloadSession::new(config.into()).await.unwrap();
let (_id, mut stream) = session.download_stream(&xfi).await.unwrap();
let (_id, mut stream) = session.download_stream(&xfi, None).await.unwrap();
stream.cancel();
assert!(stream.next().await.unwrap().is_none());
assert!(stream.next().await.unwrap().is_none());
@@ -757,7 +821,7 @@ mod tests {
let config = TranslatorConfig::local_config(&cas_path).unwrap();
let session = FileDownloadSession::new(config.into()).await.unwrap();
let (_id, mut stream) = session.download_stream(&xfi).await.unwrap();
let (_id, mut stream) = session.download_stream(&xfi, None).await.unwrap();
let _ = stream.next().await.unwrap();
stream.cancel();
assert!(stream.next().await.unwrap().is_none());
@@ -998,7 +1062,7 @@ mod tests {
let config = TranslatorConfig::local_config(&cas_path).unwrap();
let session = FileDownloadSession::new(config.into()).await.unwrap();
let (_id, mut stream) = session.download_stream(&xfi_no_size).await.unwrap();
let (_id, mut stream) = session.download_stream(&xfi_no_size, None).await.unwrap();
let mut collected = Vec::new();
while let Some(chunk) = stream.next().await.unwrap() {

View File

@@ -18,7 +18,7 @@ pub use file_upload_session::FileUploadSession;
pub use xet_file::XetFileInfo;
pub use crate::deduplication::RawXorbData;
pub use crate::file_reconstruction::DownloadStream;
pub use crate::file_reconstruction::{DownloadStream, UnorderedDownloadStream};
#[cfg(debug_assertions)]
pub mod test_utils;

View File

@@ -380,7 +380,7 @@ impl HydrateDehydrateTest {
let out_filename = self.dest_dir.join(entry.file_name());
let xf: XetFileInfo = serde_json::from_reader(File::open(entry.path()).unwrap()).unwrap();
let (_id, mut stream) = session.download_stream(&xf).await.unwrap();
let (_id, mut stream) = session.download_stream(&xf, None).await.unwrap();
let mut file = File::create(&out_filename).unwrap();
while let Some(chunk) = stream.next().await.unwrap() {

View File

@@ -224,7 +224,7 @@ pub struct ItemProgressUpdater {
impl ItemProgressUpdater {
/// Create a standalone updater for debug/testing purposes.
/// Creates its own throwaway GroupProgress.
#[cfg(test)]
#[cfg(any(debug_assertions, test))]
pub fn new_standalone(name: &str) -> Arc<Self> {
let group = GroupProgress::new();
let item = Arc::new(ItemProgress::new(UniqueID::new(), Arc::from(name)));

View File

@@ -0,0 +1,551 @@
#[cfg(test)]
mod tests {
use std::fs;
use std::sync::Arc;
use std::time::Duration;
use tempfile::TempDir;
use xet_client::cas_client::LocalTestServerBuilder;
use xet_data::processing::configurations::TranslatorConfig;
use xet_data::processing::{FileDownloadSession, FileUploadSession, Sha256Policy, XetFileInfo};
async fn upload_bytes(upload_session: &Arc<FileUploadSession>, name: &str, data: &[u8]) -> XetFileInfo {
let (_id, mut cleaner) = upload_session
.start_clean(Some(name.into()), data.len() as u64, Sha256Policy::Compute)
.unwrap();
cleaner.add_data(data).await.unwrap();
let (xfi, _metrics) = cleaner.finish().await.unwrap();
xfi
}
fn reassemble(chunks: Vec<(u64, bytes::Bytes)>, expected_len: usize) -> Vec<u8> {
let mut buf = vec![0u8; expected_len];
for (offset, data) in chunks {
buf[offset as usize..offset as usize + data.len()].copy_from_slice(&data);
}
buf
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_unordered_stream_async_various_sizes() {
let server = LocalTestServerBuilder::new().start().await;
let base_dir = TempDir::new().unwrap();
let config = Arc::new(TranslatorConfig::test_server_config(server.http_endpoint(), base_dir.path()).unwrap());
let test_cases: Vec<(&str, Vec<u8>)> = vec![
("one_byte", vec![0x42]),
("small", b"hello world".to_vec()),
("medium", vec![0xAB; 4096]),
("larger", vec![0xCD; 64 * 1024]),
];
let download_session = FileDownloadSession::new(config.clone()).await.unwrap();
for (name, data) in &test_cases {
let upload_session = FileUploadSession::new(config.clone()).await.unwrap();
let xfi = upload_bytes(&upload_session, name, data).await;
upload_session.finalize().await.unwrap();
let (_id, mut stream) = download_session.download_unordered_stream(&xfi, None).await.unwrap();
let mut chunks = Vec::new();
while let Some((offset, chunk)) = stream.next().await.unwrap() {
chunks.push((offset, chunk));
}
let assembled = reassemble(chunks, data.len());
assert_eq!(assembled, *data, "content mismatch for {name}");
assert!(stream.is_complete());
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_unordered_stream_blocking_various_sizes() {
let server = LocalTestServerBuilder::new().start().await;
let base_dir = TempDir::new().unwrap();
let config = Arc::new(TranslatorConfig::test_server_config(server.http_endpoint(), base_dir.path()).unwrap());
let test_cases: Vec<(&str, Vec<u8>)> = vec![
("one_byte", vec![0x42]),
("small", b"hello world".to_vec()),
("medium", vec![0xAB; 4096]),
("larger", vec![0xCD; 64 * 1024]),
];
let download_session = FileDownloadSession::new(config.clone()).await.unwrap();
for (name, data) in &test_cases {
let upload_session = FileUploadSession::new(config.clone()).await.unwrap();
let xfi = upload_bytes(&upload_session, name, data).await;
upload_session.finalize().await.unwrap();
let (_id, stream) = download_session.download_unordered_stream(&xfi, None).await.unwrap();
let expected_data = data.clone();
let collected = tokio::task::spawn_blocking(move || {
let mut stream = stream;
let mut chunks = Vec::new();
while let Some((offset, chunk)) = stream.blocking_next().unwrap() {
chunks.push((offset, chunk));
}
reassemble(chunks, expected_data.len())
})
.await
.unwrap();
assert_eq!(collected, *data, "content mismatch for {name}");
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_unordered_stream_reassemble_to_file() {
let server = LocalTestServerBuilder::new().start().await;
let base_dir = TempDir::new().unwrap();
let config = Arc::new(TranslatorConfig::test_server_config(server.http_endpoint(), base_dir.path()).unwrap());
let original_data = vec![0xEF; 64 * 1024];
let upload_session = FileUploadSession::new(config.clone()).await.unwrap();
let xfi = upload_bytes(&upload_session, "file_to_disk", &original_data).await;
upload_session.finalize().await.unwrap();
let download_session = FileDownloadSession::new(config.clone()).await.unwrap();
let (_id, mut stream) = download_session.download_unordered_stream(&xfi, None).await.unwrap();
let mut chunks = Vec::new();
while let Some((offset, chunk)) = stream.next().await.unwrap() {
chunks.push((offset, chunk));
}
let assembled = reassemble(chunks, original_data.len());
let out_path = base_dir.path().join("reassembled.bin");
fs::write(&out_path, &assembled).unwrap();
assert_eq!(fs::read(&out_path).unwrap(), original_data);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_unordered_stream_total_bytes_expected() {
let server = LocalTestServerBuilder::new().start().await;
let base_dir = TempDir::new().unwrap();
let config = Arc::new(TranslatorConfig::test_server_config(server.http_endpoint(), base_dir.path()).unwrap());
let original_data = b"total bytes tracking test";
let upload_session = FileUploadSession::new(config.clone()).await.unwrap();
let xfi = upload_bytes(&upload_session, "tracking", original_data).await;
upload_session.finalize().await.unwrap();
let download_session = FileDownloadSession::new(config.clone()).await.unwrap();
let (_id, mut stream) = download_session.download_unordered_stream(&xfi, None).await.unwrap();
assert_eq!(stream.total_bytes_expected(), original_data.len() as u64);
while stream.next().await.unwrap().is_some() {}
assert!(stream.is_complete());
assert_eq!(stream.bytes_completed(), original_data.len() as u64);
assert_eq!(stream.bytes_in_progress(), 0);
assert_eq!(stream.terms_in_progress(), 0);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_unordered_stream_is_complete_loop_drains_all_data() {
let server = LocalTestServerBuilder::new().start().await;
let base_dir = TempDir::new().unwrap();
let config = Arc::new(TranslatorConfig::test_server_config(server.http_endpoint(), base_dir.path()).unwrap());
let original_data: Vec<u8> = (0..131072u32).map(|i| (i % 251) as u8).collect();
let upload_session = FileUploadSession::new(config.clone()).await.unwrap();
let xfi = upload_bytes(&upload_session, "is_complete_loop", &original_data).await;
upload_session.finalize().await.unwrap();
let download_session = FileDownloadSession::new(config.clone()).await.unwrap();
let (_id, mut stream) = download_session.download_unordered_stream(&xfi, None).await.unwrap();
let mut chunks = Vec::new();
while !stream.is_complete() {
if let Some((offset, chunk)) = stream.next().await.unwrap() {
chunks.push((offset, chunk));
tokio::time::sleep(Duration::from_millis(1)).await;
}
}
let assembled = reassemble(chunks, original_data.len());
assert_eq!(assembled, original_data);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_unordered_stream_next_returns_none_after_complete() {
let server = LocalTestServerBuilder::new().start().await;
let base_dir = TempDir::new().unwrap();
let config = Arc::new(TranslatorConfig::test_server_config(server.http_endpoint(), base_dir.path()).unwrap());
let original_data = b"extra none calls";
let upload_session = FileUploadSession::new(config.clone()).await.unwrap();
let xfi = upload_bytes(&upload_session, "none_test", original_data).await;
upload_session.finalize().await.unwrap();
let download_session = FileDownloadSession::new(config.clone()).await.unwrap();
let (_id, mut stream) = download_session.download_unordered_stream(&xfi, None).await.unwrap();
while stream.next().await.unwrap().is_some() {}
assert!(stream.next().await.unwrap().is_none());
assert!(stream.next().await.unwrap().is_none());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_unordered_stream_cancel_before_start() {
let server = LocalTestServerBuilder::new().start().await;
let base_dir = TempDir::new().unwrap();
let config = Arc::new(TranslatorConfig::test_server_config(server.http_endpoint(), base_dir.path()).unwrap());
let original_data = b"cancel before start";
let upload_session = FileUploadSession::new(config.clone()).await.unwrap();
let xfi = upload_bytes(&upload_session, "cancel_pre", original_data).await;
upload_session.finalize().await.unwrap();
let download_session = FileDownloadSession::new(config.clone()).await.unwrap();
let (_id, mut stream) = download_session.download_unordered_stream(&xfi, None).await.unwrap();
stream.cancel();
assert!(stream.next().await.unwrap().is_none());
assert!(stream.next().await.unwrap().is_none());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_unordered_stream_cancel_after_partial_read() {
let server = LocalTestServerBuilder::new().start().await;
let base_dir = TempDir::new().unwrap();
let config = Arc::new(TranslatorConfig::test_server_config(server.http_endpoint(), base_dir.path()).unwrap());
let original_data = b"cancel after partial read test data";
let upload_session = FileUploadSession::new(config.clone()).await.unwrap();
let xfi = upload_bytes(&upload_session, "cancel_mid", original_data).await;
upload_session.finalize().await.unwrap();
let download_session = FileDownloadSession::new(config.clone()).await.unwrap();
let (_id, mut stream) = download_session.download_unordered_stream(&xfi, None).await.unwrap();
let _ = stream.next().await.unwrap();
stream.cancel();
assert!(stream.next().await.unwrap().is_none());
let (_id2, mut stream2) = download_session.download_unordered_stream(&xfi, None).await.unwrap();
let mut chunks = Vec::new();
while let Some((offset, chunk)) = stream2.next().await.unwrap() {
chunks.push((offset, chunk));
}
let assembled = reassemble(chunks, original_data.len());
assert_eq!(assembled, original_data);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_unordered_stream_multiple_concurrent() {
let server = LocalTestServerBuilder::new().start().await;
let base_dir = TempDir::new().unwrap();
let config = Arc::new(TranslatorConfig::test_server_config(server.http_endpoint(), base_dir.path()).unwrap());
let data_a = b"Unordered stream A for concurrent download";
let data_b = b"Unordered stream B for concurrent download - different content";
let upload_session = FileUploadSession::new(config.clone()).await.unwrap();
let xfi_a = upload_bytes(&upload_session, "concurrent_a", data_a).await;
let xfi_b = upload_bytes(&upload_session, "concurrent_b", data_b).await;
upload_session.finalize().await.unwrap();
let download_session = FileDownloadSession::new(config.clone()).await.unwrap();
let (_id_a, mut stream_a) = download_session.download_unordered_stream(&xfi_a, None).await.unwrap();
let (_id_b, mut stream_b) = download_session.download_unordered_stream(&xfi_b, None).await.unwrap();
let len_a = data_a.len();
let task_a = tokio::spawn(async move {
let mut chunks = Vec::new();
while let Some((offset, chunk)) = stream_a.next().await.unwrap() {
chunks.push((offset, chunk));
}
reassemble(chunks, len_a)
});
let len_b = data_b.len();
let task_b = tokio::spawn(async move {
let mut chunks = Vec::new();
while let Some((offset, chunk)) = stream_b.next().await.unwrap() {
chunks.push((offset, chunk));
}
reassemble(chunks, len_b)
});
let result_a = task_a.await.unwrap();
let result_b = task_b.await.unwrap();
assert_eq!(result_a, data_a);
assert_eq!(result_b, data_b);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_unordered_stream_drop_without_reading() {
let server = LocalTestServerBuilder::new().start().await;
let base_dir = TempDir::new().unwrap();
let config = Arc::new(TranslatorConfig::test_server_config(server.http_endpoint(), base_dir.path()).unwrap());
let original_data = b"drop without reading cleanup test";
let upload_session = FileUploadSession::new(config.clone()).await.unwrap();
let xfi = upload_bytes(&upload_session, "drop_test", original_data).await;
upload_session.finalize().await.unwrap();
let download_session = FileDownloadSession::new(config.clone()).await.unwrap();
let (_id, stream) = download_session.download_unordered_stream(&xfi, None).await.unwrap();
drop(stream);
tokio::task::yield_now().await;
let out_path = base_dir.path().join("after_drop.txt");
let (_id2, n_bytes) = download_session.download_file(&xfi, &out_path).await.unwrap();
assert_eq!(n_bytes, original_data.len() as u64);
assert_eq!(fs::read(&out_path).unwrap(), original_data);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_unordered_stream_matches_sequential_download() {
let server = LocalTestServerBuilder::new().start().await;
let base_dir = TempDir::new().unwrap();
let config = Arc::new(TranslatorConfig::test_server_config(server.http_endpoint(), base_dir.path()).unwrap());
let original_data: Vec<u8> = (0..=255u8).cycle().take(32 * 1024).collect();
let upload_session = FileUploadSession::new(config.clone()).await.unwrap();
let xfi = upload_bytes(&upload_session, "compare_test", &original_data).await;
upload_session.finalize().await.unwrap();
let download_session = FileDownloadSession::new(config.clone()).await.unwrap();
let out_path = base_dir.path().join("sequential.bin");
let (_id, _) = download_session.download_file(&xfi, &out_path).await.unwrap();
let sequential_result = fs::read(&out_path).unwrap();
let (_id2, mut stream) = download_session.download_unordered_stream(&xfi, None).await.unwrap();
let mut chunks = Vec::new();
while let Some((offset, chunk)) = stream.next().await.unwrap() {
chunks.push((offset, chunk));
}
let unordered_result = reassemble(chunks, original_data.len());
assert_eq!(sequential_result, unordered_result);
assert_eq!(sequential_result, original_data);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_test_repeated_blocking_downloads() {
let server = LocalTestServerBuilder::new().start().await;
let base_dir = TempDir::new().unwrap();
let config = Arc::new(TranslatorConfig::test_server_config(server.http_endpoint(), base_dir.path()).unwrap());
let original_data: Vec<u8> = (0..65536u32).map(|i| (i % 251) as u8).collect();
let upload_session = FileUploadSession::new(config.clone()).await.unwrap();
let xfi = upload_bytes(&upload_session, "stress_blocking", &original_data).await;
upload_session.finalize().await.unwrap();
let download_session = FileDownloadSession::new(config.clone()).await.unwrap();
for _ in 0..30 {
let (_id, stream) = download_session.download_unordered_stream(&xfi, None).await.unwrap();
let expected_len = original_data.len();
let result = tokio::task::spawn_blocking(move || {
let mut stream = stream;
let mut chunks = Vec::new();
while let Some((offset, chunk)) = stream.blocking_next().unwrap() {
chunks.push((offset, chunk));
}
reassemble(chunks, expected_len)
})
.await
.unwrap();
assert_eq!(result, original_data);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_test_repeated_async_downloads() {
let server = LocalTestServerBuilder::new().start().await;
let base_dir = TempDir::new().unwrap();
let config = Arc::new(TranslatorConfig::test_server_config(server.http_endpoint(), base_dir.path()).unwrap());
let original_data: Vec<u8> = (0..65536u32).map(|i| (i % 251) as u8).collect();
let upload_session = FileUploadSession::new(config.clone()).await.unwrap();
let xfi = upload_bytes(&upload_session, "stress_async", &original_data).await;
upload_session.finalize().await.unwrap();
let download_session = FileDownloadSession::new(config.clone()).await.unwrap();
for _ in 0..30 {
let (_id, mut stream) = download_session.download_unordered_stream(&xfi, None).await.unwrap();
let mut chunks = Vec::new();
while let Some((offset, chunk)) = stream.next().await.unwrap() {
chunks.push((offset, chunk));
}
let result = reassemble(chunks, original_data.len());
assert_eq!(result, original_data);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_test_concurrent_blocking_downloads() {
let server = LocalTestServerBuilder::new().start().await;
let base_dir = TempDir::new().unwrap();
let config = Arc::new(TranslatorConfig::test_server_config(server.http_endpoint(), base_dir.path()).unwrap());
let original_data: Vec<u8> = (0..65536u32).map(|i| (i % 251) as u8).collect();
let upload_session = FileUploadSession::new(config.clone()).await.unwrap();
let xfi = upload_bytes(&upload_session, "stress_concurrent", &original_data).await;
upload_session.finalize().await.unwrap();
let download_session = FileDownloadSession::new(config.clone()).await.unwrap();
let mut handles = Vec::new();
for _ in 0..8 {
let (_id, stream) = download_session.download_unordered_stream(&xfi, None).await.unwrap();
let expected_len = original_data.len();
handles.push(tokio::task::spawn_blocking(move || {
let mut stream = stream;
let mut chunks = Vec::new();
while let Some((offset, chunk)) = stream.blocking_next().unwrap() {
chunks.push((offset, chunk));
}
reassemble(chunks, expected_len)
}));
}
for handle in handles {
let result = handle.await.unwrap();
assert_eq!(result, original_data);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_test_large_file_download() {
let server = LocalTestServerBuilder::new().start().await;
let base_dir = TempDir::new().unwrap();
let config = Arc::new(TranslatorConfig::test_server_config(server.http_endpoint(), base_dir.path()).unwrap());
let original_data: Vec<u8> = (0..262144u32)
.map(|i| ((i.wrapping_mul(7919) ^ (i >> 3)) % 256) as u8)
.collect();
let upload_session = FileUploadSession::new(config.clone()).await.unwrap();
let xfi = upload_bytes(&upload_session, "stress_large", &original_data).await;
upload_session.finalize().await.unwrap();
let download_session = FileDownloadSession::new(config.clone()).await.unwrap();
{
let (_id, mut stream) = download_session.download_unordered_stream(&xfi, None).await.unwrap();
let mut chunks = Vec::new();
while let Some((offset, chunk)) = stream.next().await.unwrap() {
chunks.push((offset, chunk));
}
assert_eq!(reassemble(chunks, original_data.len()), original_data);
}
{
let (_id, stream) = download_session.download_unordered_stream(&xfi, None).await.unwrap();
let expected_len = original_data.len();
let expected_data = original_data.clone();
let result = tokio::task::spawn_blocking(move || {
let mut stream = stream;
let mut chunks = Vec::new();
while let Some((offset, chunk)) = stream.blocking_next().unwrap() {
chunks.push((offset, chunk));
}
reassemble(chunks, expected_len)
})
.await
.unwrap();
assert_eq!(result, expected_data);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_test_mixed_concurrent_async_and_blocking() {
let server = LocalTestServerBuilder::new().start().await;
let base_dir = TempDir::new().unwrap();
let config = Arc::new(TranslatorConfig::test_server_config(server.http_endpoint(), base_dir.path()).unwrap());
let original_data: Vec<u8> = (0..65536u32).map(|i| (i % 251) as u8).collect();
let upload_session = FileUploadSession::new(config.clone()).await.unwrap();
let xfi = upload_bytes(&upload_session, "stress_mixed", &original_data).await;
upload_session.finalize().await.unwrap();
let download_session = FileDownloadSession::new(config.clone()).await.unwrap();
let mut handles: Vec<tokio::task::JoinHandle<Vec<u8>>> = Vec::new();
for i in 0..8 {
if i % 2 == 0 {
let (_id, mut stream) = download_session.download_unordered_stream(&xfi, None).await.unwrap();
let expected_len = original_data.len();
handles.push(tokio::spawn(async move {
let mut chunks = Vec::new();
while let Some((offset, chunk)) = stream.next().await.unwrap() {
chunks.push((offset, chunk));
}
reassemble(chunks, expected_len)
}));
} else {
let (_id, stream) = download_session.download_unordered_stream(&xfi, None).await.unwrap();
let expected_len = original_data.len();
handles.push(tokio::task::spawn_blocking(move || {
let mut stream = stream;
let mut chunks = Vec::new();
while let Some((offset, chunk)) = stream.blocking_next().unwrap() {
chunks.push((offset, chunk));
}
reassemble(chunks, expected_len)
}));
}
}
for handle in handles {
let result = handle.await.unwrap();
assert_eq!(result, original_data);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn stress_test_rapid_create_and_drop() {
let server = LocalTestServerBuilder::new().start().await;
let base_dir = TempDir::new().unwrap();
let config = Arc::new(TranslatorConfig::test_server_config(server.http_endpoint(), base_dir.path()).unwrap());
let original_data: Vec<u8> = (0..32768u32).map(|i| (i % 199) as u8).collect();
let upload_session = FileUploadSession::new(config.clone()).await.unwrap();
let xfi = upload_bytes(&upload_session, "stress_drop", &original_data).await;
upload_session.finalize().await.unwrap();
let download_session = FileDownloadSession::new(config.clone()).await.unwrap();
for _ in 0..20 {
let (_id, mut stream) = download_session.download_unordered_stream(&xfi, None).await.unwrap();
let _ = stream.next().await;
drop(stream);
}
let (_id, mut stream) = download_session.download_unordered_stream(&xfi, None).await.unwrap();
let mut chunks = Vec::new();
while let Some((offset, chunk)) = stream.next().await.unwrap() {
chunks.push((offset, chunk));
}
assert_eq!(reassemble(chunks, original_data.len()), original_data);
}
}

View File

@@ -20,13 +20,13 @@ xet-data = { version = "1.4.0", path = "../xet_data" }
anyhow = { workspace = true }
async-trait = { workspace = true }
bytes = { workspace = true }
http = { workspace = true }
more-asserts = { workspace = true }
serde = { workspace = true, features = ["derive"] }
thiserror = { workspace = true }
tokio = { workspace = true, features = ["net", "time"] }
tracing = { workspace = true }
ulid = { workspace = true }
pyo3 = { workspace = true, optional = true }
[features]

View File

@@ -1,6 +1,6 @@
//! Async session-based upload/download example.
//!
//! Mirror of `example_sync.rs` using the async API (`UploadCommit` / `DownloadGroup`).
//! Mirror of `example_sync.rs` using the async API (`UploadCommit` / `FileDownloadGroup`).
//! Requires an async runtime — here provided by `#[tokio::main]`.
use std::path::PathBuf;
@@ -107,7 +107,7 @@ async fn download_files(metadata_file: PathBuf, output_dir: PathBuf, endpoint: O
builder = builder.with_endpoint(ep);
}
let session = builder.build_async().await?;
let group = session.new_download_group().await?;
let group = session.new_file_download_group().await?;
// Enqueue all downloads; each starts immediately in the background.
let n_files = metadata.len();

View File

@@ -1,6 +1,6 @@
//! Session-based upload/download example.
//!
//! Shows the three-level hierarchy: XetSession → UploadCommit/DownloadGroup → files.
//! Shows the three-level hierarchy: XetSession → UploadCommit/FileDownloadGroup → files.
use std::path::PathBuf;
use std::time::Duration;
@@ -104,7 +104,7 @@ fn download_files(metadata_file: PathBuf, output_dir: PathBuf, endpoint: Option<
builder = builder.with_endpoint(ep);
}
let session = builder.build()?;
let group = session.new_download_group_blocking()?;
let group = session.new_file_download_group_blocking()?;
// Enqueue all downloads; each starts immediately in the background.
let n_files = metadata.len();

View File

@@ -0,0 +1,182 @@
use std::sync::Arc;
use bytes::Bytes;
use xet_data::DataError;
use xet_data::processing::{DownloadStream, FileDownloadSession, UnorderedDownloadStream};
use xet_data::progress_tracking::{ItemProgressReport, UniqueID};
use super::errors::SessionError;
/// A streaming download handle with built-in progress tracking.
///
/// Wraps a [`DownloadStream`] and keeps a reference to the
/// [`FileDownloadSession`] that created it, so callers can poll progress
/// while consuming data chunks. Created by
/// [`XetSession::download_stream`](super::XetSession::download_stream) or
/// [`XetSession::download_stream_blocking`](super::XetSession::download_stream_blocking).
///
/// The reconstruction task is spawned at creation time but paused until
/// [`start`](Self::start) is called explicitly, or automatically on the
/// first call to [`next`](Self::next) / [`blocking_next`](Self::blocking_next).
/// Because the spawn happens during creation, `start()` is non-async and
/// works from any executor or plain thread.
pub struct XetDownloadStream {
inner: DownloadStream,
download_session: Arc<FileDownloadSession>,
id: UniqueID,
}
impl XetDownloadStream {
pub(super) fn new(inner: DownloadStream, download_session: Arc<FileDownloadSession>, id: UniqueID) -> Self {
Self {
inner,
download_session,
id,
}
}
/// Unblocks the reconstruction task so it begins producing data.
///
/// If already started, this is a no-op. Called automatically on the first
/// [`next`](Self::next) / [`blocking_next`](Self::blocking_next).
///
/// This method is non-async and does not require a tokio runtime context.
pub fn start(&mut self) {
self.inner.start();
}
/// Returns the next chunk of downloaded data asynchronously.
///
/// Returns `Ok(None)` when the download is complete.
pub async fn next(&mut self) -> Result<Option<Bytes>, SessionError> {
self.inner.next().await.map_err(|e| SessionError::from(DataError::from(e)))
}
/// Returns the next chunk of downloaded data, blocking the current thread
/// until data is available.
///
/// Returns `Ok(None)` when the download is complete.
///
/// # Panics
///
/// Panics if called from within an async runtime context. Use
/// [`next`](Self::next) for async contexts.
pub fn blocking_next(&mut self) -> Result<Option<Bytes>, SessionError> {
self.inner.blocking_next().map_err(|e| SessionError::from(DataError::from(e)))
}
/// Cancels the in-progress (or not-yet-started) download.
///
/// Subsequent calls to [`next`](Self::next) / [`blocking_next`](Self::blocking_next)
/// will return `Ok(None)`.
pub fn cancel(&mut self) {
self.inner.cancel();
}
/// Returns a snapshot of this stream's download progress.
///
/// The returned [`ItemProgressReport`] contains the item name,
/// total bytes, and bytes completed so far. This method is lock-free
/// (reads atomic counters) and safe to call from any thread.
pub fn get_progress(&self) -> ItemProgressReport {
self.download_session
.item_report(self.id)
.expect("progress item was registered at stream creation and is never removed")
}
}
impl Drop for XetDownloadStream {
fn drop(&mut self) {
self.download_session.unregister_stream_abort_callback(self.id);
}
}
/// A streaming download handle that yields data chunks in completion order,
/// each tagged with their byte offset in the output file.
///
/// Wraps an [`UnorderedDownloadStream`] and keeps a reference to the
/// [`FileDownloadSession`] that created it, so callers can poll progress
/// while consuming data chunks. Created by
/// [`XetSession::download_unordered_stream`](super::XetSession::download_unordered_stream) or
/// [`XetSession::download_unordered_stream_blocking`](super::XetSession::download_unordered_stream_blocking).
///
/// The reconstruction task is spawned at creation time but paused until
/// [`start`](Self::start) is called explicitly, or automatically on the
/// first call to [`next`](Self::next) / [`blocking_next`](Self::blocking_next).
/// Because the spawn happens during creation, `start()` is non-async and
/// works from any executor or plain thread.
pub struct XetUnorderedDownloadStream {
inner: UnorderedDownloadStream,
download_session: Arc<FileDownloadSession>,
id: UniqueID,
}
impl XetUnorderedDownloadStream {
pub(super) fn new(
inner: UnorderedDownloadStream,
download_session: Arc<FileDownloadSession>,
id: UniqueID,
) -> Self {
Self {
inner,
download_session,
id,
}
}
/// Unblocks the reconstruction task so it begins producing data.
///
/// If already started, this is a no-op. Called automatically on the first
/// [`next`](Self::next) / [`blocking_next`](Self::blocking_next).
///
/// This method is non-async and does not require a tokio runtime context.
pub fn start(&mut self) {
self.inner.start();
}
/// Returns the next chunk of downloaded data with its byte offset
/// asynchronously.
///
/// Returns `Ok(None)` when the download is complete.
pub async fn next(&mut self) -> Result<Option<(u64, Bytes)>, SessionError> {
self.inner.next().await.map_err(|e| SessionError::from(DataError::from(e)))
}
/// Returns the next chunk of downloaded data with its byte offset,
/// blocking the current thread until data is available.
///
/// Returns `Ok(None)` when the download is complete.
///
/// # Panics
///
/// Panics if called from within an async runtime context. Use
/// [`next`](Self::next) for async contexts.
pub fn blocking_next(&mut self) -> Result<Option<(u64, Bytes)>, SessionError> {
self.inner.blocking_next().map_err(|e| SessionError::from(DataError::from(e)))
}
/// Cancels the in-progress (or not-yet-started) download.
///
/// Subsequent calls to [`next`](Self::next) / [`blocking_next`](Self::blocking_next)
/// will return `Ok(None)`.
pub fn cancel(&mut self) {
self.inner.cancel();
}
/// Returns a snapshot of this stream's download progress.
///
/// The returned [`ItemProgressReport`] contains the item name,
/// total bytes, and bytes completed so far. This method is lock-free
/// (reads atomic counters) and safe to call from any thread.
pub fn get_progress(&self) -> ItemProgressReport {
self.download_session
.item_report(self.id)
.expect("progress item was registered at stream creation and is never removed")
}
}
impl Drop for XetUnorderedDownloadStream {
fn drop(&mut self) {
self.download_session.unregister_stream_abort_callback(self.id);
}
}

View File

@@ -0,0 +1 @@
pub use crate::error::XetError as SessionError;

View File

@@ -1,4 +1,4 @@
//! DownloadGroup - groups related downloads
//! FileDownloadGroup - groups related downloads
use std::collections::HashMap;
use std::path::PathBuf;
@@ -17,8 +17,8 @@ use crate::error::XetError;
/// API for grouping related file downloads into a single unit of work.
///
/// Obtain via [`XetSession::new_download_group`] (async) or
/// [`XetSession::new_download_group_blocking`] (sync).
/// Obtain via [`XetSession::new_file_download_group`] (async) or
/// [`XetSession::new_file_download_group_blocking`] (sync).
///
/// Queue files with [`download_file_to_path`](Self::download_file_to_path) (they start
/// downloading immediately in the background), poll progress with
@@ -38,18 +38,18 @@ use crate::error::XetError;
/// aborted, and [`XetError::AlreadyFinished`] if
/// [`finish`](Self::finish) has already been called.
#[derive(Clone)]
pub struct DownloadGroup {
pub(super) inner: Arc<DownloadGroupInner>,
pub struct FileDownloadGroup {
pub(super) inner: Arc<FileDownloadGroupInner>,
}
impl std::ops::Deref for DownloadGroup {
type Target = DownloadGroupInner;
impl std::ops::Deref for FileDownloadGroup {
type Target = FileDownloadGroupInner;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl DownloadGroup {
impl FileDownloadGroup {
/// Create a new download group from an **async** context. Initialisation logic shared by the sync and async
/// constructors.
pub(super) async fn new(session: XetSession) -> Result<Self, XetError> {
@@ -57,7 +57,7 @@ impl DownloadGroup {
let config = create_translator_config(&session)?;
let download_session = FileDownloadSession::new(Arc::new(config)).await?;
let inner = Arc::new(DownloadGroupInner {
let inner = Arc::new(FileDownloadGroupInner {
group_id,
session,
active_tasks: RwLock::new(HashMap::new()),
@@ -195,14 +195,14 @@ impl DownloadGroup {
}
}
/// Per-file result type returned by [`DownloadGroup::finish`].
/// Per-file result type returned by [`FileDownloadGroup::finish`].
///
/// The `Arc` lets the same value be stored in both the `finish()` return map
/// and the per-task [`DownloadTaskHandle`] without requiring the inner
/// `Result` to be `Clone`.
pub type DownloadResult = Arc<Result<DownloadedFile, XetError>>;
/// Handle for a single download task tracked internally by DownloadGroup.
/// Handle for a single download task tracked internally by FileDownloadGroup.
struct InnerDownloadTaskHandle {
status: Arc<Mutex<TaskStatus>>,
dest_path: PathBuf,
@@ -211,10 +211,10 @@ struct InnerDownloadTaskHandle {
result: Arc<OnceLock<DownloadResult>>,
}
/// All shared state owned by a single DownloadGroup instance.
/// Accessed through `Arc<DownloadGroupInner>`; do not use this type directly.
/// All shared state owned by a single FileDownloadGroup instance.
/// Accessed through `Arc<FileDownloadGroupInner>`; do not use this type directly.
#[doc(hidden)]
pub struct DownloadGroupInner {
pub struct FileDownloadGroupInner {
group_id: UniqueID,
session: XetSession,
@@ -228,7 +228,7 @@ pub struct DownloadGroupInner {
state: Mutex<GroupState>,
}
impl DownloadGroupInner {
impl FileDownloadGroupInner {
// ===== State helpers =====
/// Check whether the group is still accepting new tasks.
@@ -356,7 +356,7 @@ impl DownloadGroupInner {
*self.state.lock()? = GroupState::Finished;
// Unregister from session
self.session.finish_download_group(self.group_id)?;
self.session.finish_file_download_group(self.group_id)?;
Ok(results)
}
@@ -373,7 +373,7 @@ impl DownloadGroupInner {
}
}
/// Per-file result returned by [`DownloadGroup::finish`].
/// Per-file result returned by [`FileDownloadGroup::finish`].
#[derive(Clone, Debug)]
pub struct DownloadedFile {
/// Local path where the file was written.
@@ -430,9 +430,9 @@ mod tests {
fn test_finish_blocked_while_download_registration_holds_state_lock() -> Result<()> {
let session = XetSessionBuilder::new().build()?;
let runtime = session.runtime.clone();
// Create DownloadGroup directly so we can access its private state field
// Create FileDownloadGroup directly so we can access its private state field
// (accessible here because mod tests is a submodule of download_group).
let group = runtime.external_run_async_task(DownloadGroup::new(session.clone()))??;
let group = runtime.external_run_async_task(FileDownloadGroup::new(session.clone()))??;
let group_for_thread = group.clone();
let runtime_for_thread = runtime.clone();
@@ -464,8 +464,8 @@ mod tests {
// Two download groups created from the same session have distinct IDs.
async fn test_group_has_unique_id() {
let session = XetSessionBuilder::new().build_async().await.unwrap();
let g1 = session.new_download_group().await.unwrap();
let g2 = session.new_download_group().await.unwrap();
let g1 = session.new_file_download_group().await.unwrap();
let g2 = session.new_file_download_group().await.unwrap();
assert_ne!(g1.id(), g2.id());
}
@@ -475,7 +475,7 @@ mod tests {
// A fresh group has all-zero aggregate progress.
async fn test_get_progress_empty_initially() {
let session = XetSessionBuilder::new().build_async().await.unwrap();
let group = session.new_download_group().await.unwrap();
let group = session.new_file_download_group().await.unwrap();
let report = group.get_progress().unwrap();
assert_eq!(report.total_bytes, 0);
assert_eq!(report.total_bytes_completed, 0);
@@ -487,7 +487,7 @@ mod tests {
// An empty finish succeeds and returns an empty result set.
async fn test_finish_empty_succeeds() {
let session = XetSessionBuilder::new().build_async().await.unwrap();
let group = session.new_download_group().await.unwrap();
let group = session.new_file_download_group().await.unwrap();
let results = group.finish().await.unwrap();
assert!(results.is_empty());
}
@@ -496,7 +496,7 @@ mod tests {
// finish() transitions the group into the Finished state.
async fn test_finish_marks_as_finished() {
let session = XetSessionBuilder::new().build_async().await.unwrap();
let group = session.new_download_group().await.unwrap();
let group = session.new_file_download_group().await.unwrap();
let group_clone = group.clone();
group.finish().await.unwrap();
assert!(group_clone.is_finished());
@@ -506,7 +506,7 @@ mod tests {
// A second finish() call on any clone returns AlreadyFinished.
async fn test_second_finish_fails() {
let session = XetSessionBuilder::new().build_async().await.unwrap();
let g1 = session.new_download_group().await.unwrap();
let g1 = session.new_file_download_group().await.unwrap();
let g2 = g1.clone();
g1.finish().await.unwrap();
let err = g2.finish().await.unwrap_err();
@@ -517,10 +517,10 @@ mod tests {
// finish() unregisters the group from the session's active set.
async fn test_finish_unregisters_from_session() {
let session = XetSessionBuilder::new().build_async().await.unwrap();
let group = session.new_download_group().await.unwrap();
assert_eq!(session.active_download_groups.lock().unwrap().len(), 1);
let group = session.new_file_download_group().await.unwrap();
assert_eq!(session.active_file_download_groups.lock().unwrap().len(), 1);
group.finish().await.unwrap();
assert_eq!(session.active_download_groups.lock().unwrap().len(), 0);
assert_eq!(session.active_file_download_groups.lock().unwrap().len(), 0);
}
// ── Guards ───────────────────────────────────────────────────────────────
@@ -529,7 +529,7 @@ mod tests {
// download_file_to_path returns Aborted when the parent session has been aborted.
async fn test_download_file_on_aborted_session_returns_error() {
let session = XetSessionBuilder::new().build_async().await.unwrap();
let group = session.new_download_group().await.unwrap();
let group = session.new_file_download_group().await.unwrap();
session.abort().unwrap();
let err = group
.download_file_to_path(
@@ -549,7 +549,7 @@ mod tests {
// download_file_to_path after finish returns AlreadyFinished.
async fn test_download_file_after_finish_fails() {
let session = XetSessionBuilder::new().build_async().await.unwrap();
let g1 = session.new_download_group().await.unwrap();
let g1 = session.new_file_download_group().await.unwrap();
let g2 = g1.clone();
g1.finish().await.unwrap();
let err = g2
@@ -570,7 +570,7 @@ mod tests {
// download_file_to_path on a directly-aborted group returns Aborted.
async fn test_download_file_on_aborted_group_returns_aborted() {
let session = XetSessionBuilder::new().build_async().await.unwrap();
let group = session.new_download_group().await.unwrap();
let group = session.new_file_download_group().await.unwrap();
group.abort().unwrap();
let err = group
.download_file_to_path(
@@ -592,8 +592,8 @@ mod tests {
// Finishing one group does not affect the state of another from the same session.
async fn test_two_groups_are_independent() {
let session = XetSessionBuilder::new().build_async().await.unwrap();
let g1 = session.new_download_group().await.unwrap();
let g2 = session.new_download_group().await.unwrap();
let g1 = session.new_file_download_group().await.unwrap();
let g2 = session.new_file_download_group().await.unwrap();
g1.finish().await.unwrap();
assert!(!g2.is_finished());
}
@@ -609,7 +609,7 @@ mod tests {
let file_info = upload_bytes(&session, original, "payload.bin").await.unwrap();
let dest = temp.path().join("downloaded.bin");
let group = session.new_download_group().await.unwrap();
let group = session.new_file_download_group().await.unwrap();
let handle = group.download_file_to_path(file_info, dest.clone()).await.unwrap();
assert!(matches!(handle.status().unwrap(), TaskStatus::Queued | TaskStatus::Running | TaskStatus::Completed));
group.finish().await.unwrap();
@@ -623,7 +623,7 @@ mod tests {
async fn test_download_status_failed_for_invalid_file_info() {
let temp = tempdir().unwrap();
let session = local_session(&temp).await.unwrap();
let group = session.new_download_group().await.unwrap();
let group = session.new_file_download_group().await.unwrap();
let handle = group
.download_file_to_path(
XetFileInfo {
@@ -650,7 +650,7 @@ mod tests {
let file_info = upload_bytes(&session, original, "id.bin").await.unwrap();
let dest = temp.path().join("download_id.bin");
let group = session.new_download_group().await.unwrap();
let group = session.new_file_download_group().await.unwrap();
let handle = group.download_file_to_path(file_info, dest).await.unwrap();
let download_session = group.inner.download_session.lock().unwrap().clone().unwrap();
@@ -700,7 +700,7 @@ mod tests {
let dest_a = temp.path().join("a_out.bin");
let dest_b = temp.path().join("b_out.bin");
let group = session.new_download_group().await.unwrap();
let group = session.new_file_download_group().await.unwrap();
group
.download_file_to_path(to_file_info(&handle_a), dest_a.clone())
.await
@@ -724,7 +724,7 @@ mod tests {
let file_info = upload_bytes(&session, original, "prog.bin").await.unwrap();
let dest = temp.path().join("out.bin");
let group = session.new_download_group().await.unwrap();
let group = session.new_file_download_group().await.unwrap();
let progress_observer = group.clone();
group.download_file_to_path(file_info, dest).await.unwrap();
group.finish().await.unwrap();
@@ -755,7 +755,7 @@ mod tests {
let data = b"result via task_id in finish map";
let file_info = upload_bytes(&session, data, "file.bin").await.unwrap();
let dest = temp.path().join("out.bin");
let group = session.new_download_group().await.unwrap();
let group = session.new_file_download_group().await.unwrap();
let handle = group.download_file_to_path(file_info, dest).await.unwrap();
let results = group.finish().await.unwrap();
let result = results.get(&handle.task_id).expect("task_id must be present in results");
@@ -769,7 +769,7 @@ mod tests {
let session = local_session(&temp).await.unwrap();
let file_info = upload_bytes(&session, b"some data", "file.bin").await.unwrap();
let dest = temp.path().join("out.bin");
let group = session.new_download_group().await.unwrap();
let group = session.new_file_download_group().await.unwrap();
let handle = group.download_file_to_path(file_info, dest).await.unwrap();
assert!(handle.result().is_none(), "result must be None before finish()");
group.finish().await.unwrap();
@@ -783,7 +783,7 @@ mod tests {
let data = b"download result test data";
let file_info = upload_bytes(&session, data, "file.bin").await.unwrap();
let dest = temp.path().join("out.bin");
let group = session.new_download_group().await.unwrap();
let group = session.new_file_download_group().await.unwrap();
let handle = group.download_file_to_path(file_info.clone(), dest).await.unwrap();
group.finish().await.unwrap();
let result = handle.result().expect("result must be set after finish()");
@@ -820,7 +820,7 @@ mod tests {
};
let dest = temp.path().join("out_futures.bin");
let group = session.new_download_group().await.unwrap();
let group = session.new_file_download_group().await.unwrap();
group.download_file_to_path(file_info, dest.clone()).await.unwrap();
group.finish().await.unwrap();
assert_eq!(std::fs::read(&dest).unwrap(), data);
@@ -851,7 +851,7 @@ mod tests {
};
let dest = temp.path().join("out_smol.bin");
let group = session.new_download_group().await.unwrap();
let group = session.new_file_download_group().await.unwrap();
group.download_file_to_path(file_info, dest.clone()).await.unwrap();
group.finish().await.unwrap();
assert_eq!(std::fs::read(&dest).unwrap(), data);
@@ -882,7 +882,7 @@ mod tests {
};
let dest = temp.path().join("out_async_std.bin");
let group = session.new_download_group().await.unwrap();
let group = session.new_file_download_group().await.unwrap();
group.download_file_to_path(file_info, dest.clone()).await.unwrap();
group.finish().await.unwrap();
assert_eq!(std::fs::read(&dest).unwrap(), data);
@@ -918,7 +918,7 @@ mod tests {
let file_info = upload_bytes_blocking(&session, original, "payload.bin")?;
let dest = temp.path().join("downloaded.bin");
let group = session.new_download_group_blocking()?;
let group = session.new_file_download_group_blocking()?;
group.download_file_to_path_blocking(file_info, dest.clone())?;
group.finish_blocking()?;
@@ -950,7 +950,7 @@ mod tests {
let dest_a = temp.path().join("a_out.bin");
let dest_b = temp.path().join("b_out.bin");
let group = session.new_download_group_blocking()?;
let group = session.new_file_download_group_blocking()?;
group.download_file_to_path_blocking(to_file_info(&handle_a), dest_a.clone())?;
group.download_file_to_path_blocking(to_file_info(&handle_b), dest_b.clone())?;
group.finish_blocking()?;
@@ -968,7 +968,7 @@ mod tests {
let file_info = upload_bytes_blocking(&session, original, "prog.bin")?;
let dest = temp.path().join("out.bin");
let group = session.new_download_group_blocking()?;
let group = session.new_file_download_group_blocking()?;
let progress_observer = group.clone();
group.download_file_to_path_blocking(file_info, dest)?;
group.finish_blocking()?;
@@ -996,7 +996,7 @@ mod tests {
let data = b"download result access patterns";
let file_info = upload_bytes_blocking(&session, data, "file.bin")?;
let dest = temp.path().join("out.bin");
let group = session.new_download_group_blocking()?;
let group = session.new_file_download_group_blocking()?;
let handle = group.download_file_to_path_blocking(file_info.clone(), dest)?;
// Before finish, per-task result is not available yet.
@@ -1027,7 +1027,7 @@ mod tests {
let data = b"download from smol executor";
let file_info = upload_bytes_blocking(&session, data, "test.bin").unwrap();
let dest = temp.path().join("out_smol.bin");
let group = session.new_download_group_blocking().unwrap();
let group = session.new_file_download_group_blocking().unwrap();
group.download_file_to_path_blocking(file_info, dest.clone()).unwrap();
group.finish_blocking().unwrap();
assert_eq!(std::fs::read(&dest).unwrap(), data);
@@ -1056,7 +1056,7 @@ mod tests {
async fn test_download_file_to_path_blocking_errors_in_external_mode() {
let session = XetSessionBuilder::new().build_async().await.unwrap();
assert_eq!(session.runtime_mode, RuntimeMode::External);
let group = session.new_download_group().await.unwrap();
let group = session.new_file_download_group().await.unwrap();
let file_info = XetFileInfo {
hash: String::new(),
file_size: Some(0),
@@ -1078,7 +1078,7 @@ mod tests {
fn test_download_file_to_path_blocking_panics_in_async_context() {
let session = XetSessionBuilder::new().build().unwrap();
assert_eq!(session.runtime_mode, RuntimeMode::Owned);
let group = session.new_download_group_blocking().unwrap();
let group = session.new_file_download_group_blocking().unwrap();
let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
let file_info = XetFileInfo {
hash: String::new(),

View File

@@ -6,7 +6,7 @@
//! ```text
//! XetSession — holds runtime context and authentication credentials
//! ├── UploadCommit — groups related uploads; finalised with commit()
//! └── DownloadGroup — groups related downloads; finalised with finish()
//! └── FileDownloadGroup — groups related file downloads; finalised with finish()
//! ```
//!
//! Each [`XetSession`] holds its own runtime context and configuration, so
@@ -27,28 +27,28 @@
//! transfers to finish and receive a `HashMap<`[`UniqueID`]`, `[`UploadResult`]`>`
//! keyed by task ID.
//!
//! `UploadResult` = `Arc<Result<`[`FileMetadata`]`, `[`XetError`]`>>`.
//! `UploadResult` = `Arc<Result<`[`FileMetadata`]`, `[`SessionError`]`>>`.
//! Per-task results can also be read from the returned [`UploadTaskHandle`]
//! via [`result`](UploadTaskHandle::result) after `commit()` returns.
//!
//! ## Downloads
//!
//! Create a [`DownloadGroup`] with [`XetSession::new_download_group`] (async)
//! or [`XetSession::new_download_group_blocking`] (sync), queue files with
//! [`download_file_to_path`](DownloadGroup::download_file_to_path) /
//! [`download_file_to_path_blocking`](DownloadGroup::download_file_to_path_blocking),
//! then call [`finish`](DownloadGroup::finish) (async) or
//! [`finish_blocking`](DownloadGroup::finish_blocking) (sync) to wait for all
//! Create a [`FileDownloadGroup`] with [`XetSession::new_file_download_group`] (async)
//! or [`XetSession::new_file_download_group_blocking`] (sync), queue files with
//! [`download_file_to_path`](FileDownloadGroup::download_file_to_path) /
//! [`download_file_to_path_blocking`](FileDownloadGroup::download_file_to_path_blocking),
//! then call [`finish`](FileDownloadGroup::finish) (async) or
//! [`finish_blocking`](FileDownloadGroup::finish_blocking) (sync) to wait for all
//! transfers to complete and receive a `HashMap<`[`UniqueID`]`, `[`DownloadResult`]`>`
//! keyed by task ID.
//!
//! `DownloadResult` = `Arc<Result<`[`DownloadedFile`]`, `[`XetError`]`>>`.
//! `DownloadResult` = `Arc<Result<`[`DownloadedFile`]`, `[`SessionError`]`>>`.
//! Per-task results can also be read from the returned [`DownloadTaskHandle`]
//! via [`result`](DownloadTaskHandle::result) after `finish()` returns.
//!
//! ## Progress tracking
//!
//! Both [`UploadCommit`] and [`DownloadGroup`] expose `get_progress()`,
//! Both [`UploadCommit`] and [`FileDownloadGroup`] expose `get_progress()`,
//! which returns a [`GroupProgressReport`] without acquiring a lock on the
//! calling thread (useful for Python bindings that must release the GIL).
//! Poll it from a background thread/task while the main thread/task blocks
@@ -56,9 +56,9 @@
//!
//! ## Error handling
//!
//! All public methods return `Result<_, `[`XetError`]`>`.
//! All public methods return `Result<_, `[`SessionError`]`>`.
//! [`commit`](UploadCommit::commit) returns `HashMap<`[`UniqueID`]`, `[`UploadResult`]`>`
//! keyed by task ID, and [`finish`](DownloadGroup::finish) returns
//! keyed by task ID, and [`finish`](FileDownloadGroup::finish) returns
//! `HashMap<`[`UniqueID`]`, `[`DownloadResult`]`>` keyed by task ID, so a single failed
//! file does not discard all others.
//!
@@ -77,12 +77,12 @@
//! // 2. Upload — use the _blocking factory and _blocking methods
//! let commit = session.new_upload_commit_blocking()?;
//! let handle = commit.upload_from_path_blocking("file.bin".into(), Sha256Policy::Compute)?;
//! // UploadResult = Arc<Result<FileMetadata, XetError>>
//! // UploadResult = Arc<Result<FileMetadata, SessionError>>
//! let results = commit.commit_blocking()?;
//! let m = results.values().next().unwrap().as_ref().as_ref().unwrap();
//!
//! // 3. Download — use the _blocking factory and finish_blocking
//! let group = session.new_download_group_blocking()?;
//! let group = session.new_file_download_group_blocking()?;
//! let info = XetFileInfo {
//! hash: m.hash.clone(),
//! file_size: Some(m.file_size),
@@ -90,10 +90,10 @@
//! };
//! let dl_handle = group.download_file_to_path_blocking(info, "out/file.bin".into())?;
//! let finish_results = group.finish_blocking()?;
//! // DownloadResult = Arc<Result<DownloadedFile, XetError>>
//! // DownloadResult = Arc<Result<DownloadedFile, SessionError>>
//! let r = finish_results.get(&dl_handle.task_id).unwrap().as_ref().as_ref().unwrap();
//!
//! # Ok::<(), xet::XetError>(())
//! # Ok::<(), xet::xet_session::SessionError>(())
//! ```
//!
//! # Quick start — async API
@@ -101,7 +101,7 @@
//! ```rust,no_run
//! use xet::xet_session::{Sha256Policy, XetFileInfo, XetSessionBuilder};
//!
//! # async fn example() -> Result<(), xet::XetError> {
//! # async fn example() -> Result<(), xet::xet_session::SessionError> {
//! // 1. Build a session. build_async() auto-detects the executor:
//! // - tokio (multi-thread): wraps the caller's handle, no second thread pool.
//! // - non-tokio (smol, async-std, etc.): creates an owned thread pool.
@@ -114,12 +114,12 @@
//! // 2. Upload — use the async factory and async methods
//! let commit = session.new_upload_commit().await?;
//! let handle = commit.upload_from_path("file.bin".into(), Sha256Policy::Compute).await?;
//! // UploadResult = Arc<Result<FileMetadata, XetError>>
//! // UploadResult = Arc<Result<FileMetadata, SessionError>>
//! let results = commit.commit().await?;
//! let m = results.values().next().unwrap().as_ref().as_ref().unwrap();
//!
//! // 3. Download — use the async factory and async finish
//! let group = session.new_download_group().await?;
//! let group = session.new_file_download_group().await?;
//! let info = XetFileInfo {
//! hash: m.hash.clone(),
//! file_size: Some(m.file_size),
@@ -127,19 +127,23 @@
//! };
//! let dl_handle = group.download_file_to_path(info, "out/file.bin".into()).await?;
//! let finish_results = group.finish().await?;
//! // DownloadResult = Arc<Result<DownloadedFile, XetError>>
//! // DownloadResult = Arc<Result<DownloadedFile, SessionError>>
//! let r = finish_results.get(&dl_handle.task_id).unwrap().as_ref().as_ref().unwrap();
//! # Ok(())
//! # }
//! ```
mod common;
mod download_group;
mod download_streams;
mod errors;
mod file_download_group;
mod session;
mod tasks;
mod upload_commit;
pub use download_group::{DownloadGroup, DownloadResult, DownloadedFile};
pub use download_streams::{XetDownloadStream, XetUnorderedDownloadStream};
pub use errors::SessionError;
pub use file_download_group::{DownloadResult, DownloadedFile, FileDownloadGroup};
pub use session::{XetSession, XetSessionBuilder};
pub use tasks::{DownloadTaskHandle, TaskHandle, TaskStatus, UploadTaskHandle};
pub use upload_commit::{FileMetadata, UploadCommit, UploadResult};

View File

@@ -2,22 +2,25 @@
use std::collections::HashMap;
use std::future::Future;
use std::ops::Range;
use std::pin::pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Waker};
use http::HeaderMap;
use tracing::info;
use ulid::Ulid;
use xet_client::cas_client::auth::TokenRefresher;
use xet_data::processing::{FileDownloadSession, XetFileInfo};
use xet_data::progress_tracking::UniqueID;
use xet_runtime::RuntimeError;
use xet_runtime::config::XetConfig;
use xet_runtime::core::XetRuntime;
use super::download_group::DownloadGroup;
use super::common::create_translator_config;
use super::download_streams::{XetDownloadStream, XetUnorderedDownloadStream};
use super::errors::SessionError;
use super::file_download_group::FileDownloadGroup;
use super::upload_commit::UploadCommit;
use crate::error::XetError;
/// Session state
enum SessionState {
@@ -34,7 +37,7 @@ enum SessionState {
///
/// - **`External`**: session wraps a caller-provided tokio handle via [`XetSessionBuilder::with_tokio_handle`] or
/// [`XetSessionBuilder::build_async`] (tokio context). Only async methods may be called; `_blocking` methods return
/// [`XetError::WrongRuntimeMode`]. No second thread pool is created.
/// [`SessionError::WrongRuntimeMode`]. No second thread pool is created.
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub(super) enum RuntimeMode {
Owned,
@@ -62,12 +65,14 @@ pub struct XetSessionInner {
// Track active upload commits and download groups.
pub(super) active_upload_commits: Mutex<HashMap<UniqueID, UploadCommit>>,
pub(super) active_download_groups: Mutex<HashMap<UniqueID, DownloadGroup>>,
pub(super) active_file_download_groups: Mutex<HashMap<UniqueID, FileDownloadGroup>>,
// Lazily-initialized download session for streaming downloads (no group-level progress).
streaming_download_session: tokio::sync::OnceCell<Arc<FileDownloadSession>>,
// Session state
state: Mutex<SessionState>,
// "id" is used to identity a group of activities on our server, and so need to be globally unique
pub(super) id: Ulid,
pub(super) id: UniqueID,
}
/// Probe whether a tokio runtime handle meets the requirements for External mode.
@@ -128,12 +133,12 @@ fn handle_meets_session_requirements(handle: &tokio::runtime::Handle) -> bool {
/// .with_endpoint("https://cas.example.com".into())
/// .with_token_info("my-token".into(), 1_700_000_000)
/// .build()?;
/// # Ok::<(), xet::XetError>(())
/// # Ok::<(), xet::xet_session::SessionError>(())
/// ```
///
/// ```rust,no_run
/// # use xet::xet_session::XetSessionBuilder;
/// # async fn example() -> Result<(), xet::XetError> {
/// # async fn example() -> Result<(), xet::xet_session::SessionError> {
/// // Async context — wraps the caller's tokio handle (External mode) if inside tokio,
/// // or creates an owned runtime (Owned mode) if called from a non-tokio executor:
/// let session = XetSessionBuilder::new()
@@ -220,8 +225,8 @@ impl XetSessionBuilder {
///
/// If the handle meets session requirements (multi-thread flavor, time driver, IO driver),
/// the session will wrap it — no second thread pool is created (External mode). Only async
/// methods (`new_upload_commit`, `new_download_group`) may be called; `_blocking` variants
/// will return [`XetError::WrongRuntimeMode`].
/// methods (`new_upload_commit`, `new_file_download_group`) may be called; `_blocking` variants
/// will return [`SessionError::WrongRuntimeMode`].
///
/// If the handle does **not** meet requirements (e.g. `current_thread` flavor or missing
/// drivers), it is silently ignored and [`build`](Self::build) will fall back to creating
@@ -253,7 +258,7 @@ impl XetSessionBuilder {
/// `with_tokio_handle`; falls back to an owned thread pool — Owned mode.
/// - **Non-tokio context** (smol, async-std, etc.): creates an owned thread pool — Owned mode; async methods use an
/// internal bridge compatible with any executor.
pub async fn build_async(self) -> Result<XetSession, XetError> {
pub async fn build_async(self) -> Result<XetSession, SessionError> {
match tokio::runtime::Handle::try_current() {
Ok(handle) => self.with_tokio_handle(handle).build(),
Err(_) => self.build(),
@@ -268,7 +273,7 @@ impl XetSessionBuilder {
/// executor, and `_blocking` methods are available.
///
/// For async contexts, prefer [`build_async`](Self::build_async).
pub fn build(self) -> Result<XetSession, XetError> {
pub fn build(self) -> Result<XetSession, SessionError> {
let (runtime, mode) = match self.tokio_handle {
Some(handle) => (XetRuntime::from_external_with_config(handle, self.config.clone()), RuntimeMode::External),
None => (XetRuntime::new_with_config(self.config.clone())?, RuntimeMode::Owned),
@@ -290,7 +295,7 @@ impl XetSessionBuilder {
/// `XetSession` is the top-level entry point for the xet-session API. It
/// owns a `XetRuntime` (tokio thread pool) and holds authentication
/// credentials that are shared by all [`UploadCommit`]s and
/// [`DownloadGroup`]s created from it.
/// [`FileDownloadGroup`]s created from it.
///
/// # Cloning
///
@@ -300,7 +305,7 @@ impl XetSessionBuilder {
/// # Lifecycle
///
/// 1. Create a session with [`XetSessionBuilder`].
/// 2. Create one or more [`UploadCommit`]s / [`DownloadGroup`]s.
/// 2. Create one or more [`UploadCommit`]s / [`FileDownloadGroup`]s.
/// 3. For an emergency stop, call [`XetSession::abort`].
#[derive(Clone)]
pub struct XetSession {
@@ -335,9 +340,10 @@ impl XetSession {
token_refresher,
custom_headers,
active_upload_commits: Mutex::new(HashMap::new()),
active_download_groups: Mutex::new(HashMap::new()),
active_file_download_groups: Mutex::new(HashMap::new()),
streaming_download_session: tokio::sync::OnceCell::new(),
state: Mutex::new(SessionState::Alive),
id: Ulid::new(),
id: UniqueID::new(),
}),
}
}
@@ -360,18 +366,18 @@ impl XetSession {
/// Create a new [`UploadCommit`] that groups related file uploads.
///
/// Returns `Err(XetError::Aborted)` if the session has been aborted.
/// Returns `Err(SessionError::Aborted)` if the session has been aborted.
///
/// # Note
///
/// This is an `async fn` and must be `.await`ed. For sync Rust or Python (PyO3) callers,
/// use [`new_upload_commit_blocking`](Self::new_upload_commit_blocking).
pub async fn new_upload_commit(&self) -> Result<UploadCommit, XetError> {
pub async fn new_upload_commit(&self) -> Result<UploadCommit, SessionError> {
// Check state before the async init; drop the guard so it is not held across .await.
{
let state = self.state.lock()?;
if matches!(*state, SessionState::Aborted) {
return Err(XetError::Aborted);
return Err(SessionError::Aborted);
}
}
@@ -391,8 +397,8 @@ impl XetSession {
/// The returned [`UploadCommit`] supports both async methods (`upload_from_path`,
/// `commit`) and blocking methods (`upload_from_path_blocking`, `commit_blocking`).
///
/// Returns `Err(XetError::Aborted)` if the session has been aborted.
/// Returns `Err(XetError::WrongRuntimeMode)` if the session uses an external
/// Returns `Err(SessionError::Aborted)` if the session has been aborted.
/// Returns `Err(SessionError::WrongRuntimeMode)` if the session uses an external
/// tokio runtime (from [`XetSessionBuilder::with_tokio_handle`] or tokio-detected
/// [`XetSessionBuilder::build_async`]).
///
@@ -403,9 +409,9 @@ impl XetSession {
/// async-std, `futures::executor`) do not set this context, so calling from those is
/// safe — it blocks the executor thread until the task completes. Use
/// [`new_upload_commit`](Self::new_upload_commit) from async contexts instead.
pub fn new_upload_commit_blocking(&self) -> Result<UploadCommit, XetError> {
pub fn new_upload_commit_blocking(&self) -> Result<UploadCommit, SessionError> {
if matches!(self.runtime_mode, RuntimeMode::External) {
return Err(XetError::wrong_mode(
return Err(SessionError::wrong_mode(
"new_upload_commit_blocking() cannot be called on a session using an \
external tokio runtime (with_tokio_handle() or tokio build_async()); \
use new_upload_commit().await instead",
@@ -414,7 +420,7 @@ impl XetSession {
{
let state = self.state.lock()?;
if matches!(*state, SessionState::Aborted) {
return Err(XetError::Aborted);
return Err(SessionError::Aborted);
}
}
@@ -423,41 +429,41 @@ impl XetSession {
Ok(commit)
}
/// Create a new [`DownloadGroup`] that groups related file downloads.
/// Create a new [`FileDownloadGroup`] that groups related file downloads.
///
/// Returns `Err(XetError::Aborted)` if the session has been aborted.
/// Returns `Err(SessionError::Aborted)` if the session has been aborted.
///
/// # Note
///
/// This is an `async fn` and must be `.await`ed. For sync Rust or Python (PyO3) callers,
/// use [`new_download_group_blocking`](Self::new_download_group_blocking).
pub async fn new_download_group(&self) -> Result<DownloadGroup, XetError> {
/// use [`new_file_download_group_blocking`](Self::new_file_download_group_blocking).
pub async fn new_file_download_group(&self) -> Result<FileDownloadGroup, SessionError> {
// Check state before the async init; drop the guard so it is not held across .await.
{
let state = self.state.lock()?;
if matches!(*state, SessionState::Aborted) {
return Err(XetError::Aborted);
return Err(SessionError::Aborted);
}
}
let session = self.clone();
let group = self
.dispatch("new_download_group", async move { DownloadGroup::new(session).await })
.dispatch("new_file_download_group", async move { FileDownloadGroup::new(session).await })
.await??;
// Register the group (sync insertion, safe in any executor context)
self.active_download_groups.lock()?.insert(group.id(), group.clone());
self.active_file_download_groups.lock()?.insert(group.id(), group.clone());
Ok(group)
}
/// Create a new [`DownloadGroup`] from a **sync** (non-async) context.
/// Create a new [`FileDownloadGroup`] from a **sync** (non-async) context.
///
/// The returned [`DownloadGroup`] supports both the async [`finish`](DownloadGroup::finish)
/// and blocking [`finish_blocking`](DownloadGroup::finish_blocking) methods.
/// The returned [`FileDownloadGroup`] supports both the async [`finish`](FileDownloadGroup::finish)
/// and blocking [`finish_blocking`](FileDownloadGroup::finish_blocking) methods.
///
/// Returns `Err(XetError::Aborted)` if the session has been aborted.
/// Returns `Err(XetError::WrongRuntimeMode)` if the session uses an external
/// Returns `Err(SessionError::Aborted)` if the session has been aborted.
/// Returns `Err(SessionError::WrongRuntimeMode)` if the session uses an external
/// tokio runtime (from [`XetSessionBuilder::with_tokio_handle`] or tokio-detected
/// [`XetSessionBuilder::build_async`]).
///
@@ -467,32 +473,183 @@ impl XetSession {
/// context that `Handle::block_on` detects and panics on). Non-tokio executors (smol,
/// async-std, `futures::executor`) do not set this context, so calling from those is
/// safe — it blocks the executor thread until the task completes. Use
/// [`new_download_group`](Self::new_download_group) from async contexts instead.
pub fn new_download_group_blocking(&self) -> Result<DownloadGroup, XetError> {
/// [`new_file_download_group`](Self::new_file_download_group) from async contexts instead.
pub fn new_file_download_group_blocking(&self) -> Result<FileDownloadGroup, SessionError> {
if matches!(self.runtime_mode, RuntimeMode::External) {
return Err(XetError::wrong_mode(
"new_download_group_blocking() cannot be called on a session using an \
return Err(SessionError::wrong_mode(
"new_file_download_group_blocking() cannot be called on a session using an \
external tokio runtime (with_tokio_handle() or tokio build_async()); \
use new_download_group().await instead",
use new_file_download_group().await instead",
));
}
{
let state = self.state.lock()?;
if matches!(*state, SessionState::Aborted) {
return Err(XetError::Aborted);
return Err(SessionError::Aborted);
}
}
let group = self.runtime.external_run_async_task(DownloadGroup::new(self.clone()))??;
self.active_download_groups.lock()?.insert(group.id(), group.clone());
let group = self.runtime.external_run_async_task(FileDownloadGroup::new(self.clone()))??;
self.active_file_download_groups.lock()?.insert(group.id(), group.clone());
Ok(group)
}
/// Initialise (or return the cached) [`FileDownloadSession`] used for
/// streaming downloads. The session is created lazily on the first call
/// with no group-level progress tracking.
async fn get_or_init_streaming_session(&self) -> Result<Arc<FileDownloadSession>, SessionError> {
self.streaming_download_session
.get_or_try_init(|| async {
let config = create_translator_config(self)?;
let session = FileDownloadSession::new(Arc::new(config)).await?;
Ok::<_, SessionError>(session)
})
.await
.cloned()
}
/// Create a [`XetDownloadStream`] for the given file, optionally
/// restricted to a byte range.
///
/// The returned stream yields data chunks as they are reconstructed,
/// with built-in progress tracking via
/// [`get_progress`](XetDownloadStream::get_progress).
/// The reconstruction task is spawned on the session's runtime but
/// paused until [`start`](XetDownloadStream::start) is called (or the
/// first [`next`](XetDownloadStream::next) /
/// [`blocking_next`](XetDownloadStream::blocking_next)). Because the
/// spawn happens during creation, `start()` and `next()` work from any
/// executor (tokio, smol, async-std, futures).
///
/// If `range` is `Some`, only the specified byte range of the file is
/// reconstructed.
///
/// The stream is independent of any [`FileDownloadGroup`] and is not
/// tracked by the session — the caller is responsible for consuming
/// or cancelling it.
///
/// Returns `Err(SessionError::Aborted)` if the session has been aborted.
pub async fn download_stream(
&self,
file_info: XetFileInfo,
range: Option<Range<u64>>,
) -> Result<XetDownloadStream, SessionError> {
self.check_alive()?;
let session = self.clone();
self.dispatch("download_stream", async move {
let dl_session = session.get_or_init_streaming_session().await?;
let (id, stream) = dl_session.download_stream(&file_info, range).await?;
Ok(XetDownloadStream::new(stream, dl_session, id))
})
.await?
}
/// Blocking version of [`download_stream`](Self::download_stream).
///
/// The reconstruction task is spawned on the session's runtime but
/// paused until [`start`](XetDownloadStream::start) is called (or the
/// first [`blocking_next`](XetDownloadStream::blocking_next)). No
/// tokio runtime context is required on the calling thread after this
/// method returns.
///
/// # Panics
///
/// Panics if called from within a tokio async runtime.
pub fn download_stream_blocking(
&self,
file_info: XetFileInfo,
range: Option<Range<u64>>,
) -> Result<XetDownloadStream, SessionError> {
if matches!(self.runtime_mode, RuntimeMode::External) {
return Err(SessionError::wrong_mode(
"download_stream_blocking() cannot be called on a session using an \
external tokio runtime (with_tokio_handle() or tokio build_async()); \
use download_stream().await instead",
));
}
self.check_alive()?;
let session = self.clone();
self.runtime.external_run_async_task(async move {
let dl_session = session.get_or_init_streaming_session().await?;
let (id, stream) = dl_session.download_stream(&file_info, range).await?;
Ok(XetDownloadStream::new(stream, dl_session, id))
})?
}
/// Create an [`XetUnorderedDownloadStream`] for the given file,
/// optionally restricted to a byte range.
///
/// The returned stream yields `(offset, Bytes)` chunks in whatever
/// order they complete, with built-in progress tracking via
/// [`get_progress`](XetUnorderedDownloadStream::get_progress).
///
/// If `range` is `Some`, only the specified byte range of the file is
/// reconstructed.
///
/// Can be awaited from any async executor (tokio, smol, async-std,
/// futures).
///
/// The stream is independent of any [`FileDownloadGroup`] and is not
/// tracked by the session — the caller is responsible for consuming
/// or cancelling it.
///
/// Returns `Err(SessionError::Aborted)` if the session has been aborted.
pub async fn download_unordered_stream(
&self,
file_info: XetFileInfo,
range: Option<Range<u64>>,
) -> Result<XetUnorderedDownloadStream, SessionError> {
self.check_alive()?;
let session = self.clone();
self.dispatch("download_unordered_stream", async move {
let dl_session = session.get_or_init_streaming_session().await?;
let (id, stream) = dl_session.download_unordered_stream(&file_info, range).await?;
Ok(XetUnorderedDownloadStream::new(stream, dl_session, id))
})
.await?
}
/// Blocking version of [`download_unordered_stream`](Self::download_unordered_stream).
///
/// The reconstruction task is spawned on the session's runtime but
/// paused until [`start`](XetUnorderedDownloadStream::start) is called
/// (or the first [`blocking_next`](XetUnorderedDownloadStream::blocking_next)).
/// No tokio runtime context is required on the calling thread after
/// this method returns.
///
/// # Panics
///
/// Panics if called from within a tokio async runtime.
pub fn download_unordered_stream_blocking(
&self,
file_info: XetFileInfo,
range: Option<Range<u64>>,
) -> Result<XetUnorderedDownloadStream, SessionError> {
if matches!(self.runtime_mode, RuntimeMode::External) {
return Err(SessionError::wrong_mode(
"download_unordered_stream_blocking() cannot be called on a session using an \
external tokio runtime (with_tokio_handle() or tokio build_async()); \
use download_unordered_stream().await instead",
));
}
self.check_alive()?;
let session = self.clone();
self.runtime.external_run_async_task(async move {
let dl_session = session.get_or_init_streaming_session().await?;
let (id, stream) = dl_session.download_unordered_stream(&file_info, range).await?;
Ok(XetUnorderedDownloadStream::new(stream, dl_session, id))
})?
}
/// Abort the session - cancel all currently running tasks
///
/// This performs a SIGINT-style shutdown, aborting all active upload and download tasks.
/// Use this when a Ctrl+C signal is detected or when you need to immediately stop all operations.
pub fn abort(&self) -> Result<(), XetError> {
pub fn abort(&self) -> Result<(), SessionError> {
// Mark as not accepting new work, hold the lock so no new task can be created when aborting
let mut state = self.state.lock()?;
*state = SessionState::Aborted;
@@ -506,27 +663,30 @@ impl XetSession {
for (_id, task) in active_upload_commits {
task.abort()?;
}
let active_download_groups = std::mem::take(&mut *self.active_download_groups.lock()?);
for (_id, task) in active_download_groups {
let active_file_download_groups = std::mem::take(&mut *self.active_file_download_groups.lock()?);
for (_id, task) in active_file_download_groups {
task.abort()?;
}
Ok(())
}
pub(super) fn check_alive(&self) -> Result<(), XetError> {
if matches!(*self.state.lock()?, SessionState::Aborted) {
return Err(XetError::Aborted);
if let Some(streaming_download_session) = self.streaming_download_session.get() {
streaming_download_session.abort_active_streams();
}
Ok(())
}
pub(super) fn finish_upload_commit(&self, commit_id: UniqueID) -> Result<(), XetError> {
pub(super) fn check_alive(&self) -> Result<(), SessionError> {
if matches!(*self.state.lock()?, SessionState::Aborted) {
return Err(SessionError::Aborted);
}
Ok(())
}
pub(super) fn finish_upload_commit(&self, commit_id: UniqueID) -> Result<(), SessionError> {
self.active_upload_commits.lock()?.remove(&commit_id);
Ok(())
}
pub(super) fn finish_download_group(&self, group_id: UniqueID) -> Result<(), XetError> {
self.active_download_groups.lock()?.remove(&group_id);
pub(super) fn finish_file_download_group(&self, group_id: UniqueID) -> Result<(), SessionError> {
self.active_file_download_groups.lock()?.remove(&group_id);
Ok(())
}
}
@@ -553,13 +713,6 @@ mod tests {
assert_ne!(s1.id, s2.id);
}
#[test]
// Session ID is a Ulid, to guard future regressions.
fn test_session_id_is_ulid() {
let s = XetSessionBuilder::new().build().unwrap();
assert!(Ulid::from_string(&s.id.to_string()).is_ok())
}
// ── Abort behavior ───────────────────────────────────────────────────────
#[test]
@@ -568,7 +721,7 @@ mod tests {
let session = XetSessionBuilder::new().build().unwrap();
session.abort().unwrap();
let err = session.check_alive().unwrap_err();
assert!(matches!(err, XetError::Aborted));
assert!(matches!(err, SessionError::Aborted));
}
#[test]
@@ -577,16 +730,16 @@ mod tests {
let session = XetSessionBuilder::new().build().unwrap();
session.abort().unwrap();
let err = session.new_upload_commit_blocking().err().unwrap();
assert!(matches!(err, XetError::Aborted));
assert!(matches!(err, SessionError::Aborted));
}
#[test]
// new_download_group_blocking on an aborted session returns Aborted.
fn test_new_download_group_after_abort_returns_aborted() {
// new_file_download_group_blocking on an aborted session returns Aborted.
fn test_new_file_download_group_after_abort_returns_aborted() {
let session = XetSessionBuilder::new().build().unwrap();
session.abort().unwrap();
let err = session.new_download_group_blocking().err().unwrap();
assert!(matches!(err, XetError::Aborted));
let err = session.new_file_download_group_blocking().err().unwrap();
assert!(matches!(err, SessionError::Aborted));
}
#[test]
@@ -601,11 +754,11 @@ mod tests {
#[test]
// Aborting a session clears all registered download groups.
fn test_abort_clears_active_download_groups() {
fn test_abort_clears_active_file_download_groups() {
let session = XetSessionBuilder::new().build().unwrap();
let _g1 = session.new_download_group_blocking().unwrap();
let _g1 = session.new_file_download_group_blocking().unwrap();
session.abort().unwrap();
assert_eq!(session.active_download_groups.lock().unwrap().len(), 0);
assert_eq!(session.active_file_download_groups.lock().unwrap().len(), 0);
}
// ── Registration ─────────────────────────────────────────────────────────
@@ -620,10 +773,10 @@ mod tests {
#[test]
// A new download group is registered in the session's active set.
fn test_new_download_group_registers_in_session() {
fn test_new_file_download_group_registers_in_session() {
let session = XetSessionBuilder::new().build().unwrap();
let _group = session.new_download_group_blocking().unwrap();
assert_eq!(session.active_download_groups.lock().unwrap().len(), 1);
let _group = session.new_file_download_group_blocking().unwrap();
assert_eq!(session.active_file_download_groups.lock().unwrap().len(), 1);
}
// ── Deregistration ───────────────────────────────────────────────────────
@@ -640,14 +793,14 @@ mod tests {
}
#[test]
// finish_download_group removes only the specified group, leaving others intact.
fn test_finish_download_group_removes_only_that_group() {
// finish_file_download_group removes only the specified group, leaving others intact.
fn test_finish_file_download_group_removes_only_that_group() {
let session = XetSessionBuilder::new().build().unwrap();
let g1 = session.new_download_group_blocking().unwrap();
let _g2 = session.new_download_group_blocking().unwrap();
assert_eq!(session.active_download_groups.lock().unwrap().len(), 2);
session.finish_download_group(g1.id()).unwrap();
assert_eq!(session.active_download_groups.lock().unwrap().len(), 1);
let g1 = session.new_file_download_group_blocking().unwrap();
let _g2 = session.new_file_download_group_blocking().unwrap();
assert_eq!(session.active_file_download_groups.lock().unwrap().len(), 2);
session.finish_file_download_group(g1.id()).unwrap();
assert_eq!(session.active_file_download_groups.lock().unwrap().len(), 1);
}
#[test]
@@ -663,14 +816,14 @@ mod tests {
// ── Async abort behavior ──────────────────────────────────────────────────
#[tokio::test(flavor = "multi_thread")]
// new_upload_commit / new_download_group on an aborted session both return Aborted.
// new_upload_commit / new_file_download_group on an aborted session both return Aborted.
async fn test_async_new_after_abort_returns_aborted() {
let session = XetSessionBuilder::new().build_async().await.unwrap();
session.abort().unwrap();
let commit_err = session.new_upload_commit().await.err().unwrap();
let group_err = session.new_download_group().await.err().unwrap();
assert!(matches!(commit_err, XetError::Aborted));
assert!(matches!(group_err, XetError::Aborted));
let group_err = session.new_file_download_group().await.err().unwrap();
assert!(matches!(commit_err, SessionError::Aborted));
assert!(matches!(group_err, SessionError::Aborted));
}
#[tokio::test(flavor = "multi_thread")]
@@ -679,10 +832,10 @@ mod tests {
let session = XetSessionBuilder::new().build_async().await.unwrap();
let _c1 = session.new_upload_commit().await.unwrap();
let _c2 = session.new_upload_commit().await.unwrap();
let _g1 = session.new_download_group().await.unwrap();
let _g1 = session.new_file_download_group().await.unwrap();
session.abort().unwrap();
assert_eq!(session.active_upload_commits.lock().unwrap().len(), 0);
assert_eq!(session.active_download_groups.lock().unwrap().len(), 0);
assert_eq!(session.active_file_download_groups.lock().unwrap().len(), 0);
}
// ── Async registration ────────────────────────────────────────────────────
@@ -693,9 +846,9 @@ mod tests {
async fn test_async_new_registers_in_session() {
let session = XetSessionBuilder::new().build_async().await.unwrap();
let _commit = session.new_upload_commit().await.unwrap();
let _group = session.new_download_group().await.unwrap();
let _group = session.new_file_download_group().await.unwrap();
assert_eq!(session.active_upload_commits.lock().unwrap().len(), 1);
assert_eq!(session.active_download_groups.lock().unwrap().len(), 1);
assert_eq!(session.active_file_download_groups.lock().unwrap().len(), 1);
}
// ── Async deregistration ──────────────────────────────────────────────────
@@ -707,14 +860,14 @@ mod tests {
let session = XetSessionBuilder::new().build_async().await.unwrap();
let c1 = session.new_upload_commit().await.unwrap();
let _c2 = session.new_upload_commit().await.unwrap();
let g1 = session.new_download_group().await.unwrap();
let _g2 = session.new_download_group().await.unwrap();
let g1 = session.new_file_download_group().await.unwrap();
let _g2 = session.new_file_download_group().await.unwrap();
assert_eq!(session.active_upload_commits.lock().unwrap().len(), 2);
assert_eq!(session.active_download_groups.lock().unwrap().len(), 2);
assert_eq!(session.active_file_download_groups.lock().unwrap().len(), 2);
session.finish_upload_commit(c1.id()).unwrap();
session.finish_download_group(g1.id()).unwrap();
session.finish_file_download_group(g1.id()).unwrap();
assert_eq!(session.active_upload_commits.lock().unwrap().len(), 1);
assert_eq!(session.active_download_groups.lock().unwrap().len(), 1);
assert_eq!(session.active_file_download_groups.lock().unwrap().len(), 1);
}
// ── handle_meets_session_requirements ────────────────────────────────────
@@ -763,16 +916,16 @@ mod tests {
let session = XetSessionBuilder::new().build_async().await.unwrap();
assert_eq!(session.runtime_mode, RuntimeMode::External);
let err = session.new_upload_commit_blocking().err().unwrap();
assert!(matches!(err, XetError::WrongRuntimeMode(_)));
assert!(matches!(err, SessionError::WrongRuntimeMode(_)));
}
#[tokio::test(flavor = "multi_thread")]
// new_download_group_blocking returns WrongRuntimeMode on an External-mode session.
async fn test_new_download_group_blocking_errors_in_external_mode() {
// new_file_download_group_blocking returns WrongRuntimeMode on an External-mode session.
async fn test_new_file_download_group_blocking_errors_in_external_mode() {
let session = XetSessionBuilder::new().build_async().await.unwrap();
assert_eq!(session.runtime_mode, RuntimeMode::External);
let err = session.new_download_group_blocking().err().unwrap();
assert!(matches!(err, XetError::WrongRuntimeMode(_)));
let err = session.new_file_download_group_blocking().err().unwrap();
assert!(matches!(err, SessionError::WrongRuntimeMode(_)));
}
// ── Owned-mode _blocking panic guard ─────────────────────────────────────
@@ -792,15 +945,224 @@ mod tests {
}
#[test]
// new_download_group_blocking panics when called from within a tokio runtime on an
// new_file_download_group_blocking panics when called from within a tokio runtime on an
// Owned-mode session: same mechanism as the upload variant above.
fn test_new_download_group_blocking_panics_in_async_context() {
fn test_new_file_download_group_blocking_panics_in_async_context() {
let session = XetSessionBuilder::new().build().unwrap();
assert_eq!(session.runtime_mode, RuntimeMode::Owned);
let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
rt.block_on(async { session.new_download_group_blocking() })
rt.block_on(async { session.new_file_download_group_blocking() })
}));
assert!(result.is_err(), "new_download_group_blocking() must panic when called from async");
assert!(result.is_err(), "new_file_download_group_blocking() must panic when called from async");
}
// ── Streaming download ──────────────────────────────────────────────────
#[tokio::test(flavor = "multi_thread")]
// download_stream on an aborted session returns Aborted.
async fn test_download_stream_on_aborted_session_returns_aborted() {
let session = XetSessionBuilder::new().build_async().await.unwrap();
session.abort().unwrap();
let result = session
.download_stream(
XetFileInfo {
hash: "abc123".to_string(),
file_size: Some(1024),
sha256: None,
},
None,
)
.await;
assert!(matches!(result, Err(SessionError::Aborted)));
}
#[test]
// download_stream_blocking on an aborted session returns Aborted.
fn test_download_stream_blocking_on_aborted_session_returns_aborted() {
let session = XetSessionBuilder::new().build().unwrap();
session.abort().unwrap();
let result = session.download_stream_blocking(
XetFileInfo {
hash: "abc123".to_string(),
file_size: Some(1024),
sha256: None,
},
None,
);
assert!(matches!(result, Err(SessionError::Aborted)));
}
#[tokio::test(flavor = "multi_thread")]
// download_stream_blocking returns WrongRuntimeMode on an External-mode session.
async fn test_download_stream_blocking_errors_in_external_mode() {
let session = XetSessionBuilder::new().build_async().await.unwrap();
assert_eq!(session.runtime_mode, RuntimeMode::External);
let result = session.download_stream_blocking(
XetFileInfo {
hash: "abc123".to_string(),
file_size: Some(1024),
sha256: None,
},
None,
);
assert!(matches!(result, Err(SessionError::WrongRuntimeMode(_))));
}
// ── Streaming download round-trip tests ─────────────────────────────────
use tempfile::{TempDir, tempdir};
use xet_data::processing::Sha256Policy;
async fn local_session(temp: &TempDir) -> Result<XetSession, Box<dyn std::error::Error>> {
let cas_path = temp.path().join("cas");
Ok(XetSessionBuilder::new()
.with_endpoint(format!("local://{}", cas_path.display()))
.build_async()
.await?)
}
fn local_session_sync(temp: &TempDir) -> Result<XetSession, Box<dyn std::error::Error>> {
let cas_path = temp.path().join("cas");
Ok(XetSessionBuilder::new()
.with_endpoint(format!("local://{}", cas_path.display()))
.build()?)
}
async fn upload_bytes(
session: &XetSession,
data: &[u8],
name: &str,
) -> Result<XetFileInfo, Box<dyn std::error::Error>> {
let commit = session.new_upload_commit().await?;
let handle = commit
.upload_bytes(data.to_vec(), Sha256Policy::Compute, Some(name.into()))
.await?;
let results = commit.commit().await?;
let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap();
Ok(XetFileInfo {
hash: meta.hash.clone(),
file_size: Some(meta.file_size),
sha256: meta.sha256.clone(),
})
}
fn upload_bytes_blocking(
session: &XetSession,
data: &[u8],
name: &str,
) -> Result<XetFileInfo, Box<dyn std::error::Error>> {
let commit = session.new_upload_commit_blocking()?;
let handle = commit.upload_bytes_blocking(data.to_vec(), Sha256Policy::Compute, Some(name.into()))?;
let results = commit.commit_blocking()?;
let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap();
Ok(XetFileInfo {
hash: meta.hash.clone(),
file_size: Some(meta.file_size),
sha256: meta.sha256.clone(),
})
}
#[tokio::test(flavor = "multi_thread")]
// Async streaming download round-trip: upload, stream, verify content.
async fn test_download_stream_round_trip() {
let temp = tempdir().unwrap();
let session = local_session(&temp).await.unwrap();
let original = b"Hello, streaming download!";
let file_info = upload_bytes(&session, original, "stream.bin").await.unwrap();
let mut stream = session.download_stream(file_info, None).await.unwrap();
let mut collected = Vec::new();
while let Some(chunk) = stream.next().await.unwrap() {
collected.extend_from_slice(&chunk);
}
assert_eq!(collected, original);
}
#[test]
// Blocking streaming download round-trip: upload, stream, verify content.
fn test_download_stream_blocking_round_trip() {
let temp = tempdir().unwrap();
let session = local_session_sync(&temp).unwrap();
let original = b"Hello, blocking streaming download!";
let file_info = upload_bytes_blocking(&session, original, "stream.bin").unwrap();
let mut stream = session.download_stream_blocking(file_info, None).unwrap();
let mut collected = Vec::new();
while let Some(chunk) = stream.blocking_next().unwrap() {
collected.extend_from_slice(&chunk);
}
assert_eq!(collected, original);
}
#[tokio::test(flavor = "multi_thread")]
// get_progress() reports correct totals after consuming the stream.
async fn test_download_stream_progress_reports_completion() {
let temp = tempdir().unwrap();
let session = local_session(&temp).await.unwrap();
let original = b"progress tracking test data for streaming";
let file_info = upload_bytes(&session, original, "progress.bin").await.unwrap();
let mut stream = session.download_stream(file_info, None).await.unwrap();
let initial = stream.get_progress();
assert_eq!(initial.total_bytes, original.len() as u64);
assert_eq!(initial.bytes_completed, 0);
let mut collected = Vec::new();
while let Some(chunk) = stream.next().await.unwrap() {
collected.extend_from_slice(&chunk);
}
assert_eq!(collected, original);
let final_progress = stream.get_progress();
assert_eq!(final_progress.total_bytes, original.len() as u64);
assert_eq!(final_progress.bytes_completed, original.len() as u64);
}
#[test]
// get_progress() works correctly in blocking mode.
fn test_download_stream_blocking_progress_reports_completion() {
let temp = tempdir().unwrap();
let session = local_session_sync(&temp).unwrap();
let original = b"blocking progress tracking test data";
let file_info = upload_bytes_blocking(&session, original, "progress.bin").unwrap();
let mut stream = session.download_stream_blocking(file_info, None).unwrap();
let mut collected = Vec::new();
while let Some(chunk) = stream.blocking_next().unwrap() {
collected.extend_from_slice(&chunk);
}
assert_eq!(collected, original);
let final_progress = stream.get_progress();
assert_eq!(final_progress.total_bytes, original.len() as u64);
assert_eq!(final_progress.bytes_completed, original.len() as u64);
}
#[tokio::test(flavor = "multi_thread")]
// Multiple sequential streaming downloads reuse the lazy FileDownloadSession.
async fn test_download_stream_multiple_sequential() {
let temp = tempdir().unwrap();
let session = local_session(&temp).await.unwrap();
let data_a = b"first stream payload";
let data_b = b"second stream payload";
let info_a = upload_bytes(&session, data_a, "a.bin").await.unwrap();
let info_b = upload_bytes(&session, data_b, "b.bin").await.unwrap();
let mut stream_a = session.download_stream(info_a, None).await.unwrap();
let mut collected_a = Vec::new();
while let Some(chunk) = stream_a.next().await.unwrap() {
collected_a.extend_from_slice(&chunk);
}
assert_eq!(collected_a, data_a);
let mut stream_b = session.download_stream(info_b, None).await.unwrap();
let mut collected_b = Vec::new();
while let Some(chunk) = stream_b.next().await.unwrap() {
collected_b.extend_from_slice(&chunk);
}
assert_eq!(collected_b, data_b);
}
}

View File

@@ -5,9 +5,9 @@ use std::sync::{Arc, Mutex, OnceLock};
use xet_data::progress_tracking::UniqueID;
use super::download_group::DownloadResult;
use super::SessionError;
use super::file_download_group::DownloadResult;
use super::upload_commit::UploadResult;
use crate::error::XetError;
/// Lifecycle state of a single upload or download task.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
@@ -84,11 +84,11 @@ impl Deref for DownloadTaskHandle {
}
impl TaskHandle {
pub fn status(&self) -> Result<TaskStatus, XetError> {
pub fn status(&self) -> Result<TaskStatus, SessionError> {
if let Some(status) = &self.status {
Ok(*status.lock()?)
} else {
Err(XetError::other("status not available"))
Err(SessionError::other("status not available"))
}
}
}
@@ -195,7 +195,7 @@ mod tests {
let result = handle.result().unwrap();
let dl = result.as_ref().as_ref().unwrap();
assert_eq!(dl.file_info.file_size, Some(99));
assert_eq!(dl.file_info.file_size(), Some(99));
assert_eq!(dl.dest_path, PathBuf::from("out/file.bin"));
}
}

View File

@@ -390,7 +390,7 @@ pub struct UploadCommitInner {
// tokio::sync::Mutex (not std) because registration methods hold this lock across
// .await points (e.g. start_clean in start_upload_file) to serialise with commit.
// DownloadGroupInner uses std::sync::Mutex because its registration is synchronous.
// FileDownloadGroupInner uses std::sync::Mutex because its registration is synchronous.
state: tokio::sync::Mutex<GroupState>,
}

File diff suppressed because it is too large Load Diff

View File

@@ -79,4 +79,5 @@ crate::config_group!({
///
/// Use the environment variable `HF_XET_RECONSTRUCTION_USE_VECTORED_WRITE` to set this value.
ref use_vectored_write: bool = true;
});