diff --git a/Cargo.lock b/Cargo.lock index 8a9d0058..79a2a6ed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3096,12 +3096,14 @@ name = "merklehash" version = "0.14.5" dependencies = [ "base64 0.22.1", + "bincode", "blake3", "getrandom 0.3.3", "heed", "rand 0.9.1", "safe-transmute", "serde", + "tokio", ] [[package]] @@ -3935,6 +3937,7 @@ dependencies = [ "merklehash", "more-asserts", "tokio", + "utils", ] [[package]] @@ -5846,6 +5849,7 @@ name = "utils" version = "0.14.5" dependencies = [ "async-trait", + "bincode", "bytes", "ctor", "derivative", @@ -5857,6 +5861,7 @@ dependencies = [ "merklehash", "pin-project", "rand 0.9.1", + "serde", "serial_test", "shellexpand", "tempfile", diff --git a/cas_client/src/simulation/client_testing_utils.rs b/cas_client/src/simulation/client_testing_utils.rs index 52a6e20e..737cca20 100644 --- a/cas_client/src/simulation/client_testing_utils.rs +++ b/cas_client/src/simulation/client_testing_utils.rs @@ -7,6 +7,7 @@ use mdb_shard::file_structs::{FileDataSequenceEntry, FileDataSequenceHeader, MDB use mdb_shard::shard_in_memory::MDBInMemoryShard; use merklehash::{MerkleHash, compute_data_hash, file_hash_with_salt}; use rand::prelude::*; +use utils::MerkleHashMap; use crate::error::Result; use crate::interface::Client; @@ -37,7 +38,7 @@ pub struct RandomFileContents { /// The complete file data. pub data: Bytes, /// The RawXorbData for each XORB that was created, keyed by XORB hash. - pub xorbs: HashMap, + pub xorbs: MerkleHashMap, /// Information about each term in file order. pub terms: Vec, } diff --git a/cas_client/src/simulation/local_client.rs b/cas_client/src/simulation/local_client.rs index f349c584..ef603b6c 100644 --- a/cas_client/src/simulation/local_client.rs +++ b/cas_client/src/simulation/local_client.rs @@ -32,6 +32,7 @@ use rand::Rng; use tempfile::TempDir; use tokio::time::{Duration, Instant}; use tracing::{error, info, warn}; +use utils::MerkleHashMap; use utils::serialization_utils::read_u32; use super::direct_access_client::DirectAccessClient; @@ -661,7 +662,7 @@ impl Client for LocalClient { byte_range: FileRange, } - let mut fetch_info_map: HashMap> = HashMap::new(); + let mut fetch_info_map: MerkleHashMap> = MerkleHashMap::new(); while s_idx < file_info.segments.len() && cumulative_bytes < file_range.end { let mut segment = file_info.segments[s_idx].clone(); diff --git a/cas_client/src/simulation/memory_client.rs b/cas_client/src/simulation/memory_client.rs index d4c308e3..6288a825 100644 --- a/cas_client/src/simulation/memory_client.rs +++ b/cas_client/src/simulation/memory_client.rs @@ -22,6 +22,7 @@ use rand::Rng; use tokio::sync::RwLock; use tokio::time::{Duration, Instant}; use tracing::{error, info}; +use utils::MerkleHashMap; use super::client_testing_utils::{FileTermReference, RandomFileContents}; use super::direct_access_client::DirectAccessClient; @@ -57,11 +58,11 @@ enum XorbStorage { /// In-memory client for testing purposes. Stores all data in memory using hash tables. pub struct MemoryClient { /// XORBs stored by hash - xorbs: RwLock>, + xorbs: RwLock>, /// In-memory shard for file reconstruction info shard: RwLock, /// Global dedup lookup: chunk_hash -> shard bytes - global_dedup: RwLock>, + global_dedup: RwLock>, /// Upload concurrency controller upload_concurrency_controller: Arc, /// URL expiration in milliseconds @@ -74,9 +75,9 @@ impl MemoryClient { /// Create a new in-memory client. pub fn new() -> Arc { Arc::new(Self { - xorbs: RwLock::new(HashMap::new()), + xorbs: RwLock::new(MerkleHashMap::new()), shard: RwLock::new(MDBInMemoryShard::default()), - global_dedup: RwLock::new(HashMap::new()), + global_dedup: RwLock::new(MerkleHashMap::new()), upload_concurrency_controller: AdaptiveConcurrencyController::new_upload("memory_uploads"), url_expiration_ms: AtomicU64::new(u64::MAX), random_ms_delay_window: (AtomicU64::new(0), AtomicU64::new(0)), @@ -205,7 +206,7 @@ impl MemoryClient { Ok(RandomFileContents { file_hash, data: Bytes::from(file_data), - xorbs: std::collections::HashMap::new(), + xorbs: MerkleHashMap::new(), terms: term_infos, }) } @@ -232,9 +233,9 @@ impl MemoryClient { impl Default for MemoryClient { fn default() -> Self { Self { - xorbs: RwLock::new(HashMap::new()), + xorbs: RwLock::new(MerkleHashMap::new()), shard: RwLock::new(MDBInMemoryShard::default()), - global_dedup: RwLock::new(HashMap::new()), + global_dedup: RwLock::new(MerkleHashMap::new()), upload_concurrency_controller: AdaptiveConcurrencyController::new_upload("memory_uploads"), url_expiration_ms: AtomicU64::new(u64::MAX), random_ms_delay_window: (AtomicU64::new(0), AtomicU64::new(0)), @@ -715,7 +716,7 @@ impl Client for MemoryClient { byte_range: FileRange, } - let mut fetch_info_map: HashMap> = HashMap::new(); + let mut fetch_info_map: MerkleHashMap> = MerkleHashMap::new(); let xorbs = self.xorbs.read().await; diff --git a/deduplication/src/file_deduplication.rs b/deduplication/src/file_deduplication.rs index d01d3567..fbce6e72 100644 --- a/deduplication/src/file_deduplication.rs +++ b/deduplication/src/file_deduplication.rs @@ -1,4 +1,3 @@ -use std::collections::HashMap; use std::result::Result; use mdb_shard::file_structs::{ @@ -8,6 +7,7 @@ use mdb_shard::hash_is_global_dedup_eligible; use merklehash::{MerkleHash, file_hash}; use more_asserts::{debug_assert_le, debug_assert_lt}; use progress_tracking::upload_tracking::FileXorbDependency; +use utils::MerkleHashMap; use crate::Chunk; use crate::constants::{MAX_XORB_BYTES, MAX_XORB_CHUNKS}; @@ -30,7 +30,7 @@ pub struct FileDeduper { new_data_size: usize, /// A hashmap allowing deduplication against the current chunk. - new_data_hash_lookup: HashMap, + new_data_hash_lookup: MerkleHashMap, /// The current chunk hashes for this file. chunk_hashes: Vec<(MerkleHash, u64)>, @@ -62,7 +62,7 @@ impl FileDeduper>, - chunk_lookup: HashMap, + chunk_lookup: TruncatedMerkleHashMap, } impl KeyedShardCollection { @@ -53,8 +53,8 @@ impl KeyedShardCollection { #[derive(Default)] struct ShardBookkeeper { shard_collections: Vec, - collection_by_key: HashMap, - shard_lookup_by_shard_hash: HashMap, + collection_by_key: MerkleHashMap, + shard_lookup_by_shard_hash: MerkleHashMap<(usize, usize)>, // We cap the number of chunks indexed for dedup; beyond those, we simply drop the search. total_indexed_chunks: usize, @@ -66,7 +66,7 @@ impl ShardBookkeeper { // we always try to dedup locally first. Self { shard_collections: vec![KeyedShardCollection::new(HMACKey::default())], - collection_by_key: HashMap::from([(HMACKey::default(), 0)]), + collection_by_key: MerkleHashMap::from([(HMACKey::default(), 0)]), ..Default::default() } } diff --git a/mdb_shard/src/shard_in_memory.rs b/mdb_shard/src/shard_in_memory.rs index 4240b8ce..41cd901b 100644 --- a/mdb_shard/src/shard_in_memory.rs +++ b/mdb_shard/src/shard_in_memory.rs @@ -1,6 +1,6 @@ // The shard structure for the in memory querying -use std::collections::{BTreeMap, HashMap}; +use std::collections::BTreeMap; use std::io::{BufWriter, Write}; use std::mem::size_of; use std::path::{Path, PathBuf}; @@ -9,6 +9,7 @@ use std::time::Duration; use merklehash::{HashedWrite, MerkleHash}; use tracing::debug; +use utils::MerkleHashMap; use crate::cas_structs::*; use crate::error::Result; @@ -21,7 +22,7 @@ use crate::utils::{shard_file_name, temp_shard_file_name}; pub struct MDBInMemoryShard { pub cas_content: BTreeMap>, pub file_content: BTreeMap, - pub chunk_hash_lookup: HashMap, u64)>, + pub chunk_hash_lookup: MerkleHashMap<(Arc, u64)>, current_shard_file_size: u64, } diff --git a/mdb_shard/src/streaming_shard.rs b/mdb_shard/src/streaming_shard.rs index f4a62a79..e7d374dc 100644 --- a/mdb_shard/src/streaming_shard.rs +++ b/mdb_shard/src/streaming_shard.rs @@ -1,4 +1,4 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use std::io::{Cursor, Read, Write, copy}; use std::mem::size_of; @@ -8,6 +8,7 @@ use futures_util::io::AsyncReadExt; use itertools::Itertools; use merklehash::MerkleHash; use more_asserts::debug_assert_lt; +use utils::MerkleHashMap; use crate::cas_structs::{CASChunkSequenceEntry, CASChunkSequenceHeader, MDBCASInfoView}; use crate::error::{MDBShardError, Result}; @@ -301,8 +302,8 @@ impl MDBMinimalShard { /// Return a lookup of xorb hash to starting chunk indices for all the files present in the /// shard. These are the chunks that are useful for global dedup. - fn file_start_entries(&self) -> HashMap> { - let mut file_start_entries = HashMap::>::new(); + fn file_start_entries(&self) -> MerkleHashMap> { + let mut file_start_entries = MerkleHashMap::>::new(); for f_idx in 0..self.num_files() { let Some(fv) = self.file(f_idx) else { diff --git a/merklehash/Cargo.toml b/merklehash/Cargo.toml index 88a048e7..1f5b973b 100644 --- a/merklehash/Cargo.toml +++ b/merklehash/Cargo.toml @@ -17,5 +17,9 @@ heed = { workspace = true } [target.'cfg(target_family = "wasm")'.dependencies] getrandom = { workspace = true, features = ["wasm_js"] } +[dev-dependencies] +bincode = { workspace = true } +tokio = { workspace = true, features = ["rt-multi-thread", "macros"] } + [features] strict = [] diff --git a/progress_tracking/Cargo.toml b/progress_tracking/Cargo.toml index 38944923..12c78fda 100644 --- a/progress_tracking/Cargo.toml +++ b/progress_tracking/Cargo.toml @@ -5,6 +5,7 @@ edition = "2024" [dependencies] merklehash = { path = "../merklehash" } +utils = { path = "../utils" } async-trait = { workspace = true } more-asserts = { workspace = true } diff --git a/progress_tracking/src/upload_tracking.rs b/progress_tracking/src/upload_tracking.rs index e2ea0fab..957adc6f 100644 --- a/progress_tracking/src/upload_tracking.rs +++ b/progress_tracking/src/upload_tracking.rs @@ -1,11 +1,12 @@ +use std::collections::BTreeSet; use std::collections::hash_map::Entry as HashMapEntry; -use std::collections::{BTreeSet, HashMap}; use std::mem::take; use std::sync::Arc; use merklehash::MerkleHash; use more_asserts::{debug_assert_ge, debug_assert_le}; use tokio::sync::Mutex; +use utils::MerkleHashMap; use crate::{ItemProgressUpdate, ProgressUpdate, TrackingProgressUpdater}; @@ -57,7 +58,7 @@ struct FileDependency { /// Mapping of xorb_hash -> (number of completed bytes / number of bytes of the file contained in that xorb). Only /// xorbs that are not uploaded yet are tracked here. /// Once an xorb is uploaded, we remove it from here (and add to `completed_bytes`). - remaining_xorbs_parts: HashMap, + remaining_xorbs_parts: MerkleHashMap, } /// Tracks all files and all xorbs, allowing you to register file @@ -68,7 +69,7 @@ struct CompletionTrackerImpl { /// List of all files being tracked. files: Vec, /// Map of xorb hash -> its dependency info (which files rely on it). - xorbs: HashMap, + xorbs: MerkleHashMap, /// Keep track of the totals across all xorbs. total_upload_bytes: u64, @@ -99,7 +100,7 @@ impl CompletionTrackerImpl { name: name.into(), total_bytes: n_bytes, completed_bytes: 0, - remaining_xorbs_parts: HashMap::new(), + remaining_xorbs_parts: MerkleHashMap::new(), }; // Insert it into our files vector. diff --git a/utils/Cargo.toml b/utils/Cargo.toml index 112ed97b..6ea9ee85 100644 --- a/utils/Cargo.toml +++ b/utils/Cargo.toml @@ -7,6 +7,10 @@ edition = "2024" name = "utils" path = "src/lib.rs" +[[bin]] +name = "benchmark_hashmaps" +path = "src/data_structures/bin/benchmark_hashmaps.rs" + [dependencies] error_printer = { path = "../error_printer" } merklehash = { path = "../merklehash" } @@ -19,6 +23,7 @@ duration-str = { workspace = true } futures = { workspace = true } lazy_static = { workspace = true } pin-project = { workspace = true } +serde = { workspace = true } shellexpand = { workspace = true } thiserror = { workspace = true } tokio = { workspace = true, features = [ @@ -31,6 +36,8 @@ tokio = { workspace = true, features = [ tracing = { workspace = true } [target.'cfg(not(target_family = "wasm"))'.dependencies] +bincode = { workspace = true } +rand = { workspace = true } tokio-util = { workspace = true, features = ["io"] } [target.'cfg(not(target_family = "wasm"))'.dev-dependencies] @@ -41,10 +48,11 @@ xet_runtime = { path = "../xet_runtime" } web-time = { workspace = true } [dev-dependencies] +bincode = { workspace = true } +futures-util = { workspace = true } rand = { workspace = true } serial_test = { workspace = true } -futures-util = { workspace = true } [features] strict = [] -elevated_information_level = [] \ No newline at end of file +elevated_information_level = [] diff --git a/utils/src/data_structures/bin/benchmark_hashmaps.rs b/utils/src/data_structures/bin/benchmark_hashmaps.rs new file mode 100644 index 00000000..5940d994 --- /dev/null +++ b/utils/src/data_structures/bin/benchmark_hashmaps.rs @@ -0,0 +1,342 @@ +//! Benchmark comparing HashMap and PassThroughHashMap performance. +//! +//! Run with: cargo run --bin benchmark_hashmaps --release + +use std::collections::{BTreeMap, HashMap}; +use std::time::{Duration, Instant}; + +use merklehash::{DataHash, MerkleHash}; +use rand::Rng; +use utils::data_structures::PassThroughHashMap; + +// Type aliases for cleaner code +type MerklePassThroughHashMap = PassThroughHashMap; + +const OPERATION_COUNTS: &[usize] = &[100_000, 1_000_000, 10_000_000]; + +// Data structure identifiers for columns +const DS_HASHMAP: &str = "HashMap"; +const DS_PASSTHROUGH: &str = "PassThrough"; + +/// Results storage for summary table +#[derive(Default)] +struct BenchmarkResults { + /// Map of (test_name, size, data_structure) -> duration + results: BTreeMap<(String, usize, String), Duration>, +} + +impl BenchmarkResults { + fn record(&mut self, test: &str, size: usize, ds: &str, duration: Duration) { + self.results.insert((test.to_string(), size, ds.to_string()), duration); + } + + fn print_summary_table(&self) { + let data_structures = [DS_HASHMAP, DS_PASSTHROUGH]; + let tests = ["Insert", "Lookup", "Insert+Lookup", "Serialize", "Deserialize"]; + + println!("\n{}", "=".repeat(65)); + println!("PERFORMANCE SUMMARY (times in ms, lower is better)"); + println!("{}", "=".repeat(65)); + + // Print header + print!("{:<25}", "Test"); + for ds in &data_structures { + print!("{:>20}", ds); + } + println!(); + println!("{}", "-".repeat(65)); + + for &size in OPERATION_COUNTS { + println!("--- {} ---", format_count(size)); + for test in &tests { + print!(" {:<23}", test); + for ds in &data_structures { + if let Some(duration) = self.results.get(&(test.to_string(), size, ds.to_string())) { + print!("{:>20}", format_duration_ms(*duration)); + } else { + print!("{:>20}", "-"); + } + } + println!(); + } + println!(); + } + println!("{}", "=".repeat(65)); + } +} + +fn generate_merkle_keys(count: usize) -> Vec { + let mut rng = rand::rng(); + (0..count) + .map(|_| { + let mut bytes = [0u8; 32]; + rng.fill(&mut bytes); + MerkleHash::from(bytes) + }) + .collect() +} + +fn format_count(count: usize) -> String { + if count >= 1_000_000 { + format!("{}M", count / 1_000_000) + } else if count >= 1_000 { + format!("{}K", count / 1_000) + } else { + format!("{}", count) + } +} + +fn format_duration(d: Duration) -> String { + if d.as_secs() > 0 { + format!("{:.2}s", d.as_secs_f64()) + } else if d.as_millis() > 0 { + format!("{:.2}ms", d.as_secs_f64() * 1000.0) + } else { + format!("{:.2}µs", d.as_secs_f64() * 1_000_000.0) + } +} + +fn format_duration_ms(d: Duration) -> String { + format!("{:.1}", d.as_secs_f64() * 1000.0) +} + +fn format_ops_per_sec(count: usize, d: Duration) -> String { + let ops = count as f64 / d.as_secs_f64(); + if ops >= 1_000_000_000.0 { + format!("{:.2}B/s", ops / 1_000_000_000.0) + } else if ops >= 1_000_000.0 { + format!("{:.2}M/s", ops / 1_000_000.0) + } else if ops >= 1_000.0 { + format!("{:.2}K/s", ops / 1_000.0) + } else { + format!("{:.2}/s", ops) + } +} + +fn print_result(name: &str, count: usize, duration: Duration) { + println!(" {:<40} {:>12} ({:>12})", name, format_duration(duration), format_ops_per_sec(count, duration)); +} + +// ============================================================================ +// Benchmarks +// ============================================================================ + +fn bench_insert(keys: &[DataHash], map_type: &str) -> Duration { + match map_type { + DS_HASHMAP => { + let start = Instant::now(); + let mut map: HashMap = HashMap::with_capacity(keys.len()); + for (i, key) in keys.iter().enumerate() { + map.insert(*key, i as u64); + } + std::hint::black_box(&map); + start.elapsed() + }, + DS_PASSTHROUGH => { + let start = Instant::now(); + let mut map: MerklePassThroughHashMap = PassThroughHashMap::with_capacity(keys.len()); + for (i, key) in keys.iter().enumerate() { + map.insert(*key, i as u64); + } + std::hint::black_box(&map); + start.elapsed() + }, + _ => Duration::ZERO, + } +} + +fn bench_lookup(keys: &[DataHash], map_type: &str) -> Duration { + // First build the map + match map_type { + DS_HASHMAP => { + let mut map: HashMap = HashMap::with_capacity(keys.len()); + for (i, key) in keys.iter().enumerate() { + map.insert(*key, i as u64); + } + let start = Instant::now(); + let mut sum = 0u64; + for key in keys { + if let Some(v) = map.get(key) { + sum = sum.wrapping_add(*v); + } + } + std::hint::black_box(sum); + start.elapsed() + }, + DS_PASSTHROUGH => { + let mut map: MerklePassThroughHashMap = PassThroughHashMap::with_capacity(keys.len()); + for (i, key) in keys.iter().enumerate() { + map.insert(*key, i as u64); + } + let start = Instant::now(); + let mut sum = 0u64; + for key in keys { + if let Some(v) = map.get(key) { + sum = sum.wrapping_add(*v); + } + } + std::hint::black_box(sum); + start.elapsed() + }, + _ => Duration::ZERO, + } +} + +fn bench_insert_then_lookup(keys: &[DataHash], map_type: &str) -> Duration { + match map_type { + DS_HASHMAP => { + let start = Instant::now(); + let mut map: HashMap = HashMap::with_capacity(keys.len()); + for (i, key) in keys.iter().enumerate() { + map.insert(*key, i as u64); + } + let mut sum = 0u64; + for key in keys { + if let Some(v) = map.get(key) { + sum = sum.wrapping_add(*v); + } + } + std::hint::black_box(sum); + start.elapsed() + }, + DS_PASSTHROUGH => { + let start = Instant::now(); + let mut map: MerklePassThroughHashMap = PassThroughHashMap::with_capacity(keys.len()); + for (i, key) in keys.iter().enumerate() { + map.insert(*key, i as u64); + } + let mut sum = 0u64; + for key in keys { + if let Some(v) = map.get(key) { + sum = sum.wrapping_add(*v); + } + } + std::hint::black_box(sum); + start.elapsed() + }, + _ => Duration::ZERO, + } +} + +fn bench_serialize(keys: &[DataHash], map_type: &str) -> Duration { + match map_type { + DS_HASHMAP => { + let mut map: HashMap = HashMap::with_capacity(keys.len()); + for (i, key) in keys.iter().enumerate() { + map.insert(*key, i as u64); + } + let start = Instant::now(); + let bytes = bincode::serialize(&map).unwrap(); + std::hint::black_box(&bytes); + start.elapsed() + }, + DS_PASSTHROUGH => { + let mut map: MerklePassThroughHashMap = PassThroughHashMap::with_capacity(keys.len()); + for (i, key) in keys.iter().enumerate() { + map.insert(*key, i as u64); + } + let start = Instant::now(); + let bytes = bincode::serialize(&map).unwrap(); + std::hint::black_box(&bytes); + start.elapsed() + }, + _ => Duration::ZERO, + } +} + +fn bench_deserialize(keys: &[DataHash], map_type: &str) -> Duration { + match map_type { + DS_HASHMAP => { + let mut map: HashMap = HashMap::with_capacity(keys.len()); + for (i, key) in keys.iter().enumerate() { + map.insert(*key, i as u64); + } + let bytes = bincode::serialize(&map).unwrap(); + let start = Instant::now(); + let map2: HashMap = bincode::deserialize(&bytes).unwrap(); + std::hint::black_box(&map2); + start.elapsed() + }, + DS_PASSTHROUGH => { + let mut map: MerklePassThroughHashMap = PassThroughHashMap::with_capacity(keys.len()); + for (i, key) in keys.iter().enumerate() { + map.insert(*key, i as u64); + } + let bytes = bincode::serialize(&map).unwrap(); + let start = Instant::now(); + let map2: MerklePassThroughHashMap = bincode::deserialize(&bytes).unwrap(); + std::hint::black_box(&map2); + start.elapsed() + }, + _ => Duration::ZERO, + } +} + +// ============================================================================ +// Main Benchmark Runner +// ============================================================================ + +fn run_benchmarks(count: usize, results: &mut BenchmarkResults) { + println!("\n{}", "=".repeat(80)); + println!("Benchmarking {} operations", count); + println!("{}", "=".repeat(80)); + + println!("\n--- Generating keys ---"); + let keys = generate_merkle_keys(count); + println!(" Generated {} MerkleHash keys", count); + + println!("\n--- Benchmarks ---"); + + for ds in [DS_HASHMAP, DS_PASSTHROUGH] { + let d = bench_insert(&keys, ds); + print_result(&format!("{} Insert", ds), count, d); + results.record("Insert", count, ds, d); + } + + for ds in [DS_HASHMAP, DS_PASSTHROUGH] { + let d = bench_lookup(&keys, ds); + print_result(&format!("{} Lookup", ds), count, d); + results.record("Lookup", count, ds, d); + } + + for ds in [DS_HASHMAP, DS_PASSTHROUGH] { + let d = bench_insert_then_lookup(&keys, ds); + print_result(&format!("{} Insert+Lookup", ds), count, d); + results.record("Insert+Lookup", count, ds, d); + } + + for ds in [DS_HASHMAP, DS_PASSTHROUGH] { + let d = bench_serialize(&keys, ds); + print_result(&format!("{} Serialize", ds), count, d); + results.record("Serialize", count, ds, d); + } + + for ds in [DS_HASHMAP, DS_PASSTHROUGH] { + let d = bench_deserialize(&keys, ds); + print_result(&format!("{} Deserialize", ds), count, d); + results.record("Deserialize", count, ds, d); + } +} + +fn main() { + println!("HashMap Benchmark Suite"); + println!("======================="); + println!(); + println!("Comparing:"); + println!(" - HashMap: std::collections::HashMap"); + println!(" - PassThrough: PassThroughHashMap (optimized hasher for MerkleHash keys)"); + println!(); + println!("Key type: MerkleHash (32 bytes)"); + println!("Value type: u64"); + + let mut results = BenchmarkResults::default(); + + for &count in OPERATION_COUNTS { + run_benchmarks(count, &mut results); + } + + // Print summary table + results.print_summary_table(); + + println!("\nBenchmark complete!"); +} diff --git a/utils/src/data_structures/mod.rs b/utils/src/data_structures/mod.rs new file mode 100644 index 00000000..4ea27a4b --- /dev/null +++ b/utils/src/data_structures/mod.rs @@ -0,0 +1,17 @@ +mod passthrough_hasher; +mod passthrough_hashmap; + +use merklehash::MerkleHash; +pub use passthrough_hasher::U64HashExtractable; +pub use passthrough_hashmap::PassThroughHashMap; + +/// A HashMap specialized for `MerkleHash` keys using passthrough hashing. +/// +/// This is a type alias for `PassThroughHashMap`. +pub type MerkleHashMap = PassThroughHashMap; + +/// A HashMap specialized for `u64` keys using passthrough hashing. +/// +/// This is useful when the key is already a truncated hash value (e.g., the first 8 bytes +/// of a larger hash), and we want to avoid re-hashing. +pub type TruncatedMerkleHashMap = PassThroughHashMap; diff --git a/utils/src/data_structures/passthrough_hasher.rs b/utils/src/data_structures/passthrough_hasher.rs new file mode 100644 index 00000000..8e626340 --- /dev/null +++ b/utils/src/data_structures/passthrough_hasher.rs @@ -0,0 +1,143 @@ +use std::hash::{BuildHasher, Hasher}; +use std::marker::PhantomData; + +use merklehash::MerkleHash; +use serde::{Deserialize, Serialize}; + +/// Trait for types that can efficiently provide a u64 hash value. +/// Types implementing this trait are optimized for use with `U64DirectHasher`, +/// which avoids extra computation by directly using the u64 value as the hash. +pub trait U64HashExtractable { + fn u64_hash_value(&self) -> u64; +} + +impl U64HashExtractable for u64 { + fn u64_hash_value(&self) -> u64 { + *self + } +} + +impl U64HashExtractable for MerkleHash { + fn u64_hash_value(&self) -> u64 { + self[0] + } +} + +/// A hasher that maps directly to a u64 hash value for types implementing U64HashExtractable. +/// +/// This hasher is designed to work with types that already contain high-quality hash values +/// (like cryptographic hashes), avoiding extra computation and supporting direct bucket selection. +/// +/// Note: This is not resistant to attacks where all keys are designed to hash to the same bucket +/// and thus make the lookup time linear, but in our use case, where it's already a cryptographic +/// hash, this gives us two advantages: +/// - Speedup for lookup. +/// - Consistent serialization order, so MUCH faster deserialization. +#[derive(Clone, Copy, Serialize, Deserialize)] +pub struct U64DirectHasher { + state: u64, + #[serde(skip)] + _phantom: PhantomData, +} + +impl Default for U64DirectHasher { + fn default() -> Self { + Self { + state: 0, + _phantom: PhantomData, + } + } +} + +impl Hasher for U64DirectHasher { + fn finish(&self) -> u64 { + self.state + } + + fn write(&mut self, bytes: &[u8]) { + debug_assert!(bytes.len() >= 8); + + unsafe { + let dest = &mut self.state as *mut u64 as *mut u8; + bytes.as_ptr().copy_to_nonoverlapping(dest, 8); + } + } + + fn write_u64(&mut self, i: u64) { + self.state = i; + } +} + +impl BuildHasher for U64DirectHasher { + type Hasher = U64DirectHasher; + + fn build_hasher(&self) -> Self::Hasher { + U64DirectHasher::default() + } +} + +#[cfg(test)] +mod tests { + use std::hash::Hash; + + use merklehash::DataHash; + + use super::*; + + #[test] + fn test_hash_uses_first_u64() { + let mut hasher1 = U64DirectHasher::::default(); + let mut hasher2 = U64DirectHasher::::default(); + + let hash1 = DataHash::from([0x1234567890ABCDEF, 0, 0, 0]); + let hash2 = DataHash::from([ + 0x1234567890ABCDEF, + 0xFFFFFFFFFFFFFFFF, + 0xFFFFFFFFFFFFFFFF, + 0xFFFFFFFFFFFFFFFF, + ]); + + hash1.hash(&mut hasher1); + hash2.hash(&mut hasher2); + + assert_eq!(hasher1.finish(), hasher2.finish()); + } + + #[test] + fn test_hash_different_first_u64() { + let mut hasher1 = U64DirectHasher::::default(); + let mut hasher2 = U64DirectHasher::::default(); + + let hash1 = DataHash::from([0x1234567890ABCDEF, 0, 0, 0]); + let hash2 = DataHash::from([0xFEDCBA0987654321, 0, 0, 0]); + + hash1.hash(&mut hasher1); + hash2.hash(&mut hasher2); + + assert_ne!(hasher1.finish(), hasher2.finish()); + } + + #[test] + fn test_u64_hash_extractable() { + let value: u64 = 0x1234567890ABCDEF; + assert_eq!(value.u64_hash_value(), 0x1234567890ABCDEF); + + let hash = MerkleHash::from([0xFEDCBA0987654321, 0x1111111111111111, 0, 0]); + assert_eq!(hash.u64_hash_value(), 0xFEDCBA0987654321); + } + + #[test] + fn test_u64_passthrough_hasher() { + let mut hasher1 = U64DirectHasher::::default(); + let mut hasher2 = U64DirectHasher::::default(); + + let value1: u64 = 0x1234567890ABCDEF; + let value2: u64 = 0x1234567890ABCDEF; + + value1.hash(&mut hasher1); + value2.hash(&mut hasher2); + + assert_eq!(hasher1.finish(), hasher2.finish()); + assert_eq!(hasher1.finish(), value1); + } +} diff --git a/utils/src/data_structures/passthrough_hashmap.rs b/utils/src/data_structures/passthrough_hashmap.rs new file mode 100644 index 00000000..a7fdc65e --- /dev/null +++ b/utils/src/data_structures/passthrough_hashmap.rs @@ -0,0 +1,484 @@ +use std::collections::HashMap; +use std::hash::Hash; +use std::ops::{Deref, DerefMut}; + +use serde::{Deserialize, Serialize}; + +use super::passthrough_hasher::{U64DirectHasher, U64HashExtractable}; + +/// A HashMap wrapper optimized for keys that implement `U64HashExtractable`. +/// +/// This structure uses `U64DirectHasher` which avoids extra hash computation +/// by directly using the u64 value extracted from the key. This is particularly +/// efficient for cryptographic hashes (like `MerkleHash`) where the first 8 bytes +/// already provide excellent distribution. +/// +/// # Behavior +/// +/// This type implements `Deref` and `DerefMut` to the underlying `HashMap`, so it +/// behaves exactly like a standard `HashMap`. All `HashMap` methods are available +/// directly on this type: +/// +/// ```ignore +/// let mut map: PassThroughHashMap = PassThroughHashMap::new(); +/// +/// // All HashMap methods work directly: +/// map.insert(key, "value".to_string()); +/// map.get(&key); +/// map.contains_key(&key); +/// map.remove(&key); +/// map.len(); +/// map.is_empty(); +/// map.iter(); +/// map.entry(key).or_insert("default".to_string()); +/// // ... and all other HashMap methods +/// ``` +/// +/// # Type Parameters +/// - `Key`: The key type (must implement `U64HashExtractable + Hash + Eq`) +/// - `Value`: The value type stored in the map +pub struct PassThroughHashMap +where + Key: U64HashExtractable + Hash + Eq, +{ + inner: HashMap>, +} + +impl PassThroughHashMap +where + Key: U64HashExtractable + Hash + Eq, +{ + pub fn new() -> Self { + Self { + inner: HashMap::with_hasher(U64DirectHasher::default()), + } + } + + pub fn with_capacity(capacity: usize) -> Self { + Self { + inner: HashMap::with_capacity_and_hasher(capacity, U64DirectHasher::default()), + } + } +} + +impl Default for PassThroughHashMap +where + Key: U64HashExtractable + Hash + Eq, +{ + fn default() -> Self { + Self::new() + } +} + +impl std::fmt::Debug for PassThroughHashMap +where + Key: U64HashExtractable + Hash + Eq + std::fmt::Debug, + Value: std::fmt::Debug, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_map().entries(self.inner.iter()).finish() + } +} + +impl Clone for PassThroughHashMap +where + Key: U64HashExtractable + Hash + Eq + Clone, + Value: Clone, +{ + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + } + } +} + +impl std::iter::FromIterator<(Key, Value)> for PassThroughHashMap +where + Key: U64HashExtractable + Hash + Eq, +{ + fn from_iter>(iter: I) -> Self { + let mut map = Self::new(); + for (k, v) in iter { + map.insert(k, v); + } + map + } +} + +impl IntoIterator for PassThroughHashMap +where + Key: U64HashExtractable + Hash + Eq, +{ + type Item = (Key, Value); + type IntoIter = std::collections::hash_map::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.inner.into_iter() + } +} + +/// Provides immutable access to all `HashMap` methods. +/// +/// This allows using `PassThroughHashMap` exactly like a `HashMap`: +/// `get`, `contains_key`, `len`, `is_empty`, `iter`, `keys`, `values`, etc. +impl Deref for PassThroughHashMap +where + Key: U64HashExtractable + Hash + Eq, +{ + type Target = HashMap>; + + fn deref(&self) -> &Self::Target { + &self.inner + } +} + +/// Provides mutable access to all `HashMap` methods. +/// +/// This allows using `PassThroughHashMap` exactly like a `HashMap`: +/// `insert`, `remove`, `clear`, `get_mut`, `entry`, `iter_mut`, etc. +impl DerefMut for PassThroughHashMap +where + Key: U64HashExtractable + Hash + Eq, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.inner + } +} + +impl From>> for PassThroughHashMap +where + Key: U64HashExtractable + Hash + Eq, +{ + fn from(inner: HashMap>) -> Self { + Self { inner } + } +} + +impl From> for HashMap> +where + Key: U64HashExtractable + Hash + Eq, +{ + fn from(val: PassThroughHashMap) -> Self { + val.inner + } +} + +impl From<[(Key, Value); N]> for PassThroughHashMap +where + Key: U64HashExtractable + Hash + Eq, +{ + fn from(arr: [(Key, Value); N]) -> Self { + arr.into_iter().collect() + } +} + +impl Serialize for PassThroughHashMap +where + Key: U64HashExtractable + Hash + Eq + Serialize, + Value: Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + use serde::ser::SerializeSeq; + + let mut seq = serializer.serialize_seq(Some(self.inner.len()))?; + for (key, value) in self.inner.iter() { + seq.serialize_element(&(key, value))?; + } + seq.end() + } +} + +impl<'de, Key, Value> Deserialize<'de> for PassThroughHashMap +where + Key: U64HashExtractable + Hash + Eq + Deserialize<'de>, + Value: Deserialize<'de>, +{ + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct PassThroughVisitor { + _marker: std::marker::PhantomData<(Key, Value)>, + } + + impl<'de, Key, Value> serde::de::Visitor<'de> for PassThroughVisitor + where + Key: U64HashExtractable + Hash + Eq + Deserialize<'de>, + Value: Deserialize<'de>, + { + type Value = PassThroughHashMap; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a sequence of key-value pairs") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + let capacity = seq.size_hint().unwrap_or(0); + let mut map = PassThroughHashMap::with_capacity(capacity); + + while let Some((key, value)) = seq.next_element()? { + map.insert(key, value); + } + + Ok(map) + } + } + + deserializer.deserialize_seq(PassThroughVisitor { + _marker: std::marker::PhantomData, + }) + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use merklehash::{DataHash, compute_data_hash}; + + use super::*; + + type MerkleHashMap = PassThroughHashMap; + + #[test] + fn test_new() { + let table: MerkleHashMap = MerkleHashMap::new(); + assert!(table.is_empty()); + assert_eq!(table.len(), 0); + } + + #[test] + fn test_default() { + let table: MerkleHashMap = MerkleHashMap::default(); + assert!(table.is_empty()); + } + + #[test] + fn test_with_capacity() { + let table: MerkleHashMap = MerkleHashMap::with_capacity(100); + assert!(table.is_empty()); + } + + #[test] + fn test_insert_and_get() { + let mut table = MerkleHashMap::new(); + let hash1 = compute_data_hash(b"test1"); + let hash2 = compute_data_hash(b"test2"); + + assert_eq!(table.insert(hash1, "value1"), None); + assert_eq!(table.insert(hash2, "value2"), None); + assert_eq!(table.len(), 2); + + assert_eq!(table.get(&hash1), Some(&"value1")); + assert_eq!(table.get(&hash2), Some(&"value2")); + } + + #[test] + fn test_insert_overwrite() { + let mut table = MerkleHashMap::new(); + let hash = compute_data_hash(b"test"); + + assert_eq!(table.insert(hash, "value1"), None); + assert_eq!(table.insert(hash, "value2"), Some("value1")); + assert_eq!(table.get(&hash), Some(&"value2")); + } + + #[test] + fn test_contains_key() { + let mut table = MerkleHashMap::new(); + let hash1 = compute_data_hash(b"test1"); + let hash2 = compute_data_hash(b"test2"); + + table.insert(hash1, 42); + assert!(table.contains_key(&hash1)); + assert!(!table.contains_key(&hash2)); + } + + #[test] + fn test_remove() { + let mut table = MerkleHashMap::new(); + let hash1 = compute_data_hash(b"test1"); + let hash2 = compute_data_hash(b"test2"); + + table.insert(hash1, "value1"); + table.insert(hash2, "value2"); + assert_eq!(table.len(), 2); + + assert_eq!(table.remove(&hash1), Some("value1")); + assert_eq!(table.len(), 1); + assert_eq!(table.get(&hash1), None); + assert_eq!(table.get(&hash2), Some(&"value2")); + } + + #[test] + fn test_remove_nonexistent() { + let mut table: MerkleHashMap<&str> = MerkleHashMap::new(); + let hash = compute_data_hash(b"test"); + + assert_eq!(table.remove(&hash), None); + } + + #[test] + fn test_clear() { + let mut table = MerkleHashMap::new(); + let hash1 = compute_data_hash(b"test1"); + let hash2 = compute_data_hash(b"test2"); + + table.insert(hash1, "value1"); + table.insert(hash2, "value2"); + assert_eq!(table.len(), 2); + + table.clear(); + assert!(table.is_empty()); + assert_eq!(table.len(), 0); + } + + #[test] + fn test_iter() { + let mut table = MerkleHashMap::new(); + let hash1 = compute_data_hash(b"test1"); + let hash2 = compute_data_hash(b"test2"); + + table.insert(hash1, 10); + table.insert(hash2, 20); + + let items: Vec<_> = table.iter().collect(); + assert_eq!(items.len(), 2); + + let values: Vec<_> = items.iter().map(|(_, v)| **v).collect(); + assert!(values.contains(&10)); + assert!(values.contains(&20)); + } + + #[test] + fn test_iter_mut() { + let mut table = MerkleHashMap::new(); + let hash = compute_data_hash(b"test"); + + table.insert(hash, 10); + *table.get_mut(&hash).unwrap() = 20; + assert_eq!(table.get(&hash), Some(&20)); + } + + #[test] + fn test_keys() { + let mut table = MerkleHashMap::new(); + let hash1 = compute_data_hash(b"test1"); + let hash2 = compute_data_hash(b"test2"); + + table.insert(hash1, "value1"); + table.insert(hash2, "value2"); + + let keys: Vec<_> = table.keys().collect(); + assert_eq!(keys.len(), 2); + assert!(keys.contains(&&hash1)); + assert!(keys.contains(&&hash2)); + } + + #[test] + fn test_values() { + let mut table = MerkleHashMap::new(); + let hash1 = compute_data_hash(b"test1"); + let hash2 = compute_data_hash(b"test2"); + + table.insert(hash1, "value1"); + table.insert(hash2, "value2"); + + let values: Vec<_> = table.values().collect(); + assert_eq!(values.len(), 2); + assert!(values.contains(&&"value1")); + assert!(values.contains(&&"value2")); + } + + #[test] + fn test_entry() { + let mut table = MerkleHashMap::new(); + let hash = compute_data_hash(b"test"); + + table.entry(hash).or_insert("default"); + assert_eq!(table.get(&hash), Some(&"default")); + + table.entry(hash).and_modify(|v| *v = "modified"); + assert_eq!(table.get(&hash), Some(&"modified")); + } + + #[test] + fn test_from_hashmap() { + let mut hashmap = HashMap::with_hasher(U64DirectHasher::::default()); + let hash1 = compute_data_hash(b"test1"); + let hash2 = compute_data_hash(b"test2"); + + hashmap.insert(hash1, 10); + hashmap.insert(hash2, 20); + + let table: MerkleHashMap = MerkleHashMap::from(hashmap); + assert_eq!(table.len(), 2); + assert_eq!(table.get(&hash1), Some(&10)); + assert_eq!(table.get(&hash2), Some(&20)); + } + + #[test] + fn test_into_hashmap() { + let mut table = MerkleHashMap::new(); + let hash1 = compute_data_hash(b"test1"); + let hash2 = compute_data_hash(b"test2"); + + table.insert(hash1, 10); + table.insert(hash2, 20); + + let hashmap: HashMap> = table.into(); + assert_eq!(hashmap.len(), 2); + assert_eq!(hashmap.get(&hash1), Some(&10)); + assert_eq!(hashmap.get(&hash2), Some(&20)); + } + + #[test] + fn test_serialize_deserialize() { + let mut table = MerkleHashMap::new(); + let hash1 = compute_data_hash(b"test1"); + let hash2 = compute_data_hash(b"test2"); + let hash3 = compute_data_hash(b"test3"); + + table.insert(hash1, 10); + table.insert(hash2, 20); + table.insert(hash3, 30); + + let serialized = bincode::serialize(&table).unwrap(); + let deserialized: MerkleHashMap = bincode::deserialize(&serialized).unwrap(); + + assert_eq!(deserialized.len(), 3); + assert_eq!(deserialized.get(&hash1), Some(&10)); + assert_eq!(deserialized.get(&hash2), Some(&20)); + assert_eq!(deserialized.get(&hash3), Some(&30)); + } + + #[test] + fn test_serialize_deserialize_empty() { + let table: MerkleHashMap = MerkleHashMap::new(); + + let serialized = bincode::serialize(&table).unwrap(); + let deserialized: MerkleHashMap = bincode::deserialize(&serialized).unwrap(); + + assert!(deserialized.is_empty()); + assert_eq!(deserialized.len(), 0); + } + + #[test] + fn test_u64_key_hashmap() { + type TruncatedHashMap = PassThroughHashMap; + + let mut table: TruncatedHashMap = TruncatedHashMap::new(); + table.insert(12345, "value1".to_string()); + table.insert(67890, "value2".to_string()); + + assert_eq!(table.get(&12345), Some(&"value1".to_string())); + assert_eq!(table.get(&67890), Some(&"value2".to_string())); + assert_eq!(table.len(), 2); + } +} diff --git a/utils/src/lib.rs b/utils/src/lib.rs index 808a764f..3252dbb9 100644 --- a/utils/src/lib.rs +++ b/utils/src/lib.rs @@ -1,6 +1,8 @@ #![cfg_attr(feature = "strict", deny(warnings))] pub mod async_iterator; +pub mod data_structures; +pub use data_structures::{MerkleHashMap, PassThroughHashMap, TruncatedMerkleHashMap, U64HashExtractable}; pub mod async_read; pub mod auth; pub mod errors;