From cacd7132187d1fcd8ebb1966f3e3c45ab4d50fb6 Mon Sep 17 00:00:00 2001 From: Di Xiao Date: Wed, 11 Mar 2026 16:21:27 -0700 Subject: [PATCH] Rework the interface for session task to get result from registered upload (#690) This PR updates the interface for retrieving per-task results after UploadCommit::commit() or DownloadGroup::finish(). The problem with the previous interface is that commit() and finish() return a vector of FileMetadata or DownloadResult, making it difficult for users to associate each result with a specific task. The new interface uses `task_id` as a strong binding bridge: ## Upload per-task result access patterns After commit() completes, there are two equivalent ways to retrieve a per-task FileMetadata result: 1. Lookup in the global result map: ``` let commit = session.new_upload_commit()?; let handle = commit.upload_from_path(src)?; let results = commit.commit()?; let result = results.get(&handle.task_id) ``` 2. Direct access from the handle: ``` let commit = session.new_upload_commit()?; let handle = commit.upload_from_path(src)?; commit.commit()?; // handle.result() is populated by commit() via the shared Arc. let result = handle.result() ``` ## Download per-task result access patterns The pattern is similar to the above. ## Why not put results in a vector in the same order as tasks are registered to the commit instance? After a commit instance is created, it can be cloned (since it is itself an Arc wrapping an internal struct) and sent to different threads. When multiple threads are registering tasks, there is no static registration order that a program can observe upfront. --- xet_pkg/examples/example.rs | 24 ++- xet_pkg/src/xet_session/download_group.rs | 172 ++++++++++++++------- xet_pkg/src/xet_session/mod.rs | 56 ++++--- xet_pkg/src/xet_session/progress.rs | 178 ++++++++++++++++++++-- xet_pkg/src/xet_session/upload_commit.rs | 162 ++++++++++++++++---- 5 files changed, 470 insertions(+), 122 deletions(-) diff --git a/xet_pkg/examples/example.rs b/xet_pkg/examples/example.rs index cf226c78..06253be1 100644 --- a/xet_pkg/examples/example.rs +++ b/xet_pkg/examples/example.rs @@ -7,7 +7,9 @@ use std::time::Duration; use anyhow::Result; use clap::{Parser, Subcommand}; -use xet::xet_session::{FileMetadata, TaskHandle, TaskStatus, XetFileInfo, XetSessionBuilder}; +use xet::xet_session::{ + DownloadTaskHandle, FileMetadata, TaskStatus, UploadTaskHandle, XetFileInfo, XetSessionBuilder, +}; #[derive(Parser)] #[clap(name = "session-demo", about = "XetSession API demo")] @@ -58,7 +60,7 @@ fn upload_files(files: Vec, endpoint: Option) -> Result<()> { // Enqueue all uploads; each starts immediately in the background. let n_files = files.len(); - let handles: Vec = files + let handles: Vec = files .iter() .map(|f| commit.upload_from_path(f.clone())) .collect::>()?; @@ -80,13 +82,17 @@ fn upload_files(files: Vec, endpoint: Option) -> Result<()> { }); // Block until all uploads finish and metadata is finalized. - let metadata: Vec<_> = commit.commit()?.into_iter().filter_map(|m| m.ok()).collect(); + let results = commit.commit()?; - for m in &metadata { + for m in results.values().filter_map(|m| m.as_ref().as_ref().ok()) { println!(" {} -> {} ({} bytes)", m.tracking_name.as_deref().unwrap_or("?"), m.hash, m.file_size); } // Persist metadata so it can be passed to the `download` subcommand. + let metadata: Vec<_> = results + .into_values() + .filter_map(|m| m.as_ref().as_ref().ok().cloned()) + .collect(); std::fs::write("upload_metadata.json", serde_json::to_string_pretty(&metadata)?)?; Ok(()) @@ -105,7 +111,7 @@ fn download_files(metadata_file: PathBuf, output_dir: PathBuf, endpoint: Option< // Enqueue all downloads; each starts immediately in the background. let n_files = metadata.len(); - let handles: Vec = metadata + let handles: Vec = metadata .iter() .map(|m| { let dest = output_dir.join(m.tracking_name.as_deref().unwrap_or("file")); @@ -136,10 +142,12 @@ fn download_files(metadata_file: PathBuf, output_dir: PathBuf, endpoint: Option< }); // Block until all downloads finish. - let results: Vec<_> = group.finish()?.into_iter().filter_map(|m| m.ok()).collect(); + let results = group.finish()?; - for r in &results { - println!(" {} ({} bytes)", r.dest_path.display(), r.file_info.file_size); + for (_task_id, result) in &results { + if let Ok(r) = result.as_ref() { + println!(" {} ({} bytes)", r.dest_path.display(), r.file_info.file_size); + } } Ok(()) diff --git a/xet_pkg/src/xet_session/download_group.rs b/xet_pkg/src/xet_session/download_group.rs index e61979ec..f8733190 100644 --- a/xet_pkg/src/xet_session/download_group.rs +++ b/xet_pkg/src/xet_session/download_group.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use std::path::PathBuf; -use std::sync::{Arc, Mutex, MutexGuard, RwLock}; +use std::sync::{Arc, Mutex, MutexGuard, OnceLock, RwLock}; use tokio::task::JoinHandle; use ulid::Ulid; @@ -11,7 +11,7 @@ use xet_runtime::core::XetRuntime; use super::common::{GroupState, create_translator_config}; use super::errors::SessionError; -use super::progress::{GroupProgress, ProgressSnapshot, TaskHandle, TaskStatus}; +use super::progress::{DownloadTaskHandle, GroupProgress, ProgressSnapshot, TaskHandle, TaskStatus}; use super::session::XetSession; /// Groups related file downloads into a single unit of work. @@ -89,7 +89,7 @@ impl DownloadGroup { /// * `dest_path` – Local path where the downloaded file will be written. Parent directories are created /// automatically. /// - /// Returns a [`TaskHandle`] that can be used to poll status and per-file + /// Returns a [`DownloadTaskHandle`] that can be used to poll status and per-file /// progress without taking the GIL. /// /// # Errors @@ -101,7 +101,7 @@ impl DownloadGroup { &self, file_info: XetFileInfo, dest_path: PathBuf, - ) -> Result { + ) -> Result { self.session.check_alive()?; self.inner.start_download_file_to_path(file_info, dest_path) } @@ -122,13 +122,21 @@ impl DownloadGroup { /// Wait for all downloads to complete and return their results. /// - /// Blocks until every queued download finishes (or fails). Returns one - /// [`DownloadResult`] entry per download. + /// Blocks until every queued download finishes (or fails). Returns a + /// `HashMap` keyed by task ID (the [`Ulid`] returned by + /// [`download_file_to_path`](Self::download_file_to_path)), where each + /// value is [`DownloadResult`] (= `Arc>`). A single failed download + /// does not prevent the others from being collected. + /// + /// Per-task results can also be read directly from the + /// [`DownloadTaskHandle`] returned by `download_file_to_path` via + /// [`result`](DownloadTaskHandle::result) after this method returns. /// /// Consumes `self` — subsequent calls on any clone will return /// [`SessionError::AlreadyFinished`] (or a channel-closed error if the /// background worker has already exited). - pub fn finish(self) -> Result>, SessionError> { + pub fn finish(self) -> Result, SessionError> { let inner = self.inner.clone(); self.session .runtime @@ -136,11 +144,19 @@ impl DownloadGroup { } } +/// Per-file result type returned by [`DownloadGroup::finish`]. +/// +/// The `Arc` lets the same value be stored in both the `finish()` return map +/// and the per-task [`DownloadTaskHandle`] without requiring the inner +/// `Result` to be `Clone`. +pub type DownloadResult = Arc>; + /// Handle for a single download task tracked internally by DownloadGroup. pub(crate) struct InnerDownloadTaskHandle { status: Arc>, dest_path: PathBuf, join_handle: JoinHandle>, + result: Arc>, } /// All shared state owned by a single DownloadGroup instance. @@ -214,7 +230,7 @@ impl DownloadGroupInner { self: &Arc, file_info: XetFileInfo, dest_path: PathBuf, - ) -> Result { + ) -> Result { // Hold the state lock guard for the duration of this function so finish() will not run // when a download task is registering. let state = self.state.lock()?; @@ -223,10 +239,14 @@ impl DownloadGroupInner { let tracking_id = Ulid::new(); let status = Arc::new(Mutex::new(TaskStatus::Queued)); - let task_handle = TaskHandle { - status: Some(status.clone()), - group_progress: self.progress.clone(), - tracking_id, + let result: Arc> = Arc::new(OnceLock::new()); + let task_handle = DownloadTaskHandle { + inner: TaskHandle { + status: Some(status.clone()), + group_progress: self.progress.clone(), + task_id: tracking_id, + }, + result: result.clone(), }; let Some(download_session) = self.download_session.lock()?.clone() else { @@ -245,6 +265,7 @@ impl DownloadGroupInner { status, dest_path, join_handle, + result, }; self.active_tasks.write()?.insert(tracking_id, handle); @@ -253,7 +274,7 @@ impl DownloadGroupInner { } /// Handle a `Finish` command from the public API. - async fn handle_finish(self: &Arc) -> Result>, SessionError> { + async fn handle_finish(self: &Arc) -> Result, SessionError> { // Mark as not accepting new tasks { let mut state_guard = self.state.lock()?; @@ -266,19 +287,27 @@ impl DownloadGroupInner { // Wait for all downloads to complete let active_tasks = std::mem::take(&mut *self.active_tasks.write()?); - let mut results = Vec::new(); + let mut results = HashMap::new(); let mut join_err = None; // Join all tasks first and then propogate errors. - for (_task_id, handle) in active_tasks { + for (task_id, handle) in active_tasks { match handle.join_handle.await.map_err(SessionError::from) { Ok(Ok(file_info)) => { - results.push(Ok(DownloadResult { + let result = Arc::new(Ok(DownloadedFile { dest_path: handle.dest_path, file_info, })); + results.insert(task_id, result.clone()); + // Update result to the external task handle, this is the only place setting + // the result, so no error will happen. + let _ = handle.result.set(result); }, Ok(Err(task_err)) => { - results.push(Err(task_err)); + let result: Arc> = Arc::new(Err(task_err)); + results.insert(task_id, result.clone()); + // Update result to the external task handle, this is the only place setting + // the result, so no error will happen. + let _ = handle.result.set(result); }, Err(e) => { if join_err.is_none() { @@ -316,30 +345,9 @@ impl DownloadGroupInner { } } -/// A progress snapshot for a single queued download. -/// -/// Returned by [`DownloadGroup::get_progress`]. -#[derive(Clone, Debug)] -pub struct DownloadProgress { - /// Unique identifier for this download task. - pub task_id: Ulid, - /// Local path where the file will be written. - pub dest_path: PathBuf, - /// Content-addressed hash of the file being downloaded. - pub file_hash: String, - /// Number of bytes downloaded so far. - pub bytes_completed: u64, - /// Total file size in bytes (0 if not yet known). - pub bytes_total: u64, - /// Current lifecycle state of the task. - pub status: TaskStatus, - /// Instantaneous download throughput in bytes per second. - pub speed_bps: f64, -} - /// Per-file result returned by [`DownloadGroup::finish`]. #[derive(Clone, Debug)] -pub struct DownloadResult { +pub struct DownloadedFile { /// Local path where the file was written. pub dest_path: PathBuf, /// Xet file hash and size of the downloaded file. @@ -353,6 +361,7 @@ mod tests { use tempfile::{TempDir, tempdir}; use super::*; + use crate::xet_session::progress::UploadTaskHandle; use crate::xet_session::session::XetSession; fn local_session(temp: &TempDir) -> Result> { @@ -362,12 +371,12 @@ mod tests { fn upload_bytes(session: &XetSession, data: &[u8], name: &str) -> Result> { let commit = session.new_upload_commit()?; - commit.upload_bytes(data.to_vec(), Some(name.into()))?; + let handle = commit.upload_bytes(data.to_vec(), Some(name.into()))?; let results = commit.commit()?; - let m = &results[0]; + let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap(); Ok(XetFileInfo { - hash: m.as_ref().unwrap().hash.clone(), - file_size: m.as_ref().unwrap().file_size, + hash: meta.hash.clone(), + file_size: meta.file_size, }) } @@ -544,28 +553,25 @@ mod tests { let data_a = b"First file content"; let data_b = b"Second file content - different"; - // Upload both files in one commit; use tracking_name to locate each result. + // Upload both files; capture handles so results can be retrieved by task_id. let commit = session.new_upload_commit()?; - commit.upload_bytes(data_a.to_vec(), Some("a.bin".into()))?; - commit.upload_bytes(data_b.to_vec(), Some("b.bin".into()))?; + let handle_a = commit.upload_bytes(data_a.to_vec(), Some("a.bin".into()))?; + let handle_b = commit.upload_bytes(data_b.to_vec(), Some("b.bin".into()))?; let results = commit.commit()?; - let find_info = |name: &str| -> XetFileInfo { - let m = results - .iter() - .find(|r| r.as_ref().unwrap().tracking_name.as_deref() == Some(name)) - .unwrap(); + let to_file_info = |handle: &UploadTaskHandle| -> XetFileInfo { + let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap(); XetFileInfo { - hash: m.as_ref().unwrap().hash.clone(), - file_size: m.as_ref().unwrap().file_size, + hash: meta.hash.clone(), + file_size: meta.file_size, } }; let dest_a = temp.path().join("a_out.bin"); let dest_b = temp.path().join("b_out.bin"); let group = session.new_download_group()?; - group.download_file_to_path(find_info("a.bin"), dest_a.clone())?; - group.download_file_to_path(find_info("b.bin"), dest_b.clone())?; + group.download_file_to_path(to_file_info(&handle_a), dest_a.clone())?; + group.download_file_to_path(to_file_info(&handle_b), dest_b.clone())?; group.finish()?; assert_eq!(std::fs::read(&dest_a)?, data_a); @@ -600,6 +606,62 @@ mod tests { Ok(()) } + // ── Per-task result access patterns ────────────────────────────────────── + // + // After finish() completes there are two equivalent ways to retrieve a + // per-task DownloadResult: + // + // 1. HashMap lookup: `finish_results.get(&handle.task_id)` + // 2. Direct handle: `handle.result()` (on DownloadTaskHandle) + + #[test] + // Pattern 1: per-task result is accessible via task_id in the finish() HashMap. + fn test_download_result_accessible_via_task_id_in_finish_map() -> Result<(), Box> { + let temp = tempdir()?; + let session = local_session(&temp)?; + let data = b"result via task_id in finish map"; + let file_info = upload_bytes(&session, data, "file.bin")?; + let dest = temp.path().join("out.bin"); + let group = session.new_download_group()?; + let handle = group.download_file_to_path(file_info, dest)?; + let results = group.finish()?; + let result = results.get(&handle.task_id).expect("task_id must be present in results"); + assert_eq!(result.as_ref().as_ref().unwrap().file_info.file_size, data.len() as u64); + Ok(()) + } + + #[test] + // DownloadTaskHandle::result() returns None before finish() is called. + fn test_download_result_none_before_finish() -> Result<(), Box> { + let temp = tempdir()?; + let session = local_session(&temp)?; + let file_info = upload_bytes(&session, b"some data", "file.bin")?; + let dest = temp.path().join("out.bin"); + let group = session.new_download_group()?; + let handle = group.download_file_to_path(file_info, dest)?; + assert!(handle.result().is_none(), "result must be None before finish()"); + group.finish()?; + Ok(()) + } + + #[test] + // DownloadTaskHandle::result() returns Some after finish() completes. + fn test_download_result_some_after_finish() -> Result<(), Box> { + let temp = tempdir()?; + let session = local_session(&temp)?; + let data = b"download result test data"; + let file_info = upload_bytes(&session, data, "file.bin")?; + let dest = temp.path().join("out.bin"); + let group = session.new_download_group()?; + let handle = group.download_file_to_path(file_info.clone(), dest)?; + group.finish()?; + let result = handle.result().expect("result must be set after finish()"); + let dl = result.as_ref().as_ref().unwrap(); + assert_eq!(dl.file_info.file_size, data.len() as u64); + assert_eq!(dl.file_info.hash, file_info.hash); + Ok(()) + } + // ── Mutex guard / concurrency test ─────────────────────────────────────── // // `download_file_to_path` holds `self.state` for its entire execution so diff --git a/xet_pkg/src/xet_session/mod.rs b/xet_pkg/src/xet_session/mod.rs index 94d868db..85a88097 100644 --- a/xet_pkg/src/xet_session/mod.rs +++ b/xet_pkg/src/xet_session/mod.rs @@ -20,29 +20,38 @@ //! files with [`upload_from_path`](UploadCommit::upload_from_path) or //! [`upload_bytes`](UploadCommit::upload_bytes), then call //! [`commit`](UploadCommit::commit) to wait for all transfers to finish and -//! receive a [`Vec`] of [`FileMetadata`]. +//! receive a `HashMap` keyed by task ID +//! (`UploadResult` = `Arc>`). +//! Per-task results can also be read directly from the returned +//! [`UploadTaskHandle`] via [`result`](UploadTaskHandle::result) after +//! `commit()` returns. //! //! ## Downloads //! //! Create a [`DownloadGroup`] with [`XetSession::new_download_group`], queue //! files with [`download_file_to_path`](DownloadGroup::download_file_to_path), -//! then call [`finish`](DownloadGroup::finish) to wait for all transfers. +//! then call [`finish`](DownloadGroup::finish) to wait for all transfers and +//! receive a `HashMap` keyed by task ID +//! (`DownloadResult` = `Arc>`). +//! Per-task results can also be read directly from the returned +//! [`DownloadTaskHandle`] via [`result`](DownloadTaskHandle::result) after +//! `finish()` returns. //! //! ## Progress tracking //! //! Both [`UploadCommit`] and [`DownloadGroup`] expose //! [`get_progress`](UploadCommit::get_progress), which returns a -//! [`ProgressSnapshot`](progress::ProgressSnapshot) without acquiring a lock -//! on the calling thread (useful for Python bindings that must release the -//! GIL). Poll it from a background thread while the main thread blocks in -//! `commit()` / `finish()`. +//! [`ProgressSnapshot`] without acquiring a lock on the calling thread +//! (useful for Python bindings that must release the GIL). Poll it from a +//! background thread while the main thread blocks in `commit()` / `finish()`. //! //! ## Error handling //! //! All public methods return `Result<_, `[`SessionError`]`>`. -//! [`commit`](UploadCommit::commit) and [`finish`](DownloadGroup::finish) -//! return `Vec>` so a single failed file does not -//! discard the results of all others. +//! [`commit`](UploadCommit::commit) returns `HashMap` +//! keyed by task ID, and [`finish`](DownloadGroup::finish) returns +//! `HashMap` keyed by task ID, so a single failed +//! file does not discard all others. //! //! # Quick start //! @@ -57,18 +66,29 @@ //! //! // 2. Upload //! let commit = session.new_upload_commit()?; -//! commit.upload_from_path("file.bin".into())?; -//! let metadata = commit.commit()?; +//! let handle = commit.upload_from_path("file.bin".into())?; +//! commit.commit()?; +//! // Access result directly from the handle (populated by commit()) +//! // UploadResult = Arc> +//! let m = handle.result().unwrap(); // Option +//! let m = m.as_ref().as_ref().unwrap(); // &FileMetadata //! //! // 3. Download //! let group = session.new_download_group()?; -//! let m = metadata[0].as_ref().unwrap(); //! let info = XetFileInfo { //! hash: m.hash.clone(), //! file_size: m.file_size, //! }; -//! group.download_file_to_path(info, "out/file.bin".into())?; -//! group.finish()?; +//! let dl_handle = group.download_file_to_path(info, "out/file.bin".into())?; +//! let finish_results = group.finish()?; +//! // Pattern 1: look up by task ID in the returned HashMap +//! // DownloadResult = Arc> +//! let r1 = finish_results.get(&dl_handle.task_id).unwrap(); // &DownloadResult +//! let r1 = r1.as_ref().as_ref().unwrap(); // &DownloadedFile +//! // Pattern 2: read directly from the handle (populated by finish()) +//! let r2 = dl_handle.result().unwrap(); // DownloadResult +//! let r2 = r2.as_ref().as_ref().unwrap(); // &DownloadedFile +//! //! # Ok::<(), xet::xet_session::SessionError>(()) //! ``` @@ -79,11 +99,13 @@ mod progress; mod session; mod upload_commit; -pub use download_group::{DownloadGroup, DownloadProgress, DownloadResult}; +pub use download_group::{DownloadGroup, DownloadResult, DownloadedFile}; pub use errors::SessionError; -pub use progress::{TaskHandle, TaskStatus}; +pub use progress::{ + DownloadTaskHandle, FileProgress, ProgressSnapshot, TaskHandle, TaskStatus, TotalProgressSnapshot, UploadTaskHandle, +}; pub use session::{XetSession, XetSessionBuilder}; -pub use upload_commit::{FileMetadata, UploadCommit}; +pub use upload_commit::{FileMetadata, UploadCommit, UploadResult}; pub use xet_data::processing::XetFileInfo; // Re-export XetConfig for convenience pub use xet_runtime::config::XetConfig; diff --git a/xet_pkg/src/xet_session/progress.rs b/xet_pkg/src/xet_session/progress.rs index 6e96dd91..373ca805 100644 --- a/xet_pkg/src/xet_session/progress.rs +++ b/xet_pkg/src/xet_session/progress.rs @@ -1,14 +1,17 @@ //! Progress tracking for upload commits and download groups. use std::collections::HashMap; +use std::ops::Deref; use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, OnceLock}; use async_trait::async_trait; use ulid::Ulid; use xet_data::progress_tracking::{ProgressUpdate, TrackingProgressUpdater}; use super::SessionError; +use super::download_group::DownloadResult; +use super::upload_commit::UploadResult; /// Lifecycle state of a single upload or download task. #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -29,7 +32,36 @@ pub enum TaskStatus { pub struct TaskHandle { pub(crate) status: Option>>, pub(crate) group_progress: Arc, - pub(crate) tracking_id: Ulid, + /// Id of the task, can be used to retrive per-task progress and result. + pub task_id: Ulid, +} + +#[derive(Debug)] +pub struct UploadTaskHandle { + pub(crate) inner: TaskHandle, + pub(crate) result: Arc>, +} + +impl Deref for UploadTaskHandle { + type Target = TaskHandle; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +#[derive(Debug)] +pub struct DownloadTaskHandle { + pub(crate) inner: TaskHandle, + pub(crate) result: Arc>, +} + +impl Deref for DownloadTaskHandle { + type Target = TaskHandle; + + fn deref(&self) -> &Self::Target { + &self.inner + } } impl TaskHandle { @@ -42,7 +74,19 @@ impl TaskHandle { } pub fn progress(&self) -> Result { - self.group_progress.file(self.tracking_id) + self.group_progress.file(self.task_id) + } +} + +impl UploadTaskHandle { + pub fn result(&self) -> Option { + self.result.get().cloned() + } +} + +impl DownloadTaskHandle { + pub fn result(&self) -> Option { + self.result.get().cloned() } } @@ -206,9 +250,13 @@ impl TrackingProgressUpdater for GroupProgress { #[cfg(test)] mod tests { + use std::path::PathBuf; + + use xet_data::processing::XetFileInfo; use xet_data::progress_tracking::ItemProgressUpdate; use super::*; + use crate::xet_session::{DownloadedFile, FileMetadata}; // ── GroupProgress unit tests ───────────────────────────────────────────── @@ -240,13 +288,13 @@ mod tests { // Per-file bytes_completed uses max semantics: a stale/lower update does not reduce progress. async fn test_register_updates_per_file_bytes_completed_never_decreases() { let p = GroupProgress::new(); - let tracking_id = Ulid::new(); + let task_id = Ulid::new(); p.register_updates(ProgressUpdate { total_bytes: 100, total_bytes_completed: 80, item_updates: vec![ItemProgressUpdate { - tracking_id, + tracking_id: task_id, item_name: "file.bin".into(), total_bytes: 100, bytes_completed: 80, @@ -261,7 +309,7 @@ mod tests { total_bytes: 100, total_bytes_completed: 40, item_updates: vec![ItemProgressUpdate { - tracking_id, + tracking_id: task_id, item_name: "file.bin".into(), total_bytes: 100, bytes_completed: 40, // lower than previously seen @@ -272,7 +320,7 @@ mod tests { .await; let snapshot = p.snapshot().unwrap(); - let file = snapshot.file(tracking_id).unwrap(); + let file = snapshot.file(task_id).unwrap(); // Max semantics: should still report 80, not the lower 40. assert_eq!(file.bytes_completed, 80); } @@ -286,23 +334,133 @@ mod tests { let handle = TaskHandle { status: None, group_progress: progress, - tracking_id: Ulid::new(), + task_id: Ulid::new(), }; assert!(handle.status().is_err()); } #[test] - // TaskHandle::progress() returns InvalidTaskID when the tracking_id has no registered progress. + // TaskHandle::progress() returns InvalidTaskID when the task_id has no registered progress. fn test_task_handle_progress_for_unknown_id_returns_error() { let progress = Arc::new(GroupProgress::new()); let handle = TaskHandle { status: None, group_progress: progress, - tracking_id: Ulid::new(), // not registered in GroupProgress + task_id: Ulid::new(), // not registered in GroupProgress }; assert!(matches!(handle.progress(), Err(SessionError::InvalidTaskID(_)))); } + // ── UploadTaskHandle unit tests ────────────────────────────────────────── + // + // `UploadTaskHandle` wraps a `TaskHandle` and adds a `result` Arc that is + // shared with the internal `InnerUploadTaskHandle` inside `UploadCommit`. + // After `commit()` completes, the internal handle writes the per-file + // `UploadResult` (= `Arc>`) into that + // shared Arc so callers can read it directly from the task handle without + // touching the `commit()` return value. + // + // There are therefore two equivalent ways to retrieve a per-task result: + // 1. `commit()` returns `HashMap`; look up the task using `handle.task_id`. + // 2. Call `handle.result()` directly after `commit()` returns. + // + // The tests below exercise the `result` Arc mechanics in isolation; see + // `upload_commit.rs` for end-to-end integration tests of both patterns. + + #[test] + // UploadTaskHandle::result() returns None before the result Arc is populated. + fn test_upload_task_handle_result_none_before_commit() { + let progress = Arc::new(GroupProgress::new()); + let handle = UploadTaskHandle { + inner: TaskHandle { + status: None, + group_progress: progress, + task_id: Ulid::new(), + }, + result: Arc::new(OnceLock::new()), + }; + assert!(handle.result().is_none()); + } + + #[test] + // UploadTaskHandle::result() returns the value once the shared Arc is populated. + fn test_upload_task_handle_result_some_after_result_set() { + let progress = Arc::new(GroupProgress::new()); + let result_arc = Arc::new(OnceLock::new()); + let handle = UploadTaskHandle { + inner: TaskHandle { + status: None, + group_progress: progress, + task_id: Ulid::new(), + }, + result: result_arc.clone(), + }; + + // Simulate commit() writing the result. + let metadata = Arc::new(Ok(FileMetadata { + tracking_name: Some("file.bin".to_string()), + hash: "abc123".to_string(), + file_size: 42, + })); + result_arc.set(metadata).unwrap(); + + let result = handle.result().unwrap(); + let meta = result.as_ref().as_ref().unwrap(); + assert_eq!(meta.file_size, 42); + assert_eq!(meta.hash, "abc123"); + } + + // ── DownloadTaskHandle unit tests ──────────────────────────────────────── + // + // `DownloadTaskHandle` follows the same Arc-sharing pattern as + // `UploadTaskHandle`. Its `result` field holds a `DownloadResult` + // (= `Arc>`), populated by `finish()`. + + #[test] + // DownloadTaskHandle::result() returns None before finish() populates the result Arc. + fn test_download_task_handle_result_none_before_finish() { + let progress = Arc::new(GroupProgress::new()); + let handle = DownloadTaskHandle { + inner: TaskHandle { + status: None, + group_progress: progress, + task_id: Ulid::new(), + }, + result: Arc::new(OnceLock::new()), + }; + assert!(handle.result().is_none()); + } + + #[test] + // DownloadTaskHandle::result() returns the value once the shared Arc is populated. + fn test_download_task_handle_result_some_after_result_set() { + let progress = Arc::new(GroupProgress::new()); + let result_arc = Arc::new(OnceLock::new()); + let handle = DownloadTaskHandle { + inner: TaskHandle { + status: None, + group_progress: progress, + task_id: Ulid::new(), + }, + result: result_arc.clone(), + }; + + // Simulate finish() writing the result. + let download_result = Arc::new(Ok(DownloadedFile { + dest_path: PathBuf::from("out/file.bin"), + file_info: XetFileInfo { + hash: "def456".to_string(), + file_size: 99, + }, + })); + result_arc.set(download_result).unwrap(); + + let result = handle.result().unwrap(); + let dl = result.as_ref().as_ref().unwrap(); + assert_eq!(dl.file_info.file_size, 99); + assert_eq!(dl.dest_path, PathBuf::from("out/file.bin")); + } + // ── Full register_updates test ─────────────────────────────────────────── #[tokio::test] diff --git a/xet_pkg/src/xet_session/upload_commit.rs b/xet_pkg/src/xet_session/upload_commit.rs index 374dc448..a039ac40 100644 --- a/xet_pkg/src/xet_session/upload_commit.rs +++ b/xet_pkg/src/xet_session/upload_commit.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use std::path::PathBuf; -use std::sync::{Arc, Mutex, MutexGuard, RwLock}; +use std::sync::{Arc, Mutex, MutexGuard, OnceLock, RwLock}; use tokio::task::JoinHandle; use ulid::Ulid; @@ -12,7 +12,7 @@ use xet_runtime::core::XetRuntime; use super::common::{GroupState, create_translator_config}; use super::errors::SessionError; -use super::progress::{GroupProgress, ProgressSnapshot, TaskHandle, TaskStatus}; +use super::progress::{GroupProgress, ProgressSnapshot, TaskHandle, TaskStatus, UploadTaskHandle}; use super::session::XetSession; /// Groups related file uploads into a single atomic commit. @@ -84,7 +84,7 @@ impl UploadCommit { /// Queue a file for upload, starting the transfer immediately if system resource permits. /// - /// Returns a [`TaskHandle`] that can be used to poll status and per-file + /// Returns an [`UploadTaskHandle`] that can be used to poll status and per-file /// progress without taking the GIL. /// /// # Errors @@ -92,7 +92,7 @@ impl UploadCommit { /// Returns [`SessionError::Aborted`] if the session has been aborted, or /// [`SessionError::AlreadyCommitted`] if [`commit`](Self::commit) has /// already been called. - pub fn upload_from_path(&self, file_path: PathBuf) -> Result { + pub fn upload_from_path(&self, file_path: PathBuf) -> Result { self.session.check_alive()?; // Use the absolute path in case the process current working directory changes @@ -130,6 +130,8 @@ impl UploadCommit { /// /// - `file_name`: optional name used for progress/telemetry reporting. /// - `file_size`: expected size in bytes (used for progress tracking; `0` is valid if unknown). + /// # Returns [`TaskHandle`] because the handle isn't expected to hold any result, and instead + /// the user is expected to get upload result from the returned [`SingleFileCleaner`]. pub fn upload_file( &self, file_name: Option, @@ -142,8 +144,12 @@ impl UploadCommit { /// Queue raw bytes for upload, starting the transfer immediately if system resource permits. /// - /// Returns a [`TaskHandle`]. See [`upload_from_path`](Self::upload_from_path) for details. - pub fn upload_bytes(&self, bytes: Vec, tracking_name: Option) -> Result { + /// Returns an [`UploadTaskHandle`]. See [`upload_from_path`](Self::upload_from_path) for details. + pub fn upload_bytes( + &self, + bytes: Vec, + tracking_name: Option, + ) -> Result { self.session.check_alive()?; self.inner.start_upload_bytes(bytes, tracking_name) } @@ -165,12 +171,14 @@ impl UploadCommit { /// Wait for all uploads to complete and push metadata to the CAS server. /// /// Blocks until every queued upload finishes (or fails), then finalises - /// the upload session. Returns one [`FileMetadata`] entry per uploaded - /// file. + /// the upload session. Returns a `HashMap` keyed by task ID where each + /// value is [`UploadResult`] (= `Arc>`). A single failed upload does not prevent the + /// others from being collected. /// /// Consumes `self` — subsequent calls on any clone will return /// [`SessionError::AlreadyCommitted`]. - pub fn commit(self) -> Result>, SessionError> { + pub fn commit(self) -> Result, SessionError> { let inner = self.inner.clone(); self.session .runtime @@ -178,11 +186,19 @@ impl UploadCommit { } } +/// Per-file result type returned by [`UploadCommit::commit`]. +/// +/// The `Arc` lets the same value be stored in both the `commit()` return map +/// and the per-task [`UploadTaskHandle`] without requiring the inner +/// `Result` to be `Clone`. +pub type UploadResult = Arc>; + /// Handle for a single upload task tracked internally by UploadCommit. pub(crate) struct InnerUploadTaskHandle { status: Arc>, tracking_name: Option, join_handle: JoinHandle>, + result: Arc>, } /// All shared state owned by a single UploadCommit instance. @@ -279,7 +295,7 @@ impl UploadCommitInner { }) } - fn start_upload_file_from_path(&self, file_path: PathBuf) -> Result { + fn start_upload_file_from_path(&self, file_path: PathBuf) -> Result { // Hold the state lock guard for the duration of this function so commit() will not run // when an upload task is registering. let state = self.state.lock()?; @@ -287,10 +303,14 @@ impl UploadCommitInner { let tracking_id = Ulid::new(); let status = Arc::new(Mutex::new(TaskStatus::Queued)); - let task_handle = TaskHandle { - status: Some(status.clone()), - group_progress: self.progress.clone(), - tracking_id, + let result: Arc> = Arc::new(OnceLock::new()); + let task_handle = UploadTaskHandle { + inner: TaskHandle { + status: Some(status.clone()), + group_progress: self.progress.clone(), + task_id: tracking_id, + }, + result: result.clone(), }; let Some(upload_session) = self.upload_session.lock()?.clone() else { @@ -304,6 +324,7 @@ impl UploadCommitInner { status, tracking_name: file_path.to_str().map(|s| s.to_owned()), join_handle, + result, }; self.active_tasks.write()?.insert(tracking_id, handle); @@ -327,7 +348,7 @@ impl UploadCommitInner { let task_handle = TaskHandle { status: None, // upload directly managed by user - not internally managed group_progress: self.progress.clone(), - tracking_id, + task_id: tracking_id, }; let Some(upload_session) = self.upload_session.lock()?.clone() else { @@ -345,7 +366,11 @@ impl UploadCommitInner { } /// Handle an `UploadBytes` command from the public API. - fn start_upload_bytes(&self, bytes: Vec, tracking_name: Option) -> Result { + fn start_upload_bytes( + &self, + bytes: Vec, + tracking_name: Option, + ) -> Result { // Hold the state lock guard for the duration of this function so commit() will not run // when an upload task is registering. let state = self.state.lock()?; @@ -353,10 +378,14 @@ impl UploadCommitInner { let tracking_id = Ulid::new(); let status = Arc::new(Mutex::new(TaskStatus::Queued)); - let task_handle = TaskHandle { - status: Some(status.clone()), - group_progress: self.progress.clone(), - tracking_id, + let result: Arc> = Arc::new(OnceLock::new()); + let task_handle = UploadTaskHandle { + inner: TaskHandle { + status: Some(status.clone()), + group_progress: self.progress.clone(), + task_id: tracking_id, + }, + result: result.clone(), }; let Some(upload_session) = self.upload_session.lock()?.clone() else { @@ -369,6 +398,7 @@ impl UploadCommitInner { status, tracking_name, join_handle, + result, }; self.active_tasks.write()?.insert(tracking_id, handle); @@ -377,7 +407,7 @@ impl UploadCommitInner { } /// Handle a `Commit` command from the public API. - async fn handle_commit(&self) -> Result>, SessionError> { + async fn handle_commit(&self) -> Result, SessionError> { // Mark as not accepting new tasks { let mut state_guard = self.state.lock()?; @@ -392,20 +422,28 @@ impl UploadCommitInner { // The guard is dropped immediately so the lock is not held across any `.await`. let active_tasks = std::mem::take(&mut *self.active_tasks.write()?); - let mut results = Vec::new(); + let mut results = HashMap::new(); let mut join_err = None; // Join all tasks first and then propogate errors. - for (_task_id, handle) in active_tasks { + for (task_id, handle) in active_tasks { match handle.join_handle.await.map_err(SessionError::from) { Ok(Ok(file_info)) => { - results.push(Ok(FileMetadata { + let result = Arc::new(Ok(FileMetadata { tracking_name: handle.tracking_name, hash: file_info.hash().to_string(), file_size: file_info.file_size(), })); + results.insert(task_id, result.clone()); + // Update result to the external task handle, this is the only place setting + // the result, so no error will happen. + let _ = handle.result.set(result); }, Ok(Err(task_err)) => { - results.push(Err(task_err)); + let result = Arc::new(Err(task_err)); + results.insert(task_id, result.clone()); + // Update result to the external task handle, this is the only place setting + // the result, so no error will happen. + let _ = handle.result.set(result); }, Err(e) => { if join_err.is_none() { @@ -648,11 +686,12 @@ mod tests { let session = local_session(&temp)?; let data = b"Hello, upload commit round-trip!"; let commit = session.new_upload_commit()?; - commit.upload_bytes(data.to_vec(), Some("hello.bin".into()))?; + let task_handle = commit.upload_bytes(data.to_vec(), Some("hello.bin".into()))?; let results = commit.commit()?; assert_eq!(results.len(), 1); - assert_eq!(results[0].as_ref().unwrap().file_size, data.len() as u64); - assert!(!results[0].as_ref().unwrap().hash.is_empty()); + let meta = results.get(&task_handle.task_id).unwrap().as_ref().as_ref().unwrap(); + assert_eq!(meta.file_size, data.len() as u64); + assert!(!meta.hash.is_empty()); Ok(()) } @@ -665,11 +704,70 @@ mod tests { let data = b"file path upload content"; std::fs::write(&src, data)?; let commit = session.new_upload_commit()?; - commit.upload_from_path(src)?; + let handle = commit.upload_from_path(src)?; + commit.commit()?; + let meta = handle.result().unwrap(); + let meta = meta.as_ref().as_ref().unwrap(); + assert_eq!(meta.file_size, data.len() as u64); + assert!(!meta.hash.is_empty()); + Ok(()) + } + + // ── Per-task result access patterns ────────────────────────────────────── + // + // After commit() completes there are two equivalent ways to retrieve a + // per-task FileMetadata result: + // + // 1. HashMap lookup: `commit_results.get(&handle.task_id)` + // 2. Direct handle: `handle.result()` (only on UploadTaskHandle, not the plain TaskHandle returned by + // upload_file) + // + // Both patterns are exercised by the tests below. + + #[test] + // UploadTaskHandle::result() returns None before commit() is called. + fn test_upload_result_none_before_commit() -> Result<(), Box> { + let temp = tempdir()?; + let session = local_session(&temp)?; + let src = temp.path().join("data.bin"); + std::fs::write(&src, b"content")?; + let commit = session.new_upload_commit()?; + let handle = commit.upload_from_path(src)?; + assert!(handle.result().is_none(), "result must be None before commit()"); + commit.commit()?; + Ok(()) + } + + #[test] + // Pattern 1: per-task result is accessible via task_id in the commit() HashMap. + fn test_upload_result_accessible_via_task_id_in_commit_map() -> Result<(), Box> { + let temp = tempdir()?; + let session = local_session(&temp)?; + let data = b"result via task_id"; + let src = temp.path().join("data.bin"); + std::fs::write(&src, data)?; + let commit = session.new_upload_commit()?; + let handle = commit.upload_from_path(src)?; let results = commit.commit()?; - assert_eq!(results.len(), 1); - assert_eq!(results[0].as_ref().unwrap().file_size, data.len() as u64); - assert!(!results[0].as_ref().unwrap().hash.is_empty()); + let result = results.get(&handle.task_id).expect("task_id must be present in results"); + assert_eq!(result.as_ref().as_ref().unwrap().file_size, data.len() as u64); + Ok(()) + } + + #[test] + // Pattern 2: per-task result is accessible directly from the UploadTaskHandle after commit(). + fn test_upload_result_accessible_via_handle_after_commit() -> Result<(), Box> { + let temp = tempdir()?; + let session = local_session(&temp)?; + let data = b"result via handle"; + let src = temp.path().join("data.bin"); + std::fs::write(&src, data)?; + let commit = session.new_upload_commit()?; + let handle = commit.upload_from_path(src)?; + commit.commit()?; + // handle.result() is populated by commit() via the shared Arc. + let result = handle.result().expect("result must be set after commit"); + assert_eq!(result.as_ref().as_ref().unwrap().file_size, data.len() as u64); Ok(()) }