mirror of
https://github.com/huggingface/xet-core.git
synced 2026-06-04 13:30:29 +08:00
Hash table with pass-through hasher for MerkleHashes (#611)
Currently, the rust HashMap uses a randomized hasher for input, which prevents hash collision attacks. However, in our code, we don't need that protection in the client, and a MerkleHash is already a cryptographic hash. This PR adds a MerkleHashMap type that just passes the hash through to the HashMap, providing a substantial speedup: ``` ================================================================= PERFORMANCE SUMMARY (times in ms, lower is better) ================================================================= Test HashMap PassThrough ----------------------------------------------------------------- --- 100K --- Insert 2.1 0.7 Lookup 2.1 1.3 Insert+Lookup 4.4 1.6 Serialize 1.6 0.9 Deserialize 4.3 1.2 --- 10M --- Insert 433.2 204.1 Lookup 615.3 255.5 Insert+Lookup 951.6 460.4 Serialize 117.2 93.4 Deserialize 599.5 89.3 ================================================================= ``` It also replaces HashMap<MerkleHash, ...> everywhere in the code to provide an across-the-board improvement.
This commit is contained in:
5
Cargo.lock
generated
5
Cargo.lock
generated
@@ -3096,12 +3096,14 @@ name = "merklehash"
|
||||
version = "0.14.5"
|
||||
dependencies = [
|
||||
"base64 0.22.1",
|
||||
"bincode",
|
||||
"blake3",
|
||||
"getrandom 0.3.3",
|
||||
"heed",
|
||||
"rand 0.9.1",
|
||||
"safe-transmute",
|
||||
"serde",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3935,6 +3937,7 @@ dependencies = [
|
||||
"merklehash",
|
||||
"more-asserts",
|
||||
"tokio",
|
||||
"utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5846,6 +5849,7 @@ name = "utils"
|
||||
version = "0.14.5"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"bincode",
|
||||
"bytes",
|
||||
"ctor",
|
||||
"derivative",
|
||||
@@ -5857,6 +5861,7 @@ dependencies = [
|
||||
"merklehash",
|
||||
"pin-project",
|
||||
"rand 0.9.1",
|
||||
"serde",
|
||||
"serial_test",
|
||||
"shellexpand",
|
||||
"tempfile",
|
||||
|
||||
@@ -7,6 +7,7 @@ use mdb_shard::file_structs::{FileDataSequenceEntry, FileDataSequenceHeader, MDB
|
||||
use mdb_shard::shard_in_memory::MDBInMemoryShard;
|
||||
use merklehash::{MerkleHash, compute_data_hash, file_hash_with_salt};
|
||||
use rand::prelude::*;
|
||||
use utils::MerkleHashMap;
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::interface::Client;
|
||||
@@ -37,7 +38,7 @@ pub struct RandomFileContents {
|
||||
/// The complete file data.
|
||||
pub data: Bytes,
|
||||
/// The RawXorbData for each XORB that was created, keyed by XORB hash.
|
||||
pub xorbs: HashMap<MerkleHash, RawXorbData>,
|
||||
pub xorbs: MerkleHashMap<RawXorbData>,
|
||||
/// Information about each term in file order.
|
||||
pub terms: Vec<FileTermReference>,
|
||||
}
|
||||
|
||||
@@ -32,6 +32,7 @@ use rand::Rng;
|
||||
use tempfile::TempDir;
|
||||
use tokio::time::{Duration, Instant};
|
||||
use tracing::{error, info, warn};
|
||||
use utils::MerkleHashMap;
|
||||
use utils::serialization_utils::read_u32;
|
||||
|
||||
use super::direct_access_client::DirectAccessClient;
|
||||
@@ -661,7 +662,7 @@ impl Client for LocalClient {
|
||||
byte_range: FileRange,
|
||||
}
|
||||
|
||||
let mut fetch_info_map: HashMap<MerkleHash, Vec<FetchInfoIntermediate>> = HashMap::new();
|
||||
let mut fetch_info_map: MerkleHashMap<Vec<FetchInfoIntermediate>> = MerkleHashMap::new();
|
||||
|
||||
while s_idx < file_info.segments.len() && cumulative_bytes < file_range.end {
|
||||
let mut segment = file_info.segments[s_idx].clone();
|
||||
|
||||
@@ -22,6 +22,7 @@ use rand::Rng;
|
||||
use tokio::sync::RwLock;
|
||||
use tokio::time::{Duration, Instant};
|
||||
use tracing::{error, info};
|
||||
use utils::MerkleHashMap;
|
||||
|
||||
use super::client_testing_utils::{FileTermReference, RandomFileContents};
|
||||
use super::direct_access_client::DirectAccessClient;
|
||||
@@ -57,11 +58,11 @@ enum XorbStorage {
|
||||
/// In-memory client for testing purposes. Stores all data in memory using hash tables.
|
||||
pub struct MemoryClient {
|
||||
/// XORBs stored by hash
|
||||
xorbs: RwLock<HashMap<MerkleHash, XorbStorage>>,
|
||||
xorbs: RwLock<MerkleHashMap<XorbStorage>>,
|
||||
/// In-memory shard for file reconstruction info
|
||||
shard: RwLock<MDBInMemoryShard>,
|
||||
/// Global dedup lookup: chunk_hash -> shard bytes
|
||||
global_dedup: RwLock<HashMap<MerkleHash, Bytes>>,
|
||||
global_dedup: RwLock<MerkleHashMap<Bytes>>,
|
||||
/// Upload concurrency controller
|
||||
upload_concurrency_controller: Arc<AdaptiveConcurrencyController>,
|
||||
/// URL expiration in milliseconds
|
||||
@@ -74,9 +75,9 @@ impl MemoryClient {
|
||||
/// Create a new in-memory client.
|
||||
pub fn new() -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
xorbs: RwLock::new(HashMap::new()),
|
||||
xorbs: RwLock::new(MerkleHashMap::new()),
|
||||
shard: RwLock::new(MDBInMemoryShard::default()),
|
||||
global_dedup: RwLock::new(HashMap::new()),
|
||||
global_dedup: RwLock::new(MerkleHashMap::new()),
|
||||
upload_concurrency_controller: AdaptiveConcurrencyController::new_upload("memory_uploads"),
|
||||
url_expiration_ms: AtomicU64::new(u64::MAX),
|
||||
random_ms_delay_window: (AtomicU64::new(0), AtomicU64::new(0)),
|
||||
@@ -205,7 +206,7 @@ impl MemoryClient {
|
||||
Ok(RandomFileContents {
|
||||
file_hash,
|
||||
data: Bytes::from(file_data),
|
||||
xorbs: std::collections::HashMap::new(),
|
||||
xorbs: MerkleHashMap::new(),
|
||||
terms: term_infos,
|
||||
})
|
||||
}
|
||||
@@ -232,9 +233,9 @@ impl MemoryClient {
|
||||
impl Default for MemoryClient {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
xorbs: RwLock::new(HashMap::new()),
|
||||
xorbs: RwLock::new(MerkleHashMap::new()),
|
||||
shard: RwLock::new(MDBInMemoryShard::default()),
|
||||
global_dedup: RwLock::new(HashMap::new()),
|
||||
global_dedup: RwLock::new(MerkleHashMap::new()),
|
||||
upload_concurrency_controller: AdaptiveConcurrencyController::new_upload("memory_uploads"),
|
||||
url_expiration_ms: AtomicU64::new(u64::MAX),
|
||||
random_ms_delay_window: (AtomicU64::new(0), AtomicU64::new(0)),
|
||||
@@ -715,7 +716,7 @@ impl Client for MemoryClient {
|
||||
byte_range: FileRange,
|
||||
}
|
||||
|
||||
let mut fetch_info_map: HashMap<MerkleHash, Vec<FetchInfoIntermediate>> = HashMap::new();
|
||||
let mut fetch_info_map: MerkleHashMap<Vec<FetchInfoIntermediate>> = MerkleHashMap::new();
|
||||
|
||||
let xorbs = self.xorbs.read().await;
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use std::collections::HashMap;
|
||||
use std::result::Result;
|
||||
|
||||
use mdb_shard::file_structs::{
|
||||
@@ -8,6 +7,7 @@ use mdb_shard::hash_is_global_dedup_eligible;
|
||||
use merklehash::{MerkleHash, file_hash};
|
||||
use more_asserts::{debug_assert_le, debug_assert_lt};
|
||||
use progress_tracking::upload_tracking::FileXorbDependency;
|
||||
use utils::MerkleHashMap;
|
||||
|
||||
use crate::Chunk;
|
||||
use crate::constants::{MAX_XORB_BYTES, MAX_XORB_CHUNKS};
|
||||
@@ -30,7 +30,7 @@ pub struct FileDeduper<DataInterfaceType: DeduplicationDataInterface> {
|
||||
new_data_size: usize,
|
||||
|
||||
/// A hashmap allowing deduplication against the current chunk.
|
||||
new_data_hash_lookup: HashMap<MerkleHash, usize>,
|
||||
new_data_hash_lookup: MerkleHashMap<usize>,
|
||||
|
||||
/// The current chunk hashes for this file.
|
||||
chunk_hashes: Vec<(MerkleHash, u64)>,
|
||||
@@ -62,7 +62,7 @@ impl<DataInterfaceType: DeduplicationDataInterface> FileDeduper<DataInterfaceTyp
|
||||
file_id,
|
||||
new_data: Vec::new(),
|
||||
new_data_size: 0,
|
||||
new_data_hash_lookup: HashMap::new(),
|
||||
new_data_hash_lookup: MerkleHashMap::new(),
|
||||
chunk_hashes: Vec::new(),
|
||||
file_info: Vec::new(),
|
||||
internally_referencing_entries: Vec::new(),
|
||||
|
||||
4
hf_xet/Cargo.lock
generated
4
hf_xet/Cargo.lock
generated
@@ -2508,6 +2508,7 @@ dependencies = [
|
||||
"merklehash",
|
||||
"more-asserts",
|
||||
"tokio",
|
||||
"utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4077,6 +4078,7 @@ name = "utils"
|
||||
version = "0.14.5"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"bincode",
|
||||
"bytes",
|
||||
"ctor",
|
||||
"derivative",
|
||||
@@ -4086,6 +4088,8 @@ dependencies = [
|
||||
"lazy_static",
|
||||
"merklehash",
|
||||
"pin-project",
|
||||
"rand 0.9.2",
|
||||
"serde",
|
||||
"shellexpand",
|
||||
"thiserror 2.0.15",
|
||||
"tokio",
|
||||
|
||||
5
hf_xet_thin_wasm/Cargo.lock
generated
5
hf_xet_thin_wasm/Cargo.lock
generated
@@ -1825,6 +1825,7 @@ name = "utils"
|
||||
version = "0.14.5"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"bincode",
|
||||
"bytes",
|
||||
"ctor",
|
||||
"derivative",
|
||||
@@ -1833,8 +1834,12 @@ dependencies = [
|
||||
"futures",
|
||||
"lazy_static",
|
||||
"merklehash",
|
||||
"more-asserts",
|
||||
"pin-project",
|
||||
"rand",
|
||||
"serde",
|
||||
"shellexpand",
|
||||
"tempfile",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
|
||||
105
hf_xet_wasm/Cargo.lock
generated
105
hf_xet_wasm/Cargo.lock
generated
@@ -293,7 +293,6 @@ dependencies = [
|
||||
"cas_object",
|
||||
"cas_types",
|
||||
"chrono",
|
||||
"chunk_cache",
|
||||
"clap",
|
||||
"deduplication",
|
||||
"error_printer",
|
||||
@@ -407,27 +406,6 @@ dependencies = [
|
||||
"windows-link 0.2.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "chunk_cache"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"base64 0.22.1",
|
||||
"cas_types",
|
||||
"crc32fast",
|
||||
"error_printer",
|
||||
"file_utils",
|
||||
"merklehash",
|
||||
"mockall",
|
||||
"once_cell",
|
||||
"rand 0.9.2",
|
||||
"thiserror 2.0.16",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"utils",
|
||||
"xet_runtime",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "4.5.47"
|
||||
@@ -584,15 +562,6 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crc32fast"
|
||||
version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-queue"
|
||||
version = "0.3.12"
|
||||
@@ -745,12 +714,6 @@ dependencies = [
|
||||
"syn 2.0.106",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "downcast"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1435fa1053d8b2fbbe9be7e97eca7f33d37b28409959813daefc1446a14247f1"
|
||||
|
||||
[[package]]
|
||||
name = "dtor"
|
||||
version = "0.0.6"
|
||||
@@ -881,12 +844,6 @@ dependencies = [
|
||||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fragile"
|
||||
version = "2.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "28dd6caf6059519a65843af8fe2a3ae298b14b80179855aeb4adc2c1934ee619"
|
||||
|
||||
[[package]]
|
||||
name = "futures"
|
||||
version = "0.3.31"
|
||||
@@ -1793,32 +1750,6 @@ dependencies = [
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mockall"
|
||||
version = "0.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "39a6bfcc6c8c7eed5ee98b9c3e33adc726054389233e201c95dab2d41a3839d2"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.3",
|
||||
"downcast",
|
||||
"fragile",
|
||||
"mockall_derive",
|
||||
"predicates",
|
||||
"predicates-tree",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mockall_derive"
|
||||
version = "0.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "25ca3004c2efe9011bd4e461bd8256445052b9615405b4f7ea43fc8ca5c20898"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.3",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.106",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "more-asserts"
|
||||
version = "0.3.1"
|
||||
@@ -2088,32 +2019,6 @@ dependencies = [
|
||||
"zerocopy",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "predicates"
|
||||
version = "3.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a5d19ee57562043d37e82899fade9a22ebab7be9cef5026b07fda9cdd4293573"
|
||||
dependencies = [
|
||||
"anstyle",
|
||||
"predicates-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "predicates-core"
|
||||
version = "1.0.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "727e462b119fe9c93fd0eb1429a5f7647394014cf3c04ab2c0350eeb09095ffa"
|
||||
|
||||
[[package]]
|
||||
name = "predicates-tree"
|
||||
version = "1.0.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "72dd2d6d381dfb73a193c7fca536518d7caee39fc8503f74e7dc0be0531b425c"
|
||||
dependencies = [
|
||||
"predicates-core",
|
||||
"termtree",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.101"
|
||||
@@ -2131,6 +2036,7 @@ dependencies = [
|
||||
"merklehash",
|
||||
"more-asserts",
|
||||
"tokio",
|
||||
"utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2916,12 +2822,6 @@ dependencies = [
|
||||
"windows-sys 0.60.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "termtree"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683"
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.69"
|
||||
@@ -3339,6 +3239,7 @@ name = "utils"
|
||||
version = "0.14.5"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"bincode",
|
||||
"bytes",
|
||||
"ctor",
|
||||
"derivative",
|
||||
@@ -3348,6 +3249,8 @@ dependencies = [
|
||||
"lazy_static",
|
||||
"merklehash",
|
||||
"pin-project",
|
||||
"rand 0.9.2",
|
||||
"serde",
|
||||
"shellexpand",
|
||||
"thiserror 2.0.16",
|
||||
"tokio",
|
||||
|
||||
@@ -7,7 +7,7 @@ use std::sync::atomic::AtomicBool;
|
||||
use merklehash::{HMACKey, MerkleHash};
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, info, instrument, trace, warn};
|
||||
use utils::RwTaskLock;
|
||||
use utils::{MerkleHashMap, RwTaskLock, TruncatedMerkleHashMap};
|
||||
use xet_runtime::xet_config;
|
||||
|
||||
use crate::cas_structs::*;
|
||||
@@ -36,7 +36,7 @@ struct ChunkCacheElement {
|
||||
struct KeyedShardCollection {
|
||||
hmac_key: HMACKey,
|
||||
shard_list: Vec<Arc<MDBShardFile>>,
|
||||
chunk_lookup: HashMap<u64, ChunkCacheElement>,
|
||||
chunk_lookup: TruncatedMerkleHashMap<ChunkCacheElement>,
|
||||
}
|
||||
|
||||
impl KeyedShardCollection {
|
||||
@@ -53,8 +53,8 @@ impl KeyedShardCollection {
|
||||
#[derive(Default)]
|
||||
struct ShardBookkeeper {
|
||||
shard_collections: Vec<KeyedShardCollection>,
|
||||
collection_by_key: HashMap<HMACKey, usize>,
|
||||
shard_lookup_by_shard_hash: HashMap<MerkleHash, (usize, usize)>,
|
||||
collection_by_key: MerkleHashMap<usize>,
|
||||
shard_lookup_by_shard_hash: MerkleHashMap<(usize, usize)>,
|
||||
|
||||
// We cap the number of chunks indexed for dedup; beyond those, we simply drop the search.
|
||||
total_indexed_chunks: usize,
|
||||
@@ -66,7 +66,7 @@ impl ShardBookkeeper {
|
||||
// we always try to dedup locally first.
|
||||
Self {
|
||||
shard_collections: vec![KeyedShardCollection::new(HMACKey::default())],
|
||||
collection_by_key: HashMap::from([(HMACKey::default(), 0)]),
|
||||
collection_by_key: MerkleHashMap::from([(HMACKey::default(), 0)]),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
// The shard structure for the in memory querying
|
||||
|
||||
use std::collections::{BTreeMap, HashMap};
|
||||
use std::collections::BTreeMap;
|
||||
use std::io::{BufWriter, Write};
|
||||
use std::mem::size_of;
|
||||
use std::path::{Path, PathBuf};
|
||||
@@ -9,6 +9,7 @@ use std::time::Duration;
|
||||
|
||||
use merklehash::{HashedWrite, MerkleHash};
|
||||
use tracing::debug;
|
||||
use utils::MerkleHashMap;
|
||||
|
||||
use crate::cas_structs::*;
|
||||
use crate::error::Result;
|
||||
@@ -21,7 +22,7 @@ use crate::utils::{shard_file_name, temp_shard_file_name};
|
||||
pub struct MDBInMemoryShard {
|
||||
pub cas_content: BTreeMap<MerkleHash, Arc<MDBCASInfo>>,
|
||||
pub file_content: BTreeMap<MerkleHash, MDBFileInfo>,
|
||||
pub chunk_hash_lookup: HashMap<MerkleHash, (Arc<MDBCASInfo>, u64)>,
|
||||
pub chunk_hash_lookup: MerkleHashMap<(Arc<MDBCASInfo>, u64)>,
|
||||
current_shard_file_size: u64,
|
||||
}
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::collections::HashSet;
|
||||
use std::io::{Cursor, Read, Write, copy};
|
||||
use std::mem::size_of;
|
||||
|
||||
@@ -8,6 +8,7 @@ use futures_util::io::AsyncReadExt;
|
||||
use itertools::Itertools;
|
||||
use merklehash::MerkleHash;
|
||||
use more_asserts::debug_assert_lt;
|
||||
use utils::MerkleHashMap;
|
||||
|
||||
use crate::cas_structs::{CASChunkSequenceEntry, CASChunkSequenceHeader, MDBCASInfoView};
|
||||
use crate::error::{MDBShardError, Result};
|
||||
@@ -301,8 +302,8 @@ impl MDBMinimalShard {
|
||||
|
||||
/// Return a lookup of xorb hash to starting chunk indices for all the files present in the
|
||||
/// shard. These are the chunks that are useful for global dedup.
|
||||
fn file_start_entries(&self) -> HashMap<MerkleHash, Vec<usize>> {
|
||||
let mut file_start_entries = HashMap::<MerkleHash, Vec<usize>>::new();
|
||||
fn file_start_entries(&self) -> MerkleHashMap<Vec<usize>> {
|
||||
let mut file_start_entries = MerkleHashMap::<Vec<usize>>::new();
|
||||
|
||||
for f_idx in 0..self.num_files() {
|
||||
let Some(fv) = self.file(f_idx) else {
|
||||
|
||||
@@ -17,5 +17,9 @@ heed = { workspace = true }
|
||||
[target.'cfg(target_family = "wasm")'.dependencies]
|
||||
getrandom = { workspace = true, features = ["wasm_js"] }
|
||||
|
||||
[dev-dependencies]
|
||||
bincode = { workspace = true }
|
||||
tokio = { workspace = true, features = ["rt-multi-thread", "macros"] }
|
||||
|
||||
[features]
|
||||
strict = []
|
||||
|
||||
@@ -5,6 +5,7 @@ edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
merklehash = { path = "../merklehash" }
|
||||
utils = { path = "../utils" }
|
||||
|
||||
async-trait = { workspace = true }
|
||||
more-asserts = { workspace = true }
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
use std::collections::BTreeSet;
|
||||
use std::collections::hash_map::Entry as HashMapEntry;
|
||||
use std::collections::{BTreeSet, HashMap};
|
||||
use std::mem::take;
|
||||
use std::sync::Arc;
|
||||
|
||||
use merklehash::MerkleHash;
|
||||
use more_asserts::{debug_assert_ge, debug_assert_le};
|
||||
use tokio::sync::Mutex;
|
||||
use utils::MerkleHashMap;
|
||||
|
||||
use crate::{ItemProgressUpdate, ProgressUpdate, TrackingProgressUpdater};
|
||||
|
||||
@@ -57,7 +58,7 @@ struct FileDependency {
|
||||
/// Mapping of xorb_hash -> (number of completed bytes / number of bytes of the file contained in that xorb). Only
|
||||
/// xorbs that are not uploaded yet are tracked here.
|
||||
/// Once an xorb is uploaded, we remove it from here (and add to `completed_bytes`).
|
||||
remaining_xorbs_parts: HashMap<MerkleHash, XorbPartCompletionStats>,
|
||||
remaining_xorbs_parts: MerkleHashMap<XorbPartCompletionStats>,
|
||||
}
|
||||
|
||||
/// Tracks all files and all xorbs, allowing you to register file
|
||||
@@ -68,7 +69,7 @@ struct CompletionTrackerImpl {
|
||||
/// List of all files being tracked.
|
||||
files: Vec<FileDependency>,
|
||||
/// Map of xorb hash -> its dependency info (which files rely on it).
|
||||
xorbs: HashMap<MerkleHash, XorbDependency>,
|
||||
xorbs: MerkleHashMap<XorbDependency>,
|
||||
|
||||
/// Keep track of the totals across all xorbs.
|
||||
total_upload_bytes: u64,
|
||||
@@ -99,7 +100,7 @@ impl CompletionTrackerImpl {
|
||||
name: name.into(),
|
||||
total_bytes: n_bytes,
|
||||
completed_bytes: 0,
|
||||
remaining_xorbs_parts: HashMap::new(),
|
||||
remaining_xorbs_parts: MerkleHashMap::new(),
|
||||
};
|
||||
|
||||
// Insert it into our files vector.
|
||||
|
||||
@@ -7,6 +7,10 @@ edition = "2024"
|
||||
name = "utils"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "benchmark_hashmaps"
|
||||
path = "src/data_structures/bin/benchmark_hashmaps.rs"
|
||||
|
||||
[dependencies]
|
||||
error_printer = { path = "../error_printer" }
|
||||
merklehash = { path = "../merklehash" }
|
||||
@@ -19,6 +23,7 @@ duration-str = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
lazy_static = { workspace = true }
|
||||
pin-project = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
shellexpand = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tokio = { workspace = true, features = [
|
||||
@@ -31,6 +36,8 @@ tokio = { workspace = true, features = [
|
||||
tracing = { workspace = true }
|
||||
|
||||
[target.'cfg(not(target_family = "wasm"))'.dependencies]
|
||||
bincode = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
tokio-util = { workspace = true, features = ["io"] }
|
||||
|
||||
[target.'cfg(not(target_family = "wasm"))'.dev-dependencies]
|
||||
@@ -41,10 +48,11 @@ xet_runtime = { path = "../xet_runtime" }
|
||||
web-time = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
bincode = { workspace = true }
|
||||
futures-util = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
serial_test = { workspace = true }
|
||||
futures-util = { workspace = true }
|
||||
|
||||
[features]
|
||||
strict = []
|
||||
elevated_information_level = []
|
||||
elevated_information_level = []
|
||||
|
||||
342
utils/src/data_structures/bin/benchmark_hashmaps.rs
Normal file
342
utils/src/data_structures/bin/benchmark_hashmaps.rs
Normal file
@@ -0,0 +1,342 @@
|
||||
//! Benchmark comparing HashMap and PassThroughHashMap performance.
|
||||
//!
|
||||
//! Run with: cargo run --bin benchmark_hashmaps --release
|
||||
|
||||
use std::collections::{BTreeMap, HashMap};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use merklehash::{DataHash, MerkleHash};
|
||||
use rand::Rng;
|
||||
use utils::data_structures::PassThroughHashMap;
|
||||
|
||||
// Type aliases for cleaner code
|
||||
type MerklePassThroughHashMap = PassThroughHashMap<DataHash, u64>;
|
||||
|
||||
const OPERATION_COUNTS: &[usize] = &[100_000, 1_000_000, 10_000_000];
|
||||
|
||||
// Data structure identifiers for columns
|
||||
const DS_HASHMAP: &str = "HashMap";
|
||||
const DS_PASSTHROUGH: &str = "PassThrough";
|
||||
|
||||
/// Results storage for summary table
|
||||
#[derive(Default)]
|
||||
struct BenchmarkResults {
|
||||
/// Map of (test_name, size, data_structure) -> duration
|
||||
results: BTreeMap<(String, usize, String), Duration>,
|
||||
}
|
||||
|
||||
impl BenchmarkResults {
|
||||
fn record(&mut self, test: &str, size: usize, ds: &str, duration: Duration) {
|
||||
self.results.insert((test.to_string(), size, ds.to_string()), duration);
|
||||
}
|
||||
|
||||
fn print_summary_table(&self) {
|
||||
let data_structures = [DS_HASHMAP, DS_PASSTHROUGH];
|
||||
let tests = ["Insert", "Lookup", "Insert+Lookup", "Serialize", "Deserialize"];
|
||||
|
||||
println!("\n{}", "=".repeat(65));
|
||||
println!("PERFORMANCE SUMMARY (times in ms, lower is better)");
|
||||
println!("{}", "=".repeat(65));
|
||||
|
||||
// Print header
|
||||
print!("{:<25}", "Test");
|
||||
for ds in &data_structures {
|
||||
print!("{:>20}", ds);
|
||||
}
|
||||
println!();
|
||||
println!("{}", "-".repeat(65));
|
||||
|
||||
for &size in OPERATION_COUNTS {
|
||||
println!("--- {} ---", format_count(size));
|
||||
for test in &tests {
|
||||
print!(" {:<23}", test);
|
||||
for ds in &data_structures {
|
||||
if let Some(duration) = self.results.get(&(test.to_string(), size, ds.to_string())) {
|
||||
print!("{:>20}", format_duration_ms(*duration));
|
||||
} else {
|
||||
print!("{:>20}", "-");
|
||||
}
|
||||
}
|
||||
println!();
|
||||
}
|
||||
println!();
|
||||
}
|
||||
println!("{}", "=".repeat(65));
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_merkle_keys(count: usize) -> Vec<DataHash> {
|
||||
let mut rng = rand::rng();
|
||||
(0..count)
|
||||
.map(|_| {
|
||||
let mut bytes = [0u8; 32];
|
||||
rng.fill(&mut bytes);
|
||||
MerkleHash::from(bytes)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn format_count(count: usize) -> String {
|
||||
if count >= 1_000_000 {
|
||||
format!("{}M", count / 1_000_000)
|
||||
} else if count >= 1_000 {
|
||||
format!("{}K", count / 1_000)
|
||||
} else {
|
||||
format!("{}", count)
|
||||
}
|
||||
}
|
||||
|
||||
fn format_duration(d: Duration) -> String {
|
||||
if d.as_secs() > 0 {
|
||||
format!("{:.2}s", d.as_secs_f64())
|
||||
} else if d.as_millis() > 0 {
|
||||
format!("{:.2}ms", d.as_secs_f64() * 1000.0)
|
||||
} else {
|
||||
format!("{:.2}µs", d.as_secs_f64() * 1_000_000.0)
|
||||
}
|
||||
}
|
||||
|
||||
fn format_duration_ms(d: Duration) -> String {
|
||||
format!("{:.1}", d.as_secs_f64() * 1000.0)
|
||||
}
|
||||
|
||||
fn format_ops_per_sec(count: usize, d: Duration) -> String {
|
||||
let ops = count as f64 / d.as_secs_f64();
|
||||
if ops >= 1_000_000_000.0 {
|
||||
format!("{:.2}B/s", ops / 1_000_000_000.0)
|
||||
} else if ops >= 1_000_000.0 {
|
||||
format!("{:.2}M/s", ops / 1_000_000.0)
|
||||
} else if ops >= 1_000.0 {
|
||||
format!("{:.2}K/s", ops / 1_000.0)
|
||||
} else {
|
||||
format!("{:.2}/s", ops)
|
||||
}
|
||||
}
|
||||
|
||||
fn print_result(name: &str, count: usize, duration: Duration) {
|
||||
println!(" {:<40} {:>12} ({:>12})", name, format_duration(duration), format_ops_per_sec(count, duration));
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Benchmarks
|
||||
// ============================================================================
|
||||
|
||||
fn bench_insert(keys: &[DataHash], map_type: &str) -> Duration {
|
||||
match map_type {
|
||||
DS_HASHMAP => {
|
||||
let start = Instant::now();
|
||||
let mut map: HashMap<DataHash, u64> = HashMap::with_capacity(keys.len());
|
||||
for (i, key) in keys.iter().enumerate() {
|
||||
map.insert(*key, i as u64);
|
||||
}
|
||||
std::hint::black_box(&map);
|
||||
start.elapsed()
|
||||
},
|
||||
DS_PASSTHROUGH => {
|
||||
let start = Instant::now();
|
||||
let mut map: MerklePassThroughHashMap = PassThroughHashMap::with_capacity(keys.len());
|
||||
for (i, key) in keys.iter().enumerate() {
|
||||
map.insert(*key, i as u64);
|
||||
}
|
||||
std::hint::black_box(&map);
|
||||
start.elapsed()
|
||||
},
|
||||
_ => Duration::ZERO,
|
||||
}
|
||||
}
|
||||
|
||||
fn bench_lookup(keys: &[DataHash], map_type: &str) -> Duration {
|
||||
// First build the map
|
||||
match map_type {
|
||||
DS_HASHMAP => {
|
||||
let mut map: HashMap<DataHash, u64> = HashMap::with_capacity(keys.len());
|
||||
for (i, key) in keys.iter().enumerate() {
|
||||
map.insert(*key, i as u64);
|
||||
}
|
||||
let start = Instant::now();
|
||||
let mut sum = 0u64;
|
||||
for key in keys {
|
||||
if let Some(v) = map.get(key) {
|
||||
sum = sum.wrapping_add(*v);
|
||||
}
|
||||
}
|
||||
std::hint::black_box(sum);
|
||||
start.elapsed()
|
||||
},
|
||||
DS_PASSTHROUGH => {
|
||||
let mut map: MerklePassThroughHashMap = PassThroughHashMap::with_capacity(keys.len());
|
||||
for (i, key) in keys.iter().enumerate() {
|
||||
map.insert(*key, i as u64);
|
||||
}
|
||||
let start = Instant::now();
|
||||
let mut sum = 0u64;
|
||||
for key in keys {
|
||||
if let Some(v) = map.get(key) {
|
||||
sum = sum.wrapping_add(*v);
|
||||
}
|
||||
}
|
||||
std::hint::black_box(sum);
|
||||
start.elapsed()
|
||||
},
|
||||
_ => Duration::ZERO,
|
||||
}
|
||||
}
|
||||
|
||||
fn bench_insert_then_lookup(keys: &[DataHash], map_type: &str) -> Duration {
|
||||
match map_type {
|
||||
DS_HASHMAP => {
|
||||
let start = Instant::now();
|
||||
let mut map: HashMap<DataHash, u64> = HashMap::with_capacity(keys.len());
|
||||
for (i, key) in keys.iter().enumerate() {
|
||||
map.insert(*key, i as u64);
|
||||
}
|
||||
let mut sum = 0u64;
|
||||
for key in keys {
|
||||
if let Some(v) = map.get(key) {
|
||||
sum = sum.wrapping_add(*v);
|
||||
}
|
||||
}
|
||||
std::hint::black_box(sum);
|
||||
start.elapsed()
|
||||
},
|
||||
DS_PASSTHROUGH => {
|
||||
let start = Instant::now();
|
||||
let mut map: MerklePassThroughHashMap = PassThroughHashMap::with_capacity(keys.len());
|
||||
for (i, key) in keys.iter().enumerate() {
|
||||
map.insert(*key, i as u64);
|
||||
}
|
||||
let mut sum = 0u64;
|
||||
for key in keys {
|
||||
if let Some(v) = map.get(key) {
|
||||
sum = sum.wrapping_add(*v);
|
||||
}
|
||||
}
|
||||
std::hint::black_box(sum);
|
||||
start.elapsed()
|
||||
},
|
||||
_ => Duration::ZERO,
|
||||
}
|
||||
}
|
||||
|
||||
fn bench_serialize(keys: &[DataHash], map_type: &str) -> Duration {
|
||||
match map_type {
|
||||
DS_HASHMAP => {
|
||||
let mut map: HashMap<DataHash, u64> = HashMap::with_capacity(keys.len());
|
||||
for (i, key) in keys.iter().enumerate() {
|
||||
map.insert(*key, i as u64);
|
||||
}
|
||||
let start = Instant::now();
|
||||
let bytes = bincode::serialize(&map).unwrap();
|
||||
std::hint::black_box(&bytes);
|
||||
start.elapsed()
|
||||
},
|
||||
DS_PASSTHROUGH => {
|
||||
let mut map: MerklePassThroughHashMap = PassThroughHashMap::with_capacity(keys.len());
|
||||
for (i, key) in keys.iter().enumerate() {
|
||||
map.insert(*key, i as u64);
|
||||
}
|
||||
let start = Instant::now();
|
||||
let bytes = bincode::serialize(&map).unwrap();
|
||||
std::hint::black_box(&bytes);
|
||||
start.elapsed()
|
||||
},
|
||||
_ => Duration::ZERO,
|
||||
}
|
||||
}
|
||||
|
||||
fn bench_deserialize(keys: &[DataHash], map_type: &str) -> Duration {
|
||||
match map_type {
|
||||
DS_HASHMAP => {
|
||||
let mut map: HashMap<DataHash, u64> = HashMap::with_capacity(keys.len());
|
||||
for (i, key) in keys.iter().enumerate() {
|
||||
map.insert(*key, i as u64);
|
||||
}
|
||||
let bytes = bincode::serialize(&map).unwrap();
|
||||
let start = Instant::now();
|
||||
let map2: HashMap<DataHash, u64> = bincode::deserialize(&bytes).unwrap();
|
||||
std::hint::black_box(&map2);
|
||||
start.elapsed()
|
||||
},
|
||||
DS_PASSTHROUGH => {
|
||||
let mut map: MerklePassThroughHashMap = PassThroughHashMap::with_capacity(keys.len());
|
||||
for (i, key) in keys.iter().enumerate() {
|
||||
map.insert(*key, i as u64);
|
||||
}
|
||||
let bytes = bincode::serialize(&map).unwrap();
|
||||
let start = Instant::now();
|
||||
let map2: MerklePassThroughHashMap = bincode::deserialize(&bytes).unwrap();
|
||||
std::hint::black_box(&map2);
|
||||
start.elapsed()
|
||||
},
|
||||
_ => Duration::ZERO,
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Main Benchmark Runner
|
||||
// ============================================================================
|
||||
|
||||
fn run_benchmarks(count: usize, results: &mut BenchmarkResults) {
|
||||
println!("\n{}", "=".repeat(80));
|
||||
println!("Benchmarking {} operations", count);
|
||||
println!("{}", "=".repeat(80));
|
||||
|
||||
println!("\n--- Generating keys ---");
|
||||
let keys = generate_merkle_keys(count);
|
||||
println!(" Generated {} MerkleHash keys", count);
|
||||
|
||||
println!("\n--- Benchmarks ---");
|
||||
|
||||
for ds in [DS_HASHMAP, DS_PASSTHROUGH] {
|
||||
let d = bench_insert(&keys, ds);
|
||||
print_result(&format!("{} Insert", ds), count, d);
|
||||
results.record("Insert", count, ds, d);
|
||||
}
|
||||
|
||||
for ds in [DS_HASHMAP, DS_PASSTHROUGH] {
|
||||
let d = bench_lookup(&keys, ds);
|
||||
print_result(&format!("{} Lookup", ds), count, d);
|
||||
results.record("Lookup", count, ds, d);
|
||||
}
|
||||
|
||||
for ds in [DS_HASHMAP, DS_PASSTHROUGH] {
|
||||
let d = bench_insert_then_lookup(&keys, ds);
|
||||
print_result(&format!("{} Insert+Lookup", ds), count, d);
|
||||
results.record("Insert+Lookup", count, ds, d);
|
||||
}
|
||||
|
||||
for ds in [DS_HASHMAP, DS_PASSTHROUGH] {
|
||||
let d = bench_serialize(&keys, ds);
|
||||
print_result(&format!("{} Serialize", ds), count, d);
|
||||
results.record("Serialize", count, ds, d);
|
||||
}
|
||||
|
||||
for ds in [DS_HASHMAP, DS_PASSTHROUGH] {
|
||||
let d = bench_deserialize(&keys, ds);
|
||||
print_result(&format!("{} Deserialize", ds), count, d);
|
||||
results.record("Deserialize", count, ds, d);
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
println!("HashMap Benchmark Suite");
|
||||
println!("=======================");
|
||||
println!();
|
||||
println!("Comparing:");
|
||||
println!(" - HashMap: std::collections::HashMap");
|
||||
println!(" - PassThrough: PassThroughHashMap (optimized hasher for MerkleHash keys)");
|
||||
println!();
|
||||
println!("Key type: MerkleHash (32 bytes)");
|
||||
println!("Value type: u64");
|
||||
|
||||
let mut results = BenchmarkResults::default();
|
||||
|
||||
for &count in OPERATION_COUNTS {
|
||||
run_benchmarks(count, &mut results);
|
||||
}
|
||||
|
||||
// Print summary table
|
||||
results.print_summary_table();
|
||||
|
||||
println!("\nBenchmark complete!");
|
||||
}
|
||||
17
utils/src/data_structures/mod.rs
Normal file
17
utils/src/data_structures/mod.rs
Normal file
@@ -0,0 +1,17 @@
|
||||
mod passthrough_hasher;
|
||||
mod passthrough_hashmap;
|
||||
|
||||
use merklehash::MerkleHash;
|
||||
pub use passthrough_hasher::U64HashExtractable;
|
||||
pub use passthrough_hashmap::PassThroughHashMap;
|
||||
|
||||
/// A HashMap specialized for `MerkleHash` keys using passthrough hashing.
|
||||
///
|
||||
/// This is a type alias for `PassThroughHashMap<MerkleHash, Value>`.
|
||||
pub type MerkleHashMap<Value> = PassThroughHashMap<MerkleHash, Value>;
|
||||
|
||||
/// A HashMap specialized for `u64` keys using passthrough hashing.
|
||||
///
|
||||
/// This is useful when the key is already a truncated hash value (e.g., the first 8 bytes
|
||||
/// of a larger hash), and we want to avoid re-hashing.
|
||||
pub type TruncatedMerkleHashMap<Value> = PassThroughHashMap<u64, Value>;
|
||||
143
utils/src/data_structures/passthrough_hasher.rs
Normal file
143
utils/src/data_structures/passthrough_hasher.rs
Normal file
@@ -0,0 +1,143 @@
|
||||
use std::hash::{BuildHasher, Hasher};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use merklehash::MerkleHash;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Trait for types that can efficiently provide a u64 hash value.
|
||||
/// Types implementing this trait are optimized for use with `U64DirectHasher`,
|
||||
/// which avoids extra computation by directly using the u64 value as the hash.
|
||||
pub trait U64HashExtractable {
|
||||
fn u64_hash_value(&self) -> u64;
|
||||
}
|
||||
|
||||
impl U64HashExtractable for u64 {
|
||||
fn u64_hash_value(&self) -> u64 {
|
||||
*self
|
||||
}
|
||||
}
|
||||
|
||||
impl U64HashExtractable for MerkleHash {
|
||||
fn u64_hash_value(&self) -> u64 {
|
||||
self[0]
|
||||
}
|
||||
}
|
||||
|
||||
/// A hasher that maps directly to a u64 hash value for types implementing U64HashExtractable.
|
||||
///
|
||||
/// This hasher is designed to work with types that already contain high-quality hash values
|
||||
/// (like cryptographic hashes), avoiding extra computation and supporting direct bucket selection.
|
||||
///
|
||||
/// Note: This is not resistant to attacks where all keys are designed to hash to the same bucket
|
||||
/// and thus make the lookup time linear, but in our use case, where it's already a cryptographic
|
||||
/// hash, this gives us two advantages:
|
||||
/// - Speedup for lookup.
|
||||
/// - Consistent serialization order, so MUCH faster deserialization.
|
||||
#[derive(Clone, Copy, Serialize, Deserialize)]
|
||||
pub struct U64DirectHasher<T: U64HashExtractable> {
|
||||
state: u64,
|
||||
#[serde(skip)]
|
||||
_phantom: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T: U64HashExtractable> Default for U64DirectHasher<T> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
state: 0,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: U64HashExtractable> Hasher for U64DirectHasher<T> {
|
||||
fn finish(&self) -> u64 {
|
||||
self.state
|
||||
}
|
||||
|
||||
fn write(&mut self, bytes: &[u8]) {
|
||||
debug_assert!(bytes.len() >= 8);
|
||||
|
||||
unsafe {
|
||||
let dest = &mut self.state as *mut u64 as *mut u8;
|
||||
bytes.as_ptr().copy_to_nonoverlapping(dest, 8);
|
||||
}
|
||||
}
|
||||
|
||||
fn write_u64(&mut self, i: u64) {
|
||||
self.state = i;
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: U64HashExtractable> BuildHasher for U64DirectHasher<T> {
|
||||
type Hasher = U64DirectHasher<T>;
|
||||
|
||||
fn build_hasher(&self) -> Self::Hasher {
|
||||
U64DirectHasher::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::hash::Hash;
|
||||
|
||||
use merklehash::DataHash;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_hash_uses_first_u64() {
|
||||
let mut hasher1 = U64DirectHasher::<MerkleHash>::default();
|
||||
let mut hasher2 = U64DirectHasher::<MerkleHash>::default();
|
||||
|
||||
let hash1 = DataHash::from([0x1234567890ABCDEF, 0, 0, 0]);
|
||||
let hash2 = DataHash::from([
|
||||
0x1234567890ABCDEF,
|
||||
0xFFFFFFFFFFFFFFFF,
|
||||
0xFFFFFFFFFFFFFFFF,
|
||||
0xFFFFFFFFFFFFFFFF,
|
||||
]);
|
||||
|
||||
hash1.hash(&mut hasher1);
|
||||
hash2.hash(&mut hasher2);
|
||||
|
||||
assert_eq!(hasher1.finish(), hasher2.finish());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hash_different_first_u64() {
|
||||
let mut hasher1 = U64DirectHasher::<MerkleHash>::default();
|
||||
let mut hasher2 = U64DirectHasher::<MerkleHash>::default();
|
||||
|
||||
let hash1 = DataHash::from([0x1234567890ABCDEF, 0, 0, 0]);
|
||||
let hash2 = DataHash::from([0xFEDCBA0987654321, 0, 0, 0]);
|
||||
|
||||
hash1.hash(&mut hasher1);
|
||||
hash2.hash(&mut hasher2);
|
||||
|
||||
assert_ne!(hasher1.finish(), hasher2.finish());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_u64_hash_extractable() {
|
||||
let value: u64 = 0x1234567890ABCDEF;
|
||||
assert_eq!(value.u64_hash_value(), 0x1234567890ABCDEF);
|
||||
|
||||
let hash = MerkleHash::from([0xFEDCBA0987654321, 0x1111111111111111, 0, 0]);
|
||||
assert_eq!(hash.u64_hash_value(), 0xFEDCBA0987654321);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_u64_passthrough_hasher() {
|
||||
let mut hasher1 = U64DirectHasher::<u64>::default();
|
||||
let mut hasher2 = U64DirectHasher::<u64>::default();
|
||||
|
||||
let value1: u64 = 0x1234567890ABCDEF;
|
||||
let value2: u64 = 0x1234567890ABCDEF;
|
||||
|
||||
value1.hash(&mut hasher1);
|
||||
value2.hash(&mut hasher2);
|
||||
|
||||
assert_eq!(hasher1.finish(), hasher2.finish());
|
||||
assert_eq!(hasher1.finish(), value1);
|
||||
}
|
||||
}
|
||||
484
utils/src/data_structures/passthrough_hashmap.rs
Normal file
484
utils/src/data_structures/passthrough_hashmap.rs
Normal file
@@ -0,0 +1,484 @@
|
||||
use std::collections::HashMap;
|
||||
use std::hash::Hash;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::passthrough_hasher::{U64DirectHasher, U64HashExtractable};
|
||||
|
||||
/// A HashMap wrapper optimized for keys that implement `U64HashExtractable`.
|
||||
///
|
||||
/// This structure uses `U64DirectHasher` which avoids extra hash computation
|
||||
/// by directly using the u64 value extracted from the key. This is particularly
|
||||
/// efficient for cryptographic hashes (like `MerkleHash`) where the first 8 bytes
|
||||
/// already provide excellent distribution.
|
||||
///
|
||||
/// # Behavior
|
||||
///
|
||||
/// This type implements `Deref` and `DerefMut` to the underlying `HashMap`, so it
|
||||
/// behaves exactly like a standard `HashMap`. All `HashMap` methods are available
|
||||
/// directly on this type:
|
||||
///
|
||||
/// ```ignore
|
||||
/// let mut map: PassThroughHashMap<MerkleHash, String> = PassThroughHashMap::new();
|
||||
///
|
||||
/// // All HashMap methods work directly:
|
||||
/// map.insert(key, "value".to_string());
|
||||
/// map.get(&key);
|
||||
/// map.contains_key(&key);
|
||||
/// map.remove(&key);
|
||||
/// map.len();
|
||||
/// map.is_empty();
|
||||
/// map.iter();
|
||||
/// map.entry(key).or_insert("default".to_string());
|
||||
/// // ... and all other HashMap methods
|
||||
/// ```
|
||||
///
|
||||
/// # Type Parameters
|
||||
/// - `Key`: The key type (must implement `U64HashExtractable + Hash + Eq`)
|
||||
/// - `Value`: The value type stored in the map
|
||||
pub struct PassThroughHashMap<Key, Value>
|
||||
where
|
||||
Key: U64HashExtractable + Hash + Eq,
|
||||
{
|
||||
inner: HashMap<Key, Value, U64DirectHasher<Key>>,
|
||||
}
|
||||
|
||||
impl<Key, Value> PassThroughHashMap<Key, Value>
|
||||
where
|
||||
Key: U64HashExtractable + Hash + Eq,
|
||||
{
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
inner: HashMap::with_hasher(U64DirectHasher::default()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_capacity(capacity: usize) -> Self {
|
||||
Self {
|
||||
inner: HashMap::with_capacity_and_hasher(capacity, U64DirectHasher::default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Key, Value> Default for PassThroughHashMap<Key, Value>
|
||||
where
|
||||
Key: U64HashExtractable + Hash + Eq,
|
||||
{
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl<Key, Value> std::fmt::Debug for PassThroughHashMap<Key, Value>
|
||||
where
|
||||
Key: U64HashExtractable + Hash + Eq + std::fmt::Debug,
|
||||
Value: std::fmt::Debug,
|
||||
{
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_map().entries(self.inner.iter()).finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl<Key, Value> Clone for PassThroughHashMap<Key, Value>
|
||||
where
|
||||
Key: U64HashExtractable + Hash + Eq + Clone,
|
||||
Value: Clone,
|
||||
{
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
inner: self.inner.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Key, Value> std::iter::FromIterator<(Key, Value)> for PassThroughHashMap<Key, Value>
|
||||
where
|
||||
Key: U64HashExtractable + Hash + Eq,
|
||||
{
|
||||
fn from_iter<I: IntoIterator<Item = (Key, Value)>>(iter: I) -> Self {
|
||||
let mut map = Self::new();
|
||||
for (k, v) in iter {
|
||||
map.insert(k, v);
|
||||
}
|
||||
map
|
||||
}
|
||||
}
|
||||
|
||||
impl<Key, Value> IntoIterator for PassThroughHashMap<Key, Value>
|
||||
where
|
||||
Key: U64HashExtractable + Hash + Eq,
|
||||
{
|
||||
type Item = (Key, Value);
|
||||
type IntoIter = std::collections::hash_map::IntoIter<Key, Value>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.inner.into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
/// Provides immutable access to all `HashMap` methods.
|
||||
///
|
||||
/// This allows using `PassThroughHashMap` exactly like a `HashMap`:
|
||||
/// `get`, `contains_key`, `len`, `is_empty`, `iter`, `keys`, `values`, etc.
|
||||
impl<Key, Value> Deref for PassThroughHashMap<Key, Value>
|
||||
where
|
||||
Key: U64HashExtractable + Hash + Eq,
|
||||
{
|
||||
type Target = HashMap<Key, Value, U64DirectHasher<Key>>;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.inner
|
||||
}
|
||||
}
|
||||
|
||||
/// Provides mutable access to all `HashMap` methods.
|
||||
///
|
||||
/// This allows using `PassThroughHashMap` exactly like a `HashMap`:
|
||||
/// `insert`, `remove`, `clear`, `get_mut`, `entry`, `iter_mut`, etc.
|
||||
impl<Key, Value> DerefMut for PassThroughHashMap<Key, Value>
|
||||
where
|
||||
Key: U64HashExtractable + Hash + Eq,
|
||||
{
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<Key, Value> From<HashMap<Key, Value, U64DirectHasher<Key>>> for PassThroughHashMap<Key, Value>
|
||||
where
|
||||
Key: U64HashExtractable + Hash + Eq,
|
||||
{
|
||||
fn from(inner: HashMap<Key, Value, U64DirectHasher<Key>>) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
}
|
||||
|
||||
impl<Key, Value> From<PassThroughHashMap<Key, Value>> for HashMap<Key, Value, U64DirectHasher<Key>>
|
||||
where
|
||||
Key: U64HashExtractable + Hash + Eq,
|
||||
{
|
||||
fn from(val: PassThroughHashMap<Key, Value>) -> Self {
|
||||
val.inner
|
||||
}
|
||||
}
|
||||
|
||||
impl<Key, Value, const N: usize> From<[(Key, Value); N]> for PassThroughHashMap<Key, Value>
|
||||
where
|
||||
Key: U64HashExtractable + Hash + Eq,
|
||||
{
|
||||
fn from(arr: [(Key, Value); N]) -> Self {
|
||||
arr.into_iter().collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl<Key, Value> Serialize for PassThroughHashMap<Key, Value>
|
||||
where
|
||||
Key: U64HashExtractable + Hash + Eq + Serialize,
|
||||
Value: Serialize,
|
||||
{
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
use serde::ser::SerializeSeq;
|
||||
|
||||
let mut seq = serializer.serialize_seq(Some(self.inner.len()))?;
|
||||
for (key, value) in self.inner.iter() {
|
||||
seq.serialize_element(&(key, value))?;
|
||||
}
|
||||
seq.end()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de, Key, Value> Deserialize<'de> for PassThroughHashMap<Key, Value>
|
||||
where
|
||||
Key: U64HashExtractable + Hash + Eq + Deserialize<'de>,
|
||||
Value: Deserialize<'de>,
|
||||
{
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
struct PassThroughVisitor<Key, Value> {
|
||||
_marker: std::marker::PhantomData<(Key, Value)>,
|
||||
}
|
||||
|
||||
impl<'de, Key, Value> serde::de::Visitor<'de> for PassThroughVisitor<Key, Value>
|
||||
where
|
||||
Key: U64HashExtractable + Hash + Eq + Deserialize<'de>,
|
||||
Value: Deserialize<'de>,
|
||||
{
|
||||
type Value = PassThroughHashMap<Key, Value>;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("a sequence of key-value pairs")
|
||||
}
|
||||
|
||||
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
|
||||
where
|
||||
A: serde::de::SeqAccess<'de>,
|
||||
{
|
||||
let capacity = seq.size_hint().unwrap_or(0);
|
||||
let mut map = PassThroughHashMap::with_capacity(capacity);
|
||||
|
||||
while let Some((key, value)) = seq.next_element()? {
|
||||
map.insert(key, value);
|
||||
}
|
||||
|
||||
Ok(map)
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_seq(PassThroughVisitor {
|
||||
_marker: std::marker::PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::HashMap;
|
||||
|
||||
use merklehash::{DataHash, compute_data_hash};
|
||||
|
||||
use super::*;
|
||||
|
||||
type MerkleHashMap<Value> = PassThroughHashMap<DataHash, Value>;
|
||||
|
||||
#[test]
|
||||
fn test_new() {
|
||||
let table: MerkleHashMap<String> = MerkleHashMap::new();
|
||||
assert!(table.is_empty());
|
||||
assert_eq!(table.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default() {
|
||||
let table: MerkleHashMap<i32> = MerkleHashMap::default();
|
||||
assert!(table.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_with_capacity() {
|
||||
let table: MerkleHashMap<u64> = MerkleHashMap::with_capacity(100);
|
||||
assert!(table.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_insert_and_get() {
|
||||
let mut table = MerkleHashMap::new();
|
||||
let hash1 = compute_data_hash(b"test1");
|
||||
let hash2 = compute_data_hash(b"test2");
|
||||
|
||||
assert_eq!(table.insert(hash1, "value1"), None);
|
||||
assert_eq!(table.insert(hash2, "value2"), None);
|
||||
assert_eq!(table.len(), 2);
|
||||
|
||||
assert_eq!(table.get(&hash1), Some(&"value1"));
|
||||
assert_eq!(table.get(&hash2), Some(&"value2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_insert_overwrite() {
|
||||
let mut table = MerkleHashMap::new();
|
||||
let hash = compute_data_hash(b"test");
|
||||
|
||||
assert_eq!(table.insert(hash, "value1"), None);
|
||||
assert_eq!(table.insert(hash, "value2"), Some("value1"));
|
||||
assert_eq!(table.get(&hash), Some(&"value2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_contains_key() {
|
||||
let mut table = MerkleHashMap::new();
|
||||
let hash1 = compute_data_hash(b"test1");
|
||||
let hash2 = compute_data_hash(b"test2");
|
||||
|
||||
table.insert(hash1, 42);
|
||||
assert!(table.contains_key(&hash1));
|
||||
assert!(!table.contains_key(&hash2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_remove() {
|
||||
let mut table = MerkleHashMap::new();
|
||||
let hash1 = compute_data_hash(b"test1");
|
||||
let hash2 = compute_data_hash(b"test2");
|
||||
|
||||
table.insert(hash1, "value1");
|
||||
table.insert(hash2, "value2");
|
||||
assert_eq!(table.len(), 2);
|
||||
|
||||
assert_eq!(table.remove(&hash1), Some("value1"));
|
||||
assert_eq!(table.len(), 1);
|
||||
assert_eq!(table.get(&hash1), None);
|
||||
assert_eq!(table.get(&hash2), Some(&"value2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_remove_nonexistent() {
|
||||
let mut table: MerkleHashMap<&str> = MerkleHashMap::new();
|
||||
let hash = compute_data_hash(b"test");
|
||||
|
||||
assert_eq!(table.remove(&hash), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clear() {
|
||||
let mut table = MerkleHashMap::new();
|
||||
let hash1 = compute_data_hash(b"test1");
|
||||
let hash2 = compute_data_hash(b"test2");
|
||||
|
||||
table.insert(hash1, "value1");
|
||||
table.insert(hash2, "value2");
|
||||
assert_eq!(table.len(), 2);
|
||||
|
||||
table.clear();
|
||||
assert!(table.is_empty());
|
||||
assert_eq!(table.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_iter() {
|
||||
let mut table = MerkleHashMap::new();
|
||||
let hash1 = compute_data_hash(b"test1");
|
||||
let hash2 = compute_data_hash(b"test2");
|
||||
|
||||
table.insert(hash1, 10);
|
||||
table.insert(hash2, 20);
|
||||
|
||||
let items: Vec<_> = table.iter().collect();
|
||||
assert_eq!(items.len(), 2);
|
||||
|
||||
let values: Vec<_> = items.iter().map(|(_, v)| **v).collect();
|
||||
assert!(values.contains(&10));
|
||||
assert!(values.contains(&20));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_iter_mut() {
|
||||
let mut table = MerkleHashMap::new();
|
||||
let hash = compute_data_hash(b"test");
|
||||
|
||||
table.insert(hash, 10);
|
||||
*table.get_mut(&hash).unwrap() = 20;
|
||||
assert_eq!(table.get(&hash), Some(&20));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_keys() {
|
||||
let mut table = MerkleHashMap::new();
|
||||
let hash1 = compute_data_hash(b"test1");
|
||||
let hash2 = compute_data_hash(b"test2");
|
||||
|
||||
table.insert(hash1, "value1");
|
||||
table.insert(hash2, "value2");
|
||||
|
||||
let keys: Vec<_> = table.keys().collect();
|
||||
assert_eq!(keys.len(), 2);
|
||||
assert!(keys.contains(&&hash1));
|
||||
assert!(keys.contains(&&hash2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_values() {
|
||||
let mut table = MerkleHashMap::new();
|
||||
let hash1 = compute_data_hash(b"test1");
|
||||
let hash2 = compute_data_hash(b"test2");
|
||||
|
||||
table.insert(hash1, "value1");
|
||||
table.insert(hash2, "value2");
|
||||
|
||||
let values: Vec<_> = table.values().collect();
|
||||
assert_eq!(values.len(), 2);
|
||||
assert!(values.contains(&&"value1"));
|
||||
assert!(values.contains(&&"value2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_entry() {
|
||||
let mut table = MerkleHashMap::new();
|
||||
let hash = compute_data_hash(b"test");
|
||||
|
||||
table.entry(hash).or_insert("default");
|
||||
assert_eq!(table.get(&hash), Some(&"default"));
|
||||
|
||||
table.entry(hash).and_modify(|v| *v = "modified");
|
||||
assert_eq!(table.get(&hash), Some(&"modified"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_hashmap() {
|
||||
let mut hashmap = HashMap::with_hasher(U64DirectHasher::<DataHash>::default());
|
||||
let hash1 = compute_data_hash(b"test1");
|
||||
let hash2 = compute_data_hash(b"test2");
|
||||
|
||||
hashmap.insert(hash1, 10);
|
||||
hashmap.insert(hash2, 20);
|
||||
|
||||
let table: MerkleHashMap<i32> = MerkleHashMap::from(hashmap);
|
||||
assert_eq!(table.len(), 2);
|
||||
assert_eq!(table.get(&hash1), Some(&10));
|
||||
assert_eq!(table.get(&hash2), Some(&20));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_into_hashmap() {
|
||||
let mut table = MerkleHashMap::new();
|
||||
let hash1 = compute_data_hash(b"test1");
|
||||
let hash2 = compute_data_hash(b"test2");
|
||||
|
||||
table.insert(hash1, 10);
|
||||
table.insert(hash2, 20);
|
||||
|
||||
let hashmap: HashMap<DataHash, i32, U64DirectHasher<DataHash>> = table.into();
|
||||
assert_eq!(hashmap.len(), 2);
|
||||
assert_eq!(hashmap.get(&hash1), Some(&10));
|
||||
assert_eq!(hashmap.get(&hash2), Some(&20));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serialize_deserialize() {
|
||||
let mut table = MerkleHashMap::new();
|
||||
let hash1 = compute_data_hash(b"test1");
|
||||
let hash2 = compute_data_hash(b"test2");
|
||||
let hash3 = compute_data_hash(b"test3");
|
||||
|
||||
table.insert(hash1, 10);
|
||||
table.insert(hash2, 20);
|
||||
table.insert(hash3, 30);
|
||||
|
||||
let serialized = bincode::serialize(&table).unwrap();
|
||||
let deserialized: MerkleHashMap<i32> = bincode::deserialize(&serialized).unwrap();
|
||||
|
||||
assert_eq!(deserialized.len(), 3);
|
||||
assert_eq!(deserialized.get(&hash1), Some(&10));
|
||||
assert_eq!(deserialized.get(&hash2), Some(&20));
|
||||
assert_eq!(deserialized.get(&hash3), Some(&30));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serialize_deserialize_empty() {
|
||||
let table: MerkleHashMap<i32> = MerkleHashMap::new();
|
||||
|
||||
let serialized = bincode::serialize(&table).unwrap();
|
||||
let deserialized: MerkleHashMap<i32> = bincode::deserialize(&serialized).unwrap();
|
||||
|
||||
assert!(deserialized.is_empty());
|
||||
assert_eq!(deserialized.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_u64_key_hashmap() {
|
||||
type TruncatedHashMap<Value> = PassThroughHashMap<u64, Value>;
|
||||
|
||||
let mut table: TruncatedHashMap<String> = TruncatedHashMap::new();
|
||||
table.insert(12345, "value1".to_string());
|
||||
table.insert(67890, "value2".to_string());
|
||||
|
||||
assert_eq!(table.get(&12345), Some(&"value1".to_string()));
|
||||
assert_eq!(table.get(&67890), Some(&"value2".to_string()));
|
||||
assert_eq!(table.len(), 2);
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
#![cfg_attr(feature = "strict", deny(warnings))]
|
||||
|
||||
pub mod async_iterator;
|
||||
pub mod data_structures;
|
||||
pub use data_structures::{MerkleHashMap, PassThroughHashMap, TruncatedMerkleHashMap, U64HashExtractable};
|
||||
pub mod async_read;
|
||||
pub mod auth;
|
||||
pub mod errors;
|
||||
|
||||
Reference in New Issue
Block a user