diff --git a/Cargo.lock b/Cargo.lock index 00a340ae..fac47419 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2311,6 +2311,7 @@ dependencies = [ "itertools 0.14.0", "lazy_static", "merklehash", + "more-asserts", "rand 0.9.1", "regex", "serde", diff --git a/hf_xet/Cargo.lock b/hf_xet/Cargo.lock index 35eee6a4..d1ab4eb4 100644 --- a/hf_xet/Cargo.lock +++ b/hf_xet/Cargo.lock @@ -1864,6 +1864,7 @@ dependencies = [ "itertools 0.14.0", "lazy_static", "merklehash", + "more-asserts", "rand 0.9.2", "regex", "serde", diff --git a/mdb_shard/Cargo.toml b/mdb_shard/Cargo.toml index 3f2525ba..0eef5370 100644 --- a/mdb_shard/Cargo.toml +++ b/mdb_shard/Cargo.toml @@ -26,6 +26,7 @@ tempfile = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true } tracing = { workspace = true } +more-asserts = { workspace = true } [target.'cfg(target_family = "wasm")'.dependencies] uuid = { workspace = true, features = ["v4", "js"] } diff --git a/mdb_shard/src/cas_structs.rs b/mdb_shard/src/cas_structs.rs index fcd5a0d3..86e21b17 100644 --- a/mdb_shard/src/cas_structs.rs +++ b/mdb_shard/src/cas_structs.rs @@ -5,8 +5,12 @@ use bytes::Bytes; use merklehash::MerkleHash; use utils::serialization_utils::*; +use crate::hash_is_global_dedup_eligible; + pub const MDB_DEFAULT_CAS_FLAG: u32 = 0; +pub const MDB_CHUNK_WITH_GLOBAL_DEDUP_FLAG: u32 = 1 << 31; + /// Each CAS consists of a CASChunkSequenceHeader following /// a sequence of CASChunkSequenceEntry. @@ -89,7 +93,8 @@ pub struct CASChunkSequenceEntry { pub chunk_hash: MerkleHash, pub chunk_byte_range_start: u32, pub unpacked_segment_bytes: u32, - pub _unused: u64, + pub flags: u32, + pub _unused: u32, } impl CASChunkSequenceEntry { @@ -106,13 +111,31 @@ impl CASChunkSequenceEntry { chunk_hash, unpacked_segment_bytes: unpacked_segment_bytes.try_into().unwrap(), chunk_byte_range_start: chunk_byte_range_start.try_into().unwrap(), - #[cfg(test)] - _unused: 216944691646848u64, - #[cfg(not(test))] + flags: 0, _unused: 0, } } + /// Mark this chunk as a candidate for population in the global dedup table. + pub fn with_global_dedup_flag(self, is_global_dedup_chunk: bool) -> Self { + if is_global_dedup_chunk { + Self { + flags: self.flags | MDB_CHUNK_WITH_GLOBAL_DEDUP_FLAG, + ..self + } + } else { + Self { + flags: self.flags & !MDB_CHUNK_WITH_GLOBAL_DEDUP_FLAG, + ..self + } + } + } + + // Is this chunk elegible for a global dedup query? + pub fn is_global_dedup_eligible(&self) -> bool { + (self.flags & MDB_CHUNK_WITH_GLOBAL_DEDUP_FLAG) != 0 || hash_is_global_dedup_eligible(&self.chunk_hash) + } + pub fn serialize(&self, writer: &mut W) -> Result { let mut buf = [0u8; size_of::()]; { @@ -122,7 +145,8 @@ impl CASChunkSequenceEntry { write_hash(writer, &self.chunk_hash)?; write_u32(writer, self.chunk_byte_range_start)?; write_u32(writer, self.unpacked_segment_bytes)?; - write_u64(writer, self._unused)?; + write_u32(writer, self.flags)?; + write_u32(writer, self._unused)?; } writer.write_all(&buf[..])?; @@ -140,7 +164,8 @@ impl CASChunkSequenceEntry { chunk_hash: read_hash(reader)?, chunk_byte_range_start: read_u32(reader)?, unpacked_segment_bytes: read_u32(reader)?, - _unused: read_u64(reader)?, + flags: read_u32(reader)?, + _unused: read_u32(reader)?, }) } } @@ -253,4 +278,21 @@ impl MDBCASInfoView { writer.write_all(&self.data[..n_bytes])?; Ok(n_bytes) } + + #[inline] + pub fn serialize_with_chunk_rewrite( + &self, + writer: &mut W, + chunk_rewrite_fn: impl Fn(usize, CASChunkSequenceEntry) -> CASChunkSequenceEntry, + ) -> std::io::Result { + let mut n_out_bytes = 0; + n_out_bytes += self.header.serialize(writer)?; + + for idx in 0..self.num_entries() { + let rewritten_chunk = chunk_rewrite_fn(idx, self.chunk(idx)); + n_out_bytes += rewritten_chunk.serialize(writer)?; + } + + Ok(n_out_bytes) + } } diff --git a/mdb_shard/src/shard_format.rs b/mdb_shard/src/shard_format.rs index 91e5e8a0..cab40a91 100644 --- a/mdb_shard/src/shard_format.rs +++ b/mdb_shard/src/shard_format.rs @@ -1,4 +1,4 @@ -use std::collections::{BTreeMap, HashMap}; +use std::collections::BTreeMap; use std::io::{Read, Seek, SeekFrom, Write, copy}; use std::mem::size_of; use std::ops::Add; @@ -13,11 +13,11 @@ use tracing::debug; use utils::serialization_utils::*; use crate::cas_structs::*; -use crate::constants::*; use crate::error::{MDBShardError, Result}; use crate::file_structs::*; use crate::interpolation_search::search_on_sorted_u64s; use crate::shard_in_memory::MDBInMemoryShard; +use crate::streaming_shard::MDBMinimalShard; use crate::utils::{shard_expiry_time, truncate_hash}; // Same size for FileDataSequenceHeader and FileDataSequenceEntry @@ -931,44 +931,9 @@ impl MDBShardInfo { /// The chunk hashes are either multiple of 'hash_filter_modulues', /// or the hash of the first chunk of a file present in the shard. pub fn filter_cas_chunks_for_global_dedup(reader: &mut R) -> Result> { - let mut ret = Vec::new(); + let shard = MDBMinimalShard::from_reader(reader, true, true)?; - // First, go through and get all of the cas chunks. This allows us to form the lookup for the CAS block - // hashes later. - let shard = MDBShardInfo::load_from_reader(reader)?; - - let cas_chunks = shard.read_all_cas_blocks_full(reader)?; - let mut cas_block_lookup = HashMap::::with_capacity(cas_chunks.len()); - - for (i, cas_info) in cas_chunks.iter().enumerate() { - cas_block_lookup.insert(cas_info.metadata.cas_hash, i); - for chunk in cas_info.chunks.iter() { - if hash_is_global_dedup_eligible(&chunk.chunk_hash) { - ret.push(chunk.chunk_hash); - } - } - } - - // Now, go through all the files present, collecting the first chunks of the files. - // TODO: break this out into a utility if needed. - let files = shard.read_all_file_info_sections(reader)?; - - for fi in files { - let Some(entry) = fi.segments.first() else { - continue; - }; - - let Some(cas_block_index) = cas_block_lookup.get(&entry.cas_hash) else { - continue; - }; - - // Scan the cas entries to get the proper index - let first_chunk_hash = cas_chunks[*cas_block_index].chunks[entry.chunk_index_start as usize].chunk_hash; - - ret.push(first_chunk_hash); - } - - Ok(ret) + Ok(shard.global_dedup_eligible_chunks()) } /// Export the current shard as an hmac keyed shard, returning the number of bytes written diff --git a/mdb_shard/src/streaming_shard.rs b/mdb_shard/src/streaming_shard.rs index 9f11a662..b72a1bab 100644 --- a/mdb_shard/src/streaming_shard.rs +++ b/mdb_shard/src/streaming_shard.rs @@ -1,3 +1,4 @@ +use std::collections::{HashMap, HashSet}; use std::io::{Cursor, Read, Write, copy}; use std::mem::size_of; @@ -5,6 +6,8 @@ use bytes::Bytes; use futures::AsyncRead; use futures_util::io::AsyncReadExt; use itertools::Itertools; +use merklehash::MerkleHash; +use more_asserts::debug_assert_lt; use crate::cas_structs::{CASChunkSequenceEntry, CASChunkSequenceHeader, MDBCASInfoView}; use crate::error::{MDBShardError, Result}; @@ -296,7 +299,43 @@ impl MDBMinimalShard { + size_of::() } - pub fn serialize(&self, writer: &mut W, with_verification: bool) -> Result { + /// Return a lookup of xorb hash to starting chunk indices for all the files present in the + /// shard. These are the chunks that are useful for global dedup. + fn file_start_entries(&self) -> HashMap> { + let mut file_start_entries = HashMap::>::new(); + + for f_idx in 0..self.num_files() { + let Some(fv) = self.file(f_idx) else { + break; + }; + + if fv.num_entries() > 0 { + let entry = fv.entry(0); + let cas_hash = entry.cas_hash; + let idx = entry.chunk_index_start; + + file_start_entries.entry(cas_hash).or_default().push(idx as usize); + } + } + + // Sort all the individual entries. + for v in file_start_entries.values_mut() { + v.sort_unstable(); + v.dedup(); + } + + file_start_entries + } + + /// Implementation for the xorb serialization function. Use one of the methods below + /// to directly access this. + fn serialize_impl( + &self, + writer: &mut W, + with_file_section: bool, + with_verification: bool, + xorb_filter_fn: impl Fn(&MDBCASInfoView) -> bool, + ) -> Result { let mut bytes = 0; bytes += MDBShardFileHeader::default().serialize(writer)?; @@ -307,22 +346,49 @@ impl MDBMinimalShard { let mut stored_bytes = 0; let mut materialized_bytes = 0; + // When adding in the global dedup flags based on the files present in the shard, we first need to get + // a lookup of which chunks occur at the start of a file. These are the ones for which we set the + // global dedup eligibility flag. + // + // In addition, we propagate the global dedup eligibility flag if it is already present. + // + let file_start_chunks = self.file_start_entries(); + let fs_start = bytes as u64; - for file_info in &self.file_info_views { - for j in 0..file_info.num_entries() { - let segment_info = file_info.entry(j); - materialized_bytes += segment_info.unpacked_segment_bytes as u64; + + if with_file_section { + for file_info in &self.file_info_views { + for j in 0..file_info.num_entries() { + let segment_info = file_info.entry(j); + materialized_bytes += segment_info.unpacked_segment_bytes as u64; + } + bytes += file_info.serialize(writer, with_verification)?; } - bytes += file_info.serialize(writer, with_verification)?; } bytes += FileDataSequenceHeader::bookend().serialize(writer)?; let cs_start = bytes as u64; for cas_info in &self.cas_info_views { + // Skip any filtered sections. + if !xorb_filter_fn(cas_info) { + continue; + } + stored_bytes_on_disk += cas_info.header().num_bytes_on_disk as u64; stored_bytes += cas_info.header().num_bytes_in_cas as u64; - bytes += cas_info.serialize(writer)?; + if let Some(gde_indices) = file_start_chunks.get(&cas_info.cas_hash()) { + debug_assert!(gde_indices.is_sorted()); + bytes += cas_info.serialize_with_chunk_rewrite(writer, |idx, chunk| { + if gde_indices.binary_search(&idx).is_ok() { + chunk.with_global_dedup_flag(true) + } else { + chunk + } + })?; + } else { + bytes += cas_info.serialize(writer)?; + } } bytes += CASChunkSequenceHeader::bookend().serialize(writer)?; @@ -350,19 +416,86 @@ impl MDBMinimalShard { Ok(bytes) } + + /// Serialize out a shard without any of the file information and a subset of xorb data that is given + /// by the xorb_filter_fn. Global deduplication chunk information is preserved. + pub fn serialize_xorb_subset_only( + &self, + writer: &mut W, + xorb_filter_fn: impl Fn(&MDBCASInfoView) -> bool, + ) -> Result { + self.serialize_impl(writer, false, false, xorb_filter_fn) + } + + /// Serialize out the given shard, sanitizing and updating the global dedup chunk flags and optionally + /// dropping the file verification section. + pub fn serialize(&self, writer: &mut W, with_verification: bool) -> Result { + self.serialize_impl(writer, true, with_verification, |_| true) + } + + /// Returns a list of all the global dedup eligible chunks, as given either by the hash value, file starts, or + /// the embedded global dedup flags. + pub fn global_dedup_eligible_chunks(&self) -> Vec { + // We need to get a list of all the chunk hashes that + // - References the first chunk of a file, or + // - hash_is_global_dedup_eligible(&hash) is true, or + // - has the global dedup flag set. + + let mut ret = HashSet::::new(); + + // To do the file lookup part efficiently, first scan through the files and record + // a lookup of xorb hash to offset. Thus when scanning through the xorb definitions, + // we can easily extract the hashes that match these indices. + let file_start_entries = self.file_start_entries(); + + for cas_idx in 0..self.num_cas() { + let Some(cas_view) = self.cas(cas_idx) else { + break; + }; + + let num_entries = cas_view.num_entries(); + + if let Some(fse) = file_start_entries.get(&cas_view.cas_hash()) { + for &c_idx in fse { + debug_assert_lt!(c_idx, num_entries); + + // Check bounds to be safe here to ensure things don't crash in production; would be + // an error and fail verification elsewhere. + if c_idx < num_entries { + let chunk_hash = cas_view.chunk(c_idx).chunk_hash; + ret.insert(chunk_hash); + } + } + } + + for c_idx in 0..num_entries { + let chunk = cas_view.chunk(c_idx); + + if chunk.is_global_dedup_eligible() { + ret.insert(chunk.chunk_hash); + } + } + } + + Vec::from_iter(ret) + } } #[cfg(test)] mod tests { + use std::collections::{HashMap, HashSet}; use std::io::Cursor; use anyhow::Result; + use merklehash::MerkleHash; + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; use super::MDBMinimalShard; use crate::MDBShardInfo; use crate::cas_structs::MDBCASInfo; use crate::file_structs::MDBFileInfo; - use crate::shard_file::test_routines::{convert_to_file, gen_random_shard}; + use crate::shard_file::test_routines::{convert_to_file, gen_random_shard, gen_random_shard_with_cas_references}; use crate::shard_in_memory::MDBInMemoryShard; fn verify_serialization(min_shard: &MDBMinimalShard, mem_shard: &MDBInMemoryShard) -> Result<()> { @@ -403,8 +536,18 @@ mod tests { assert_eq!(cas_info.len(), mem_cas_info.len(), "verification = {verification}"); + // Test for equality while ignoring the global dedup flag, as this gets modified on reserializing. for i in 0..cas_info.len() { - assert_eq!(&cas_info[i], mem_cas_info[i].as_ref(), "verification = {verification}"); + let c1 = &cas_info[i]; + let c2 = mem_cas_info[i].as_ref(); + + assert_eq!(c1.metadata, c2.metadata); + + for (ch1, ch2) in c1.chunks.iter().zip(c2.chunks.iter()) { + // Clear the global dedup one on the new serialized version, as it may have been set. + let ch1 = ch1.clone().with_global_dedup_flag(false); + assert_eq!(&ch1, ch2); + } } } @@ -486,7 +629,7 @@ mod tests { } #[tokio::test] - async fn test_empty_shards() -> Result<()> { + async fn test_shards() -> Result<()> { let shard = gen_random_shard(0, &[], &[0], false, false)?; verify_minimal_shard(&shard).await?; @@ -508,4 +651,114 @@ mod tests { Ok(()) } + + async fn verify_minimal_shard_dedup_processing(mem_shard: &MDBInMemoryShard) { + verify_minimal_shard(mem_shard).await.unwrap(); + + // Additionally, verify that the exporting functions work properly. + let buffer = convert_to_file(mem_shard).unwrap(); + let min_shard = MDBMinimalShard::from_reader(&mut Cursor::new(&buffer), true, true).unwrap(); + + // Calculate the global_dedup chunks. + let ref_global_dedup_chunks: HashSet<_> = min_shard.global_dedup_eligible_chunks().into_iter().collect(); + + // Produce a new minimal shard without the file info. + let mut xorb_only_shard_buffer = Vec::::new(); + min_shard + .serialize_xorb_subset_only(&mut xorb_only_shard_buffer, |_| true) + .unwrap(); + + let xorb_only_shard = + MDBMinimalShard::from_reader(&mut Cursor::new(&xorb_only_shard_buffer), true, true).unwrap(); + + let global_dedup_chunks: HashSet<_> = xorb_only_shard.global_dedup_eligible_chunks().into_iter().collect(); + + // Now make sure these are the same. + assert_eq!(ref_global_dedup_chunks, global_dedup_chunks); + + // Now, exclude subsets of the xorbs for testing to make sure that the filtering works properly. + // + // We'll do the filtering by excluding the xorbs with index in the given shard list less + // than a given value in a set. + // + // Annoyingly, our test setup allows some duplication between the chunks in the xorbs, so we end up + // having to account for that in the tests by allowing a chunk to be in multiple xorbs. + let mut chunk_hashes = HashMap::>::new(); + let mut xorb_map = HashMap::::new(); + + let mut rng = SmallRng::seed_from_u64(0); + + for xi in 0..min_shard.num_cas() { + let xorb = min_shard.cas(xi).unwrap(); + let group = rng.random_range(0..=3); + + xorb_map.insert(xorb.cas_hash(), group); + for ci in 0..xorb.num_entries() { + let chunk_hash = xorb.chunk(ci).chunk_hash; + if ref_global_dedup_chunks.contains(&chunk_hash) { + chunk_hashes.entry(chunk_hash).or_default().push(group); + } + } + } + + // Exclude xorbs with set index as given above. + for grp_set_threshhold in 1..4 { + let xorb_filter_fn = |xh| *xorb_map.get(&xh).unwrap() < grp_set_threshhold; + + // Get the reference set of xorbs. + let ref_filtered_xorbs: HashSet = + xorb_map.keys().filter(|&&xh| xorb_filter_fn(xh)).cloned().collect(); + + let ref_filtered_global_dedup_chunks: HashSet<_> = chunk_hashes + .iter() + .filter(|(_, grp_set)| grp_set.iter().any(|&grp| grp < grp_set_threshhold)) + .map(|(&ch, _)| ch) + .collect(); + + let mut xo_subset_shard_buffer = Vec::::new(); + min_shard + .serialize_xorb_subset_only(&mut xo_subset_shard_buffer, |xorb| xorb_filter_fn(xorb.cas_hash())) + .unwrap(); + + let xo_subset_shard = + MDBMinimalShard::from_reader(&mut Cursor::new(&xo_subset_shard_buffer), true, true).unwrap(); + + assert_eq!(xo_subset_shard.num_files(), 0); + assert_eq!(xo_subset_shard.num_cas(), ref_filtered_xorbs.len()); + + let xorbs_present: HashSet<_> = (0..xo_subset_shard.num_cas()) + .map(|i| xo_subset_shard.cas(i).unwrap().cas_hash()) + .collect(); + + assert_eq!(xorbs_present, ref_filtered_xorbs); + + let xo_global_dedup_chunks: HashSet<_> = + xo_subset_shard.global_dedup_eligible_chunks().into_iter().collect(); + + assert_eq!(ref_filtered_global_dedup_chunks, xo_global_dedup_chunks); + } + } + + // Tests to verify that all the shard filtering options are supported. + #[tokio::test] + async fn test_shard_processing() { + let shard = gen_random_shard_with_cas_references(1, &[1], &[1], false, false).unwrap(); + verify_minimal_shard_dedup_processing(&shard).await; + + // Tests to make sure the async and non-async match. + let shard = gen_random_shard_with_cas_references(1, &[2], &[1, 1], false, false).unwrap(); + verify_minimal_shard_dedup_processing(&shard).await; + + let shard = gen_random_shard_with_cas_references(1, &[1, 5, 10, 8], &[4, 3, 5, 9, 4, 6], false, false).unwrap(); + verify_minimal_shard_dedup_processing(&shard).await; + + let shard = gen_random_shard_with_cas_references(1, &[1, 5, 10, 8], &[4, 3, 5, 9, 4, 6], true, false).unwrap(); + verify_minimal_shard_dedup_processing(&shard).await; + + let shard = gen_random_shard_with_cas_references(1, &[1, 5, 10, 8], &[4, 3, 5, 9, 4, 6], false, true).unwrap(); + verify_minimal_shard_dedup_processing(&shard).await; + + let shard = gen_random_shard_with_cas_references(1, &[1, 5, 10, 8], &[4, 3, 5, 9, 4, 6], true, true).unwrap(); + verify_minimal_shard_dedup_processing(&shard).await; + } }