diff --git a/api_changes/update_260318_optional_file_size.md b/api_changes/update_260318_optional_file_size.md new file mode 100644 index 00000000..961faa79 --- /dev/null +++ b/api_changes/update_260318_optional_file_size.md @@ -0,0 +1,75 @@ +# `XetFileInfo.file_size` is now `Option` + +**Date**: 2026-03-18 +**Crate**: `xet-data` (`xet_data::processing::XetFileInfo`) + +## What changed + +`XetFileInfo.file_size` changed from `u64` to `Option`. +The `file_size()` accessor now returns `Option`. + +Downstream API surfaces that consume `XetFileInfo` were updated accordingly: + +- `xet_pkg` session/examples/tests now construct download `XetFileInfo` values + with `file_size: Some(size)` when known. +- `hf_xet::PyXetDownloadInfo.file_size` is now `Option`, and converting + from Python download metadata supports hash-only downloads. +- `hf_xet::PyPointerFile.filesize` getter also returns `Option` now. +- `hf_xet::PyXetUploadInfo` still exposes `file_size: u64`; upload metadata is + expected to always provide a known size. + +## Why + +The download path no longer requires callers to know the file size upfront. +When `file_size` is `None`, the reconstruction discovers the actual size +from the remote and progress tracking updates incrementally. + +## Migration + +### Struct literal construction + +```rust +// Before +XetFileInfo { hash: h, file_size: s, sha256: None } + +// After +XetFileInfo { hash: h, file_size: Some(s), sha256: None } +``` + +### Using the constructor (no change needed) + +```rust +// XetFileInfo::new wraps in Some internally +XetFileInfo::new(hash, file_size) +``` + +### New: hash-only constructor for unknown size + +```rust +XetFileInfo::new_hash_only(hash) +``` + +### Reading file_size + +```rust +// Before +let size: u64 = info.file_size(); + +// After +let size: Option = info.file_size(); +// or when size is known to be present: +let size: u64 = info.file_size().expect("size should be set"); +``` + +### Serde + +`Some(n)` serializes as `n` (backward compatible). +`None` omits the field. Missing field deserializes as `None`. + +## New error variant + +`DataError::SizeMismatch { expected, actual }` is returned when a download +completes but the actual byte count differs from the specified `file_size`. + +This check runs after full-file reconstruction and works for both larger and +smaller actual byte counts relative to the caller-provided value. diff --git a/hf_xet/src/lib.rs b/hf_xet/src/lib.rs index d005c43a..9e956e74 100644 --- a/hf_xet/src/lib.rs +++ b/hf_xet/src/lib.rs @@ -332,13 +332,14 @@ pub struct PyXetDownloadInfo { #[pyo3(get)] hash: String, #[pyo3(get)] - file_size: u64, + file_size: Option, } #[pymethods] impl PyXetDownloadInfo { #[new] - pub fn new(destination_path: String, hash: String, file_size: u64) -> Self { + #[pyo3(signature = (destination_path, hash, file_size=None))] + pub fn new(destination_path: String, hash: String, file_size: Option) -> Self { Self { destination_path, hash, @@ -351,7 +352,8 @@ impl PyXetDownloadInfo { } fn __repr__(&self) -> String { - format!("PyXetDownloadInfo({}, {}, {})", self.destination_path, self.hash, self.file_size) + let size_str = self.file_size.map_or("None".to_string(), |s| s.to_string()); + format!("PyXetDownloadInfo({}, {}, {})", self.destination_path, self.hash, size_str) } } @@ -365,7 +367,7 @@ pub struct PyPointerFile {} impl PyPointerFile { #[new] pub fn new(path: String, hash: String, filesize: u64) -> (Self, PyXetDownloadInfo) { - (PyPointerFile {}, PyXetDownloadInfo::new(path, hash, filesize)) + (PyPointerFile {}, PyXetDownloadInfo::new(path, hash, Some(filesize))) } fn __str__(&self) -> String { @@ -374,7 +376,8 @@ impl PyPointerFile { fn __repr__(self_: PyRef<'_, Self>) -> String { let super_ = self_.as_super(); - format!("PyPointerFile({}, {}, {})", super_.destination_path, super_.hash, super_.file_size) + let size_str = super_.file_size.map_or("None".to_string(), |s| s.to_string()); + format!("PyPointerFile({}, {}, {})", super_.destination_path, super_.hash, size_str) } #[getter] @@ -388,7 +391,7 @@ impl PyPointerFile { } #[getter] - fn filesize(self_: PyRef<'_, Self>) -> u64 { + fn filesize(self_: PyRef<'_, Self>) -> Option { self_.as_super().file_size } } @@ -436,7 +439,7 @@ impl From for PyXetUploadInfo { fn from(xf: XetFileInfo) -> Self { Self { hash: xf.hash().to_owned(), - file_size: xf.file_size(), + file_size: xf.file_size().expect("upload metadata must always include a known file size"), sha256: xf.sha256().map(str::to_owned), } } @@ -444,7 +447,11 @@ impl From for PyXetUploadInfo { impl From for (XetFileInfo, DestinationPath) { fn from(pf: PyXetDownloadInfo) -> Self { - (XetFileInfo::new(pf.hash, pf.file_size), pf.destination_path) + let file_info = match pf.file_size { + Some(size) => XetFileInfo::new(pf.hash, size), + None => XetFileInfo::new_hash_only(pf.hash), + }; + (file_info, pf.destination_path) } } diff --git a/xet_data/src/error.rs b/xet_data/src/error.rs index 74804628..ec4654e7 100644 --- a/xet_data/src/error.rs +++ b/xet_data/src/error.rs @@ -74,6 +74,8 @@ pub enum DataError { #[error("Invalid operation: {0}")] InvalidOperation(String), + #[error("File size mismatch: expected {expected} bytes but downloaded {actual} bytes")] + SizeMismatch { expected: u64, actual: u64 }, #[error("Auth error: {0}")] AuthError(#[from] AuthError), diff --git a/xet_data/src/file_reconstruction/file_reconstructor.rs b/xet_data/src/file_reconstruction/file_reconstructor.rs index a3fc3e25..0d2953b0 100644 --- a/xet_data/src/file_reconstruction/file_reconstructor.rs +++ b/xet_data/src/file_reconstruction/file_reconstructor.rs @@ -345,7 +345,9 @@ impl FileReconstructor { #[cfg(debug_assertions)] if !_is_streaming && let Some(updater) = run_state.progress_updater() { updater.assert_complete(); - if let Some(byte_range) = byte_range { + if let Some(byte_range) = byte_range + && byte_range.end < u64::MAX + { assert_eq!(updater.total_bytes_completed(), byte_range.end - byte_range.start); } } diff --git a/xet_data/src/file_reconstruction/reconstruction_terms/manager.rs b/xet_data/src/file_reconstruction/reconstruction_terms/manager.rs index 74cf9ff6..9c9c9192 100644 --- a/xet_data/src/file_reconstruction/reconstruction_terms/manager.rs +++ b/xet_data/src/file_reconstruction/reconstruction_terms/manager.rs @@ -167,6 +167,10 @@ impl ReconstructionTermManager { self.known_final_byte_position .store(self.prefetched_byte_position, Ordering::Relaxed); + if let Some(progress_updater) = &self.progress_updater { + progress_updater.update_item_size(self.total_bytes_reported, true); + } + info!( file_hash = %self.file_hash, prefetched_byte_position = self.prefetched_byte_position, diff --git a/xet_data/src/processing/bin/xtool.rs b/xet_data/src/processing/bin/xtool.rs index 8b7b4107..49d4317c 100644 --- a/xet_data/src/processing/bin/xtool.rs +++ b/xet_data/src/processing/bin/xtool.rs @@ -144,7 +144,12 @@ impl Command { eprintln!("\n\nClean results:"); for (xf, new_bytes) in clean_ret { - println!("{}: {} bytes -> {} bytes", xf.hash(), xf.file_size(), new_bytes); + println!( + "{}: {} bytes -> {} bytes", + xf.hash(), + xf.file_size().map_or("?".to_string(), |s| s.to_string()), + new_bytes + ); } eprintln!("Transmitted {total_bytes_trans} bytes in total."); diff --git a/xet_data/src/processing/data_client.rs b/xet_data/src/processing/data_client.rs index ce064930..aa8c36ed 100644 --- a/xet_data/src/processing/data_client.rs +++ b/xet_data/src/processing/data_client.rs @@ -273,7 +273,7 @@ mod tests { assert!(result.is_ok()); let file_info = result.unwrap(); - assert_eq!(file_info.file_size(), 0); + assert_eq!(file_info.file_size(), Some(0)); assert!(!file_info.hash().is_empty()); } @@ -289,7 +289,7 @@ mod tests { assert!(result.is_ok()); let file_info = result.unwrap(); - assert_eq!(file_info.file_size(), content.len() as u64); + assert_eq!(file_info.file_size(), Some(content.len() as u64)); assert!(!file_info.hash().is_empty()); } @@ -349,8 +349,8 @@ mod tests { let file_infos = result.unwrap(); assert_eq!(file_infos.len(), 2); - assert_eq!(file_infos[0].file_size(), 18); - assert_eq!(file_infos[1].file_size(), 19); + assert_eq!(file_infos[0].file_size(), Some(18)); + assert_eq!(file_infos[1].file_size(), Some(19)); assert_ne!(file_infos[0].hash(), file_infos[1].hash()); } @@ -373,7 +373,7 @@ mod tests { let result1 = hash_single_file(file_path_str.clone(), 8 * 1024 * 1024); assert!(result1.is_ok()); let file_info1 = result1.unwrap(); - assert_eq!(file_info1.file_size(), file_size as u64); + assert_eq!(file_info1.file_size(), Some(file_size as u64)); assert!(!file_info1.hash().is_empty()); // Hash with 4MB buffer size - file is exactly 4x buffer size diff --git a/xet_data/src/processing/file_cleaner.rs b/xet_data/src/processing/file_cleaner.rs index dfc5b891..219d19d4 100644 --- a/xet_data/src/processing/file_cleaner.rs +++ b/xet_data/src/processing/file_cleaner.rs @@ -214,7 +214,7 @@ impl SingleFileCleaner { let file_info = XetFileInfo { hash: file_hash.hex(), - file_size: deduplication_metrics.total_bytes, + file_size: Some(deduplication_metrics.total_bytes), sha256: sha256.map(|s| s.hex()), }; diff --git a/xet_data/src/processing/file_download_session.rs b/xet_data/src/processing/file_download_session.rs index 7780c15f..799def6e 100644 --- a/xet_data/src/processing/file_download_session.rs +++ b/xet_data/src/processing/file_download_session.rs @@ -1,6 +1,6 @@ use std::borrow::Cow; use std::io::Write; -use std::ops::Range; +use std::ops::{Bound, RangeBounds}; use std::path::{Path, PathBuf}; use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; @@ -110,6 +110,16 @@ impl FileDownloadSession { let progress_updater = self.progress.new_item(id, name); let reconstructor = self.setup_reconstructor(file_info, None, Some(progress_updater))?; let n_bytes = reconstructor.reconstruct_to_file(write_path, None).await?; + // Caller is responsible for cleaning up the file on error (consistent + // with other error paths); see download_group.rs error handling. + if let Some(expected_size) = file_info.file_size() + && n_bytes != expected_size + { + return Err(DataError::SizeMismatch { + expected: expected_size, + actual: n_bytes, + }); + } prometheus_metrics::FILTER_BYTES_SMUDGED.inc_by(n_bytes); Ok(n_bytes) } @@ -117,26 +127,46 @@ impl FileDownloadSession { /// Downloads a byte range of a file and writes it to the provided writer. /// /// The provided `source_range` is interpreted against the original file; output - /// starts at the writer's current position. + /// starts at the writer's current position. Accepts any `RangeBounds`: + /// `4..12`, `5..`, `..100`, or `..` (full file). /// /// This path does not acquire the session-level file download semaphore. #[instrument(skip_all, name = "FileDownloadSession::download_to_writer", - fields(hash = file_info.hash(), range_start = source_range.start, range_end = source_range.end))] + fields(hash = file_info.hash(), range_start = tracing::field::Empty, range_end = tracing::field::Empty))] pub async fn download_to_writer( &self, file_info: &XetFileInfo, - source_range: Range, + source_range: impl RangeBounds, writer: W, ) -> Result<(UniqueID, u64)> { self.check_not_finalized()?; - let range = FileRange::new(source_range.start, source_range.end); + let range = range_bounds_to_file_range(&source_range)?; + if let Some(ref r) = range { + let span = tracing::Span::current(); + span.record("range_start", r.start); + span.record("range_end", r.end); + } let id = UniqueID::new(); let name = Arc::from(""); let progress_updater = self.progress.new_item(id, name); - let reconstructor = self.setup_reconstructor(file_info, Some(range), Some(progress_updater))?; + let reconstructor = self.setup_reconstructor(file_info, range, Some(progress_updater))?; let n_bytes = reconstructor.reconstruct_to_writer(writer).await?; - prometheus_metrics::FILTER_BYTES_SMUDGED.inc_by(n_bytes); + let expected_size = match range { + Some(r) if r.end < u64::MAX => Some(r.end - r.start), + None => file_info.file_size(), + _ => None, + }; + if let Some(expected) = expected_size + && n_bytes != expected + { + return Err(DataError::SizeMismatch { + expected, + actual: n_bytes, + }); + } + + prometheus_metrics::FILTER_BYTES_SMUDGED.inc_by(n_bytes); Ok((id, n_bytes)) } @@ -150,13 +180,27 @@ impl FileDownloadSession { /// This path does not acquire the session-level file download semaphore. #[instrument(skip_all, name = "FileDownloadSession::download_stream", fields(hash = file_info.hash()))] pub async fn download_stream(&self, file_info: &XetFileInfo) -> Result<(UniqueID, DownloadStream)> { - self.check_not_finalized()?; - let id = UniqueID::new(); - let progress_updater = self.progress.new_item(id, "stream"); - let reconstructor = self.setup_reconstructor(file_info, None, Some(progress_updater))?; - Ok((id, reconstructor.reconstruct_to_stream())) + self.download_stream_range(file_info, ..).await } + /// Creates a streaming download of a byte range of a file. + /// + /// Accepts any `RangeBounds`: `4..12`, `5..`, `..100`, or `..` (full file). + /// + /// This path does not acquire the session-level file download semaphore. + #[instrument(skip_all, name = "FileDownloadSession::download_stream_range", fields(hash = file_info.hash()))] + pub async fn download_stream_range( + &self, + file_info: &XetFileInfo, + range: impl RangeBounds, + ) -> Result<(UniqueID, DownloadStream)> { + self.check_not_finalized()?; + let file_range = range_bounds_to_file_range(&range)?; + let id = UniqueID::new(); + let progress_updater = self.progress.new_item(id, "stream"); + let reconstructor = self.setup_reconstructor(file_info, file_range, Some(progress_updater))?; + Ok((id, reconstructor.reconstruct_to_stream())) + } fn check_not_finalized(&self) -> Result<()> { if self.finalized.load(Ordering::Acquire) { return Err(DataError::InvalidOperation("FileDownloadSession already finalized".to_string())); @@ -181,12 +225,39 @@ impl FileDownloadSession { progress_updater: Option>, ) -> Result { let file_id = file_info.merkle_hash()?; - let effective_range = range.unwrap_or_else(|| FileRange::new(0, file_info.file_size())); - let size = effective_range.end - effective_range.start; - if let Some(ref updater) = progress_updater { - updater.update_item_size(size, true); + + let mut reconstructor = FileReconstructor::new(&self.client, file_id); + + match range { + Some(range) if range.end < u64::MAX => { + // Fully bounded range: we know the exact download size upfront. + let size = range.end - range.start; + if let Some(ref updater) = progress_updater { + updater.update_item_size(size, true); + } + reconstructor = reconstructor.with_byte_range(range); + }, + Some(range) => { + // Open-ended range (end == u64::MAX): pass the range to set the + // start position, but let ReconstructionTermManager discover + // the actual end and finalize progress incrementally. + reconstructor = reconstructor.with_byte_range(range); + }, + None if file_info.file_size().is_some() => { + // Full file with caller-provided size. Set progress upfront so + // UI consumers get percentage-based progress. SizeMismatch is + // validated after reconstruction in download_file_with_id. + if let Some(ref updater) = progress_updater { + updater.update_item_size(file_info.file_size().unwrap(), true); + } + }, + None => { + // Full file with unknown size: the reconstructor uses + // FileRange::full() internally and ReconstructionTermManager + // discovers the size incrementally. + }, } - let mut reconstructor = FileReconstructor::new(&self.client, file_id).with_byte_range(effective_range); + if let Some(updater) = progress_updater { reconstructor = reconstructor.with_progress_updater(updater); } @@ -194,6 +265,34 @@ impl FileDownloadSession { } } +/// Converts any `RangeBounds` into an `Option`. +/// +/// Returns `None` for the unbounded range `..` (equivalent to full file), +/// and `Some(FileRange)` otherwise. Open-ended ranges use `u64::MAX` as +/// the end sentinel (matching `FileRange::full()`). +/// +/// Returns an error for inverted ranges where `start > end`. +fn range_bounds_to_file_range(range: &impl RangeBounds) -> Result> { + let start = match range.start_bound() { + Bound::Included(&s) => s, + Bound::Excluded(&s) => s.saturating_add(1), + Bound::Unbounded => 0, + }; + let end = match range.end_bound() { + Bound::Included(&e) => e.saturating_add(1), + Bound::Excluded(&e) => e, + Bound::Unbounded => u64::MAX, + }; + if start > end { + return Err(DataError::InvalidOperation(format!("Invalid range: start ({start}) > end ({end})"))); + } + if start == 0 && end == u64::MAX { + Ok(None) + } else { + Ok(Some(FileRange::new(start, end))) + } +} + #[cfg(test)] mod tests { use std::fs::{read, write}; @@ -670,4 +769,367 @@ mod tests { }) .unwrap(); } + + // ==================== Range Download Tests ==================== + + #[test] + fn test_download_to_writer_range_from() { + let runtime = get_threadpool(); + runtime + .clone() + .external_run_async_task(async { + let temp = tempdir().unwrap(); + let cas_path = temp.path().join("cas"); + let original_data = b"0123456789abcdef"; + + let xfi = upload_data(&cas_path, original_data).await; + + let config = TranslatorConfig::local_config(&cas_path).unwrap(); + let session = FileDownloadSession::new(config.into()).await.unwrap(); + + let out_path = temp.path().join("range_from.bin"); + let file = std::fs::File::create(&out_path).unwrap(); + let (_id, n_bytes) = session.download_to_writer(&xfi, 4.., file).await.unwrap(); + + assert_eq!(n_bytes, 12); + assert_eq!(read(&out_path).unwrap(), &original_data[4..]); + }) + .unwrap(); + } + + #[test] + fn test_download_to_writer_range_to() { + let runtime = get_threadpool(); + runtime + .clone() + .external_run_async_task(async { + let temp = tempdir().unwrap(); + let cas_path = temp.path().join("cas"); + let original_data = b"0123456789abcdef"; + + let xfi = upload_data(&cas_path, original_data).await; + + let config = TranslatorConfig::local_config(&cas_path).unwrap(); + let session = FileDownloadSession::new(config.into()).await.unwrap(); + + let out_path = temp.path().join("range_to.bin"); + let file = std::fs::File::create(&out_path).unwrap(); + let (_id, n_bytes) = session.download_to_writer(&xfi, ..8, file).await.unwrap(); + + assert_eq!(n_bytes, 8); + assert_eq!(read(&out_path).unwrap(), &original_data[..8]); + }) + .unwrap(); + } + + #[test] + fn test_download_to_writer_full_range() { + let runtime = get_threadpool(); + runtime + .clone() + .external_run_async_task(async { + let temp = tempdir().unwrap(); + let cas_path = temp.path().join("cas"); + let original_data = b"0123456789abcdef"; + + let xfi = upload_data(&cas_path, original_data).await; + + let config = TranslatorConfig::local_config(&cas_path).unwrap(); + let session = FileDownloadSession::new(config.into()).await.unwrap(); + + let out_path = temp.path().join("full_range.bin"); + let file = std::fs::File::create(&out_path).unwrap(); + let (_id, n_bytes) = session.download_to_writer(&xfi, .., file).await.unwrap(); + + assert_eq!(n_bytes, original_data.len() as u64); + assert_eq!(read(&out_path).unwrap(), original_data); + }) + .unwrap(); + } + + #[test] + fn test_download_to_writer_range_inclusive() { + let runtime = get_threadpool(); + runtime + .clone() + .external_run_async_task(async { + let temp = tempdir().unwrap(); + let cas_path = temp.path().join("cas"); + let original_data = b"0123456789abcdef"; + + let xfi = upload_data(&cas_path, original_data).await; + + let config = TranslatorConfig::local_config(&cas_path).unwrap(); + let session = FileDownloadSession::new(config.into()).await.unwrap(); + + let out_path = temp.path().join("range_incl.bin"); + let file = std::fs::File::create(&out_path).unwrap(); + let (_id, n_bytes) = session.download_to_writer(&xfi, 2..=5, file).await.unwrap(); + + assert_eq!(n_bytes, 4); + assert_eq!(read(&out_path).unwrap(), &original_data[2..=5]); + }) + .unwrap(); + } + + // ==================== Range Stream Tests ==================== + + #[test] + fn test_download_stream_range_bounded() { + let runtime = get_threadpool(); + runtime + .clone() + .external_run_async_task(async { + let temp = tempdir().unwrap(); + let cas_path = temp.path().join("cas"); + let original_data = b"0123456789abcdef"; + + let xfi = upload_data(&cas_path, original_data).await; + + let config = TranslatorConfig::local_config(&cas_path).unwrap(); + let session = FileDownloadSession::new(config.into()).await.unwrap(); + + let (_id, mut stream) = session.download_stream_range(&xfi, 4..12).await.unwrap(); + + let mut collected = Vec::new(); + while let Some(chunk) = stream.next().await.unwrap() { + collected.extend_from_slice(&chunk); + } + + assert_eq!(collected, &original_data[4..12]); + }) + .unwrap(); + } + + #[test] + fn test_download_stream_range_from() { + let runtime = get_threadpool(); + runtime + .clone() + .external_run_async_task(async { + let temp = tempdir().unwrap(); + let cas_path = temp.path().join("cas"); + let original_data = b"0123456789abcdef"; + + let xfi = upload_data(&cas_path, original_data).await; + + let config = TranslatorConfig::local_config(&cas_path).unwrap(); + let session = FileDownloadSession::new(config.into()).await.unwrap(); + + let (_id, mut stream) = session.download_stream_range(&xfi, 10..).await.unwrap(); + + let mut collected = Vec::new(); + while let Some(chunk) = stream.next().await.unwrap() { + collected.extend_from_slice(&chunk); + } + + assert_eq!(collected, &original_data[10..]); + }) + .unwrap(); + } + + #[test] + fn test_download_stream_range_to() { + let runtime = get_threadpool(); + runtime + .clone() + .external_run_async_task(async { + let temp = tempdir().unwrap(); + let cas_path = temp.path().join("cas"); + let original_data = b"0123456789abcdef"; + + let xfi = upload_data(&cas_path, original_data).await; + + let config = TranslatorConfig::local_config(&cas_path).unwrap(); + let session = FileDownloadSession::new(config.into()).await.unwrap(); + + let (_id, mut stream) = session.download_stream_range(&xfi, ..6).await.unwrap(); + + let mut collected = Vec::new(); + while let Some(chunk) = stream.next().await.unwrap() { + collected.extend_from_slice(&chunk); + } + + assert_eq!(collected, &original_data[..6]); + }) + .unwrap(); + } + + // ==================== Download with unknown file size ==================== + + #[test] + fn test_download_file_unknown_size() { + let runtime = get_threadpool(); + runtime + .clone() + .external_run_async_task(async { + let temp = tempdir().unwrap(); + let cas_path = temp.path().join("cas"); + let original_data = b"File with unknown size test"; + + let xfi = upload_data(&cas_path, original_data).await; + let xfi_no_size = XetFileInfo::new_hash_only(xfi.hash().to_string()); + + let config = TranslatorConfig::local_config(&cas_path).unwrap(); + let session = FileDownloadSession::new(config.into()).await.unwrap(); + + let out_path = temp.path().join("output_unknown.txt"); + let (_id, n_bytes) = session.download_file(&xfi_no_size, &out_path).await.unwrap(); + + assert_eq!(n_bytes, original_data.len() as u64); + assert_eq!(read(&out_path).unwrap(), original_data); + }) + .unwrap(); + } + + #[test] + fn test_download_stream_unknown_size() { + let runtime = get_threadpool(); + runtime + .clone() + .external_run_async_task(async { + let temp = tempdir().unwrap(); + let cas_path = temp.path().join("cas"); + let original_data = b"Stream with unknown size test"; + + let xfi = upload_data(&cas_path, original_data).await; + let xfi_no_size = XetFileInfo::new_hash_only(xfi.hash().to_string()); + + let config = TranslatorConfig::local_config(&cas_path).unwrap(); + let session = FileDownloadSession::new(config.into()).await.unwrap(); + + let (_id, mut stream) = session.download_stream(&xfi_no_size).await.unwrap(); + + let mut collected = Vec::new(); + while let Some(chunk) = stream.next().await.unwrap() { + collected.extend_from_slice(&chunk); + } + + assert_eq!(collected, original_data); + }) + .unwrap(); + } + + #[cfg(not(debug_assertions))] + #[test] + fn test_download_file_size_mismatch_error() { + let runtime = get_threadpool(); + runtime + .clone() + .external_run_async_task(async { + let temp = tempdir().unwrap(); + let cas_path = temp.path().join("cas"); + let original_data = b"Size mismatch test data"; + + let xfi = upload_data(&cas_path, original_data).await; + let wrong_size_xfi = XetFileInfo::new(xfi.hash().to_string(), 999); + + let config = TranslatorConfig::local_config(&cas_path).unwrap(); + let session = FileDownloadSession::new(config.into()).await.unwrap(); + + let out_path = temp.path().join("output_mismatch.txt"); + let err = session.download_file(&wrong_size_xfi, &out_path).await.unwrap_err(); + + assert!( + matches!(err, DataError::SizeMismatch { expected: 999, .. }), + "Expected SizeMismatch error, got: {err:?}" + ); + }) + .unwrap(); + } + + // ==================== range_bounds_to_file_range unit tests ==================== + + #[test] + fn test_range_bounds_conversion() { + use super::range_bounds_to_file_range; + + assert_eq!(range_bounds_to_file_range(&(..)).unwrap(), None); + assert_eq!(range_bounds_to_file_range(&(0..100)).unwrap(), Some(FileRange::new(0, 100))); + assert_eq!(range_bounds_to_file_range(&(5..)).unwrap(), Some(FileRange::new(5, u64::MAX))); + assert_eq!(range_bounds_to_file_range(&(..50)).unwrap(), Some(FileRange::new(0, 50))); + assert_eq!(range_bounds_to_file_range(&(10..=19)).unwrap(), Some(FileRange::new(10, 20))); + } + + #[test] + fn test_range_bounds_inverted_range_errors() { + use super::range_bounds_to_file_range; + + let result = range_bounds_to_file_range(&(10..5)); + assert!(result.is_err()); + } + + #[test] + fn test_download_to_writer_empty_range() { + let runtime = get_threadpool(); + runtime + .clone() + .external_run_async_task(async { + let temp = tempdir().unwrap(); + let cas_path = temp.path().join("cas"); + let original_data = b"0123456789abcdef"; + + let xfi = upload_data(&cas_path, original_data).await; + + let config = TranslatorConfig::local_config(&cas_path).unwrap(); + let session = FileDownloadSession::new(config.into()).await.unwrap(); + + let out_path = temp.path().join("empty_range.bin"); + let file = std::fs::File::create(&out_path).unwrap(); + let (_id, n_bytes) = session.download_to_writer(&xfi, 5..5, file).await.unwrap(); + + assert_eq!(n_bytes, 0); + assert_eq!(read(&out_path).unwrap(), &[] as &[u8]); + }) + .unwrap(); + } + + #[test] + fn test_download_to_writer_inverted_range_errors() { + let runtime = get_threadpool(); + runtime + .clone() + .external_run_async_task(async { + let temp = tempdir().unwrap(); + let cas_path = temp.path().join("cas"); + let original_data = b"0123456789abcdef"; + + let xfi = upload_data(&cas_path, original_data).await; + + let config = TranslatorConfig::local_config(&cas_path).unwrap(); + let session = FileDownloadSession::new(config.into()).await.unwrap(); + + let out_path = temp.path().join("inverted_range.bin"); + let file = std::fs::File::create(&out_path).unwrap(); + let result = session.download_to_writer(&xfi, 10..5, file).await; + + assert!(result.is_err()); + }) + .unwrap(); + } + + #[cfg(not(debug_assertions))] + #[test] + fn test_download_to_writer_range_start_beyond_file_size_errors() { + let runtime = get_threadpool(); + runtime + .clone() + .external_run_async_task(async { + let temp = tempdir().unwrap(); + let cas_path = temp.path().join("cas"); + let original_data = b"0123456789abcdef"; + + let xfi = upload_data(&cas_path, original_data).await; + + let config = TranslatorConfig::local_config(&cas_path).unwrap(); + let session = FileDownloadSession::new(config.into()).await.unwrap(); + + let out_path = temp.path().join("beyond_size.bin"); + let file = std::fs::File::create(&out_path).unwrap(); + let result = session.download_to_writer(&xfi, 100000.., file).await; + + assert!(result.is_err()); + }) + .unwrap(); + } } diff --git a/xet_data/src/processing/test_utils.rs b/xet_data/src/processing/test_utils.rs index 41a17b54..deeea48b 100644 --- a/xet_data/src/processing/test_utils.rs +++ b/xet_data/src/processing/test_utils.rs @@ -335,7 +335,7 @@ impl HydrateDehydrateTest { let entry = entry.unwrap(); let out_filename = self.dest_dir.join(entry.file_name()); let xf: XetFileInfo = serde_json::from_reader(File::open(entry.path()).unwrap()).unwrap(); - let file_size = xf.file_size(); + let file_size = xf.file_size().expect("file size required for partitioned hydration"); let out_file = File::create(&out_filename).unwrap(); out_file.set_len(file_size).unwrap(); diff --git a/xet_data/src/processing/xet_file.rs b/xet_data/src/processing/xet_file.rs index ff6fa3a7..f3bc13dc 100644 --- a/xet_data/src/processing/xet_file.rs +++ b/xet_data/src/processing/xet_file.rs @@ -8,8 +8,9 @@ pub struct XetFileInfo { /// The Merkle hash of the file pub hash: String, - /// The size of the file - pub file_size: u64, + /// The size of the file, if known. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub file_size: Option, /// The SHA-256 hash of the file, if available. #[serde(default, skip_serializing_if = "Option::is_none")] @@ -17,7 +18,7 @@ pub struct XetFileInfo { } impl XetFileInfo { - /// Creates a new `XetFileInfo` instance. + /// Creates a new `XetFileInfo` instance with a known size. /// /// # Arguments /// @@ -26,20 +27,29 @@ impl XetFileInfo { pub fn new(hash: String, file_size: u64) -> Self { Self { hash, - file_size, + file_size: Some(file_size), sha256: None, } } - /// Creates a new `XetFileInfo` instance with a SHA-256 hash. + /// Creates a new `XetFileInfo` instance with a SHA-256 hash and known size. pub fn new_with_sha256(hash: String, file_size: u64, sha256: String) -> Self { Self { hash, - file_size, + file_size: Some(file_size), sha256: Some(sha256), } } + /// Creates a new `XetFileInfo` with only a hash and no known size. + pub fn new_hash_only(hash: String) -> Self { + Self { + hash, + file_size: None, + sha256: None, + } + } + /// Returns the Merkle hash of the file. pub fn hash(&self) -> &str { &self.hash @@ -50,8 +60,8 @@ impl XetFileInfo { MerkleHash::from_hex(&self.hash).log_error("Error parsing hash value for file info") } - /// Returns the size of the file. - pub fn file_size(&self) -> u64 { + /// Returns the size of the file, if known. + pub fn file_size(&self) -> Option { self.file_size } diff --git a/xet_data/tests/test_range_downloads.rs b/xet_data/tests/test_range_downloads.rs new file mode 100644 index 00000000..ce5e6771 --- /dev/null +++ b/xet_data/tests/test_range_downloads.rs @@ -0,0 +1,213 @@ +//! Integration tests for range-based downloads using a LocalTestServer. +//! +//! Exercises all range variants (`start..end`, `start..`, `..end`, `..`) across +//! the three download paths: file, writer, and streaming. + +#[cfg(test)] +mod tests { + use std::fs; + use std::sync::Arc; + + use tempfile::TempDir; + use xet_client::cas_client::{LocalTestServer, LocalTestServerBuilder}; + use xet_data::processing::configurations::TranslatorConfig; + use xet_data::processing::{FileDownloadSession, FileUploadSession, Sha256Policy, XetFileInfo}; + + async fn upload_bytes(upload_session: &Arc, name: &str, data: &[u8]) -> XetFileInfo { + let (_id, mut cleaner) = upload_session + .start_clean(Some(name.into()), data.len() as u64, Sha256Policy::Compute) + .unwrap(); + cleaner.add_data(data).await.unwrap(); + let (xfi, _metrics) = cleaner.finish().await.unwrap(); + xfi + } + + struct TestHarness { + _server: LocalTestServer, + _base_dir: TempDir, + session: Arc, + xfi: XetFileInfo, + data: Vec, + } + + async fn setup() -> TestHarness { + let server = LocalTestServerBuilder::new().start().await; + let base_dir = TempDir::new().unwrap(); + let config = Arc::new(TranslatorConfig::test_server_config(server.http_endpoint(), base_dir.path()).unwrap()); + + let data: Vec = (0..=255u8).cycle().take(8192).collect(); + let upload_session = FileUploadSession::new(config.clone()).await.unwrap(); + let xfi = upload_bytes(&upload_session, "range_test", &data).await; + upload_session.finalize().await.unwrap(); + + let download_session = FileDownloadSession::new(config).await.unwrap(); + TestHarness { + _server: server, + _base_dir: base_dir, + session: download_session, + xfi, + data, + } + } + + // ── Writer helpers ─────────────────────────────────────────────────────── + + async fn writer_download( + session: &FileDownloadSession, + xfi: &XetFileInfo, + range: impl std::ops::RangeBounds, + ) -> (u64, Vec) { + let tmp = tempfile::NamedTempFile::new().unwrap(); + let path = tmp.path().to_path_buf(); + // Keep the NamedTempFile alive so the path remains valid. + let file = tmp.reopen().unwrap(); + let (_id, n_bytes) = session.download_to_writer(xfi, range, file).await.unwrap(); + let contents = fs::read(&path).unwrap(); + (n_bytes, contents) + } + + // ── Stream helpers ─────────────────────────────────────────────────────── + + async fn stream_download( + session: &FileDownloadSession, + xfi: &XetFileInfo, + range: impl std::ops::RangeBounds, + ) -> Vec { + let (_id, mut stream) = session.download_stream_range(xfi, range).await.unwrap(); + let mut collected = Vec::new(); + while let Some(chunk) = stream.next().await.unwrap() { + collected.extend_from_slice(&chunk); + } + collected + } + + // ── download_to_writer with various range types ────────────────────────── + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_writer_bounded_range() { + let h = setup().await; + let (n_bytes, buf) = writer_download(&h.session, &h.xfi, 100..200).await; + assert_eq!(n_bytes, 100); + assert_eq!(buf, h.data[100..200]); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_writer_range_from() { + let h = setup().await; + let (n_bytes, buf) = writer_download(&h.session, &h.xfi, 8000..).await; + assert_eq!(n_bytes, 192); + assert_eq!(buf, h.data[8000..]); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_writer_range_to() { + let h = setup().await; + let (n_bytes, buf) = writer_download(&h.session, &h.xfi, ..128).await; + assert_eq!(n_bytes, 128); + assert_eq!(buf, h.data[..128]); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_writer_full_range() { + let h = setup().await; + let (n_bytes, buf) = writer_download(&h.session, &h.xfi, ..).await; + assert_eq!(n_bytes, h.data.len() as u64); + assert_eq!(buf, h.data); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_writer_inclusive_range() { + let h = setup().await; + let (n_bytes, buf) = writer_download(&h.session, &h.xfi, 50..=149).await; + assert_eq!(n_bytes, 100); + assert_eq!(buf, h.data[50..=149]); + } + + // ── download_stream_range with various range types ─────────────────────── + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_stream_bounded_range() { + let h = setup().await; + let collected = stream_download(&h.session, &h.xfi, 100..200).await; + assert_eq!(collected, h.data[100..200]); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_stream_range_from() { + let h = setup().await; + let collected = stream_download(&h.session, &h.xfi, 8000..).await; + assert_eq!(collected, h.data[8000..]); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_stream_range_to() { + let h = setup().await; + let collected = stream_download(&h.session, &h.xfi, ..128).await; + assert_eq!(collected, h.data[..128]); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_stream_full_range() { + let h = setup().await; + let collected = stream_download(&h.session, &h.xfi, ..).await; + assert_eq!(collected, h.data); + } + + // ── download_file with unknown size ────────────────────────────────────── + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_download_file_unknown_size() { + let h = setup().await; + let xfi_no_size = XetFileInfo::new_hash_only(h.xfi.hash().to_string()); + + let out_path = h._base_dir.path().join("unknown_size.bin"); + let (_id, n_bytes) = h.session.download_file(&xfi_no_size, &out_path).await.unwrap(); + + assert_eq!(n_bytes, h.data.len() as u64); + assert_eq!(fs::read(&out_path).unwrap(), h.data); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_download_stream_unknown_size() { + let h = setup().await; + let xfi_no_size = XetFileInfo::new_hash_only(h.xfi.hash().to_string()); + let collected = stream_download(&h.session, &xfi_no_size, ..).await; + assert_eq!(collected, h.data); + } + + // ── size mismatch validation ───────────────────────────────────────────── + + // SizeMismatch is caught after reconstruction completes, but debug + // assertions inside the progress tracker fire first, so this test only + // passes in release/test profile without debug_assertions. + #[cfg(not(debug_assertions))] + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_download_file_size_mismatch() { + let h = setup().await; + let wrong_size = XetFileInfo::new(h.xfi.hash().to_string(), 42); + + let out_path = h._base_dir.path().join("mismatch.bin"); + let err = h.session.download_file(&wrong_size, &out_path).await.unwrap_err(); + let msg = format!("{err}"); + assert!(msg.contains("mismatch"), "Expected size mismatch error, got: {msg}"); + } + + // ── range download with unknown file size ──────────────────────────────── + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_writer_range_from_unknown_size() { + let h = setup().await; + let xfi_no_size = XetFileInfo::new_hash_only(h.xfi.hash().to_string()); + let (n_bytes, buf) = writer_download(&h.session, &xfi_no_size, 8000..).await; + assert_eq!(n_bytes, 192); + assert_eq!(buf, h.data[8000..]); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_stream_range_from_unknown_size() { + let h = setup().await; + let xfi_no_size = XetFileInfo::new_hash_only(h.xfi.hash().to_string()); + let collected = stream_download(&h.session, &xfi_no_size, 8000..).await; + assert_eq!(collected, h.data[8000..]); + } +} diff --git a/xet_pkg/examples/example.rs b/xet_pkg/examples/example.rs index 3b113b01..5ae01b07 100644 --- a/xet_pkg/examples/example.rs +++ b/xet_pkg/examples/example.rs @@ -119,7 +119,7 @@ async fn download_files(metadata_file: PathBuf, output_dir: PathBuf, endpoint: O .download_file_to_path( XetFileInfo { hash: m.hash.clone(), - file_size: m.file_size, + file_size: Some(m.file_size), sha256: m.sha256.clone(), }, dest, @@ -148,7 +148,7 @@ async fn download_files(metadata_file: PathBuf, output_dir: PathBuf, endpoint: O for (_task_id, result) in &results { if let Ok(r) = result.as_ref() { - println!(" {} ({} bytes)", r.dest_path.display(), r.file_info.file_size); + println!(" {} ({:?} bytes)", r.dest_path.display(), r.file_info.file_size); } } diff --git a/xet_pkg/examples/example_sync.rs b/xet_pkg/examples/example_sync.rs index b120b54c..e08014f0 100644 --- a/xet_pkg/examples/example_sync.rs +++ b/xet_pkg/examples/example_sync.rs @@ -114,7 +114,7 @@ fn download_files(metadata_file: PathBuf, output_dir: PathBuf, endpoint: Option< handles.push(group.download_file_to_path_blocking( XetFileInfo { hash: m.hash.clone(), - file_size: m.file_size, + file_size: Some(m.file_size), sha256: m.sha256.clone(), }, dest, @@ -141,7 +141,7 @@ fn download_files(metadata_file: PathBuf, output_dir: PathBuf, endpoint: Option< for (_task_id, result) in &results { if let Ok(r) = result.as_ref() { - println!(" {} ({} bytes)", r.dest_path.display(), r.file_info.file_size); + println!(" {} ({:?} bytes)", r.dest_path.display(), r.file_info.file_size); } } diff --git a/xet_pkg/src/xet_session/download_group.rs b/xet_pkg/src/xet_session/download_group.rs index 35d4730d..68ecaebf 100644 --- a/xet_pkg/src/xet_session/download_group.rs +++ b/xet_pkg/src/xet_session/download_group.rs @@ -323,7 +323,7 @@ impl DownloadGroupInner { dest_path: handle.dest_path, file_info: XetFileInfo { hash: handle.file_info.hash, - file_size: n_bytes, + file_size: Some(n_bytes), sha256: None, }, })); @@ -412,7 +412,7 @@ mod tests { let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap(); Ok(XetFileInfo { hash: meta.hash.clone(), - file_size: meta.file_size, + file_size: Some(meta.file_size), sha256: meta.sha256.clone(), }) } @@ -535,7 +535,7 @@ mod tests { .download_file_to_path( XetFileInfo { hash: "abc123".to_string(), - file_size: 1024, + file_size: Some(1024), sha256: None, }, std::path::PathBuf::from("dest.bin"), @@ -556,7 +556,7 @@ mod tests { .download_file_to_path( XetFileInfo { hash: "abc123".to_string(), - file_size: 1024, + file_size: Some(1024), sha256: None, }, std::path::PathBuf::from("dest.bin"), @@ -576,7 +576,7 @@ mod tests { .download_file_to_path( XetFileInfo { hash: "abc123".to_string(), - file_size: 1024, + file_size: Some(1024), sha256: None, }, std::path::PathBuf::from("dest.bin"), @@ -628,7 +628,7 @@ mod tests { .download_file_to_path( XetFileInfo { hash: "abc123".to_string(), - file_size: 123, + file_size: Some(123), sha256: None, }, temp.path().join("missing.bin"), @@ -693,7 +693,7 @@ mod tests { let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap(); XetFileInfo { hash: meta.hash.clone(), - file_size: meta.file_size, + file_size: Some(meta.file_size), sha256: meta.sha256.clone(), } }; @@ -759,7 +759,7 @@ mod tests { let handle = group.download_file_to_path(file_info, dest).await.unwrap(); let results = group.finish().await.unwrap(); let result = results.get(&handle.task_id).expect("task_id must be present in results"); - assert_eq!(result.as_ref().as_ref().unwrap().file_info.file_size, data.len() as u64); + assert_eq!(result.as_ref().as_ref().unwrap().file_info.file_size, Some(data.len() as u64)); } #[tokio::test(flavor = "multi_thread")] @@ -788,7 +788,7 @@ mod tests { group.finish().await.unwrap(); let result = handle.result().expect("result must be set after finish()"); let dl = result.as_ref().as_ref().unwrap(); - assert_eq!(dl.file_info.file_size, data.len() as u64); + assert_eq!(dl.file_info.file_size, Some(data.len() as u64)); assert_eq!(dl.file_info.hash, file_info.hash); } @@ -815,7 +815,7 @@ mod tests { let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap(); let file_info = XetFileInfo { hash: meta.hash.clone(), - file_size: meta.file_size, + file_size: Some(meta.file_size), sha256: meta.sha256.clone(), }; @@ -846,7 +846,7 @@ mod tests { let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap(); let file_info = XetFileInfo { hash: meta.hash.clone(), - file_size: meta.file_size, + file_size: Some(meta.file_size), sha256: meta.sha256.clone(), }; @@ -877,7 +877,7 @@ mod tests { let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap(); let file_info = XetFileInfo { hash: meta.hash.clone(), - file_size: meta.file_size, + file_size: Some(meta.file_size), sha256: meta.sha256.clone(), }; @@ -905,7 +905,7 @@ mod tests { let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap(); Ok(XetFileInfo { hash: meta.hash.clone(), - file_size: meta.file_size, + file_size: Some(meta.file_size), sha256: meta.sha256.clone(), }) } @@ -943,7 +943,7 @@ mod tests { let meta = results.get(&handle.task_id).unwrap().as_ref().as_ref().unwrap(); XetFileInfo { hash: meta.hash.clone(), - file_size: meta.file_size, + file_size: Some(meta.file_size), sha256: meta.sha256.clone(), } }; @@ -1006,12 +1006,12 @@ mod tests { // Result should be available in the finish map by task id. let map_result = results.get(&handle.task_id).expect("task_id must be present in results"); - assert_eq!(map_result.as_ref().as_ref().unwrap().file_info.file_size, data.len() as u64); + assert_eq!(map_result.as_ref().as_ref().unwrap().file_info.file_size, Some(data.len() as u64)); // Result should also be available via the task handle. let result = handle.result().expect("result must be set after finish"); let dl = result.as_ref().as_ref().unwrap(); - assert_eq!(dl.file_info.file_size, data.len() as u64); + assert_eq!(dl.file_info.file_size, Some(data.len() as u64)); assert_eq!(dl.file_info.hash, file_info.hash); Ok(()) } @@ -1059,7 +1059,7 @@ mod tests { let group = session.new_download_group().await.unwrap(); let file_info = XetFileInfo { hash: String::new(), - file_size: 0, + file_size: Some(0), sha256: None, }; let err = group @@ -1082,7 +1082,7 @@ mod tests { let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap(); let file_info = XetFileInfo { hash: String::new(), - file_size: 0, + file_size: Some(0), sha256: None, }; let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { diff --git a/xet_pkg/src/xet_session/mod.rs b/xet_pkg/src/xet_session/mod.rs index b08082e6..027ad4b1 100644 --- a/xet_pkg/src/xet_session/mod.rs +++ b/xet_pkg/src/xet_session/mod.rs @@ -85,7 +85,7 @@ //! let group = session.new_download_group_blocking()?; //! let info = XetFileInfo { //! hash: m.hash.clone(), -//! file_size: m.file_size, +//! file_size: Some(m.file_size), //! sha256: m.sha256.clone(), //! }; //! let dl_handle = group.download_file_to_path_blocking(info, "out/file.bin".into())?; @@ -122,7 +122,7 @@ //! let group = session.new_download_group().await?; //! let info = XetFileInfo { //! hash: m.hash.clone(), -//! file_size: m.file_size, +//! file_size: Some(m.file_size), //! sha256: m.sha256.clone(), //! }; //! let dl_handle = group.download_file_to_path(info, "out/file.bin".into()).await?; diff --git a/xet_pkg/src/xet_session/tasks.rs b/xet_pkg/src/xet_session/tasks.rs index b5d51342..70341533 100644 --- a/xet_pkg/src/xet_session/tasks.rs +++ b/xet_pkg/src/xet_session/tasks.rs @@ -187,7 +187,7 @@ mod tests { dest_path: PathBuf::from("out/file.bin"), file_info: XetFileInfo { hash: "def456".to_string(), - file_size: 99, + file_size: Some(99), sha256: None, }, })); @@ -195,7 +195,7 @@ mod tests { let result = handle.result().unwrap(); let dl = result.as_ref().as_ref().unwrap(); - assert_eq!(dl.file_info.file_size, 99); + assert_eq!(dl.file_info.file_size, Some(99)); assert_eq!(dl.dest_path, PathBuf::from("out/file.bin")); } } diff --git a/xet_pkg/src/xet_session/upload_commit.rs b/xet_pkg/src/xet_session/upload_commit.rs index f944ed3c..c53b983b 100644 --- a/xet_pkg/src/xet_session/upload_commit.rs +++ b/xet_pkg/src/xet_session/upload_commit.rs @@ -544,7 +544,7 @@ impl UploadCommitInner { let result = Arc::new(Ok(FileMetadata { tracking_name: handle.tracking_name, hash: file_info.hash().to_string(), - file_size: file_info.file_size(), + file_size: file_info.file_size().expect("upload always produces a known file size"), sha256: file_info.sha256().map(str::to_owned), })); results.insert(task_id, result.clone()); @@ -1103,7 +1103,7 @@ mod tests { let (xfi, _) = cleaner.finish().await.unwrap(); let results = commit.commit().await.unwrap(); assert!(results.is_empty()); - assert_eq!(xfi.file_size, data.len() as u64); + assert_eq!(xfi.file_size, Some(data.len() as u64)); assert!(!xfi.hash.is_empty()); } @@ -1304,7 +1304,7 @@ mod tests { })?; let results = commit.commit_blocking()?; assert!(results.is_empty()); - assert_eq!(file_size, data.len() as u64); + assert_eq!(file_size, Some(data.len() as u64)); assert!(!hash.is_empty()); Ok(()) } diff --git a/xet_pkg/tests/test_legacy_data_client.rs b/xet_pkg/tests/test_legacy_data_client.rs index d0753099..75a9a943 100644 --- a/xet_pkg/tests/test_legacy_data_client.rs +++ b/xet_pkg/tests/test_legacy_data_client.rs @@ -53,7 +53,7 @@ mod tests { assert_eq!(file_infos.len(), 3); for info in &file_infos { assert!(!info.hash.is_empty()); - assert!(info.file_size > 0); + assert!(info.file_size.unwrap_or(0) > 0); } let download_dir = TempDir::new().unwrap(); @@ -256,7 +256,7 @@ mod tests { .unwrap(); assert_eq!(file_infos.len(), 1); - assert_eq!(file_infos[0].file_size, large_data.len() as u64); + assert_eq!(file_infos[0].file_size, Some(large_data.len() as u64)); let download_dir = TempDir::new().unwrap(); let out_path = download_dir.path().join("large_out.bin"); diff --git a/xet_pkg/tests/test_xet_session.rs b/xet_pkg/tests/test_xet_session.rs index c9793288..d7efbd51 100644 --- a/xet_pkg/tests/test_xet_session.rs +++ b/xet_pkg/tests/test_xet_session.rs @@ -43,7 +43,7 @@ fn sync_session(temp: &TempDir) -> XetSession { fn to_file_info(meta: &FileMetadata) -> XetFileInfo { XetFileInfo { hash: meta.hash.clone(), - file_size: meta.file_size, + file_size: Some(meta.file_size), sha256: meta.sha256.clone(), } } @@ -71,7 +71,7 @@ fn upload_bytes_sync(session: &XetSession, data: &[u8], name: &str) -> XetFileIn async fn assert_roundtrip_async(session: &XetSession, temp: &TempDir, data: &[u8], name: &str) { let file_info = upload_bytes_async(session, data, name).await; - assert_eq!(file_info.file_size, data.len() as u64); + assert_eq!(file_info.file_size, Some(data.len() as u64)); let dest = temp.path().join(format!("{name}.out")); let group = session.new_download_group().await.unwrap(); @@ -82,7 +82,7 @@ async fn assert_roundtrip_async(session: &XetSession, temp: &TempDir, data: &[u8 fn assert_roundtrip_sync(session: &XetSession, temp: &TempDir, data: &[u8], name: &str) { let file_info = upload_bytes_sync(session, data, name); - assert_eq!(file_info.file_size, data.len() as u64); + assert_eq!(file_info.file_size, Some(data.len() as u64)); let dest = temp.path().join(format!("{name}.out")); let group = session.new_download_group_blocking().unwrap(); @@ -380,6 +380,28 @@ async fn async_progress_tracking() { assert_eq!(report.total_bytes_completed, data.len() as u64); } +// ── Download with unknown file size ────────────────────────────────────── + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn async_download_unknown_size_roundtrip() { + let temp = tempdir().unwrap(); + let session = async_session(&temp).await; + let data = b"download with unknown size via xet_pkg"; + let file_info = upload_bytes_async(&session, data, "unknown_size.bin").await; + + let hash_only = XetFileInfo::new_hash_only(file_info.hash().to_string()); + + let dest = temp.path().join("unknown_size.out"); + let group = session.new_download_group().await.unwrap(); + group.download_file_to_path(hash_only, dest.clone()).await.unwrap(); + let results = group.finish().await.unwrap(); + + for result in results.values() { + let dl = result.as_ref().as_ref().unwrap(); + assert_eq!(dl.file_info.file_size, Some(data.len() as u64)); + } + assert_eq!(fs::read(&dest).unwrap(), data); +} #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_download_invalid_hash_fails() { let temp = tempdir().unwrap(); @@ -390,7 +412,7 @@ async fn async_download_invalid_hash_fails() { .download_file_to_path( XetFileInfo { hash: "nonexistent_hash_abc123".to_string(), - file_size: 100, + file_size: Some(100), sha256: None, }, temp.path().join("missing.bin"), @@ -859,7 +881,7 @@ async fn async_abort_rejects_download_on_existing_group() { .download_file_to_path( XetFileInfo { hash: "abc".to_string(), - file_size: 1, + file_size: Some(1), sha256: None, }, PathBuf::from("dest.bin"),