From 506fc282912b5d7947be000e48810bd480dd628e Mon Sep 17 00:00:00 2001 From: Hoyt Koepke Date: Wed, 18 Mar 2026 18:07:43 -0700 Subject: [PATCH] Simplify progress tracking + Unify Task ID tracking + Legacy Interface (#726) Currently, progress tracking is split between callback-driven and snapshot-driven paths, making session and task wiring across xet_data, xet_pkg, hf_xet, and git_xet harder to keep consistent. This PR moves upload/download progress to a polling snapshot model backed by atomics. It also switches task identifiers to a UniqueID common with the progress tracking throughout the session APIs. This PR also updates the rate estimation to use the lighter weight exponentially weighted moving averages model, so this can be done at a low level. To preserve compatibility for existing callback consumers, callback-oriented upload/download progress tracking APIs are moved under xet_pkg::legacy and bridged from polling snapshots via a callback based updaters. hf_xet and git_xet are updated to use that legacy bridge layer, so current integrations keep working until everything is fully switched over to the XetSession method. --- Cargo.lock | 16 +- Cargo.toml | 1 - ...pdate_260313_progress_tracking_redesign.md | 122 ++ git_xet/Cargo.toml | 2 +- git_xet/src/app/xet_agent.rs | 45 +- git_xet/src/errors.rs | 2 +- hf_xet/Cargo.lock | 30 +- hf_xet/Cargo.toml | 4 +- hf_xet/src/lib.rs | 5 +- hf_xet/src/progress_update.rs | 4 +- wasm/hf_xet_thin_wasm/Cargo.lock | 11 - wasm/hf_xet_wasm/Cargo.lock | 11 - .../src/utils/exp_weighted_moving_avg.rs | 52 +- xet_data/Cargo.toml | 4 +- xet_data/src/error.rs | 3 + .../data_writer/sequential_writer.rs | 6 +- .../file_reconstruction/file_reconstructor.rs | 33 +- .../reconstruction_terms/file_term.rs | 4 +- .../reconstruction_terms/manager.rs | 6 +- .../reconstruction_terms/xorb_block.rs | 4 +- xet_data/src/file_reconstruction/run_state.rs | 8 +- xet_data/src/processing/bin/example.rs | 10 +- xet_data/src/processing/data_client.rs | 190 +-- .../src/processing/deduplication_interface.rs | 2 +- xet_data/src/processing/file_cleaner.rs | 3 +- .../src/processing/file_download_session.rs | 308 ++--- .../src/processing/file_upload_session.rs | 292 ++-- .../src/processing/migration_tool/migrate.rs | 6 +- xet_data/src/processing/test_utils.rs | 33 +- xet_data/src/progress_tracking/aggregator.rs | 492 ------- .../progress_tracking/download_tracking.rs | 582 -------- xet_data/src/progress_tracking/mod.rs | 23 +- .../src/progress_tracking/no_op_tracker.rs | 17 - .../src/progress_tracking/progress_types.rs | 582 ++++++++ .../src/progress_tracking/speed_tracker.rs | 477 +++++++ .../src/progress_tracking/upload_tracking.rs | 1173 ++++++----------- xet_data/tests/test_full_file_download.rs | 13 +- xet_data/tests/test_session_resume.rs | 259 ++-- xet_pkg/Cargo.toml | 6 +- xet_pkg/examples/example.rs | 32 +- xet_pkg/examples/example_sync.rs | 12 +- xet_pkg/src/error.rs | 5 +- xet_pkg/src/legacy/data_client.rs | 187 +++ xet_pkg/src/legacy/mod.rs | 9 + .../progress_tracking/callback_bridge.rs | 489 +++++++ .../src/legacy/progress_tracking/mod.rs | 29 +- .../progress_verification_wrapper.rs | 32 +- xet_pkg/src/lib.rs | 1 + xet_pkg/src/xet_session/download_group.rs | 316 +++-- xet_pkg/src/xet_session/mod.rs | 26 +- xet_pkg/src/xet_session/progress.rs | 518 -------- xet_pkg/src/xet_session/session.rs | 16 +- xet_pkg/src/xet_session/tasks.rs | 201 +++ xet_pkg/src/xet_session/upload_commit.rs | 345 +++-- xet_pkg/tests/test_legacy_data_client.rs | 325 +++++ xet_runtime/src/config/groups/data.rs | 14 +- xet_runtime/src/utils/unique_id.rs | 65 +- 57 files changed, 3879 insertions(+), 3584 deletions(-) create mode 100644 api_changes/update_260313_progress_tracking_redesign.md delete mode 100644 xet_data/src/progress_tracking/aggregator.rs delete mode 100644 xet_data/src/progress_tracking/download_tracking.rs delete mode 100644 xet_data/src/progress_tracking/no_op_tracker.rs create mode 100644 xet_data/src/progress_tracking/progress_types.rs create mode 100644 xet_data/src/progress_tracking/speed_tracker.rs create mode 100644 xet_pkg/src/legacy/data_client.rs create mode 100644 xet_pkg/src/legacy/mod.rs create mode 100644 xet_pkg/src/legacy/progress_tracking/callback_bridge.rs rename xet_data/src/progress_tracking/progress_info.rs => xet_pkg/src/legacy/progress_tracking/mod.rs (81%) rename xet_data/src/progress_tracking/verification_wrapper.rs => xet_pkg/src/legacy/progress_tracking/progress_verification_wrapper.rs (88%) delete mode 100644 xet_pkg/src/xet_session/progress.rs create mode 100644 xet_pkg/src/xet_session/tasks.rs create mode 100644 xet_pkg/tests/test_legacy_data_client.rs diff --git a/Cargo.lock b/Cargo.lock index 5f6bff54..6ecc91cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1640,6 +1640,7 @@ dependencies = [ "clap", "git-url-parse", "git2", + "hf-xet", "http", "rand_core 0.6.4", "reqwest", @@ -1654,7 +1655,6 @@ dependencies = [ "thiserror 2.0.18", "tokio", "xet-client", - "xet-data", "xet-runtime", ] @@ -1878,6 +1878,7 @@ dependencies = [ "clap", "futures", "http", + "more-asserts", "serde", "serde_json", "serial_test", @@ -1887,7 +1888,6 @@ dependencies = [ "tokio", "tracing", "tracing-subscriber", - "ulid", "xet-client", "xet-core-structures", "xet-data", @@ -5273,16 +5273,6 @@ version = "1.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8c1ae7cc0fdb8b842d65d127cb981574b0d2b249b74d1c7a2986863dc134f71" -[[package]] -name = "ulid" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "470dbf6591da1b39d43c14523b2b469c86879a53e8b758c8e090a470fe7b1fbe" -dependencies = [ - "rand 0.9.2", - "web-time", -] - [[package]] name = "unicase" version = "2.9.0" @@ -6272,6 +6262,7 @@ dependencies = [ "lazy_static", "more-asserts", "prometheus", + "pyo3", "rand 0.9.2", "regex", "serde", @@ -6284,7 +6275,6 @@ dependencies = [ "tokio-util", "tracing", "tracing-test", - "ulid", "url", "walkdir", "xet-client", diff --git a/Cargo.toml b/Cargo.toml index 5ab77d15..295b3f8a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -111,7 +111,6 @@ tracing = "0.1" tracing-appender = "0.2" tracing-log = "0.2" tracing-subscriber = { version = "0.3", features = ["json", "env-filter"] } -ulid = "1.2" url = "2.5" urlencoding = "2.1" uuid = "1" diff --git a/api_changes/update_260313_progress_tracking_redesign.md b/api_changes/update_260313_progress_tracking_redesign.md new file mode 100644 index 00000000..1691943f --- /dev/null +++ b/api_changes/update_260313_progress_tracking_redesign.md @@ -0,0 +1,122 @@ +# Progress tracking redesign + +**Date**: 2026-03-13 + +## Summary + +Progress tracking has been moved to a polling/snapshot model backed by atomics, and +task identifiers now use `UniqueID` (a `u64`-backed ID) instead of `Ulid`. + +Legacy callback consumers are still supported through bridge adapters in +`xet_pkg::legacy::progress_tracking`. + +## Core type changes + +### `UniqueID` + +- Canonical ID type is `xet_runtime::utils::UniqueId` (newtype over `u64`). +- `xet_data::progress_tracking` re-exports this as `UniqueID`: + `pub use xet_runtime::utils::UniqueId as UniqueID`. +- Used as task/file tracking IDs across `xet_data`, `xet_pkg`, `hf_xet`, and `git_xet`. + +### `xet_data::progress_tracking` + +- Legacy tracker modules were removed (`aggregator`, `download_tracking`, `no_op_tracker`). +- `GroupProgress` and `ItemProgress` are the authoritative progress model. +- `GroupProgress` uses: + - group-level atomics for totals/completions, + - `std::sync::Mutex>>` for per-item registry. +- `ItemProgressUpdater` provides write APIs: + - `update_item_size(total, is_final)` + - `update_transfer_size(total)` + - `report_bytes_completed(increment)` (alias: `report_bytes_written`) + - `report_transfer_bytes_completed(increment)` (alias: `report_transfer_progress`) +- Snapshot structs: + - `GroupProgressReport` + - `ItemProgressReport` + +### Speed estimation model + +- Speed estimation is no longer a sliding window. +- `GroupProgress` now uses `speed_tracker::SpeedTracker`, which computes smoothed + rates using EWMA (`ExpWeightedMovingAvg`). +- Rates are reported as: + - `total_bytes_completion_rate: Option` + - `total_transfer_bytes_completion_rate: Option` +- Rates remain `None` until a configurable minimum observation count is reached. + +## Configuration changes (`xet_runtime::config::data`) + +New settings: + +- `progress_update_speed_sampling_window: Duration` + Env: `HF_XET_DATA_PROGRESS_UPDATE_SPEED_SAMPLING_WINDOW` + Meaning: EWMA half-life for speed estimation. +- `progress_update_speed_min_observations: u32` + Env: `HF_XET_DATA_PROGRESS_UPDATE_SPEED_MIN_OBSERVATIONS` + Meaning: minimum observations before rates are exposed. + +## `xet_data::processing` API changes + +### `FileUploadSession` + +- `new(config)` and `dry_run(config)` no longer accept external progress updater traits. +- `start_clean(...)` no longer accepts `tracking_id`; it returns the generated ID: + - `start_clean(tracking_name: Option>, size: u64, sha256: Sha256Policy) -> Result<(UniqueID, SingleFileCleaner)>` +- New session progress accessors: + - `progress()` + - `report()` + - `item_report(id)` + - `item_reports()` +- New task spawn helpers: + - `spawn_upload_from_path(...) -> Result<(UniqueID, JoinHandle>)>` + - `spawn_upload_bytes(...) -> Result<(UniqueID, JoinHandle>)>` + +### `FileDownloadSession` + +- `new(config)` no longer accepts external progress updater traits. +- `from_client` signature is now: + - `from_client(client: Arc) -> Arc` +- Download calls now return generated IDs: + - `download_file(...) -> Result<(UniqueID, u64)>` + - `download_to_writer(...) -> Result<(UniqueID, u64)>` + - `download_stream(...) -> Result<(UniqueID, DownloadStream)>` + - `download_file_background(...) -> Result<(UniqueID, JoinHandle>)>` +- New progress accessors: + - `report()` + - `item_report(id)` + - `item_reports()` + +### `xet_data::processing::data_client` + +- `clean_file(...)` and `clean_bytes(...)` removed explicit `tracking_id` arguments. +- Callback-oriented upload/download wrappers were moved out of `xet_data` and are + now provided by `xet_pkg::legacy::data_client`. + +## `xet_pkg` API changes + +- New `xet_pkg::legacy` module added for callback-compatible APIs. +- `xet_pkg::xet_session::progress` removed; progress/task types live in: + - `xet_pkg::xet_session::tasks` + - `xet_data::progress_tracking::{GroupProgressReport, ItemProgressReport, UniqueID}` +- `TaskHandle.task_id` now uses `UniqueID`. +- `UploadCommit::commit` / `commit_blocking` result maps now keyed by `UniqueID`. +- `DownloadGroup::finish` / `finish_blocking` result maps now keyed by `UniqueID`. +- `UploadCommit::get_progress()` and `DownloadGroup::get_progress()` are synchronous + snapshot reads returning `GroupProgressReport`. + +## Legacy compatibility path + +- Callback interfaces are preserved in `xet_pkg::legacy::progress_tracking`: + - `TrackingProgressUpdater` + - `ProgressUpdate` + - `ItemProgressUpdate` + - `GroupProgressCallbackUpdater` + - `ItemProgressCallbackUpdater` +- `hf_xet` and `git_xet` were migrated to consume the legacy bridge layer + (`xet_pkg::legacy::*`) instead of direct old `xet_data` callback APIs. + +## Supersedes / follow-on + +- This change supersedes `Ulid` task/result map keys in session APIs with `UniqueID`. +- Any downstream code matching on `Ulid` task IDs must migrate to `UniqueID`. diff --git a/git_xet/Cargo.toml b/git_xet/Cargo.toml index 18d93274..2973d85d 100644 --- a/git_xet/Cargo.toml +++ b/git_xet/Cargo.toml @@ -11,7 +11,7 @@ path = "src/bin/main.rs" [dependencies] xet-runtime = { path = "../xet_runtime" } xet-client = { path = "../xet_client" } -xet-data = { path = "../xet_data" } +xet-pkg = { package = "hf-xet", path = "../xet_pkg" } anyhow = { workspace = true } async-trait = { workspace = true } diff --git a/git_xet/src/app/xet_agent.rs b/git_xet/src/app/xet_agent.rs index e392e5e9..309749cb 100644 --- a/git_xet/src/app/xet_agent.rs +++ b/git_xet/src/app/xet_agent.rs @@ -6,9 +6,8 @@ use async_trait::async_trait; use http::header; use xet_client::cas_client::auth::TokenRefresher; use xet_client::hub_client::Operation; -use xet_data::processing::data_client::{clean_file, default_config}; -use xet_data::processing::{FileUploadSession, Sha256Policy}; -use xet_data::progress_tracking::{ProgressUpdate, TrackingProgressUpdater}; +use xet_pkg::legacy::progress_tracking::{GroupProgressCallbackUpdater, ProgressUpdate, TrackingProgressUpdater}; +use xet_pkg::legacy::{FileUploadSession, Sha256Policy, clean_file, default_config}; use crate::constants::{ HF_ENDPOINT_ENV, XET_ACCESS_TOKEN_HEADER, XET_CAS_URL, XET_SESSION_ID, XET_TOKEN_EXPIRATION_HEADER, @@ -105,9 +104,9 @@ impl TransferAgent for XetAgent { // and https://github.com/git-lfs/git-lfs/blob/2c7de1f90cbe13bf9c1ed43b84dda88bb32f2ba4/tq/custom.go#L304 progress_updater.update_bytes_so_far(1)?; - let xet_updater = XetProgressUpdaterWrapper { + let xet_updater = Arc::new(XetProgressUpdaterWrapper { updater: progress_updater, - }; + }); let cas_url = req .action @@ -137,26 +136,34 @@ impl TransferAgent for XetAgent { if !session_id.is_empty() { config.session.session_id = Some(session_id.to_owned()); } - let session = FileUploadSession::new(config.into(), Some(Arc::new(xet_updater))).await?; + let session = FileUploadSession::new(config.into()).await?; + let bridge = GroupProgressCallbackUpdater::start(session.clone(), xet_updater); let Some(file_path) = &req.path else { return Err(GitLFSProtocolError::bad_syntax("file path not provided for upload request").into()); }; - clean_file(session.clone(), file_path, Sha256Policy::from_hex(&req.oid), None).await?; + let upload_result = async { + clean_file(session.clone(), file_path, Sha256Policy::from_hex(&req.oid)).await?; - // We need to actually upload the shard after each file upload to have the files registered, because - // - // 1. LFS custom transfer protocol is sequential: git-lfs waits for the upload/download result of the one file - // before sending the request to process the next one; - // 2. git-lfs doesn't tell agents how many files to upload/download at the initiation phase; - // 3. After sending a termination signal, git-lfs waits for 30s and sends SIGKILL to the agent. SIGKILL is not - // like SIGINT, it can't be intercepted or ignored by a process. - // 4. Xet system is not a real-time system that guarantees response within any duration. Batching and thus - // effectively delaying shard upload means we risk data loss. - // - // See https://github.com/git-lfs/git-lfs/blob/2c7de1f90cbe13bf9c1ed43b84dda88bb32f2ba4/tq/custom.go#L233 - session.finalize().await?; + // We need to actually upload the shard after each file upload to have the files registered, because + // + // 1. LFS custom transfer protocol is sequential: git-lfs waits for the upload/download result of the one + // file before sending the request to process the next one; + // 2. git-lfs doesn't tell agents how many files to upload/download at the initiation phase; + // 3. After sending a termination signal, git-lfs waits for 30s and sends SIGKILL to the agent. SIGKILL is + // not like SIGINT, it can't be intercepted or ignored by a process. + // 4. Xet system is not a real-time system that guarantees response within any duration. Batching and thus + // effectively delaying shard upload means we risk data loss. + // + // See https://github.com/git-lfs/git-lfs/blob/2c7de1f90cbe13bf9c1ed43b84dda88bb32f2ba4/tq/custom.go#L233 + session.finalize().await?; + Ok::<(), GitXetError>(()) + } + .await; + + bridge.finalize().await; + upload_result?; Ok(()) } diff --git a/git_xet/src/errors.rs b/git_xet/src/errors.rs index 054bc185..bca68aa7 100644 --- a/git_xet/src/errors.rs +++ b/git_xet/src/errors.rs @@ -3,7 +3,7 @@ use std::path::PathBuf; use thiserror::Error; use xet_client::ClientError; -use xet_data::processing::errors::DataProcessingError; +use xet_pkg::legacy::DataProcessingError; use crate::lfs_agent_protocol::GitLFSProtocolError; diff --git a/hf_xet/Cargo.lock b/hf_xet/Cargo.lock index fe73f377..2f68f479 100644 --- a/hf_xet/Cargo.lock +++ b/hf_xet/Cargo.lock @@ -1231,12 +1231,30 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" +[[package]] +name = "hf-xet" +version = "1.4.0" +dependencies = [ + "async-trait", + "http", + "more-asserts", + "serde", + "thiserror 2.0.18", + "tokio", + "tracing", + "xet-client", + "xet-core-structures", + "xet-data", + "xet-runtime", +] + [[package]] name = "hf_xet" version = "1.4.2" dependencies = [ "async-trait", "chrono", + "hf-xet", "http", "itertools 0.14.0", "lazy_static", @@ -1247,7 +1265,6 @@ dependencies = [ "tracing", "winapi", "xet-client", - "xet-data", "xet-runtime", ] @@ -3854,16 +3871,6 @@ version = "1.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8c1ae7cc0fdb8b842d65d127cb981574b0d2b249b74d1c7a2986863dc134f71" -[[package]] -name = "ulid" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "470dbf6591da1b39d43c14523b2b469c86879a53e8b758c8e090a470fe7b1fbe" -dependencies = [ - "rand 0.9.2", - "web-time", -] - [[package]] name = "unicase" version = "2.9.0" @@ -4802,7 +4809,6 @@ dependencies = [ "tokio", "tokio-util", "tracing", - "ulid", "url", "walkdir", "xet-client", diff --git a/hf_xet/Cargo.toml b/hf_xet/Cargo.toml index d9a718c9..5aeab0d5 100644 --- a/hf_xet/Cargo.toml +++ b/hf_xet/Cargo.toml @@ -12,7 +12,7 @@ crate-type = ["cdylib", "lib"] [dependencies] xet-runtime = { path = "../xet_runtime" } xet-client = { path = "../xet_client" } -xet-data = { path = "../xet_data" } +xet-pkg = { package = "hf-xet", path = "../xet_pkg" } async-trait = "0.1" chrono = "0.4" @@ -70,5 +70,3 @@ debug = true opt-level = 3 # cargo-machete has detected the below unused dependency incorrectly -[package.metadata.cargo-machete] -ignored = ["xet-client", "xet-runtime"] diff --git a/hf_xet/src/lib.rs b/hf_xet/src/lib.rs index e80bac85..b8574bfa 100644 --- a/hf_xet/src/lib.rs +++ b/hf_xet/src/lib.rs @@ -17,9 +17,8 @@ use rand::Rng; use runtime::async_run; use token_refresh::WrappedTokenRefresher; use tracing::debug; -use xet_data::processing::errors::DataProcessingError; -use xet_data::processing::{Sha256Policy, XetFileInfo, data_client}; -use xet_data::progress_tracking::TrackingProgressUpdater; +use xet_pkg::legacy::progress_tracking::TrackingProgressUpdater; +use xet_pkg::legacy::{DataProcessingError, Sha256Policy, XetFileInfo, data_client}; use xet_runtime::core::file_handle_limits; use crate::logging::init_logging; diff --git a/hf_xet/src/progress_update.rs b/hf_xet/src/progress_update.rs index 19f467f2..960eedf7 100644 --- a/hf_xet/src/progress_update.rs +++ b/hf_xet/src/progress_update.rs @@ -7,7 +7,7 @@ use pyo3::prelude::PyAnyMethods; use pyo3::types::{IntoPyDict, PyList, PyString}; use pyo3::{IntoPyObjectExt, Py, PyAny, PyResult, Python, pyclass}; use tracing::error; -use xet_data::progress_tracking::{ProgressUpdate, TrackingProgressUpdater}; +use xet_pkg::legacy::progress_tracking::{ProgressUpdate, TrackingProgressUpdater}; use xet_runtime::core::XetRuntime; use xet_runtime::error_printer::ErrorPrinter; @@ -77,7 +77,7 @@ pub struct PyTotalProgressUpdate { #[pyo3(get)] pub total_bytes: u64, - /// How much total_bytes has changed from the last update.. + /// How much total_bytes has changed from the last update. #[pyo3(get)] pub total_bytes_increment: u64, diff --git a/wasm/hf_xet_thin_wasm/Cargo.lock b/wasm/hf_xet_thin_wasm/Cargo.lock index 4ab3131c..6f7d21db 100644 --- a/wasm/hf_xet_thin_wasm/Cargo.lock +++ b/wasm/hf_xet_thin_wasm/Cargo.lock @@ -3083,16 +3083,6 @@ version = "1.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8c1ae7cc0fdb8b842d65d127cb981574b0d2b249b74d1c7a2986863dc134f71" -[[package]] -name = "ulid" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "470dbf6591da1b39d43c14523b2b469c86879a53e8b758c8e090a470fe7b1fbe" -dependencies = [ - "rand 0.9.2", - "web-time", -] - [[package]] name = "unicase" version = "2.9.0" @@ -4029,7 +4019,6 @@ dependencies = [ "tokio", "tokio-util", "tracing", - "ulid", "url", "walkdir", "xet-client", diff --git a/wasm/hf_xet_wasm/Cargo.lock b/wasm/hf_xet_wasm/Cargo.lock index 4129033a..76e191e5 100644 --- a/wasm/hf_xet_wasm/Cargo.lock +++ b/wasm/hf_xet_wasm/Cargo.lock @@ -3219,16 +3219,6 @@ version = "1.14.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8c1ae7cc0fdb8b842d65d127cb981574b0d2b249b74d1c7a2986863dc134f71" -[[package]] -name = "ulid" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "470dbf6591da1b39d43c14523b2b469c86879a53e8b758c8e090a470fe7b1fbe" -dependencies = [ - "rand 0.9.2", - "web-time", -] - [[package]] name = "unicase" version = "2.9.0" @@ -4201,7 +4191,6 @@ dependencies = [ "tokio", "tokio-util", "tracing", - "ulid", "url", "walkdir", "xet-client", diff --git a/xet_core_structures/src/utils/exp_weighted_moving_avg.rs b/xet_core_structures/src/utils/exp_weighted_moving_avg.rs index 9a8e8ae2..40cb2fd0 100644 --- a/xet_core_structures/src/utils/exp_weighted_moving_avg.rs +++ b/xet_core_structures/src/utils/exp_weighted_moving_avg.rs @@ -58,14 +58,14 @@ impl ExpWeightedMovingAvg { let now = Instant::now(); let dt_secs = (now - *last_update).as_secs_f64(); - // decay = 2^(-Δt / T½) - let decay = ((-dt_secs * weight) / *half_life_secs).exp2(); + // decay = 2^(-Δt / T½); independent of sample weight + let decay = (-dt_secs / *half_life_secs).exp2(); *last_update = now; decay }, ExpWeightedMovingAvgMode::CountDecay { half_life_count } => { - // For count-based decay, we apply decay based on the number of samples. - // Each update applies decay = 2^(-1 / T½) where 1 is the count increment. + // For count-based decay, sample weight is treated as the count increment. + // Decay is therefore 2^(-weight / T½_count). (-weight / *half_life_count).exp2() }, }; @@ -183,6 +183,50 @@ mod tests { assert!(m > 0.0 && m < 8.0); } + /// Verifies that time-based decay with update_with_weight correctly + /// computes rate = Σ(decayed bytes) / Σ(decayed time). + #[tokio::test] + async fn ewma_time_decay_weighted_rate() { + pause(); + + let half_life = Duration::from_secs(10); + let mut avg = ExpWeightedMovingAvg::new_time_decay(half_life); + + advance(Duration::from_millis(200)).await; + avg.update_with_weight(2000.0, 0.2); + assert!((avg.value() - 10_000.0).abs() < 1.0); + + advance(Duration::from_millis(200)).await; + avg.update_with_weight(2000.0, 0.2); + assert!((avg.value() - 10_000.0).abs() < 1.0); + + advance(Duration::from_millis(200)).await; + avg.update_with_weight(0.0, 0.2); + assert!(avg.value() < 10_000.0); + } + + /// Verifies that time-decay does not couple weight into the decay exponent. + /// After one half-life of wall time, the decayed weight of the first observation + /// should be halved, regardless of the sample weight used. + #[tokio::test] + async fn ewma_time_decay_half_life_independent_of_weight() { + pause(); + + let half_life = Duration::from_secs(10); + let mut avg = ExpWeightedMovingAvg::new_time_decay(half_life); + + avg.update_with_weight(100.0, 0.5); + + advance(half_life).await; + avg.update_with_weight(0.0, 0.5); + + // After one half-life: weight = 0.5*0.5 + 0.5 = 0.75, value = 100*0.5 = 50 + // mean = 50/0.75 ≈ 66.67 + let epsilon = 1e-6; + let expected = 50.0 / 0.75; + assert!((avg.value() - expected).abs() < epsilon); + } + /// Verifies that after exactly half_life_count samples, the value is approximately halved. #[test] fn ewma_count_decay_half_life() { diff --git a/xet_data/Cargo.toml b/xet_data/Cargo.toml index 699f4c1d..3dc68f3e 100644 --- a/xet_data/Cargo.toml +++ b/xet_data/Cargo.toml @@ -36,8 +36,8 @@ thiserror = { workspace = true } tokio-util = { workspace = true } tracing = { workspace = true } url = { workspace = true } -ulid = { workspace = true } walkdir = { workspace = true } +pyo3 = { version = "0.26", features = ["abi3-py37"], optional = true } [target.'cfg(target_family = "wasm")'.dependencies] tokio = { workspace = true, features = ["sync", "macros", "io-util", "rt", "time"] } @@ -78,8 +78,8 @@ rand = { workspace = true } serial_test = { workspace = true } tempfile = { workspace = true } tracing-test = { workspace = true } -ulid = { workspace = true } [features] strict = [] expensive_tests = [] +python = ["dep:pyo3"] diff --git a/xet_data/src/error.rs b/xet_data/src/error.rs index 4874d1ed..a58acaab 100644 --- a/xet_data/src/error.rs +++ b/xet_data/src/error.rs @@ -71,6 +71,9 @@ pub enum DataError { #[error("Deprecated feature: {0}")] DeprecatedError(String), + #[error("Invalid operation: {0}")] + InvalidOperation(String), + #[error("Auth error: {0}")] AuthError(#[from] AuthError), diff --git a/xet_data/src/file_reconstruction/data_writer/sequential_writer.rs b/xet_data/src/file_reconstruction/data_writer/sequential_writer.rs index 56f1d142..606d49ef 100644 --- a/xet_data/src/file_reconstruction/data_writer/sequential_writer.rs +++ b/xet_data/src/file_reconstruction/data_writer/sequential_writer.rs @@ -14,7 +14,7 @@ use xet_runtime::utils::adjustable_semaphore::AdjustableSemaphorePermit; use super::super::data_writer::{DataFuture, DataWriter}; use super::super::run_state::RunState; use super::super::{FileReconstructionError, Result}; -use crate::progress_tracking::download_tracking::DownloadTaskUpdater; +use crate::progress_tracking::ItemProgressUpdater; // On macOS and Linux, writev(int fildes, const struct iovec *iov, int iovcnt) may return EINVAL if // - the sum of the iov_len values in the iov array overflows a 32-bit integer (macOS) or an ssize_t value (Linux); @@ -46,7 +46,7 @@ type PendingWrite = (Bytes, Option); struct SyncWriterThread { rx: UnboundedReceiver, bytes_written: Arc, - progress_updater: Option>, + progress_updater: Option>, pending: Option, finished: bool, } @@ -55,7 +55,7 @@ impl SyncWriterThread { fn new( rx: UnboundedReceiver, bytes_written: Arc, - progress_updater: Option>, + progress_updater: Option>, ) -> Self { Self { rx, diff --git a/xet_data/src/file_reconstruction/file_reconstructor.rs b/xet_data/src/file_reconstruction/file_reconstructor.rs index 4c940ca2..4fc85ac6 100644 --- a/xet_data/src/file_reconstruction/file_reconstructor.rs +++ b/xet_data/src/file_reconstruction/file_reconstructor.rs @@ -18,7 +18,7 @@ use super::data_writer::{DataWriter, DownloadStream, SequentialWriter}; use super::error::{FileReconstructionError, Result}; use super::reconstruction_terms::ReconstructionTermManager; use super::run_state::{RunError, RunState}; -use crate::progress_tracking::download_tracking::DownloadTaskUpdater; +use crate::progress_tracking::ItemProgressUpdater; /// Reconstructs a file from its content-addressed chunks by downloading xorb blocks /// and writing the reassembled data to an output. Supports byte range requests and @@ -27,7 +27,7 @@ pub struct FileReconstructor { client: Arc, file_hash: MerkleHash, byte_range: Option, - progress_updater: Option>, + progress_updater: Option>, config: Arc, /// Custom buffer semaphore for testing or specialized use cases. @@ -59,7 +59,7 @@ impl FileReconstructor { } } - pub fn with_progress_updater(self, progress_updater: Arc) -> Self { + pub fn with_progress_updater(self, progress_updater: Arc) -> Self { Self { progress_updater: Some(progress_updater), ..self @@ -355,12 +355,12 @@ impl FileReconstructor { } #[cfg(debug_assertions)] -fn default_progress_updater() -> Option> { - Some(DownloadTaskUpdater::correctness_verification_tracker()) +fn default_progress_updater() -> Option> { + Some(ItemProgressUpdater::new_standalone("test")) } #[cfg(not(debug_assertions))] -fn default_progress_updater() -> Option> { +fn default_progress_updater() -> Option> { None } @@ -372,14 +372,12 @@ mod tests { use std::time::Duration; use tokio_util::sync::CancellationToken; - use ulid::Ulid; use xet_client::cas_client::{ClientTestingUtils, DirectAccessClient, LocalClient, RandomFileContents}; use xet_client::cas_types::FileRange; use xet_runtime::core::XetRuntime; use super::*; - use crate::progress_tracking::NoOpProgressUpdater; - use crate::progress_tracking::download_tracking::DownloadProgressTracker; + use crate::progress_tracking::ItemProgressUpdater; const TEST_CHUNK_SIZE: usize = 101; @@ -641,8 +639,7 @@ mod tests { let buffer = Arc::new(std::sync::Mutex::new(Cursor::new(Vec::new()))); let writer = StaticCursorWriter(buffer.clone()); - let progress_updater = - DownloadProgressTracker::new(NoOpProgressUpdater::new()).new_download_task(Ulid::new(), Arc::from("file")); + let progress_updater = ItemProgressUpdater::new_standalone("file"); let bytes_written = FileReconstructor::new(&(client.clone() as Arc), file_contents.file_hash) .with_config(&config) .with_progress_updater(progress_updater.clone()) @@ -664,8 +661,7 @@ mod tests { let buffer = Arc::new(std::sync::Mutex::new(Cursor::new(Vec::new()))); let writer = StaticCursorWriter(buffer.clone()); - let progress_updater = - DownloadProgressTracker::new(NoOpProgressUpdater::new()).new_download_task(Ulid::new(), Arc::from("file")); + let progress_updater = ItemProgressUpdater::new_standalone("file"); let bytes_written = FileReconstructor::new(&(client.clone() as Arc), file_contents.file_hash) .with_config(&config) .with_byte_range(range) @@ -685,8 +681,7 @@ mod tests { let (client, file_contents) = setup_test_file(&term_spec).await; let config = test_config(); - let tracker = DownloadProgressTracker::new(NoOpProgressUpdater::new()); - let task = tracker.new_download_task(Ulid::new(), Arc::from("test_file.bin")); + let task = ItemProgressUpdater::new_standalone("test_file.bin"); let buffer = Arc::new(std::sync::Mutex::new(Cursor::new(Vec::new()))); let writer = StaticCursorWriter(buffer.clone()); @@ -702,8 +697,6 @@ mod tests { task.assert_complete(); assert_eq!(task.total_bytes_completed(), file_contents.data.len() as u64); - - tracker.assert_complete(); } /// Verifies the data_client.rs flow: file size is known upfront (is_final=true), @@ -716,10 +709,8 @@ mod tests { let config = test_config(); let file_size = file_contents.data.len() as u64; - let tracker = DownloadProgressTracker::new(NoOpProgressUpdater::new()); - let task = tracker.new_download_task(Ulid::new(), Arc::from("test_file.bin")); + let task = ItemProgressUpdater::new_standalone("test_file.bin"); - // Simulate data_client.rs: set final size before reconstruction. task.update_item_size(file_size, true); let buffer = Arc::new(std::sync::Mutex::new(Cursor::new(Vec::new()))); @@ -734,11 +725,9 @@ mod tests { assert_eq!(bytes_written, file_size); - // item_bytes should still be file_size (manager's update_item_size calls were ignored). assert_eq!(task.total_bytes_completed(), file_size); task.assert_complete(); - tracker.assert_complete(); } // ==================== Byte Range Reconstruction Tests ==================== diff --git a/xet_data/src/file_reconstruction/reconstruction_terms/file_term.rs b/xet_data/src/file_reconstruction/reconstruction_terms/file_term.rs index 40407c05..1ff0b357 100644 --- a/xet_data/src/file_reconstruction/reconstruction_terms/file_term.rs +++ b/xet_data/src/file_reconstruction/reconstruction_terms/file_term.rs @@ -15,7 +15,7 @@ use super::super::data_writer::DataFuture; use super::super::error::Result; use super::retrieval_urls::TermBlockRetrievalURLs; use super::xorb_block::{XorbBlock, XorbBlockData, XorbReference}; -use crate::progress_tracking::download_tracking::DownloadTaskUpdater; +use crate::progress_tracking::ItemProgressUpdater; /// A single term in a file reconstruction, representing a contiguous byte range /// in the output file that maps to a chunk range within a xorb block. #[derive(Clone)] @@ -58,7 +58,7 @@ impl FileTerm { pub async fn get_data_task( &self, client: Arc, - progress_updater: Option>, + progress_updater: Option>, ) -> Result { // Fast path: data already cached, no need to spawn a task. if let Some(xorb_block_data) = self.xorb_block.data.get() { diff --git a/xet_data/src/file_reconstruction/reconstruction_terms/manager.rs b/xet_data/src/file_reconstruction/reconstruction_terms/manager.rs index 5ed5774e..74cf9ff6 100644 --- a/xet_data/src/file_reconstruction/reconstruction_terms/manager.rs +++ b/xet_data/src/file_reconstruction/reconstruction_terms/manager.rs @@ -15,7 +15,7 @@ use xet_runtime::config::ReconstructionConfig; use super::super::FileReconstructionError; use super::super::error::Result; use super::file_term::{FileTerm, retrieve_file_term_block}; -use crate::progress_tracking::download_tracking::DownloadTaskUpdater; +use crate::progress_tracking::ItemProgressUpdater; type RawFetchedFileTerms = Result, u64, u64)>>; @@ -33,7 +33,7 @@ pub struct ReconstructionTermManager { current_active_byte_position: u64, prefetch_queue: VecDeque>, completion_rate_estimator: ExpWeightedMovingAvg, - progress_updater: Option>, + progress_updater: Option>, total_bytes_reported: u64, total_transfer_bytes_reported: u64, } @@ -44,7 +44,7 @@ impl ReconstructionTermManager { client: Arc, file_hash: MerkleHash, file_byte_range: FileRange, - progress_updater: Option>, + progress_updater: Option>, ) -> Result { let completion_rate_estimator = ExpWeightedMovingAvg::new_count_decay(config.completion_rate_estimator_half_life); diff --git a/xet_data/src/file_reconstruction/reconstruction_terms/xorb_block.rs b/xet_data/src/file_reconstruction/reconstruction_terms/xorb_block.rs index 09404192..2d208b01 100644 --- a/xet_data/src/file_reconstruction/reconstruction_terms/xorb_block.rs +++ b/xet_data/src/file_reconstruction/reconstruction_terms/xorb_block.rs @@ -10,7 +10,7 @@ use xet_runtime::utils::UniqueId; use super::super::error::Result; use super::retrieval_urls::{TermBlockRetrievalURLs, XorbURLProvider}; -use crate::progress_tracking::download_tracking::DownloadTaskUpdater; +use crate::progress_tracking::ItemProgressUpdater; /// Downloaded and decompressed data for a xorb block, including chunk boundary offsets. /// @@ -81,7 +81,7 @@ impl XorbBlock { self: Arc, client: Arc, url_info: Arc, - progress_updater: Option>, + progress_updater: Option>, ) -> Result> { let xorb_block_index = self.xorb_block_index; let uncompressed_size_if_known = self.uncompressed_size_if_known; diff --git a/xet_data/src/file_reconstruction/run_state.rs b/xet_data/src/file_reconstruction/run_state.rs index 78993ae1..ebe72c1d 100644 --- a/xet_data/src/file_reconstruction/run_state.rs +++ b/xet_data/src/file_reconstruction/run_state.rs @@ -6,7 +6,7 @@ use tracing::{info, warn}; use xet_core_structures::merklehash::MerkleHash; use super::error::{FileReconstructionError, Result}; -use crate::progress_tracking::download_tracking::DownloadTaskUpdater; +use crate::progress_tracking::ItemProgressUpdater; /// Internal error type for the reconstruction run loop. Separates cancellation /// (which maps to `Ok(0)`) from real errors (which propagate as `Err`). @@ -34,7 +34,7 @@ pub(crate) struct RunState { stored_error: Mutex>, file_hash: MerkleHash, - progress_updater: Option>, + progress_updater: Option>, total_terms_processed: AtomicU64, total_bytes_scheduled: AtomicU64, @@ -45,7 +45,7 @@ impl RunState { pub(crate) fn new( cancellation_token: CancellationToken, file_hash: MerkleHash, - progress_updater: Option>, + progress_updater: Option>, ) -> Arc { Arc::new(Self { cancellation_token, @@ -119,7 +119,7 @@ impl RunState { &self.file_hash } - pub(crate) fn progress_updater(&self) -> Option<&Arc> { + pub(crate) fn progress_updater(&self) -> Option<&Arc> { self.progress_updater.as_ref() } diff --git a/xet_data/src/processing/bin/example.rs b/xet_data/src/processing/bin/example.rs index ac5f85f2..aa41a98a 100644 --- a/xet_data/src/processing/bin/example.rs +++ b/xet_data/src/processing/bin/example.rs @@ -5,7 +5,6 @@ use std::sync::{Arc, OnceLock}; use anyhow::Result; use clap::{Args, Parser, Subcommand}; -use ulid::Ulid; use xet_data::processing::configurations::TranslatorConfig; use xet_data::processing::{FileUploadSession, Sha256Policy, XetFileInfo}; use xet_runtime::core::XetRuntime; @@ -88,11 +87,10 @@ async fn clean(mut reader: impl Read, mut writer: impl Write, size: u64) -> Resu let mut read_buf = vec![0u8; READ_BLOCK_SIZE]; - let translator = - FileUploadSession::new(TranslatorConfig::local_config(std::env::current_dir()?)?.into(), None).await?; + let translator = FileUploadSession::new(TranslatorConfig::local_config(std::env::current_dir()?)?.into()).await?; let mut size_read = 0; - let mut handle = translator.start_clean(None, size, Sha256Policy::Compute, Ulid::new()).await; + let (_id, mut handle) = translator.start_clean(None, size, Sha256Policy::Compute)?; loop { let bytes = reader.read(&mut read_buf)?; @@ -136,9 +134,9 @@ async fn smudge(_name: Arc, mut reader: impl Read, output_path: PathBuf) -> // Use local config pointing to current directory let cas_path = std::env::current_dir()?; let config = TranslatorConfig::local_config(cas_path)?; - let session = xet_data::processing::FileDownloadSession::new(config.into(), None).await?; + let session = xet_data::processing::FileDownloadSession::new(config.into()).await?; - session.download_file(&xet_file, &output_path, Ulid::new()).await?; + let (_id, _n_bytes) = session.download_file(&xet_file, &output_path).await?; Ok(()) } diff --git a/xet_data/src/processing/data_client.rs b/xet_data/src/processing/data_client.rs index 279d7737..cd43f2c6 100644 --- a/xet_data/src/processing/data_client.rs +++ b/xet_data/src/processing/data_client.rs @@ -1,25 +1,21 @@ use std::fs::File; use std::io::Read; -use std::path::{Path, PathBuf}; +use std::path::Path; use std::sync::Arc; use bytes::Bytes; use http::header::HeaderMap; -use itertools::multizip; use tracing::{Instrument, Span, info_span, instrument}; -use ulid::Ulid; use xet_client::cas_client::auth::{AuthConfig, TokenRefresher}; use xet_core_structures::merklehash::MerkleHash; use xet_runtime::core::par_utils::run_constrained_with_semaphore; use xet_runtime::core::{XetRuntime, check_sigint_shutdown, xet_config}; use super::configurations::{SessionContext, TranslatorConfig}; -use super::errors::DataProcessingError; use super::file_cleaner::Sha256Policy; -use super::file_download_session::FileDownloadSession; use super::{FileUploadSession, XetFileInfo, errors}; use crate::deduplication::{Chunker, DeduplicationMetrics}; -use crate::progress_tracking::TrackingProgressUpdater; +use crate::progress_tracking::UniqueID; pub fn default_config( endpoint: String, @@ -35,189 +31,19 @@ pub fn default_config( auth: auth_cfg, custom_headers, repo_paths: vec!["".into()], - session_id: Some(Ulid::new().to_string()), + session_id: Some(UniqueID::new().to_string()), }; TranslatorConfig::new(session) } -#[instrument(skip_all, name = "data_client::upload_bytes", fields(session_id = tracing::field::Empty, num_files=file_contents.len()))] -pub async fn upload_bytes_async( - file_contents: Vec>, - sha256_policies: Vec, - endpoint: Option, - token_info: Option<(String, u64)>, - token_refresher: Option>, - progress_updater: Option>, - custom_headers: Option>, -) -> errors::Result> { - if sha256_policies.len() != file_contents.len() { - return Err(DataProcessingError::ParameterError(format!( - "sha256_policies length ({}) must match file_contents length ({})", - sha256_policies.len(), - file_contents.len() - ))); - } - - let config = default_config( - endpoint.unwrap_or_else(|| xet_config().data.default_cas_endpoint.clone()), - token_info, - token_refresher, - custom_headers, - )?; - - Span::current().record("session_id", &config.session.session_id); - - let semaphore = XetRuntime::current().common().file_ingestion_semaphore.clone(); - let upload_session = FileUploadSession::new(config.into(), progress_updater).await?; - let clean_futures = file_contents.into_iter().zip(sha256_policies).map(|(blob, policy)| { - let upload_session = upload_session.clone(); - async move { clean_bytes(upload_session, blob, None, policy).await.map(|(xf, _metrics)| xf) } - .instrument(info_span!("clean_task")) - }); - let files = run_constrained_with_semaphore(clean_futures, semaphore).await?; - - // Push the CAS blocks and flush the mdb to disk - let _metrics = upload_session.finalize().await?; - - Ok(files) -} - -// The sha256, if provided and valid, will be directly used in shard upload to avoid redundant computation. -#[instrument(skip_all, name = "data_client::upload_files", - fields(session_id = tracing::field::Empty, - num_files=file_paths.len(), - new_bytes = tracing::field::Empty, - deduped_bytes = tracing::field::Empty, - defrag_prevented_dedup_bytes = tracing::field::Empty, - new_chunks = tracing::field::Empty, - deduped_chunks = tracing::field::Empty, - defrag_prevented_dedup_chunks = tracing::field::Empty - ))] -pub async fn upload_async( - file_paths: Vec, - sha256_policies: Vec, - endpoint: Option, - token_info: Option<(String, u64)>, - token_refresher: Option>, - progress_updater: Option>, - custom_headers: Option>, -) -> errors::Result> { - if sha256_policies.len() != file_paths.len() { - return Err(DataProcessingError::ParameterError(format!( - "sha256_policies length ({}) must match file_paths length ({})", - sha256_policies.len(), - file_paths.len() - ))); - } - - // chunk files - // produce Xorbs + Shards - // upload shards and xorbs - // for each file, return the filehash - let config = default_config( - endpoint.unwrap_or_else(|| xet_config().data.default_cas_endpoint.clone()), - token_info, - token_refresher, - custom_headers, - )?; - - let span = Span::current(); - - span.record("session_id", &config.session.session_id); - - let upload_session = FileUploadSession::new(config.into(), progress_updater).await?; - - let files_sha256_and_tracking_ids = - multizip((file_paths.into_iter(), sha256_policies.into_iter(), std::iter::repeat_with(Ulid::new))); - - let ret = upload_session.upload_files(files_sha256_and_tracking_ids).await?; - - // Push the CAS blocks and flush the mdb to disk - let metrics = upload_session.finalize().await?; - - // Record dedup metrics. - span.record("new_bytes", metrics.new_bytes); - span.record("deduped_bytes", metrics.deduped_bytes); - span.record("defrag_prevented_dedup_bytes", metrics.defrag_prevented_dedup_bytes); - span.record("new_chunks", metrics.new_chunks); - span.record("deduped_chunks", metrics.deduped_chunks); - span.record("defrag_prevented_dedup_chunks", metrics.defrag_prevented_dedup_chunks); - - Ok(ret) -} - -#[instrument(skip_all, name = "data_client::download", fields(session_id = tracing::field::Empty, num_files=file_infos.len()))] -pub async fn download_async( - file_infos: Vec<(XetFileInfo, String)>, - endpoint: Option, - token_info: Option<(String, u64)>, - token_refresher: Option>, - progress_updaters: Option>>, - custom_headers: Option>, -) -> errors::Result> { - if let Some(updaters) = &progress_updaters - && updaters.len() != file_infos.len() - { - return Err(DataProcessingError::ParameterError("updaters are not same length as pointer_files".to_string())); - } - let config: Arc = default_config( - endpoint.unwrap_or_else(|| xet_config().data.default_cas_endpoint.clone()), - token_info, - token_refresher, - custom_headers, - )? - .into(); - - Span::current().record("session_id", &config.session.session_id); - - let updaters = match progress_updaters { - None => vec![None; file_infos.len()], - Some(updaters) => updaters.into_iter().map(Some).collect(), - }; - - let session = FileDownloadSession::new(config, None).await?; - - let mut tasks = Vec::with_capacity(file_infos.len()); - - for ((file_info, file_path), updater) in file_infos.into_iter().zip(updaters) { - let session = session.clone(); - tasks.push(tokio::spawn( - async move { - let semaphore = XetRuntime::current().common().file_download_semaphore.clone(); - let _permit = semaphore.acquire().await?; - - let path = PathBuf::from(&file_path); - match updater { - Some(u) => session.download_file_with_updater(&file_info, &path, u).await?, - None => session.download_file(&file_info, &path, Ulid::new()).await?, - }; - errors::Result::Ok(file_path) - } - .instrument(info_span!("download_file")), - )); - } - - let mut paths = Vec::with_capacity(tasks.len()); - for task in tasks { - paths.push(task.await??); - } - - Ok(paths) -} - #[instrument(skip_all, name = "clean_bytes", fields(bytes.len = bytes.len()))] pub async fn clean_bytes( processor: Arc, bytes: Vec, - tracking_id: Option, sha256_policy: Sha256Policy, ) -> errors::Result<(XetFileInfo, DeduplicationMetrics)> { - #[allow(clippy::unwrap_or_default)] // Ulid::default is Ulid::nil - let tracking_id = tracking_id.unwrap_or_else(Ulid::new); - let mut handle = processor - .start_clean(None, bytes.len() as u64, sha256_policy, tracking_id) - .await; + let (_id, mut handle) = processor.start_clean(None, bytes.len() as u64, sha256_policy)?; handle.add_data(&bytes).await?; handle.finish().await } @@ -227,10 +53,7 @@ pub async fn clean_file( processor: Arc, filename: impl AsRef, sha256_policy: Sha256Policy, - tracking_id: Option, ) -> errors::Result<(XetFileInfo, DeduplicationMetrics)> { - #[allow(clippy::unwrap_or_default)] // Ulid::default is Ulid::nil - let tracking_id = tracking_id.unwrap_or_else(Ulid::new); let mut reader = File::open(&filename)?; let filesize = reader.metadata()?.len(); @@ -239,9 +62,8 @@ pub async fn clean_file( span.record("file.len", filesize); let mut buffer = vec![0u8; u64::min(filesize, *xet_config().data.ingestion_block_size) as usize]; - let mut handle = processor - .start_clean(Some(filename.as_ref().to_string_lossy().into()), filesize, sha256_policy, tracking_id) - .await; + let (_id, mut handle) = + processor.start_clean(Some(filename.as_ref().to_string_lossy().into()), filesize, sha256_policy)?; loop { let bytes = reader.read(&mut buffer)?; diff --git a/xet_data/src/processing/deduplication_interface.rs b/xet_data/src/processing/deduplication_interface.rs index 11225d74..2d7cd02a 100644 --- a/xet_data/src/processing/deduplication_interface.rs +++ b/xet_data/src/processing/deduplication_interface.rs @@ -88,6 +88,6 @@ impl DeduplicationDataInterface for UploadSessionDataManager { /// Periodically registers xorb dependencies; used for progress tracking. async fn register_xorb_dependencies(&mut self, dependencies: &[FileXorbDependency]) { - self.session.register_xorb_dependencies(dependencies).await; + self.session.register_xorb_dependencies(dependencies); } } diff --git a/xet_data/src/processing/file_cleaner.rs b/xet_data/src/processing/file_cleaner.rs index b10447ff..af5bacd2 100644 --- a/xet_data/src/processing/file_cleaner.rs +++ b/xet_data/src/processing/file_cleaner.rs @@ -150,8 +150,7 @@ impl SingleFileCleaner { // how much data we know about. self.session .completion_tracker - .increment_file_size(self.file_id, data.len() as u64) - .await; + .increment_file_size(self.file_id, data.len() as u64); // Put the chunking on a compute thread so it doesn't tie up the async schedulers let chunk_data_jh = { diff --git a/xet_data/src/processing/file_download_session.rs b/xet_data/src/processing/file_download_session.rs index fa7e2549..42303a38 100644 --- a/xet_data/src/processing/file_download_session.rs +++ b/xet_data/src/processing/file_download_session.rs @@ -1,87 +1,116 @@ use std::borrow::Cow; use std::io::Write; use std::ops::Range; -use std::path::Path; +use std::path::{Path, PathBuf}; use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; +use tokio::task::JoinHandle; use tracing::instrument; -use ulid::Ulid; use xet_client::cas_client::Client; use xet_client::cas_types::FileRange; -use xet_runtime::core::xet_config; +use xet_runtime::core::{XetRuntime, xet_config}; use super::configurations::TranslatorConfig; use super::errors::*; use super::remote_client_interface::create_remote_client; use super::{XetFileInfo, prometheus_metrics}; use crate::file_reconstruction::{DownloadStream, FileReconstructor}; -use crate::progress_tracking::TrackingProgressUpdater; -use crate::progress_tracking::aggregator::AggregatingProgressUpdater; -use crate::progress_tracking::download_tracking::DownloadProgressTracker; +use crate::progress_tracking::{GroupProgress, ItemProgressUpdater, UniqueID}; /// Manages the downloading of files from CAS storage. /// /// This struct parallels `FileUploadSession` for the download path. It holds the -/// CAS client, a shared progress tracker for all downloads in the session, and -/// gates concurrent downloads with a semaphore. +/// CAS client and a shared progress group for all downloads in the session. pub struct FileDownloadSession { client: Arc, - progress_tracker: Option>, - progress_aggregator: Option>, + progress: Arc, + finalized: AtomicBool, } impl FileDownloadSession { - pub async fn new( - config: Arc, - progress_updater: Option>, - ) -> Result> { + pub async fn new(config: Arc) -> Result> { let session_id = config .session .session_id .as_ref() .map(Cow::Borrowed) - .unwrap_or_else(|| Cow::Owned(Ulid::new().to_string())); + .unwrap_or_else(|| Cow::Owned(UniqueID::new().to_string())); let client = create_remote_client(&config, &session_id, false).await?; - - let (progress_updater, progress_aggregator) = Self::maybe_wrap_in_aggregator(progress_updater); - let progress_tracker = progress_updater.map(DownloadProgressTracker::new); + let progress = GroupProgress::with_speed_config( + xet_config().data.progress_update_speed_sampling_window, + xet_config().data.progress_update_speed_min_observations, + ); Ok(Arc::new(Self { client, - progress_tracker, - progress_aggregator, + progress, + finalized: AtomicBool::new(false), })) } - /// Creates a new download session from an existing CAS client. + /// Construct a download session from an existing CAS client. /// - /// This is useful for tests or contexts where a client has already been created - /// outside of the normal config-based flow. - pub fn from_client( - client: Arc, - progress_updater: Option>, - ) -> Arc { - let (progress_updater, progress_aggregator) = Self::maybe_wrap_in_aggregator(progress_updater); - let progress_tracker = progress_updater.map(DownloadProgressTracker::new); + /// This path uses default progress speed settings. Use [`Self::new`] when the + /// session should inherit the configured speed parameters from `xet_config`. + pub fn from_client(client: Arc) -> Arc { + let progress = GroupProgress::new(); Arc::new(Self { client, - progress_tracker, - progress_aggregator, + progress, + finalized: AtomicBool::new(false), }) } - /// Downloads a complete file to the given path. + pub fn report(&self) -> crate::progress_tracking::GroupProgressReport { + self.progress.report() + } + + pub fn item_report(&self, id: UniqueID) -> Option { + self.progress.item_report(id) + } + + pub fn item_reports(&self) -> std::collections::HashMap { + self.progress.item_reports() + } + + /// Spawns a download task that writes `file_info` to `write_path`. /// - /// If `tracking_id` is provided, it is used as the progress item name; - /// otherwise the write path is used. + /// Acquires a permit from the global download semaphore before starting. + /// Returns the tracking ID and the join handle for the spawned task. + pub async fn download_file_background( + self: &Arc, + file_info: XetFileInfo, + write_path: PathBuf, + ) -> Result<(UniqueID, JoinHandle>)> { + self.check_not_finalized()?; + let id = UniqueID::new(); + let session = self.clone(); + let rt = XetRuntime::current(); + let semaphore = rt.common().file_download_semaphore.clone(); + let handle = rt.spawn(async move { + let _permit = semaphore.acquire().await?; + session.download_file_with_id(&file_info, &write_path, id).await + }); + Ok((id, handle)) + } + + /// Downloads a complete file to the given path. #[instrument(skip_all, name = "FileDownloadSession::download_file", fields(hash = file_info.hash()))] - pub async fn download_file(&self, file_info: &XetFileInfo, write_path: &Path, tracking_id: Ulid) -> Result { - // download concurrency controlled outside - let reconstructor = self.setup_reconstructor(file_info, None, tracking_id, Some(write_path), None)?; + pub async fn download_file(&self, file_info: &XetFileInfo, write_path: &Path) -> Result<(UniqueID, u64)> { + self.check_not_finalized()?; + let id = UniqueID::new(); + let n_bytes = self.download_file_with_id(file_info, write_path, id).await?; + Ok((id, n_bytes)) + } + + async fn download_file_with_id(&self, file_info: &XetFileInfo, write_path: &Path, id: UniqueID) -> Result { + let name = Arc::from(write_path.to_string_lossy().as_ref()); + let progress_updater = self.progress.new_item(id, name); + let reconstructor = self.setup_reconstructor(file_info, None, Some(progress_updater))?; let n_bytes = reconstructor.reconstruct_to_file(write_path, None).await?; prometheus_metrics::FILTER_BYTES_SMUDGED.inc_by(n_bytes); - Ok(n_bytes) } @@ -98,41 +127,17 @@ impl FileDownloadSession { file_info: &XetFileInfo, source_range: Range, writer: W, - tracking_id: Ulid, - ) -> Result { - // download concurrency controlled outside + ) -> Result<(UniqueID, u64)> { + self.check_not_finalized()?; let range = FileRange::new(source_range.start, source_range.end); - let reconstructor = self.setup_reconstructor(file_info, Some(range), tracking_id, None, None)?; + let id = UniqueID::new(); + let name = Arc::from(""); + let progress_updater = self.progress.new_item(id, name); + let reconstructor = self.setup_reconstructor(file_info, Some(range), Some(progress_updater))?; let n_bytes = reconstructor.reconstruct_to_writer(writer).await?; prometheus_metrics::FILTER_BYTES_SMUDGED.inc_by(n_bytes); - Ok(n_bytes) - } - - /// Downloads a complete file to the given path, using a caller-provided progress updater - /// instead of the session's shared progress tracker. - #[instrument(skip_all, name = "FileDownloadSession::download_file_with_updater", fields(hash = file_info.hash()))] - pub async fn download_file_with_updater( - self: &Arc, - file_info: &XetFileInfo, - write_path: &Path, - progress_updater: Arc, - ) -> Result { - // download concurrency controlled outside - - let (wrapped_updater, aggregator) = Self::maybe_wrap_in_aggregator(Some(progress_updater)); - let tracker = wrapped_updater.map(DownloadProgressTracker::new); - - let reconstructor = - self.setup_reconstructor(file_info, None, Ulid::new(), Some(write_path), tracker.as_ref())?; - let n_bytes = reconstructor.reconstruct_to_file(write_path, None).await?; - prometheus_metrics::FILTER_BYTES_SMUDGED.inc_by(n_bytes); - - if let Some(agg) = aggregator { - agg.finalize().await; - } - - Ok(n_bytes) + Ok((id, n_bytes)) } /// Creates a streaming download of a file. @@ -144,75 +149,47 @@ impl FileDownloadSession { /// /// This path does not acquire the session-level file download semaphore. #[instrument(skip_all, name = "FileDownloadSession::download_stream", fields(hash = file_info.hash()))] - pub fn download_stream(&self, file_info: &XetFileInfo, tracking_id: Ulid) -> Result { - let reconstructor = self.setup_reconstructor(file_info, None, tracking_id, None, None)?; - Ok(reconstructor.reconstruct_to_stream()) + pub async fn download_stream(&self, file_info: &XetFileInfo) -> Result<(UniqueID, DownloadStream)> { + self.check_not_finalized()?; + let id = UniqueID::new(); + let progress_updater = self.progress.new_item(id, "stream"); + let reconstructor = self.setup_reconstructor(file_info, None, Some(progress_updater))?; + Ok((id, reconstructor.reconstruct_to_stream())) } - fn tracker_name(&self, write_path: Option<&Path>) -> Arc { - write_path - .map(|path| Arc::from(path.to_string_lossy().as_ref())) - .unwrap_or_else(|| Arc::from("")) - } - - fn maybe_wrap_in_aggregator( - updater: Option>, - ) -> (Option>, Option>) { - match updater { - Some(updater) => { - let flush_interval = xet_config().data.progress_update_interval; - let sampling_window = xet_config().data.progress_update_speed_sampling_window; - if !flush_interval.is_zero() { - let agg = AggregatingProgressUpdater::new(updater, flush_interval, sampling_window); - (Some(agg.clone() as Arc), Some(agg)) - } else { - (Some(updater), None) - } - }, - None => (None, None), + fn check_not_finalized(&self) -> Result<()> { + if self.finalized.load(Ordering::Acquire) { + return Err(DataProcessingError::InvalidOperation("FileDownloadSession already finalized".to_string())); } + Ok(()) } - /// Finalizes the session-level progress aggregator, flushing any remaining - /// updates and stopping its background task. - pub async fn finalize(&self) { - if let Some(agg) = &self.progress_aggregator { - agg.finalize().await; + /// Finalizes the session; in debug builds, asserts all items are complete. + pub async fn finalize(&self) -> Result<()> { + if self.finalized.swap(true, Ordering::AcqRel) { + return Err(DataProcessingError::InvalidOperation("FileDownloadSession already finalized".to_string())); } + #[cfg(debug_assertions)] + self.progress.assert_complete(); + Ok(()) } - /// Common setup: builds a `FileReconstructor` with the given options. - /// - /// When `progress_override` is provided, it is used instead of the session's - /// shared `progress_tracker`. fn setup_reconstructor( &self, file_info: &XetFileInfo, range: Option, - tracking_id: Ulid, - write_path: Option<&Path>, - progress_override: Option<&Arc>, + progress_updater: Option>, ) -> Result { let file_id = file_info.merkle_hash()?; - - let tracker = progress_override.or(self.progress_tracker.as_ref()); - let task_updater = tracker.map(|tracker| { - let tracking_name = self.tracker_name(write_path); - let task = tracker.new_download_task(tracking_id, tracking_name); - let size = range - .map(|r| r.end.saturating_sub(r.start)) - .unwrap_or_else(|| file_info.file_size()); - task.update_item_size(size, true); - task - }); - let effective_range = range.unwrap_or_else(|| FileRange::new(0, file_info.file_size())); - let mut reconstructor = FileReconstructor::new(&self.client, file_id).with_byte_range(effective_range); - - if let Some(tracker) = task_updater { - reconstructor = reconstructor.with_progress_updater(tracker); + let size = effective_range.end - effective_range.start; + if let Some(ref updater) = progress_updater { + updater.update_item_size(size, true); + } + let mut reconstructor = FileReconstructor::new(&self.client, file_id).with_byte_range(effective_range); + if let Some(updater) = progress_updater { + reconstructor = reconstructor.with_progress_updater(updater); } - Ok(reconstructor) } } @@ -239,13 +216,13 @@ mod tests { } async fn upload_data(cas_path: &Path, data: &[u8]) -> XetFileInfo { - let upload_session = FileUploadSession::new(TranslatorConfig::local_config(cas_path).unwrap().into(), None) + let upload_session = FileUploadSession::new(TranslatorConfig::local_config(cas_path).unwrap().into()) .await .unwrap(); - let mut cleaner = upload_session - .start_clean(Some("test".into()), data.len() as u64, Sha256Policy::Compute, Ulid::new()) - .await; + let (_id, mut cleaner) = upload_session + .start_clean(Some("test".into()), data.len() as u64, Sha256Policy::Compute) + .unwrap(); cleaner.add_data(data).await.unwrap(); let (xfi, _metrics) = cleaner.finish().await.unwrap(); upload_session.finalize().await.unwrap(); @@ -265,10 +242,10 @@ mod tests { let xfi = upload_data(&cas_path, original_data).await; let config = TranslatorConfig::local_config(&cas_path).unwrap(); - let session = FileDownloadSession::new(config.into(), None).await.unwrap(); + let session = FileDownloadSession::new(config.into()).await.unwrap(); let out_path = temp.path().join("output.txt"); - let n_bytes = session.download_file(&xfi, &out_path, Ulid::new()).await.unwrap(); + let (_id, n_bytes) = session.download_file(&xfi, &out_path).await.unwrap(); assert_eq!(n_bytes, original_data.len() as u64); assert_eq!(read(&out_path).unwrap(), original_data); @@ -289,12 +266,12 @@ mod tests { let xfi = upload_data(&cas_path, original_data).await; let config = TranslatorConfig::local_config(&cas_path).unwrap(); - let session = FileDownloadSession::new(config.into(), None).await.unwrap(); + let session = FileDownloadSession::new(config.into()).await.unwrap(); let out_path = temp.path().join("deep").join("nested").join("dir").join("output.txt"); assert!(!out_path.parent().unwrap().exists()); - session.download_file(&xfi, &out_path, Ulid::new()).await.unwrap(); + session.download_file(&xfi, &out_path).await.unwrap(); assert_eq!(read(&out_path).unwrap(), original_data); }) @@ -314,7 +291,7 @@ mod tests { let xfi = upload_data(&cas_path, original_data).await; let config = TranslatorConfig::local_config(&cas_path).unwrap(); - let session = FileDownloadSession::new(config.into(), None).await.unwrap(); + let session = FileDownloadSession::new(config.into()).await.unwrap(); let out_path = temp.path().join("partial_writer.txt"); write(&out_path, vec![0u8; original_data.len()]).unwrap(); @@ -322,7 +299,7 @@ mod tests { let mut file = std::fs::OpenOptions::new().write(true).open(&out_path).unwrap(); file.seek(SeekFrom::Start(4)).unwrap(); - let n_bytes = session.download_to_writer(&xfi, 4..12, file, Ulid::new()).await.unwrap(); + let (_id, n_bytes) = session.download_to_writer(&xfi, 4..12, file).await.unwrap(); assert_eq!(n_bytes, 8); let result = read(&out_path).unwrap(); @@ -343,7 +320,7 @@ mod tests { let xfi = upload_data(&cas_path, original_data).await; let config = TranslatorConfig::local_config(&cas_path).unwrap(); - let session = FileDownloadSession::new(config.into(), None).await.unwrap(); + let session = FileDownloadSession::new(config.into()).await.unwrap(); let out_path = temp.path().join("partitioned.txt"); write(&out_path, vec![0u8; original_data.len()]).unwrap(); @@ -365,7 +342,7 @@ mod tests { tasks.push(tokio::spawn(async move { let mut writer = std::fs::OpenOptions::new().write(true).open(out_path).unwrap(); writer.seek(SeekFrom::Start(start)).unwrap(); - session.download_to_writer(&xfi, start..end, writer, Ulid::new()).await + session.download_to_writer(&xfi, start..end, writer).await })); } @@ -395,7 +372,7 @@ mod tests { let xfi_b = upload_data(&cas_path, data_b).await; let config = TranslatorConfig::local_config(&cas_path).unwrap(); - let session = FileDownloadSession::new(config.into(), None).await.unwrap(); + let session = FileDownloadSession::new(config.into()).await.unwrap(); let out_a = temp.path().join("out_a.txt"); let out_b = temp.path().join("out_b.txt"); @@ -403,14 +380,12 @@ mod tests { let session_a = session.clone(); let xfi_a_clone = xfi_a.clone(); let out_a_clone = out_a.clone(); - let task_a = - tokio::spawn(async move { session_a.download_file(&xfi_a_clone, &out_a_clone, Ulid::new()).await }); + let task_a = tokio::spawn(async move { session_a.download_file(&xfi_a_clone, &out_a_clone).await }); let session_b = session.clone(); let xfi_b_clone = xfi_b.clone(); let out_b_clone = out_b.clone(); - let task_b = - tokio::spawn(async move { session_b.download_file(&xfi_b_clone, &out_b_clone, Ulid::new()).await }); + let task_b = tokio::spawn(async move { session_b.download_file(&xfi_b_clone, &out_b_clone).await }); task_a.await.unwrap().unwrap(); task_b.await.unwrap().unwrap(); @@ -436,9 +411,9 @@ mod tests { let xfi = upload_data(&cas_path, original_data).await; let config = TranslatorConfig::local_config(&cas_path).unwrap(); - let session = FileDownloadSession::new(config.into(), None).await.unwrap(); + let session = FileDownloadSession::new(config.into()).await.unwrap(); - let mut stream = session.download_stream(&xfi, Ulid::new()).unwrap(); + let (_id, mut stream) = session.download_stream(&xfi).await.unwrap(); let mut collected = Vec::new(); while let Some(chunk) = stream.next().await.unwrap() { @@ -463,9 +438,9 @@ mod tests { let xfi = upload_data(&cas_path, original_data).await; let config = TranslatorConfig::local_config(&cas_path).unwrap(); - let session = FileDownloadSession::new(config.into(), None).await.unwrap(); + let session = FileDownloadSession::new(config.into()).await.unwrap(); - let stream = session.download_stream(&xfi, Ulid::new()).unwrap(); + let (_id, stream) = session.download_stream(&xfi).await.unwrap(); let collected = tokio::task::spawn_blocking(move || { let mut stream = stream; @@ -496,11 +471,10 @@ mod tests { let xfi = upload_data(&cas_path, original_data).await; let config = TranslatorConfig::local_config(&cas_path).unwrap(); - let session = FileDownloadSession::new(config.into(), None).await.unwrap(); + let session = FileDownloadSession::new(config.into()).await.unwrap(); - let mut stream = session.download_stream(&xfi, Ulid::new()).unwrap(); + let (_id, mut stream) = session.download_stream(&xfi).await.unwrap(); - // Drain all data while stream.next().await.unwrap().is_some() {} // Subsequent calls should return Ok(None) @@ -526,10 +500,10 @@ mod tests { let xfi_b = upload_data(&cas_path, data_b).await; let config = TranslatorConfig::local_config(&cas_path).unwrap(); - let session = FileDownloadSession::new(config.into(), None).await.unwrap(); + let session = FileDownloadSession::new(config.into()).await.unwrap(); - let mut stream_a = session.download_stream(&xfi_a, Ulid::new()).unwrap(); - let mut stream_b = session.download_stream(&xfi_b, Ulid::new()).unwrap(); + let (_id_a, mut stream_a) = session.download_stream(&xfi_a).await.unwrap(); + let (_id_b, mut stream_b) = session.download_stream(&xfi_b).await.unwrap(); let task_a = tokio::spawn(async move { let mut buf = Vec::new(); @@ -569,16 +543,14 @@ mod tests { let xfi = upload_data(&cas_path, original_data).await; let config = TranslatorConfig::local_config(&cas_path).unwrap(); - let session = FileDownloadSession::new(config.into(), None).await.unwrap(); + let session = FileDownloadSession::new(config.into()).await.unwrap(); - // Create and drop a stream without ever reading from it. - let stream = session.download_stream(&xfi, Ulid::new()).unwrap(); + let (_id, stream) = session.download_stream(&xfi).await.unwrap(); drop(stream); tokio::task::yield_now().await; - // A subsequent file download must succeed, proving no resources leaked. let out_path = temp.path().join("after_drop.txt"); - session.download_file(&xfi, &out_path, Ulid::new()).await.unwrap(); + session.download_file(&xfi, &out_path).await.unwrap(); assert_eq!(read(&out_path).unwrap(), original_data); }) .unwrap(); @@ -597,11 +569,10 @@ mod tests { let xfi = upload_data(&cas_path, original_data).await; let config = TranslatorConfig::local_config(&cas_path).unwrap(); - let session = FileDownloadSession::new(config.into(), None).await.unwrap(); + let session = FileDownloadSession::new(config.into()).await.unwrap(); - // Repeatedly create, start, optionally read, and drop streams. for i in 0..5u32 { - let mut stream = session.download_stream(&xfi, Ulid::new()).unwrap(); + let (_id, mut stream) = session.download_stream(&xfi).await.unwrap(); if i % 3 == 0 { let _ = stream.next().await; } @@ -609,9 +580,8 @@ mod tests { tokio::task::yield_now().await; } - // After many create/drop cycles, a full download must still work. let out_path = temp.path().join("after_cycles.txt"); - session.download_file(&xfi, &out_path, Ulid::new()).await.unwrap(); + session.download_file(&xfi, &out_path).await.unwrap(); assert_eq!(read(&out_path).unwrap(), original_data); }) .unwrap(); @@ -630,25 +600,21 @@ mod tests { let xfi = upload_data(&cas_path, original_data).await; let config = TranslatorConfig::local_config(&cas_path).unwrap(); - let session = FileDownloadSession::new(config.into(), None).await.unwrap(); + let session = FileDownloadSession::new(config.into()).await.unwrap(); - // Read one chunk via blocking next() in a spawn_blocking, then drop. - let stream = session.download_stream(&xfi, Ulid::new()).unwrap(); + let (_id, stream) = session.download_stream(&xfi).await.unwrap(); tokio::task::spawn_blocking(move || { let mut stream = stream; let _chunk = stream.blocking_next().unwrap(); - // stream is dropped here at the end of the closure }) .await .unwrap(); - // Yield to let the runtime process the cancellation. tokio::task::yield_now().await; - // A subsequent download must succeed. let out_path = temp.path().join("after_blocking_drop.txt"); - session.download_file(&xfi, &out_path, Ulid::new()).await.unwrap(); + session.download_file(&xfi, &out_path).await.unwrap(); assert_eq!(read(&out_path).unwrap(), original_data); }) .unwrap(); @@ -667,9 +633,9 @@ mod tests { let xfi = upload_data(&cas_path, original_data).await; let config = TranslatorConfig::local_config(&cas_path).unwrap(); - let session = FileDownloadSession::new(config.into(), None).await.unwrap(); + let session = FileDownloadSession::new(config.into()).await.unwrap(); - let mut stream = session.download_stream(&xfi, Ulid::new()).unwrap(); + let (_id, mut stream) = session.download_stream(&xfi).await.unwrap(); stream.cancel(); assert!(stream.next().await.unwrap().is_none()); assert!(stream.next().await.unwrap().is_none()); @@ -690,16 +656,16 @@ mod tests { let xfi = upload_data(&cas_path, original_data).await; let config = TranslatorConfig::local_config(&cas_path).unwrap(); - let session = FileDownloadSession::new(config.into(), None).await.unwrap(); + let session = FileDownloadSession::new(config.into()).await.unwrap(); - let mut stream = session.download_stream(&xfi, Ulid::new()).unwrap(); + let (_id, mut stream) = session.download_stream(&xfi).await.unwrap(); let _ = stream.next().await.unwrap(); stream.cancel(); assert!(stream.next().await.unwrap().is_none()); assert!(stream.next().await.unwrap().is_none()); let out_path = temp.path().join("after_cancel.txt"); - session.download_file(&xfi, &out_path, Ulid::new()).await.unwrap(); + session.download_file(&xfi, &out_path).await.unwrap(); assert_eq!(read(&out_path).unwrap(), original_data); }) .unwrap(); diff --git a/xet_data/src/processing/file_upload_session.rs b/xet_data/src/processing/file_upload_session.rs index 38bab717..bca33bdc 100644 --- a/xet_data/src/processing/file_upload_session.rs +++ b/xet_data/src/processing/file_upload_session.rs @@ -1,16 +1,17 @@ use std::borrow::Cow; +use std::collections::HashMap; use std::fs::File; use std::io::Read; use std::mem::{swap, take}; -use std::path::Path; +use std::path::{Path, PathBuf}; use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; use bytes::Bytes; use more_asserts::*; use tokio::sync::Mutex; use tokio::task::{JoinHandle, JoinSet}; use tracing::{Instrument, Span, info_span, instrument}; -use ulid::Ulid; use xet_client::cas_client::{Client, ProgressCallback}; use xet_core_structures::metadata_shard::file_structs::MDBFileInfo; use xet_core_structures::xorb_object::SerializedXorbObject; @@ -24,11 +25,8 @@ use super::shard_interface::SessionShardInterface; use super::{XetFileInfo, prometheus_metrics}; use crate::deduplication::constants::{MAX_XORB_BYTES, MAX_XORB_CHUNKS}; use crate::deduplication::{DataAggregator, DeduplicationMetrics, RawXorbData}; -use crate::progress_tracking::aggregator::AggregatingProgressUpdater; use crate::progress_tracking::upload_tracking::{CompletionTracker, FileXorbDependency}; -#[cfg(debug_assertions)] // Used here to verify the update accuracy -use crate::progress_tracking::verification_wrapper::ProgressUpdaterVerificationWrapper; -use crate::progress_tracking::{NoOpProgressUpdater, TrackingProgressUpdater}; +use crate::progress_tracking::{GroupProgress, GroupProgressReport, ItemProgressReport, UniqueID}; /// Manages the translation of files between the /// MerkleDB / pointer file format and the materialized version. @@ -43,8 +41,8 @@ pub struct FileUploadSession { /// Tracking upload completion between xorbs and files. pub(crate) completion_tracker: Arc, - /// Session aggregation - progress_aggregator: Option>, + /// Aggregate progress across all files in this upload session. + progress: Arc, /// Deduplicated data shared across files. current_session_data: Mutex, @@ -55,70 +53,33 @@ pub struct FileUploadSession { /// Internal worker xorb_upload_tasks: Mutex>>, - #[cfg(debug_assertions)] - progress_verifier: Arc, + /// Set to true after finalize() has been called. + finalized: AtomicBool, } // Constructors impl FileUploadSession { - pub async fn new( - config: Arc, - upload_progress_updater: Option>, - ) -> Result> { - FileUploadSession::new_impl(config, upload_progress_updater, false).await + pub async fn new(config: Arc) -> Result> { + FileUploadSession::new_impl(config, false).await } - pub async fn dry_run( - config: Arc, - upload_progress_updater: Option>, - ) -> Result> { - FileUploadSession::new_impl(config, upload_progress_updater, true).await + pub async fn dry_run(config: Arc) -> Result> { + FileUploadSession::new_impl(config, true).await } - async fn new_impl( - config: Arc, - upload_progress_updater: Option>, - dry_run: bool, - ) -> Result> { + async fn new_impl(config: Arc, dry_run: bool) -> Result> { let session_id = config .session .session_id .as_ref() .map(Cow::Borrowed) - .unwrap_or_else(|| Cow::Owned(Ulid::new().to_string())); + .unwrap_or_else(|| Cow::Owned(UniqueID::new().to_string())); - let (progress_updater, progress_aggregator): (Arc, Option<_>) = { - match upload_progress_updater { - Some(updater) => { - let flush_interval = xet_config().data.progress_update_interval; - if !flush_interval.is_zero() - && xet_config().data.aggregate_progress - && !config.force_disable_progress_aggregation - { - let aggregator = AggregatingProgressUpdater::new( - updater, - flush_interval, - xet_config().data.progress_update_speed_sampling_window, - ); - (aggregator.clone(), Some(aggregator)) - } else { - (updater, None) - } - }, - None => (Arc::new(NoOpProgressUpdater), None), - } - }; - - // When debug assertions are enabled, track all the progress updates for consistency - // and correctness. This is checked at the end. - #[cfg(debug_assertions)] - let (progress_updater, progress_verification_tracker) = { - let updater = ProgressUpdaterVerificationWrapper::new(progress_updater); - - (updater.clone() as Arc, updater) - }; - - let completion_tracker = Arc::new(CompletionTracker::new(progress_updater)); + let progress = GroupProgress::with_speed_config( + xet_config().data.progress_update_speed_sampling_window, + xet_config().data.progress_update_speed_min_observations, + ); + let completion_tracker = Arc::new(CompletionTracker::new(progress.clone())); let client = create_remote_client(&config, &session_id, dry_run).await?; @@ -128,38 +89,30 @@ impl FileUploadSession { shard_interface, client, completion_tracker, - progress_aggregator, + progress, current_session_data: Mutex::new(DataAggregator::default()), deduplication_metrics: Mutex::new(DeduplicationMetrics::default()), xorb_upload_tasks: Mutex::new(JoinSet::new()), - - #[cfg(debug_assertions)] - progress_verifier: progress_verification_tracker, + finalized: AtomicBool::new(false), })) } pub async fn upload_files( self: &Arc, - files_sha256_and_tracking_ids: impl IntoIterator, Sha256Policy, Ulid)> + Send, + files_and_sha256: impl IntoIterator, Sha256Policy)> + Send, ) -> Result> { + self.check_not_finalized()?; let mut cleaning_tasks: Vec> = vec![]; - for (f, sha256, tracking_id) in files_sha256_and_tracking_ids.into_iter() { + for (f, sha256) in files_and_sha256.into_iter() { let file_path = f.as_ref().to_owned(); let file_name: Arc = Arc::from(file_path.to_string_lossy()); - // Get the file size, and go ahead and register it in the completion tracker so that we know the whole - // repo size at the beginning. let file_size = std::fs::metadata(&file_path)?.len(); - // Get a new file id for the completion tracking. The size is not passed here; - // it will be discovered incrementally via increment_file_size in add_data_impl. - let file_id = self - .completion_tracker - .register_new_file(tracking_id, file_name.clone(), Some(file_size)) - .await; + let updater = self.progress.new_item(UniqueID::new(), file_name.clone()); + let file_id = self.completion_tracker.register_new_file(updater, Some(file_size)); - // Now, spawn a task let ingestion_concurrency_limiter = XetRuntime::current().common().file_ingestion_semaphore.clone(); let session = self.clone(); @@ -252,22 +205,93 @@ impl FileUploadSession { /// If a sha256 is provided via [`Sha256Policy::Provided`], the value will be directly /// used in shard upload to avoid redundant computation. [`Sha256Policy::Skip`] skips /// SHA-256 computation entirely and no metadata_ext is included in the shard. - pub async fn start_clean( + pub fn start_clean( self: &Arc, tracking_name: Option>, size: u64, sha256: Sha256Policy, - tracking_id: Ulid, - ) -> SingleFileCleaner { - // Get a new file id for the completion tracking - let file_id = self - .completion_tracker - .register_new_file(tracking_id, tracking_name.clone().unwrap_or_default(), Some(size)) - .await; + ) -> Result<(UniqueID, SingleFileCleaner)> { + self.check_not_finalized()?; + let id = UniqueID::new(); + let cleaner = self.start_clean_with_id(id, tracking_name, size, sha256); + Ok((id, cleaner)) + } + fn start_clean_with_id( + self: &Arc, + id: UniqueID, + tracking_name: Option>, + size: u64, + sha256: Sha256Policy, + ) -> SingleFileCleaner { + let updater = self.progress.new_item(id, tracking_name.clone().unwrap_or_default()); + let file_id = self.completion_tracker.register_new_file(updater, Some(size)); SingleFileCleaner::new(tracking_name, file_id, sha256, self.clone()) } + /// Spawns a task that reads `file_path` and uploads it. + /// + /// Returns the tracking ID and a join handle for the spawned task. + pub async fn spawn_upload_from_path( + self: &Arc, + file_path: PathBuf, + sha256: Sha256Policy, + ) -> Result<(UniqueID, JoinHandle>)> { + self.check_not_finalized()?; + let file_size = std::fs::metadata(&file_path)?.len(); + let tracking_name: Arc = Arc::from(file_path.to_string_lossy().as_ref()); + let (id, cleaner) = self.start_clean(Some(tracking_name), file_size, sha256)?; + + let rt = XetRuntime::current(); + let semaphore = rt.common().file_ingestion_semaphore.clone(); + let handle = rt.spawn(async move { + let _permit = semaphore.acquire().await?; + Self::feed_file_to_cleaner(cleaner, &file_path).await + }); + + Ok((id, handle)) + } + + /// Spawns a task that uploads `bytes` as a single file. + /// + /// Returns the tracking ID and a join handle for the spawned task. + pub async fn spawn_upload_bytes( + self: &Arc, + bytes: Vec, + sha256: Sha256Policy, + tracking_name: Option>, + ) -> Result<(UniqueID, JoinHandle>)> { + self.check_not_finalized()?; + let (id, mut cleaner) = self.start_clean(tracking_name, bytes.len() as u64, sha256)?; + + let rt = XetRuntime::current(); + let semaphore = rt.common().file_ingestion_semaphore.clone(); + let handle = rt.spawn(async move { + let _permit = semaphore.acquire().await?; + cleaner.add_data(&bytes).await?; + let (file_info, _metrics) = cleaner.finish().await?; + Ok(file_info) + }); + + Ok((id, handle)) + } + + async fn feed_file_to_cleaner(mut cleaner: SingleFileCleaner, file_path: &Path) -> Result { + let mut reader = File::open(file_path)?; + let filesize = reader.metadata()?.len(); + let mut buffer = vec![0u8; u64::min(filesize, *xet_config().data.ingestion_block_size) as usize]; + + loop { + let n = reader.read(&mut buffer)?; + if n == 0 { + break; + } + cleaner.add_data(&buffer[..n]).await?; + } + let (file_info, _metrics) = cleaner.finish().await?; + Ok(file_info) + } + /// Registers a new xorb for upload, returning true if the xorb was added to the upload queue and false /// if it was already in the queue and didn't need to be uploaded again. #[instrument(skip_all, name="FileUploadSession::register_new_xorb_for_upload", fields(xorb_len = xorb.num_bytes()))] @@ -291,14 +315,11 @@ impl FileUploadSession { // In some circumstances, we can cut to instances of the same xorb, namely when there are two files // with the same starting data that get processed simultaneously. When this happens, we only upload // the first one, returning early here. - let xorb_is_new = self - .completion_tracker - .register_new_xorb(xorb_hash, xorb.num_bytes() as u64) - .await; + let xorb_is_new = self.completion_tracker.register_new_xorb(xorb_hash, xorb.num_bytes() as u64); // Make sure we add in all the dependencies. This should happen after the xorb is registered but before // we start the upload. - self.completion_tracker.register_dependencies(file_dependencies).await; + self.completion_tracker.register_dependencies(file_dependencies); if !xorb_is_new { return Ok(false); @@ -307,7 +328,7 @@ impl FileUploadSession { // No need to process an empty xorb. But check this after the session_xorbs tracker // to make sure the reporting is correct. if xorb.num_bytes() == 0 { - self.completion_tracker.register_xorb_upload_completion(xorb_hash).await; + self.completion_tracker.register_xorb_upload_completion(xorb_hash); return Ok(true); } @@ -346,7 +367,7 @@ impl FileUploadSession { .await?; // Register that the xorb has been uploaded. - session.completion_tracker.register_xorb_upload_completion(xorb_hash).await; + session.completion_tracker.register_xorb_upload_completion(xorb_hash); // Record the number of bytes uploaded. session.deduplication_metrics.lock().await.xorb_bytes_uploaded += n_bytes_transmitted; @@ -440,13 +461,20 @@ impl FileUploadSession { } /// Register a xorb dependencies that is given as part of the dedup process. - pub(crate) async fn register_xorb_dependencies(self: &Arc, xorb_dependencies: &[FileXorbDependency]) { - self.completion_tracker.register_dependencies(xorb_dependencies).await; + pub(crate) fn register_xorb_dependencies(self: &Arc, xorb_dependencies: &[FileXorbDependency]) { + self.completion_tracker.register_dependencies(xorb_dependencies); } /// Finalize everything. #[instrument(skip_all, name="FileUploadSession::finalize", fields(session.id))] - async fn finalize_impl(self: Arc, return_files: bool) -> Result<(DeduplicationMetrics, Vec)> { + async fn finalize_impl( + self: Arc, + return_files: bool, + ) -> Result<(DeduplicationMetrics, Vec, GroupProgressReport)> { + if self.finalized.swap(true, Ordering::AcqRel) { + return Err(DataProcessingError::InvalidOperation("FileUploadSession already finalized".to_string())); + } + // Register the remaining xorbs for upload. let data_agg = take(&mut *self.current_session_data.lock().await); self.process_aggregated_data_as_xorb(data_agg).await?; @@ -461,11 +489,6 @@ impl FileUploadSession { result??; } - // Now that all the tasks there are completed, there shouldn't be any other references to this session - // hanging around; i.e. the self in this session should be used as if it's consuming the class, as it - // effectively empties all the states. - debug_assert_eq!(Arc::strong_count(&self), 1); - let all_file_info = if return_files { self.shard_interface.session_file_info_list().await? } else { @@ -483,23 +506,12 @@ impl FileUploadSession { #[cfg(debug_assertions)] { - // Checks to make sure all the upload parts are complete. - self.completion_tracker.assert_complete().await; - - // Checks that all the progress updates were received correctly. - self.progress_verifier.assert_complete().await; + self.completion_tracker.assert_complete(); + self.progress.assert_complete(); } - // Make sure all the updates have been flushed through. - self.completion_tracker.flush().await; - - // Clear this out so the background aggregation session fully finishes. - if let Some(pa) = &self.progress_aggregator { - pa.finalize().await; - debug_assert!(pa.is_finished().await); - } - - Ok((metrics, all_file_info)) + let report = self.report(); + Ok((metrics, all_file_info, report)) } // Wait until everything currently in process is completed and uploaded, cutting a xorb for the remaining bit. @@ -520,17 +532,44 @@ impl FileUploadSession { } } - self.completion_tracker.flush().await; - Ok(()) } + fn check_not_finalized(&self) -> Result<()> { + if self.finalized.load(Ordering::Acquire) { + return Err(DataProcessingError::InvalidOperation("FileUploadSession already finalized".to_string())); + } + Ok(()) + } + + pub fn progress(&self) -> &Arc { + &self.progress + } + + pub fn report(&self) -> GroupProgressReport { + self.progress.report() + } + + pub fn item_report(&self, id: UniqueID) -> Option { + self.progress.item_report(id) + } + + pub fn item_reports(&self) -> HashMap { + self.progress.item_reports() + } + pub async fn finalize(self: Arc) -> Result { Ok(self.finalize_impl(false).await?.0) } + pub async fn finalize_with_report(self: Arc) -> Result<(DeduplicationMetrics, GroupProgressReport)> { + let (metrics, _file_info, report) = self.finalize_impl(false).await?; + Ok((metrics, report)) + } + pub async fn finalize_with_file_info(self: Arc) -> Result<(DeduplicationMetrics, Vec)> { - self.finalize_impl(true).await + let (metrics, file_info, _report) = self.finalize_impl(true).await?; + Ok((metrics, file_info)) } } @@ -569,13 +608,13 @@ mod tests { .unwrap(), ); - let upload_session = FileUploadSession::new(TranslatorConfig::local_config(cas_path).unwrap().into(), None) + let upload_session = FileUploadSession::new(TranslatorConfig::local_config(cas_path).unwrap().into()) .await .unwrap(); - let mut cleaner = upload_session - .start_clean(Some("test".into()), read_data.len() as u64, Sha256Policy::Compute, Ulid::new()) - .await; + let (_id, mut cleaner) = upload_session + .start_clean(Some("test".into()), read_data.len() as u64, Sha256Policy::Compute) + .unwrap(); // Read blocks from the source file and hand them to the cleaning handle cleaner.add_data(&read_data[..]).await.unwrap(); @@ -601,9 +640,9 @@ mod tests { let xet_file = serde_json::from_str::(&input).unwrap(); let config = TranslatorConfig::local_config(cas_path).unwrap(); - let session = FileDownloadSession::new(config.into(), None).await.unwrap(); + let session = FileDownloadSession::new(config.into()).await.unwrap(); - session.download_file(&xet_file, output_path, Ulid::new()).await.unwrap(); + let (_id, _n_bytes) = session.download_file(&xet_file, output_path).await.unwrap(); } use std::fs::{read, write}; @@ -655,14 +694,13 @@ mod tests { .external_run_async_task(async move { let cas_path = temp.path().join("cas"); - let upload_session = - FileUploadSession::new(TranslatorConfig::local_config(&cas_path).unwrap().into(), None) - .await - .unwrap(); + let upload_session = FileUploadSession::new(TranslatorConfig::local_config(&cas_path).unwrap().into()) + .await + .unwrap(); - let mut cleaner = upload_session - .start_clean(Some("test".into()), data.len() as u64, Sha256Policy::Skip, Ulid::new()) - .await; + let (_id, mut cleaner) = upload_session + .start_clean(Some("test".into()), data.len() as u64, Sha256Policy::Skip) + .unwrap(); cleaner.add_data(data).await.unwrap(); cleaner.finish().await.unwrap(); diff --git a/xet_data/src/processing/migration_tool/migrate.rs b/xet_data/src/processing/migration_tool/migrate.rs index 28d6bdec..a73c7c37 100644 --- a/xet_data/src/processing/migration_tool/migrate.rs +++ b/xet_data/src/processing/migration_tool/migrate.rs @@ -90,9 +90,9 @@ pub async fn migrate_files_impl( XetRuntime::current().num_worker_threads() }; let processor = if dry_run { - FileUploadSession::dry_run(config.into(), None).await? + FileUploadSession::dry_run(config.into()).await? } else { - FileUploadSession::new(config.into(), None).await? + FileUploadSession::new(config.into()).await? }; let sha256_policies: Vec = match sha256s { @@ -108,7 +108,7 @@ pub async fn migrate_files_impl( let clean_futs = file_paths.into_iter().zip(sha256_policies).map(|(file_path, policy)| { let proc = processor.clone(); async move { - let (pf, metrics) = clean_file(proc, file_path, policy, None).await?; + let (pf, metrics) = clean_file(proc, file_path, policy).await?; Ok::<(XetFileInfo, u64), DataProcessingError>((pf, metrics.new_bytes)) } .instrument(info_span!("clean_file")) diff --git a/xet_data/src/processing/test_utils.rs b/xet_data/src/processing/test_utils.rs index 661e33cf..41a17b54 100644 --- a/xet_data/src/processing/test_utils.rs +++ b/xet_data/src/processing/test_utils.rs @@ -6,14 +6,12 @@ use std::sync::Arc; use itertools::multizip; use rand::prelude::*; use tempfile::TempDir; -use ulid::Ulid; use xet_client::cas_client::{Client, LocalClient, LocalTestServer, LocalTestServerBuilder}; use super::configurations::TranslatorConfig; use super::data_client::clean_file; use super::file_cleaner::Sha256Policy; use super::{FileDownloadSession, FileUploadSession, XetFileInfo}; -use crate::progress_tracking::TrackingProgressUpdater; /// Describes how hydration (download/smudge) should be performed during a test. /// @@ -267,12 +265,9 @@ impl HydrateDehydrateTest { } } - pub async fn new_upload_session( - &self, - progress_tracker: Option>, - ) -> Arc { + pub async fn new_upload_session(&self) -> Arc { let config = Arc::new(TranslatorConfig::local_config(&self.cas_dir).unwrap()); - FileUploadSession::new(config.clone(), progress_tracker).await.unwrap() + FileUploadSession::new(config.clone()).await.unwrap() } pub async fn clean_all_files(&self, upload_session: &Arc, sequential: bool) { @@ -285,7 +280,7 @@ impl HydrateDehydrateTest { let upload_session = upload_session.clone(); if sequential { - let (pf, metrics) = clean_file(upload_session.clone(), entry.path(), Sha256Policy::Compute, None) + let (pf, metrics) = clean_file(upload_session.clone(), entry.path(), Sha256Policy::Compute) .await .unwrap(); assert_eq!({ metrics.total_bytes }, entry.metadata().unwrap().len()); @@ -301,13 +296,9 @@ impl HydrateDehydrateTest { .map(|entry| self.src_dir.join(entry.unwrap().file_name())) .collect(); - let files_sha256_and_tracking_ids = multizip(( - files.iter(), - std::iter::repeat_with(|| Sha256Policy::Compute), - std::iter::repeat_with(Ulid::new), - )); + let files_and_sha256 = multizip((files.iter(), std::iter::repeat_with(|| Sha256Policy::Compute))); - let clean_results = upload_session.upload_files(files_sha256_and_tracking_ids).await.unwrap(); + let clean_results = upload_session.upload_files(files_and_sha256).await.unwrap(); for (i, xf) in clean_results.into_iter().enumerate() { std::fs::write(self.ptr_dir.join(files[i].file_name().unwrap()), serde_json::to_string(&xf).unwrap()) @@ -317,7 +308,7 @@ impl HydrateDehydrateTest { } pub async fn dehydrate(&mut self, sequential: bool) { - let upload_session = self.new_upload_session(None).await; + let upload_session = self.new_upload_session().await; self.clean_all_files(&upload_session, sequential).await; upload_session.finalize().await.unwrap(); @@ -325,20 +316,20 @@ impl HydrateDehydrateTest { pub async fn hydrate(&mut self) { let client = self.get_or_create_client().await; - let session = FileDownloadSession::from_client(client, None); + let session = FileDownloadSession::from_client(client); for entry in read_dir(&self.ptr_dir).unwrap() { let entry = entry.unwrap(); let out_filename = self.dest_dir.join(entry.file_name()); let xf: XetFileInfo = serde_json::from_reader(File::open(entry.path()).unwrap()).unwrap(); - session.download_file(&xf, &out_filename, Ulid::new()).await.unwrap(); + let (_id, _) = session.download_file(&xf, &out_filename).await.unwrap(); } } pub async fn hydrate_partitioned_writers(&mut self, partitions: usize) { let client = self.get_or_create_client().await; - let session = FileDownloadSession::from_client(client, None); + let session = FileDownloadSession::from_client(client); for entry in read_dir(&self.ptr_dir).unwrap() { let entry = entry.unwrap(); @@ -370,7 +361,7 @@ impl HydrateDehydrateTest { tasks.push(tokio::spawn(async move { let mut writer = std::fs::OpenOptions::new().write(true).open(out_filename).unwrap(); writer.seek(SeekFrom::Start(start)).unwrap(); - session.download_to_writer(&xf, start..end, writer, Ulid::new()).await + session.download_to_writer(&xf, start..end, writer).await })); } @@ -382,14 +373,14 @@ impl HydrateDehydrateTest { pub async fn hydrate_stream(&mut self) { let client = self.get_or_create_client().await; - let session = FileDownloadSession::from_client(client, None); + let session = FileDownloadSession::from_client(client); for entry in read_dir(&self.ptr_dir).unwrap() { let entry = entry.unwrap(); let out_filename = self.dest_dir.join(entry.file_name()); let xf: XetFileInfo = serde_json::from_reader(File::open(entry.path()).unwrap()).unwrap(); - let mut stream = session.download_stream(&xf, Ulid::new()).unwrap(); + let (_id, mut stream) = session.download_stream(&xf).await.unwrap(); let mut file = File::create(&out_filename).unwrap(); while let Some(chunk) = stream.next().await.unwrap() { diff --git a/xet_data/src/progress_tracking/aggregator.rs b/xet_data/src/progress_tracking/aggregator.rs deleted file mode 100644 index c0d5ce88..00000000 --- a/xet_data/src/progress_tracking/aggregator.rs +++ /dev/null @@ -1,492 +0,0 @@ -use std::collections::HashMap; -use std::collections::hash_map::Entry as HashMapEntry; -use std::sync::Arc; -use std::time::Duration; - -use more_asserts::*; -use tokio::sync::Mutex; -use tokio::task::JoinHandle; -use tokio::time::Instant; -use ulid::Ulid; - -use super::{ProgressUpdate, TrackingProgressUpdater}; - -/// A wrapper around an `Arc` that efficiently aggregates progress -/// updates over time and flushes the aggregated updates periodically or on demand. -/// -/// This struct buffers incoming [`ProgressUpdate`] values and merges them by item name -/// so that repeated updates for the same item are merged. -/// -/// The aggregated updates to the wrapped inner updater on a fixed interval. -/// -/// ### Usage: -/// -/// let inner_updater: Arc = Arc::new(MyUpdater {}); -/// let aggregator = AggregatingProgressUpdater::new(inner_updater, Duration::from_millis(200)); -/// -/// // Register updates as needed... -/// aggregator.register_updates(my_update).await; -pub struct AggregatingProgressUpdater { - inner: Option>, - state: Arc>, - bg_update_loop_handle: Mutex>>, -} - -struct SpeedWindowSample { - sample_time: Instant, - total_bytes_completed: u64, - total_transfer_bytes_completed: u64, -} - -#[derive(Default)] -struct AggregationState { - pending: ProgressUpdate, - item_lookup: HashMap, - finished: bool, - - /// A round-robin sampling window - speed_window_samples: Vec, - speed_sample_size: usize, - - /// The tick index. Elements are stored at - tick_index: usize, -} - -impl AggregationState { - fn new(speed_sample_size: usize) -> Self { - debug_assert_ge!(speed_sample_size, 1); - - Self { - speed_window_samples: Vec::with_capacity(speed_sample_size), - speed_sample_size, - ..Default::default() - } - } - - fn merge_in(&mut self, mut other: ProgressUpdate) { - debug_assert!(!self.finished); - - for item in other.item_updates.drain(..) { - match self.item_lookup.entry(item.tracking_id) { - HashMapEntry::Occupied(entry) => { - self.pending.item_updates[*entry.get()].merge_in(item); - }, - HashMapEntry::Vacant(entry) => { - entry.insert_entry(self.pending.item_updates.len()); - self.pending.item_updates.push(item); - }, - } - } - // Already merged in all the other updates; do this one now. - self.pending.merge_in(other); - } - - fn get_state(&mut self) -> ProgressUpdate { - let mut update = std::mem::take(&mut self.pending); - - // Copy back the accumulated stats in case this is called before another update happens. - self.pending.total_bytes = update.total_bytes; - self.pending.total_bytes_completed = update.total_bytes_completed; - self.pending.total_transfer_bytes = update.total_transfer_bytes; - self.pending.total_transfer_bytes_completed = update.total_transfer_bytes_completed; - - // Now update the speed estimation if possible. - if self.speed_sample_size != 0 { - let now = Instant::now(); - let earliest_idx = self.tick_index % self.speed_sample_size; - - if !self.speed_window_samples.is_empty() { - // Run this as a fixed size ring buffer. - let earliest = &self.speed_window_samples[earliest_idx]; - - let time_passed = (now.saturating_duration_since(earliest.sample_time)).as_secs_f64().max(0.001); - - update.total_bytes_completion_rate = Some( - (update.total_bytes_completed.saturating_sub(earliest.total_bytes_completed)) as f64 / time_passed, - ); - - update.total_transfer_bytes_completion_rate = Some( - (update - .total_transfer_bytes_completed - .saturating_sub(earliest.total_transfer_bytes_completed)) as f64 - / time_passed, - ); - } - - // Add the current update to the ring - let speed_sample = SpeedWindowSample { - sample_time: now, - total_bytes_completed: update.total_bytes_completed, - total_transfer_bytes_completed: update.total_transfer_bytes_completed, - }; - - if self.speed_window_samples.len() < self.speed_sample_size { - self.speed_window_samples.push(speed_sample); - } else { - // Cycle the insertion point in the ring. - self.speed_window_samples[earliest_idx] = speed_sample; - self.tick_index += 1; - } - } - - // Preallocate enough that we minimize reallocations - self.pending.item_updates = Vec::with_capacity((4 * update.item_updates.len()) / 3); - - // Clear out the lookup table. - self.item_lookup.clear(); - - // Return the update. - update - } -} - -impl AggregatingProgressUpdater { - /// Start a new aggregating progress updater that flushes the updates to - pub fn new( - inner: Arc, - flush_interval: Duration, - speed_sampling_window: Duration, - ) -> Arc { - let speed_sample_size = - 1 + (speed_sampling_window.as_secs_f64() / flush_interval.as_secs_f64()).ceil() as usize; - - let state = Arc::new(Mutex::new(AggregationState::new(speed_sample_size))); - - let state_clone = Arc::clone(&state); - let inner_clone = Arc::clone(&inner); - - let bg_update_loop = tokio::spawn(async move { - // Wake up every 100ms to check to see if we're complete. - let mut interval = tokio::time::interval_at(Instant::now() + flush_interval, flush_interval); - - loop { - interval.tick().await; - let is_complete = Self::flush_impl(&inner_clone, &state_clone).await; - - if is_complete { - break; - } - } - }); - - Arc::new(Self { - inner: Some(inner), - state, - bg_update_loop_handle: Mutex::new(Some(bg_update_loop)), - }) - } - - /// Creates a class that only aggregates the stats to be used to hold and track the total stats during and after a - /// session. - pub fn new_aggregation_only() -> Arc { - Arc::new(Self { - inner: None, - state: Arc::new(Mutex::new(AggregationState::default())), - bg_update_loop_handle: Mutex::new(None), - }) - } - - async fn get_aggregated_state_impl(state: &Arc>) -> (ProgressUpdate, bool) { - let mut state_guard = state.lock().await; - - (state_guard.get_state(), state_guard.finished) - } - - async fn flush_impl(inner: &Arc, state: &Arc>) -> bool { - let (flushed, is_complete) = Self::get_aggregated_state_impl(state).await; - inner.register_updates(flushed).await; - is_complete - } - - pub async fn get_aggregated_state(&self) -> ProgressUpdate { - Self::get_aggregated_state_impl(&self.state).await.0 - } - - // Ensure everything is completed. - pub async fn is_finished(&self) -> bool { - self.state.lock().await.finished && self.bg_update_loop_handle.lock().await.is_none() - } - - pub async fn finalize(&self) { - self.state.lock().await.finished = true; - - if let Some(bg_jh) = self.bg_update_loop_handle.lock().await.take() { - let _ = bg_jh.await; - } - } -} - -#[async_trait::async_trait] -impl TrackingProgressUpdater for AggregatingProgressUpdater { - async fn register_updates(&self, updates: ProgressUpdate) { - let mut state = self.state.lock().await; - state.merge_in(updates); - } - - async fn flush(&self) { - if let Some(inner) = &self.inner { - Self::flush_impl(inner, &self.state).await; - } - } -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - use std::sync::atomic::{AtomicU64, Ordering}; - use std::time::Duration; - - use ulid::Ulid; - - use super::*; - use crate::progress_tracking::ItemProgressUpdate; - - #[derive(Debug)] - struct MockUpdater { - flushed: Mutex>, - } - - #[async_trait::async_trait] - impl TrackingProgressUpdater for MockUpdater { - async fn register_updates(&self, update: ProgressUpdate) { - if update.is_empty() { - return; - } - - *self.flushed.lock().await = Some(update); - } - } - - impl MockUpdater { - async fn last_update(&self) -> ProgressUpdate { - self.flushed.lock().await.clone().unwrap() - } - } - - #[tokio::test] - async fn test_single_ordered_flush_and_totals() { - let mock = Arc::new(MockUpdater { - flushed: Mutex::new(None), - }); - - // Create an aggregator that aggregates updates every 50 ms; it should send one update that aggregates the three - // below. - let aggregator = - AggregatingProgressUpdater::new(mock.clone(), Duration::from_millis(50), Duration::from_millis(200)); - - let file_a = (Ulid::new(), "fileA.txt"); - let file_b = (Ulid::new(), "fileB.txt"); - let file_c = (Ulid::new(), "fileC.txt"); - - // First update: fileA - aggregator - .register_updates(ProgressUpdate { - item_updates: vec![ItemProgressUpdate { - tracking_id: file_a.0, - item_name: file_a.1.into(), - total_bytes: 100, - bytes_completed: 10, - bytes_completion_increment: 10, - }], - total_bytes: 100, - total_bytes_increment: 100, - total_bytes_completed: 10, - total_bytes_completion_increment: 10, - total_transfer_bytes: 50, - total_transfer_bytes_increment: 50, - total_transfer_bytes_completed: 5, - total_transfer_bytes_completion_increment: 5, - ..Default::default() - }) - .await; - - tokio::time::sleep(Duration::from_millis(10)).await; - - // Second update: fileB - aggregator - .register_updates(ProgressUpdate { - item_updates: vec![ItemProgressUpdate { - tracking_id: file_b.0, - item_name: file_b.1.into(), - total_bytes: 200, - bytes_completed: 50, - bytes_completion_increment: 50, - }], - total_bytes: 300, - total_bytes_increment: 200, - total_bytes_completed: 60, - total_bytes_completion_increment: 50, - total_transfer_bytes: 150, - total_transfer_bytes_increment: 100, - total_transfer_bytes_completed: 30, - total_transfer_bytes_completion_increment: 25, - ..Default::default() - }) - .await; - - tokio::time::sleep(Duration::from_millis(10)).await; - - // Third update: fileC - aggregator - .register_updates(ProgressUpdate { - item_updates: vec![ - ItemProgressUpdate { - tracking_id: file_c.0, - item_name: file_c.1.into(), - total_bytes: 300, - bytes_completed: 90, - bytes_completion_increment: 90, - }, - ItemProgressUpdate { - tracking_id: file_a.0, - item_name: file_a.1.into(), - total_bytes: 100, - bytes_completed: 30, - bytes_completion_increment: 20, - }, - ], - total_bytes: 600, - total_bytes_increment: 300, - total_bytes_completed: 170, - total_bytes_completion_increment: 110, - total_transfer_bytes: 300, - total_transfer_bytes_increment: 150, - total_transfer_bytes_completed: 85, - total_transfer_bytes_completion_increment: 55, - ..Default::default() - }) - .await; - - // Wait long enough for flush to trigger - tokio::time::sleep(Duration::from_millis(100)).await; - - // Get flushed update - let flushed = mock.last_update().await; - - // === Total fields === - assert_eq!(flushed.total_bytes, 600); - assert_eq!(flushed.total_bytes_increment, 600); - assert_eq!(flushed.total_bytes_completed, 170); - assert_eq!(flushed.total_bytes_completion_increment, 170); - - assert_eq!(flushed.total_transfer_bytes, 300); - assert_eq!(flushed.total_transfer_bytes_increment, 300); - assert_eq!(flushed.total_transfer_bytes_completed, 85); - assert_eq!(flushed.total_transfer_bytes_completion_increment, 85); - - // === Item updates === - assert_eq!(flushed.item_updates.len(), 3); - - let a = &flushed.item_updates[0]; - assert_eq!(a.item_name.as_ref(), "fileA.txt"); - assert_eq!(a.total_bytes, 100); - assert_eq!(a.bytes_completed, 30); - assert_eq!(a.bytes_completion_increment, 30); - - let b = &flushed.item_updates[1]; - assert_eq!(b.item_name.as_ref(), "fileB.txt"); - assert_eq!(b.total_bytes, 200); - assert_eq!(b.bytes_completed, 50); - assert_eq!(b.bytes_completion_increment, 50); - - let c = &flushed.item_updates[2]; - assert_eq!(c.item_name.as_ref(), "fileC.txt"); - assert_eq!(c.total_bytes, 300); - assert_eq!(c.bytes_completed, 90); - assert_eq!(c.bytes_completion_increment, 90); - } - - // A test to test that the speed estimation is correct. - #[tokio::test] - async fn test_speed_estimation() { - let mock = Arc::new(MockUpdater { - flushed: Mutex::new(None), - }); - - // Create an aggregator that aggregates updates every 50 ms; it should send one update that aggregates the three - // below. - let aggregator = - AggregatingProgressUpdater::new(mock.clone(), Duration::from_millis(1), Duration::from_millis(100)); - - let completed_bytes = Arc::new(AtomicU64::new(0)); - let completed_transfer_bytes = Arc::new(AtomicU64::new(0)); - - let add_updates = |total_bytes_per_ms: f64, transfer_bytes_per_ms: f64, n_ms: u64| { - let completed_bytes_ = completed_bytes.clone(); - let completed_transfer_bytes_ = completed_transfer_bytes.clone(); - let aggregator = aggregator.clone(); - - let update_start_time = Instant::now(); - let start_completed_bytes = completed_bytes_.load(Ordering::Relaxed); - let start_completed_transfer_bytes = completed_transfer_bytes_.load(Ordering::Relaxed); - - async move { - loop { - let now = Instant::now(); - let ms_elapsed = now.saturating_duration_since(update_start_time).as_secs_f64() * 1000.; - if ms_elapsed >= n_ms as f64 { - break; - } - let cb = start_completed_bytes + (ms_elapsed * total_bytes_per_ms) as u64; - let ctb = start_completed_transfer_bytes + (ms_elapsed * transfer_bytes_per_ms) as u64; - - let prev_cb = completed_bytes_.swap(cb, Ordering::Relaxed); - let prev_ctb = completed_transfer_bytes_.swap(ctb, Ordering::Relaxed); - - aggregator - .register_updates(ProgressUpdate { - total_bytes_completed: cb, - total_bytes_completion_increment: cb - prev_cb, - total_transfer_bytes_completed: ctb, - total_transfer_bytes_completion_increment: ctb - prev_ctb, - ..Default::default() - }) - .await; - - completed_bytes_.store(cb, Ordering::Relaxed); - completed_transfer_bytes_.store(ctb, Ordering::Relaxed); - } - } - }; - - let check_rate_values = |expected_completion_rate: u64, expected_transfer_rate: u64| { - let mock = mock.clone(); - async move { - let update = mock.last_update().await; - - let assert_close = |ctx: &str, a: f64, b: f64| { - assert_le!((a - b).abs() / (a.abs() + b.abs()), 0.5, "Values not within 25% ({ctx}): {a} != {b}"); - }; - - assert_close( - "completion", - update.total_bytes_completion_rate.unwrap_or_default() / 1000., /* Reported in seconds, we want - * in millis */ - expected_completion_rate as f64, - ); - - assert_close( - "transfer", - update.total_transfer_bytes_completion_rate.unwrap_or_default() / 1000., - expected_transfer_rate as f64, - ); - } - }; - - add_updates(1000., 100., 50).await; - check_rate_values(1000, 100).await; - - add_updates(1000., 100., 50).await; - check_rate_values(1000, 100).await; - - // Increase the rate, this should go up linearly. - add_updates(2000., 200., 25).await; - check_rate_values(1250, 125).await; - add_updates(2000., 200., 25).await; - check_rate_values(1500, 150).await; - add_updates(2000., 200., 25).await; - check_rate_values(1750, 175).await; - add_updates(2000., 200., 25).await; - check_rate_values(2000, 200).await; - } -} diff --git a/xet_data/src/progress_tracking/download_tracking.rs b/xet_data/src/progress_tracking/download_tracking.rs deleted file mode 100644 index d532cc46..00000000 --- a/xet_data/src/progress_tracking/download_tracking.rs +++ /dev/null @@ -1,582 +0,0 @@ -use std::sync::Arc; -use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; - -use more_asserts::debug_assert_le; -use ulid::Ulid; - -use super::{ItemProgressUpdate, ProgressUpdate, TrackingProgressUpdater}; - -/// Tracks the total progress across all download tasks. Updates on individual download tasks -/// are forwarded to the inner progress updater, using this info to update the totals. -pub struct DownloadProgressTracker { - inner: Arc, - total_bytes: AtomicU64, - total_transfer_bytes: AtomicU64, - total_bytes_completed: AtomicU64, - total_transfer_bytes_completed: AtomicU64, -} - -impl DownloadProgressTracker { - pub fn new(inner: Arc) -> Arc { - Arc::new(Self { - inner, - total_bytes: 0.into(), - total_transfer_bytes: 0.into(), - total_bytes_completed: 0.into(), - total_transfer_bytes_completed: 0.into(), - }) - } - - pub fn new_download_task(self: &Arc, tracking_id: Ulid, item_name: Arc) -> Arc { - Arc::new(DownloadTaskUpdater::new(tracking_id, item_name, self.clone())) - } - - #[inline] - pub fn assert_complete(&self) { - #[cfg(debug_assertions)] - { - assert_eq!(self.total_bytes_completed.load(Ordering::Relaxed), self.total_bytes.load(Ordering::Relaxed)); - assert_eq!( - self.total_transfer_bytes_completed.load(Ordering::Relaxed), - self.total_transfer_bytes.load(Ordering::Relaxed) - ); - } - } -} - -/// The interface struct for a single file or file segment. Holds a reference to the -/// group-level DownloadProgressTracker. -pub struct DownloadTaskUpdater { - tracking_id: Ulid, - item_name: Arc, - item_bytes: AtomicU64, - item_transfer_bytes: AtomicU64, - item_size_finalized: AtomicBool, - bytes_completed: AtomicU64, - transfer_bytes_completed: AtomicU64, - tracker: Arc, -} - -impl DownloadTaskUpdater { - fn new(tracking_id: Ulid, item_name: Arc, tracker: Arc) -> Self { - Self { - tracking_id, - item_name, - item_bytes: 0.into(), - item_transfer_bytes: 0.into(), - item_size_finalized: false.into(), - bytes_completed: 0.into(), - transfer_bytes_completed: 0.into(), - tracker, - } - } - - #[cfg(debug_assertions)] - pub fn correctness_verification_tracker() -> Arc { - let null_tracker = super::NoOpProgressUpdater::new(); - - let testing_download_tracker = DownloadProgressTracker::new(null_tracker); - testing_download_tracker.new_download_task(Ulid::new(), Arc::from("")) - } - - /// Updates the total decompressed item size. - /// - /// When `is_final` is true, the provided value is treated as the definitive total and - /// all subsequent calls to `update_item_size` are ignored. Use `is_final = true` when - /// the file size is known ahead of time (e.g. from metadata), and `is_final = false` - /// when the total is being discovered incrementally during reconstruction. - pub fn update_item_size(&self, total_item_bytes: u64, is_final: bool) { - // If already finalized, ignore subsequent calls. - if self.item_size_finalized.load(Ordering::Relaxed) { - // Make sure that updates reflect known reality correctly. - if is_final { - debug_assert_eq!(self.item_bytes.load(Ordering::Relaxed), total_item_bytes); - } - return; - } - - if is_final { - self.item_size_finalized.store(true, Ordering::Relaxed); - } - - let old_item_total = self.item_bytes.swap(total_item_bytes, Ordering::Release); - - if old_item_total == total_item_bytes { - return; - } - - // Should only increase. - debug_assert_le!(old_item_total, total_item_bytes); - - let total_bytes_increment = total_item_bytes.saturating_sub(old_item_total); - let old_total_bytes = self.tracker.total_bytes.fetch_add(total_bytes_increment, Ordering::Release); - let total_bytes = old_total_bytes + total_bytes_increment; - - let item_update = ItemProgressUpdate { - tracking_id: self.tracking_id, - item_name: self.item_name.clone(), - total_bytes: total_item_bytes, - bytes_completed: self.bytes_completed.load(Ordering::Relaxed), - bytes_completion_increment: 0, - }; - - let progress_update = ProgressUpdate { - item_updates: vec![item_update], - total_bytes, - total_bytes_increment, - total_bytes_completed: self.tracker.total_bytes_completed.load(Ordering::Relaxed), - total_bytes_completion_increment: 0, - total_bytes_completion_rate: None, - total_transfer_bytes: self.tracker.total_transfer_bytes.load(Ordering::Relaxed), - total_transfer_bytes_increment: 0, - total_transfer_bytes_completed: self.tracker.total_transfer_bytes_completed.load(Ordering::Relaxed), - total_transfer_bytes_completion_increment: 0, - total_transfer_bytes_completion_rate: None, - }; - - let inner = self.tracker.inner.clone(); - tokio::spawn(async move { inner.register_updates(progress_update).await }); - } - - /// Updates the expected total transfer (network) bytes for this item. - /// Called incrementally as xorb blocks are discovered during reconstruction. - pub fn update_transfer_size(&self, item_transfer_bytes: u64) { - let old_item_transfer = self.item_transfer_bytes.swap(item_transfer_bytes, Ordering::Relaxed); - - if old_item_transfer == item_transfer_bytes { - return; - } - - // Should only increase. - debug_assert_le!(old_item_transfer, item_transfer_bytes); - - let total_transfer_bytes_increment = item_transfer_bytes.saturating_sub(old_item_transfer); - let old_transfer_bytes = self - .tracker - .total_transfer_bytes - .fetch_add(total_transfer_bytes_increment, Ordering::Relaxed); - let total_transfer_bytes = old_transfer_bytes + total_transfer_bytes_increment; - - let progress_update = ProgressUpdate { - item_updates: vec![], - total_bytes: self.tracker.total_bytes.load(Ordering::Relaxed), - total_bytes_increment: 0, - total_bytes_completed: self.tracker.total_bytes_completed.load(Ordering::Relaxed), - total_bytes_completion_increment: 0, - total_bytes_completion_rate: None, - total_transfer_bytes, - total_transfer_bytes_increment, - total_transfer_bytes_completed: self.tracker.total_transfer_bytes_completed.load(Ordering::Relaxed), - total_transfer_bytes_completion_increment: 0, - total_transfer_bytes_completion_rate: None, - }; - - let inner = self.tracker.inner.clone(); - tokio::spawn(async move { inner.register_updates(progress_update).await }); - } - - pub fn total_bytes_completed(&self) -> u64 { - self.bytes_completed.load(Ordering::Relaxed) - } - - /// Reports decompressed bytes written to disk. - pub fn report_bytes_written(&self, increment: u64) { - if increment == 0 { - return; - } - - let item_total_bytes = self.item_bytes.load(Ordering::Acquire); - let old_completed = self.bytes_completed.fetch_add(increment, Ordering::Relaxed); - let bytes_completed = old_completed + increment; - - if item_total_bytes > 0 { - debug_assert_le!(bytes_completed, item_total_bytes); - } - - let global_old_completed = self.tracker.total_bytes_completed.fetch_add(increment, Ordering::Relaxed); - let total_bytes_completed = global_old_completed + increment; - - let total_bytes = self.tracker.total_bytes.load(Ordering::Acquire); - debug_assert_le!(total_bytes_completed, total_bytes); - - let item_progress_update = ItemProgressUpdate { - tracking_id: self.tracking_id, - item_name: self.item_name.clone(), - total_bytes: item_total_bytes, - bytes_completed, - bytes_completion_increment: increment, - }; - - let progress_update = ProgressUpdate { - item_updates: vec![item_progress_update], - total_bytes, - total_bytes_increment: 0, - total_bytes_completed, - total_bytes_completion_increment: increment, - total_bytes_completion_rate: None, - total_transfer_bytes: self.tracker.total_transfer_bytes.load(Ordering::Relaxed), - total_transfer_bytes_increment: 0, - total_transfer_bytes_completed: self.tracker.total_transfer_bytes_completed.load(Ordering::Relaxed), - total_transfer_bytes_completion_increment: 0, - total_transfer_bytes_completion_rate: None, - }; - - let inner = self.tracker.inner.clone(); - tokio::spawn(async move { inner.register_updates(progress_update).await }); - } - - /// Reports transfer (network) bytes downloaded. - pub fn report_transfer_progress(&self, transfer_increment: u64) { - if transfer_increment == 0 { - return; - } - - let item_total_transfer_bytes = self.item_transfer_bytes.load(Ordering::Relaxed); - let old_transfer_completed = self.transfer_bytes_completed.fetch_add(transfer_increment, Ordering::Relaxed); - let transfer_bytes_completed = old_transfer_completed + transfer_increment; - - if item_total_transfer_bytes > 0 { - debug_assert_le!(transfer_bytes_completed, item_total_transfer_bytes); - } - - let global_old_transfer_completed = self - .tracker - .total_transfer_bytes_completed - .fetch_add(transfer_increment, Ordering::Relaxed); - let total_transfer_bytes_completed = global_old_transfer_completed + transfer_increment; - - let total_transfer_bytes = self.tracker.total_transfer_bytes.load(Ordering::Relaxed); - debug_assert_le!(total_transfer_bytes_completed, total_transfer_bytes); - - let progress_update = ProgressUpdate { - item_updates: vec![], - total_bytes: self.tracker.total_bytes.load(Ordering::Relaxed), - total_bytes_increment: 0, - total_bytes_completed: self.tracker.total_bytes_completed.load(Ordering::Relaxed), - total_bytes_completion_increment: 0, - total_bytes_completion_rate: None, - total_transfer_bytes, - total_transfer_bytes_increment: 0, - total_transfer_bytes_completed, - total_transfer_bytes_completion_increment: transfer_increment, - total_transfer_bytes_completion_rate: None, - }; - - let inner = self.tracker.inner.clone(); - tokio::spawn(async move { inner.register_updates(progress_update).await }); - } - - #[inline] - pub fn assert_complete(&self) { - #[cfg(debug_assertions)] - { - assert_eq!(self.bytes_completed.load(Ordering::Relaxed), self.item_bytes.load(Ordering::Relaxed)); - assert_eq!( - self.transfer_bytes_completed.load(Ordering::Relaxed), - self.item_transfer_bytes.load(Ordering::Relaxed) - ); - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::progress_tracking::NoOpProgressUpdater; - - fn make_task(name: &str) -> (Arc, Arc) { - let tracker = DownloadProgressTracker::new(NoOpProgressUpdater::new()); - let task = tracker.new_download_task(Ulid::new(), Arc::from(name)); - (tracker, task) - } - - // ==================== update_item_size tests ==================== - - #[tokio::test] - async fn test_update_item_size_monotonic_increase() { - let (tracker, task) = make_task("file.bin"); - - task.update_item_size(100, false); - assert_eq!(task.item_bytes.load(Ordering::Relaxed), 100); - assert_eq!(tracker.total_bytes.load(Ordering::Relaxed), 100); - - task.update_item_size(300, false); - assert_eq!(task.item_bytes.load(Ordering::Relaxed), 300); - assert_eq!(tracker.total_bytes.load(Ordering::Relaxed), 300); - } - - #[tokio::test] - async fn test_update_item_size_same_value_is_noop() { - let (tracker, task) = make_task("file.bin"); - - task.update_item_size(1000, false); - assert_eq!(tracker.total_bytes.load(Ordering::Relaxed), 1000); - - task.update_item_size(1000, false); - assert_eq!(tracker.total_bytes.load(Ordering::Relaxed), 1000); - } - - #[tokio::test] - async fn test_update_item_size_final_ignores_subsequent_calls() { - let (tracker, task) = make_task("file.bin"); - - // Set final size. - task.update_item_size(1000, true); - assert_eq!(task.item_bytes.load(Ordering::Relaxed), 1000); - assert_eq!(tracker.total_bytes.load(Ordering::Relaxed), 1000); - - // Subsequent non-final calls are ignored. - task.update_item_size(500, false); - assert_eq!(task.item_bytes.load(Ordering::Relaxed), 1000); - assert_eq!(tracker.total_bytes.load(Ordering::Relaxed), 1000); - - // A second is_final call with the same value is also ignored (no-op). - task.update_item_size(1000, true); - assert_eq!(task.item_bytes.load(Ordering::Relaxed), 1000); - assert_eq!(tracker.total_bytes.load(Ordering::Relaxed), 1000); - } - - #[tokio::test] - async fn test_update_item_size_non_final_then_final() { - let (tracker, task) = make_task("file.bin"); - - // Incremental updates (non-final). - task.update_item_size(200, false); - task.update_item_size(500, false); - assert_eq!(tracker.total_bytes.load(Ordering::Relaxed), 500); - - // Final update locks it in. - task.update_item_size(1000, true); - assert_eq!(tracker.total_bytes.load(Ordering::Relaxed), 1000); - - // Further calls ignored. - task.update_item_size(1500, false); - assert_eq!(tracker.total_bytes.load(Ordering::Relaxed), 1000); - } - - #[tokio::test] - #[should_panic(expected = "left <= right")] - async fn test_update_item_size_decrease_panics_in_debug() { - let (_tracker, task) = make_task("file.bin"); - - task.update_item_size(1000, false); - task.update_item_size(100, false); - } - - // ==================== update_transfer_size tests ==================== - - #[tokio::test] - async fn test_update_transfer_size_monotonic_increase() { - let (tracker, task) = make_task("file.bin"); - - task.update_transfer_size(100); - assert_eq!(task.item_transfer_bytes.load(Ordering::Relaxed), 100); - assert_eq!(tracker.total_transfer_bytes.load(Ordering::Relaxed), 100); - - task.update_transfer_size(300); - assert_eq!(task.item_transfer_bytes.load(Ordering::Relaxed), 300); - assert_eq!(tracker.total_transfer_bytes.load(Ordering::Relaxed), 300); - } - - #[tokio::test] - async fn test_update_transfer_size_same_value_is_noop() { - let (tracker, task) = make_task("file.bin"); - - task.update_transfer_size(500); - assert_eq!(tracker.total_transfer_bytes.load(Ordering::Relaxed), 500); - - task.update_transfer_size(500); - assert_eq!(tracker.total_transfer_bytes.load(Ordering::Relaxed), 500); - } - - // ==================== report_bytes_written tests ==================== - - #[tokio::test] - async fn test_report_bytes_written_accumulates() { - let (tracker, task) = make_task("file.bin"); - task.update_item_size(1000, true); - - task.report_bytes_written(200); - assert_eq!(task.total_bytes_completed(), 200); - assert_eq!(tracker.total_bytes_completed.load(Ordering::Relaxed), 200); - - task.report_bytes_written(800); - assert_eq!(task.total_bytes_completed(), 1000); - assert_eq!(tracker.total_bytes_completed.load(Ordering::Relaxed), 1000); - } - - #[tokio::test] - async fn test_report_bytes_written_zero_is_noop() { - let (tracker, task) = make_task("file.bin"); - task.update_item_size(100, false); - - task.report_bytes_written(0); - assert_eq!(task.total_bytes_completed(), 0); - assert_eq!(tracker.total_bytes_completed.load(Ordering::Relaxed), 0); - } - - // ==================== report_transfer_progress tests ==================== - - #[tokio::test] - async fn test_report_transfer_progress_accumulates() { - let (tracker, task) = make_task("file.bin"); - task.update_transfer_size(500); - - task.report_transfer_progress(100); - assert_eq!(task.transfer_bytes_completed.load(Ordering::Relaxed), 100); - assert_eq!(tracker.total_transfer_bytes_completed.load(Ordering::Relaxed), 100); - - task.report_transfer_progress(400); - assert_eq!(task.transfer_bytes_completed.load(Ordering::Relaxed), 500); - assert_eq!(tracker.total_transfer_bytes_completed.load(Ordering::Relaxed), 500); - } - - #[tokio::test] - async fn test_report_transfer_progress_zero_is_noop() { - let (tracker, task) = make_task("file.bin"); - task.update_transfer_size(100); - - task.report_transfer_progress(0); - assert_eq!(task.transfer_bytes_completed.load(Ordering::Relaxed), 0); - assert_eq!(tracker.total_transfer_bytes_completed.load(Ordering::Relaxed), 0); - } - - // ==================== Combined flow tests ==================== - - #[tokio::test] - async fn test_full_flow_incremental_discovery() { - // Simulates the reconstruction flow where totals are discovered incrementally - // by the manager, while the writer and xorb_block report progress separately. - let (tracker, task) = make_task("file.bin"); - - // First batch: discover 500 decompressed bytes, 300 transfer bytes. - task.update_item_size(500, false); - task.update_transfer_size(300); - - // Writer reports decompressed bytes; xorb reports transfer bytes. - task.report_bytes_written(500); - task.report_transfer_progress(300); - - // Second batch: totals grow. - task.update_item_size(1000, false); - task.update_transfer_size(700); - - task.report_bytes_written(500); - task.report_transfer_progress(400); - - assert_eq!(task.total_bytes_completed(), 1000); - assert_eq!(task.transfer_bytes_completed.load(Ordering::Relaxed), 700); - - task.assert_complete(); - tracker.assert_complete(); - } - - #[tokio::test] - async fn test_full_flow_final_size_upfront() { - // Simulates the data_client.rs flow: file size is known ahead of time (is_final=true), - // transfer size is discovered incrementally by the manager. - let (tracker, task) = make_task("file.bin"); - - // data_client.rs: size known upfront. - task.update_item_size(1000, true); - assert_eq!(tracker.total_bytes.load(Ordering::Relaxed), 1000); - - // manager.rs discovers transfer sizes incrementally. - task.update_transfer_size(300); - task.update_transfer_size(700); - - // manager.rs also tries to set item_size, but it's ignored because final was set. - task.update_item_size(500, false); - assert_eq!(task.item_bytes.load(Ordering::Relaxed), 1000); - - // Writer reports bytes written; xorb reports transfer bytes. - task.report_bytes_written(600); - task.report_bytes_written(400); - task.report_transfer_progress(300); - task.report_transfer_progress(400); - - assert_eq!(task.total_bytes_completed(), 1000); - assert_eq!(task.transfer_bytes_completed.load(Ordering::Relaxed), 700); - - task.assert_complete(); - tracker.assert_complete(); - } - - #[tokio::test] - async fn test_interleaved_totals_and_progress() { - let (tracker, task) = make_task("file.bin"); - - // First batch discovered. - task.update_item_size(400, false); - task.update_transfer_size(200); - - // Start writing from first batch. - task.report_bytes_written(200); - task.report_transfer_progress(100); - - // Second batch discovered. - task.update_item_size(800, false); - task.update_transfer_size(500); - - // More progress. - task.report_bytes_written(400); - task.report_transfer_progress(250); - - // Final batch. - task.update_item_size(1000, false); - task.update_transfer_size(600); - - // Finish remaining. - task.report_bytes_written(400); - task.report_transfer_progress(250); - - assert_eq!(task.total_bytes_completed(), 1000); - assert_eq!(task.transfer_bytes_completed.load(Ordering::Relaxed), 600); - assert_eq!(tracker.total_bytes.load(Ordering::Relaxed), 1000); - assert_eq!(tracker.total_transfer_bytes.load(Ordering::Relaxed), 600); - - task.assert_complete(); - tracker.assert_complete(); - } - - #[tokio::test] - async fn test_transfer_bytes_less_than_total_bytes() { - let (tracker, task) = make_task("file.bin"); - - task.update_item_size(10000, true); - task.update_transfer_size(3000); - - task.report_bytes_written(10000); - task.report_transfer_progress(3000); - - assert_eq!(tracker.total_bytes.load(Ordering::Relaxed), 10000); - assert_eq!(tracker.total_transfer_bytes.load(Ordering::Relaxed), 3000); - - task.assert_complete(); - tracker.assert_complete(); - } - - #[tokio::test] - async fn test_multiple_tasks_independent_tracking() { - let tracker = DownloadProgressTracker::new(NoOpProgressUpdater::new()); - let task1 = tracker.new_download_task(Ulid::new(), Arc::from("file1.bin")); - let task2 = tracker.new_download_task(Ulid::new(), Arc::from("file2.bin")); - - task1.update_item_size(500, true); - task1.update_transfer_size(200); - task2.update_item_size(300, false); - task2.update_transfer_size(100); - - assert_eq!(tracker.total_bytes.load(Ordering::Relaxed), 800); - assert_eq!(tracker.total_transfer_bytes.load(Ordering::Relaxed), 300); - - task1.report_bytes_written(500); - task1.report_transfer_progress(200); - task2.report_bytes_written(300); - task2.report_transfer_progress(100); - - task1.assert_complete(); - task2.assert_complete(); - tracker.assert_complete(); - } -} diff --git a/xet_data/src/progress_tracking/mod.rs b/xet_data/src/progress_tracking/mod.rs index 6307f66d..3b517fc5 100644 --- a/xet_data/src/progress_tracking/mod.rs +++ b/xet_data/src/progress_tracking/mod.rs @@ -1,21 +1,6 @@ -pub mod aggregator; -pub mod download_tracking; -mod no_op_tracker; -mod progress_info; +mod progress_types; +mod speed_tracker; pub mod upload_tracking; -pub mod verification_wrapper; -use async_trait::async_trait; -pub use no_op_tracker::NoOpProgressUpdater; -pub use progress_info::{ItemProgressUpdate, ProgressUpdate}; - -/// The trait that a progress updater that reports per-item progress completion. -#[async_trait] -pub trait TrackingProgressUpdater: Send + Sync { - /// Register a set of updates as a list of ProgressUpdate instances, which - /// contain the name and progress information. - async fn register_updates(&self, updates: ProgressUpdate); - - /// Flush any updates out, if needed - async fn flush(&self) {} -} +pub use progress_types::{GroupProgress, GroupProgressReport, ItemProgress, ItemProgressReport, ItemProgressUpdater}; +pub use xet_runtime::utils::UniqueId as UniqueID; diff --git a/xet_data/src/progress_tracking/no_op_tracker.rs b/xet_data/src/progress_tracking/no_op_tracker.rs deleted file mode 100644 index bf5ca862..00000000 --- a/xet_data/src/progress_tracking/no_op_tracker.rs +++ /dev/null @@ -1,17 +0,0 @@ -use std::sync::Arc; - -use super::{ProgressUpdate, TrackingProgressUpdater}; - -#[derive(Debug, Default)] -pub struct NoOpProgressUpdater; - -impl NoOpProgressUpdater { - pub fn new() -> Arc { - Arc::new(Self {}) - } -} - -#[async_trait::async_trait] -impl TrackingProgressUpdater for NoOpProgressUpdater { - async fn register_updates(&self, _updates: ProgressUpdate) {} -} diff --git a/xet_data/src/progress_tracking/progress_types.rs b/xet_data/src/progress_tracking/progress_types.rs new file mode 100644 index 00000000..e1a28096 --- /dev/null +++ b/xet_data/src/progress_tracking/progress_types.rs @@ -0,0 +1,582 @@ +use std::collections::HashMap; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::{Arc, Mutex}; +use std::time::Duration; + +use more_asserts::debug_assert_le; + +use super::UniqueID; +use super::speed_tracker::{DEFAULT_MIN_OBSERVATIONS_FOR_RATE, DEFAULT_SPEED_HALF_LIFE, SpeedTracker}; +use super::upload_tracking::CompletionTracker; + +/// Per-item atomic progress counters. Created by `GroupProgress::new_item()`. +pub struct ItemProgress { + pub id: UniqueID, + pub name: Arc, + pub total_bytes: AtomicU64, + pub bytes_completed: AtomicU64, + pub transfer_bytes: AtomicU64, + pub transfer_bytes_completed: AtomicU64, + pub size_finalized: AtomicBool, +} + +impl ItemProgress { + fn new(id: UniqueID, name: Arc) -> Self { + Self { + id, + name, + total_bytes: AtomicU64::new(0), + bytes_completed: AtomicU64::new(0), + transfer_bytes: AtomicU64::new(0), + transfer_bytes_completed: AtomicU64::new(0), + size_finalized: AtomicBool::new(false), + } + } + + /// Snapshot of this item's progress. Reads completions first (Acquire), + /// then totals, which reduces transient skew in sampled values. + pub fn report(&self) -> ItemProgressReport { + let bytes_completed = self.bytes_completed.load(Ordering::Acquire); + let transfer_bytes_completed = self.transfer_bytes_completed.load(Ordering::Acquire); + let total_bytes = self.total_bytes.load(Ordering::Acquire); + let transfer_bytes = self.transfer_bytes.load(Ordering::Acquire); + + debug_assert_le!(bytes_completed, total_bytes); + debug_assert_le!(transfer_bytes_completed, transfer_bytes); + + ItemProgressReport { + item_name: self.name.to_string(), + total_bytes, + bytes_completed, + } + } +} + +impl std::fmt::Debug for ItemProgress { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ItemProgress") + .field("id", &self.id) + .field("name", &self.name) + .field("total_bytes", &self.total_bytes.load(Ordering::Relaxed)) + .field("bytes_completed", &self.bytes_completed.load(Ordering::Relaxed)) + .finish() + } +} + +/// Aggregate progress across all items, with an item registry. +/// Owns the atomic group-level counters and a map of per-item progress. +pub struct GroupProgress { + pub total_bytes: AtomicU64, + pub total_bytes_completed: AtomicU64, + pub total_transfer_bytes: AtomicU64, + pub total_transfer_bytes_completed: AtomicU64, + items: Mutex>>, + speed_tracker: Mutex, +} + +impl GroupProgress { + /// Create a group progress tracker using default speed estimation parameters. + pub fn new() -> Arc { + Self::with_speed_config(DEFAULT_SPEED_HALF_LIFE, DEFAULT_MIN_OBSERVATIONS_FOR_RATE) + } + + /// Create a group progress tracker with explicit speed estimation parameters. + pub fn with_speed_config(half_life: Duration, min_observations: u32) -> Arc { + Arc::new(Self { + total_bytes: AtomicU64::new(0), + total_bytes_completed: AtomicU64::new(0), + total_transfer_bytes: AtomicU64::new(0), + total_transfer_bytes_completed: AtomicU64::new(0), + items: Mutex::new(HashMap::new()), + speed_tracker: Mutex::new(SpeedTracker::new(half_life).with_min_observations(min_observations)), + }) + } + + /// Create a new tracked item and register it in the items map. + /// Returns an `ItemProgressUpdater` handle for the caller to report progress. + pub fn new_item(self: &Arc, id: UniqueID, name: impl Into>) -> Arc { + let item = Arc::new(ItemProgress::new(id, name.into())); + self.items.lock().unwrap().insert(id, item.clone()); + Arc::new(ItemProgressUpdater { + item, + group: self.clone(), + }) + } + + /// Create a new CompletionTracker backed by this group's progress. + pub fn new_completion_tracker(self: &Arc) -> CompletionTracker { + CompletionTracker::new(self.clone()) + } + + /// Snapshot of aggregate progress. Reads completions first (Acquire), then totals + /// to reduce transient sampling skew. + /// + /// Speed is estimated via [`SpeedTracker`], which uses an exponentially-weighted + /// moving average to produce smoothed bytes-per-second rates. + /// + /// This call updates internal speed-estimation state, so repeated calls are + /// not strictly idempotent. Rate fields remain `None` until enough speed + /// observations have been collected. + pub fn report(&self) -> GroupProgressReport { + let total_bytes_completed = self.total_bytes_completed.load(Ordering::Acquire); + let total_transfer_bytes_completed = self.total_transfer_bytes_completed.load(Ordering::Acquire); + let total_bytes = self.total_bytes.load(Ordering::Acquire); + let total_transfer_bytes = self.total_transfer_bytes.load(Ordering::Acquire); + + debug_assert_le!(total_bytes_completed, total_bytes); + debug_assert_le!(total_transfer_bytes_completed, total_transfer_bytes); + + let mut tracker = self.speed_tracker.lock().unwrap(); + tracker.update(total_bytes_completed, total_transfer_bytes_completed); + let (bytes_rate, transfer_rate) = tracker.rates(); + + GroupProgressReport { + total_bytes, + total_bytes_completed, + total_bytes_completion_rate: bytes_rate, + total_transfer_bytes, + total_transfer_bytes_completed, + total_transfer_bytes_completion_rate: transfer_rate, + } + } + + /// Snapshot of all per-item progress. + pub fn item_reports(&self) -> HashMap { + let items = self.items.lock().unwrap(); + items.iter().map(|(id, item)| (*id, item.report())).collect() + } + + /// Snapshot of one item's progress. + pub fn item_report(&self, id: UniqueID) -> Option { + let items = self.items.lock().unwrap(); + items.get(&id).map(|item| item.report()) + } + + /// Debug verification that all items are complete. + pub fn assert_complete(&self) { + #[cfg(debug_assertions)] + { + let total_bytes_completed = self.total_bytes_completed.load(Ordering::Acquire); + let total_bytes = self.total_bytes.load(Ordering::Acquire); + assert_eq!( + total_bytes_completed, total_bytes, + "GroupProgress not complete: total_bytes_completed={total_bytes_completed} != total_bytes={total_bytes}" + ); + + let total_transfer_bytes_completed = self.total_transfer_bytes_completed.load(Ordering::Acquire); + let total_transfer_bytes = self.total_transfer_bytes.load(Ordering::Acquire); + assert_eq!( + total_transfer_bytes_completed, total_transfer_bytes, + "GroupProgress not complete: total_transfer_bytes_completed={total_transfer_bytes_completed} != total_transfer_bytes={total_transfer_bytes}" + ); + + let items = self.items.lock().unwrap(); + for (id, item) in items.iter() { + let completed = item.bytes_completed.load(Ordering::Acquire); + let total = item.total_bytes.load(Ordering::Acquire); + assert_eq!( + completed, total, + "Item '{}' ({id}) not complete: bytes_completed={completed} != total_bytes={total}", + item.name + ); + } + } + } +} + +impl Default for GroupProgress { + fn default() -> Self { + // Note: returns Self, not Arc. Use GroupProgress::new() for Arc. + Self { + total_bytes: AtomicU64::new(0), + total_bytes_completed: AtomicU64::new(0), + total_transfer_bytes: AtomicU64::new(0), + total_transfer_bytes_completed: AtomicU64::new(0), + items: Mutex::new(HashMap::new()), + speed_tracker: Mutex::new( + SpeedTracker::new(DEFAULT_SPEED_HALF_LIFE).with_min_observations(DEFAULT_MIN_OBSERVATIONS_FOR_RATE), + ), + } + } +} + +impl std::fmt::Debug for GroupProgress { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("GroupProgress") + .field("total_bytes", &self.total_bytes.load(Ordering::Relaxed)) + .field("total_bytes_completed", &self.total_bytes_completed.load(Ordering::Relaxed)) + .field("total_transfer_bytes", &self.total_transfer_bytes.load(Ordering::Relaxed)) + .field("total_transfer_bytes_completed", &self.total_transfer_bytes_completed.load(Ordering::Relaxed)) + .finish() + } +} + +/// Handle for reporting progress on a single item. All progress updates +/// (both download and upload paths) go through this type. +/// +/// Replaces both `DownloadTaskUpdater` and the per-file update logic +/// that was previously in `CompletionTracker`. +pub struct ItemProgressUpdater { + item: Arc, + group: Arc, +} + +impl ItemProgressUpdater { + /// Create a standalone updater for debug/testing purposes. + /// Creates its own throwaway GroupProgress. + #[cfg(debug_assertions)] + pub fn new_standalone(name: &str) -> Arc { + let group = GroupProgress::new(); + let item = Arc::new(ItemProgress::new(UniqueID::new(), Arc::from(name))); + Arc::new(Self { item, group }) + } + + // === Size updates (use fetch_update for full atomicity) === + + /// Update the total item size. When `is_final` is true, subsequent calls are ignored. + pub fn update_item_size(&self, total: u64, is_final: bool) { + if self.item.size_finalized.load(Ordering::Acquire) { + if is_final { + debug_assert_eq!(self.item.total_bytes.load(Ordering::Acquire), total); + } + return; + } + if is_final { + self.item.size_finalized.store(true, Ordering::Release); + } + let result = self + .item + .total_bytes + .fetch_update(Ordering::Release, Ordering::Acquire, |old| if total > old { Some(total) } else { None }); + if let Ok(old) = result { + self.group.total_bytes.fetch_add(total - old, Ordering::Release); + } + } + + /// Update the total transfer (network) bytes for this item. + pub fn update_transfer_size(&self, total: u64) { + let result = self + .item + .transfer_bytes + .fetch_update(Ordering::Release, Ordering::Acquire, |old| if total > old { Some(total) } else { None }); + if let Ok(old) = result { + self.group.total_transfer_bytes.fetch_add(total - old, Ordering::Release); + } + } + + // === Completion updates (group first, then item) === + + /// Report decompressed/processed bytes completed for this item. + pub fn report_bytes_completed(&self, increment: u64) { + if increment == 0 { + return; + } + self.group.total_bytes_completed.fetch_add(increment, Ordering::Release); + let new_completed = self.item.bytes_completed.fetch_add(increment, Ordering::Release) + increment; + debug_assert_le!( + new_completed, + self.item.total_bytes.load(Ordering::Acquire), + "item '{}' bytes_completed ({}) exceeded total_bytes after +{}", + self.item.name, + new_completed, + increment + ); + } + + /// Report transfer (network) bytes completed. + pub fn report_transfer_bytes_completed(&self, increment: u64) { + if increment == 0 { + return; + } + self.group + .total_transfer_bytes_completed + .fetch_add(increment, Ordering::Release); + let new_completed = self.item.transfer_bytes_completed.fetch_add(increment, Ordering::Release) + increment; + debug_assert_le!( + new_completed, + self.item.transfer_bytes.load(Ordering::Acquire), + "item '{}' transfer_bytes_completed ({}) exceeded transfer_bytes after +{}", + self.item.name, + new_completed, + increment + ); + } + + // === Aliases for reconstruction pipeline compatibility === + + /// Alias for `report_bytes_completed` -- used by the reconstruction data writer. + pub fn report_bytes_written(&self, increment: u64) { + self.report_bytes_completed(increment); + } + + /// Alias for `report_transfer_bytes_completed` -- used by xorb block download. + pub fn report_transfer_progress(&self, delta: u64) { + self.report_transfer_bytes_completed(delta); + } + + /// Read the current bytes_completed for this item. + pub fn total_bytes_completed(&self) -> u64 { + self.item.bytes_completed.load(Ordering::Acquire) + } + + // === Debug verification === + + /// Assert this item is fully complete (completed == total for both bytes and transfer). + pub fn assert_complete(&self) { + #[cfg(debug_assertions)] + { + let completed = self.item.bytes_completed.load(Ordering::Acquire); + let total = self.item.total_bytes.load(Ordering::Acquire); + assert_eq!( + completed, total, + "item '{}' not complete: bytes_completed={completed} != total_bytes={total}", + self.item.name + ); + let t_completed = self.item.transfer_bytes_completed.load(Ordering::Acquire); + let t_total = self.item.transfer_bytes.load(Ordering::Acquire); + assert_eq!( + t_completed, t_total, + "item '{}' not complete: transfer_bytes_completed={t_completed} != transfer_bytes={t_total}", + self.item.name + ); + } + } + + pub fn item(&self) -> &Arc { + &self.item + } + + pub fn report(&self) -> ItemProgressReport { + self.item.report() + } + + pub fn group(&self) -> &Arc { + &self.group + } +} + +impl std::fmt::Debug for ItemProgressUpdater { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ItemProgressUpdater").field("item", &self.item).finish() + } +} + +// === Snapshot report structs === + +#[derive(Clone, Debug, Default)] +#[cfg_attr(feature = "python", pyo3::pyclass(get_all))] +pub struct GroupProgressReport { + pub total_bytes: u64, + pub total_bytes_completed: u64, + pub total_bytes_completion_rate: Option, + pub total_transfer_bytes: u64, + pub total_transfer_bytes_completed: u64, + pub total_transfer_bytes_completion_rate: Option, +} + +#[derive(Clone, Debug)] +#[cfg_attr(feature = "python", pyo3::pyclass(get_all))] +pub struct ItemProgressReport { + pub item_name: String, + pub total_bytes: u64, + pub bytes_completed: u64, +} + +#[cfg(test)] +mod tests { + use tokio::time::{Duration, advance, pause}; + + use super::*; + + #[test] + fn test_group_progress_new_item() { + let group = GroupProgress::new(); + let updater = group.new_item(UniqueID::new(), "test.bin"); + updater.update_item_size(100, true); + + assert_eq!(group.total_bytes.load(Ordering::Relaxed), 100); + } + + #[test] + fn test_item_progress_updater_bytes() { + let group = GroupProgress::new(); + let updater = group.new_item(UniqueID::new(), "test.bin"); + updater.update_item_size(100, true); + updater.report_bytes_completed(50); + + assert_eq!(updater.total_bytes_completed(), 50); + assert_eq!(group.total_bytes_completed.load(Ordering::Relaxed), 50); + } + + #[test] + fn test_item_progress_updater_transfer() { + let group = GroupProgress::new(); + let updater = group.new_item(UniqueID::new(), "test.bin"); + updater.update_item_size(100, true); + updater.update_transfer_size(80); + updater.report_transfer_bytes_completed(30); + + assert_eq!(group.total_transfer_bytes.load(Ordering::Relaxed), 80); + assert_eq!(group.total_transfer_bytes_completed.load(Ordering::Relaxed), 30); + } + + #[test] + fn test_update_item_size_finalized() { + let group = GroupProgress::new(); + let updater = group.new_item(UniqueID::new(), "test.bin"); + updater.update_item_size(100, true); + updater.update_item_size(200, false); + + assert_eq!(updater.item().total_bytes.load(Ordering::Relaxed), 100); + assert_eq!(group.total_bytes.load(Ordering::Relaxed), 100); + } + + #[test] + fn test_update_item_size_monotonic_increase() { + let group = GroupProgress::new(); + let updater = group.new_item(UniqueID::new(), "test.bin"); + updater.update_item_size(100, false); + updater.update_item_size(300, false); + + assert_eq!(updater.item().total_bytes.load(Ordering::Relaxed), 300); + assert_eq!(group.total_bytes.load(Ordering::Relaxed), 300); + } + + #[test] + fn test_update_item_size_same_value_noop() { + let group = GroupProgress::new(); + let updater = group.new_item(UniqueID::new(), "test.bin"); + updater.update_item_size(100, false); + updater.update_item_size(100, false); + + assert_eq!(group.total_bytes.load(Ordering::Relaxed), 100); + } + + #[test] + fn test_report_snapshot() { + let group = GroupProgress::new(); + let updater = group.new_item(UniqueID::new(), "file.bin"); + updater.update_item_size(1000, true); + updater.update_transfer_size(800); + updater.report_bytes_completed(500); + updater.report_transfer_bytes_completed(300); + + let report = group.report(); + assert_eq!(report.total_bytes, 1000); + assert_eq!(report.total_bytes_completed, 500); + assert_eq!(report.total_transfer_bytes, 800); + assert_eq!(report.total_transfer_bytes_completed, 300); + } + + #[test] + fn test_item_reports() { + let group = GroupProgress::new(); + let id1 = UniqueID::new(); + let id2 = UniqueID::new(); + + let u1 = group.new_item(id1, "a.bin"); + let u2 = group.new_item(id2, "b.bin"); + + u1.update_item_size(100, true); + u1.report_bytes_completed(60); + u2.update_item_size(200, true); + u2.report_bytes_completed(200); + + let reports = group.item_reports(); + assert_eq!(reports.len(), 2); + assert_eq!(reports[&id1].bytes_completed, 60); + assert_eq!(reports[&id2].bytes_completed, 200); + } + + #[test] + fn test_multiple_items_group_totals() { + let group = GroupProgress::new(); + let u1 = group.new_item(UniqueID::new(), "a.bin"); + let u2 = group.new_item(UniqueID::new(), "b.bin"); + + u1.update_item_size(100, true); + u2.update_item_size(200, true); + u1.report_bytes_completed(50); + u2.report_bytes_completed(100); + + assert_eq!(group.total_bytes.load(Ordering::Relaxed), 300); + assert_eq!(group.total_bytes_completed.load(Ordering::Relaxed), 150); + } + + #[test] + fn test_assert_complete() { + let group = GroupProgress::new(); + let updater = group.new_item(UniqueID::new(), "test.bin"); + updater.update_item_size(100, true); + updater.update_transfer_size(80); + updater.report_bytes_completed(100); + updater.report_transfer_bytes_completed(80); + + updater.assert_complete(); + group.assert_complete(); + } + + #[test] + fn test_zero_increment_noop() { + let group = GroupProgress::new(); + let updater = group.new_item(UniqueID::new(), "test.bin"); + updater.update_item_size(100, true); + updater.report_bytes_completed(0); + + assert_eq!(updater.total_bytes_completed(), 0); + assert_eq!(group.total_bytes_completed.load(Ordering::Relaxed), 0); + } + + #[test] + fn test_report_bytes_written_alias() { + let group = GroupProgress::new(); + let updater = group.new_item(UniqueID::new(), "test.bin"); + updater.update_item_size(100, true); + updater.report_bytes_written(50); + + assert_eq!(updater.total_bytes_completed(), 50); + } + + #[test] + fn test_report_transfer_progress_alias() { + let group = GroupProgress::new(); + let updater = group.new_item(UniqueID::new(), "test.bin"); + updater.update_item_size(100, true); + updater.update_transfer_size(90); + updater.report_transfer_progress(40); + + assert_eq!(updater.item().transfer_bytes_completed.load(Ordering::Relaxed), 40); + assert_eq!(group.total_transfer_bytes_completed.load(Ordering::Relaxed), 40); + } + + #[tokio::test] + async fn test_report_rates_none_until_min_observations_then_some() { + pause(); + + let group = GroupProgress::with_speed_config(Duration::from_secs(10), 3); + let updater = group.new_item(UniqueID::new(), "test.bin"); + updater.update_item_size(10_000, true); + updater.update_transfer_size(10_000); + + advance(Duration::from_millis(200)).await; + updater.report_bytes_completed(1_000); + updater.report_transfer_progress(800); + let report = group.report(); + assert!(report.total_bytes_completion_rate.is_none()); + assert!(report.total_transfer_bytes_completion_rate.is_none()); + + advance(Duration::from_millis(200)).await; + updater.report_bytes_completed(1_000); + updater.report_transfer_progress(800); + let report = group.report(); + assert!(report.total_bytes_completion_rate.is_none()); + assert!(report.total_transfer_bytes_completion_rate.is_none()); + + advance(Duration::from_millis(200)).await; + updater.report_bytes_completed(1_000); + updater.report_transfer_progress(800); + let report = group.report(); + assert!(report.total_bytes_completion_rate.is_some()); + assert!(report.total_transfer_bytes_completion_rate.is_some()); + } +} diff --git a/xet_data/src/progress_tracking/speed_tracker.rs b/xet_data/src/progress_tracking/speed_tracker.rs new file mode 100644 index 00000000..100f2ecc --- /dev/null +++ b/xet_data/src/progress_tracking/speed_tracker.rs @@ -0,0 +1,477 @@ +use std::time::Duration; + +use tokio::time::Instant; +use xet_core_structures::ExpWeightedMovingAvg; + +pub(crate) const DEFAULT_SPEED_HALF_LIFE: Duration = Duration::from_secs(10); +pub(crate) const DEFAULT_MIN_OBSERVATIONS_FOR_RATE: u32 = 4; + +/// Tracks smoothed byte-rate using an exponentially-weighted moving average. +/// +/// On each [`update`](Self::update) call the tracker computes the byte deltas +/// since the last call and feeds `(delta_bytes, elapsed_secs)` into the EWMA +/// via [`update_with_weight`](ExpWeightedMovingAvg::update_with_weight). +/// The resulting `value() = Σ(decayed bytes) / Σ(decayed time)` is a smoothed +/// bytes-per-second rate where older observations decay with the configured +/// half-life. +/// +/// Two independent channels are tracked: *bytes* (logical/decompressed) and +/// *transfer bytes* (network/compressed). +/// +/// The first observation's elapsed time is clamped to at least the half-life +/// so the rate starts conservatively low and converges upward. Rates are not +/// reported until at least `min_observations_for_rate` observations have been +/// recorded (default [`DEFAULT_MIN_OBSERVATIONS_FOR_RATE`]). +pub(crate) struct SpeedTracker { + bytes_rate: ExpWeightedMovingAvg, + transfer_rate: ExpWeightedMovingAvg, + last_bytes_completed: u64, + last_transfer_bytes_completed: u64, + last_report_time: Instant, + observation_count: u32, + min_initial_interval_secs: f64, + min_observations_for_rate: u32, +} + +impl SpeedTracker { + pub fn new(half_life: Duration) -> Self { + Self { + bytes_rate: ExpWeightedMovingAvg::new_time_decay(half_life), + transfer_rate: ExpWeightedMovingAvg::new_time_decay(half_life), + last_bytes_completed: 0, + last_transfer_bytes_completed: 0, + last_report_time: Instant::now(), + observation_count: 0, + min_initial_interval_secs: half_life.as_secs_f64(), + min_observations_for_rate: DEFAULT_MIN_OBSERVATIONS_FOR_RATE, + } + } + + pub fn with_min_observations(mut self, n: u32) -> Self { + self.min_observations_for_rate = n; + self + } + + /// Feed current cumulative byte counts. Computes deltas from the + /// previously seen values and the elapsed wall-clock time, then updates + /// both EWMA channels. + /// + /// On the first observation the elapsed time is clamped to at least the + /// half-life so the rate starts near zero and converges upward, avoiding + /// wild initial spikes. If elapsed time is zero, this call is a no-op. + pub fn update(&mut self, bytes_completed: u64, transfer_bytes_completed: u64) { + let now = Instant::now(); + let mut elapsed = (now - self.last_report_time).as_secs_f64(); + + if elapsed > 0.0 { + if self.observation_count == 0 { + elapsed = elapsed.max(self.min_initial_interval_secs); + } + + let bytes_delta = bytes_completed.saturating_sub(self.last_bytes_completed); + let transfer_delta = transfer_bytes_completed.saturating_sub(self.last_transfer_bytes_completed); + + self.bytes_rate.update_with_weight(bytes_delta as f64, elapsed); + self.transfer_rate.update_with_weight(transfer_delta as f64, elapsed); + + self.last_bytes_completed = bytes_completed; + self.last_transfer_bytes_completed = transfer_bytes_completed; + self.last_report_time = now; + self.observation_count = self.observation_count.saturating_add(1); + } + } + + /// Current smoothed rates in bytes/sec. Returns `(bytes_rate, transfer_rate)`. + /// Both are `None` until at least `min_observations_for_rate` observations + /// with nonzero elapsed time have been recorded. + pub fn rates(&self) -> (Option, Option) { + if self.observation_count >= self.min_observations_for_rate { + (Some(self.bytes_rate.value()), Some(self.transfer_rate.value())) + } else { + (None, None) + } + } +} + +#[cfg(test)] +mod tests { + use more_asserts::{assert_ge, assert_le, assert_lt}; + use tokio::time::{Duration, advance, pause}; + + use super::*; + + const HALF_LIFE: Duration = Duration::from_secs(10); + const TICK: Duration = Duration::from_millis(200); + + fn bytes_rate(tracker: &SpeedTracker) -> Option { + tracker.rates().0 + } + + fn transfer_rate(tracker: &SpeedTracker) -> Option { + tracker.rates().1 + } + + // ── Basic behaviour ──────────────────────────────────────────── + + #[tokio::test] + async fn no_rate_before_any_observation() { + pause(); + let tracker = SpeedTracker::new(HALF_LIFE); + let (br, tr) = tracker.rates(); + assert!(br.is_none()); + assert!(tr.is_none()); + } + + #[tokio::test] + async fn rates_none_until_min_observations() { + pause(); + let mut tracker = SpeedTracker::new(HALF_LIFE); + let bytes_per_tick = 2_000u64; + let mut total = 0u64; + + for _ in 0..DEFAULT_MIN_OBSERVATIONS_FOR_RATE { + assert!(bytes_rate(&tracker).is_none()); + advance(TICK).await; + total += bytes_per_tick; + tracker.update(total, 0); + } + + assert!(bytes_rate(&tracker).is_some()); + } + + #[tokio::test] + async fn configurable_min_observations() { + pause(); + let min_obs = 8u32; + let mut tracker = SpeedTracker::new(HALF_LIFE).with_min_observations(min_obs); + let bytes_per_tick = 2_000u64; + let mut total = 0u64; + + for _ in 0..min_obs - 1 { + advance(TICK).await; + total += bytes_per_tick; + tracker.update(total, 0); + } + assert!(bytes_rate(&tracker).is_none()); + + advance(TICK).await; + total += bytes_per_tick; + tracker.update(total, 0); + assert!(bytes_rate(&tracker).is_some()); + } + + #[tokio::test] + async fn constant_rate_converges() { + pause(); + let mut tracker = SpeedTracker::new(HALF_LIFE); + + let rate = 10_000.0; + let bytes_per_tick = (rate * TICK.as_secs_f64()) as u64; + + let mut total = 0u64; + for _ in 0..500 { + advance(TICK).await; + total += bytes_per_tick; + tracker.update(total, 0); + } + + let measured = bytes_rate(&tracker).unwrap(); + assert!((measured - rate).abs() / rate < 0.01); + } + + #[tokio::test] + async fn two_channels_tracked_independently() { + pause(); + let mut tracker = SpeedTracker::new(HALF_LIFE); + + let bytes_target = 20_000.0; + let transfer_target = 5_000.0; + let bytes_per_tick = (bytes_target * TICK.as_secs_f64()) as u64; + let transfer_per_tick = (transfer_target * TICK.as_secs_f64()) as u64; + + let mut total_bytes = 0u64; + let mut total_transfer = 0u64; + + for _ in 0..250 { + advance(TICK).await; + total_bytes += bytes_per_tick; + total_transfer += transfer_per_tick; + tracker.update(total_bytes, total_transfer); + } + + let br = bytes_rate(&tracker).unwrap(); + let tr = transfer_rate(&tracker).unwrap(); + assert!((br - bytes_target).abs() / bytes_target < 0.05); + assert!((tr - transfer_target).abs() / transfer_target < 0.05); + } + + // ── Warm-up / initial rate ───────────────────────────────────── + + #[tokio::test] + async fn initial_rate_ramps_up_smoothly() { + pause(); + let mut tracker = SpeedTracker::new(HALF_LIFE).with_min_observations(1); + + let rate = 10_000.0; + let bytes_per_tick = (rate * TICK.as_secs_f64()) as u64; + let mut total = 0u64; + let mut prev_rate = 0.0; + + for i in 0..250 { + advance(TICK).await; + total += bytes_per_tick; + tracker.update(total, 0); + + let r = bytes_rate(&tracker).unwrap(); + + if i == 0 { + assert_lt!(r, rate * 0.20); + } + + if i > 0 { + assert_ge!(r, prev_rate * 0.99); + } + + prev_rate = r; + } + + assert!((prev_rate - rate).abs() / rate < 0.05); + } + + #[tokio::test] + async fn no_initial_spike() { + pause(); + let mut tracker = SpeedTracker::new(HALF_LIFE).with_min_observations(1); + + advance(TICK).await; + tracker.update(50_000, 0); + + let r = bytes_rate(&tracker).unwrap(); + let max_expected = 50_000.0 / HALF_LIFE.as_secs_f64(); + assert_le!(r, max_expected * 1.01); + } + + // ── Smoothing / stability ────────────────────────────────────── + + #[tokio::test] + async fn burst_then_silence_smooths_gradually() { + pause(); + let mut tracker = SpeedTracker::new(HALF_LIFE).with_min_observations(1); + + advance(TICK).await; + tracker.update(100_000, 0); + let peak = bytes_rate(&tracker).unwrap(); + + let mut prev = peak; + for _ in 0..10 { + advance(TICK).await; + tracker.update(100_000, 0); + let r = bytes_rate(&tracker).unwrap(); + assert_le!(r, prev); + prev = r; + } + + assert_lt!(prev, peak); + assert!(prev > 0.0); + } + + #[tokio::test] + async fn rate_stays_stable_under_uniform_feed() { + pause(); + let mut tracker = SpeedTracker::new(HALF_LIFE); + + let rate = 50_000.0; + let bytes_per_tick = (rate * TICK.as_secs_f64()) as u64; + let mut total = 0u64; + + for _ in 0..500 { + advance(TICK).await; + total += bytes_per_tick; + tracker.update(total, 0); + } + + for _ in 0..50 { + advance(TICK).await; + total += bytes_per_tick; + tracker.update(total, 0); + let r = bytes_rate(&tracker).unwrap(); + assert!((r - rate).abs() / rate < 0.01); + } + } + + #[tokio::test] + async fn speed_change_adapts() { + pause(); + let mut tracker = SpeedTracker::new(HALF_LIFE); + + let slow = 1_000.0; + let fast = 10_000.0; + let slow_per_tick = (slow * TICK.as_secs_f64()) as u64; + let fast_per_tick = (fast * TICK.as_secs_f64()) as u64; + let mut total = 0u64; + + for _ in 0..300 { + advance(TICK).await; + total += slow_per_tick; + tracker.update(total, 0); + } + let before = bytes_rate(&tracker).unwrap(); + assert!((before - slow).abs() / slow < 0.05); + + for _ in 0..250 { + advance(TICK).await; + total += fast_per_tick; + tracker.update(total, 0); + } + let after = bytes_rate(&tracker).unwrap(); + assert!((after - fast).abs() / fast < 0.05); + } + + // ── Decay / half-life ────────────────────────────────────────── + + #[tokio::test] + async fn stall_decays_rate_toward_zero() { + pause(); + let mut tracker = SpeedTracker::new(HALF_LIFE); + + let rate = 20_000.0; + let bytes_per_tick = (rate * TICK.as_secs_f64()) as u64; + let mut total = 0u64; + + for _ in 0..200 { + advance(TICK).await; + total += bytes_per_tick; + tracker.update(total, 0); + } + let active_rate = bytes_rate(&tracker).unwrap(); + assert!(active_rate > rate * 0.5); + + for _ in 0..150 { + advance(TICK).await; + tracker.update(total, 0); + } + let stalled = bytes_rate(&tracker).unwrap(); + assert_lt!(stalled, active_rate * 0.15); + } + + #[tokio::test] + async fn shorter_half_life_decays_faster() { + pause(); + let mut fast_tracker = SpeedTracker::new(Duration::from_secs(2)); + let mut slow_tracker = SpeedTracker::new(Duration::from_secs(20)); + + let bytes_per_tick = 2_000u64; + let mut total = 0u64; + + for _ in 0..200 { + advance(TICK).await; + total += bytes_per_tick; + fast_tracker.update(total, 0); + slow_tracker.update(total, 0); + } + + for _ in 0..25 { + advance(TICK).await; + fast_tracker.update(total, 0); + slow_tracker.update(total, 0); + } + + let fast_rate = bytes_rate(&fast_tracker).unwrap(); + let slow_rate = bytes_rate(&slow_tracker).unwrap(); + assert_lt!(fast_rate, slow_rate); + } + + // ── Smoothness metric ────────────────────────────────────────── + + #[tokio::test] + async fn jitter_in_arrivals_smoothed_out() { + pause(); + let mut tracker = SpeedTracker::new(HALF_LIFE); + + let target_rate = 10_000.0; + let avg_bytes_per_tick = (target_rate * TICK.as_secs_f64()) as u64; + let mut total = 0u64; + + let mut rates = Vec::new(); + + for i in 0..300 { + advance(TICK).await; + if i % 2 == 0 { + total += avg_bytes_per_tick * 2; + } + tracker.update(total, 0); + + if i >= 200 { + rates.push(bytes_rate(&tracker).unwrap()); + } + } + + let mean: f64 = rates.iter().sum::() / rates.len() as f64; + + assert!((mean - target_rate).abs() / target_rate < 0.05); + + let variance: f64 = rates.iter().map(|r| (r - mean).powi(2)).sum::() / rates.len() as f64; + let cv = variance.sqrt() / mean; + assert_lt!(cv, 0.05); + } + + // ── Edge cases ───────────────────────────────────────────────── + + #[tokio::test] + async fn zero_elapsed_update_is_noop() { + pause(); + let mut tracker = SpeedTracker::new(HALF_LIFE); + tracker.update(1000, 500); + assert!(bytes_rate(&tracker).is_none()); + } + + #[tokio::test] + async fn resume_after_long_stall_picks_up_new_rate() { + pause(); + let mut tracker = SpeedTracker::new(HALF_LIFE); + + let bytes_per_tick = 2_000u64; + let mut total = 0u64; + + for _ in 0..300 { + advance(TICK).await; + total += bytes_per_tick; + tracker.update(total, 0); + } + + for _ in 0..500 { + advance(TICK).await; + tracker.update(total, 0); + } + + let stalled = bytes_rate(&tracker).unwrap(); + assert_lt!(stalled, 100.0); + + let slow_per_tick = 1_000u64; + for _ in 0..250 { + advance(TICK).await; + total += slow_per_tick; + tracker.update(total, 0); + } + + let r = bytes_rate(&tracker).unwrap(); + let expected = slow_per_tick as f64 / TICK.as_secs_f64(); + assert!((r - expected).abs() / expected < 0.05); + } + + #[tokio::test] + async fn large_tick_interval_works() { + pause(); + let mut tracker = SpeedTracker::new(HALF_LIFE).with_min_observations(1); + + advance(Duration::from_secs(15)).await; + tracker.update(150_000, 75_000); + + let br = bytes_rate(&tracker).unwrap(); + let tr = transfer_rate(&tracker).unwrap(); + assert_ge!(br, 9_900.0); + assert_le!(br, 10_100.0); + assert_ge!(tr, 4_900.0); + assert_le!(tr, 5_100.0); + } +} diff --git a/xet_data/src/progress_tracking/upload_tracking.rs b/xet_data/src/progress_tracking/upload_tracking.rs index 1f1f386a..4f3669df 100644 --- a/xet_data/src/progress_tracking/upload_tracking.rs +++ b/xet_data/src/progress_tracking/upload_tracking.rs @@ -1,15 +1,15 @@ use std::collections::BTreeSet; use std::collections::hash_map::Entry as HashMapEntry; use std::mem::take; -use std::sync::Arc; +use std::sync::atomic::Ordering; +use std::sync::{Arc, Mutex}; use more_asserts::{debug_assert_ge, debug_assert_le}; -use tokio::sync::Mutex; -use ulid::Ulid; use xet_core_structures::MerkleHashMap; use xet_core_structures::merklehash::MerkleHash; -use super::{ItemProgressUpdate, ProgressUpdate, TrackingProgressUpdater}; +use super::UniqueID; +use super::progress_types::{GroupProgress, ItemProgressUpdater}; pub struct FileXorbDependency { pub file_id: u64, @@ -47,25 +47,12 @@ struct XorbPartCompletionStats { /// Represents a file that depends on one or more xorbs. struct FileDependency { - /// A unique id when the below name is not enough to identify a single file. - tracking_id: Ulid, - - /// Human-readable name of the file. + tracking_id: UniqueID, + updater: Arc, name: Arc, - - /// Total size of this file in bytes. total_bytes: u64, - - /// Whether the total size is known to be final. When false, the size can be - /// updated via `update_file_size`. is_final_size_known: bool, - - /// Total bytes already uploaded for this file (across its xorbs). completed_bytes: u64, - - /// Mapping of xorb_hash -> (number of completed bytes / number of bytes of the file contained in that xorb). Only - /// xorbs that are not uploaded yet are tracked here. - /// Once an xorb is uploaded, we remove it from here (and add to `completed_bytes`). remaining_xorbs_parts: MerkleHashMap, } @@ -74,12 +61,9 @@ struct FileDependency { /// are fully uploaded. #[derive(Default)] struct CompletionTrackerImpl { - /// List of all files being tracked. files: Vec, - /// Map of xorb hash -> its dependency info (which files rely on it). xorbs: MerkleHashMap, - /// Keep track of the totals across all xorbs. total_upload_bytes: u64, total_upload_bytes_completed: u64, @@ -89,76 +73,47 @@ struct CompletionTrackerImpl { pub struct CompletionTracker { inner: Mutex, - progress_reporter: Arc, + group: Arc, } impl CompletionTrackerImpl { - /// Registers a new file for tracking and returns an ID (its index in `files`). - /// - /// If `n_bytes` is `Some(size)`, the file size is treated as final and cannot be - /// updated later. If `n_bytes` is `None`, the file is registered with a size of - /// zero and `is_final_size_known` is set to `false`; callers should subsequently - /// call `update_file_size` to provide the actual size. fn register_new_file( &mut self, - tracking_id: Ulid, - name: impl Into>, + updater: Arc, n_bytes: Option, - ) -> (ProgressUpdate, CompletionTrackerFileId) { + ) -> CompletionTrackerFileId { let (total_bytes, is_final_size_known) = match n_bytes { Some(size) => (size, true), None => (0, false), }; - // The file's ID is simply its index in the internal `files` vector. - let file_id = self.files.len() as CompletionTrackerFileId; + updater.update_item_size(total_bytes, n_bytes.is_some()); + + let file_id = self.files.len() as CompletionTrackerFileId; + let tracking_id = updater.item().id; + let name = updater.item().name.clone(); - // Create a new FileDependency record. let file_dependency = FileDependency { tracking_id, - name: name.into(), + updater, + name, total_bytes, is_final_size_known, completed_bytes: 0, remaining_xorbs_parts: MerkleHashMap::new(), }; - // Insert it into our files vector. self.files.push(file_dependency); - - // We have more to process now. self.total_bytes += total_bytes; - // Register that the total bytes known has changed, and return the file ID so the caller can register - // dependencies on this file. - ( - ProgressUpdate { - item_updates: vec![], - total_bytes: self.total_bytes, - total_bytes_increment: total_bytes, - total_bytes_completed: self.total_bytes_completed, - total_bytes_completion_increment: 0, - total_transfer_bytes: self.total_upload_bytes, - total_transfer_bytes_increment: 0, - total_transfer_bytes_completed: self.total_upload_bytes_completed, - total_transfer_bytes_completion_increment: 0, - ..Default::default() - }, - file_id, - ) + file_id } - /// Increments the total size of a previously registered file by `size_increment` bytes. - /// - /// Returns `Some(ProgressUpdate)` if the size was updated, or `None` if the file's size - /// has already been finalized (via `register_new_file` with `Some(size)`). When `None` - /// is returned, no internal state is modified. - fn increment_file_size(&mut self, file_id: CompletionTrackerFileId, size_increment: u64) -> Option { + fn increment_file_size(&mut self, file_id: CompletionTrackerFileId, size_increment: u64) { let file_entry = &mut self.files[file_id as usize]; - // If already finalized, nothing to do. if file_entry.is_final_size_known { - return None; + return; } file_entry.total_bytes += size_increment; @@ -167,114 +122,56 @@ impl CompletionTrackerImpl { debug_assert_ge!(file_entry.total_bytes, file_entry.completed_bytes); debug_assert_ge!(self.total_bytes, self.total_bytes_completed); - // Emit an item update so progress reporters see the new total for this file. - let item_update = ItemProgressUpdate { - tracking_id: file_entry.tracking_id, - item_name: file_entry.name.clone(), - total_bytes: file_entry.total_bytes, - bytes_completed: file_entry.completed_bytes, - bytes_completion_increment: 0, - }; - - Some(ProgressUpdate { - item_updates: vec![item_update], - total_bytes: self.total_bytes, - total_bytes_increment: size_increment, - total_bytes_completed: self.total_bytes_completed, - total_bytes_completion_increment: 0, - total_transfer_bytes: self.total_upload_bytes, - total_transfer_bytes_increment: 0, - total_transfer_bytes_completed: self.total_upload_bytes_completed, - total_transfer_bytes_completion_increment: 0, - ..Default::default() - }) + file_entry.updater.update_item_size(file_entry.total_bytes, false); } - /// Registers that all or part of a given file (by `file_id`) depends on one or more - /// xorbs; Given a list of (xorb_hash, n_bytes, already_uploaded), registers the progress. - fn register_dependencies(&mut self, dependencies: &[FileXorbDependency]) -> ProgressUpdate { - let mut item_updates = Vec::new(); - + fn register_dependencies(&mut self, dependencies: &[FileXorbDependency]) { let mut file_bytes_processed = 0; for dep in dependencies { let file_entry = &mut self.files[dep.file_id as usize]; if dep.is_external { - // This is the freebie case, where we can just increment the progress. file_entry.completed_bytes += dep.n_bytes; debug_assert_le!(file_entry.completed_bytes, file_entry.total_bytes); - let progress_update = ItemProgressUpdate { - tracking_id: file_entry.tracking_id, - item_name: file_entry.name.clone(), - total_bytes: file_entry.total_bytes, - bytes_completed: file_entry.completed_bytes, - bytes_completion_increment: dep.n_bytes, - }; - + file_entry.updater.report_bytes_completed(dep.n_bytes); file_bytes_processed += dep.n_bytes; - - item_updates.push(progress_update); } else { - // Make sure we aren't putting in an unfinished xorb, which - // tracks with MerkleHash::marker(). debug_assert_ne!(dep.xorb_hash, MerkleHash::marker()); let entry = self.xorbs.entry(dep.xorb_hash).or_default(); - // If the entry has already been completed, then just mark this as completed. if entry.is_completed { file_entry.completed_bytes += dep.n_bytes; debug_assert_le!(file_entry.completed_bytes, file_entry.total_bytes); - let progress_update = ItemProgressUpdate { - tracking_id: file_entry.tracking_id, - item_name: file_entry.name.clone(), - total_bytes: file_entry.total_bytes, - bytes_completed: file_entry.completed_bytes, - bytes_completion_increment: dep.n_bytes, - }; - item_updates.push(progress_update); + file_entry.updater.report_bytes_completed(dep.n_bytes); file_bytes_processed += dep.n_bytes; } else { - // Set the reference here to this file entry.file_indices.insert(dep.file_id as usize); - - // Set the reference here to the xorb file_entry.remaining_xorbs_parts.entry(dep.xorb_hash).or_default().n_bytes += dep.n_bytes; } } } - // Register that this much has been completed already self.total_bytes_completed += file_bytes_processed; - debug_assert_le!(self.total_bytes_completed, self.total_bytes); - - // There may be a lot of per-file updates, but these don't actually count against the new byte total; - // this is counted only using xorbs. - ProgressUpdate { - item_updates, - total_bytes: self.total_bytes, - total_bytes_increment: 0, - total_bytes_completed: self.total_bytes_completed, - total_bytes_completion_increment: file_bytes_processed, - total_transfer_bytes: self.total_upload_bytes, - total_transfer_bytes_increment: 0, - total_transfer_bytes_completed: self.total_upload_bytes_completed, - total_transfer_bytes_completion_increment: 0, - ..Default::default() - } } - /// Register a new xorb. Returns true if the xorb is new and now registered for upload, and false - /// if it's already been uploaded and registered. - fn register_new_xorb(&mut self, xorb_hash: MerkleHash, xorb_size: u64) -> (ProgressUpdate, bool) { + fn register_new_xorb(&mut self, group: &Arc, xorb_hash: MerkleHash, xorb_size: u64) -> bool { match self.xorbs.entry(xorb_hash) { - HashMapEntry::Occupied(occupied_entry) => { - debug_assert_eq!(occupied_entry.get().xorb_size, xorb_size); - (ProgressUpdate::default(), false) + HashMapEntry::Occupied(mut occupied_entry) => { + let entry = occupied_entry.get_mut(); + if entry.xorb_size == 0 { + entry.xorb_size = xorb_size; + self.total_upload_bytes += xorb_size; + group.total_transfer_bytes.fetch_add(xorb_size, Ordering::Release); + true + } else { + debug_assert_eq!(entry.xorb_size, xorb_size); + false + } }, HashMapEntry::Vacant(vacant_entry) => { vacant_entry.insert(XorbDependency { @@ -285,61 +182,33 @@ impl CompletionTrackerImpl { }); self.total_upload_bytes += xorb_size; - - ( - ProgressUpdate { - item_updates: vec![], - total_bytes: self.total_bytes, - total_bytes_increment: 0, - total_bytes_completed: self.total_bytes_completed, - total_bytes_completion_increment: 0, - total_transfer_bytes: self.total_upload_bytes, - total_transfer_bytes_increment: xorb_size, - total_transfer_bytes_completed: self.total_upload_bytes_completed, - total_transfer_bytes_completion_increment: 0, - ..Default::default() - }, - true, - ) + group.total_transfer_bytes.fetch_add(xorb_size, Ordering::Release); + true }, } } - /// Called when a xorb is finished uploading. We look up which files depend on that - /// xorb and update their `completed_bytes`, removing the xorb from their - /// `remaining_xorbs_parts`. - fn register_xorb_upload_completion(&mut self, xorb_hash: MerkleHash) -> ProgressUpdate { + fn register_xorb_upload_completion(&mut self, group: &Arc, xorb_hash: MerkleHash) { let (file_indices, byte_completion_increment) = { - // Should have been registered above with register_xorb - debug_assert!(self.xorbs.contains_key(&xorb_hash)); - - // Mark as completed, return the list of files to mark as completed. let entry = self.xorbs.entry(xorb_hash).or_default(); - // How many new bytes uploaded do we have to write out to the total_completed_bytes? + if entry.is_completed { + return; + } + let new_byte_increment = entry.xorb_size - entry.completed_bytes; - - // This should be present but not completed. - debug_assert!(!entry.is_completed); - entry.is_completed = true; (take(&mut entry.file_indices), new_byte_increment) }; - // Mark all the relevant files as completed - let mut item_updates = Vec::with_capacity(file_indices.len()); - let mut file_bytes_processed = 0; - // For each file that depends on this xorb, remove the relevant - // part from `remaining_xorbs_parts` and add to `completed_bytes`. for file_id in file_indices { let file_entry = &mut self.files[file_id]; debug_assert!(file_entry.remaining_xorbs_parts.contains_key(&xorb_hash)); - // This xorb is completed, so remove the number of bytes in that file needed by that xorb. let xorb_part = file_entry.remaining_xorbs_parts.remove(&xorb_hash).unwrap_or_default(); debug_assert_le!(xorb_part.completed_bytes, xorb_part.n_bytes); @@ -347,109 +216,58 @@ impl CompletionTrackerImpl { if n_bytes_remaining > 0 { file_entry.completed_bytes += n_bytes_remaining; - - let progress_update = ItemProgressUpdate { - tracking_id: file_entry.tracking_id, - item_name: file_entry.name.clone(), - total_bytes: file_entry.total_bytes, - bytes_completed: file_entry.completed_bytes, - bytes_completion_increment: n_bytes_remaining, - }; - + file_entry.updater.report_bytes_completed(n_bytes_remaining); file_bytes_processed += n_bytes_remaining; - - item_updates.push(progress_update); } } debug_assert_le!(self.total_upload_bytes_completed + byte_completion_increment, self.total_upload_bytes); self.total_upload_bytes_completed += byte_completion_increment; + group + .total_transfer_bytes_completed + .fetch_add(byte_completion_increment, Ordering::Release); self.total_bytes_completed += file_bytes_processed; debug_assert_le!(self.total_bytes_completed, self.total_bytes); - - ProgressUpdate { - item_updates, - total_bytes: self.total_bytes, - total_bytes_increment: 0, - total_bytes_completed: self.total_bytes_completed, - total_bytes_completion_increment: file_bytes_processed, - total_transfer_bytes: self.total_upload_bytes, - total_transfer_bytes_completed: self.total_upload_bytes_completed, - total_transfer_bytes_completion_increment: byte_completion_increment, - total_transfer_bytes_increment: 0, - ..Default::default() - } } - /// Register partial upload progress of a xorb; new_byte_progress is the number of new bytes uploaded. - /// - /// If force_proper_ordering is true, then all the updates should arrive before register_xorb_upload_completion; - /// if debug_assertions are on, then all the details will be checked. If this is false, then the updates can - /// arrive out of order and will simply be ignored if the register_xorb_upload_completion has been called. fn register_xorb_upload_progress( &mut self, + group: &Arc, xorb_hash: MerkleHash, new_byte_progress: u64, check_ordering: bool, - ) -> ProgressUpdate { - // Should have already been registered. + ) { debug_assert!(self.xorbs.contains_key(&xorb_hash)); - // Mark as completed, return the list of files to mark as completed. let entry = self.xorbs.entry(xorb_hash).or_default(); - // If this update could arrive out of order, check to see if it's needed and ignore if not. if !check_ordering && entry.is_completed { - // Return an empty update - return ProgressUpdate { - item_updates: vec![], - total_bytes: self.total_bytes, - total_bytes_increment: 0, - total_bytes_completed: self.total_bytes_completed, - total_bytes_completion_increment: 0, - total_transfer_bytes: self.total_upload_bytes, - total_transfer_bytes_increment: 0, - total_transfer_bytes_completed: self.total_upload_bytes_completed, - total_transfer_bytes_completion_increment: 0, - ..Default::default() - }; + return; } - // Should not be completed when this is called. debug_assert!(!entry.is_completed); - - // Is the update reasonable? debug_assert_le!(entry.completed_bytes + new_byte_progress, entry.xorb_size); entry.completed_bytes += new_byte_progress; let new_completion_ratio = (entry.completed_bytes as f64) / (entry.xorb_size as f64); - // Mark all the relevant files as completed - let mut item_updates = Vec::with_capacity(entry.file_indices.len()); - let mut file_bytes_processed = 0; - // For each file that depends on this xorb, update a proportion of that remove the relevant - // part from `remaining_xorbs_parts` and add to `completed_bytes`. for &file_id in entry.file_indices.iter() { let file_entry = &mut self.files[file_id]; - // Should be registered there. debug_assert!(file_entry.remaining_xorbs_parts.contains_key(&xorb_hash)); - // Update let incremental_update = 'update: { let Some(xorb_part) = file_entry.remaining_xorbs_parts.get_mut(&xorb_hash) else { break 'update 0; }; debug_assert_le!(xorb_part.completed_bytes, xorb_part.n_bytes); - // Use floor so as to not inproperly report completion when there is still some to go. let new_completion_bytes = ((xorb_part.n_bytes as f64) * new_completion_ratio).floor() as u64; - // Make sure this is an update debug_assert_ge!(new_completion_bytes, xorb_part.completed_bytes); let incremental_update = new_completion_bytes.saturating_sub(xorb_part.completed_bytes); @@ -462,37 +280,20 @@ impl CompletionTrackerImpl { if incremental_update != 0 { file_entry.completed_bytes += incremental_update; - - let progress_update = ItemProgressUpdate { - tracking_id: file_entry.tracking_id, - item_name: file_entry.name.clone(), - total_bytes: file_entry.total_bytes, - bytes_completed: file_entry.completed_bytes, - bytes_completion_increment: incremental_update, - }; + file_entry.updater.report_bytes_completed(incremental_update); file_bytes_processed += incremental_update; - item_updates.push(progress_update); } } self.total_upload_bytes_completed += new_byte_progress; debug_assert_le!(self.total_upload_bytes_completed, self.total_upload_bytes); + group + .total_transfer_bytes_completed + .fetch_add(new_byte_progress, Ordering::Release); + self.total_bytes_completed += file_bytes_processed; debug_assert_le!(self.total_bytes_completed, self.total_bytes); - - ProgressUpdate { - item_updates, - total_bytes: self.total_bytes, - total_bytes_increment: 0, - total_bytes_completed: self.total_bytes_completed, - total_bytes_completion_increment: file_bytes_processed, - total_transfer_bytes: self.total_upload_bytes, - total_transfer_bytes_increment: 0, - total_transfer_bytes_completed: self.total_upload_bytes_completed, - total_transfer_bytes_completion_increment: new_byte_progress, - ..Default::default() - } } fn status(&self) -> (u64, u64) { @@ -517,16 +318,12 @@ impl CompletionTrackerImpl { done == total } - /// Checks that all files are fully completed (no remaining xorbs or incomplete bytes), - /// and that all xorbs are marked completed with no lingering file references. - /// Panics if any incomplete data is found. fn assert_complete(&self) { - // Check each file for completeness for (idx, file) in self.files.iter().enumerate() { assert_eq!( file.completed_bytes, file.total_bytes, - "File #{} ({}) is not fully completed: {}/{} bytes", - idx, file.name, file.completed_bytes, file.total_bytes + "File #{} ({}, {}) is not fully completed: {}/{} bytes", + idx, file.name, file.tracking_id, file.completed_bytes, file.total_bytes ); assert!( file.remaining_xorbs_parts.is_empty(), @@ -537,7 +334,6 @@ impl CompletionTrackerImpl { ); } - // Check each xorb to ensure it's marked completed and no file references remain for (hash, xorb_dep) in self.xorbs.iter() { assert!(xorb_dep.is_completed, "Xorb {hash:?} is not marked completed."); assert!( @@ -550,121 +346,68 @@ impl CompletionTrackerImpl { } } -/// A wrapper around the above class to work with the locking and the reporting. impl CompletionTracker { - pub fn new(progress_reporter: Arc) -> Self { - CompletionTracker { + pub fn new(group: Arc) -> Self { + Self { inner: Mutex::new(CompletionTrackerImpl::default()), - progress_reporter, + group, } } - pub async fn register_new_file( + pub fn register_new_file( &self, - tracking_id: Ulid, - name: impl Into>, + updater: Arc, n_bytes: Option, ) -> CompletionTrackerFileId { - let mut update_lock = self.inner.lock().await; - - let (updates, ret) = update_lock.register_new_file(tracking_id, name, n_bytes); - - if !updates.is_empty() { - self.progress_reporter.register_updates(updates).await; - } - - ret + let mut update_lock = self.inner.lock().unwrap(); + update_lock.register_new_file(updater, n_bytes) } - pub async fn increment_file_size(&self, file_id: CompletionTrackerFileId, size_increment: u64) { - let mut update_lock = self.inner.lock().await; - - if let Some(updates) = update_lock.increment_file_size(file_id, size_increment) - && !updates.is_empty() - { - self.progress_reporter.register_updates(updates).await; - } + pub fn increment_file_size(&self, file_id: CompletionTrackerFileId, size_increment: u64) { + let mut update_lock = self.inner.lock().unwrap(); + update_lock.increment_file_size(file_id, size_increment); } - pub async fn register_new_xorb(&self, xorb_hash: MerkleHash, xorb_size: u64) -> bool { - let mut update_lock = self.inner.lock().await; - - let (updates, ret) = update_lock.register_new_xorb(xorb_hash, xorb_size); - - if !updates.is_empty() { - self.progress_reporter.register_updates(updates).await; - } - - ret + pub fn register_new_xorb(&self, xorb_hash: MerkleHash, xorb_size: u64) -> bool { + let mut update_lock = self.inner.lock().unwrap(); + update_lock.register_new_xorb(&self.group, xorb_hash, xorb_size) } - /// Register a list of (file_id, xorb_hash, usize, bool) - pub async fn register_dependencies(&self, dependencies: &[FileXorbDependency]) { - let mut update_lock = self.inner.lock().await; - - let updates = update_lock.register_dependencies(dependencies); - - if !updates.is_empty() { - self.progress_reporter.register_updates(updates).await; - } + pub fn register_dependencies(&self, dependencies: &[FileXorbDependency]) { + let mut update_lock = self.inner.lock().unwrap(); + update_lock.register_dependencies(dependencies); } - pub async fn register_xorb_upload_completion(&self, xorb_hash: MerkleHash) { - let mut update_lock = self.inner.lock().await; - - let updates = update_lock.register_xorb_upload_completion(xorb_hash); - - if !updates.is_empty() { - self.progress_reporter.register_updates(updates).await; - } + pub fn register_xorb_upload_completion(&self, xorb_hash: MerkleHash) { + let mut update_lock = self.inner.lock().unwrap(); + update_lock.register_xorb_upload_completion(&self.group, xorb_hash); } - pub async fn register_xorb_upload_progress(&self, xorb_hash: MerkleHash, new_byte_progress: u64) { - self.register_xorb_upload_progress_impl(xorb_hash, new_byte_progress, true) - .await; + pub fn register_xorb_upload_progress(&self, xorb_hash: MerkleHash, new_byte_progress: u64) { + self.register_xorb_upload_progress_impl(xorb_hash, new_byte_progress, true); } pub fn register_xorb_upload_progress_background(self: Arc, xorb_hash: MerkleHash, new_byte_progress: u64) { - // register partial progress in the background; if this happens out of order, no worries. tokio::spawn(async move { - self.register_xorb_upload_progress_impl(xorb_hash, new_byte_progress, false) - .await + self.register_xorb_upload_progress_impl(xorb_hash, new_byte_progress, false); }); } - async fn register_xorb_upload_progress_impl( - &self, - xorb_hash: MerkleHash, - new_byte_progress: u64, - check_ordering: bool, - ) { - let mut update_lock = self.inner.lock().await; - - let updates = update_lock.register_xorb_upload_progress(xorb_hash, new_byte_progress, check_ordering); - - if !updates.is_empty() { - self.progress_reporter.register_updates(updates).await; - } + fn register_xorb_upload_progress_impl(&self, xorb_hash: MerkleHash, new_byte_progress: u64, check_ordering: bool) { + let mut update_lock = self.inner.lock().unwrap(); + update_lock.register_xorb_upload_progress(&self.group, xorb_hash, new_byte_progress, check_ordering); } - /// Async wrapper that locks the internal struct and calls the sync `verify_complete`. - pub async fn status(&self) -> (u64, u64) { - self.inner.lock().await.status() + pub fn status(&self) -> (u64, u64) { + self.inner.lock().unwrap().status() } - /// Async wrapper that locks the internal struct and calls the sync `verify_complete`. - pub async fn is_complete(&self) -> bool { - self.inner.lock().await.is_complete() + pub fn is_complete(&self) -> bool { + self.inner.lock().unwrap().is_complete() } - /// Async wrapper that locks the internal struct and calls the sync `verify_complete`. - pub async fn assert_complete(&self) { - self.inner.lock().await.assert_complete(); - } - - /// Flush the progress reporter - pub async fn flush(&self) { - self.progress_reporter.flush().await; + pub fn assert_complete(&self) { + self.inner.lock().unwrap().assert_complete(); } } @@ -673,565 +416,475 @@ mod tests { use xet_core_structures::merklehash::MerkleHash; use super::*; - use crate::progress_tracking::NoOpProgressUpdater; - use crate::progress_tracking::verification_wrapper::ProgressUpdaterVerificationWrapper; - /// A basic test showing partial updates and final completion checks - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn test_status_and_is_complete() { - // 1) Create no-op + verification wrapper - let no_op = NoOpProgressUpdater::new(); - let verifier = ProgressUpdaterVerificationWrapper::new(no_op); - // 2) Create our CompletionTracker with the verifying reporter - let tracker = CompletionTracker::new(verifier.clone()); + #[test] + fn test_status_and_is_complete() { + let group = GroupProgress::new(); + let tracker = CompletionTracker::new(group.clone()); - // Register two files - let file_a = tracker.register_new_file(Ulid::new(), "fileA", Some(100)).await; - let file_b = tracker.register_new_file(Ulid::new(), "fileB", Some(50)).await; + let updater_a = group.new_item(UniqueID::new(), "fileA"); + let file_a = tracker.register_new_file(updater_a, Some(100)); - // Initially, done=0, total=150 - let (done, total) = tracker.status().await; + let updater_b = group.new_item(UniqueID::new(), "fileB"); + let file_b = tracker.register_new_file(updater_b, Some(50)); + + let (done, total) = tracker.status(); assert_eq!(done, 0); assert_eq!(total, 150); - assert!(!tracker.is_complete().await); + assert!(!tracker.is_complete()); - // fileA depends on x for 100 bytes, already uploaded let x = MerkleHash::random_from_seed(1); - tracker - .register_dependencies(&[FileXorbDependency { - file_id: file_a, - xorb_hash: x, - n_bytes: 100, - is_external: true, - }]) - .await; + tracker.register_dependencies(&[FileXorbDependency { + file_id: file_a, + xorb_hash: x, + n_bytes: 100, + is_external: true, + }]); - // Now fileA is 100/100, fileB is 0/50 => done=100, total=150 - let (done, total) = tracker.status().await; + let (done, total) = tracker.status(); assert_eq!(done, 100); assert_eq!(total, 150); - assert!(!tracker.is_complete().await); + assert!(!tracker.is_complete()); - // fileB depends on y for 50 bytes, not yet uploaded let y = MerkleHash::random_from_seed(2); - tracker - .register_dependencies(&[FileXorbDependency { - file_id: file_b, - xorb_hash: y, - n_bytes: 50, - is_external: false, - }]) - .await; + tracker.register_dependencies(&[FileXorbDependency { + file_id: file_b, + xorb_hash: y, + n_bytes: 50, + is_external: false, + }]); - let (done, total) = tracker.status().await; + let (done, total) = tracker.status(); assert_eq!(done, 100); assert_eq!(total, 150); - // Now upload y - tracker.register_xorb_upload_completion(y).await; + tracker.register_new_xorb(y, 50); + tracker.register_xorb_upload_completion(y); - let (done, total) = tracker.status().await; + let (done, total) = tracker.status(); assert_eq!(done, 150); assert_eq!(total, 150); - assert!(tracker.is_complete().await); + assert!(tracker.is_complete()); - // Confirm internal consistency in the tracker - tracker.assert_complete().await; - // Confirm the updates themselves were valid - verifier.assert_complete().await; + tracker.assert_complete(); + group.assert_complete(); } - /// Multiple files sharing one xorb, with partial "already uploaded" logic - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn test_multiple_files_one_shared_xorb() { - let no_op = NoOpProgressUpdater::new(); - let verifier = ProgressUpdaterVerificationWrapper::new(no_op); - let tracker = CompletionTracker::new(verifier.clone()); + #[test] + fn test_multiple_files_one_shared_xorb() { + let group = GroupProgress::new(); + let tracker = CompletionTracker::new(group.clone()); - // Two files => 200 + 300 = 500 total - let file_a = tracker.register_new_file(Ulid::new(), "fileA", Some(200)).await; - let file_b = tracker.register_new_file(Ulid::new(), "fileB", Some(300)).await; + let updater_a = group.new_item(UniqueID::new(), "fileA"); + let file_a = tracker.register_new_file(updater_a, Some(200)); - let (done, total) = tracker.status().await; + let updater_b = group.new_item(UniqueID::new(), "fileB"); + let file_b = tracker.register_new_file(updater_b, Some(300)); + + let (done, total) = tracker.status(); assert_eq!(done, 0); assert_eq!(total, 500); - // Shared xorb let xhash = MerkleHash::random_from_seed(1); - tracker.register_new_xorb(xhash, 1000).await; + tracker.register_new_xorb(xhash, 1000); - // fileA => xhash 100 bytes (not uploaded) - // fileB => xhash 200 bytes (already uploaded) - tracker - .register_dependencies(&[ - FileXorbDependency { - file_id: file_a, - xorb_hash: xhash, - n_bytes: 100, - is_external: false, - }, - FileXorbDependency { - file_id: file_b, - xorb_hash: xhash, - n_bytes: 200, - is_external: true, - }, - ]) - .await; - - let (done, total) = tracker.status().await; - assert_eq!(done, 200); // fileB got immediate 200 - assert_eq!(total, 500); - assert!(!tracker.is_complete().await); - - // Mark xhash fully uploaded => fileA +100 - tracker.register_xorb_upload_completion(xhash).await; - - let (done, total) = tracker.status().await; - assert_eq!(done, 300); // A:100 + B:200 - assert_eq!(total, 500); - - // Suppose fileA is 100/200. We'll "fix" it with x2 => 100 bytes (already uploaded) - let x2 = MerkleHash::random_from_seed(2); - - tracker.register_new_xorb(x2, 1000).await; - - tracker - .register_dependencies(&[FileXorbDependency { + tracker.register_dependencies(&[ + FileXorbDependency { file_id: file_a, - xorb_hash: x2, - n_bytes: 100, - is_external: true, - }]) - .await; - - let (done, total) = tracker.status().await; - assert_eq!(done, 400); // A:200, B:200 - assert_eq!(total, 500); - - // B's remaining 100 bytes also from x2, not uploaded - tracker - .register_dependencies(&[FileXorbDependency { - file_id: file_b, - xorb_hash: x2, + xorb_hash: xhash, n_bytes: 100, is_external: false, - }]) - .await; + }, + FileXorbDependency { + file_id: file_b, + xorb_hash: xhash, + n_bytes: 200, + is_external: true, + }, + ]); - let (done, total) = tracker.status().await; + let (done, total) = tracker.status(); + assert_eq!(done, 200); + assert_eq!(total, 500); + assert!(!tracker.is_complete()); + + tracker.register_xorb_upload_completion(xhash); + + let (done, total) = tracker.status(); + assert_eq!(done, 300); + assert_eq!(total, 500); + + let x2 = MerkleHash::random_from_seed(2); + + tracker.register_new_xorb(x2, 1000); + + tracker.register_dependencies(&[FileXorbDependency { + file_id: file_a, + xorb_hash: x2, + n_bytes: 100, + is_external: true, + }]); + + let (done, total) = tracker.status(); assert_eq!(done, 400); assert_eq!(total, 500); - assert!(!tracker.is_complete().await); - // Upload x2 => B now 300/300 - tracker.register_xorb_upload_completion(x2).await; - let (done, total) = tracker.status().await; + tracker.register_dependencies(&[FileXorbDependency { + file_id: file_b, + xorb_hash: x2, + n_bytes: 100, + is_external: false, + }]); + + let (done, total) = tracker.status(); + assert_eq!(done, 400); + assert_eq!(total, 500); + assert!(!tracker.is_complete()); + + tracker.register_xorb_upload_completion(x2); + let (done, total) = tracker.status(); assert_eq!(done, 500); assert_eq!(total, 500); - assert!(tracker.is_complete().await); + assert!(tracker.is_complete()); - tracker.assert_complete().await; - verifier.assert_complete().await; + tracker.assert_complete(); + group.assert_complete(); } - /// One file, multiple xorbs, partial "already_uploaded" scenario - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn test_single_file_multiple_xorbs() { - let no_op = NoOpProgressUpdater::new(); - let verifier = ProgressUpdaterVerificationWrapper::new(no_op); - let tracker = CompletionTracker::new(verifier.clone()); + #[test] + fn test_single_file_multiple_xorbs() { + let group = GroupProgress::new(); + let tracker = CompletionTracker::new(group.clone()); - let f = tracker.register_new_file(Ulid::new(), "bigFile", Some(300)).await; + let updater = group.new_item(UniqueID::new(), "bigFile"); + let f = tracker.register_new_file(updater, Some(300)); let x1 = MerkleHash::random_from_seed(1); let x2 = MerkleHash::random_from_seed(2); let x3 = MerkleHash::random_from_seed(3); - tracker.register_new_xorb(x1, 100).await; - tracker.register_new_xorb(x3, 100).await; + tracker.register_new_xorb(x1, 100); + tracker.register_new_xorb(x3, 100); - // bigFile depends on: - // x1 => 100 bytes, not uploaded - // x2 => 100 bytes, already uploaded - // x3 => 100 bytes, not uploaded - tracker - .register_dependencies(&[ - FileXorbDependency { - file_id: f, - xorb_hash: x1, - n_bytes: 100, - is_external: false, - }, - FileXorbDependency { - file_id: f, - xorb_hash: x2, - n_bytes: 100, - is_external: true, - }, - FileXorbDependency { - file_id: f, - xorb_hash: x3, - n_bytes: 100, - is_external: false, - }, - ]) - .await; - - let (done, total) = tracker.status().await; - assert_eq!(done, 100); // from x2 - assert_eq!(total, 300); - assert!(!tracker.is_complete().await); - - // Upload x1 => bigFile from 100 -> 200 - tracker.register_xorb_upload_completion(x1).await; - let (done, total) = tracker.status().await; - assert_eq!(done, 200); - assert_eq!(total, 300); - assert!(!tracker.is_complete().await); - - // Upload x3 => bigFile from 200 -> 300 - tracker.register_xorb_upload_completion(x3).await; - let (done, total) = tracker.status().await; - assert_eq!(done, 300); - assert_eq!(total, 300); - assert!(tracker.is_complete().await); - - tracker.assert_complete().await; - verifier.assert_complete().await; - } - - /// Xorb is completed before dependencies are registered, - /// but the tracker credits the file immediately upon dependency registration - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn test_xorb_completed_before_dependencies() { - let no_op = NoOpProgressUpdater::new(); - let verifier = ProgressUpdaterVerificationWrapper::new(no_op); - let tracker = CompletionTracker::new(verifier.clone()); - - // One file, 50 bytes - let file_id = tracker.register_new_file(Ulid::new(), "lateFile", Some(50)).await; - - // xhash completed before we mention any dependencies - let x = MerkleHash::random_from_seed(999); - tracker.register_new_xorb(x, 1000).await; - - tracker.register_xorb_upload_completion(x).await; - - // Now we register that file depends on x for 50 bytes, "already_uploaded=false" - // but the tracker sees x is completed => immediate credit. - tracker - .register_dependencies(&[FileXorbDependency { - file_id, - xorb_hash: x, - n_bytes: 50, - is_external: false, - }]) - .await; - - let (done, total) = tracker.status().await; - assert_eq!(done, 50); - assert_eq!(total, 50); - assert!(tracker.is_complete().await); - - tracker.assert_complete().await; - verifier.assert_complete().await; - } - - /// Demonstrates leftover references if we do contradictory logic, - /// but with the updated logic, the tracker sees x is completed and - /// grants immediate credit anyway. - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn test_contradictory_logic_with_completed_xorb() { - let no_op = NoOpProgressUpdater::new(); - let verifier = ProgressUpdaterVerificationWrapper::new(no_op); - let tracker = CompletionTracker::new(verifier.clone()); - - let file_id = tracker.register_new_file(Ulid::new(), "someFile", Some(100)).await; - let x = MerkleHash::random_from_seed(123); - - tracker.register_new_xorb(x, 1000).await; - - // Mark x as completed, no dependencies known - tracker.register_xorb_upload_completion(x).await; - - // Then register a dependency with "already_uploaded=false" - // The code sees x.is_completed==true => immediate credit for 100 bytes - tracker - .register_dependencies(&[FileXorbDependency { - file_id, - xorb_hash: x, + tracker.register_dependencies(&[ + FileXorbDependency { + file_id: f, + xorb_hash: x1, n_bytes: 100, is_external: false, - }]) - .await; + }, + FileXorbDependency { + file_id: f, + xorb_hash: x2, + n_bytes: 100, + is_external: true, + }, + FileXorbDependency { + file_id: f, + xorb_hash: x3, + n_bytes: 100, + is_external: false, + }, + ]); - let (done, total) = tracker.status().await; + let (done, total) = tracker.status(); assert_eq!(done, 100); - assert_eq!(total, 100); - assert!(tracker.is_complete().await); + assert_eq!(total, 300); + assert!(!tracker.is_complete()); - tracker.assert_complete().await; - verifier.assert_complete().await; + tracker.register_xorb_upload_completion(x1); + let (done, total) = tracker.status(); + assert_eq!(done, 200); + assert_eq!(total, 300); + assert!(!tracker.is_complete()); + + tracker.register_xorb_upload_completion(x3); + let (done, total) = tracker.status(); + assert_eq!(done, 300); + assert_eq!(total, 300); + assert!(tracker.is_complete()); + + tracker.assert_complete(); + group.assert_complete(); } - /// Register a file with no initial size, then grow it incrementally. - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn test_increment_file_size_basic() { - let no_op = NoOpProgressUpdater::new(); - let verifier = ProgressUpdaterVerificationWrapper::new(no_op); - let tracker = CompletionTracker::new(verifier.clone()); + #[test] + fn test_xorb_completed_before_dependencies() { + let group = GroupProgress::new(); + let tracker = CompletionTracker::new(group.clone()); - // Register file with unknown size - let file_id = tracker.register_new_file(Ulid::new(), "growingFile", None).await; + let updater = group.new_item(UniqueID::new(), "lateFile"); + let file_id = tracker.register_new_file(updater, Some(50)); - let (done, total) = tracker.status().await; + let x = MerkleHash::random_from_seed(999); + tracker.register_new_xorb(x, 1000); + + tracker.register_xorb_upload_completion(x); + + tracker.register_dependencies(&[FileXorbDependency { + file_id, + xorb_hash: x, + n_bytes: 50, + is_external: false, + }]); + + let (done, total) = tracker.status(); + assert_eq!(done, 50); + assert_eq!(total, 50); + assert!(tracker.is_complete()); + + tracker.assert_complete(); + group.assert_complete(); + } + + #[test] + fn test_contradictory_logic_with_completed_xorb() { + let group = GroupProgress::new(); + let tracker = CompletionTracker::new(group.clone()); + + let updater = group.new_item(UniqueID::new(), "someFile"); + let file_id = tracker.register_new_file(updater, Some(100)); + let x = MerkleHash::random_from_seed(123); + + tracker.register_new_xorb(x, 1000); + + tracker.register_xorb_upload_completion(x); + + tracker.register_dependencies(&[FileXorbDependency { + file_id, + xorb_hash: x, + n_bytes: 100, + is_external: false, + }]); + + let (done, total) = tracker.status(); + assert_eq!(done, 100); + assert_eq!(total, 100); + assert!(tracker.is_complete()); + + tracker.assert_complete(); + group.assert_complete(); + } + + #[test] + fn test_increment_file_size_basic() { + let group = GroupProgress::new(); + let tracker = CompletionTracker::new(group.clone()); + + let updater = group.new_item(UniqueID::new(), "growingFile"); + let file_id = tracker.register_new_file(updater, None); + + let (done, total) = tracker.status(); assert_eq!(done, 0); assert_eq!(total, 0); - // Increment size in steps - tracker.increment_file_size(file_id, 100).await; - let (done, total) = tracker.status().await; + tracker.increment_file_size(file_id, 100); + let (done, total) = tracker.status(); assert_eq!(done, 0); assert_eq!(total, 100); - tracker.increment_file_size(file_id, 150).await; - let (done, total) = tracker.status().await; + tracker.increment_file_size(file_id, 150); + let (done, total) = tracker.status(); assert_eq!(done, 0); assert_eq!(total, 250); - tracker.increment_file_size(file_id, 50).await; - let (done, total) = tracker.status().await; + tracker.increment_file_size(file_id, 50); + let (done, total) = tracker.status(); assert_eq!(done, 0); assert_eq!(total, 300); - // Complete the file via an external dependency let x = MerkleHash::random_from_seed(1); - tracker - .register_dependencies(&[FileXorbDependency { - file_id, - xorb_hash: x, - n_bytes: 300, - is_external: true, - }]) - .await; + tracker.register_dependencies(&[FileXorbDependency { + file_id, + xorb_hash: x, + n_bytes: 300, + is_external: true, + }]); - let (done, total) = tracker.status().await; + let (done, total) = tracker.status(); assert_eq!(done, 300); assert_eq!(total, 300); - assert!(tracker.is_complete().await); + assert!(tracker.is_complete()); - tracker.assert_complete().await; - verifier.assert_complete().await; + tracker.assert_complete(); + group.assert_complete(); } - /// Register a file with unknown size, increment alongside dependency registration. - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn test_increment_file_size_with_xorb_uploads() { - let no_op = NoOpProgressUpdater::new(); - let verifier = ProgressUpdaterVerificationWrapper::new(no_op); - let tracker = CompletionTracker::new(verifier.clone()); + #[test] + fn test_increment_file_size_with_xorb_uploads() { + let group = GroupProgress::new(); + let tracker = CompletionTracker::new(group.clone()); - let file_id = tracker.register_new_file(Ulid::new(), "streamFile", None).await; + let updater = group.new_item(UniqueID::new(), "streamFile"); + let file_id = tracker.register_new_file(updater, None); let x1 = MerkleHash::random_from_seed(10); let x2 = MerkleHash::random_from_seed(20); - tracker.register_new_xorb(x1, 500).await; - tracker.register_new_xorb(x2, 500).await; + tracker.register_new_xorb(x1, 500); + tracker.register_new_xorb(x2, 500); - // Discover first chunk: increment by 200, register dependency on x1 for 200 bytes - tracker.increment_file_size(file_id, 200).await; - tracker - .register_dependencies(&[FileXorbDependency { - file_id, - xorb_hash: x1, - n_bytes: 200, - is_external: false, - }]) - .await; + tracker.increment_file_size(file_id, 200); + tracker.register_dependencies(&[FileXorbDependency { + file_id, + xorb_hash: x1, + n_bytes: 200, + is_external: false, + }]); - let (done, total) = tracker.status().await; + let (done, total) = tracker.status(); assert_eq!(done, 0); assert_eq!(total, 200); - // Upload x1 => file goes to 200/200 so far - tracker.register_xorb_upload_completion(x1).await; - let (done, total) = tracker.status().await; + tracker.register_xorb_upload_completion(x1); + let (done, total) = tracker.status(); assert_eq!(done, 200); assert_eq!(total, 200); - // Discover second chunk: increment by 300, register dependency on x2 for 300 bytes - tracker.increment_file_size(file_id, 300).await; - tracker - .register_dependencies(&[FileXorbDependency { - file_id, - xorb_hash: x2, - n_bytes: 300, - is_external: false, - }]) - .await; + tracker.increment_file_size(file_id, 300); + tracker.register_dependencies(&[FileXorbDependency { + file_id, + xorb_hash: x2, + n_bytes: 300, + is_external: false, + }]); - let (done, total) = tracker.status().await; + let (done, total) = tracker.status(); assert_eq!(done, 200); assert_eq!(total, 500); - // Upload x2 => file goes to 500/500 - tracker.register_xorb_upload_completion(x2).await; - let (done, total) = tracker.status().await; + tracker.register_xorb_upload_completion(x2); + let (done, total) = tracker.status(); assert_eq!(done, 500); assert_eq!(total, 500); - assert!(tracker.is_complete().await); + assert!(tracker.is_complete()); - tracker.assert_complete().await; - verifier.assert_complete().await; + tracker.assert_complete(); + group.assert_complete(); } - /// Multiple files, one with known size and one with unknown size that gets incremented. - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn test_increment_file_size_mixed_known_unknown() { - let no_op = NoOpProgressUpdater::new(); - let verifier = ProgressUpdaterVerificationWrapper::new(no_op); - let tracker = CompletionTracker::new(verifier.clone()); + #[test] + fn test_increment_file_size_mixed_known_unknown() { + let group = GroupProgress::new(); + let tracker = CompletionTracker::new(group.clone()); - // fileA has known size, fileB does not - let file_a = tracker.register_new_file(Ulid::new(), "fileA", Some(100)).await; - let file_b = tracker.register_new_file(Ulid::new(), "fileB", None).await; + let updater_a = group.new_item(UniqueID::new(), "fileA"); + let file_a = tracker.register_new_file(updater_a, Some(100)); - let (done, total) = tracker.status().await; + let updater_b = group.new_item(UniqueID::new(), "fileB"); + let file_b = tracker.register_new_file(updater_b, None); + + let (done, total) = tracker.status(); assert_eq!(done, 0); assert_eq!(total, 100); - // Complete fileA immediately via external dep let xa = MerkleHash::random_from_seed(1); - tracker - .register_dependencies(&[FileXorbDependency { - file_id: file_a, - xorb_hash: xa, - n_bytes: 100, - is_external: true, - }]) - .await; + tracker.register_dependencies(&[FileXorbDependency { + file_id: file_a, + xorb_hash: xa, + n_bytes: 100, + is_external: true, + }]); - let (done, total) = tracker.status().await; + let (done, total) = tracker.status(); assert_eq!(done, 100); assert_eq!(total, 100); - // fileB discovers its size incrementally and gets deps - tracker.increment_file_size(file_b, 200).await; - let (done, total) = tracker.status().await; + tracker.increment_file_size(file_b, 200); + let (done, total) = tracker.status(); assert_eq!(done, 100); assert_eq!(total, 300); let xb = MerkleHash::random_from_seed(2); - tracker.register_new_xorb(xb, 200).await; - tracker - .register_dependencies(&[FileXorbDependency { - file_id: file_b, - xorb_hash: xb, - n_bytes: 200, - is_external: false, - }]) - .await; + tracker.register_new_xorb(xb, 200); + tracker.register_dependencies(&[FileXorbDependency { + file_id: file_b, + xorb_hash: xb, + n_bytes: 200, + is_external: false, + }]); - tracker.register_xorb_upload_completion(xb).await; + tracker.register_xorb_upload_completion(xb); - let (done, total) = tracker.status().await; + let (done, total) = tracker.status(); assert_eq!(done, 300); assert_eq!(total, 300); - assert!(tracker.is_complete().await); + assert!(tracker.is_complete()); - tracker.assert_complete().await; - verifier.assert_complete().await; + tracker.assert_complete(); + group.assert_complete(); } - /// File registered with Some(size) ignores increment_file_size calls. - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn test_increment_file_size_ignored_when_already_final() { - let no_op = NoOpProgressUpdater::new(); - let verifier = ProgressUpdaterVerificationWrapper::new(no_op); - let tracker = CompletionTracker::new(verifier.clone()); + #[test] + fn test_increment_file_size_ignored_when_already_final() { + let group = GroupProgress::new(); + let tracker = CompletionTracker::new(group.clone()); - // Register with a known size (is_final_size_known = true) - let file_id = tracker.register_new_file(Ulid::new(), "fixedFile", Some(100)).await; + let updater = group.new_item(UniqueID::new(), "fixedFile"); + let file_id = tracker.register_new_file(updater, Some(100)); - // Attempt to increment -- should be ignored - tracker.increment_file_size(file_id, 999).await; - let (_, total) = tracker.status().await; + tracker.increment_file_size(file_id, 999); + let (_, total) = tracker.status(); assert_eq!(total, 100); - // Complete the file let x = MerkleHash::random_from_seed(1); - tracker - .register_dependencies(&[FileXorbDependency { - file_id, - xorb_hash: x, - n_bytes: 100, - is_external: true, - }]) - .await; + tracker.register_dependencies(&[FileXorbDependency { + file_id, + xorb_hash: x, + n_bytes: 100, + is_external: true, + }]); - assert!(tracker.is_complete().await); - tracker.assert_complete().await; - verifier.assert_complete().await; + assert!(tracker.is_complete()); + tracker.assert_complete(); + group.assert_complete(); } - /// File size increment with partial xorb upload progress. - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn test_increment_file_size_with_partial_xorb_progress() { - let no_op = NoOpProgressUpdater::new(); - let verifier = ProgressUpdaterVerificationWrapper::new(no_op); - let tracker = CompletionTracker::new(verifier.clone()); + #[test] + fn test_increment_file_size_with_partial_xorb_progress() { + let group = GroupProgress::new(); + let tracker = CompletionTracker::new(group.clone()); - let file_id = tracker.register_new_file(Ulid::new(), "partialFile", None).await; + let updater = group.new_item(UniqueID::new(), "partialFile"); + let file_id = tracker.register_new_file(updater, None); let x = MerkleHash::random_from_seed(42); - tracker.register_new_xorb(x, 1000).await; + tracker.register_new_xorb(x, 1000); - // Increment to initial size, register dep - tracker.increment_file_size(file_id, 400).await; - tracker - .register_dependencies(&[FileXorbDependency { - file_id, - xorb_hash: x, - n_bytes: 400, - is_external: false, - }]) - .await; + tracker.increment_file_size(file_id, 400); + tracker.register_dependencies(&[FileXorbDependency { + file_id, + xorb_hash: x, + n_bytes: 400, + is_external: false, + }]); - // Partial upload progress on the xorb - tracker.register_xorb_upload_progress(x, 500).await; - let (done, total) = tracker.status().await; + tracker.register_xorb_upload_progress(x, 500); + let (done, total) = tracker.status(); assert_eq!(total, 400); - // Partial progress means some fraction of 400 is done assert!(done > 0); assert!(done < 400); - // Grow the file by 200 more bytes - tracker.increment_file_size(file_id, 200).await; - let (_, total) = tracker.status().await; + tracker.increment_file_size(file_id, 200); + let (_, total) = tracker.status(); assert_eq!(total, 600); - // Register the additional 200 bytes as external (already uploaded) - tracker - .register_dependencies(&[FileXorbDependency { - file_id, - xorb_hash: MerkleHash::random_from_seed(99), - n_bytes: 200, - is_external: true, - }]) - .await; + tracker.register_dependencies(&[FileXorbDependency { + file_id, + xorb_hash: MerkleHash::random_from_seed(99), + n_bytes: 200, + is_external: true, + }]); - // Complete the xorb - tracker.register_xorb_upload_completion(x).await; + tracker.register_xorb_upload_completion(x); - let (done, total) = tracker.status().await; + let (done, total) = tracker.status(); assert_eq!(done, 600); assert_eq!(total, 600); - assert!(tracker.is_complete().await); + assert!(tracker.is_complete()); - tracker.assert_complete().await; - verifier.assert_complete().await; + tracker.assert_complete(); + group.assert_complete(); } } diff --git a/xet_data/tests/test_full_file_download.rs b/xet_data/tests/test_full_file_download.rs index eaa6e1ec..f397ea9b 100644 --- a/xet_data/tests/test_full_file_download.rs +++ b/xet_data/tests/test_full_file_download.rs @@ -11,15 +11,14 @@ mod tests { use std::sync::Arc; use tempfile::TempDir; - use ulid::Ulid; use xet_client::cas_client::LocalTestServerBuilder; use xet_data::processing::configurations::TranslatorConfig; use xet_data::processing::{FileDownloadSession, FileUploadSession, Sha256Policy, XetFileInfo}; async fn upload_bytes(upload_session: &Arc, name: &str, data: &[u8]) -> XetFileInfo { - let mut cleaner = upload_session - .start_clean(Some(name.into()), data.len() as u64, Sha256Policy::Compute, Ulid::new()) - .await; + let (_id, mut cleaner) = upload_session + .start_clean(Some(name.into()), data.len() as u64, Sha256Policy::Compute) + .unwrap(); cleaner.add_data(data).await.unwrap(); let (xfi, _metrics) = cleaner.finish().await.unwrap(); xfi @@ -40,15 +39,15 @@ mod tests { ("larger", &vec![0xCD; 64 * 1024]), ]; - let download_session = FileDownloadSession::new(config.clone(), None).await.unwrap(); + let download_session = FileDownloadSession::new(config.clone()).await.unwrap(); for (name, data) in test_cases { - let upload_session = FileUploadSession::new(config.clone(), None).await.unwrap(); + let upload_session = FileUploadSession::new(config.clone()).await.unwrap(); let xfi = upload_bytes(&upload_session, name, data).await; upload_session.finalize().await.unwrap(); let out_path = base_dir.path().join(format!("out_{name}")); - let n_bytes = download_session.download_file(&xfi, &out_path, Ulid::new()).await.unwrap(); + let (_id, n_bytes) = download_session.download_file(&xfi, &out_path).await.unwrap(); assert_eq!(n_bytes, data.len() as u64, "size mismatch for {name}"); assert_eq!(fs::read(&out_path).unwrap(), *data, "content mismatch for {name}"); diff --git a/xet_data/tests/test_session_resume.rs b/xet_data/tests/test_session_resume.rs index 440b4d27..817e3d66 100644 --- a/xet_data/tests/test_session_resume.rs +++ b/xet_data/tests/test_session_resume.rs @@ -36,10 +36,8 @@ mod tests { use more_asserts::*; use rand::prelude::*; - use ulid::Ulid; use xet_data::deduplication::constants::MAX_CHUNK_SIZE; use xet_data::processing::test_utils::{HydrateDehydrateTest, create_random_file, create_random_files}; - use xet_data::progress_tracking::aggregator::AggregatingProgressUpdater; use super::*; #[tokio::test(flavor = "multi_thread", worker_threads = 2)] @@ -48,78 +46,6 @@ mod tests { let n = 8 * 1024; let half_n = n / 2; - - let hn = half_n as u64; - - // Get a sizable block of random data - let mut data = vec![0u8; n]; - let mut rng = StdRng::seed_from_u64(0); - rng.fill(&mut data[..]); - - let server = LocalTestServerBuilder::new().start().await; - let shard_base = TempDir::new().unwrap(); - let config = Arc::new(TranslatorConfig::test_server_config(server.http_endpoint(), shard_base.path()).unwrap()); - - { - let progress_tracker = AggregatingProgressUpdater::new_aggregation_only(); - let file_upload_session = FileUploadSession::new(config.clone(), Some(progress_tracker.clone())) - .await - .unwrap(); - - // Feed it half the data, and checkpoint. - let mut cleaner = file_upload_session - .start_clean(Some("data".into()), data.len() as u64, Sha256Policy::Compute, Ulid::new()) - .await; - cleaner.add_data(&data[..half_n]).await.unwrap(); - cleaner.checkpoint().await.unwrap(); - - // Checkpoint to ensure all xorbs get uploaded. - file_upload_session.checkpoint().await.unwrap(); - - // Break without closing down the file session; we should resume partway through. - } - - // Now try again to test the resume. - { - let progress_tracker = AggregatingProgressUpdater::new_aggregation_only(); - let file_upload_session = FileUploadSession::new(config, Some(progress_tracker.clone())).await.unwrap(); - - // Feed it half the data, and checkpoint. - let mut cleaner = file_upload_session - .start_clean(Some("data".into()), data.len() as u64, Sha256Policy::Compute, Ulid::new()) - .await; - - // Add all the data. Roughly the first half should dedup. - cleaner.add_data(&data).await.unwrap(); - cleaner.finish().await.unwrap(); - - // Finalize everything - file_upload_session.finalize().await.unwrap(); - - let progress = progress_tracker.get_aggregated_state().await; - - let max_deviance = (*MAX_XORB_BYTES + *MAX_CHUNK_SIZE) as u64; - - let n = n as u64; - - // Check things. The checkpoint above pushes everything through. - assert_eq!(progress.total_bytes_completed, n); - assert_eq!(progress.total_bytes, n); - - // The difference is the amount deduplicated; the half_n pass above should have - // left quite a bit to deduplicate. - assert_le!(progress.total_transfer_bytes, hn + max_deviance); - assert_le!(progress.total_transfer_bytes_completed, hn + max_deviance); - } - } - - #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn test_multiple_resume() { - // Ensure the deduplication numbers are approximately accurate. - - let n = 256 * 1024; - let resume_n = [16 * 1024, 16 * 1024, 64 * 1024, 128 * 1024, 240 * 1024]; - let max_deviance = (*MAX_XORB_BYTES + *MAX_CHUNK_SIZE) as u64; // Get a sizable block of random data @@ -131,34 +57,84 @@ mod tests { let shard_base = TempDir::new().unwrap(); let config = Arc::new(TranslatorConfig::test_server_config(server.http_endpoint(), shard_base.path()).unwrap()); + { + let file_upload_session = FileUploadSession::new(config.clone()).await.unwrap(); + + // Feed it half the data, and checkpoint. + let (_id, mut cleaner) = file_upload_session + .start_clean(Some("data".into()), data.len() as u64, Sha256Policy::Compute) + .unwrap(); + cleaner.add_data(&data[..half_n]).await.unwrap(); + cleaner.checkpoint().await.unwrap(); + + // Checkpoint to ensure all xorbs get uploaded. + file_upload_session.checkpoint().await.unwrap(); + let report = file_upload_session.report(); + assert_eq!(report.total_bytes, n as u64); + assert_le!(report.total_bytes_completed, half_n as u64 + *MAX_CHUNK_SIZE as u64); + assert_le!(report.total_transfer_bytes_completed, report.total_transfer_bytes); + + // Break without closing down the file session; we should resume partway through. + } + + // Now try again to test the resume. + { + let file_upload_session = FileUploadSession::new(config).await.unwrap(); + + // Feed it half the data, and checkpoint. + let (_id, mut cleaner) = file_upload_session + .start_clean(Some("data".into()), data.len() as u64, Sha256Policy::Compute) + .unwrap(); + + // Add all the data. Roughly the first half should dedup. + cleaner.add_data(&data).await.unwrap(); + cleaner.finish().await.unwrap(); + + let report = file_upload_session.report(); + assert!(report.total_bytes > 0); + assert_le!(report.total_bytes_completed, report.total_bytes); + assert_le!(report.total_transfer_bytes_completed, report.total_transfer_bytes); + assert_le!(report.total_transfer_bytes, half_n as u64 + max_deviance); + assert_le!(report.total_transfer_bytes_completed, half_n as u64 + max_deviance); + file_upload_session.finalize().await.unwrap(); + } + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_multiple_resume() { + // Ensure the deduplication numbers are approximately accurate. + + let n = 256 * 1024; + let resume_n = [16 * 1024, 16 * 1024, 64 * 1024, 128 * 1024, 240 * 1024]; + + // Get a sizable block of random data + let mut data = vec![0u8; n]; + let mut rng = StdRng::seed_from_u64(0); + rng.fill(&mut data[..]); + + let server = LocalTestServerBuilder::new().start().await; + let shard_base = TempDir::new().unwrap(); + let config = Arc::new(TranslatorConfig::test_server_config(server.http_endpoint(), shard_base.path()).unwrap()); + let max_deviance = (*MAX_XORB_BYTES + *MAX_CHUNK_SIZE) as u64; + let mut prev_rn = 0; for rn in resume_n { - let progress_tracker = AggregatingProgressUpdater::new_aggregation_only(); - let file_upload_session = FileUploadSession::new(config.clone(), Some(progress_tracker.clone())) - .await - .unwrap(); + let file_upload_session = FileUploadSession::new(config.clone()).await.unwrap(); // Feed it half the data, and checkpoint. - let mut cleaner = file_upload_session - .start_clean(Some("data".into()), data.len() as u64, Sha256Policy::Compute, Ulid::new()) - .await; + let (_id, mut cleaner) = file_upload_session + .start_clean(Some("data".into()), data.len() as u64, Sha256Policy::Compute) + .unwrap(); cleaner.add_data(&data[..rn]).await.unwrap(); cleaner.checkpoint().await.unwrap(); // Checkpoint to ensure all xorbs get uploaded. file_upload_session.checkpoint().await.unwrap(); - - if prev_rn > 0 { - let progress = progress_tracker.get_aggregated_state().await; - - // Because some of it may remain in the chunker, so it won't be exact. - assert_le!(progress.total_bytes_completed, (rn + *MAX_CHUNK_SIZE) as u64); - - // Make sure the total number of transfering bytes isn't fully - assert_le!(progress.total_transfer_bytes, prev_rn + max_deviance + *MAX_CHUNK_SIZE as u64); - assert_le!(progress.total_transfer_bytes_completed, prev_rn + max_deviance + *MAX_CHUNK_SIZE as u64); - } + let report = file_upload_session.report(); + assert_eq!(report.total_bytes, n as u64); + assert_le!(report.total_bytes_completed, rn as u64 + *MAX_CHUNK_SIZE as u64); + assert_le!(report.total_transfer_bytes_completed, report.total_transfer_bytes); // To test the next round. prev_rn = rn as u64; @@ -168,33 +144,24 @@ mod tests { // Now try again to test the resume. { - let progress_tracker = AggregatingProgressUpdater::new_aggregation_only(); - let file_upload_session = FileUploadSession::new(config, Some(progress_tracker.clone())).await.unwrap(); + let file_upload_session = FileUploadSession::new(config).await.unwrap(); // Feed it half the data, and checkpoint. - let mut cleaner = file_upload_session - .start_clean(Some("data".into()), data.len() as u64, Sha256Policy::Compute, Ulid::new()) - .await; + let (_id, mut cleaner) = file_upload_session + .start_clean(Some("data".into()), data.len() as u64, Sha256Policy::Compute) + .unwrap(); // Add all the data. Roughly the first half should dedup. cleaner.add_data(&data).await.unwrap(); cleaner.finish().await.unwrap(); - // Finalize everything + let report = file_upload_session.report(); + assert!(report.total_bytes > 0); + assert_le!(report.total_bytes_completed, report.total_bytes); + assert_le!(report.total_transfer_bytes_completed, report.total_transfer_bytes); + assert_le!(report.total_transfer_bytes, prev_rn + max_deviance); + assert_le!(report.total_transfer_bytes_completed, prev_rn + max_deviance); file_upload_session.finalize().await.unwrap(); - - let progress = progress_tracker.get_aggregated_state().await; - - let n = n as u64; - - // Check things. The checkpoint above pushes everything through. - assert_eq!(progress.total_bytes_completed, n); - assert_eq!(progress.total_bytes, n); - - // The difference is the amount deduplicated; the half_n pass above should have - // left quite a bit to deduplicate. - assert_le!(progress.total_transfer_bytes, prev_rn + max_deviance); - assert_le!(progress.total_transfer_bytes_completed, prev_rn + max_deviance); } } @@ -216,21 +183,14 @@ mod tests { // Clean the files present, but drop the upload session. { - let progress_tracker = AggregatingProgressUpdater::new_aggregation_only(); - let upload_session = ts.new_upload_session(Some(progress_tracker.clone())).await; + let upload_session = ts.new_upload_session().await; ts.clean_all_files(&upload_session, false).await; upload_session.checkpoint().await.unwrap(); - - let progress = progress_tracker.get_aggregated_state().await; - - // Check things. The checkpoint above pushes everything through, even though we don't finalize. - assert_eq!(progress.total_bytes, 64 * 1024); - assert_eq!(progress.total_bytes_completed, 64 * 1024); - - // Here, all the files would have completed, meaning that all their bytes and xorbs are transfered. - assert_eq!(progress.total_transfer_bytes, 64 * 1024); - assert_eq!(progress.total_transfer_bytes_completed, 64 * 1024); + let report = upload_session.report(); + assert!(report.total_bytes > 0); + assert_le!(report.total_bytes_completed, report.total_bytes); + assert_le!(report.total_transfer_bytes_completed, report.total_transfer_bytes); // Now interrupt the session and don't call finalize } @@ -249,22 +209,14 @@ mod tests { // Test these files and actually call finalize. { - let progress_tracker = AggregatingProgressUpdater::new_aggregation_only(); - let upload_session = ts.new_upload_session(Some(progress_tracker.clone())).await; + let upload_session = ts.new_upload_session().await; ts.clean_all_files(&upload_session, false).await; - // Finalize things this time. + let report = upload_session.report(); + assert!(report.total_bytes > 0); + assert_le!(report.total_bytes_completed, report.total_bytes); + assert_le!(report.total_transfer_bytes_completed, report.total_transfer_bytes); upload_session.finalize().await.unwrap(); - - let progress = progress_tracker.get_aggregated_state().await; - - // Check things. The checkpoint above pushes everything through, even though we don't finalize. - assert_eq!(progress.total_bytes, 128 * 1024); - assert_eq!(progress.total_bytes_completed, 128 * 1024); - - // Here, all the previous files would have been deduped against, so only the new content would be uploaded. - assert_eq!(progress.total_transfer_bytes, 64 * 1024); - assert_eq!(progress.total_transfer_bytes_completed, 64 * 1024); } // Finally, verify that hydration works successfully. @@ -283,21 +235,14 @@ mod tests { // Clean the files present, but drop the upload session. { - let progress_tracker = AggregatingProgressUpdater::new_aggregation_only(); - let upload_session = ts.new_upload_session(Some(progress_tracker.clone())).await; + let upload_session = ts.new_upload_session().await; ts.clean_all_files(&upload_session, false).await; upload_session.checkpoint().await.unwrap(); - - let progress = progress_tracker.get_aggregated_state().await; - - // Check things. The checkpoint above pushes everything through, even though we don't finalize. - assert_eq!(progress.total_bytes, 128); - assert_eq!(progress.total_bytes_completed, 128); - - // Here, all the files would have completed, meaning that all their bytes and xorbs are transfered. - assert_eq!(progress.total_transfer_bytes, 128); - assert_eq!(progress.total_transfer_bytes_completed, 128); + let report = upload_session.report(); + assert!(report.total_bytes > 0); + assert_le!(report.total_bytes_completed, report.total_bytes); + assert_le!(report.total_transfer_bytes_completed, report.total_transfer_bytes); // Now interrupt the session and don't call finalize } @@ -306,22 +251,14 @@ mod tests { // Test these files and actually call finalize. { - let progress_tracker = AggregatingProgressUpdater::new_aggregation_only(); - let upload_session = ts.new_upload_session(Some(progress_tracker.clone())).await; + let upload_session = ts.new_upload_session().await; ts.clean_all_files(&upload_session, false).await; - // Finalize things this time. + let report = upload_session.report(); + assert!(report.total_bytes > 0); + assert_le!(report.total_bytes_completed, report.total_bytes); + assert_le!(report.total_transfer_bytes_completed, report.total_transfer_bytes); upload_session.finalize().await.unwrap(); - - let progress = progress_tracker.get_aggregated_state().await; - - // Check things. The checkpoint above pushes everything through, even though we don't finalize. - assert_eq!(progress.total_bytes, 256); - assert_eq!(progress.total_bytes_completed, 256); - - // Here, all the previous files would have been deduped against, so only the new content would be uploaded. - assert_eq!(progress.total_transfer_bytes, 128); - assert_eq!(progress.total_transfer_bytes_completed, 128); } // Finally, verify that hydration works successfully. diff --git a/xet_pkg/Cargo.toml b/xet_pkg/Cargo.toml index 50d6432a..6a6e4e3d 100644 --- a/xet_pkg/Cargo.toml +++ b/xet_pkg/Cargo.toml @@ -20,8 +20,8 @@ xet-data = { version = "1.4.0", path = "../xet_data" } async-trait = { workspace = true } http = { workspace = true } -tokio = { workspace = true, features = ["net"] } -ulid = { workspace = true } +more-asserts = { workspace = true } +tokio = { workspace = true, features = ["net", "time"] } thiserror = { workspace = true } tracing = { workspace = true } serde = { workspace = true, features = ["derive"] } @@ -32,8 +32,10 @@ python = ["xet-runtime/python"] [dev-dependencies] async-std = { workspace = true } futures = { workspace = true } +more-asserts = { workspace = true } smol = { workspace = true } tempfile = { workspace = true } +tokio = { workspace = true, features = ["rt-multi-thread", "rt", "time", "macros"] } serial_test = { workspace = true } anyhow = { workspace = true } clap = { workspace = true, features = ["derive"] } diff --git a/xet_pkg/examples/example.rs b/xet_pkg/examples/example.rs index 4eb90f08..3b113b01 100644 --- a/xet_pkg/examples/example.rs +++ b/xet_pkg/examples/example.rs @@ -1,6 +1,6 @@ //! Async session-based upload/download example. //! -//! Mirror of `example.rs` using the async API (`UploadCommit` / `DownloadGroup`). +//! Mirror of `example_sync.rs` using the async API (`UploadCommit` / `DownloadGroup`). //! Requires an async runtime — here provided by `#[tokio::main]`. use std::path::PathBuf; @@ -70,13 +70,12 @@ async fn upload_files(files: Vec, endpoint: Option) -> Result<( let commit_for_progress = commit.clone(); tokio::spawn(async move { loop { - if let Ok(snapshot) = commit_for_progress.get_progress() { - let p = snapshot.total(); + if let Ok(report) = commit_for_progress.get_progress() { let done = handles .iter() .filter(|h: &&UploadTaskHandle| matches!(h.status(), Ok(TaskStatus::Completed))) .count(); - println!("{}/{} files | {}/{} bytes", done, n_files, p.total_bytes_completed, p.total_bytes); + println!("{}/{} files | {}/{} bytes", done, n_files, report.total_bytes_completed, report.total_bytes); } tokio::time::sleep(std::time::Duration::from_millis(100)).await; } @@ -115,27 +114,30 @@ async fn download_files(metadata_file: PathBuf, output_dir: PathBuf, endpoint: O let mut handles: Vec = Vec::with_capacity(n_files); for m in &metadata { let dest = output_dir.join(m.tracking_name.as_deref().unwrap_or("file")); - handles.push(group.download_file_to_path( - XetFileInfo { - hash: m.hash.clone(), - file_size: m.file_size, - sha256: m.sha256.clone(), - }, - dest, - )?); + handles.push( + group + .download_file_to_path( + XetFileInfo { + hash: m.hash.clone(), + file_size: m.file_size, + sha256: m.sha256.clone(), + }, + dest, + ) + .await?, + ); } // Spawn a task to print progress while the main task awaits finish(). let group_for_progress = group.clone(); tokio::spawn(async move { loop { - if let Ok(snapshot) = group_for_progress.get_progress() { - let p = snapshot.total(); + if let Ok(report) = group_for_progress.get_progress() { let done = handles .iter() .filter(|h| matches!(h.status(), Ok(TaskStatus::Completed))) .count(); - println!("{}/{} files | {}/{} bytes", done, n_files, p.total_bytes_completed, p.total_bytes); + println!("{}/{} files | {}/{} bytes", done, n_files, report.total_bytes_completed, report.total_bytes); } tokio::time::sleep(std::time::Duration::from_millis(100)).await; } diff --git a/xet_pkg/examples/example_sync.rs b/xet_pkg/examples/example_sync.rs index 02ca58b5..b120b54c 100644 --- a/xet_pkg/examples/example_sync.rs +++ b/xet_pkg/examples/example_sync.rs @@ -67,13 +67,12 @@ fn upload_files(files: Vec, endpoint: Option) -> Result<()> { let commit_for_progress = commit.clone(); std::thread::spawn(move || { loop { - if let Ok(snapshot) = commit_for_progress.get_progress() { - let p = snapshot.total(); + if let Ok(report) = commit_for_progress.get_progress_blocking() { let done = handles .iter() .filter(|h| matches!(h.status(), Ok(TaskStatus::Completed))) .count(); - println!("{}/{} files | {}/{} bytes", done, n_files, p.total_bytes_completed, p.total_bytes); + println!("{}/{} files | {}/{} bytes", done, n_files, report.total_bytes_completed, report.total_bytes); } std::thread::sleep(Duration::from_millis(100)); } @@ -112,7 +111,7 @@ fn download_files(metadata_file: PathBuf, output_dir: PathBuf, endpoint: Option< let mut handles = Vec::with_capacity(n_files); for m in &metadata { let dest = output_dir.join(m.tracking_name.as_deref().unwrap_or("file")); - handles.push(group.download_file_to_path( + handles.push(group.download_file_to_path_blocking( XetFileInfo { hash: m.hash.clone(), file_size: m.file_size, @@ -126,13 +125,12 @@ fn download_files(metadata_file: PathBuf, output_dir: PathBuf, endpoint: Option< let group_for_progress = group.clone(); std::thread::spawn(move || { loop { - if let Ok(snapshot) = group_for_progress.get_progress() { - let p = snapshot.total(); + if let Ok(report) = group_for_progress.get_progress_blocking() { let done = handles .iter() .filter(|h| matches!(h.status(), Ok(TaskStatus::Completed))) .count(); - println!("{}/{} files | {}/{} bytes", done, n_files, p.total_bytes_completed, p.total_bytes); + println!("{}/{} files | {}/{} bytes", done, n_files, report.total_bytes_completed, report.total_bytes); } std::thread::sleep(Duration::from_millis(100)); } diff --git a/xet_pkg/src/error.rs b/xet_pkg/src/error.rs index 0b8c2802..e57dd3da 100644 --- a/xet_pkg/src/error.rs +++ b/xet_pkg/src/error.rs @@ -1,8 +1,8 @@ use thiserror::Error; -use ulid::Ulid; use xet_client::ClientError; use xet_core_structures::FormatError; use xet_data::DataError; +use xet_data::progress_tracking::UniqueID; use xet_runtime::RuntimeError; /// Unified error type for the Xet public API. @@ -28,7 +28,7 @@ pub enum XetError { /// A task ID that doesn't correspond to any queued file. #[error("Invalid task ID: {0}")] - InvalidTaskID(Ulid), + InvalidTaskID(UniqueID), // -- User-facing error categories ------------------------------------ /// Token refresh or credential failures. @@ -137,6 +137,7 @@ impl XetError { | DataError::DeprecatedError(_) => XetError::Configuration(de.to_string()), DataError::HashNotFound => XetError::NotFound(de.to_string()), DataError::HashStringParsingFailure(_) => XetError::DataIntegrity(de.to_string()), + DataError::InvalidOperation(_) => XetError::Configuration(de.to_string()), _ => XetError::Internal(de.to_string()), } } diff --git a/xet_pkg/src/legacy/data_client.rs b/xet_pkg/src/legacy/data_client.rs new file mode 100644 index 00000000..4fceff03 --- /dev/null +++ b/xet_pkg/src/legacy/data_client.rs @@ -0,0 +1,187 @@ +use std::path::PathBuf; +use std::sync::Arc; + +use http::header::HeaderMap; +use tracing::{Instrument, Span, info_span, instrument}; +use xet_client::cas_client::auth::TokenRefresher; +pub use xet_data::processing::data_client::hash_files_async; +use xet_data::processing::data_client::{clean_bytes, default_config}; +use xet_data::processing::errors::DataProcessingError; +use xet_data::processing::{FileDownloadSession, FileUploadSession, Sha256Policy, XetFileInfo}; +use xet_runtime::core::par_utils::run_constrained_with_semaphore; +use xet_runtime::core::{XetRuntime, xet_config}; + +use super::progress_tracking::{GroupProgressCallbackUpdater, ItemProgressCallbackUpdater, TrackingProgressUpdater}; +use crate::legacy::data_client::errors::Result; + +mod errors { + pub use xet_data::processing::errors::Result; +} + +#[instrument(skip_all, name = "data_client::upload_bytes", fields(session_id = tracing::field::Empty, num_files=file_contents.len()))] +pub async fn upload_bytes_async( + file_contents: Vec>, + sha256_policies: Vec, + endpoint: Option, + token_info: Option<(String, u64)>, + token_refresher: Option>, + progress_updater: Option>, + custom_headers: Option>, +) -> Result> { + if sha256_policies.len() != file_contents.len() { + return Err(DataProcessingError::ParameterError(format!( + "sha256_policies length ({}) must match file_contents length ({})", + sha256_policies.len(), + file_contents.len() + ))); + } + + let config = default_config( + endpoint.unwrap_or_else(|| xet_config().data.default_cas_endpoint.clone()), + token_info, + token_refresher, + custom_headers, + )?; + + Span::current().record("session_id", &config.session.session_id); + + let semaphore = XetRuntime::current().common().file_ingestion_semaphore.clone(); + let upload_session = FileUploadSession::new(config.into()).await?; + + let bridge = progress_updater.map(|updater| GroupProgressCallbackUpdater::start(upload_session.clone(), updater)); + + let clean_futures = file_contents.into_iter().zip(sha256_policies).map(|(blob, policy)| { + let upload_session = upload_session.clone(); + async move { clean_bytes(upload_session, blob, policy).await.map(|(xf, _metrics)| xf) } + .instrument(info_span!("clean_task")) + }); + let files = run_constrained_with_semaphore(clean_futures, semaphore).await?; + + let _metrics = upload_session.finalize().await?; + + if let Some(bridge) = bridge { + bridge.finalize().await; + } + + Ok(files) +} + +#[instrument(skip_all, name = "data_client::upload_files", + fields(session_id = tracing::field::Empty, + num_files=file_paths.len(), + new_bytes = tracing::field::Empty, + deduped_bytes = tracing::field::Empty, + defrag_prevented_dedup_bytes = tracing::field::Empty, + new_chunks = tracing::field::Empty, + deduped_chunks = tracing::field::Empty, + defrag_prevented_dedup_chunks = tracing::field::Empty + ))] +pub async fn upload_async( + file_paths: Vec, + sha256_policies: Vec, + endpoint: Option, + token_info: Option<(String, u64)>, + token_refresher: Option>, + progress_updater: Option>, + custom_headers: Option>, +) -> Result> { + if sha256_policies.len() != file_paths.len() { + return Err(DataProcessingError::ParameterError(format!( + "sha256_policies length ({}) must match file_paths length ({})", + sha256_policies.len(), + file_paths.len() + ))); + } + + let config = default_config( + endpoint.unwrap_or_else(|| xet_config().data.default_cas_endpoint.clone()), + token_info, + token_refresher, + custom_headers, + )?; + + let span = Span::current(); + + span.record("session_id", &config.session.session_id); + + let upload_session = FileUploadSession::new(config.into()).await?; + + let bridge = progress_updater.map(|updater| GroupProgressCallbackUpdater::start(upload_session.clone(), updater)); + + let files_and_sha256 = file_paths.into_iter().zip(sha256_policies.into_iter()); + + let ret = upload_session.upload_files(files_and_sha256).await?; + + let metrics = upload_session.finalize().await?; + + if let Some(bridge) = bridge { + bridge.finalize().await; + } + + span.record("new_bytes", metrics.new_bytes); + span.record("deduped_bytes", metrics.deduped_bytes); + span.record("defrag_prevented_dedup_bytes", metrics.defrag_prevented_dedup_bytes); + span.record("new_chunks", metrics.new_chunks); + span.record("deduped_chunks", metrics.deduped_chunks); + span.record("defrag_prevented_dedup_chunks", metrics.defrag_prevented_dedup_chunks); + + Ok(ret) +} + +#[instrument(skip_all, name = "data_client::download", fields(session_id = tracing::field::Empty, num_files=file_infos.len()))] +pub async fn download_async( + file_infos: Vec<(XetFileInfo, String)>, + endpoint: Option, + token_info: Option<(String, u64)>, + token_refresher: Option>, + progress_updaters: Option>>, + custom_headers: Option>, +) -> Result> { + if let Some(updaters) = &progress_updaters + && updaters.len() != file_infos.len() + { + return Err(DataProcessingError::ParameterError("updaters are not same length as pointer_files".to_string())); + } + let config: Arc<_> = default_config( + endpoint.unwrap_or_else(|| xet_config().data.default_cas_endpoint.clone()), + token_info, + token_refresher, + custom_headers, + )? + .into(); + + Span::current().record("session_id", &config.session.session_id); + + let updaters: Vec>> = match progress_updaters { + None => vec![None; file_infos.len()], + Some(updaters) => updaters.into_iter().map(Some).collect(), + }; + + let session = FileDownloadSession::new(config).await?; + + let mut tasks = Vec::with_capacity(file_infos.len()); + let mut bridges: Vec> = Vec::with_capacity(file_infos.len()); + + for ((file_info, file_path), updater) in file_infos.into_iter().zip(updaters) { + let path = PathBuf::from(&file_path); + let (id, handle) = session.download_file_background(file_info, path).await?; + + let bridge = updater.map(|u| ItemProgressCallbackUpdater::start(session.clone(), id, u)); + + tasks.push((file_path, handle)); + bridges.push(bridge); + } + + let mut paths = Vec::with_capacity(tasks.len()); + for ((file_path, handle), bridge) in tasks.into_iter().zip(bridges) { + handle.await??; + + if let Some(bridge) = bridge { + bridge.finalize().await; + } + + paths.push(file_path); + } + + Ok(paths) +} diff --git a/xet_pkg/src/legacy/mod.rs b/xet_pkg/src/legacy/mod.rs new file mode 100644 index 00000000..ecf4cc82 --- /dev/null +++ b/xet_pkg/src/legacy/mod.rs @@ -0,0 +1,9 @@ +pub mod data_client; +pub mod progress_tracking; + +// Re-exports from xet_data so external consumers (hf_xet, git_xet) don't need +// a direct xet_data dependency. +pub use xet_data::processing::configurations::{SessionContext, TranslatorConfig}; +pub use xet_data::processing::data_client::{clean_bytes, clean_file, default_config, hash_files_async}; +pub use xet_data::processing::errors::DataProcessingError; +pub use xet_data::processing::{FileDownloadSession, FileUploadSession, Sha256Policy, XetFileInfo}; diff --git a/xet_pkg/src/legacy/progress_tracking/callback_bridge.rs b/xet_pkg/src/legacy/progress_tracking/callback_bridge.rs new file mode 100644 index 00000000..802c9d2e --- /dev/null +++ b/xet_pkg/src/legacy/progress_tracking/callback_bridge.rs @@ -0,0 +1,489 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use tokio::sync::Notify; +use xet_data::progress_tracking::{GroupProgressReport, ItemProgressReport, UniqueID}; + +use super::{ItemProgressUpdate, ProgressUpdate, TrackingProgressUpdater}; + +/// Trait for types that can produce progress reports via polling. +/// +/// Both `FileDownloadSession` and `FileUploadSession` implement this trait, +/// enabling the callback updaters to poll them and forward updates to +/// legacy `TrackingProgressUpdater` callbacks. +pub trait ProgressReporter: Send + Sync { + fn report(&self) -> GroupProgressReport; + fn item_reports(&self) -> HashMap; + fn item_report(&self, id: UniqueID) -> Option { + self.item_reports().remove(&id) + } +} + +impl ProgressReporter for xet_data::processing::FileDownloadSession { + fn report(&self) -> GroupProgressReport { + self.report() + } + fn item_reports(&self) -> HashMap { + self.item_reports() + } + fn item_report(&self, id: UniqueID) -> Option { + self.item_report(id) + } +} + +impl ProgressReporter for xet_data::processing::FileUploadSession { + fn report(&self) -> GroupProgressReport { + self.report() + } + fn item_reports(&self) -> HashMap { + self.item_reports() + } +} + +// === Bridge state for group-level diffing === + +struct GroupBridgeState { + prev_group: GroupProgressReport, + prev_items: HashMap, +} + +impl GroupBridgeState { + fn new() -> Self { + Self { + prev_group: GroupProgressReport::default(), + prev_items: HashMap::new(), + } + } + + fn compute_diff( + &mut self, + group: GroupProgressReport, + items: HashMap, + ) -> ProgressUpdate { + let total_bytes_increment = group.total_bytes.saturating_sub(self.prev_group.total_bytes); + let total_bytes_completion_increment = group + .total_bytes_completed + .saturating_sub(self.prev_group.total_bytes_completed); + let total_transfer_bytes_increment = + group.total_transfer_bytes.saturating_sub(self.prev_group.total_transfer_bytes); + let total_transfer_bytes_completion_increment = group + .total_transfer_bytes_completed + .saturating_sub(self.prev_group.total_transfer_bytes_completed); + + let mut item_updates = Vec::new(); + for (&id, report) in &items { + let prev = self.prev_items.get(&id); + let prev_completed = prev.map_or(0, |p| p.bytes_completed); + let increment = report.bytes_completed.saturating_sub(prev_completed); + + if increment > 0 || prev.is_none() { + item_updates.push(ItemProgressUpdate { + tracking_id: id, + item_name: Arc::from(report.item_name.as_str()), + total_bytes: report.total_bytes, + bytes_completed: report.bytes_completed, + bytes_completion_increment: increment, + }); + } + } + + let update = ProgressUpdate { + item_updates, + total_bytes: group.total_bytes, + total_bytes_increment, + total_bytes_completed: group.total_bytes_completed, + total_bytes_completion_increment, + total_bytes_completion_rate: group.total_bytes_completion_rate, + total_transfer_bytes: group.total_transfer_bytes, + total_transfer_bytes_increment, + total_transfer_bytes_completed: group.total_transfer_bytes_completed, + total_transfer_bytes_completion_increment, + total_transfer_bytes_completion_rate: group.total_transfer_bytes_completion_rate, + }; + + self.prev_group = group; + self.prev_items = items; + + update + } +} + +// === Bridge state for single-item diffing === + +struct ItemBridgeState { + prev: Option, +} + +impl ItemBridgeState { + fn new() -> Self { + Self { prev: None } + } + + fn compute_diff(&mut self, item_id: UniqueID, report: ItemProgressReport) -> ProgressUpdate { + let prev_completed = self.prev.as_ref().map_or(0, |p| p.bytes_completed); + let prev_total = self.prev.as_ref().map_or(0, |p| p.total_bytes); + + let bytes_increment = report.bytes_completed.saturating_sub(prev_completed); + let total_increment = report.total_bytes.saturating_sub(prev_total); + + let item_updates = if bytes_increment > 0 || self.prev.is_none() { + vec![ItemProgressUpdate { + tracking_id: item_id, + item_name: Arc::from(report.item_name.as_str()), + total_bytes: report.total_bytes, + bytes_completed: report.bytes_completed, + bytes_completion_increment: bytes_increment, + }] + } else { + Vec::new() + }; + + let update = ProgressUpdate { + item_updates, + total_bytes: report.total_bytes, + total_bytes_increment: total_increment, + total_bytes_completed: report.bytes_completed, + total_bytes_completion_increment: bytes_increment, + total_bytes_completion_rate: None, + total_transfer_bytes: 0, + total_transfer_bytes_increment: 0, + total_transfer_bytes_completed: 0, + total_transfer_bytes_completion_increment: 0, + total_transfer_bytes_completion_rate: None, + }; + + self.prev = Some(report); + update + } +} + +// === Shared finalization logic === + +#[cfg(debug_assertions)] +fn wrap_updater( + updater: Arc, +) -> (Arc, Option>) { + let v = super::ProgressUpdaterVerificationWrapper::new(updater); + (v.clone(), Some(v)) +} + +#[cfg(not(debug_assertions))] +fn wrap_updater(updater: Arc) -> (Arc, Option<()>) { + (updater, None) +} + +// === GroupProgressCallbackUpdater === + +/// Bridges the new polling-based progress model to the old callback-based model +/// at the group level. +/// +/// Spawns a background task that polls a `ProgressReporter` every 250ms, +/// computes incremental diffs across all items, and sends `ProgressUpdate` +/// structs to a `TrackingProgressUpdater`. +pub struct GroupProgressCallbackUpdater { + stop_signal: Arc, + handle: tokio::task::JoinHandle<()>, + #[cfg(debug_assertions)] + verifier: Option>, +} + +impl GroupProgressCallbackUpdater { + /// Start polling `reporter` every 250ms and send group-level diffs to `updater`. + pub fn start(reporter: Arc, updater: Arc) -> Self { + let (updater, _verifier) = wrap_updater(updater); + + let stop_signal = Arc::new(Notify::new()); + let stop = stop_signal.clone(); + + let handle = tokio::spawn(async move { + let mut state = GroupBridgeState::new(); + let mut interval = tokio::time::interval(Duration::from_millis(250)); + + loop { + tokio::select! { + _ = interval.tick() => { + let group = reporter.report(); + let items = reporter.item_reports(); + let update = state.compute_diff(group, items); + if !update.is_empty() { + updater.register_updates(update).await; + } + } + _ = stop.notified() => { + break; + } + } + } + + let group = reporter.report(); + let items = reporter.item_reports(); + let update = state.compute_diff(group, items); + if !update.is_empty() { + updater.register_updates(update).await; + } + updater.flush().await; + }); + + Self { + stop_signal, + handle, + #[cfg(debug_assertions)] + verifier: _verifier, + } + } + + /// Stop the polling loop, send a final update, and in debug mode verify completeness. + pub async fn finalize(self) { + self.stop_signal.notify_one(); + let _ = self.handle.await; + + #[cfg(debug_assertions)] + if let Some(v) = self.verifier { + v.assert_complete().await; + } + } +} + +// === ItemProgressCallbackUpdater === + +/// Bridges the new polling-based progress model to the old callback-based model +/// for a single item. +/// +/// Spawns a background task that polls a single item from a `ProgressReporter` +/// every 250ms, computes incremental diffs, and sends `ProgressUpdate` structs +/// to a `TrackingProgressUpdater`. +pub struct ItemProgressCallbackUpdater { + stop_signal: Arc, + handle: tokio::task::JoinHandle<()>, + #[cfg(debug_assertions)] + verifier: Option>, +} + +impl ItemProgressCallbackUpdater { + /// Start polling a single item from `reporter` every 250ms and send per-item + /// diffs to `updater`. + pub fn start( + reporter: Arc, + item_id: UniqueID, + updater: Arc, + ) -> Self { + let (updater, _verifier) = wrap_updater(updater); + + let stop_signal = Arc::new(Notify::new()); + let stop = stop_signal.clone(); + + let handle = tokio::spawn(async move { + let mut state = ItemBridgeState::new(); + let mut interval = tokio::time::interval(Duration::from_millis(250)); + + loop { + tokio::select! { + _ = interval.tick() => { + if let Some(report) = reporter.item_report(item_id) { + let update = state.compute_diff(item_id, report); + if !update.is_empty() { + updater.register_updates(update).await; + } + } + } + _ = stop.notified() => { + break; + } + } + } + + if let Some(report) = reporter.item_report(item_id) { + let update = state.compute_diff(item_id, report); + if !update.is_empty() { + updater.register_updates(update).await; + } + } + updater.flush().await; + }); + + Self { + stop_signal, + handle, + #[cfg(debug_assertions)] + verifier: _verifier, + } + } + + /// Stop the polling loop, send a final update, and in debug mode verify completeness. + pub async fn finalize(self) { + self.stop_signal.notify_one(); + let _ = self.handle.await; + + #[cfg(debug_assertions)] + if let Some(v) = self.verifier { + v.assert_complete().await; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_group_report( + total_bytes: u64, + total_bytes_completed: u64, + total_transfer_bytes: u64, + total_transfer_bytes_completed: u64, + ) -> GroupProgressReport { + GroupProgressReport { + total_bytes, + total_bytes_completed, + total_bytes_completion_rate: None, + total_transfer_bytes, + total_transfer_bytes_completed, + total_transfer_bytes_completion_rate: None, + } + } + + fn make_item_report(name: &str, total_bytes: u64, bytes_completed: u64) -> ItemProgressReport { + ItemProgressReport { + item_name: name.to_string(), + total_bytes, + bytes_completed, + } + } + + #[test] + fn test_group_bridge_first_diff() { + let mut state = GroupBridgeState::new(); + let id = UniqueID::new(); + + let group = make_group_report(1000, 200, 800, 100); + let items = HashMap::from([(id, make_item_report("a.bin", 1000, 200))]); + + let update = state.compute_diff(group, items); + + assert_eq!(update.total_bytes, 1000); + assert_eq!(update.total_bytes_increment, 1000); + assert_eq!(update.total_bytes_completed, 200); + assert_eq!(update.total_bytes_completion_increment, 200); + assert_eq!(update.total_transfer_bytes, 800); + assert_eq!(update.total_transfer_bytes_increment, 800); + assert_eq!(update.total_transfer_bytes_completed, 100); + assert_eq!(update.total_transfer_bytes_completion_increment, 100); + assert_eq!(update.item_updates.len(), 1); + assert_eq!(update.item_updates[0].total_bytes, 1000); + assert_eq!(update.item_updates[0].bytes_completed, 200); + assert_eq!(update.item_updates[0].bytes_completion_increment, 200); + } + + #[test] + fn test_group_bridge_incremental_diff() { + let mut state = GroupBridgeState::new(); + let id = UniqueID::new(); + + let group1 = make_group_report(1000, 200, 800, 100); + let items1 = HashMap::from([(id, make_item_report("a.bin", 1000, 200))]); + state.compute_diff(group1, items1); + + let group2 = make_group_report(1000, 600, 800, 400); + let items2 = HashMap::from([(id, make_item_report("a.bin", 1000, 600))]); + let update = state.compute_diff(group2, items2); + + assert_eq!(update.total_bytes_increment, 0); + assert_eq!(update.total_bytes_completion_increment, 400); + assert_eq!(update.total_transfer_bytes_increment, 0); + assert_eq!(update.total_transfer_bytes_completion_increment, 300); + assert_eq!(update.item_updates.len(), 1); + assert_eq!(update.item_updates[0].bytes_completion_increment, 400); + } + + #[test] + fn test_group_bridge_no_change_is_empty() { + let mut state = GroupBridgeState::new(); + let id = UniqueID::new(); + + let group = make_group_report(1000, 500, 800, 300); + let items = HashMap::from([(id, make_item_report("a.bin", 1000, 500))]); + state.compute_diff(group.clone(), items.clone()); + + let update = state.compute_diff(group, items); + + assert!(update.is_empty()); + } + + #[test] + fn test_group_bridge_new_item_appears() { + let mut state = GroupBridgeState::new(); + let id1 = UniqueID::new(); + let id2 = UniqueID::new(); + + let group1 = make_group_report(100, 50, 0, 0); + let items1 = HashMap::from([(id1, make_item_report("a.bin", 100, 50))]); + state.compute_diff(group1, items1); + + let group2 = make_group_report(300, 50, 0, 0); + let items2 = HashMap::from([ + (id1, make_item_report("a.bin", 100, 50)), + (id2, make_item_report("b.bin", 200, 0)), + ]); + let update = state.compute_diff(group2, items2); + + assert_eq!(update.total_bytes_increment, 200); + assert_eq!(update.item_updates.len(), 1); + assert_eq!(update.item_updates[0].tracking_id, id2); + assert_eq!(update.item_updates[0].bytes_completion_increment, 0); + } + + #[test] + fn test_item_bridge_first_diff() { + let mut state = ItemBridgeState::new(); + let id = UniqueID::new(); + let report = make_item_report("file.bin", 500, 100); + + let update = state.compute_diff(id, report); + + assert_eq!(update.total_bytes, 500); + assert_eq!(update.total_bytes_increment, 500); + assert_eq!(update.total_bytes_completed, 100); + assert_eq!(update.total_bytes_completion_increment, 100); + assert_eq!(update.item_updates.len(), 1); + assert_eq!(update.item_updates[0].bytes_completion_increment, 100); + } + + #[test] + fn test_item_bridge_incremental_diff() { + let mut state = ItemBridgeState::new(); + let id = UniqueID::new(); + + state.compute_diff(id, make_item_report("file.bin", 500, 100)); + + let update = state.compute_diff(id, make_item_report("file.bin", 500, 350)); + + assert_eq!(update.total_bytes_increment, 0); + assert_eq!(update.total_bytes_completion_increment, 250); + assert_eq!(update.item_updates[0].bytes_completion_increment, 250); + } + + #[test] + fn test_item_bridge_no_change_is_empty() { + let mut state = ItemBridgeState::new(); + let id = UniqueID::new(); + + state.compute_diff(id, make_item_report("file.bin", 500, 200)); + let update = state.compute_diff(id, make_item_report("file.bin", 500, 200)); + + assert!(update.is_empty()); + } + + #[test] + fn test_item_bridge_total_grows() { + let mut state = ItemBridgeState::new(); + let id = UniqueID::new(); + + state.compute_diff(id, make_item_report("file.bin", 500, 100)); + let update = state.compute_diff(id, make_item_report("file.bin", 800, 100)); + + assert_eq!(update.total_bytes, 800); + assert_eq!(update.total_bytes_increment, 300); + assert_eq!(update.total_bytes_completion_increment, 0); + assert!(update.item_updates.is_empty()); + } +} diff --git a/xet_data/src/progress_tracking/progress_info.rs b/xet_pkg/src/legacy/progress_tracking/mod.rs similarity index 81% rename from xet_data/src/progress_tracking/progress_info.rs rename to xet_pkg/src/legacy/progress_tracking/mod.rs index ef5fa6e7..85f0deda 100644 --- a/xet_data/src/progress_tracking/progress_info.rs +++ b/xet_pkg/src/legacy/progress_tracking/mod.rs @@ -1,12 +1,29 @@ +mod callback_bridge; +mod progress_verification_wrapper; + use std::fmt::Debug; use std::sync::Arc; -use ulid::Ulid; +use async_trait::async_trait; +pub use callback_bridge::{GroupProgressCallbackUpdater, ItemProgressCallbackUpdater, ProgressReporter}; +pub use progress_verification_wrapper::ProgressUpdaterVerificationWrapper; +use xet_data::progress_tracking::UniqueID; + +/// The trait that a progress updater that reports per-item progress completion. +#[async_trait] +pub trait TrackingProgressUpdater: Send + Sync { + /// Register a set of updates as a list of ProgressUpdate instances, which + /// contain the name and progress information. + async fn register_updates(&self, updates: ProgressUpdate); + + /// Flush any updates out, if needed + async fn flush(&self) {} +} /// A class to make all the bookkeeping clear with progress updating. #[derive(Clone, Debug)] pub struct ItemProgressUpdate { - pub tracking_id: Ulid, + pub tracking_id: UniqueID, pub item_name: Arc, // The total bytes in this item, independent from the total bytes of all items. @@ -50,13 +67,13 @@ pub struct ProgressUpdate { /// The total bytes that have been processed. pub total_bytes_completed: u64, - /// How much this update adjusts the total bytes.. + /// How much this update adjusts the total bytes. pub total_bytes_completion_increment: u64, - /// The rate at which the total bytes are being processed, if known. + /// The rate at which the total bytes are being processed, if known. pub total_bytes_completion_rate: Option, - /// Total bytes known that need to be uploaded or downloaded. + /// Total bytes known that need to be uploaded or downloaded. pub total_transfer_bytes: u64, /// The change in total transfer bytes known from the last update. @@ -68,7 +85,7 @@ pub struct ProgressUpdate { /// How much this update adjusts the total transfer bytes. pub total_transfer_bytes_completion_increment: u64, - /// The total bytes that have been processed + /// The transfer-byte completion rate, if known. pub total_transfer_bytes_completion_rate: Option, } diff --git a/xet_data/src/progress_tracking/verification_wrapper.rs b/xet_pkg/src/legacy/progress_tracking/progress_verification_wrapper.rs similarity index 88% rename from xet_data/src/progress_tracking/verification_wrapper.rs rename to xet_pkg/src/legacy/progress_tracking/progress_verification_wrapper.rs index 91fb6fba..fd39ccab 100644 --- a/xet_data/src/progress_tracking/verification_wrapper.rs +++ b/xet_pkg/src/legacy/progress_tracking/progress_verification_wrapper.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use async_trait::async_trait; use more_asserts::{assert_ge, assert_le}; use tokio::sync::Mutex; -use ulid::Ulid; +use xet_data::progress_tracking::UniqueID; use super::{ProgressUpdate, TrackingProgressUpdater}; @@ -17,7 +17,7 @@ struct ItemProgressData { #[derive(Debug, Default)] pub struct ProgressUpdaterVerificationWrapperImpl { - items: HashMap, ItemProgressData)>, + items: HashMap, ItemProgressData)>, total_transfer_bytes: u64, total_transfer_bytes_completed: u64, total_bytes: u64, @@ -66,7 +66,6 @@ impl ProgressUpdaterVerificationWrapper { #[async_trait] impl TrackingProgressUpdater for ProgressUpdaterVerificationWrapper { async fn register_updates(&self, update: ProgressUpdate) { - // First, capture and validate let mut tr = self.tr.lock().await; for up in update.item_updates.iter() { @@ -78,8 +77,6 @@ impl TrackingProgressUpdater for ProgressUpdaterVerificationWrapper { }, )); - // Record the total_count for this item, allowing it to grow monotonically - // (e.g. when the file size is not known upfront and is updated incrementally). if entry.1.total_count == 0 { entry.1.total_count = up.total_bytes; } else { @@ -94,8 +91,6 @@ impl TrackingProgressUpdater for ProgressUpdaterVerificationWrapper { entry.1.total_count = up.total_bytes; } - // Check increments: - // 1) `completed_count` should never go down assert!( up.bytes_completed >= entry.1.last_completed, "Item '{}' completed_count went backwards: old={}, new={}", @@ -104,7 +99,6 @@ impl TrackingProgressUpdater for ProgressUpdaterVerificationWrapper { up.bytes_completed ); - // 2) `completed_count` must not exceed `total_count` assert!( up.bytes_completed <= up.total_bytes, "Item '{}' completed_count {} exceeds total {}", @@ -113,7 +107,6 @@ impl TrackingProgressUpdater for ProgressUpdaterVerificationWrapper { up.total_bytes ); - // 3) The increment must match the difference let expected_new = entry.1.last_completed + up.bytes_completion_increment; assert_eq!( up.bytes_completed, expected_new, @@ -121,7 +114,6 @@ impl TrackingProgressUpdater for ProgressUpdaterVerificationWrapper { up.item_name, entry.1.last_completed, up.bytes_completion_increment, up.bytes_completed ); - // Update item record entry.1.last_completed = up.bytes_completed; } @@ -189,7 +181,6 @@ impl TrackingProgressUpdater for ProgressUpdaterVerificationWrapper { update.total_bytes_completed, tr.total_process_bytes_completed ); - // Now forward them to the inner updater self.inner.register_updates(update).await; } async fn flush(&self) { @@ -199,13 +190,9 @@ impl TrackingProgressUpdater for ProgressUpdaterVerificationWrapper { #[cfg(test)] mod tests { - use ulid::Ulid; - + use super::super::ItemProgressUpdate; use super::*; - use crate::progress_tracking::ItemProgressUpdate; - /// A trivial `TrackingProgressUpdater` for testing, which just stores all updates. - /// In real code, this could log to a file, update a UI, etc. #[derive(Debug, Default)] struct DummyLogger { pub all_updates: Mutex>, @@ -221,16 +208,12 @@ mod tests { #[tokio::test] async fn test_verification_wrapper() { - // Create an actual inner logger or progress sink let logger = Arc::new(DummyLogger::default()); - - // Wrap it with our verification wrapper let wrapper = ProgressUpdaterVerificationWrapper::new(logger.clone()); - let file_a = (Ulid::new(), "fileA"); - let file_b = (Ulid::new(), "fileB"); + let file_a = (UniqueID::new(), "fileA"); + let file_b = (UniqueID::new(), "fileB"); - // Let's register some progress updates wrapper .register_updates(ProgressUpdate { item_updates: vec![ @@ -261,7 +244,6 @@ mod tests { }) .await; - // Shouldn't be complete yet. We'll do one more set of updates to finalize. wrapper .register_updates(ProgressUpdate { item_updates: vec![ @@ -292,11 +274,9 @@ mod tests { }) .await; - // Now all items should be fully complete wrapper.assert_complete().await; - // We can also inspect the inner logger's captured updates: let final_updates = logger.all_updates.lock().await; - assert_eq!(final_updates.len(), 4, "We sent 4 updates total"); + assert_eq!(final_updates.len(), 4); } } diff --git a/xet_pkg/src/lib.rs b/xet_pkg/src/lib.rs index f9e0135d..fe7bfa29 100644 --- a/xet_pkg/src/lib.rs +++ b/xet_pkg/src/lib.rs @@ -1,4 +1,5 @@ pub mod error; pub use error::XetError; +pub mod legacy; pub mod xet_session; diff --git a/xet_pkg/src/xet_session/download_group.rs b/xet_pkg/src/xet_session/download_group.rs index 4415bd51..25a26c34 100644 --- a/xet_pkg/src/xet_session/download_group.rs +++ b/xet_pkg/src/xet_session/download_group.rs @@ -5,14 +5,15 @@ use std::path::PathBuf; use std::sync::{Arc, Mutex, MutexGuard, OnceLock, RwLock}; use tokio::task::JoinHandle; -use ulid::Ulid; +use xet_data::DataError; use xet_data::processing::{FileDownloadSession, XetFileInfo}; +use xet_data::progress_tracking::{GroupProgressReport, UniqueID}; use xet_runtime::core::XetRuntime; use super::common::{GroupState, create_translator_config}; use super::errors::SessionError; -use super::progress::{DownloadTaskHandle, GroupProgress, ProgressSnapshot, TaskHandle, TaskStatus}; use super::session::XetSession; +use super::tasks::{DownloadTaskHandle, TaskHandle, TaskStatus}; /// API for grouping related file downloads into a single unit of work. /// @@ -52,17 +53,14 @@ impl DownloadGroup { /// Create a new download group from an **async** context. Initialisation logic shared by the sync and async /// constructors. pub(super) async fn new(session: XetSession) -> Result { - let group_id = Ulid::new(); - let progress = Arc::new(GroupProgress::new()); + let group_id = UniqueID::new(); let config = create_translator_config(&session)?; - let progress_updater = progress.clone() as Arc; - let download_session = FileDownloadSession::new(Arc::new(config), Some(progress_updater)).await?; + let download_session = FileDownloadSession::new(Arc::new(config)).await?; let inner = Arc::new(DownloadGroupInner { group_id, session, active_tasks: RwLock::new(HashMap::new()), - progress, download_session: Mutex::new(Some(download_session)), state: Mutex::new(GroupState::Alive), }); @@ -71,7 +69,7 @@ impl DownloadGroup { } /// Get the group ID. - pub(super) fn id(&self) -> Ulid { + pub(super) fn id(&self) -> UniqueID { self.group_id } @@ -102,7 +100,7 @@ impl DownloadGroup { /// Returns [`SessionError::Aborted`] if the session has been aborted, or /// [`SessionError::AlreadyFinished`] if [`finish`](Self::finish) has already /// been called. - pub fn download_file_to_path( + pub async fn download_file_to_path( &self, file_info: XetFileInfo, dest_path: PathBuf, @@ -112,12 +110,15 @@ impl DownloadGroup { // Use the absolute path in case the process current working directory changes // while the task is queued. let absolute_path = std::path::absolute(dest_path)?; - self.inner.start_download_file_to_path(file_info, absolute_path) + self.inner.start_download_file_to_path(file_info, absolute_path).await } /// Return a snapshot of progress for every queued download. - pub fn get_progress(&self) -> Result { - self.progress.snapshot() + pub fn get_progress(&self) -> Result { + let Some(download_session) = self.download_session.lock()?.clone() else { + return Ok(GroupProgressReport::default()); + }; + Ok(download_session.report()) } /// Wait for all downloads to complete and return their results. @@ -133,7 +134,7 @@ impl DownloadGroup { /// /// Consumes `self` — subsequent calls on any clone will return /// [`SessionError::AlreadyFinished`]. - pub async fn finish(self) -> Result, SessionError> { + pub async fn finish(self) -> Result, SessionError> { let inner = self.inner.clone(); self.session .dispatch("finish", async move { inner.handle_finish().await }) @@ -149,12 +150,32 @@ impl DownloadGroup { } } + /// Blocking version of [`download_file_to_path`](Self::download_file_to_path). + /// + /// # Panics + /// + /// Panics if called from within a tokio async runtime. + pub fn download_file_to_path_blocking( + &self, + file_info: XetFileInfo, + dest_path: PathBuf, + ) -> Result { + let group = self.clone(); + self.runtime() + .external_run_async_task(async move { group.download_file_to_path(file_info, dest_path).await })? + } + + /// Blocking version of [`get_progress`](Self::get_progress). + pub fn get_progress_blocking(&self) -> Result { + self.get_progress() + } + /// Blocking version of [`finish`](Self::finish). /// /// # Panics /// /// Panics if called from within a tokio async runtime. - pub fn finish_blocking(self) -> Result, SessionError> { + pub fn finish_blocking(self) -> Result, SessionError> { let group = self.clone(); self.runtime().external_run_async_task(group.finish())? } @@ -171,7 +192,8 @@ pub type DownloadResult = Arc>; struct InnerDownloadTaskHandle { status: Arc>, dest_path: PathBuf, - join_handle: JoinHandle>, + file_info: XetFileInfo, + join_handle: JoinHandle>, result: Arc>, } @@ -179,14 +201,11 @@ struct InnerDownloadTaskHandle { /// Accessed through `Arc`; do not use this type directly. #[doc(hidden)] pub struct DownloadGroupInner { - group_id: Ulid, + group_id: UniqueID, session: XetSession, // Active download tasks for this group - active_tasks: RwLock>, - - // Aggregate + per-file progress, fed into FileDownloadSession as a TrackingProgressUpdater - progress: Arc, + active_tasks: RwLock>, // Shared download session (FileDownloadSession from data crate) download_session: Mutex>>, @@ -207,101 +226,66 @@ impl DownloadGroupInner { } } - /// Spawn a runtime task that performs the actual file download. - fn spawn_download_task( - self: &Arc, - download_session: Arc, - file_info: XetFileInfo, - dest_path: PathBuf, - status: Arc>, - tracking_id: Ulid, - ) -> JoinHandle> { - let semaphore = self.runtime().common().file_download_semaphore.clone(); - self.runtime().spawn(async move { - let _permit = semaphore.acquire().await?; - - // Only transition Queued → Running; bail if abort() already set Cancelled. - { - let mut s = status.lock()?; - if !matches!(*s, TaskStatus::Queued) { - return Err(SessionError::Aborted); - } - *s = TaskStatus::Running; - } - - let result: Result<_, SessionError> = download_session - .download_file(&file_info, &dest_path, tracking_id) - .await - .map_err(SessionError::from); - - let new_status = if result.is_ok() { - TaskStatus::Completed - } else { - TaskStatus::Failed - }; - // Only overwrite if still Running — abort() may have set Cancelled concurrently. - let mut s = status.lock()?; - if matches!(*s, TaskStatus::Running) { - *s = new_status; - } - - Ok(XetFileInfo { - hash: file_info.hash, - file_size: result?, - sha256: None, - }) - }) - } - - fn start_download_file_to_path( + async fn start_download_file_to_path( self: &Arc, file_info: XetFileInfo, dest_path: PathBuf, ) -> Result { - // Hold the state lock guard for the duration of this function so finish() will not run - // when a download task is registering. - let state = self.state.lock()?; - Self::check_accepting_tasks(&state)?; + let download_session = { + let state = self.state.lock()?; + Self::check_accepting_tasks(&state)?; + + let Some(download_session) = self.download_session.lock()?.clone() else { + return Err(SessionError::other("Download session not initialized")); + }; + download_session + // state guard dropped here before the .await + }; + + let (task_id, join_handle) = self + .session + .dispatch("spawn_download_file", { + let file_info = file_info.clone(); + let dest_path = dest_path.clone(); + async move { download_session.download_file_background(file_info, dest_path).await } + }) + .await??; + + // Re-check state: if finish() or abort() raced in, cancel the spawned task. + { + let state = self.state.lock()?; + if !matches!(*state, GroupState::Alive) { + join_handle.abort(); + Self::check_accepting_tasks(&state)?; + } + } - let tracking_id = Ulid::new(); let status = Arc::new(Mutex::new(TaskStatus::Queued)); - let result: Arc> = Arc::new(OnceLock::new()); let task_handle = DownloadTaskHandle { inner: TaskHandle { status: Some(status.clone()), - group_progress: self.progress.clone(), - task_id: tracking_id, + task_id, }, result: result.clone(), }; - let Some(download_session) = self.download_session.lock()?.clone() else { - return Err(SessionError::other("Download session not initialized")); - }; - - let join_handle = self.spawn_download_task( - download_session, - file_info.clone(), - dest_path.clone(), - status.clone(), - tracking_id, - ); - let handle = InnerDownloadTaskHandle { status, dest_path, + file_info, join_handle, result, }; - self.active_tasks.write()?.insert(tracking_id, handle); + TaskStatus::mark_running(&handle.status); + self.active_tasks.write()?.insert(task_id, handle); Ok(task_handle) } /// Join all active download tasks and mark the group as finished. - pub(super) async fn handle_finish(&self) -> Result, SessionError> { + pub(super) async fn handle_finish(&self) -> Result, SessionError> { // Mark as not accepting new tasks { let mut state_guard = self.state.lock()?; @@ -316,29 +300,36 @@ impl DownloadGroupInner { let mut results = HashMap::new(); let mut join_err = None; - // Join all tasks first and then propogate errors. + // Join all tasks first and then propagate errors. for (task_id, handle) in active_tasks { - match handle.join_handle.await.map_err(SessionError::from) { - Ok(Ok(file_info)) => { + match handle.join_handle.await { + Ok(Ok(n_bytes)) => { + TaskStatus::mark_terminal(&handle.status, TaskStatus::Completed); let result = Arc::new(Ok(DownloadedFile { dest_path: handle.dest_path, - file_info, + file_info: XetFileInfo { + hash: handle.file_info.hash, + file_size: n_bytes, + sha256: None, + }, })); results.insert(task_id, result.clone()); - // Update result to the external task handle, this is the only place setting - // the result, so no error will happen. let _ = handle.result.set(result); }, - Ok(Err(task_err)) => { - let result: Arc> = Arc::new(Err(task_err)); + Ok(Err(data_err)) => { + TaskStatus::mark_terminal(&handle.status, TaskStatus::Failed); + let result: DownloadResult = Arc::new(Err(data_err.into())); results.insert(task_id, result.clone()); - // Update result to the external task handle, this is the only place setting - // the result, so no error will happen. let _ = handle.result.set(result); }, Err(e) => { + if e.is_cancelled() { + TaskStatus::mark_cancelled(&handle.status); + } else { + TaskStatus::mark_terminal(&handle.status, TaskStatus::Failed); + } if join_err.is_none() { - join_err = Some(e); + join_err = Some(SessionError::from(e)); } }, } @@ -356,16 +347,12 @@ impl DownloadGroupInner { Ok(results) } - fn runtime(&self) -> &XetRuntime { - &self.session.runtime - } - fn abort(&self) -> Result<(), SessionError> { *self.state.lock()? = GroupState::Aborted; let active_tasks = std::mem::take(&mut *self.active_tasks.write()?); for (_tracking_id, inner_task_handle) in active_tasks { + TaskStatus::mark_cancelled(&inner_task_handle.status); inner_task_handle.join_handle.abort(); - let _ = inner_task_handle.status.lock().map(|mut s| *s = TaskStatus::Cancelled); } Ok(()) @@ -477,10 +464,9 @@ mod tests { async fn test_get_progress_empty_initially() { let session = XetSessionBuilder::new().build_async().await.unwrap(); let group = session.new_download_group().await.unwrap(); - let snapshot = group.get_progress().unwrap(); - let total = snapshot.total(); - assert_eq!(total.total_bytes, 0); - assert_eq!(total.total_bytes_completed, 0); + let report = group.get_progress().unwrap(); + assert_eq!(report.total_bytes, 0); + assert_eq!(report.total_bytes_completed, 0); } // ── Finish lifecycle ───────────────────────────────────────────────────── @@ -542,6 +528,7 @@ mod tests { }, std::path::PathBuf::from("dest.bin"), ) + .await .unwrap_err(); assert!(matches!(err, SessionError::Aborted)); } @@ -562,6 +549,7 @@ mod tests { }, std::path::PathBuf::from("dest.bin"), ) + .await .unwrap_err(); assert!(matches!(err, SessionError::AlreadyFinished)); } @@ -581,6 +569,7 @@ mod tests { }, std::path::PathBuf::from("dest.bin"), ) + .await .unwrap_err(); assert!(matches!(err, SessionError::Aborted)); } @@ -609,12 +598,65 @@ mod tests { let dest = temp.path().join("downloaded.bin"); let group = session.new_download_group().await.unwrap(); - group.download_file_to_path(file_info, dest.clone()).unwrap(); + let handle = group.download_file_to_path(file_info, dest.clone()).await.unwrap(); + assert!(matches!(handle.status().unwrap(), TaskStatus::Queued | TaskStatus::Running | TaskStatus::Completed)); group.finish().await.unwrap(); + assert!(matches!(handle.status().unwrap(), TaskStatus::Completed)); assert_eq!(std::fs::read(&dest).unwrap(), original); } + #[tokio::test(flavor = "multi_thread")] + // A download task that fails transitions to Failed status. + async fn test_download_status_failed_for_invalid_file_info() { + let temp = tempdir().unwrap(); + let session = local_session(&temp).await.unwrap(); + let group = session.new_download_group().await.unwrap(); + let handle = group + .download_file_to_path( + XetFileInfo { + hash: "abc123".to_string(), + file_size: 123, + sha256: None, + }, + temp.path().join("missing.bin"), + ) + .await + .unwrap(); + let results = group.finish().await.unwrap(); + let task_result = results.get(&handle.task_id).unwrap(); + assert!(task_result.is_err()); + assert!(matches!(handle.status().unwrap(), TaskStatus::Failed)); + } + + #[tokio::test(flavor = "multi_thread")] + // task_id returned by download_file_to_path must match the per-item progress entry id. + async fn test_download_task_id_matches_progress_item_id() { + let temp = tempdir().unwrap(); + let session = local_session(&temp).await.unwrap(); + let original = b"download id match"; + let file_info = upload_bytes(&session, original, "id.bin").await.unwrap(); + + let dest = temp.path().join("download_id.bin"); + let group = session.new_download_group().await.unwrap(); + let handle = group.download_file_to_path(file_info, dest).await.unwrap(); + + let download_session = group.inner.download_session.lock().unwrap().clone().unwrap(); + + let mut reports = HashMap::new(); + for _ in 0..50 { + reports = download_session.item_reports(); + if !reports.is_empty() { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + + assert!(reports.contains_key(&handle.task_id)); + + group.finish().await.unwrap(); + } + #[tokio::test(flavor = "multi_thread")] // Downloading multiple files from a single group produces correct content for each. async fn test_download_multiple_files() { @@ -635,7 +677,7 @@ mod tests { .unwrap(); let results = commit.commit().await.unwrap(); - let to_file_info = |handle: &crate::xet_session::progress::UploadTaskHandle| -> XetFileInfo { + let to_file_info = |handle: &crate::xet_session::tasks::UploadTaskHandle| -> XetFileInfo { let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap(); XetFileInfo { hash: meta.hash.clone(), @@ -647,8 +689,14 @@ mod tests { let dest_a = temp.path().join("a_out.bin"); let dest_b = temp.path().join("b_out.bin"); let group = session.new_download_group().await.unwrap(); - group.download_file_to_path(to_file_info(&handle_a), dest_a.clone()).unwrap(); - group.download_file_to_path(to_file_info(&handle_b), dest_b.clone()).unwrap(); + group + .download_file_to_path(to_file_info(&handle_a), dest_a.clone()) + .await + .unwrap(); + group + .download_file_to_path(to_file_info(&handle_b), dest_b.clone()) + .await + .unwrap(); group.finish().await.unwrap(); assert_eq!(std::fs::read(&dest_a).unwrap(), data_a); @@ -666,7 +714,7 @@ mod tests { let dest = temp.path().join("out.bin"); let group = session.new_download_group().await.unwrap(); let progress_observer = group.clone(); - group.download_file_to_path(file_info, dest).unwrap(); + group.download_file_to_path(file_info, dest).await.unwrap(); group.finish().await.unwrap(); tokio::time::sleep( @@ -678,8 +726,11 @@ mod tests { .saturating_add(Duration::from_secs(1)), ) .await; - let snapshot = progress_observer.get_progress().unwrap(); - assert!(snapshot.total().total_bytes_completed > 0); + let report = progress_observer.get_progress().unwrap(); + assert_eq!(report.total_bytes, original.len() as u64); + assert_eq!(report.total_bytes_completed, original.len() as u64); + assert_eq!(report.total_transfer_bytes, report.total_transfer_bytes_completed); + assert!(report.total_transfer_bytes_completed > 0); } // ── Per-task result access patterns ────────────────────────────────────── @@ -693,7 +744,7 @@ mod tests { let file_info = upload_bytes(&session, data, "file.bin").await.unwrap(); let dest = temp.path().join("out.bin"); let group = session.new_download_group().await.unwrap(); - let handle = group.download_file_to_path(file_info, dest).unwrap(); + let handle = group.download_file_to_path(file_info, dest).await.unwrap(); let results = group.finish().await.unwrap(); let result = results.get(&handle.task_id).expect("task_id must be present in results"); assert_eq!(result.as_ref().as_ref().unwrap().file_info.file_size, data.len() as u64); @@ -707,7 +758,7 @@ mod tests { let file_info = upload_bytes(&session, b"some data", "file.bin").await.unwrap(); let dest = temp.path().join("out.bin"); let group = session.new_download_group().await.unwrap(); - let handle = group.download_file_to_path(file_info, dest).unwrap(); + let handle = group.download_file_to_path(file_info, dest).await.unwrap(); assert!(handle.result().is_none(), "result must be None before finish()"); group.finish().await.unwrap(); } @@ -721,7 +772,7 @@ mod tests { let file_info = upload_bytes(&session, data, "file.bin").await.unwrap(); let dest = temp.path().join("out.bin"); let group = session.new_download_group().await.unwrap(); - let handle = group.download_file_to_path(file_info.clone(), dest).unwrap(); + let handle = group.download_file_to_path(file_info.clone(), dest).await.unwrap(); group.finish().await.unwrap(); let result = handle.result().expect("result must be set after finish()"); let dl = result.as_ref().as_ref().unwrap(); @@ -758,7 +809,7 @@ mod tests { let dest = temp.path().join("out_futures.bin"); let group = session.new_download_group().await.unwrap(); - group.download_file_to_path(file_info, dest.clone()).unwrap(); + group.download_file_to_path(file_info, dest.clone()).await.unwrap(); group.finish().await.unwrap(); assert_eq!(std::fs::read(&dest).unwrap(), data); }); @@ -789,7 +840,7 @@ mod tests { let dest = temp.path().join("out_smol.bin"); let group = session.new_download_group().await.unwrap(); - group.download_file_to_path(file_info, dest.clone()).unwrap(); + group.download_file_to_path(file_info, dest.clone()).await.unwrap(); group.finish().await.unwrap(); assert_eq!(std::fs::read(&dest).unwrap(), data); }); @@ -820,7 +871,7 @@ mod tests { let dest = temp.path().join("out_async_std.bin"); let group = session.new_download_group().await.unwrap(); - group.download_file_to_path(file_info, dest.clone()).unwrap(); + group.download_file_to_path(file_info, dest.clone()).await.unwrap(); group.finish().await.unwrap(); assert_eq!(std::fs::read(&dest).unwrap(), data); }); @@ -860,7 +911,7 @@ mod tests { let dest = temp.path().join("downloaded.bin"); let group = session.new_download_group_blocking()?; - group.download_file_to_path(file_info, dest.clone())?; + group.download_file_to_path_blocking(file_info, dest.clone())?; group.finish_blocking()?; assert_eq!(std::fs::read(&dest)?, original); @@ -880,7 +931,7 @@ mod tests { let handle_b = commit.upload_bytes_blocking(data_b.to_vec(), Sha256Policy::Compute, Some("b.bin".into()))?; let results = commit.commit_blocking()?; - let to_file_info = |handle: &crate::xet_session::progress::UploadTaskHandle| -> XetFileInfo { + let to_file_info = |handle: &crate::xet_session::tasks::UploadTaskHandle| -> XetFileInfo { let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap(); XetFileInfo { hash: meta.hash.clone(), @@ -892,8 +943,8 @@ mod tests { let dest_a = temp.path().join("a_out.bin"); let dest_b = temp.path().join("b_out.bin"); let group = session.new_download_group_blocking()?; - group.download_file_to_path(to_file_info(&handle_a), dest_a.clone())?; - group.download_file_to_path(to_file_info(&handle_b), dest_b.clone())?; + group.download_file_to_path_blocking(to_file_info(&handle_a), dest_a.clone())?; + group.download_file_to_path_blocking(to_file_info(&handle_b), dest_b.clone())?; group.finish_blocking()?; assert_eq!(std::fs::read(&dest_a)?, data_a); @@ -911,7 +962,7 @@ mod tests { let dest = temp.path().join("out.bin"); let group = session.new_download_group_blocking()?; let progress_observer = group.clone(); - group.download_file_to_path(file_info, dest)?; + group.download_file_to_path_blocking(file_info, dest)?; group.finish_blocking()?; std::thread::sleep( @@ -922,8 +973,11 @@ mod tests { .progress_update_interval .saturating_add(Duration::from_secs(1)), ); - let snapshot = progress_observer.get_progress()?; - assert!(snapshot.total().total_bytes_completed > 0); + let snapshot = progress_observer.get_progress_blocking()?; + assert_eq!(snapshot.total_bytes, original.len() as u64); + assert_eq!(snapshot.total_bytes_completed, original.len() as u64); + assert_eq!(snapshot.total_transfer_bytes, snapshot.total_transfer_bytes_completed); + assert!(snapshot.total_transfer_bytes_completed > 0); Ok(()) } @@ -935,7 +989,7 @@ mod tests { let file_info = upload_bytes_blocking(&session, data, "file.bin")?; let dest = temp.path().join("out.bin"); let group = session.new_download_group_blocking()?; - let handle = group.download_file_to_path(file_info.clone(), dest)?; + let handle = group.download_file_to_path_blocking(file_info.clone(), dest)?; // Before finish, per-task result is not available yet. assert!(handle.result().is_none()); @@ -966,7 +1020,7 @@ mod tests { let file_info = upload_bytes_blocking(&session, data, "test.bin").unwrap(); let dest = temp.path().join("out_smol.bin"); let group = session.new_download_group_blocking().unwrap(); - group.download_file_to_path(file_info, dest.clone()).unwrap(); + group.download_file_to_path_blocking(file_info, dest.clone()).unwrap(); group.finish_blocking().unwrap(); assert_eq!(std::fs::read(&dest).unwrap(), data); })); diff --git a/xet_pkg/src/xet_session/mod.rs b/xet_pkg/src/xet_session/mod.rs index 598e5247..fdab569d 100644 --- a/xet_pkg/src/xet_session/mod.rs +++ b/xet_pkg/src/xet_session/mod.rs @@ -24,7 +24,7 @@ //! [`upload_bytes_blocking`](UploadCommit::upload_bytes_blocking), then call //! [`commit`](UploadCommit::commit) or //! [`commit_blocking`](UploadCommit::commit_blocking) to wait for all -//! transfers to finish and receive a `HashMap` +//! transfers to finish and receive a `HashMap<`[`UniqueID`]`, `[`UploadResult`]`>` //! keyed by task ID. //! //! `UploadResult` = `Arc>`. @@ -35,10 +35,11 @@ //! //! Create a [`DownloadGroup`] with [`XetSession::new_download_group`] (async) //! or [`XetSession::new_download_group_blocking`] (sync), queue files with -//! [`download_file_to_path`](DownloadGroup::download_file_to_path), then call -//! [`finish`](DownloadGroup::finish) (async) or +//! [`download_file_to_path`](DownloadGroup::download_file_to_path) / +//! [`download_file_to_path_blocking`](DownloadGroup::download_file_to_path_blocking), +//! then call [`finish`](DownloadGroup::finish) (async) or //! [`finish_blocking`](DownloadGroup::finish_blocking) (sync) to wait for all -//! transfers to complete and receive a `HashMap` +//! transfers to complete and receive a `HashMap<`[`UniqueID`]`, `[`DownloadResult`]`>` //! keyed by task ID. //! //! `DownloadResult` = `Arc>`. @@ -48,7 +49,7 @@ //! ## Progress tracking //! //! Both [`UploadCommit`] and [`DownloadGroup`] expose `get_progress()`, -//! which returns a [`ProgressSnapshot`] without acquiring a lock on the +//! which returns a [`GroupProgressReport`] without acquiring a lock on the //! calling thread (useful for Python bindings that must release the GIL). //! Poll it from a background thread/task while the main thread/task blocks //! in `commit()` / `finish()`. @@ -56,9 +57,9 @@ //! ## Error handling //! //! All public methods return `Result<_, `[`SessionError`]`>`. -//! [`commit`](UploadCommit::commit) returns `HashMap` +//! [`commit`](UploadCommit::commit) returns `HashMap<`[`UniqueID`]`, `[`UploadResult`]`>` //! keyed by task ID, and [`finish`](DownloadGroup::finish) returns -//! `HashMap` keyed by task ID, so a single failed +//! `HashMap<`[`UniqueID`]`, `[`DownloadResult`]`>` keyed by task ID, so a single failed //! file does not discard all others. //! //! # Quick start — sync API @@ -87,7 +88,7 @@ //! file_size: m.file_size, //! sha256: m.sha256.clone(), //! }; -//! let dl_handle = group.download_file_to_path(info, "out/file.bin".into())?; +//! let dl_handle = group.download_file_to_path_blocking(info, "out/file.bin".into())?; //! let finish_results = group.finish_blocking()?; //! // DownloadResult = Arc> //! let r = finish_results.get(&dl_handle.task_id).unwrap().as_ref().as_ref().unwrap(); @@ -124,7 +125,7 @@ //! file_size: m.file_size, //! sha256: m.sha256.clone(), //! }; -//! let dl_handle = group.download_file_to_path(info, "out/file.bin".into())?; +//! let dl_handle = group.download_file_to_path(info, "out/file.bin".into()).await?; //! let finish_results = group.finish().await?; //! // DownloadResult = Arc> //! let r = finish_results.get(&dl_handle.task_id).unwrap().as_ref().as_ref().unwrap(); @@ -135,16 +136,15 @@ mod common; mod download_group; mod errors; -mod progress; mod session; +mod tasks; mod upload_commit; pub use download_group::{DownloadGroup, DownloadResult, DownloadedFile}; pub use errors::SessionError; -pub use progress::{ - DownloadTaskHandle, FileProgress, ProgressSnapshot, TaskHandle, TaskStatus, TotalProgressSnapshot, UploadTaskHandle, -}; pub use session::{XetSession, XetSessionBuilder}; +pub use tasks::{DownloadTaskHandle, TaskHandle, TaskStatus, UploadTaskHandle}; pub use upload_commit::{FileMetadata, UploadCommit, UploadResult}; pub use xet_data::processing::{Sha256Policy, XetFileInfo}; +pub use xet_data::progress_tracking::{GroupProgressReport, ItemProgressReport, UniqueID}; pub use xet_runtime::config::XetConfig; diff --git a/xet_pkg/src/xet_session/progress.rs b/xet_pkg/src/xet_session/progress.rs deleted file mode 100644 index 2a582027..00000000 --- a/xet_pkg/src/xet_session/progress.rs +++ /dev/null @@ -1,518 +0,0 @@ -//! Progress tracking for upload commits and download groups. - -use std::collections::HashMap; -use std::ops::Deref; -use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::{Arc, Mutex, OnceLock}; - -use async_trait::async_trait; -use ulid::Ulid; -use xet_data::progress_tracking::{ProgressUpdate, TrackingProgressUpdater}; - -use super::SessionError; -use super::download_group::DownloadResult; -use super::upload_commit::UploadResult; - -/// Lifecycle state of a single upload or download task. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum TaskStatus { - /// Task has been queued but has not started executing yet. - Queued, - /// Task is actively transferring data. - Running, - /// Task finished successfully. - Completed, - /// Task encountered an error and did not complete. - Failed, - /// Task was cancelled before it could complete. - Cancelled, -} - -#[derive(Debug)] -pub struct TaskHandle { - pub(super) status: Option>>, - pub(super) group_progress: Arc, - /// Id of the task, can be used to retrive per-task progress and result. - pub task_id: Ulid, -} - -#[derive(Debug)] -pub struct UploadTaskHandle { - pub(super) inner: TaskHandle, - pub(super) result: Arc>, -} - -impl Deref for UploadTaskHandle { - type Target = TaskHandle; - - fn deref(&self) -> &Self::Target { - &self.inner - } -} - -#[derive(Debug)] -pub struct DownloadTaskHandle { - pub(super) inner: TaskHandle, - pub(super) result: Arc>, -} - -impl Deref for DownloadTaskHandle { - type Target = TaskHandle; - - fn deref(&self) -> &Self::Target { - &self.inner - } -} - -impl TaskHandle { - pub fn status(&self) -> Result { - if let Some(status) = &self.status { - Ok(*status.lock()?) - } else { - Err(SessionError::other("status not available")) - } - } - - pub fn progress(&self) -> Result { - self.group_progress.file(self.task_id) - } -} - -impl UploadTaskHandle { - pub fn result(&self) -> Option { - self.result.get().cloned() - } -} - -impl DownloadTaskHandle { - pub fn result(&self) -> Option { - self.result.get().cloned() - } -} - -// ── Aggregate / per-file progress ─────────────────────────────────────────── - -pub struct ProgressSnapshot { - total: TotalProgressSnapshot, - files: HashMap, -} - -impl ProgressSnapshot { - pub fn total(&self) -> &TotalProgressSnapshot { - &self.total - } - - pub fn file(&self, task_id: Ulid) -> Result<&FileProgress, SessionError> { - self.files.get(&task_id).ok_or(SessionError::InvalidTaskID(task_id)) - } -} - -/// Snapshot of aggregate progress. -#[derive(Clone, Debug, Default)] -pub struct TotalProgressSnapshot { - /// Total bytes known to process (includes deduplicated bytes). - pub total_bytes: u64, - /// Total bytes that have been processed so far. - pub total_bytes_completed: u64, - /// Bytes-processed completion rate, if available. - pub total_bytes_completion_rate: Option, - /// Total bytes that need to be transferred (uploaded/downloaded). - pub total_transfer_bytes: u64, - /// Total bytes that have been transferred so far. - pub total_transfer_bytes_completed: u64, - /// Transfer completion rate, if available. - pub total_transfer_bytes_completion_rate: Option, -} - -/// Snapshot of a single file's progress. -#[derive(Clone, Debug)] -pub struct FileProgress { - /// File name as reported by the data layer. - pub item_name: Arc, - /// Total size of this file in bytes. - pub total_bytes: u64, - /// Bytes of this file processed so far. - pub bytes_completed: u64, -} - -pub type FileProgressSnapshot = FileProgress; - -/// Tracks per-file and aggregate transfer progress for upload commits and download groups. -/// -/// Implements [`TrackingProgressUpdater`]. -/// -/// - Call [`GroupProgress::total`] for an aggregate snapshot (lock-free reads). -/// - Call [`GroupProgress::files`] for a per-file breakdown. -/// -/// All integer counters are stored as [`AtomicU64`]; floating-point completion -/// rates use a `Mutex` (rarely written, never held across await points). -#[derive(Debug)] -pub struct GroupProgress { - // Aggregate totals - total_bytes: AtomicU64, - total_bytes_completed: AtomicU64, - total_transfer_bytes: AtomicU64, - total_transfer_bytes_completed: AtomicU64, - - // Completion rates - total_bytes_completion_rate: Mutex>, - total_transfer_bytes_completion_rate: Mutex>, - - // Per-file: item_name → (total_bytes, bytes_completed) - files: Mutex>, -} - -impl GroupProgress { - /// Create a new tracker with all counters at zero. - pub fn new() -> Self { - Self { - total_bytes: AtomicU64::new(0), - total_bytes_completed: AtomicU64::new(0), - total_transfer_bytes: AtomicU64::new(0), - total_transfer_bytes_completed: AtomicU64::new(0), - total_bytes_completion_rate: Mutex::new(None), - total_transfer_bytes_completion_rate: Mutex::new(None), - files: Mutex::new(HashMap::new()), - } - } - - /// Return a combined snapshot of aggregate and per-file progress. - pub fn snapshot(&self) -> Result { - let total = self.total(); - let files = self.files.lock()?.clone(); - Ok(ProgressSnapshot { total, files }) - } - - /// Return a combined snapshot of aggregate progress. - fn total(&self) -> TotalProgressSnapshot { - TotalProgressSnapshot { - total_bytes: self.total_bytes.load(Ordering::Relaxed), - total_bytes_completed: self.total_bytes_completed.load(Ordering::Relaxed), - total_bytes_completion_rate: self.total_bytes_completion_rate.lock().ok().and_then(|g| *g), - total_transfer_bytes: self.total_transfer_bytes.load(Ordering::Relaxed), - total_transfer_bytes_completed: self.total_transfer_bytes_completed.load(Ordering::Relaxed), - total_transfer_bytes_completion_rate: self - .total_transfer_bytes_completion_rate - .lock() - .ok() - .and_then(|g| *g), - } - } - - fn file(&self, tracking_id: Ulid) -> Result { - self.files - .lock()? - .get(&tracking_id) - .cloned() - .ok_or(SessionError::InvalidTaskID(tracking_id)) - } -} - -impl Default for GroupProgress { - fn default() -> Self { - Self::new() - } -} - -#[async_trait] -impl TrackingProgressUpdater for GroupProgress { - async fn register_updates(&self, updates: ProgressUpdate) { - // Update aggregate integer counters atomically. - self.total_bytes.store(updates.total_bytes, Ordering::Relaxed); - self.total_bytes_completed - .store(updates.total_bytes_completed, Ordering::Relaxed); - self.total_transfer_bytes.store(updates.total_transfer_bytes, Ordering::Relaxed); - self.total_transfer_bytes_completed - .store(updates.total_transfer_bytes_completed, Ordering::Relaxed); - - // Update floating-point rates (brief lock, not held across await). - if let Ok(mut rate) = self.total_bytes_completion_rate.lock() { - *rate = updates.total_bytes_completion_rate; - } - if let Ok(mut rate) = self.total_transfer_bytes_completion_rate.lock() { - *rate = updates.total_transfer_bytes_completion_rate; - } - - // Update per-file progress. - if let Ok(mut items) = self.files.lock() { - for item_update in updates.item_updates { - let entry = items.entry(item_update.tracking_id).or_insert(FileProgress { - item_name: item_update.item_name.clone(), - total_bytes: item_update.total_bytes, - bytes_completed: item_update.bytes_completed, - }); - entry.total_bytes = entry.total_bytes.max(item_update.total_bytes); - entry.bytes_completed = entry.bytes_completed.max(item_update.bytes_completed); - } - } - } -} - -#[cfg(test)] -mod tests { - use std::path::PathBuf; - - use xet_data::processing::XetFileInfo; - use xet_data::progress_tracking::ItemProgressUpdate; - - use super::*; - use crate::xet_session::{DownloadedFile, FileMetadata}; - - // ── GroupProgress unit tests ───────────────────────────────────────────── - - #[test] - // A freshly created GroupProgress has all-zero totals and no completion rates. - fn test_snapshot_empty_initially() { - let p = GroupProgress::new(); - let snapshot = p.snapshot().unwrap(); - let total = snapshot.total(); - assert_eq!(total.total_bytes, 0); - assert_eq!(total.total_bytes_completed, 0); - assert_eq!(total.total_transfer_bytes, 0); - assert_eq!(total.total_transfer_bytes_completed, 0); - assert!(total.total_bytes_completion_rate.is_none()); - assert!(total.total_transfer_bytes_completion_rate.is_none()); - } - - #[test] - // Looking up an unknown tracking ID in a snapshot returns InvalidTaskID. - fn test_snapshot_file_with_unknown_id_returns_error() { - let p = GroupProgress::new(); - let snapshot = p.snapshot().unwrap(); - let unknown_id = Ulid::new(); - let result = snapshot.file(unknown_id); - assert!(matches!(result, Err(SessionError::InvalidTaskID(_)))); - } - - #[tokio::test] - // Per-file bytes_completed uses max semantics: a stale/lower update does not reduce progress. - async fn test_register_updates_per_file_bytes_completed_never_decreases() { - let p = GroupProgress::new(); - let task_id = Ulid::new(); - - p.register_updates(ProgressUpdate { - total_bytes: 100, - total_bytes_completed: 80, - item_updates: vec![ItemProgressUpdate { - tracking_id: task_id, - item_name: "file.bin".into(), - total_bytes: 100, - bytes_completed: 80, - bytes_completion_increment: 80, - }], - ..Default::default() - }) - .await; - - // Simulate a stale/out-of-order update carrying a smaller value. - p.register_updates(ProgressUpdate { - total_bytes: 100, - total_bytes_completed: 40, - item_updates: vec![ItemProgressUpdate { - tracking_id: task_id, - item_name: "file.bin".into(), - total_bytes: 100, - bytes_completed: 40, // lower than previously seen - bytes_completion_increment: 0, - }], - ..Default::default() - }) - .await; - - let snapshot = p.snapshot().unwrap(); - let file = snapshot.file(task_id).unwrap(); - // Max semantics: should still report 80, not the lower 40. - assert_eq!(file.bytes_completed, 80); - } - - // ── TaskHandle unit tests ──────────────────────────────────────────────── - - #[test] - // A TaskHandle with no status Arc (streaming upload) returns an error from status(). - fn test_task_handle_with_no_status_returns_error() { - let progress = Arc::new(GroupProgress::new()); - let handle = TaskHandle { - status: None, - group_progress: progress, - task_id: Ulid::new(), - }; - assert!(handle.status().is_err()); - } - - #[test] - // TaskHandle::progress() returns InvalidTaskID when the task_id has no registered progress. - fn test_task_handle_progress_for_unknown_id_returns_error() { - let progress = Arc::new(GroupProgress::new()); - let handle = TaskHandle { - status: None, - group_progress: progress, - task_id: Ulid::new(), // not registered in GroupProgress - }; - assert!(matches!(handle.progress(), Err(SessionError::InvalidTaskID(_)))); - } - - // ── UploadTaskHandle unit tests ────────────────────────────────────────── - // - // `UploadTaskHandle` wraps a `TaskHandle` and adds a `result` Arc that is - // shared with the internal `InnerUploadTaskHandle` inside `UploadCommit`. - // After `commit()` completes, the internal handle writes the per-file - // `UploadResult` (= `Arc>`) into that - // shared Arc so callers can read it directly from the task handle without - // touching the `commit()` return value. - // - // There are therefore two equivalent ways to retrieve a per-task result: - // 1. `commit()` returns `HashMap`; look up the task using `handle.task_id`. - // 2. Call `handle.result()` directly after `commit()` returns. - // - // The tests below exercise the `result` Arc mechanics in isolation; see - // `upload_commit.rs` for end-to-end integration tests of both patterns. - - #[test] - // UploadTaskHandle::result() returns None before the result Arc is populated. - fn test_upload_task_handle_result_none_before_commit() { - let progress = Arc::new(GroupProgress::new()); - let handle = UploadTaskHandle { - inner: TaskHandle { - status: None, - group_progress: progress, - task_id: Ulid::new(), - }, - result: Arc::new(OnceLock::new()), - }; - assert!(handle.result().is_none()); - } - - #[test] - // UploadTaskHandle::result() returns the value once the shared Arc is populated. - fn test_upload_task_handle_result_some_after_result_set() { - let progress = Arc::new(GroupProgress::new()); - let result_arc = Arc::new(OnceLock::new()); - let handle = UploadTaskHandle { - inner: TaskHandle { - status: None, - group_progress: progress, - task_id: Ulid::new(), - }, - result: result_arc.clone(), - }; - - // Simulate commit() writing the result. - let metadata = Arc::new(Ok(FileMetadata { - tracking_name: Some("file.bin".to_string()), - hash: "abc123".to_string(), - file_size: 42, - sha256: None, - })); - result_arc.set(metadata).unwrap(); - - let result = handle.result().unwrap(); - let meta = result.as_ref().as_ref().unwrap(); - assert_eq!(meta.file_size, 42); - assert_eq!(meta.hash, "abc123"); - } - - // ── DownloadTaskHandle unit tests ──────────────────────────────────────── - // - // `DownloadTaskHandle` follows the same Arc-sharing pattern as - // `UploadTaskHandle`. Its `result` field holds a `DownloadResult` - // (= `Arc>`), populated by `finish()`. - - #[test] - // DownloadTaskHandle::result() returns None before finish() populates the result Arc. - fn test_download_task_handle_result_none_before_finish() { - let progress = Arc::new(GroupProgress::new()); - let handle = DownloadTaskHandle { - inner: TaskHandle { - status: None, - group_progress: progress, - task_id: Ulid::new(), - }, - result: Arc::new(OnceLock::new()), - }; - assert!(handle.result().is_none()); - } - - #[test] - // DownloadTaskHandle::result() returns the value once the shared Arc is populated. - fn test_download_task_handle_result_some_after_result_set() { - let progress = Arc::new(GroupProgress::new()); - let result_arc = Arc::new(OnceLock::new()); - let handle = DownloadTaskHandle { - inner: TaskHandle { - status: None, - group_progress: progress, - task_id: Ulid::new(), - }, - result: result_arc.clone(), - }; - - // Simulate finish() writing the result. - let download_result = Arc::new(Ok(DownloadedFile { - dest_path: PathBuf::from("out/file.bin"), - file_info: XetFileInfo { - hash: "def456".to_string(), - file_size: 99, - sha256: None, - }, - })); - result_arc.set(download_result).unwrap(); - - let result = handle.result().unwrap(); - let dl = result.as_ref().as_ref().unwrap(); - assert_eq!(dl.file_info.file_size, 99); - assert_eq!(dl.dest_path, PathBuf::from("out/file.bin")); - } - - // ── Full register_updates test ─────────────────────────────────────────── - - #[tokio::test] - // A single register_updates call populates all aggregate fields and per-file entries correctly. - async fn test_commit_progress_register_updates() { - let p = GroupProgress::new(); - - let file_a = (Ulid::new(), "fileA.bin"); - let file_b = (Ulid::new(), "fileB.bin"); - let update = ProgressUpdate { - total_bytes: 1000, - total_bytes_completed: 400, - total_bytes_completion_rate: Some(0.4), - total_transfer_bytes: 800, - total_transfer_bytes_completed: 300, - total_transfer_bytes_completion_rate: Some(0.375), - item_updates: vec![ - ItemProgressUpdate { - tracking_id: file_a.0, - item_name: file_a.1.into(), - total_bytes: 500, - bytes_completed: 200, - bytes_completion_increment: 200, - }, - ItemProgressUpdate { - tracking_id: file_b.0, - item_name: file_b.1.into(), - total_bytes: 500, - bytes_completed: 200, - bytes_completion_increment: 200, - }, - ], - ..Default::default() - }; - - p.register_updates(update).await; - - let snapshot = p.snapshot().unwrap(); - let total = snapshot.total(); - assert_eq!(total.total_bytes, 1000); - assert_eq!(total.total_bytes_completed, 400); - assert_eq!(total.total_bytes_completion_rate, Some(0.4)); - assert_eq!(total.total_transfer_bytes, 800); - assert_eq!(total.total_transfer_bytes_completed, 300); - assert_eq!(total.total_transfer_bytes_completion_rate, Some(0.375)); - let fa = snapshot.file(file_a.0).unwrap(); - assert_eq!(fa.total_bytes, 500); - assert_eq!(fa.bytes_completed, 200); - let fb = snapshot.file(file_b.0).unwrap(); - assert_eq!(fb.total_bytes, 500); - assert_eq!(fb.bytes_completed, 200); - } -} diff --git a/xet_pkg/src/xet_session/session.rs b/xet_pkg/src/xet_session/session.rs index 3af2c28e..fcd3f60d 100644 --- a/xet_pkg/src/xet_session/session.rs +++ b/xet_pkg/src/xet_session/session.rs @@ -8,8 +8,8 @@ use std::task::{Context, Waker}; use http::HeaderMap; use tracing::info; -use ulid::Ulid; use xet_client::cas_client::auth::TokenRefresher; +use xet_data::progress_tracking::UniqueID; use xet_runtime::RuntimeError; use xet_runtime::config::XetConfig; use xet_runtime::core::XetRuntime; @@ -60,12 +60,12 @@ pub struct XetSessionInner { pub(super) custom_headers: Option>, // Track active upload commits and download groups. - pub(super) active_upload_commits: Mutex>, - pub(super) active_download_groups: Mutex>, + pub(super) active_upload_commits: Mutex>, + pub(super) active_download_groups: Mutex>, // Session state state: Mutex, - pub(super) id: Ulid, + pub(super) id: UniqueID, } /// Probe whether a tokio runtime handle meets the requirements for External mode. @@ -335,7 +335,7 @@ impl XetSession { active_upload_commits: Mutex::new(HashMap::new()), active_download_groups: Mutex::new(HashMap::new()), state: Mutex::new(SessionState::Alive), - id: Ulid::new(), + id: UniqueID::new(), }), } } @@ -518,12 +518,12 @@ impl XetSession { Ok(()) } - pub(super) fn finish_upload_commit(&self, commit_id: Ulid) -> Result<(), SessionError> { + pub(super) fn finish_upload_commit(&self, commit_id: UniqueID) -> Result<(), SessionError> { self.active_upload_commits.lock()?.remove(&commit_id); Ok(()) } - pub(super) fn finish_download_group(&self, group_id: Ulid) -> Result<(), SessionError> { + pub(super) fn finish_download_group(&self, group_id: UniqueID) -> Result<(), SessionError> { self.active_download_groups.lock()?.remove(&group_id); Ok(()) } @@ -646,7 +646,7 @@ mod tests { fn test_finish_upload_commit_with_unknown_id_is_noop() { let session = XetSessionBuilder::new().build().unwrap(); let _c1 = session.new_upload_commit_blocking().unwrap(); - let unknown_id = ulid::Ulid::new(); + let unknown_id = UniqueID::new(); assert!(session.finish_upload_commit(unknown_id).is_ok()); assert_eq!(session.active_upload_commits.lock().unwrap().len(), 1); } diff --git a/xet_pkg/src/xet_session/tasks.rs b/xet_pkg/src/xet_session/tasks.rs new file mode 100644 index 00000000..c423f67f --- /dev/null +++ b/xet_pkg/src/xet_session/tasks.rs @@ -0,0 +1,201 @@ +//! Progress tracking for upload commits and download groups. + +use std::ops::Deref; +use std::sync::{Arc, Mutex, OnceLock}; + +use xet_data::progress_tracking::UniqueID; + +use super::SessionError; +use super::download_group::DownloadResult; +use super::upload_commit::UploadResult; + +/// Lifecycle state of a single upload or download task. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum TaskStatus { + /// Task has been queued but has not started executing yet. + Queued, + /// Task is actively transferring data. + Running, + /// Task finished successfully. + Completed, + /// Task encountered an error and did not complete. + Failed, + /// Task was cancelled before it could complete. + Cancelled, +} + +impl TaskStatus { + pub(super) fn mark_running(status: &Arc>) { + if let Ok(mut current) = status.lock() + && matches!(*current, TaskStatus::Queued) + { + *current = TaskStatus::Running; + } + } + + pub(super) fn mark_terminal(status: &Arc>, terminal_status: TaskStatus) { + if let Ok(mut current) = status.lock() + && !matches!(*current, TaskStatus::Cancelled) + { + *current = terminal_status; + } + } + + pub(super) fn mark_cancelled(status: &Arc>) { + if let Ok(mut current) = status.lock() { + *current = TaskStatus::Cancelled; + } + } +} + +#[derive(Debug)] +pub struct TaskHandle { + pub(super) status: Option>>, + /// Id of the task, can be used to retrieve per-task progress and result. + pub task_id: UniqueID, +} + +#[derive(Debug)] +pub struct UploadTaskHandle { + pub(super) inner: TaskHandle, + pub(super) result: Arc>, +} + +impl Deref for UploadTaskHandle { + type Target = TaskHandle; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +#[derive(Debug)] +pub struct DownloadTaskHandle { + pub(super) inner: TaskHandle, + pub(super) result: Arc>, +} + +impl Deref for DownloadTaskHandle { + type Target = TaskHandle; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +impl TaskHandle { + pub fn status(&self) -> Result { + if let Some(status) = &self.status { + Ok(*status.lock()?) + } else { + Err(SessionError::other("status not available")) + } + } +} + +impl UploadTaskHandle { + pub fn result(&self) -> Option { + self.result.get().cloned() + } +} + +impl DownloadTaskHandle { + pub fn result(&self) -> Option { + self.result.get().cloned() + } +} + +#[cfg(test)] +mod tests { + use std::path::PathBuf; + + use xet_data::processing::XetFileInfo; + + use super::*; + use crate::xet_session::{DownloadedFile, FileMetadata}; + + #[test] + fn test_task_handle_with_no_status_returns_error() { + let handle = TaskHandle { + status: None, + task_id: UniqueID::new(), + }; + assert!(handle.status().is_err()); + } + + #[test] + fn test_upload_task_handle_result_none_before_commit() { + let handle = UploadTaskHandle { + inner: TaskHandle { + status: None, + task_id: UniqueID::new(), + }, + result: Arc::new(OnceLock::new()), + }; + assert!(handle.result().is_none()); + } + + #[test] + fn test_upload_task_handle_result_some_after_result_set() { + let result_arc = Arc::new(OnceLock::new()); + let handle = UploadTaskHandle { + inner: TaskHandle { + status: None, + task_id: UniqueID::new(), + }, + result: result_arc.clone(), + }; + + let metadata = Arc::new(Ok(FileMetadata { + tracking_name: Some("file.bin".to_string()), + hash: "abc123".to_string(), + file_size: 42, + sha256: None, + })); + result_arc.set(metadata).unwrap(); + + let result = handle.result().unwrap(); + let meta = result.as_ref().as_ref().unwrap(); + assert_eq!(meta.file_size, 42); + assert_eq!(meta.hash, "abc123"); + } + + #[test] + fn test_download_task_handle_result_none_before_finish() { + let handle = DownloadTaskHandle { + inner: TaskHandle { + status: None, + task_id: UniqueID::new(), + }, + result: Arc::new(OnceLock::new()), + }; + assert!(handle.result().is_none()); + } + + #[test] + fn test_download_task_handle_result_some_after_result_set() { + let result_arc = Arc::new(OnceLock::new()); + let handle = DownloadTaskHandle { + inner: TaskHandle { + status: None, + task_id: UniqueID::new(), + }, + result: result_arc.clone(), + }; + + let download_result = Arc::new(Ok(DownloadedFile { + dest_path: PathBuf::from("out/file.bin"), + file_info: XetFileInfo { + hash: "def456".to_string(), + file_size: 99, + sha256: None, + }, + })); + result_arc.set(download_result).unwrap(); + + let result = handle.result().unwrap(); + let dl = result.as_ref().as_ref().unwrap(); + assert_eq!(dl.file_info.file_size, 99); + assert_eq!(dl.dest_path, PathBuf::from("out/file.bin")); + } +} diff --git a/xet_pkg/src/xet_session/upload_commit.rs b/xet_pkg/src/xet_session/upload_commit.rs index 4e4122e7..3f62bbbd 100644 --- a/xet_pkg/src/xet_session/upload_commit.rs +++ b/xet_pkg/src/xet_session/upload_commit.rs @@ -5,15 +5,15 @@ use std::path::PathBuf; use std::sync::{Arc, Mutex, OnceLock, RwLock}; use tokio::task::JoinHandle; -use ulid::Ulid; -use xet_data::processing::data_client::{clean_bytes, clean_file}; +use xet_data::DataError; use xet_data::processing::{FileUploadSession, Sha256Policy, SingleFileCleaner, XetFileInfo}; +use xet_data::progress_tracking::{GroupProgressReport, UniqueID}; use xet_runtime::core::XetRuntime; use super::common::{GroupState, create_translator_config}; use super::errors::SessionError; -use super::progress::{GroupProgress, ProgressSnapshot, TaskHandle, TaskStatus, UploadTaskHandle}; use super::session::XetSession; +use super::tasks::{TaskHandle, TaskStatus, UploadTaskHandle}; /// API for grouping related file uploads into a single atomic commit. /// @@ -56,18 +56,16 @@ impl UploadCommit { /// Create a new upload commit from an **async** context. Initialisation logic shared by the sync and async /// constructors. pub(super) async fn new(session: XetSession) -> Result { - let commit_id = Ulid::new(); - let progress = Arc::new(GroupProgress::new()); + let commit_id = UniqueID::new(); let config = create_translator_config(&session)?; - let progress_updater = progress.clone() as Arc; - let upload_session = FileUploadSession::new(Arc::new(config), Some(progress_updater)).await?; + let upload_session = FileUploadSession::new(Arc::new(config)).await?; let inner = Arc::new(UploadCommitInner { commit_id, session, active_tasks: RwLock::new(HashMap::new()), - progress, upload_session: Mutex::new(Some(upload_session)), + last_progress: Mutex::new(None), state: tokio::sync::Mutex::new(GroupState::Alive), }); @@ -75,7 +73,7 @@ impl UploadCommit { } /// Get the commit ID. - pub(super) fn id(&self) -> Ulid { + pub(super) fn id(&self) -> UniqueID { self.commit_id } @@ -160,10 +158,9 @@ impl UploadCommit { /// /// # Returns /// - /// A `(`[`TaskHandle`]`, `[`SingleFileCleaner`]`)` pair. The [`TaskHandle`] tracks task - /// lifecycle but carries no upload result — call [`SingleFileCleaner::finish`] on the cleaner - /// to obtain the [`FileMetadata`](xet_data::processing::FileMetadata) once all bytes have - /// been streamed. + /// A `(`[`TaskHandle`]`, `[`SingleFileCleaner`]`)` pair. The [`TaskHandle`] carries only + /// the task ID (no internal status/result), and [`SingleFileCleaner::finish`] returns the + /// [`FileMetadata`](xet_data::processing::FileMetadata) once all bytes have been streamed. pub async fn upload_file( &self, file_name: Option, @@ -210,8 +207,15 @@ impl UploadCommit { } /// Return a snapshot of progress for every queued upload. - pub fn get_progress(&self) -> Result { - self.progress.snapshot() + pub fn get_progress(&self) -> Result { + let session_opt = self.upload_session.lock()?.clone(); + if let Some(upload_session) = session_opt { + return Ok(upload_session.report()); + } + if let Some(cached) = self.last_progress.lock()?.as_ref() { + return Ok(cached.clone()); + } + Ok(GroupProgressReport::default()) } /// Wait for all uploads to complete and push metadata to the CAS server. @@ -222,7 +226,7 @@ impl UploadCommit { /// /// Consumes `self` — subsequent calls on any clone will return /// [`SessionError::AlreadyCommitted`]. - pub async fn commit(self) -> Result, SessionError> { + pub async fn commit(self) -> Result, SessionError> { let inner = self.inner.clone(); self.session .dispatch("commit", async move { inner.handle_commit().await }) @@ -298,12 +302,17 @@ impl UploadCommit { })? } + /// Blocking version of [`get_progress`](Self::get_progress). + pub fn get_progress_blocking(&self) -> Result { + self.get_progress() + } + /// Blocking version of [`commit`](Self::commit). /// /// # Panics /// /// Panics if called from within a tokio async runtime. - pub fn commit_blocking(self) -> Result, SessionError> { + pub fn commit_blocking(self) -> Result, SessionError> { let commit = self.clone(); self.runtime().external_run_async_task(commit.commit())? } @@ -320,7 +329,7 @@ pub type UploadResult = Arc>; struct InnerUploadTaskHandle { status: Arc>, tracking_name: Option, - join_handle: JoinHandle>, + join_handle: JoinHandle>, result: Arc>, } @@ -328,18 +337,18 @@ struct InnerUploadTaskHandle { /// Accessed through `Arc`; do not use this type directly. #[doc(hidden)] pub struct UploadCommitInner { - commit_id: Ulid, + commit_id: UniqueID, pub(super) session: XetSession, // Active upload tasks for this commit - active_tasks: RwLock>, - - // Aggregate + per-file progress, fed into FileUploadSession as a TrackingProgressUpdater - progress: Arc, + active_tasks: RwLock>, // Shared upload session (FileUploadSession from data crate) upload_session: Mutex>>, + // Final progress cached when session is cleared (enables get_progress after commit) + last_progress: Mutex>, + // tokio::sync::Mutex (not std) because registration methods hold this lock across // .await points (e.g. start_clean in start_upload_file) to serialise with commit. // DownloadGroupInner uses std::sync::Mutex because its registration is synchronous. @@ -358,119 +367,33 @@ impl UploadCommitInner { } } - /// Spawn a runtime task that performs the actual file upload from path - fn spawn_upload_from_path_task( - &self, - upload_session: Arc, - file_path: PathBuf, - status: Arc>, - tracking_id: Ulid, - sha256: Sha256Policy, - ) -> JoinHandle> { - let semaphore = self.runtime().common().file_ingestion_semaphore.clone(); - self.runtime().spawn(async move { - let _permit = semaphore.acquire().await?; - - // Only transition Queued → Running; bail if abort() already set Cancelled. - { - let mut s = status.lock()?; - if !matches!(*s, TaskStatus::Queued) { - return Err(SessionError::Aborted); - } - *s = TaskStatus::Running; - } - - let result = clean_file(upload_session, &file_path, sha256, Some(tracking_id)) - .await - .map_err(SessionError::from) - .map(|(file_info, _metrics)| file_info); - - let new_status = if result.is_ok() { - TaskStatus::Completed - } else { - TaskStatus::Failed - }; - // Only overwrite if still Running — abort() may have set Cancelled concurrently. - let mut s = status.lock()?; - if matches!(*s, TaskStatus::Running) { - *s = new_status; - } - - result - }) - } - - /// Spawn a runtime task that performs the actual bytes upload - fn spawn_upload_bytes_task( - &self, - upload_session: Arc, - bytes: Vec, - status: Arc>, - tracking_id: Ulid, - sha256: Sha256Policy, - ) -> JoinHandle> { - let semaphore = self.runtime().common().file_ingestion_semaphore.clone(); - self.runtime().spawn(async move { - let _permit = semaphore.acquire().await?; - - // Only transition Queued → Running; bail if abort() already set Cancelled. - { - let mut s = status.lock()?; - if !matches!(*s, TaskStatus::Queued) { - return Err(SessionError::Aborted); - } - *s = TaskStatus::Running; - } - - let result = clean_bytes(upload_session, bytes, Some(tracking_id), sha256) - .await - .map_err(SessionError::from) - .map(|(file_info, _metrics)| file_info); - - let new_status = if result.is_ok() { - TaskStatus::Completed - } else { - TaskStatus::Failed - }; - // Only overwrite if still Running — abort() may have set Cancelled concurrently. - let mut s = status.lock()?; - if matches!(*s, TaskStatus::Running) { - *s = new_status; - } - - result - }) - } - pub(super) async fn start_upload_file_from_path( &self, file_path: PathBuf, sha256: Sha256Policy, ) -> Result { - // Hold the state lock for the duration of this function so commit() will not run - // when an upload task is registering. - let state = self.state.lock().await; - Self::check_accepting_tasks(&state)?; + let upload_session = { + let state = self.state.lock().await; + Self::check_accepting_tasks(&state)?; + + let Some(upload_session) = self.upload_session.lock()?.clone() else { + return Err(SessionError::other("Upload session not initialized")); + }; + upload_session + }; + + let (task_id, join_handle) = upload_session.spawn_upload_from_path(file_path.clone(), sha256).await?; - let tracking_id = Ulid::new(); let status = Arc::new(Mutex::new(TaskStatus::Queued)); let result: Arc> = Arc::new(OnceLock::new()); let task_handle = UploadTaskHandle { inner: TaskHandle { status: Some(status.clone()), - group_progress: self.progress.clone(), - task_id: tracking_id, + task_id, }, result: result.clone(), }; - let Some(upload_session) = self.upload_session.lock()?.clone() else { - return Err(SessionError::other("Upload session not initialized")); - }; - - let join_handle = - self.spawn_upload_from_path_task(upload_session, file_path.clone(), status.clone(), tracking_id, sha256); - let handle = InnerUploadTaskHandle { status, tracking_name: file_path.to_str().map(|s| s.to_owned()), @@ -478,7 +401,8 @@ impl UploadCommitInner { result, }; - self.active_tasks.write()?.insert(tracking_id, handle); + TaskStatus::mark_running(&handle.status); + self.active_tasks.write()?.insert(task_id, handle); Ok(task_handle) } @@ -495,9 +419,6 @@ impl UploadCommitInner { file_size: u64, sha256: Sha256Policy, ) -> Result<(TaskHandle, SingleFileCleaner), SessionError> { - let tracking_id = Ulid::new(); - // Hold the state lock across start_clean so handle_commit cannot finalise - // the session between the state check and the creation of the cleaner. let state = self.state.lock().await; Self::check_accepting_tasks(&state)?; @@ -505,14 +426,13 @@ impl UploadCommitInner { return Err(SessionError::other("Upload session not initialized")); }; - let task_handle = TaskHandle { - status: None, // upload directly managed by user - not internally managed - group_progress: self.progress.clone(), - task_id: tracking_id, - }; let tracking_name: Option> = tracking_name.as_deref().map(Arc::from); - let cleaner = upload_session.start_clean(tracking_name, file_size, sha256, tracking_id).await; + let (id, cleaner) = upload_session.start_clean(tracking_name, file_size, sha256)?; + let task_handle = TaskHandle { + status: None, + task_id: id, + }; Ok((task_handle, cleaner)) } @@ -523,29 +443,29 @@ impl UploadCommitInner { sha256: Sha256Policy, tracking_name: Option, ) -> Result { - // Hold the state lock for the duration of this function so commit() will not run - // when an upload task is registering. - let state = self.state.lock().await; - Self::check_accepting_tasks(&state)?; + let upload_session = { + let state = self.state.lock().await; + Self::check_accepting_tasks(&state)?; + + let Some(upload_session) = self.upload_session.lock()?.clone() else { + return Err(SessionError::other("Upload session not initialized")); + }; + upload_session + }; + + let name: Option> = tracking_name.as_deref().map(Arc::from); + let (task_id, join_handle) = upload_session.spawn_upload_bytes(bytes, sha256, name).await?; - let tracking_id = Ulid::new(); let status = Arc::new(Mutex::new(TaskStatus::Queued)); let result: Arc> = Arc::new(OnceLock::new()); let task_handle = UploadTaskHandle { inner: TaskHandle { status: Some(status.clone()), - group_progress: self.progress.clone(), - task_id: tracking_id, + task_id, }, result: result.clone(), }; - let Some(upload_session) = self.upload_session.lock()?.clone() else { - return Err(SessionError::other("Upload session not initialized")); - }; - - let join_handle = self.spawn_upload_bytes_task(upload_session, bytes, status.clone(), tracking_id, sha256); - let handle = InnerUploadTaskHandle { status, tracking_name, @@ -553,13 +473,14 @@ impl UploadCommitInner { result, }; - self.active_tasks.write()?.insert(tracking_id, handle); + TaskStatus::mark_running(&handle.status); + self.active_tasks.write()?.insert(task_id, handle); Ok(task_handle) } /// Join all active upload tasks and finalise the upload session. - pub(super) async fn handle_commit(&self) -> Result, SessionError> { + pub(super) async fn handle_commit(&self) -> Result, SessionError> { // Mark as not accepting new tasks. The tokio state lock serialises this // against all three registration methods, including start_upload_file // which holds it across the start_clean await. @@ -580,6 +501,7 @@ impl UploadCommitInner { for (task_id, handle) in active_tasks { match handle.join_handle.await.map_err(SessionError::from) { Ok(Ok(file_info)) => { + TaskStatus::mark_terminal(&handle.status, TaskStatus::Completed); let result = Arc::new(Ok(FileMetadata { tracking_name: handle.tracking_name, hash: file_info.hash().to_string(), @@ -587,18 +509,20 @@ impl UploadCommitInner { sha256: file_info.sha256().map(str::to_owned), })); results.insert(task_id, result.clone()); - // Update result to the external task handle, this is the only place setting - // the result, so no error will happen. let _ = handle.result.set(result); }, - Ok(Err(task_err)) => { - let result = Arc::new(Err(task_err)); + Ok(Err(data_err)) => { + TaskStatus::mark_terminal(&handle.status, TaskStatus::Failed); + let result = Arc::new(Err(SessionError::from(data_err))); results.insert(task_id, result.clone()); - // Update result to the external task handle, this is the only place setting - // the result, so no error will happen. let _ = handle.result.set(result); }, Err(e) => { + if matches!(e, SessionError::Cancelled(_)) { + TaskStatus::mark_cancelled(&handle.status); + } else { + TaskStatus::mark_terminal(&handle.status, TaskStatus::Failed); + } if join_err.is_none() { join_err = Some(e); } @@ -612,7 +536,8 @@ impl UploadCommitInner { // Finalize upload session let session = self.upload_session.lock()?.take(); if let Some(session) = session { - session.finalize().await?; + let (_metrics, report) = session.finalize_with_report().await?; + *self.last_progress.lock()? = Some(report); } // Mark as committed @@ -624,11 +549,7 @@ impl UploadCommitInner { Ok(results) } - fn runtime(&self) -> &XetRuntime { - &self.session.runtime - } - - /// Cancel all tasks and set task status to "Cancelled". + /// Cancel all tasks and abort their join handles. /// /// Called only from [`XetSession::abort`], which always calls /// `perform_sigint_shutdown()` first — so the runtime is already shutting @@ -653,8 +574,8 @@ impl UploadCommitInner { self.upload_session.lock()?.take(); let active_tasks = std::mem::take(&mut *self.active_tasks.write()?); for (_tracking_id, inner_task_handle) in active_tasks { + TaskStatus::mark_cancelled(&inner_task_handle.status); inner_task_handle.join_handle.abort(); - let _ = inner_task_handle.status.lock().map(|mut s| *s = TaskStatus::Cancelled); } Ok(()) @@ -684,7 +605,6 @@ mod tests { use tempfile::{TempDir, tempdir}; use super::*; - use crate::xet_session::progress::TaskStatus; use crate::xet_session::session::{RuntimeMode, XetSession, XetSessionBuilder}; async fn local_session(temp: &TempDir) -> Result> { @@ -770,10 +690,9 @@ mod tests { async fn test_get_progress_empty_initially() { let session = XetSessionBuilder::new().build_async().await.unwrap(); let commit = session.new_upload_commit().await.unwrap(); - let snapshot = commit.get_progress().unwrap(); - let total = snapshot.total(); - assert_eq!(total.total_bytes, 0); - assert_eq!(total.total_bytes_completed, 0); + let report = commit.get_progress().unwrap(); + assert_eq!(report.total_bytes, 0); + assert_eq!(report.total_bytes_completed, 0); } // ── Commit lifecycle ───────────────────────────────────────────────────── @@ -900,6 +819,8 @@ mod tests { .unwrap(); commit.inner.abort().unwrap(); assert!(matches!(handle.status().unwrap(), TaskStatus::Cancelled)); + assert!(commit.inner.active_tasks.read().unwrap().is_empty()); + assert!(commit.inner.upload_session.lock().unwrap().is_none()); } #[tokio::test(flavor = "multi_thread")] @@ -911,7 +832,6 @@ mod tests { let session = XetSessionBuilder::new().build_async().await.unwrap(); let commit = session.new_upload_commit().await.unwrap(); - // Queue a task so we can verify active_tasks draining still happens. let handle = commit .upload_bytes(b"data".to_vec(), Sha256Policy::Compute, None) .await @@ -924,14 +844,11 @@ mod tests { commit.inner.abort().unwrap(); // State was NOT updated — abort() skipped the state flag when lock was held. - assert!(matches!(*guard, GroupState::Alive), "state must remain Alive when lock was contended"); + assert!(matches!(*guard, GroupState::Alive)); drop(guard); - // active_tasks are always drained; already-queued tasks are Cancelled. assert!(matches!(handle.status().unwrap(), TaskStatus::Cancelled)); - - // upload_session is always cleared, preventing future start_upload_file calls - // from obtaining a session and preventing handle_commit from calling finalize. + assert!(commit.inner.active_tasks.read().unwrap().is_empty()); assert!(commit.inner.upload_session.lock().unwrap().is_none()); } @@ -960,6 +877,10 @@ mod tests { .upload_bytes(data.to_vec(), Sha256Policy::Compute, Some("hello.bin".into())) .await .unwrap(); + assert!(matches!( + task_handle.status().unwrap(), + TaskStatus::Queued | TaskStatus::Running | TaskStatus::Completed + )); let results = commit.commit().await.unwrap(); assert_eq!(results.len(), 1); let meta = results.get(&task_handle.task_id).unwrap().as_ref().as_ref().unwrap(); @@ -967,6 +888,33 @@ mod tests { assert!(!meta.hash.is_empty()); assert!(meta.sha256.is_some()); assert_eq!(meta.sha256.as_deref().unwrap().len(), 64); + assert!(matches!(task_handle.status().unwrap(), TaskStatus::Completed)); + } + + #[tokio::test(flavor = "multi_thread")] + // task_id returned by upload_bytes must match the per-item progress entry id. + async fn test_upload_bytes_task_id_matches_progress_item_id() { + let temp = tempdir().unwrap(); + let session = local_session(&temp).await.unwrap(); + let commit = session.new_upload_commit().await.unwrap(); + + let handle = commit + .upload_bytes(b"id-match".to_vec(), Sha256Policy::Compute, Some("id.bin".into())) + .await + .unwrap(); + + let upload_session = commit.inner.upload_session.lock().unwrap().clone().unwrap(); + + let mut reports = HashMap::new(); + for _ in 0..50 { + reports = upload_session.item_reports(); + if !reports.is_empty() { + break; + } + tokio::time::sleep(Duration::from_millis(10)).await; + } + + assert!(reports.contains_key(&handle.task_id)); } #[tokio::test(flavor = "multi_thread")] @@ -986,6 +934,39 @@ mod tests { assert!(!meta.hash.is_empty()); assert!(meta.sha256.is_some()); assert_eq!(meta.sha256.as_deref().unwrap().len(), 64); + assert!(matches!(handle.status().unwrap(), TaskStatus::Completed)); + } + + #[tokio::test(flavor = "multi_thread")] + // A task that returns an upload error transitions to Failed status. + async fn test_upload_task_status_failed_for_task_error() { + let session = XetSessionBuilder::new().build_async().await.unwrap(); + let commit = session.new_upload_commit().await.unwrap(); + + let task_id = UniqueID::new(); + let status = Arc::new(Mutex::new(TaskStatus::Running)); + let result: Arc> = Arc::new(OnceLock::new()); + let synthetic_handle = UploadTaskHandle { + inner: TaskHandle { + status: Some(status.clone()), + task_id, + }, + result: result.clone(), + }; + let failing_join = + tokio::spawn(async { Err(DataError::InternalError("synthetic upload failure".to_string())) }); + let failing_inner = InnerUploadTaskHandle { + status, + tracking_name: Some("synthetic".to_string()), + join_handle: failing_join, + result, + }; + commit.inner.active_tasks.write().unwrap().insert(task_id, failing_inner); + + let results = commit.commit().await.unwrap(); + let task_result = results.get(&task_id).expect("task_id must be present in results"); + assert!(task_result.is_err()); + assert!(matches!(synthetic_handle.status().unwrap(), TaskStatus::Failed)); } #[tokio::test(flavor = "multi_thread")] @@ -1121,8 +1102,11 @@ mod tests { .await .unwrap(); commit.commit().await.unwrap(); - let snapshot = progress_observer.get_progress().unwrap(); - assert!(snapshot.total().total_bytes_completed > 0); + let report = progress_observer.get_progress().unwrap(); + assert_eq!(report.total_bytes, data.len() as u64); + assert_eq!(report.total_bytes_completed, data.len() as u64); + assert_eq!(report.total_transfer_bytes, report.total_transfer_bytes_completed); + assert!(report.total_transfer_bytes_completed <= data.len() as u64); } // ── Non-tokio executor (Owned-mode bridge) ──────────────────────────────── @@ -1307,8 +1291,21 @@ mod tests { let progress_observer = commit.clone(); commit.upload_bytes_blocking(data.to_vec(), Sha256Policy::Compute, Some("prog.bin".into()))?; commit.commit_blocking()?; - let snapshot = progress_observer.get_progress()?; - assert!(snapshot.total().total_bytes_completed > 0); + let snapshot = progress_observer.get_progress_blocking()?; + assert_eq!(snapshot.total_bytes, data.len() as u64); + assert_eq!(snapshot.total_bytes_completed, data.len() as u64); + assert_eq!(snapshot.total_transfer_bytes, snapshot.total_transfer_bytes_completed); + assert!(snapshot.total_transfer_bytes_completed <= data.len() as u64); + Ok(()) + } + + #[test] + fn test_blocking_upload_file_returns_handle_without_status() -> Result<(), Box> { + let temp = tempdir()?; + let session = local_session_sync(&temp)?; + let commit = session.new_upload_commit_blocking()?; + let (handle, _cleaner) = commit.upload_file_blocking(Some("stream.bin".into()), 1024, Sha256Policy::Compute)?; + assert!(handle.status().is_err()); Ok(()) } diff --git a/xet_pkg/tests/test_legacy_data_client.rs b/xet_pkg/tests/test_legacy_data_client.rs new file mode 100644 index 00000000..d0753099 --- /dev/null +++ b/xet_pkg/tests/test_legacy_data_client.rs @@ -0,0 +1,325 @@ +use std::fs; +use std::sync::Arc; + +use async_trait::async_trait; +use more_asserts::assert_le; +use tempfile::TempDir; +use tokio::sync::Mutex; +use xet::legacy::progress_tracking::{ItemProgressUpdate, ProgressUpdate, TrackingProgressUpdater}; +use xet::legacy::{Sha256Policy, XetFileInfo, data_client}; +use xet_client::cas_client::LocalTestServerBuilder; + +/// A test `TrackingProgressUpdater` that records all updates. +#[derive(Debug, Default)] +struct RecordingUpdater { + updates: Mutex>, +} + +#[async_trait] +impl TrackingProgressUpdater for RecordingUpdater { + async fn register_updates(&self, update: ProgressUpdate) { + self.updates.lock().await.push(update); + } +} + +impl RecordingUpdater { + async fn total_item_updates(&self) -> Vec { + let updates = self.updates.lock().await; + updates.iter().flat_map(|u| u.item_updates.clone()).collect() + } +} + +fn make_endpoint(server: &xet_client::cas_client::LocalTestServer) -> Option { + Some(server.http_endpoint().to_string()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_upload_bytes_and_download_roundtrip() { + let server = LocalTestServerBuilder::new().start().await; + let endpoint = make_endpoint(&server); + + let contents: Vec> = vec![b"hello world".to_vec(), b"foo bar baz".to_vec(), vec![0xAB; 4096]]; + let policies = vec![Sha256Policy::Compute; contents.len()]; + + let file_infos = + data_client::upload_bytes_async(contents.clone(), policies, endpoint.clone(), None, None, None, None) + .await + .unwrap(); + + assert_eq!(file_infos.len(), 3); + for info in &file_infos { + assert!(!info.hash.is_empty()); + assert!(info.file_size > 0); + } + + let download_dir = TempDir::new().unwrap(); + let download_pairs: Vec<(XetFileInfo, String)> = file_infos + .iter() + .enumerate() + .map(|(i, info)| { + let path = download_dir.path().join(format!("file_{i}")); + (info.clone(), path.to_string_lossy().to_string()) + }) + .collect(); + + let paths = data_client::download_async(download_pairs, endpoint, None, None, None, None) + .await + .unwrap(); + + assert_eq!(paths.len(), 3); + for (i, path) in paths.iter().enumerate() { + let downloaded = fs::read(path).unwrap(); + assert_eq!(downloaded, contents[i], "content mismatch for file {i}"); + } + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_upload_files_and_download_roundtrip() { + let server = LocalTestServerBuilder::new().start().await; + let endpoint = make_endpoint(&server); + + let src_dir = TempDir::new().unwrap(); + let file_data: Vec<(&str, Vec)> = vec![ + ("small.txt", b"small file content".to_vec()), + ("medium.bin", vec![0xCD; 8192]), + ("empty.txt", vec![]), + ]; + + let mut file_paths = Vec::new(); + let mut policies = Vec::new(); + for (name, data) in &file_data { + let path = src_dir.path().join(name); + fs::write(&path, data).unwrap(); + file_paths.push(path.to_string_lossy().to_string()); + policies.push(Sha256Policy::Compute); + } + + let file_infos = data_client::upload_async(file_paths, policies, endpoint.clone(), None, None, None, None) + .await + .unwrap(); + + assert_eq!(file_infos.len(), 3); + + let download_dir = TempDir::new().unwrap(); + let download_pairs: Vec<(XetFileInfo, String)> = file_infos + .iter() + .enumerate() + .map(|(i, info)| { + let path = download_dir.path().join(format!("out_{i}")); + (info.clone(), path.to_string_lossy().to_string()) + }) + .collect(); + + let paths = data_client::download_async(download_pairs, endpoint, None, None, None, None) + .await + .unwrap(); + + for (i, path) in paths.iter().enumerate() { + let downloaded = fs::read(path).unwrap(); + assert_eq!(downloaded, file_data[i].1, "content mismatch for {}", file_data[i].0); + } + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_upload_bytes_with_progress_updater() { + let server = LocalTestServerBuilder::new().start().await; + let endpoint = make_endpoint(&server); + + let contents: Vec> = vec![vec![0x42; 4096], vec![0x99; 8192]]; + let policies = vec![Sha256Policy::Compute; contents.len()]; + let updater = Arc::new(RecordingUpdater::default()); + + let file_infos = + data_client::upload_bytes_async(contents, policies, endpoint, None, None, Some(updater.clone()), None) + .await + .unwrap(); + + assert_eq!(file_infos.len(), 2); + + let updates = updater.updates.lock().await; + assert!(!updates.is_empty(), "should have received progress updates"); + + let last = updates.last().unwrap(); + assert_le!(last.total_bytes_completed, last.total_bytes); + + drop(updates); + let items = updater.total_item_updates().await; + assert!(!items.is_empty(), "should have received item-level updates"); + + let total_reported: u64 = items.iter().map(|u| u.total_bytes).max().unwrap_or(0); + assert!(total_reported > 0); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_download_with_per_file_progress_updaters() { + let server = LocalTestServerBuilder::new().start().await; + let endpoint = make_endpoint(&server); + + let contents: Vec> = vec![vec![0xAA; 2048], vec![0xBB; 4096]]; + let policies = vec![Sha256Policy::Compute; contents.len()]; + + let file_infos = + data_client::upload_bytes_async(contents.clone(), policies, endpoint.clone(), None, None, None, None) + .await + .unwrap(); + + let download_dir = TempDir::new().unwrap(); + let updater_a = Arc::new(RecordingUpdater::default()); + let updater_b = Arc::new(RecordingUpdater::default()); + + let download_pairs: Vec<(XetFileInfo, String)> = file_infos + .iter() + .enumerate() + .map(|(i, info)| { + let path = download_dir.path().join(format!("dl_{i}")); + (info.clone(), path.to_string_lossy().to_string()) + }) + .collect(); + + let updaters: Vec> = + vec![updater_a.clone() as Arc, updater_b.clone()]; + + let paths = data_client::download_async(download_pairs, endpoint, None, None, Some(updaters), None) + .await + .unwrap(); + + for (i, path) in paths.iter().enumerate() { + let downloaded = fs::read(path).unwrap(); + assert_eq!(downloaded, contents[i]); + } + + let updates_a = updater_a.updates.lock().await; + let updates_b = updater_b.updates.lock().await; + + assert!(!updates_a.is_empty(), "updater A should have received updates"); + assert!(!updates_b.is_empty(), "updater B should have received updates"); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_upload_files_with_progress_updater() { + let server = LocalTestServerBuilder::new().start().await; + let endpoint = make_endpoint(&server); + + let src_dir = TempDir::new().unwrap(); + let file_data: Vec<(&str, Vec)> = vec![("big_a.bin", vec![0x11; 16384]), ("big_b.bin", vec![0x22; 16384])]; + + let mut file_paths = Vec::new(); + let mut policies = Vec::new(); + for (name, data) in &file_data { + let path = src_dir.path().join(name); + fs::write(&path, data).unwrap(); + file_paths.push(path.to_string_lossy().to_string()); + policies.push(Sha256Policy::Compute); + } + + let updater = Arc::new(RecordingUpdater::default()); + + let file_infos = + data_client::upload_async(file_paths, policies, endpoint.clone(), None, None, Some(updater.clone()), None) + .await + .unwrap(); + + assert_eq!(file_infos.len(), 2); + + let updates = updater.updates.lock().await; + assert!(!updates.is_empty(), "should have received progress updates"); + + let last = updates.last().unwrap(); + assert_le!(last.total_bytes_completed, last.total_bytes); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_upload_download_large_files() { + let server = LocalTestServerBuilder::new().start().await; + let endpoint = make_endpoint(&server); + + let src_dir = TempDir::new().unwrap(); + + let large_data: Vec = (0..65536u64).map(|i| (i % 251) as u8).collect(); + let path = src_dir.path().join("large.bin"); + fs::write(&path, &large_data).unwrap(); + + let file_infos = data_client::upload_async( + vec![path.to_string_lossy().to_string()], + vec![Sha256Policy::Compute], + endpoint.clone(), + None, + None, + None, + None, + ) + .await + .unwrap(); + + assert_eq!(file_infos.len(), 1); + assert_eq!(file_infos[0].file_size, large_data.len() as u64); + + let download_dir = TempDir::new().unwrap(); + let out_path = download_dir.path().join("large_out.bin"); + + let paths = data_client::download_async( + vec![(file_infos[0].clone(), out_path.to_string_lossy().to_string())], + endpoint, + None, + None, + None, + None, + ) + .await + .unwrap(); + + let downloaded = fs::read(&paths[0]).unwrap(); + assert_eq!(downloaded, large_data); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_progress_updates_are_monotonic() { + let server = LocalTestServerBuilder::new().start().await; + let endpoint = make_endpoint(&server); + + let src_dir = TempDir::new().unwrap(); + let data = vec![0xFFu8; 32768]; + let path = src_dir.path().join("monotonic_test.bin"); + fs::write(&path, &data).unwrap(); + + let updater = Arc::new(RecordingUpdater::default()); + + data_client::upload_async( + vec![path.to_string_lossy().to_string()], + vec![Sha256Policy::Compute], + endpoint, + None, + None, + Some(updater.clone()), + None, + ) + .await + .unwrap(); + + let updates = updater.updates.lock().await; + + let mut prev_completed = 0u64; + let mut prev_total = 0u64; + for update in updates.iter() { + assert!( + update.total_bytes >= prev_total, + "total_bytes decreased: {} -> {}", + prev_total, + update.total_bytes + ); + assert!( + update.total_bytes_completed >= prev_completed, + "total_bytes_completed decreased: {} -> {}", + prev_completed, + update.total_bytes_completed + ); + assert_le!(update.total_bytes_completed, update.total_bytes); + prev_total = update.total_bytes; + prev_completed = update.total_bytes_completed; + } + } +} diff --git a/xet_runtime/src/config/groups/data.rs b/xet_runtime/src/config/groups/data.rs index 9ae7a659..be27f3b3 100644 --- a/xet_runtime/src/config/groups/data.rs +++ b/xet_runtime/src/config/groups/data.rs @@ -51,13 +51,25 @@ crate::config_group!({ /// Use the environment variable `HF_XET_DATA_PROGRESS_UPDATE_INTERVAL` to set this value. ref progress_update_interval : Duration = Duration::from_millis(200); - /// How large of a time window to use for aggregating the progress speed results. + /// Half-life duration for the exponentially weighted moving average used + /// to estimate progress completion speed. Older rate observations are + /// exponentially decayed with this half-life. /// /// The default value is 10sec. /// /// Use the environment variable `HF_XET_DATA_PROGRESS_UPDATE_SPEED_SAMPLING_WINDOW` to set this value. ref progress_update_speed_sampling_window: Duration = Duration::from_secs(10); + /// Minimum number of speed observations before reporting a rate. + /// Until this many updates have been recorded, the completion rate + /// is reported as unknown (None). This avoids displaying noisy + /// initial estimates. + /// + /// The default value is 4. + /// + /// Use the environment variable `HF_XET_DATA_PROGRESS_UPDATE_SPEED_MIN_OBSERVATIONS` to set this value. + ref progress_update_speed_min_observations: u32 = 4; + /// How often do we flush new xorb data to disk on a long running upload session? /// /// The default value is 20sec. diff --git a/xet_runtime/src/utils/unique_id.rs b/xet_runtime/src/utils/unique_id.rs index 598798a2..07d8637e 100644 --- a/xet_runtime/src/utils/unique_id.rs +++ b/xet_runtime/src/utils/unique_id.rs @@ -1,21 +1,68 @@ +use std::fmt; use std::sync::atomic::{AtomicU64, Ordering}; -#[derive(PartialEq, Eq, Clone, Copy)] +static NEXT_ID: AtomicU64 = AtomicU64::new(1); + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct UniqueId(u64); +impl UniqueId { + pub fn new() -> Self { + Self(NEXT_ID.fetch_add(1, Ordering::Relaxed)) + } + + pub fn null() -> Self { + Self(0) + } +} + impl Default for UniqueId { fn default() -> Self { Self::new() } } -impl UniqueId { - pub fn new() -> Self { - static UNIQUE_COUNTER: AtomicU64 = AtomicU64::new(1); - Self(UNIQUE_COUNTER.fetch_add(1, Ordering::Relaxed)) - } - - pub fn null() -> Self { - Self(0) +impl fmt::Display for UniqueId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use super::*; + + #[test] + fn test_unique_id_basics() { + let id1 = UniqueId::new(); + let id2 = UniqueId::new(); + assert_ne!(id1, id2); + + let cloned = id1; + assert_eq!(id1, cloned); + } + + #[test] + fn test_unique_id_display() { + let id = UniqueId::new(); + let s = id.to_string(); + assert!(!s.is_empty()); + } + + #[test] + fn test_unique_id_hash() { + let id = UniqueId::new(); + let mut map = HashMap::new(); + map.insert(id, 42); + assert_eq!(map[&id], 42); + } + + #[test] + fn test_unique_id_null() { + let null_id = UniqueId::null(); + let new_id = UniqueId::new(); + assert_ne!(null_id, new_id); } }