From 9332ff28b745623ef92ecc90c273ef5a2039c30d Mon Sep 17 00:00:00 2001 From: Hoyt Koepke Date: Fri, 9 Jan 2026 12:39:52 -0800 Subject: [PATCH] Mock CAS server built on LocalClient for testing and simulation. (#602) This PR adds a fully functional CAS server built around a LocalClient instance. This allows full testing of the RemoteClient interface without hitting the actual CAS backend. For testing, it can either be run as a standalone executable, or it can be started using a LocalTestServer instance that exposes both a RemoteClient interface as client, or direct access to the state through a stored LocalClient instance. Numerous tests are added to also cover existing functionality as well as the new server functioning. (Also, it exposed that when using a lot of tests with wiremock or this server, the testing would often hit a "Too many open files" error; this was fixed by consolidating these tests to reduce the number of separate testing servers running at once. --- Cargo.lock | 24 + Cargo.toml | 1 + cas_client/Cargo.toml | 10 +- cas_client/src/client_testing_utils.rs | 143 ++++-- cas_client/src/lib.rs | 2 + cas_client/src/local_client.rs | 201 ++++---- cas_client/src/local_server/handlers.rs | 522 +++++++++++++++++++ cas_client/src/local_server/main.rs | 104 ++++ cas_client/src/local_server/mod.rs | 20 + cas_client/src/local_server/server.rs | 637 ++++++++++++++++++++++++ cas_client/src/retry_wrapper.rs | 178 ++++--- cas_client/tests/reconstruction.rs | 200 ++++++++ hf_xet/Cargo.lock | 23 + hf_xet_wasm/Cargo.lock | 95 +++- 14 files changed, 1935 insertions(+), 225 deletions(-) create mode 100644 cas_client/src/local_server/handlers.rs create mode 100644 cas_client/src/local_server/main.rs create mode 100644 cas_client/src/local_server/mod.rs create mode 100644 cas_client/src/local_server/server.rs create mode 100644 cas_client/tests/reconstruction.rs diff --git a/Cargo.lock b/Cargo.lock index f66ba9ee..eb707854 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -411,10 +411,13 @@ checksum = "5b098575ebe77cb6d14fc7f32749631a6e44edbef6b796f89b020e99ba20d425" dependencies = [ "axum-core", "bytes", + "form_urlencoded", "futures-util", "http 1.3.1", "http-body 1.0.1", "http-body-util", + "hyper 1.8.1", + "hyper-util", "itoa", "matchit", "memchr", @@ -422,10 +425,15 @@ dependencies = [ "percent-encoding", "pin-project-lite", "serde_core", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", "sync_wrapper", + "tokio", "tower", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -444,6 +452,7 @@ dependencies = [ "sync_wrapper", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -668,6 +677,8 @@ dependencies = [ "anyhow", "approx", "async-trait", + "axum", + "base64 0.22.1", "bytes", "cas_object", "cas_types", @@ -679,6 +690,7 @@ dependencies = [ "error_printer", "file_utils", "futures", + "futures-util", "heed", "http 1.3.1", "httpmock", @@ -700,6 +712,7 @@ dependencies = [ "thiserror 2.0.12", "tokio", "tokio-retry", + "tower-http", "tracing", "tracing-log", "tracing-subscriber", @@ -4782,6 +4795,17 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" +dependencies = [ + "itoa", + "serde", + "serde_core", +] + [[package]] name = "serde_regex" version = "1.1.0" diff --git a/Cargo.toml b/Cargo.toml index c0734f43..237c830a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,6 +35,7 @@ debug = 1 [workspace.dependencies] anyhow = "1" +axum = "0.8" async-trait = "0.1" base64 = "0.22" bincode = "1.3" diff --git a/cas_client/Cargo.toml b/cas_client/Cargo.toml index 9cc2ae17..6e95ab56 100644 --- a/cas_client/Cargo.toml +++ b/cas_client/Cargo.toml @@ -19,6 +19,7 @@ xet_runtime = { path = "../xet_runtime" } anyhow = { workspace = true } async-trait = { workspace = true } +base64 = { workspace = true } bytes = { workspace = true } chrono = { workspace = true } clap = { workspace = true } @@ -68,8 +69,11 @@ native-tls-vendored = ["reqwest/native-tls-vendored"] [target.'cfg(not(target_family = "wasm"))'.dependencies] +axum = { workspace = true } +futures-util = { workspace = true } heed = { workspace = true } hyper = { workspace = true } +tower-http = { version = "0.6", features = ["cors"] } warp = { workspace = true } [target.'cfg(target_family = "wasm")'.dependencies] @@ -82,7 +86,11 @@ rand_distr = { workspace = true } tracing-test = { workspace = true } wiremock = { workspace = true } -# Test binaries for adaptive concurrency testing +# Local CAS server binary - wraps LocalClient with HTTP API for testing +[[bin]] +name = "local_cas_server" +path = "src/local_server/main.rs" + [[bin]] name = "simulation_server" path = "tests/adaptive_concurrency/src/simulation_server.rs" diff --git a/cas_client/src/client_testing_utils.rs b/cas_client/src/client_testing_utils.rs index a9947d38..4a398357 100644 --- a/cas_client/src/client_testing_utils.rs +++ b/cas_client/src/client_testing_utils.rs @@ -11,6 +11,73 @@ use rand::prelude::*; use crate::error::Result; use crate::interface::Client; +/// Information about a term (segment) in the file, referencing an XORB and chunk range. +#[derive(Clone, Debug)] +pub struct FileTermReference { + /// The XORB hash this term references. + pub xorb_hash: MerkleHash, + /// Start chunk index (inclusive) within the XORB. + pub chunk_start: u32, + /// End chunk index (exclusive) within the XORB. + pub chunk_end: u32, + /// The data for this term (concatenated chunk data). + pub data: Vec, + /// The chunk hashes for this term. + pub chunk_hashes: Vec, +} + +/// Complete information about a randomly generated file for testing purposes. +/// +/// Contains all the metadata needed to verify that reconstruction and fetching +/// operations return correct data. +#[derive(Clone, Debug)] +pub struct RandomFileContents { + /// The file hash (used for reconstruction queries). + pub file_hash: MerkleHash, + /// The complete file data. + pub data: Vec, + /// The RawXorbData for each XORB that was created, keyed by XORB hash. + pub xorbs: HashMap, + /// Information about each term in file order. + pub terms: Vec, +} + +impl RandomFileContents { + /// Verifies that the given data matches the expected data for a specific term. + /// + /// This checks that the hash of the provided data matches the expected XORB + /// data for the term at the given index. + /// + /// # Arguments + /// * `term_index` - The index of the term (0-based) in the terms list + /// * `data` - The data to verify against the expected term data + /// + /// # Returns + /// `true` if the data matches the expected term data, `false` otherwise. + pub fn term_matches(&self, term_index: usize, data: &[u8]) -> bool { + if term_index >= self.terms.len() { + return false; + } + let term = &self.terms[term_index]; + term.data == data + } + + /// Returns the expected data for a specific term. + pub fn term_data(&self, term_index: usize) -> Option<&[u8]> { + self.terms.get(term_index).map(|t| t.data.as_slice()) + } + + /// Returns the XORB hash for a specific term. + pub fn term_xorb_hash(&self, term_index: usize) -> Option { + self.terms.get(term_index).map(|t| t.xorb_hash) + } + + /// Returns the chunk range for a specific term. + pub fn term_chunk_range(&self, term_index: usize) -> Option<(u32, u32)> { + self.terms.get(term_index).map(|t| (t.chunk_start, t.chunk_end)) + } +} + /// A trait that adds testing utility functions to the Client interface. #[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)] #[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))] @@ -18,16 +85,18 @@ use crate::interface::Client; pub trait ClientTestingUtils: Client + Send + Sync { /// Insert a random file into the local CAS. /// - /// This function is used to test the local CAS client. + /// This function generates a random file with the given term specification. + /// Each term is defined as `(xorb_seed, (chunk_start, chunk_end))` where: + /// - `xorb_seed` determines the random data for that XORB + /// - `chunk_start` and `chunk_end` define the range of chunks to include /// - /// It generates a random file with a given number of chunks and chunk size. - /// It then creates all the xorbs and shard definitions needed, returning - /// the file data and the file hash. + /// Returns a `RandomFileContents` struct containing all the metadata needed + /// to verify reconstruction and fetching operations. async fn upload_random_file( &self, term_spec: &[(u64, (u64, u64))], chunk_size: usize, - ) -> Result<(Vec, MerkleHash)> { + ) -> Result { let mut xorb_num_chunks = HashMap::::new(); for &(xorb_seed, (_chunk_idx_start, chunk_idx_end)) in term_spec { @@ -35,22 +104,15 @@ pub trait ClientTestingUtils: Client + Send + Sync { *c = (*c).max(chunk_idx_end); } - // Track the data so that we can reconstruct the whole file later. - let mut xorb_data = HashMap::>::new(); let mut shard = MDBInMemoryShard::default(); - - let mut xorb_hash = HashMap::::new(); + let mut xorb_data = HashMap::::new(); for (&xorb_seed, n_chunks) in xorb_num_chunks.iter() { let mut rng = SmallRng::seed_from_u64(xorb_seed); - let n_chunks = *n_chunks as usize; let mut chunks = Vec::with_capacity(n_chunks); for _idx in 0..n_chunks { - // duplicate the range so that compression kicks in; - // copy the second part of the chunk from the first part. - let n = rng.random_range((chunk_size / 2 + 1)..chunk_size); let n_left = chunk_size - n; @@ -68,52 +130,59 @@ pub trait ClientTestingUtils: Client + Send + Sync { }); } - // Create RawXorbData from the generated chunks. - // file_boundaries indicates where new files start; use [0] for single file. let raw_xorb = RawXorbData::from_chunks(&chunks, vec![0]); - // Record the xorb data. - xorb_data.insert(xorb_seed, chunks); - - // Add it to the shard. shard.add_cas_block(raw_xorb.cas_info.clone())?; - // Record the hash. - xorb_hash.insert(xorb_seed, raw_xorb.hash()); - - // Build SerializedCasObject let serialized_xorb = SerializedCasObject::from_xorb(raw_xorb.clone(), None, true)?; - // upload the xorb let upload_permit = self.acquire_upload_permit().await?; self.upload_xorb("default", serialized_xorb, None, upload_permit).await?; + + xorb_data.insert(xorb_seed, raw_xorb); } - // Now, build the file info and file data. + // Build the file info and file data from RawXorbData. let mut file_segments = Vec::new(); let mut file_data = Vec::new(); let mut chunk_file_hashes = Vec::new(); + let mut term_infos = Vec::new(); for &(xorb_seed, (chunk_idx_start, chunk_idx_end)) in term_spec { - let xorb_hash = xorb_hash.get(&xorb_seed).unwrap(); + let raw_xorb = xorb_data.get(&xorb_seed).unwrap(); + let xorb_h = raw_xorb.hash(); let (c_lb, c_ub) = (chunk_idx_start as usize, chunk_idx_end as usize); - let chunks = &xorb_data.get(&xorb_seed).unwrap()[c_lb..c_ub]; let mut n_bytes = 0; + let mut term_data = Vec::new(); + let mut term_chunk_hashes = Vec::new(); - for chunk in chunks { - file_data.extend_from_slice(&chunk.data); - n_bytes += chunk.data.len(); - chunk_file_hashes.push((chunk.hash, chunk.data.len() as u64)); + for i in c_lb..c_ub { + let chunk_bytes = &raw_xorb.data[i]; + let chunk_hash = raw_xorb.cas_info.chunks[i].chunk_hash; + + file_data.extend_from_slice(chunk_bytes); + term_data.extend_from_slice(chunk_bytes); + n_bytes += chunk_bytes.len(); + chunk_file_hashes.push((chunk_hash, chunk_bytes.len() as u64)); + term_chunk_hashes.push(chunk_hash); } file_segments.push(FileDataSequenceEntry::new( - *xorb_hash, + xorb_h, n_bytes, chunk_idx_start as usize, chunk_idx_end as usize, )); + + term_infos.push(FileTermReference { + xorb_hash: xorb_h, + chunk_start: chunk_idx_start as u32, + chunk_end: chunk_idx_end as u32, + data: term_data, + chunk_hashes: term_chunk_hashes, + }); } let file_hash = file_hash_with_salt(&chunk_file_hashes, &[0; 32]); @@ -128,7 +197,15 @@ pub trait ClientTestingUtils: Client + Send + Sync { let upload_permit = self.acquire_upload_permit().await?; self.upload_shard(shard.to_bytes()?.into(), upload_permit).await?; - Ok((file_data, file_hash)) + // Convert xorb_data from seed-keyed to hash-keyed + let xorbs = xorb_data.into_values().map(|x| (x.hash(), x)).collect(); + + Ok(RandomFileContents { + file_hash, + data: file_data, + xorbs, + terms: term_infos, + }) } } diff --git a/cas_client/src/lib.rs b/cas_client/src/lib.rs index 16638b57..cd5ae608 100644 --- a/cas_client/src/lib.rs +++ b/cas_client/src/lib.rs @@ -20,6 +20,8 @@ mod interface; #[cfg(not(target_family = "wasm"))] mod local_client; #[cfg(not(target_family = "wasm"))] +pub mod local_server; +#[cfg(not(target_family = "wasm"))] mod output_provider; pub mod remote_client; pub mod retry_wrapper; diff --git a/cas_client/src/local_client.rs b/cas_client/src/local_client.rs index cca76053..95fb50b8 100644 --- a/cas_client/src/local_client.rs +++ b/cas_client/src/local_client.rs @@ -1271,10 +1271,10 @@ mod tests { // Create segments: xorb 1 chunks 0-2, then chunks 2-4 (adjacent) let term_spec = &[(1, (0, 2)), (1, (2, 4))]; - let (file_data, file_hash) = client.upload_random_file(term_spec, 2048).await.unwrap(); + let file = client.upload_random_file(term_spec, 2048).await.unwrap(); // Verify reconstruction merges adjacent ranges - let reconstruction = client.get_reconstruction(&file_hash, None).await.unwrap().unwrap(); + let reconstruction = client.get_reconstruction(&file.file_hash, None).await.unwrap().unwrap(); assert_eq!(reconstruction.terms.len(), 2); assert_eq!(reconstruction.fetch_info.len(), 1); @@ -1285,7 +1285,7 @@ mod tests { assert_eq!(fetch_infos[0].range.end, 4); // Verify file retrieval - assert_eq!(client.get_file_data(&file_hash, None).await.unwrap(), file_data); + assert_eq!(client.get_file_data(&file.file_hash, None).await.unwrap(), file.data); } #[tokio::test] @@ -1294,15 +1294,15 @@ mod tests { // Create file with segments from different xorbs let term_spec = &[(1, (0, 3)), (2, (0, 2)), (1, (3, 5))]; - let (file_data, file_hash) = client.upload_random_file(term_spec, 2048).await.unwrap(); + let file = client.upload_random_file(term_spec, 2048).await.unwrap(); // Verify reconstruction - let reconstruction = client.get_reconstruction(&file_hash, None).await.unwrap().unwrap(); + let reconstruction = client.get_reconstruction(&file.file_hash, None).await.unwrap().unwrap(); assert_eq!(reconstruction.terms.len(), 3); assert_eq!(reconstruction.fetch_info.len(), 2); // Verify file retrieval - assert_eq!(client.get_file_data(&file_hash, None).await.unwrap(), file_data); + assert_eq!(client.get_file_data(&file.file_hash, None).await.unwrap(), file.data); } /// Tests that overlapping chunk ranges within the same xorb are correctly merged @@ -1315,9 +1315,9 @@ mod tests { // Test 1: Simple overlapping ranges [0,3) and [1,4) -> merged to [0,4) { let term_spec = &[(1, (0, 3)), (1, (1, 4))]; - let (file_data, file_hash) = client.upload_random_file(term_spec, chunk_size).await.unwrap(); + let file = client.upload_random_file(term_spec, chunk_size).await.unwrap(); - let reconstruction = client.get_reconstruction(&file_hash, None).await.unwrap().unwrap(); + let reconstruction = client.get_reconstruction(&file.file_hash, None).await.unwrap().unwrap(); assert_eq!(reconstruction.terms.len(), 2); assert_eq!(reconstruction.fetch_info.len(), 1); @@ -1327,15 +1327,15 @@ mod tests { assert_eq!(fetch_infos[0].range.start, 0); assert_eq!(fetch_infos[0].range.end, 4); - assert_eq!(client.get_file_data(&file_hash, None).await.unwrap(), file_data); + assert_eq!(client.get_file_data(&file.file_hash, None).await.unwrap(), file.data); } // Test 2: Subset range - second range is fully contained in first [0,5) and [1,3) -> [0,5) { let term_spec = &[(1, (0, 5)), (1, (1, 3))]; - let (file_data, file_hash) = client.upload_random_file(term_spec, chunk_size).await.unwrap(); + let file = client.upload_random_file(term_spec, chunk_size).await.unwrap(); - let reconstruction = client.get_reconstruction(&file_hash, None).await.unwrap().unwrap(); + let reconstruction = client.get_reconstruction(&file.file_hash, None).await.unwrap().unwrap(); assert_eq!(reconstruction.terms.len(), 2); assert_eq!(reconstruction.fetch_info.len(), 1); @@ -1345,15 +1345,15 @@ mod tests { assert_eq!(fetch_infos[0].range.start, 0); assert_eq!(fetch_infos[0].range.end, 5); - assert_eq!(client.get_file_data(&file_hash, None).await.unwrap(), file_data); + assert_eq!(client.get_file_data(&file.file_hash, None).await.unwrap(), file.data); } // Test 3: Second range ends before first range end [0,5) and [2,4) -> [0,5) { let term_spec = &[(1, (0, 5)), (1, (2, 4))]; - let (file_data, file_hash) = client.upload_random_file(term_spec, chunk_size).await.unwrap(); + let file = client.upload_random_file(term_spec, chunk_size).await.unwrap(); - let reconstruction = client.get_reconstruction(&file_hash, None).await.unwrap().unwrap(); + let reconstruction = client.get_reconstruction(&file.file_hash, None).await.unwrap().unwrap(); assert_eq!(reconstruction.terms.len(), 2); assert_eq!(reconstruction.fetch_info.len(), 1); @@ -1363,15 +1363,15 @@ mod tests { assert_eq!(fetch_infos[0].range.start, 0); assert_eq!(fetch_infos[0].range.end, 5); - assert_eq!(client.get_file_data(&file_hash, None).await.unwrap(), file_data); + assert_eq!(client.get_file_data(&file.file_hash, None).await.unwrap(), file.data); } // Test 4: Multiple overlapping ranges forming a chain [0,2), [1,4), [3,6) -> [0,6) { let term_spec = &[(1, (0, 2)), (1, (1, 4)), (1, (3, 6))]; - let (file_data, file_hash) = client.upload_random_file(term_spec, chunk_size).await.unwrap(); + let file = client.upload_random_file(term_spec, chunk_size).await.unwrap(); - let reconstruction = client.get_reconstruction(&file_hash, None).await.unwrap().unwrap(); + let reconstruction = client.get_reconstruction(&file.file_hash, None).await.unwrap().unwrap(); assert_eq!(reconstruction.terms.len(), 3); assert_eq!(reconstruction.fetch_info.len(), 1); @@ -1381,15 +1381,15 @@ mod tests { assert_eq!(fetch_infos[0].range.start, 0); assert_eq!(fetch_infos[0].range.end, 6); - assert_eq!(client.get_file_data(&file_hash, None).await.unwrap(), file_data); + assert_eq!(client.get_file_data(&file.file_hash, None).await.unwrap(), file.data); } // Test 5: Ranges that interleave in a non-monotonic way [0,5), [1,3), [2,4) -> [0,5) { let term_spec = &[(1, (0, 5)), (1, (1, 3)), (1, (2, 4))]; - let (file_data, file_hash) = client.upload_random_file(term_spec, chunk_size).await.unwrap(); + let file = client.upload_random_file(term_spec, chunk_size).await.unwrap(); - let reconstruction = client.get_reconstruction(&file_hash, None).await.unwrap().unwrap(); + let reconstruction = client.get_reconstruction(&file.file_hash, None).await.unwrap().unwrap(); assert_eq!(reconstruction.terms.len(), 3); assert_eq!(reconstruction.fetch_info.len(), 1); @@ -1399,15 +1399,15 @@ mod tests { assert_eq!(fetch_infos[0].range.start, 0); assert_eq!(fetch_infos[0].range.end, 5); - assert_eq!(client.get_file_data(&file_hash, None).await.unwrap(), file_data); + assert_eq!(client.get_file_data(&file.file_hash, None).await.unwrap(), file.data); } // Test 6: Non-contiguous ranges should NOT be merged [0,2) and [4,6) -> two separate ranges { let term_spec = &[(1, (0, 2)), (1, (4, 6))]; - let (file_data, file_hash) = client.upload_random_file(term_spec, chunk_size).await.unwrap(); + let file = client.upload_random_file(term_spec, chunk_size).await.unwrap(); - let reconstruction = client.get_reconstruction(&file_hash, None).await.unwrap().unwrap(); + let reconstruction = client.get_reconstruction(&file.file_hash, None).await.unwrap().unwrap(); assert_eq!(reconstruction.terms.len(), 2); assert_eq!(reconstruction.fetch_info.len(), 1); @@ -1419,15 +1419,15 @@ mod tests { assert_eq!(fetch_infos[1].range.start, 4); assert_eq!(fetch_infos[1].range.end, 6); - assert_eq!(client.get_file_data(&file_hash, None).await.unwrap(), file_data); + assert_eq!(client.get_file_data(&file.file_hash, None).await.unwrap(), file.data); } // Test 7: Touch at boundary (adjacent) [0,3) and [3,5) -> [0,5) { let term_spec = &[(1, (0, 3)), (1, (3, 5))]; - let (file_data, file_hash) = client.upload_random_file(term_spec, chunk_size).await.unwrap(); + let file = client.upload_random_file(term_spec, chunk_size).await.unwrap(); - let reconstruction = client.get_reconstruction(&file_hash, None).await.unwrap().unwrap(); + let reconstruction = client.get_reconstruction(&file.file_hash, None).await.unwrap().unwrap(); assert_eq!(reconstruction.terms.len(), 2); assert_eq!(reconstruction.fetch_info.len(), 1); @@ -1437,15 +1437,15 @@ mod tests { assert_eq!(fetch_infos[0].range.start, 0); assert_eq!(fetch_infos[0].range.end, 5); - assert_eq!(client.get_file_data(&file_hash, None).await.unwrap(), file_data); + assert_eq!(client.get_file_data(&file.file_hash, None).await.unwrap(), file.data); } // Test 8: Large range followed by small contained range [0,10) and [4,6) -> [0,10) { let term_spec = &[(1, (0, 10)), (1, (4, 6))]; - let (file_data, file_hash) = client.upload_random_file(term_spec, chunk_size).await.unwrap(); + let file = client.upload_random_file(term_spec, chunk_size).await.unwrap(); - let reconstruction = client.get_reconstruction(&file_hash, None).await.unwrap().unwrap(); + let reconstruction = client.get_reconstruction(&file.file_hash, None).await.unwrap().unwrap(); assert_eq!(reconstruction.terms.len(), 2); assert_eq!(reconstruction.fetch_info.len(), 1); @@ -1455,15 +1455,15 @@ mod tests { assert_eq!(fetch_infos[0].range.start, 0); assert_eq!(fetch_infos[0].range.end, 10); - assert_eq!(client.get_file_data(&file_hash, None).await.unwrap(), file_data); + assert_eq!(client.get_file_data(&file.file_hash, None).await.unwrap(), file.data); } // Test 9: Same range repeated multiple times [2,5), [2,5), [2,5) -> [2,5) { let term_spec = &[(1, (2, 5)), (1, (2, 5)), (1, (2, 5))]; - let (file_data, file_hash) = client.upload_random_file(term_spec, chunk_size).await.unwrap(); + let file = client.upload_random_file(term_spec, chunk_size).await.unwrap(); - let reconstruction = client.get_reconstruction(&file_hash, None).await.unwrap().unwrap(); + let reconstruction = client.get_reconstruction(&file.file_hash, None).await.unwrap().unwrap(); assert_eq!(reconstruction.terms.len(), 3); assert_eq!(reconstruction.fetch_info.len(), 1); @@ -1473,16 +1473,16 @@ mod tests { assert_eq!(fetch_infos[0].range.start, 2); assert_eq!(fetch_infos[0].range.end, 5); - assert_eq!(client.get_file_data(&file_hash, None).await.unwrap(), file_data); + assert_eq!(client.get_file_data(&file.file_hash, None).await.unwrap(), file.data); } // Test 10: Mixed overlapping and non-contiguous in complex pattern // [0,3), [2,4), [6,8), [7,10) -> [0,4) and [6,10) { let term_spec = &[(1, (0, 3)), (1, (2, 4)), (1, (6, 8)), (1, (7, 10))]; - let (file_data, file_hash) = client.upload_random_file(term_spec, chunk_size).await.unwrap(); + let file = client.upload_random_file(term_spec, chunk_size).await.unwrap(); - let reconstruction = client.get_reconstruction(&file_hash, None).await.unwrap().unwrap(); + let reconstruction = client.get_reconstruction(&file.file_hash, None).await.unwrap().unwrap(); assert_eq!(reconstruction.terms.len(), 4); assert_eq!(reconstruction.fetch_info.len(), 1); @@ -1494,7 +1494,7 @@ mod tests { assert_eq!(fetch_infos[1].range.start, 6); assert_eq!(fetch_infos[1].range.end, 10); - assert_eq!(client.get_file_data(&file_hash, None).await.unwrap(), file_data); + assert_eq!(client.get_file_data(&file.file_hash, None).await.unwrap(), file.data); } } @@ -1502,14 +1502,14 @@ mod tests { async fn test_range_requests() { let client = LocalClient::temporary().await.unwrap(); let term_spec = &[(1, (0, 5))]; - let (file_data, file_hash) = client.upload_random_file(term_spec, 2048).await.unwrap(); - let total_file_size = file_data.len() as u64; + let file = client.upload_random_file(term_spec, 2048).await.unwrap(); + let total_file_size = file.data.len() as u64; // Test get_reconstruction range behaviors { // Partial out-of-range truncates let response = client - .get_reconstruction(&file_hash, Some(FileRange::new(total_file_size / 2, total_file_size + 1000))) + .get_reconstruction(&file.file_hash, Some(FileRange::new(total_file_size / 2, total_file_size + 1000))) .await .unwrap() .unwrap(); @@ -1518,19 +1518,22 @@ mod tests { // Entire range out of bounds returns error let result = client - .get_reconstruction(&file_hash, Some(FileRange::new(total_file_size + 100, total_file_size + 1000))) + .get_reconstruction( + &file.file_hash, + Some(FileRange::new(total_file_size + 100, total_file_size + 1000)), + ) .await; assert!(matches!(result.unwrap_err(), CasClientError::InvalidRange)); // Start equals file size returns error let result = client - .get_reconstruction(&file_hash, Some(FileRange::new(total_file_size, total_file_size + 100))) + .get_reconstruction(&file.file_hash, Some(FileRange::new(total_file_size, total_file_size + 100))) .await; assert!(matches!(result.unwrap_err(), CasClientError::InvalidRange)); // Valid range within bounds succeeds let response = client - .get_reconstruction(&file_hash, Some(FileRange::new(0, total_file_size / 2))) + .get_reconstruction(&file.file_hash, Some(FileRange::new(0, total_file_size / 2))) .await .unwrap() .unwrap(); @@ -1539,7 +1542,7 @@ mod tests { // End exactly at file size succeeds let response = client - .get_reconstruction(&file_hash, Some(FileRange::new(0, total_file_size))) + .get_reconstruction(&file.file_hash, Some(FileRange::new(0, total_file_size))) .await .unwrap() .unwrap(); @@ -1552,37 +1555,37 @@ mod tests { // Partial out-of-range truncates let partial_start = total_file_size / 2; let data = client - .get_file_data(&file_hash, Some(FileRange::new(partial_start, total_file_size + 1000))) + .get_file_data(&file.file_hash, Some(FileRange::new(partial_start, total_file_size + 1000))) .await .unwrap(); - assert_eq!(data, &file_data[partial_start as usize..]); + assert_eq!(data, &file.data[partial_start as usize..]); // Entire range out of bounds returns error let result = client - .get_file_data(&file_hash, Some(FileRange::new(total_file_size + 100, total_file_size + 1000))) + .get_file_data(&file.file_hash, Some(FileRange::new(total_file_size + 100, total_file_size + 1000))) .await; assert!(matches!(result.unwrap_err(), CasClientError::InvalidRange)); // Start equals file size returns error let result = client - .get_file_data(&file_hash, Some(FileRange::new(total_file_size, total_file_size + 100))) + .get_file_data(&file.file_hash, Some(FileRange::new(total_file_size, total_file_size + 100))) .await; assert!(matches!(result.unwrap_err(), CasClientError::InvalidRange)); // Valid range within bounds let valid_end = total_file_size / 2; let data = client - .get_file_data(&file_hash, Some(FileRange::new(0, valid_end))) + .get_file_data(&file.file_hash, Some(FileRange::new(0, valid_end))) .await .unwrap(); - assert_eq!(data, &file_data[..valid_end as usize]); + assert_eq!(data, &file.data[..valid_end as usize]); // End exactly at file size let data = client - .get_file_data(&file_hash, Some(FileRange::new(0, total_file_size))) + .get_file_data(&file.file_hash, Some(FileRange::new(0, total_file_size))) .await .unwrap(); - assert_eq!(data, file_data); + assert_eq!(data, file.data); } } @@ -1592,30 +1595,35 @@ mod tests { let client = LocalClient::temporary().await.unwrap(); let term_spec = &[(1, (0, 5))]; - let (file_data, file_hash) = client.upload_random_file(term_spec, 2048).await.unwrap(); + let file = client.upload_random_file(term_spec, 2048).await.unwrap(); // Test that sequential writer correctly wraps get_file_data let buffer = ThreadSafeBuffer::default(); let bytes_written = client .clone() - .get_file_with_sequential_writer(&file_hash, None, buffer.clone().into(), None) + .get_file_with_sequential_writer(&file.file_hash, None, buffer.clone().into(), None) .await .unwrap(); - assert_eq!(bytes_written as usize, file_data.len()); - assert_eq!(buffer.value(), file_data); + assert_eq!(bytes_written as usize, file.data.len()); + assert_eq!(buffer.value(), file.data); // Test with range let buffer2 = ThreadSafeBuffer::default(); - let half = file_data.len() as u64 / 2; + let half = file.data.len() as u64 / 2; let bytes_written2 = client .clone() - .get_file_with_sequential_writer(&file_hash, Some(FileRange::new(0, half)), buffer2.clone().into(), None) + .get_file_with_sequential_writer( + &file.file_hash, + Some(FileRange::new(0, half)), + buffer2.clone().into(), + None, + ) .await .unwrap(); assert_eq!(bytes_written2, half); - assert_eq!(buffer2.value(), &file_data[..half as usize]); + assert_eq!(buffer2.value(), &file.data[..half as usize]); } #[tokio::test] @@ -1624,62 +1632,65 @@ mod tests { // Test 1: Single segment with 3 chunks { - let (file_data, file_hash) = client.upload_random_file(&[(1, (0, 3))], 2048).await.unwrap(); - assert_eq!(client.get_file_data(&file_hash, None).await.unwrap(), file_data); + let file = client.upload_random_file(&[(1, (0, 3))], 2048).await.unwrap(); + assert_eq!(client.get_file_data(&file.file_hash, None).await.unwrap(), file.data); } // Test 2: Multiple segments from the same xorb { let term_spec = &[(1, (0, 2)), (1, (2, 4)), (1, (4, 6))]; - let (file_data, file_hash) = client.upload_random_file(term_spec, 2048).await.unwrap(); + let file = client.upload_random_file(term_spec, 2048).await.unwrap(); - let reconstruction = client.get_reconstruction(&file_hash, None).await.unwrap().unwrap(); + let reconstruction = client.get_reconstruction(&file.file_hash, None).await.unwrap().unwrap(); assert_eq!(reconstruction.terms.len(), 3); assert_eq!(reconstruction.fetch_info.len(), 1); - assert_eq!(client.get_file_data(&file_hash, None).await.unwrap(), file_data); + assert_eq!(client.get_file_data(&file.file_hash, None).await.unwrap(), file.data); } // Test 3: Segments from different xorbs { let term_spec = &[(1, (0, 3)), (2, (0, 2)), (3, (0, 4))]; - let (file_data, file_hash) = client.upload_random_file(term_spec, 2048).await.unwrap(); + let file = client.upload_random_file(term_spec, 2048).await.unwrap(); - let reconstruction = client.get_reconstruction(&file_hash, None).await.unwrap().unwrap(); + let reconstruction = client.get_reconstruction(&file.file_hash, None).await.unwrap().unwrap(); assert_eq!(reconstruction.terms.len(), 3); assert_eq!(reconstruction.fetch_info.len(), 3); - assert_eq!(client.get_file_data(&file_hash, None).await.unwrap(), file_data); + assert_eq!(client.get_file_data(&file.file_hash, None).await.unwrap(), file.data); } // Test 4: Partial range retrieval { let term_spec = &[(1, (0, 5)), (2, (0, 5))]; - let (file_data, file_hash) = client.upload_random_file(term_spec, 2048).await.unwrap(); - let half = file_data.len() as u64 / 2; + let file = client.upload_random_file(term_spec, 2048).await.unwrap(); + let half = file.data.len() as u64 / 2; // First half - let first_half = client.get_file_data(&file_hash, Some(FileRange::new(0, half))).await.unwrap(); - assert_eq!(first_half, &file_data[..half as usize]); + let first_half = client + .get_file_data(&file.file_hash, Some(FileRange::new(0, half))) + .await + .unwrap(); + assert_eq!(first_half, &file.data[..half as usize]); // Second half let second_half = client - .get_file_data(&file_hash, Some(FileRange::new(half, file_data.len() as u64))) + .get_file_data(&file.file_hash, Some(FileRange::new(half, file.data.len() as u64))) .await .unwrap(); - assert_eq!(second_half, &file_data[half as usize..]); + assert_eq!(second_half, &file.data[half as usize..]); } // Test 5: Overlapping chunk references from same xorb { let term_spec = &[(1, (0, 3)), (1, (1, 4)), (1, (2, 5))]; - let (file_data, file_hash) = client.upload_random_file(term_spec, 2048).await.unwrap(); + let file = client.upload_random_file(term_spec, 2048).await.unwrap(); - let reconstruction = client.get_reconstruction(&file_hash, None).await.unwrap().unwrap(); + let reconstruction = client.get_reconstruction(&file.file_hash, None).await.unwrap().unwrap(); assert_eq!(reconstruction.terms.len(), 3); assert_eq!(reconstruction.fetch_info.len(), 1); - assert_eq!(client.get_file_data(&file_hash, None).await.unwrap(), file_data); + assert_eq!(client.get_file_data(&file.file_hash, None).await.unwrap(), file.data); } } @@ -1692,12 +1703,12 @@ mod tests { // Create a file with 5 chunks of 2048 bytes each = 10240 total bytes let chunk_size: usize = 2048; let term_spec = &[(1, (0, 5))]; - let (file_data, file_hash) = client.upload_random_file(term_spec, chunk_size).await.unwrap(); + let file = client.upload_random_file(term_spec, chunk_size).await.unwrap(); - let total_file_size = file_data.len() as u64; + let total_file_size = file.data.len() as u64; assert_eq!(total_file_size, (5 * chunk_size) as u64); - let query_file_size = client.get_file_size(&file_hash).await.unwrap(); + let query_file_size = client.get_file_size(&file.file_hash).await.unwrap(); assert_eq!(query_file_size, total_file_size); // Test 1: Range starting in the middle of chunk 1 should skip chunk 0 @@ -1706,7 +1717,7 @@ mod tests { let start = chunk_size as u64 + 500; // Middle of chunk 1 let end = total_file_size; let response = client - .get_reconstruction(&file_hash, Some(FileRange::new(start, end))) + .get_reconstruction(&file.file_hash, Some(FileRange::new(start, end))) .await .unwrap() .unwrap(); @@ -1727,7 +1738,7 @@ mod tests { let start = (chunk_size * 2) as u64; // Start of chunk 2 let end = total_file_size; let response = client - .get_reconstruction(&file_hash, Some(FileRange::new(start, end))) + .get_reconstruction(&file.file_hash, Some(FileRange::new(start, end))) .await .unwrap() .unwrap(); @@ -1744,7 +1755,7 @@ mod tests { let start = 0u64; let end = (chunk_size * 2) as u64 + 500; // Middle of chunk 2 let response = client - .get_reconstruction(&file_hash, Some(FileRange::new(start, end))) + .get_reconstruction(&file.file_hash, Some(FileRange::new(start, end))) .await .unwrap() .unwrap(); @@ -1764,7 +1775,7 @@ mod tests { let start = (chunk_size * 2) as u64 + 100; // Inside chunk 2 let end = (chunk_size * 2) as u64 + 500; // Still inside chunk 2 let response = client - .get_reconstruction(&file_hash, Some(FileRange::new(start, end))) + .get_reconstruction(&file.file_hash, Some(FileRange::new(start, end))) .await .unwrap() .unwrap(); @@ -1784,7 +1795,7 @@ mod tests { let start = chunk_size as u64 - 100; // Near end of chunk 0 let end = chunk_size as u64 + 100; // Near start of chunk 1 let response = client - .get_reconstruction(&file_hash, Some(FileRange::new(start, end))) + .get_reconstruction(&file.file_hash, Some(FileRange::new(start, end))) .await .unwrap() .unwrap(); @@ -1803,7 +1814,7 @@ mod tests { let start = (chunk_size * 2) as u64 + delta; // Start of chunk 2 let end = (chunk_size * 4) as u64 - delta; // End of chunk 3 let response = client - .get_reconstruction(&file_hash, Some(FileRange::new(start, end))) + .get_reconstruction(&file.file_hash, Some(FileRange::new(start, end))) .await .unwrap() .unwrap(); @@ -1820,7 +1831,7 @@ mod tests { let start = (chunk_size * 2) as u64 - 1; // Start of chunk 2 let end = (chunk_size * 4) as u64 + 1; // One byte of chunk 4 let response = client - .get_reconstruction(&file_hash, Some(FileRange::new(start, end))) + .get_reconstruction(&file.file_hash, Some(FileRange::new(start, end))) .await .unwrap() .unwrap(); @@ -1844,9 +1855,9 @@ mod tests { // Total: 16384 bytes let chunk_size = 2048usize; let term_spec = &[(1, (0, 4)), (2, (0, 4))]; - let (file_data, file_hash) = client.upload_random_file(term_spec, chunk_size).await.unwrap(); + let file = client.upload_random_file(term_spec, chunk_size).await.unwrap(); - let total_file_size = file_data.len() as u64; + let total_file_size = file.data.len() as u64; assert_eq!(total_file_size, (8 * chunk_size) as u64); // Test 1: Range that skips first chunk of first xorb @@ -1854,7 +1865,7 @@ mod tests { let start = chunk_size as u64 + 500; // Middle of chunk 1 in xorb 1 let end = total_file_size; let response = client - .get_reconstruction(&file_hash, Some(FileRange::new(start, end))) + .get_reconstruction(&file.file_hash, Some(FileRange::new(start, end))) .await .unwrap() .unwrap(); @@ -1879,7 +1890,7 @@ mod tests { let start = chunk_size as u64; // Start of chunk 1 in xorb 1 let end = (chunk_size * 3) as u64; // End of chunk 2 in xorb 1 let response = client - .get_reconstruction(&file_hash, Some(FileRange::new(start, end))) + .get_reconstruction(&file.file_hash, Some(FileRange::new(start, end))) .await .unwrap() .unwrap(); @@ -1900,7 +1911,7 @@ mod tests { let start = xorb1_size + chunk_size as u64; // Start of chunk 1 in xorb 2 let end = xorb1_size + (chunk_size * 3) as u64; // End of chunk 2 in xorb 2 let response = client - .get_reconstruction(&file_hash, Some(FileRange::new(start, end))) + .get_reconstruction(&file.file_hash, Some(FileRange::new(start, end))) .await .unwrap() .unwrap(); @@ -1921,7 +1932,7 @@ mod tests { let start = (chunk_size * 2) as u64; // Start of chunk 2 in xorb 1 let end = xorb1_size + (chunk_size * 2) as u64 + 500; // Middle of chunk 2 in xorb 2 let response = client - .get_reconstruction(&file_hash, Some(FileRange::new(start, end))) + .get_reconstruction(&file.file_hash, Some(FileRange::new(start, end))) .await .unwrap() .unwrap(); @@ -1946,7 +1957,7 @@ mod tests { let start = chunk_size as u64 + delta; // Start of chunk 1 +/- delta let end = (chunk_size * 3) as u64 - delta; // End of chunk 2 -/+ delta let response = client - .get_reconstruction(&file_hash, Some(FileRange::new(start, end))) + .get_reconstruction(&file.file_hash, Some(FileRange::new(start, end))) .await .unwrap() .unwrap(); @@ -1964,7 +1975,7 @@ mod tests { let start = chunk_size as u64 - 1; // 1 byte before chunk 1 let end = (chunk_size * 3) as u64 + 1; // 1 byte into chunk 3 let response = client - .get_reconstruction(&file_hash, Some(FileRange::new(start, end))) + .get_reconstruction(&file.file_hash, Some(FileRange::new(start, end))) .await .unwrap() .unwrap(); @@ -1985,7 +1996,7 @@ mod tests { let start = (chunk_size * 2) as u64 + delta; // Chunk 2 in xorb 1 let end = xorb1_size + (chunk_size * 2) as u64 - delta; // Chunk 1 end in xorb 2 let response = client - .get_reconstruction(&file_hash, Some(FileRange::new(start, end))) + .get_reconstruction(&file.file_hash, Some(FileRange::new(start, end))) .await .unwrap() .unwrap(); @@ -2009,7 +2020,7 @@ mod tests { let start = xorb1_size - 1; // 1 byte before xorb 2 let end = xorb1_size + (chunk_size * 2) as u64 + 1; // 1 byte into chunk 2 of xorb 2 let response = client - .get_reconstruction(&file_hash, Some(FileRange::new(start, end))) + .get_reconstruction(&file.file_hash, Some(FileRange::new(start, end))) .await .unwrap() .unwrap(); diff --git a/cas_client/src/local_server/handlers.rs b/cas_client/src/local_server/handlers.rs new file mode 100644 index 00000000..13c36898 --- /dev/null +++ b/cas_client/src/local_server/handlers.rs @@ -0,0 +1,522 @@ +//! HTTP Request Handlers for the Local CAS Server +//! +//! This module contains all the Axum request handlers that bridge HTTP requests +//! to `LocalClient` operations. Each handler corresponds to an endpoint in the +//! CAS REST API that `RemoteClient` expects. +//! +//! # Handler Pattern +//! +//! All handlers follow this pattern: +//! 1. Extract request data (path parameters, headers, body) +//! 2. Call the appropriate `LocalClient` method +//! 3. Convert the result to an HTTP response +//! +//! Errors are mapped to appropriate HTTP status codes via `error_to_response`. + +use std::path::PathBuf; +use std::sync::Arc; + +use axum::Json; +use axum::body::Body; +use axum::extract::{Path, State}; +use axum::http::header::HOST; +use axum::http::{HeaderMap, HeaderValue, StatusCode}; +use axum::response::{IntoResponse, Response}; +use base64::Engine; +use base64::engine::general_purpose::URL_SAFE_NO_PAD; +use bytes::Bytes; +use cas_types::{ + CASReconstructionFetchInfo, FileRange, HexKey, HexMerkleHash, UploadShardResponse, UploadShardResponseType, + UploadXorbResponse, +}; +use futures_util::StreamExt; +use http::header::RANGE; +use merklehash::MerkleHash; + +use crate::error::CasClientError; +use crate::{Client, LocalClient}; + +/// Represents the different forms a Range header can take. +pub enum FileRangeVariant { + /// Standard byte range: bytes=start-end (inclusive end, converted to exclusive) + Normal(FileRange), + /// Open-ended range: bytes=start- (from start to end of file) + OpenRHS(u64), + /// Suffix range: bytes=-N (last N bytes of file) + Suffix(u64), +} + +/// Parses an HTTP Range header into a FileRangeVariant. +/// +/// Supports the following formats per RFC 7233: +/// - `bytes=0-499` - First 500 bytes +/// - `bytes=500-` - From byte 500 to end +/// - `bytes=-500` - Last 500 bytes +/// +/// Returns `Ok(None)` if no Range header is present. +fn parse_range_header(range_header: Option<&HeaderValue>) -> Result, (StatusCode, String)> { + let Some(range_header) = range_header else { + return Ok(None); + }; + + const RANGE_PREFIX: &str = "bytes="; + let range_str = range_header + .to_str() + .map_err(|e| (StatusCode::RANGE_NOT_SATISFIABLE, format!("Invalid range header: {e}")))?; + + if !range_str.starts_with(RANGE_PREFIX) { + return Err((StatusCode::RANGE_NOT_SATISFIABLE, format!("Range header doesn't start with {RANGE_PREFIX}"))); + } + + let split = range_str[RANGE_PREFIX.len()..].splitn(2, '-').collect::>(); + if split.len() != 2 { + return Err((StatusCode::RANGE_NOT_SATISFIABLE, "Invalid range syntax".to_string())); + } + + let start_value = if split[0].is_empty() { + None + } else { + Some( + split[0] + .parse::() + .map_err(|e| (StatusCode::RANGE_NOT_SATISFIABLE, format!("Invalid range start: {e}")))?, + ) + }; + let end_value = if split[1].is_empty() { + None + } else { + Some( + split[1] + .parse::() + .map_err(|e| (StatusCode::RANGE_NOT_SATISFIABLE, format!("Invalid range end: {e}")))?, + ) + }; + + match (start_value, end_value) { + (None, None) => Err((StatusCode::RANGE_NOT_SATISFIABLE, "Invalid range syntax".to_string())), + (Some(start), Some(end)) => { + if start > end { + Err((StatusCode::RANGE_NOT_SATISFIABLE, "Range start > end".to_string())) + } else { + // HTTP ranges are inclusive on both ends; FileRange uses exclusive end + Ok(Some(FileRangeVariant::Normal(FileRange::new(start, end + 1)))) + } + }, + (Some(start), None) => Ok(Some(FileRangeVariant::OpenRHS(start))), + (None, Some(suffix_len)) => Ok(Some(FileRangeVariant::Suffix(suffix_len))), + } +} + +/// Maps CasClientError to appropriate HTTP status codes. +fn error_to_response(e: CasClientError) -> Response { + let status = match &e { + CasClientError::XORBNotFound(_) | CasClientError::FileNotFound(_) => StatusCode::NOT_FOUND, + CasClientError::InvalidRange => StatusCode::RANGE_NOT_SATISFIABLE, + CasClientError::InvalidArguments => StatusCode::BAD_REQUEST, + _ => StatusCode::INTERNAL_SERVER_ERROR, + }; + (status, e.to_string()).into_response() +} + +/// Encodes term data (file path) into a URL-safe base64 string. +/// +/// The term encodes the local file path that the LocalClient uses. +/// This allows the fetch_term endpoint to retrieve the data. +fn encode_term(file_path: &str) -> String { + URL_SAFE_NO_PAD.encode(file_path.as_bytes()) +} + +/// Decodes a URL-safe base64 term string back into file path. +fn decode_term(term: &str) -> Result { + let bytes = URL_SAFE_NO_PAD.decode(term).map_err(|e| format!("Invalid base64: {e}"))?; + let file_path = String::from_utf8(bytes).map_err(|e| format!("Invalid UTF-8: {e}"))?; + Ok(PathBuf::from(file_path)) +} + +/// Extracts the base URL from request headers (Host header). +fn get_base_url(headers: &HeaderMap) -> String { + headers + .get(HOST) + .and_then(|h| h.to_str().ok()) + .map(|host| format!("http://{host}")) + .unwrap_or_else(|| "http://localhost".to_string()) +} + +/// Transforms fetch_info URLs from local file paths to HTTP URLs. +/// +/// LocalClient generates URLs in a local format. This function transforms them +/// into proper HTTP URLs that point to the /v1/fetch_term endpoint. +fn transform_fetch_info_urls( + fetch_info: &mut std::collections::HashMap>, + base_url: &str, +) { + for fetch_infos in fetch_info.values_mut() { + for fi in fetch_infos.iter_mut() { + // The original URL from LocalClient is in the format: + // "/path/to/file":start:end:timestamp + // We extract the file path and encode it for the HTTP URL. + // The byte range is already in url_range, so we just need the file path. + + // Parse the local URL format to extract the file path + let file_path = extract_file_path_from_local_url(&fi.url); + + // Create the HTTP URL with the encoded term + let encoded_term = encode_term(&file_path); + fi.url = format!("{base_url}/v1/fetch_term?term={encoded_term}"); + } + } +} + +/// Extracts the file path from LocalClient's URL format. +/// +/// LocalClient generates URLs like: "/path/to/file":start:end:timestamp +/// This extracts just the file path portion. +fn extract_file_path_from_local_url(local_url: &str) -> String { + // The format is: "path":start:end:timestamp + // We need to extract the path, which is quoted + let mut parts = local_url.rsplitn(4, ':').collect::>(); + parts.reverse(); + + if !parts.is_empty() { + // Remove the quotes from the path + parts[0].trim_matches('"').to_string() + } else { + local_url.to_string() + } +} + +/// GET /v1/reconstructions/{file_id} +/// +/// Returns reconstruction information for a file, including: +/// - List of terms (chunks) needed to reconstruct the file +/// - Fetch info with URLs/locations for each XORB +/// +/// Supports Range header for partial file reconstruction. +/// +/// The URLs in fetch_info are transformed from local file paths to HTTP URLs +/// that point to the /v1/fetch_term endpoint. +pub async fn get_reconstruction( + State(state): State>, + Path(HexMerkleHash(file_id)): Path, + headers: HeaderMap, +) -> Response { + let base_url = get_base_url(&headers); + + let range = match parse_range_header(headers.get(RANGE)) { + Ok(Some(FileRangeVariant::Normal(range))) => Some(range), + Ok(Some(FileRangeVariant::OpenRHS(start))) => { + let file_size = match state.get_file_size(&file_id).await { + Ok(size) => size, + Err(e) => return error_to_response(e), + }; + Some(FileRange::new(start, file_size)) + }, + Ok(Some(FileRangeVariant::Suffix(suffix))) => { + let file_size = match state.get_file_size(&file_id).await { + Ok(size) => size, + Err(e) => return error_to_response(e), + }; + Some(FileRange::new(file_size.saturating_sub(suffix), file_size)) + }, + Ok(None) => None, + Err((status, msg)) => return (status, msg).into_response(), + }; + + match state.get_reconstruction(&file_id, range).await { + Ok(Some(mut response)) => { + transform_fetch_info_urls(&mut response.fetch_info, &base_url); + Json(response).into_response() + }, + Ok(None) => (StatusCode::RANGE_NOT_SATISFIABLE, "Range not satisfiable").into_response(), + Err(e) => error_to_response(e), + } +} + +/// GET /reconstructions?file_id=...&file_id=... +/// +/// Batch query for reconstruction information for multiple files using query parameters. +/// This is the format used by RemoteClient. +/// Query params: file_id (repeated for each file hash as hex string) +/// Response: Map of file ID -> reconstruction info +/// +/// The URLs in fetch_info are transformed from local file paths to HTTP URLs. +pub async fn batch_get_reconstruction( + State(state): State>, + uri: axum::http::Uri, + headers: HeaderMap, +) -> Response { + let base_url = get_base_url(&headers); + + // Parse repeated file_id query parameters + let file_id_strings: Vec = uri + .query() + .unwrap_or("") + .split('&') + .filter_map(|param| { + let (key, value) = param.split_once('=')?; + if key == "file_id" { + Some(value.to_string()) + } else { + None + } + }) + .collect(); + + let file_ids: Vec = file_id_strings + .iter() + .filter_map(|hex| MerkleHash::from_hex(hex).ok()) + .collect(); + + if file_ids.is_empty() && !file_id_strings.is_empty() { + return (StatusCode::BAD_REQUEST, "Invalid file_id format").into_response(); + } + + match state.batch_get_reconstruction(&file_ids).await { + Ok(mut response) => { + transform_fetch_info_urls(&mut response.fetch_info, &base_url); + Json(response).into_response() + }, + Err(e) => error_to_response(e), + } +} + +/// GET /v1/fetch_term?term= +/// +/// Fetches XORB data based on an encoded term. +/// The term is a URL-safe base64-encoded file path. +/// Supports Range header for partial downloads. +/// +/// This endpoint is called by RemoteClient when fetching reconstruction terms. +pub async fn fetch_term(State(_state): State>, uri: axum::http::Uri, headers: HeaderMap) -> Response { + // Extract 'term' query parameter + let term = uri.query().unwrap_or("").split('&').find_map(|param| { + let (key, value) = param.split_once('=')?; + if key == "term" { Some(value.to_string()) } else { None } + }); + + let Some(term) = term else { + return (StatusCode::BAD_REQUEST, "Missing 'term' query parameter").into_response(); + }; + + let file_path = match decode_term(&term) { + Ok(p) => p, + Err(e) => return (StatusCode::BAD_REQUEST, format!("Invalid term: {e}")).into_response(), + }; + + // Read the file directly from disk + let data = match std::fs::read(&file_path) { + Ok(d) => d, + Err(e) => { + if e.kind() == std::io::ErrorKind::NotFound { + return (StatusCode::NOT_FOUND, "Term data not found").into_response(); + } + return (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to read data: {e}")).into_response(); + }, + }; + + // Apply range if specified + let range = match parse_range_header(headers.get(RANGE)) { + Ok(Some(FileRangeVariant::Normal(range))) => Some(range), + Ok(Some(FileRangeVariant::OpenRHS(start))) => Some(FileRange::new(start, data.len() as u64)), + Ok(Some(FileRangeVariant::Suffix(suffix))) => { + let len = data.len() as u64; + Some(FileRange::new(len.saturating_sub(suffix), len)) + }, + Ok(None) => None, + Err((status, msg)) => return (status, msg).into_response(), + }; + + let response_data = if let Some(range) = range { + let start = range.start as usize; + let end = (range.end as usize).min(data.len()); + if start >= data.len() { + return (StatusCode::RANGE_NOT_SATISFIABLE, "Range start out of bounds").into_response(); + } + data[start..end].to_vec() + } else { + data + }; + + (StatusCode::OK, response_data).into_response() +} + +/// GET /v1/chunks/{prefix}/{hash} +/// +/// Query for a global deduplication shard by chunk hash. +/// Returns the shard data if found, 404 otherwise. +pub async fn get_dedup_info_by_chunk(State(state): State>, Path(key): Path) -> Response { + match state.query_for_global_dedup_shard(&key.prefix, &key.hash).await { + Ok(Some(data)) => (StatusCode::OK, data).into_response(), + Ok(None) => (StatusCode::NOT_FOUND, "Shard not found").into_response(), + Err(e) => error_to_response(e), + } +} + +/// HEAD /v1/xorbs/{prefix}/{hash} +/// +/// Check if a XORB exists in the store. +/// Returns 200 if found, 404 otherwise. +pub async fn head_xorb(State(state): State>, Path(key): Path) -> Response { + match state.get_file_reconstruction_info(&key.hash).await { + Ok(Some(_)) => { + let mut headers = HeaderMap::new(); + headers.insert(http::header::CONTENT_LENGTH, HeaderValue::from(0)); + (StatusCode::OK, headers).into_response() + }, + Ok(None) => (StatusCode::NOT_FOUND, "XORB not found").into_response(), + Err(e) => error_to_response(e), + } +} + +/// POST /v1/xorbs/{prefix}/{hash} +/// +/// Upload a XORB (content-addressed block) to the store. +/// Request body: Serialized CAS object data +/// Response: JSON indicating if the XORB was newly inserted +pub async fn post_xorb(State(state): State>, Path(key): Path, body: Body) -> Response { + let data = match collect_body(body).await { + Ok(d) => d, + Err(e) => return (StatusCode::BAD_REQUEST, e).into_response(), + }; + + let cas_object = cas_object::SerializedCasObject { + hash: key.hash, + serialized_data: data.to_vec(), + raw_num_bytes: data.len() as u64, + num_chunks: 0, + footer_start: None, + }; + + let permit = match state.acquire_upload_permit().await { + Ok(p) => p, + Err(e) => return error_to_response(e), + }; + + match state.upload_xorb(&key.prefix, cas_object, None, permit).await { + Ok(_) => Json(UploadXorbResponse { was_inserted: true }).into_response(), + Err(e) => error_to_response(e), + } +} + +/// POST /v1/shards +/// +/// Upload a shard (deduplication index) to the store. +/// Request body: Raw shard data +/// Response: JSON indicating if the shard was newly inserted or already existed +pub async fn post_shard(State(state): State>, body: Body) -> Response { + let data = match collect_body(body).await { + Ok(d) => d, + Err(e) => return (StatusCode::BAD_REQUEST, e).into_response(), + }; + + let permit = match state.acquire_upload_permit().await { + Ok(p) => p, + Err(e) => return error_to_response(e), + }; + + match state.upload_shard(data, permit).await { + Ok(was_new) => { + let result = if was_new { + UploadShardResponseType::SyncPerformed + } else { + UploadShardResponseType::Exists + }; + Json(UploadShardResponse { result }).into_response() + }, + Err(e) => error_to_response(e), + } +} + +/// HEAD /v1/files/{file_id} +/// +/// Get the size of a file. +/// Returns Content-Length header with file size if found, 404 otherwise. +pub async fn head_file( + State(state): State>, + Path(HexMerkleHash(file_id)): Path, +) -> Response { + match state.get_file_size(&file_id).await { + Ok(size) => { + let mut headers = HeaderMap::new(); + headers.insert(http::header::CONTENT_LENGTH, HeaderValue::from(size)); + (StatusCode::OK, headers).into_response() + }, + Err(e) => error_to_response(e), + } +} + +/// GET /v1/get_xorb/{prefix}/{hash}/ +/// +/// Download XORB data directly. +/// Supports Range header for partial downloads. +pub async fn get_file_term_data( + State(state): State>, + Path((_prefix, hash_str)): Path<(String, String)>, + headers: HeaderMap, +) -> Response { + let hash = match MerkleHash::from_hex(&hash_str) { + Ok(h) => h, + Err(_) => return (StatusCode::BAD_REQUEST, "Invalid hash").into_response(), + }; + + let range = match parse_range_header(headers.get(RANGE)) { + Ok(Some(FileRangeVariant::Normal(range))) => Some(range), + Ok(Some(_)) => return (StatusCode::RANGE_NOT_SATISFIABLE, "Unsupported range type").into_response(), + Ok(None) => None, + Err((status, msg)) => return (status, msg).into_response(), + }; + + match state.get_file_data(&hash, range).await { + Ok(data) => (StatusCode::OK, data).into_response(), + Err(e) => error_to_response(e), + } +} + +/// GET /health +/// +/// Health check endpoint. Always returns 200 OK with no-cache headers. +/// Used by load balancers and monitoring systems to verify server availability. +pub async fn health_check() -> Response { + let mut headers = HeaderMap::new(); + headers.insert( + http::header::CACHE_CONTROL, + HeaderValue::from_static("no-store, no-cache, must-revalidate, proxy-revalidate"), + ); + headers.insert(http::header::PRAGMA, HeaderValue::from_static("no-cache")); + headers.insert(http::header::EXPIRES, HeaderValue::from_static("0")); + + (StatusCode::OK, headers).into_response() +} + +/// Collects the entire request body into a Bytes buffer. +async fn collect_body(body: Body) -> Result { + let mut stream = body.into_data_stream(); + let mut data = Vec::new(); + while let Some(chunk) = stream.next().await { + match chunk { + Ok(c) => data.extend_from_slice(&c), + Err(e) => return Err(format!("Error reading body: {e}")), + } + } + Ok(Bytes::from(data)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_encode_decode_term() { + let file_path = "/tmp/test/data/xorbs/abc123def456.xorb"; + let encoded = encode_term(file_path); + let decoded = decode_term(&encoded).unwrap(); + assert_eq!(decoded.to_str().unwrap(), file_path); + } + + #[test] + fn test_extract_file_path_from_local_url() { + let local_url = "\"/tmp/test/data/xorbs/abc123.xorb\":100:200:1234567890"; + let file_path = extract_file_path_from_local_url(local_url); + assert_eq!(file_path, "/tmp/test/data/xorbs/abc123.xorb"); + } +} diff --git a/cas_client/src/local_server/main.rs b/cas_client/src/local_server/main.rs new file mode 100644 index 00000000..0afbce2f --- /dev/null +++ b/cas_client/src/local_server/main.rs @@ -0,0 +1,104 @@ +//! Local CAS Server Binary +//! +//! This binary provides a local HTTP server that wraps `LocalClient`, exposing +//! the same REST API that `RemoteClient` expects from a remote CAS server. +//! +//! # Purpose +//! +//! The local CAS server enables: +//! - **Testing**: Run integration tests against a local server using `RemoteClient` without needing a remote backend. +//! - **Development**: Develop and debug CAS client interactions locally. +//! - **Offline workflows**: Store and retrieve CAS objects without network access. +//! +//! # Usage +//! +//! ```bash +//! # Start with default settings (port 8080, data in ./local_cas_data) +//! local_cas_server +//! +//! # Specify custom data directory and port +//! local_cas_server --data-directory /path/to/storage --port 9000 +//! +//! # Bind to all interfaces +//! local_cas_server --host 0.0.0.0 --port 8080 +//! ``` +//! +//! # API Endpoints +//! +//! The server exposes the following endpoints (compatible with `RemoteClient`): +//! +//! - `GET /health` - Health check endpoint +//! - `GET /v1/reconstructions/{file_id}` - Get file reconstruction info +//! - `POST /v1/reconstructions` - Batch query for multiple file reconstructions +//! - `GET /v1/chunks/{prefix}/{hash}` - Query for global deduplication shard +//! - `HEAD /v1/xorbs/{prefix}/{hash}` - Check if XORB exists +//! - `POST /v1/xorbs/{prefix}/{hash}` - Upload a XORB +//! - `POST /shards` - Upload a shard +//! - `HEAD /v1/files/{file_id}` - Get file size +//! - `GET /get_xorb/{prefix}/{hash}/` - Download XORB data +//! +//! # Environment Variables +//! +//! - `RUST_LOG` - Control logging verbosity (e.g., `RUST_LOG=info` or `RUST_LOG=debug`) + +use std::path::PathBuf; + +use cas_client::local_server::{LocalServer, LocalServerConfig}; +use clap::Parser; +use tracing_subscriber::EnvFilter; + +/// A local HTTP server that wraps LocalClient for testing and development. +/// +/// This server exposes the same REST API as the remote CAS server, allowing +/// RemoteClient to connect and interact with locally stored CAS objects. +/// Useful for integration testing, development, and offline workflows. +#[derive(Parser, Debug)] +#[command(name = "local_cas_server")] +#[command(version, about, long_about = None)] +struct Args { + /// Directory where CAS data (XORBs, shards, indices) will be stored. + /// + /// This directory will be created if it doesn't exist. All CAS objects + /// uploaded to this server will be persisted here. Multiple server + /// instances can share the same directory for read operations, but + /// concurrent writes should be avoided. + #[arg(short, long, default_value = "./local_cas_data")] + data_directory: PathBuf, + + /// Network interface to bind the server to. + /// + /// Use "127.0.0.1" (default) for local-only access, or "0.0.0.0" to + /// accept connections from any interface. + #[arg(long, default_value = "127.0.0.1")] + host: String, + + /// TCP port number for the HTTP server. + /// + /// The server will listen on this port for incoming HTTP requests. + /// Make sure this port is not already in use by another process. + #[arg(short, long, default_value = "8080")] + port: u16, +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + // Initialize tracing with environment filter (respects RUST_LOG) + tracing_subscriber::fmt().with_env_filter(EnvFilter::from_default_env()).init(); + + let args = Args::parse(); + + let config = LocalServerConfig { + data_directory: args.data_directory, + host: args.host, + port: args.port, + }; + + tracing::info!("Starting local CAS server with config: {:?}", config); + tracing::info!("Data directory: {:?}", config.data_directory); + tracing::info!("Listening on: {}:{}", config.host, config.port); + + let server = LocalServer::new(config)?; + server.run().await?; + + Ok(()) +} diff --git a/cas_client/src/local_server/mod.rs b/cas_client/src/local_server/mod.rs new file mode 100644 index 00000000..a4169104 --- /dev/null +++ b/cas_client/src/local_server/mod.rs @@ -0,0 +1,20 @@ +//! Local CAS Server Module +//! +//! This module provides an HTTP server that wraps `LocalClient`, exposing the same +//! REST API that `RemoteClient` expects from a remote CAS server. This enables: +//! +//! - **Integration testing**: Test `RemoteClient` against a local server +//! - **Development**: Debug CAS operations without network dependencies +//! - **Offline workflows**: Store and retrieve CAS objects locally +//! +//! # Components +//! +//! - [`LocalServer`]: The main server struct that manages the HTTP listener +//! - [`LocalServerConfig`]: Configuration for the server (host, port, data directory) +//! - [`LocalTestServer`]: A test utility that starts a server and provides both remote and local client access +//! - `handlers`: HTTP request handlers for each API endpoint + +mod handlers; +mod server; + +pub use server::{LocalServer, LocalServerConfig, LocalTestServer}; diff --git a/cas_client/src/local_server/server.rs b/cas_client/src/local_server/server.rs new file mode 100644 index 00000000..c7253ced --- /dev/null +++ b/cas_client/src/local_server/server.rs @@ -0,0 +1,637 @@ +//! Local CAS Server Implementation +//! +//! This module provides `LocalServer`, an HTTP server that wraps `LocalClient` +//! and exposes its functionality via REST API endpoints compatible with `RemoteClient`. +//! +//! # Architecture +//! +//! The server uses Axum as its HTTP framework and shares an `Arc` +//! across all request handlers. Routes are organized to match the API expected +//! by `RemoteClient`, with some legacy route aliases for compatibility. +//! +//! # Example +//! +//! ```no_run +//! use cas_client::local_server::{LocalServer, LocalServerConfig}; +//! +//! #[tokio::main] +//! async fn main() -> anyhow::Result<()> { +//! let config = LocalServerConfig { +//! data_directory: "./data".into(), +//! host: "127.0.0.1".to_string(), +//! port: 8080, +//! }; +//! let server = LocalServer::new(config)?; +//! server.run().await?; +//! Ok(()) +//! } +//! ``` + +use std::net::{SocketAddr, TcpListener as StdTcpListener}; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Duration; + +use axum::Router; +use axum::routing::{get, head, post}; +use tokio::net::TcpListener; +use tokio::sync::oneshot; +use tower_http::cors::CorsLayer; + +use super::handlers; +use crate::error::{CasClientError, Result}; +use crate::{LocalClient, RemoteClient}; + +/// Configuration for the local CAS server. +#[derive(Clone, Debug)] +pub struct LocalServerConfig { + /// Directory where CAS data (XORBs, shards, indices) will be stored. + pub data_directory: PathBuf, + /// Network interface to bind to (e.g., "127.0.0.1" or "0.0.0.0"). + pub host: String, + /// TCP port number for the HTTP server. + pub port: u16, +} + +impl Default for LocalServerConfig { + fn default() -> Self { + Self { + data_directory: PathBuf::from("./local_cas_data"), + host: "127.0.0.1".to_string(), + port: 8080, + } + } +} + +/// A local HTTP server that wraps `LocalClient` and exposes CAS operations via REST API. +/// +/// This server implements the same API that `RemoteClient` expects, making it useful for: +/// - Integration testing without a remote backend +/// - Local development and debugging +/// - Offline CAS workflows +pub struct LocalServer { + config: LocalServerConfig, + client: Arc, +} + +impl LocalServer { + /// Creates a new server with the given configuration. + /// + /// This will create a new `LocalClient` pointing to the configured data directory. + /// The directory will be created if it doesn't exist. + pub fn new(config: LocalServerConfig) -> Result { + let client = LocalClient::new(&config.data_directory)?; + Ok(Self { config, client }) + } + + /// Creates a server from an existing `LocalClient`. + /// + /// Useful when you want to share a `LocalClient` instance between the server + /// and other code (e.g., for testing where you want to verify server behavior + /// against direct client access). + pub fn from_client(client: Arc, host: String, port: u16) -> Self { + Self { + config: LocalServerConfig { + data_directory: PathBuf::new(), + host, + port, + }, + client, + } + } + + /// Returns a clone of the underlying LocalClient. + pub fn client(&self) -> Arc { + self.client.clone() + } + + /// Returns the server's bind address as "host:port". + pub fn addr(&self) -> String { + format!("{}:{}", self.config.host, self.config.port) + } + + /// Builds the Axum router with all CAS API routes. + /// + /// Routes follow the pattern used by RemoteClient: + /// - `/v1/` prefixed routes for chunks, xorbs, reconstructions, and files + /// - Root-level `/reconstructions` for batch queries and `/shards` for uploads + fn create_router(&self) -> Router { + Router::new() + .route("/health", get(handlers::health_check)) + .nest( + "/v1", + Router::new() + .route("/reconstructions", get(handlers::batch_get_reconstruction)) + .route("/reconstructions/{file_id}", get(handlers::get_reconstruction)) + .route("/chunks/{prefix}/{hash}", get(handlers::get_dedup_info_by_chunk)) + .route("/xorbs/{prefix}/{hash}", head(handlers::head_xorb).post(handlers::post_xorb)) + .route("/files/{file_id}", head(handlers::head_file)) + .route("/get_xorb/{prefix}/{hash}/", get(handlers::get_file_term_data)) + .route("/fetch_term", get(handlers::fetch_term)), + ) + // Routes used by RemoteClient without /v1/ prefix + .route("/reconstructions", get(handlers::batch_get_reconstruction)) + .route("/shards", post(handlers::post_shard)) + .layer(CorsLayer::very_permissive()) + .with_state(self.client.clone()) + } + + /// Runs the server, listening for incoming HTTP requests. + /// + /// This method blocks until the server is shut down via signal (Ctrl+C on Unix). + /// For programmatic shutdown, use `run_until_stopped` instead. + pub async fn run(&self) -> Result<()> { + let addr: SocketAddr = self + .addr() + .parse() + .map_err(|e| CasClientError::Other(format!("Failed to parse address: {e}")))?; + + let listener = TcpListener::bind(addr) + .await + .map_err(|e| CasClientError::Other(format!("Failed to bind to {addr}: {e}")))?; + + tracing::info!("Local CAS server listening on {}", addr); + + let router = self.create_router(); + + axum::serve(listener, router.into_make_service()) + .with_graceful_shutdown(shutdown_signal()) + .await + .map_err(|e| CasClientError::Other(format!("Server error: {e}"))) + } + + /// Runs the server until a shutdown signal is received on the provided channel. + /// + /// This is useful for tests where you want programmatic control over server lifecycle. + pub async fn run_until_stopped(&self, shutdown_rx: tokio::sync::oneshot::Receiver<()>) -> Result<()> { + let addr: SocketAddr = self + .addr() + .parse() + .map_err(|e| CasClientError::Other(format!("Failed to parse address: {e}")))?; + + let listener = TcpListener::bind(addr) + .await + .map_err(|e| CasClientError::Other(format!("Failed to bind to {addr}: {e}")))?; + + tracing::info!("Local CAS server listening on {}", addr); + + let router = self.create_router(); + + axum::serve(listener, router.into_make_service()) + .with_graceful_shutdown(async { + let _ = shutdown_rx.await; + }) + .await + .map_err(|e| CasClientError::Other(format!("Server error: {e}"))) + } +} + +/// Waits for a shutdown signal (currently blocks forever as there's no SIGTERM handling). +async fn shutdown_signal() { + std::future::pending::<()>().await +} + +/// A test server that wraps `LocalServer` and provides easy access to both +/// `RemoteClient` (for HTTP interactions) and `LocalClient` (for direct state access). +/// +/// This is useful for integration tests where you want to verify that operations +/// through the HTTP API produce the same results as direct client access. +/// +/// The server runs as a spawned tokio task and automatically shuts down when dropped +/// (no explicit shutdown call needed). +/// +/// # Example +/// +/// ```ignore +/// let server = LocalTestServer::start().await; +/// +/// // Upload via RemoteClient +/// let file = server.remote_client().upload_random_file(&[(1, (0, 5))], 123).await?; +/// +/// // Verify via LocalClient +/// let stored = server.local_client().get_file_data(&file.file_hash, None).await?; +/// assert_eq!(file.data, stored); +/// // Server automatically shuts down when dropped +/// ``` +pub struct LocalTestServer { + endpoint: String, + server_shutdown_tx: Option>, + remote_client: Arc, + local_client: Arc, +} + +impl LocalTestServer { + /// Starts a new test server with a fresh temporary data directory. + /// + /// The server listens on a randomly assigned available port on localhost. + pub async fn start() -> Self { + let local_client = LocalClient::temporary().await.unwrap(); + Self::start_with_client(local_client).await + } + + /// Starts a new test server using an existing `LocalClient`. + /// + /// Useful when you need to pre-populate the client with data before starting the server. + pub async fn start_with_client(local_client: Arc) -> Self { + let port = Self::find_available_port(); + let host = "127.0.0.1".to_string(); + let endpoint = format!("http://{}:{}", host, port); + + let server = LocalServer::from_client(local_client.clone(), host, port); + let (shutdown_tx, shutdown_rx) = oneshot::channel(); + + tokio::spawn(async move { + let _ = server.run_until_stopped(shutdown_rx).await; + }); + + tokio::time::sleep(Duration::from_millis(50)).await; + + let remote_client = RemoteClient::new(&endpoint, &None, &None, "test-session", false, "test-agent"); + + Self { + endpoint, + server_shutdown_tx: Some(shutdown_tx), + remote_client, + local_client, + } + } + + /// Returns the HTTP endpoint URL (e.g., "http://127.0.0.1:12345"). + pub fn endpoint(&self) -> &str { + &self.endpoint + } + + /// Returns the `RemoteClient` configured to connect to this test server. + pub fn remote_client(&self) -> &Arc { + &self.remote_client + } + + /// Returns the underlying `LocalClient` for direct state access. + pub fn local_client(&self) -> &Arc { + &self.local_client + } + + fn find_available_port() -> u16 { + StdTcpListener::bind("127.0.0.1:0").unwrap().local_addr().unwrap().port() + } +} + +impl Drop for LocalTestServer { + fn drop(&mut self) { + if let Some(tx) = self.server_shutdown_tx.take() { + let _ = tx.send(()); + } + } +} + +#[cfg(test)] +mod tests { + use cas_types::FileRange; + + use super::*; + use crate::Client; + use crate::client_testing_utils::ClientTestingUtils; + + const CHUNK_SIZE: usize = 123; + + /// Verifies basic server operations: upload, reconstruction (full/range/batch/multi-xorb), + /// file info, dedup queries, and fetch_term endpoint. + async fn check_basic_correctness(server: &LocalTestServer) { + // Upload via RemoteClient, verify via LocalClient + let file = server + .remote_client() + .upload_random_file(&[(1, (0, 5))], CHUNK_SIZE) + .await + .unwrap(); + let local_data = server.local_client().get_file_data(&file.file_hash, None).await.unwrap(); + assert_eq!(file.data, local_data); + + // Full file reconstruction - compare remote and local + let remote_recon = server + .remote_client() + .get_reconstruction(&file.file_hash, None) + .await + .unwrap() + .unwrap(); + let local_recon = server + .local_client() + .get_reconstruction(&file.file_hash, None) + .await + .unwrap() + .unwrap(); + assert_eq!(remote_recon.terms.len(), local_recon.terms.len()); + assert_eq!(remote_recon.offset_into_first_range, local_recon.offset_into_first_range); + for (remote_term, local_term) in remote_recon.terms.iter().zip(local_recon.terms.iter()) { + assert_eq!(remote_term.hash, local_term.hash); + assert_eq!(remote_term.range, local_term.range); + } + + // Range reconstruction + let file_size = file.data.len() as u64; + let range = FileRange::new(file_size / 4, file_size * 3 / 4); + let range_recon = server + .remote_client() + .get_reconstruction(&file.file_hash, Some(range)) + .await + .unwrap(); + assert!(range_recon.is_some()); + + // Upload multi-xorb file + let term_spec = &[(1, (0, 3)), (2, (0, 2)), (1, (3, 5))]; + let multi_file = server.local_client().upload_random_file(term_spec, CHUNK_SIZE).await.unwrap(); + let multi_recon = server + .remote_client() + .get_reconstruction(&multi_file.file_hash, None) + .await + .unwrap() + .unwrap(); + assert_eq!(multi_recon.terms.len(), 3); + + // Batch reconstruction + let file_ids = vec![file.file_hash, multi_file.file_hash]; + let batch_result = server.remote_client().batch_get_reconstruction(&file_ids).await.unwrap(); + assert_eq!(batch_result.files.len(), 2); + + // File info (MDBFileInfo) + let (remote_mdb, _) = server + .remote_client() + .get_file_reconstruction_info(&file.file_hash) + .await + .unwrap() + .unwrap(); + let (local_mdb, _) = server + .local_client() + .get_file_reconstruction_info(&file.file_hash) + .await + .unwrap() + .unwrap(); + assert_eq!(remote_mdb.file_size(), local_mdb.file_size()); + + // Dedup query - use chunk hash from RandomFileContents + let first_chunk_hash = file.terms[0].chunk_hashes[0]; + let shard_result = server + .remote_client() + .query_for_global_dedup_shard("default", &first_chunk_hash) + .await + .unwrap(); + let local_shard = server + .local_client() + .query_for_global_dedup_shard("default", &first_chunk_hash) + .await + .unwrap(); + assert!(shard_result.is_some()); + assert_eq!(shard_result.unwrap(), local_shard.unwrap()); + + // Fetch term endpoint - verify URLs are HTTP and data can be fetched + let http_client = reqwest::Client::new(); + for fetch_infos in remote_recon.fetch_info.values() { + for fi in fetch_infos { + assert!(fi.url.starts_with("http://")); + assert!(fi.url.contains("/fetch_term?term=")); + let response = http_client.get(&fi.url).send().await.unwrap(); + assert!(response.status().is_success()); + assert!(!response.bytes().await.unwrap().is_empty()); + } + } + + // Fetch term with range request + let first_fi = &remote_recon.fetch_info.values().next().unwrap()[0]; + let full_data = http_client.get(&first_fi.url).send().await.unwrap().bytes().await.unwrap(); + if full_data.len() > 100 { + let range_resp = http_client + .get(&first_fi.url) + .header(reqwest::header::RANGE, "bytes=0-99") + .send() + .await + .unwrap(); + assert!(range_resp.status().is_success()); + let range_data = range_resp.bytes().await.unwrap(); + assert_eq!(range_data.len(), 100); + assert_eq!(&range_data[..], &full_data[..100]); + } + } + + /// Tests that invalid requests return appropriate error responses. + async fn check_error_handling(server: &LocalTestServer) { + let http_client = reqwest::Client::new(); + + // Nonexistent file hash for reconstruction + let fake_hash = + merklehash::MerkleHash::from_hex("d760aaf4beb07581956e24c847c47f1abd2e419166aa68259035bc412232e9da") + .unwrap(); + let result = server.remote_client().get_reconstruction(&fake_hash, None).await; + assert!(result.is_err() || result.unwrap().is_none()); + + // Nonexistent file for file info + let file_info = server.remote_client().get_file_reconstruction_info(&fake_hash).await; + assert!(file_info.is_err() || file_info.unwrap().is_none()); + + // Invalid fetch_term (valid base64 but nonexistent path) + let invalid_term_url = format!("{}/v1/fetch_term?term=aW52YWxpZF9wYXRo", server.endpoint()); + let response = http_client.get(&invalid_term_url).send().await.unwrap(); + assert!(response.status().is_client_error() || response.status().is_server_error()); + + // Malformed base64 in fetch_term + let malformed_url = format!("{}/v1/fetch_term?term=not-valid-base64!!!", server.endpoint()); + let response = http_client.get(&malformed_url).send().await.unwrap(); + assert_eq!(response.status(), reqwest::StatusCode::BAD_REQUEST); + } + + /// Verifies that reconstruction responses contain valid HTTP URLs. + async fn check_url_transformation(server: &LocalTestServer) { + let http_client = reqwest::Client::new(); + + // Single XORB file + let file1 = server + .local_client() + .upload_random_file(&[(1, (0, 5))], CHUNK_SIZE) + .await + .unwrap(); + + // Multi-XORB file + let term_spec = &[(1, (0, 3)), (2, (0, 2)), (1, (3, 5))]; + let multi_file = server.local_client().upload_random_file(term_spec, CHUNK_SIZE).await.unwrap(); + + // Verify single XORB URLs are HTTP + let recon1 = server + .remote_client() + .get_reconstruction(&file1.file_hash, None) + .await + .unwrap() + .unwrap(); + for (hash, fetch_infos) in &recon1.fetch_info { + for fi in fetch_infos { + assert!( + fi.url.starts_with("http://") || fi.url.starts_with("https://"), + "URL for hash {} should be HTTP, got: {}", + hash, + fi.url + ); + assert!(fi.url.contains("/fetch_term?term=")); + assert!(!fi.url.contains("\":")); + } + } + + // Verify multi-XORB file has HTTP URLs for all XORBs + let multi_recon = server + .remote_client() + .get_reconstruction(&multi_file.file_hash, None) + .await + .unwrap() + .unwrap(); + assert!(multi_recon.fetch_info.len() >= 2); + for fetch_infos in multi_recon.fetch_info.values() { + for fi in fetch_infos { + assert!(fi.url.starts_with("http://")); + } + } + + // Verify partial range reconstruction has HTTP URLs + let file_size = multi_file.data.len() as u64; + let range = FileRange::new(file_size / 4, file_size * 3 / 4); + let range_recon = server + .remote_client() + .get_reconstruction(&multi_file.file_hash, Some(range)) + .await + .unwrap() + .unwrap(); + for fetch_infos in range_recon.fetch_info.values() { + for fi in fetch_infos { + assert!(fi.url.starts_with("http://")); + assert!(fi.url.contains("/fetch_term?term=")); + } + } + + // Verify all term URLs are fetchable + for term in &recon1.terms { + let fetch_infos = recon1.fetch_info.get(&term.hash).unwrap(); + for fi in fetch_infos { + let response = http_client.get(&fi.url).send().await.unwrap(); + assert!(response.status().is_success()); + assert!(!response.bytes().await.unwrap().is_empty()); + } + } + } + + /// Verifies reconstruction term hashes match the uploaded file's expected terms. + async fn check_reconstruction_term_hashes_match(server: &LocalTestServer) { + // Upload a multi-term file + let term_spec = &[(1, (0, 3)), (2, (0, 2)), (1, (3, 5))]; + let file = server.local_client().upload_random_file(term_spec, CHUNK_SIZE).await.unwrap(); + + // Get reconstruction via remote client + let recon = server + .remote_client() + .get_reconstruction(&file.file_hash, None) + .await + .unwrap() + .unwrap(); + + // Verify term count matches + assert_eq!(recon.terms.len(), file.terms.len()); + + // Verify each term's XORB hash matches + for (i, recon_term) in recon.terms.iter().enumerate() { + let expected_term = &file.terms[i]; + assert_eq!(recon_term.hash.0, expected_term.xorb_hash, "Term {} XORB hash mismatch", i); + } + } + + /// Verifies that reconstruction data can be fetched and downloaded file matches expected data. + async fn check_downloaded_terms_match_expected_data(server: &LocalTestServer) { + // Upload a file with known term structure + let term_spec = &[(1, (0, 4)), (2, (0, 3))]; + let file = server.local_client().upload_random_file(term_spec, CHUNK_SIZE).await.unwrap(); + + // Get reconstruction + let recon = server + .remote_client() + .get_reconstruction(&file.file_hash, None) + .await + .unwrap() + .unwrap(); + + // Verify term count and XORB hashes match + assert_eq!(recon.terms.len(), file.terms.len()); + for (term_idx, recon_term) in recon.terms.iter().enumerate() { + let expected_term = &file.terms[term_idx]; + assert_eq!(recon_term.hash.0, expected_term.xorb_hash); + + // Verify fetch_info exists for each XORB + let fetch_infos = recon.fetch_info.get(&recon_term.hash).unwrap(); + assert!(!fetch_infos.is_empty()); + } + + // Verify the complete file can be retrieved correctly via LocalClient + let retrieved_data = server.local_client().get_file_data(&file.file_hash, None).await.unwrap(); + assert_eq!(retrieved_data, file.data); + + // Verify term_matches works correctly for each term + for (term_idx, term) in file.terms.iter().enumerate() { + assert!(file.term_matches(term_idx, &term.data)); + } + } + + /// Verifies that the complete file can be reconstructed by concatenating term data. + async fn check_complete_file_reconstruction(server: &LocalTestServer) { + // Upload a multi-term file + let term_spec = &[(1, (0, 3)), (2, (0, 2)), (1, (3, 5))]; + let file = server.local_client().upload_random_file(term_spec, CHUNK_SIZE).await.unwrap(); + + // Reconstruct file from term data + let mut reconstructed = Vec::new(); + for term in &file.terms { + reconstructed.extend_from_slice(&term.data); + } + + // Verify it matches the original file data + assert_eq!(reconstructed, file.data); + assert!(file.term_matches(0, &file.terms[0].data)); + assert!(file.term_matches(1, &file.terms[1].data)); + assert!(file.term_matches(2, &file.terms[2].data)); + + // Verify term_matches returns false for wrong data + assert!(!file.term_matches(0, &file.terms[1].data)); + } + + /// Verifies chunk hashes in RandomFileContents match expected values. + async fn check_chunk_hashes_correctness(server: &LocalTestServer) { + let file = server + .local_client() + .upload_random_file(&[(1, (0, 3))], CHUNK_SIZE) + .await + .unwrap(); + + // Verify we have the expected number of chunks + assert_eq!(file.terms.len(), 1); + assert_eq!(file.terms[0].chunk_hashes.len(), 3); + + // Verify chunk hashes match the RawXorbData cas_info (keyed by xorb hash) + let xorb_hash = file.terms[0].xorb_hash; + let raw_xorb = file.xorbs.get(&xorb_hash).unwrap(); + assert_eq!(raw_xorb.cas_info.chunks.len(), 3); + for (i, chunk_hash) in file.terms[0].chunk_hashes.iter().enumerate() { + assert_eq!(*chunk_hash, raw_xorb.cas_info.chunks[i].chunk_hash); + } + } + + /// Main test that runs all server checks with a single shared server instance. + #[tokio::test] + async fn test_local_server() { + // Verify server creation works + let temp_client = LocalClient::temporary().await.unwrap(); + let temp_server = LocalServer::from_client(temp_client.clone(), "127.0.0.1".to_string(), 0); + assert!(temp_server.client().get_all_entries().unwrap().is_empty()); + + // Start test server for HTTP operations + let server = LocalTestServer::start().await; + + check_basic_correctness(&server).await; + check_error_handling(&server).await; + check_url_transformation(&server).await; + check_reconstruction_term_hashes_match(&server).await; + check_downloaded_terms_match_expected_data(&server).await; + check_complete_file_reconstruction(&server).await; + check_chunk_hashes_correctness(&server).await; + } +} diff --git a/cas_client/src/retry_wrapper.rs b/cas_client/src/retry_wrapper.rs index 0a961560..7e0542ad 100644 --- a/cas_client/src/retry_wrapper.rs +++ b/cas_client/src/retry_wrapper.rs @@ -530,24 +530,22 @@ mod tests { ClientBuilder::new(reqwest::Client::new()).build() } - #[tokio::test] - async fn test_success_first_try() { - let server = MockServer::start().await; - - Mock::given(method("GET")) + async fn check_success_first_try(server: &MockServer) { + let _guard = Mock::given(method("GET")) .and(path("/success")) .respond_with(ResponseTemplate::new(200)) .expect(1) - .mount(&server) + .mount_as_scoped(server) .await; let client = make_client(); let counter = Arc::new(AtomicU32::new(0)); let counter_ = counter.clone(); + let server_uri = server.uri(); - let result = connection_wrapper("test_success_first_try") + let result = connection_wrapper("check_success_first_try") .run(move |_partial_report_fn| { - let url = format!("{}/success", server.uri()); + let url = format!("{}/success", server_uri); counter_.fetch_add(1, Ordering::Relaxed); client.clone().get(&url).send() }) @@ -557,32 +555,30 @@ mod tests { assert_eq!(counter.load(Ordering::SeqCst), 1); } - #[tokio::test] - async fn test_retry_then_success() { - let server = MockServer::start().await; - + async fn check_retry_then_success(server: &MockServer) { // First two return 500 - Mock::given(method("GET")) + let _guard1 = Mock::given(method("GET")) .and(path("/flaky")) .respond_with(ResponseTemplate::new(500)) .up_to_n_times(2) - .mount(&server) + .mount_as_scoped(server) .await; // Third returns 200 - Mock::given(method("GET")) + let _guard2 = Mock::given(method("GET")) .and(path("/flaky")) .respond_with(ResponseTemplate::new(200).set_body_string("Recovered")) - .mount(&server) + .mount_as_scoped(server) .await; let client = make_client(); let counter = Arc::new(AtomicU32::new(0)); let counter_ = counter.clone(); + let server_uri = server.uri(); - let result = connection_wrapper("test_retry_then_success") + let result = connection_wrapper("check_retry_then_success") .run(move |_partial_report_fn| { - let url = format!("{}/flaky", server.uri()); + let url = format!("{}/flaky", server_uri); counter_.fetch_add(1, Ordering::Relaxed); client.clone().get(url).send() }) @@ -590,161 +586,144 @@ mod tests { assert!(result.is_ok()); assert_eq!(&result.unwrap().bytes().await.unwrap()[..], b"Recovered"); - assert_eq!(counter.load(Ordering::SeqCst), 3); // handle() only called on retry attempts + assert_eq!(counter.load(Ordering::SeqCst), 3); } - #[tokio::test] - async fn test_retry_limit_exceeded() { - let server = MockServer::start().await; - + async fn check_retry_limit_exceeded(server: &MockServer) { // Always return 500 - Mock::given(method("GET")) + let _guard = Mock::given(method("GET")) .and(path("/fail")) .respond_with(ResponseTemplate::new(500)) .expect(4) // 1 initial + 3 retries - .mount(&server) + .mount_as_scoped(server) .await; let client = make_client(); let counter = Arc::new(AtomicU32::new(0)); let counter_ = counter.clone(); + let server_uri = server.uri(); - let result = connection_wrapper("test_retry_limit_exceeded") + let result = connection_wrapper("check_retry_limit_exceeded") .with_max_attempts(3) .run(move |_partial_report_fn| { - let url = format!("{}/fail", server.uri()); + let url = format!("{}/fail", server_uri); counter_.fetch_add(1, Ordering::Relaxed); client.clone().get(&url).send() }) .await; assert!(result.is_err()); - assert_eq!(counter.load(Ordering::SeqCst), 4); // 3 retries attempted + assert_eq!(counter.load(Ordering::SeqCst), 4); } - #[tokio::test] - async fn test_non_retryable_status() { - let server = MockServer::start().await; - + async fn check_non_retryable_status(server: &MockServer) { // Respond with a 400 Bad Request - Mock::given(method("GET")) - .and(path("/bad")) + let _guard = Mock::given(method("GET")) + .and(path("/bad_request")) .respond_with(ResponseTemplate::new(400)) .expect(1) - .mount(&server) + .mount_as_scoped(server) .await; let client = make_client(); let counter = Arc::new(AtomicU32::new(0)); let counter_ = counter.clone(); + let server_uri = server.uri(); - let result = connection_wrapper("test_non_retryable_status") + let result = connection_wrapper("check_non_retryable_status") .run(move |_partial_report_fn| { - let url = format!("{}/bad", server.uri()); + let url = format!("{}/bad_request", server_uri); counter_.fetch_add(1, Ordering::Relaxed); client.clone().get(&url).send() }) .await; assert!(result.is_err()); - assert_eq!(counter.load(Ordering::SeqCst), 1); // strategy called once + assert_eq!(counter.load(Ordering::SeqCst), 1); } - #[tokio::test] - async fn test_429_retry_if_specified() { - // Ensures that 429 does in fact retry unless told not to. - - let server = MockServer::start().await; - + async fn check_429_retry_if_specified(server: &MockServer) { // Respond with a 429 too many requests - Mock::given(method("GET")) - .and(path("/bad")) + let _guard = Mock::given(method("GET")) + .and(path("/rate_limit")) .respond_with(ResponseTemplate::new(429)) .expect(4) - .mount(&server) + .mount_as_scoped(server) .await; let client = make_client(); let counter = Arc::new(AtomicU32::new(0)); let counter_ = counter.clone(); + let server_uri = server.uri(); - let result = connection_wrapper("test_429_retry_if_specified") + let result = connection_wrapper("check_429_retry_if_specified") .with_max_attempts(3) .run(move |_partial_report_fn| { - let url = format!("{}/bad", server.uri()); + let url = format!("{}/rate_limit", server_uri); counter_.fetch_add(1, Ordering::Relaxed); client.clone().get(&url).send() }) .await; assert!(result.is_err()); - assert_eq!(counter.load(Ordering::SeqCst), 4); // strategy called once + assert_eq!(counter.load(Ordering::SeqCst), 4); } - #[tokio::test] - async fn test_429_retry_no_retry() { - // Ensures that 429 does in fact retry unless told not to. - - let server = MockServer::start().await; - + async fn check_429_no_retry(server: &MockServer) { // Respond with a 429 too many requests - Mock::given(method("GET")) - .and(path("/bad")) + let _guard = Mock::given(method("GET")) + .and(path("/rate_limit_no_retry")) .respond_with(ResponseTemplate::new(429)) .expect(1) - .mount(&server) + .mount_as_scoped(server) .await; let client = make_client(); let counter = Arc::new(AtomicU32::new(0)); let counter_ = counter.clone(); + let server_uri = server.uri(); - let result = connection_wrapper("test_429_no_retry") + let result = connection_wrapper("check_429_no_retry") .with_max_attempts(3) .with_429_no_retry() .run(move |_partial_report_fn| { - let url = format!("{}/bad", server.uri()); + let url = format!("{}/rate_limit_no_retry", server_uri); counter_.fetch_add(1, Ordering::Relaxed); client.clone().get(&url).send() }) .await; assert!(result.is_err()); - assert_eq!(counter.load(Ordering::SeqCst), 1); // strategy called once + assert_eq!(counter.load(Ordering::SeqCst), 1); } - // Testing the JSON parsing #[derive(Serialize, Deserialize, PartialEq, Debug)] struct JsonData { text: String, number: u64, } - #[tokio::test] - async fn test_json_reserialization() { - // Ensures that 429 does in fact retry unless told not to. + async fn check_json_reserialization(server: &MockServer) { let data = JsonData { text: "test".into(), number: 42, }; - let server = MockServer::start().await; - - // Respond with a 429 too many requests - Mock::given(method("GET")) - .and(path("/bad")) + let _guard = Mock::given(method("GET")) + .and(path("/json")) .respond_with(ResponseTemplate::new(StatusCode::OK).set_body_json(&data)) .expect(1) - .mount(&server) + .mount_as_scoped(server) .await; let client = make_client(); let counter = Arc::new(AtomicU32::new(0)); let counter_ = counter.clone(); + let server_uri = server.uri(); - let ret_data: JsonData = connection_wrapper("test_json") + let ret_data: JsonData = connection_wrapper("check_json_reserialization") .run_and_extract_json(move |_partial_report_fn| { - let url = format!("{}/bad", server.uri()); + let url = format!("{}/json", server_uri); counter_.fetch_add(1, Ordering::Relaxed); client.clone().get(&url).send() }) @@ -752,12 +731,10 @@ mod tests { .unwrap(); assert_eq!(ret_data, data); - assert_eq!(counter.load(Ordering::SeqCst), 1); // strategy called once + assert_eq!(counter.load(Ordering::SeqCst), 1); } - #[tokio::test] - async fn test_json_unexpected_eof_retry() { - // Ensures that 429 does in fact retry unless told not to. + async fn check_json_unexpected_eof_retry(server: &MockServer) { let data = JsonData { text: "test".into(), number: 42, @@ -765,31 +742,30 @@ mod tests { let json_data = serde_json::to_string(&data).unwrap(); - let server = MockServer::start().await; - - // Respond with a 429 too many requests - Mock::given(method("GET")) + // First response truncated to simulate unexpected EOF + let _guard1 = Mock::given(method("GET")) .and(path("/json_flaky")) - .respond_with(ResponseTemplate::new(StatusCode::OK).set_body_string(&json_data[..json_data.len() - 5])) // Truncate to simulate unexpected EOF + .respond_with(ResponseTemplate::new(StatusCode::OK).set_body_string(&json_data[..json_data.len() - 5])) .up_to_n_times(1) - .mount(&server) + .mount_as_scoped(server) .await; - // Respond with a 429 too many requests - Mock::given(method("GET")) + // Second response with full data + let _guard2 = Mock::given(method("GET")) .and(path("/json_flaky")) - .respond_with(ResponseTemplate::new(StatusCode::OK).set_body_string(&json_data)) // Full length + .respond_with(ResponseTemplate::new(StatusCode::OK).set_body_string(&json_data)) .expect(1) - .mount(&server) + .mount_as_scoped(server) .await; let client = make_client(); let counter = Arc::new(AtomicU32::new(0)); let counter_ = counter.clone(); + let server_uri = server.uri(); - let ret_data: JsonData = connection_wrapper("test_json_unexpected_eof") + let ret_data: JsonData = connection_wrapper("check_json_unexpected_eof_retry") .run_and_extract_json(move |_partial_report_fn| { - let url = format!("{}/json_flaky", server.uri()); + let url = format!("{}/json_flaky", server_uri); counter_.fetch_add(1, Ordering::Relaxed); client.clone().get(&url).send() }) @@ -797,6 +773,26 @@ mod tests { .unwrap(); assert_eq!(ret_data, data); - assert_eq!(counter.load(Ordering::SeqCst), 2); // strategy called twice + assert_eq!(counter.load(Ordering::SeqCst), 2); + } + + #[tokio::test] + async fn test_retry_wrapper() { + let server = MockServer::start().await; + + // To avoid "Too many open files" error, we start one server here + // and then have each check below use the same server with scoped + // mocks. When running each of these tests on its own, it seemed + // we would hit sporadic "Too many open files" errors when the + // wiremock server was starting. + + check_success_first_try(&server).await; + check_retry_then_success(&server).await; + check_retry_limit_exceeded(&server).await; + check_non_retryable_status(&server).await; + check_429_retry_if_specified(&server).await; + check_429_no_retry(&server).await; + check_json_reserialization(&server).await; + check_json_unexpected_eof_retry(&server).await; } } diff --git a/cas_client/tests/reconstruction.rs b/cas_client/tests/reconstruction.rs new file mode 100644 index 00000000..ddf9db11 --- /dev/null +++ b/cas_client/tests/reconstruction.rs @@ -0,0 +1,200 @@ +//! Integration tests for file reconstruction using RemoteClient against a local test server. +//! +//! These tests verify that the two reconstruction routines in RemoteClient +//! (`get_file_with_sequential_writer` and `get_file_with_parallel_writer`) +//! correctly download and reconstruct files of various sizes and configurations. + +use std::sync::Arc; + +use cas_client::client_testing_utils::{ClientTestingUtils, RandomFileContents}; +use cas_client::local_server::LocalTestServer; +use cas_client::{Client, SeekingOutputProvider, sequential_output_from_filepath}; +use cas_types::FileRange; +use tempfile::NamedTempFile; + +/// Small chunk size for testing - produces more terms per file. +const CHUNK_SIZE: usize = 579; + +/// Helper to run sequential reconstruction and return the data. +async fn reconstruct_sequential( + client: &Arc, + file_hash: &merklehash::MerkleHash, + byte_range: Option, +) -> Vec { + let temp_file = NamedTempFile::new().unwrap(); + let output = sequential_output_from_filepath(temp_file.path()).unwrap(); + + client + .clone() + .get_file_with_sequential_writer(file_hash, byte_range, output, None) + .await + .unwrap(); + + std::fs::read(temp_file.path()).unwrap() +} + +/// Helper to run parallel reconstruction and return the data. +async fn reconstruct_parallel( + client: &Arc, + file_hash: &merklehash::MerkleHash, + byte_range: Option, +) -> Vec { + let temp_file = NamedTempFile::new().unwrap(); + let output = SeekingOutputProvider::new_file_provider(temp_file.path().to_path_buf()); + + client + .clone() + .get_file_with_parallel_writer(file_hash, byte_range, output, None) + .await + .unwrap(); + + std::fs::read(temp_file.path()).unwrap() +} + +/// Uploads a file with the given term specification. +async fn upload_file(server: &LocalTestServer, term_spec: &[(u64, (u64, u64))]) -> RandomFileContents { + server.local_client().upload_random_file(term_spec, CHUNK_SIZE).await.unwrap() +} + +/// Tests both sequential and parallel reconstruction, verifying correctness. +async fn check_reconstruction(server: &LocalTestServer, file: &RandomFileContents, range: Option) { + let expected_data = match range { + Some(r) => &file.data[r.start as usize..r.end as usize], + None => &file.data[..], + }; + + let sequential_result = reconstruct_sequential(server.remote_client(), &file.file_hash, range).await; + assert_eq!(sequential_result, expected_data, "Sequential reconstruction mismatch"); + + let parallel_result = reconstruct_parallel(server.remote_client(), &file.file_hash, range).await; + assert_eq!(parallel_result, expected_data, "Parallel reconstruction mismatch"); + + assert_eq!(sequential_result, parallel_result, "Sequential and parallel results differ"); +} + +// ============================================================================ +// Single-term file tests +// ============================================================================ + +/// Tests reconstruction of a single-term file with few chunks. +async fn check_single_term_full_file(server: &LocalTestServer) { + let file = upload_file(server, &[(1, (0, 3))]).await; + check_reconstruction(server, &file, None).await; +} + +/// Tests reconstruction of a single-term file with many chunks. +async fn check_single_term_many_chunks(server: &LocalTestServer) { + let file = upload_file(server, &[(1, (0, 20))]).await; + check_reconstruction(server, &file, None).await; +} + +// ============================================================================ +// Multi-term file tests (multiple XORBs) +// ============================================================================ + +/// Tests reconstruction of a multi-term file. +async fn check_multi_term_full_file(server: &LocalTestServer) { + let file = upload_file(server, &[(1, (0, 2)), (2, (0, 3)), (1, (2, 4))]).await; + check_reconstruction(server, &file, None).await; +} + +/// Tests reconstruction of a file with many terms. +async fn check_many_terms(server: &LocalTestServer) { + let term_spec: Vec<(u64, (u64, u64))> = (0..10).map(|i| (i, (0, 2))).collect(); + let file = upload_file(server, &term_spec).await; + check_reconstruction(server, &file, None).await; +} + +// ============================================================================ +// XORB reuse tests (same XORB referenced multiple times) +// ============================================================================ + +/// Tests reconstruction when the same XORB is referenced multiple times. +async fn check_xorb_reuse(server: &LocalTestServer) { + let file = upload_file(server, &[(1, (0, 2)), (2, (0, 2)), (1, (2, 4)), (2, (2, 4)), (1, (0, 2))]).await; + check_reconstruction(server, &file, None).await; +} + +// ============================================================================ +// Range request tests - partial file downloads +// ============================================================================ + +/// Tests range reconstruction from the start of the file. +async fn check_range_from_start(server: &LocalTestServer) { + let file = upload_file(server, &[(1, (0, 5))]).await; + let range_end = file.data.len() as u64 / 2; + check_reconstruction(server, &file, Some(FileRange::new(0, range_end))).await; +} + +/// Tests range reconstruction from the middle of the file. +async fn check_range_middle(server: &LocalTestServer) { + let file = upload_file(server, &[(1, (0, 6))]).await; + let file_len = file.data.len() as u64; + check_reconstruction(server, &file, Some(FileRange::new(file_len / 4, file_len * 3 / 4))).await; +} + +/// Tests range reconstruction to the end of the file. +async fn check_range_to_end(server: &LocalTestServer) { + let file = upload_file(server, &[(1, (0, 5))]).await; + let file_len = file.data.len() as u64; + check_reconstruction(server, &file, Some(FileRange::new(file_len / 2, file_len))).await; +} + +/// Tests range reconstruction spanning multiple terms. +async fn check_range_spanning_terms(server: &LocalTestServer) { + let file = upload_file(server, &[(1, (0, 3)), (2, (0, 2)), (3, (0, 3))]).await; + let term1_size = file.terms[0].data.len() as u64; + let term2_size = file.terms[1].data.len() as u64; + check_reconstruction(server, &file, Some(FileRange::new(term1_size / 2, term1_size + term2_size / 2))).await; +} + +/// Tests range reconstruction in the middle of a multi-term file. +async fn check_range_multi_term_middle(server: &LocalTestServer) { + let file = upload_file(server, &[(1, (0, 4)), (2, (0, 3)), (3, (0, 2))]).await; + let file_len = file.data.len() as u64; + check_reconstruction(server, &file, Some(FileRange::new(file_len / 4, file_len * 3 / 4))).await; +} + +// ============================================================================ +// Edge cases +// ============================================================================ + +/// Tests reconstruction of a small byte range. +async fn check_small_range(server: &LocalTestServer) { + let file = upload_file(server, &[(1, (0, 4))]).await; + check_reconstruction(server, &file, Some(FileRange::new(100, 200))).await; +} + +/// Tests reconstruction of a single byte. +async fn check_single_byte_range(server: &LocalTestServer) { + let file = upload_file(server, &[(1, (0, 3))]).await; + check_reconstruction(server, &file, Some(FileRange::new(50, 51))).await; +} + +/// Main test that runs all reconstruction checks with a single shared server. +#[tokio::test] +async fn test_reconstruction_with_local_server() { + let server = LocalTestServer::start().await; + + // Single-term file tests + check_single_term_full_file(&server).await; + check_single_term_many_chunks(&server).await; + + // Multi-term file tests + check_multi_term_full_file(&server).await; + check_many_terms(&server).await; + + // XORB reuse tests + check_xorb_reuse(&server).await; + + // Range request tests + check_range_from_start(&server).await; + check_range_middle(&server).await; + check_range_to_end(&server).await; + check_range_spanning_terms(&server).await; + check_range_multi_term_middle(&server).await; + + // Edge cases + check_small_range(&server).await; + check_single_byte_range(&server).await; +} diff --git a/hf_xet/Cargo.lock b/hf_xet/Cargo.lock index 36d08657..8bea0251 100644 --- a/hf_xet/Cargo.lock +++ b/hf_xet/Cargo.lock @@ -171,10 +171,13 @@ checksum = "021e862c184ae977658b36c4500f7feac3221ca5da43e3f25bd04ab6c79a29b5" dependencies = [ "axum-core", "bytes", + "form_urlencoded", "futures-util", "http 1.3.1", "http-body 1.0.1", "http-body-util", + "hyper 1.7.0", + "hyper-util", "itoa", "matchit", "memchr", @@ -183,10 +186,15 @@ dependencies = [ "pin-project-lite", "rustversion", "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", "sync_wrapper", + "tokio", "tower", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -205,6 +213,7 @@ dependencies = [ "sync_wrapper", "tower-layer", "tower-service", + "tracing", ] [[package]] @@ -307,6 +316,8 @@ version = "0.14.5" dependencies = [ "anyhow", "async-trait", + "axum", + "base64 0.22.1", "bytes", "cas_object", "cas_types", @@ -318,6 +329,7 @@ dependencies = [ "error_printer", "file_utils", "futures", + "futures-util", "heed", "http 1.3.1", "hyper 1.7.0", @@ -337,6 +349,7 @@ dependencies = [ "thiserror 2.0.15", "tokio", "tokio-retry", + "tower-http", "tracing", "tracing-log", "tracing-subscriber", @@ -3309,6 +3322,16 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59fab13f937fa393d08645bf3a84bdfe86e296747b506ada67bb15f10f218b2a" +dependencies = [ + "itoa", + "serde", +] + [[package]] name = "serde_repr" version = "0.1.20" diff --git a/hf_xet_wasm/Cargo.lock b/hf_xet_wasm/Cargo.lock index b2a9abd9..552006e6 100644 --- a/hf_xet_wasm/Cargo.lock +++ b/hf_xet_wasm/Cargo.lock @@ -135,6 +135,58 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "axum" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8" +dependencies = [ + "axum-core", + "bytes", + "form_urlencoded", + "futures-util", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", + "hyper 1.7.0", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "serde_core", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59446ce19cd142f8833f856eb31f3eb097812d1479ab224f54d72428ca21ea22" +dependencies = [ + "bytes", + "futures-core", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", + "mime", + "pin-project-lite", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "backtrace" version = "0.3.75" @@ -235,6 +287,8 @@ version = "0.14.5" dependencies = [ "anyhow", "async-trait", + "axum", + "base64 0.22.1", "bytes", "cas_object", "cas_types", @@ -246,6 +300,7 @@ dependencies = [ "error_printer", "file_utils", "futures", + "futures-util", "heed", "http 1.3.1", "hyper 1.7.0", @@ -265,6 +320,7 @@ dependencies = [ "thiserror 2.0.16", "tokio", "tokio-retry", + "tower-http", "tracing", "tracing-log", "tracing-subscriber", @@ -1221,6 +1277,7 @@ dependencies = [ "http 1.3.1", "http-body 1.0.1", "httparse", + "httpdate", "itoa", "pin-project-lite", "pin-utils", @@ -1627,6 +1684,12 @@ dependencies = [ "regex-automata", ] +[[package]] +name = "matchit" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" + [[package]] name = "matrixmultiply" version = "0.3.10" @@ -2540,10 +2603,11 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.219" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" dependencies = [ + "serde_core", "serde_derive", ] @@ -2559,10 +2623,19 @@ dependencies = [ ] [[package]] -name = "serde_derive" -version = "1.0.219" +name = "serde_core" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", @@ -2581,6 +2654,17 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" +dependencies = [ + "itoa", + "serde", + "serde_core", +] + [[package]] name = "serde_repr" version = "0.1.20" @@ -3043,6 +3127,7 @@ dependencies = [ "tokio", "tower-layer", "tower-service", + "tracing", ] [[package]]