Update config groups to handle more of the data management values. (#702)

This PR moves some config values that were part of the data
configuration into XetConfig, specifically the compression_policy,
staging_subdir, session_dir_name, and global_dedup_query_enabled. This
also consolidates the remaining values into a single struct with
endpoint and authentication information.
This commit is contained in:
Hoyt Koepke
2026-03-16 16:06:46 -07:00
committed by GitHub
parent ed182125fa
commit 69c714c01d
38 changed files with 950 additions and 373 deletions

View File

@@ -131,10 +131,12 @@ impl TransferAgent for XetAgent {
let headers = user_agent_headers;
let config =
default_config(cas_url, None, Some((token, token_expiry)), Some(token_refresher), Some(Arc::new(headers)))?
.disable_progress_aggregation()
.with_session_id(session_id); // upload one file at a time so no need for the heavy progress aggregator
let mut config =
default_config(cas_url, Some((token, token_expiry)), Some(token_refresher), Some(Arc::new(headers)))?
.disable_progress_aggregation();
if !session_id.is_empty() {
config.session.session_id = Some(session_id.to_owned());
}
let session = FileUploadSession::new(config.into(), Some(Arc::new(xet_updater))).await?;
let Some(file_path) = &req.path else {

View File

@@ -164,7 +164,7 @@ pub async fn clean_file(file: web_sys::File, endpoint: String, jwt_token: String
let config = TranslatorConfig {
data_config: DataConfig {
endpoint,
compression: Some(CompressionScheme::LZ4),
compression: CompressionScheme::LZ4,
auth: AuthConfig::maybe_new(Some(jwt_token), Some(expiration), None),
prefix: "default".to_owned(),
user_agent: USER_AGENT.to_string(),

View File

@@ -6,7 +6,7 @@ use xet_core_structures::xorb_object::CompressionScheme;
#[derive(Debug)]
pub struct DataConfig {
pub endpoint: String,
pub compression: Option<CompressionScheme>,
pub compression: CompressionScheme,
pub auth: Option<AuthConfig>,
pub prefix: String,
pub user_agent: String,

View File

@@ -55,7 +55,7 @@ impl XetSession {
let config = TranslatorConfig {
data_config: DataConfig {
endpoint,
compression: Some(CompressionScheme::LZ4),
compression: CompressionScheme::LZ4,
auth: Some(auth),
prefix: "default".to_owned(),
user_agent: USER_AGENT.to_string(),

View File

@@ -126,8 +126,8 @@ impl FileUploadSession {
}
// XORBs are sent without footer - the server/client reconstructs it from chunk data.
let compression_scheme = self.config.data_config.compression;
let xorb_obj = SerializedXorbObject::from_xorb(xorb, compression_scheme, false)?;
let xorb_obj =
SerializedXorbObject::from_xorb_with_compression(xorb, self.config.data_config.compression, false)?;
let Some(ref mut xorb_uploader) = *self.xorb_uploader.lock().await else {
return Err(DataProcessingError::internal("register xorb after drop"));

View File

@@ -757,7 +757,7 @@ mod tests {
let threadpool = XetRuntime::new().unwrap();
let client = RemoteClient::new(CAS_ENDPOINT, &None, "", false, None);
let xorb_obj = build_and_verify_xorb_object(raw_xorb, Some(CompressionScheme::LZ4));
let xorb_obj = build_and_verify_xorb_object(raw_xorb, CompressionScheme::LZ4);
// Act
let result = threadpool

View File

@@ -133,7 +133,7 @@ pub trait ClientTestingUtils: Client + Send + Sync {
shard.add_xorb_block(raw_xorb.xorb_info.clone())?;
let serialized_xorb = SerializedXorbObject::from_xorb(raw_xorb.clone(), None, true)?;
let serialized_xorb = SerializedXorbObject::from_xorb(raw_xorb.clone(), true)?;
let upload_permit = self.acquire_upload_permit().await?;
self.upload_xorb("default", serialized_xorb, None, upload_permit).await?;

View File

@@ -1071,6 +1071,7 @@ fn generate_v2_fetch_url(hash: &MerkleHash, ranges: &[XorbRangeDescriptor], time
}
#[cfg(test)]
mod tests {
use xet_core_structures::xorb_object::CompressionScheme;
use xet_core_structures::xorb_object::xorb_format_test_utils::{
ChunkSize, build_and_verify_xorb_object, build_raw_xorb,
};
@@ -1094,7 +1095,7 @@ mod tests {
async fn test_download_fetch_term_data_validation() {
// Setup: Create a client and upload a xorb
let xorb = build_raw_xorb(3, ChunkSize::Fixed(2048));
let xorb_obj = build_and_verify_xorb_object(xorb, None);
let xorb_obj = build_and_verify_xorb_object(xorb, CompressionScheme::Auto);
let hash = xorb_obj.hash;
let client = LocalClient::temporary().await.unwrap();

View File

@@ -73,19 +73,15 @@ fn kl_divergence(pv: &[f64], qv: &[f64]) -> f64 {
}
fn lz4_compress_size(data: &[u8]) -> usize {
serialize_chunk(
data,
&mut std::io::Empty::default(),
Some(xet_core_structures::xorb_object::CompressionScheme::LZ4),
)
.unwrap()
serialize_chunk(data, &mut std::io::Empty::default(), xet_core_structures::xorb_object::CompressionScheme::LZ4)
.unwrap()
}
fn bg4_lz4_compress_size(data: &[u8]) -> usize {
serialize_chunk(
data,
&mut std::io::Empty::default(),
Some(xet_core_structures::xorb_object::CompressionScheme::ByteGrouping4LZ4),
xet_core_structures::xorb_object::CompressionScheme::ByteGrouping4LZ4,
)
.unwrap()
}

View File

@@ -1,6 +1,7 @@
use std::borrow::Cow;
use std::fmt::Display;
use std::io::{Cursor, Read, Write, copy};
use std::str::FromStr;
use std::time::Instant;
use anyhow::anyhow;
@@ -19,12 +20,12 @@ pub static mut BG4_LZ4_DECOMPRESS_RUNTIME: f64 = 0.;
#[repr(u8)]
#[derive(Debug, PartialEq, Eq, Clone, Copy, Default)]
pub enum CompressionScheme {
#[default]
None = 0,
LZ4 = 1,
ByteGrouping4LZ4 = 2, // 4 byte groups
#[default]
Auto = 99,
}
pub const NUM_COMPRESSION_SCHEMES: usize = 3;
impl Display for CompressionScheme {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
@@ -34,9 +35,10 @@ impl Display for CompressionScheme {
impl From<&CompressionScheme> for &'static str {
fn from(value: &CompressionScheme) -> Self {
match value {
CompressionScheme::None => "none",
CompressionScheme::Auto => "auto",
CompressionScheme::LZ4 => "lz4",
CompressionScheme::ByteGrouping4LZ4 => "bg4-lz4",
CompressionScheme::None => "none",
}
}
}
@@ -55,14 +57,41 @@ impl TryFrom<u8> for CompressionScheme {
0 => Ok(CompressionScheme::None),
1 => Ok(CompressionScheme::LZ4),
2 => Ok(CompressionScheme::ByteGrouping4LZ4),
99 => Ok(CompressionScheme::Auto),
_ => Err(XorbObjectError::Format(anyhow!("cannot convert value {value} to CompressionScheme"))),
}
}
}
impl FromStr for CompressionScheme {
type Err = XorbObjectError;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.trim().to_lowercase().as_str() {
"" | "auto" => Ok(CompressionScheme::Auto),
"none" => Ok(CompressionScheme::None),
"lz4" => Ok(CompressionScheme::LZ4),
"bg4-lz4" => Ok(CompressionScheme::ByteGrouping4LZ4),
_ => Err(XorbObjectError::Format(anyhow!(
"Invalid compression scheme '{s}'. Valid values are: auto, none, lz4, bg4-lz4."
))),
}
}
}
impl CompressionScheme {
/// Resolves `Auto` to a concrete scheme by analyzing the data.
pub fn resolve_for_data(&self, data: &[u8]) -> Self {
if *self == CompressionScheme::Auto {
CompressionScheme::choose_from_data(data)
} else {
*self
}
}
pub fn compress_from_slice<'a>(&self, data: &'a [u8]) -> Result<Cow<'a, [u8]>> {
Ok(match self {
CompressionScheme::Auto => return self.resolve_for_data(data).compress_from_slice(data),
CompressionScheme::None => data.into(),
CompressionScheme::LZ4 => lz4_compress_from_slice(data).map(Cow::from)?,
CompressionScheme::ByteGrouping4LZ4 => bg4_lz4_compress_from_slice(data).map(Cow::from)?,
@@ -71,6 +100,9 @@ impl CompressionScheme {
pub fn decompress_from_slice<'a>(&self, data: &'a [u8]) -> Result<Cow<'a, [u8]>> {
Ok(match self {
CompressionScheme::Auto => {
return Err(XorbObjectError::Format(anyhow!("Cannot decompress with Auto scheme")));
},
CompressionScheme::None => data.into(),
CompressionScheme::LZ4 => lz4_decompress_from_slice(data).map(Cow::from)?,
CompressionScheme::ByteGrouping4LZ4 => bg4_lz4_decompress_from_slice(data).map(Cow::from)?,
@@ -79,6 +111,9 @@ impl CompressionScheme {
pub fn decompress_from_reader<R: Read, W: Write>(&self, reader: &mut R, writer: &mut W) -> Result<u64> {
Ok(match self {
CompressionScheme::Auto => {
return Err(XorbObjectError::Format(anyhow!("Cannot decompress with Auto scheme")));
},
CompressionScheme::None => copy(reader, writer)?,
CompressionScheme::LZ4 => lz4_decompress_from_reader(reader, writer)?,
CompressionScheme::ByteGrouping4LZ4 => bg4_lz4_decompress_from_reader(reader, writer)?,
@@ -169,19 +204,73 @@ mod tests {
use super::*;
#[test]
fn test_default_is_auto() {
assert_eq!(CompressionScheme::default(), CompressionScheme::Auto);
}
#[test]
fn test_to_str() {
assert_eq!(Into::<&str>::into(CompressionScheme::Auto), "auto");
assert_eq!(Into::<&str>::into(CompressionScheme::None), "none");
assert_eq!(Into::<&str>::into(CompressionScheme::LZ4), "lz4");
assert_eq!(Into::<&str>::into(CompressionScheme::ByteGrouping4LZ4), "bg4-lz4");
}
#[test]
fn test_from_str() {
assert_eq!("".parse::<CompressionScheme>().unwrap(), CompressionScheme::Auto);
assert_eq!("auto".parse::<CompressionScheme>().unwrap(), CompressionScheme::Auto);
assert_eq!("none".parse::<CompressionScheme>().unwrap(), CompressionScheme::None);
assert_eq!("lz4".parse::<CompressionScheme>().unwrap(), CompressionScheme::LZ4);
assert_eq!("bg4-lz4".parse::<CompressionScheme>().unwrap(), CompressionScheme::ByteGrouping4LZ4);
assert_eq!(" LZ4 ".parse::<CompressionScheme>().unwrap(), CompressionScheme::LZ4);
assert!("zstd".parse::<CompressionScheme>().is_err());
}
#[test]
fn test_display() {
assert_eq!(format!("{}", CompressionScheme::Auto), "auto");
assert_eq!(format!("{}", CompressionScheme::None), "none");
assert_eq!(format!("{}", CompressionScheme::LZ4), "lz4");
assert_eq!(format!("{}", CompressionScheme::ByteGrouping4LZ4), "bg4-lz4");
}
#[test]
fn test_parse_with_config_enum() {
use xet_runtime::utils::ConfigEnum;
let ce = ConfigEnum::new("auto", &["", "auto", "none", "lz4", "bg4-lz4"]);
let scheme: CompressionScheme = ce.parse().unwrap();
assert_eq!(scheme, CompressionScheme::Auto);
let ce = ConfigEnum::new("lz4", &["", "auto", "none", "lz4", "bg4-lz4"]);
let scheme: CompressionScheme = ce.parse().unwrap();
assert_eq!(scheme, CompressionScheme::LZ4);
let ce = ConfigEnum::new("none", &["", "auto", "none", "lz4", "bg4-lz4"]);
let scheme: CompressionScheme = ce.parse().unwrap();
assert_eq!(scheme, CompressionScheme::None);
}
#[test]
fn test_from_u8() {
assert_eq!(CompressionScheme::try_from(0u8), Ok(CompressionScheme::None));
assert_eq!(CompressionScheme::try_from(1u8), Ok(CompressionScheme::LZ4));
assert_eq!(CompressionScheme::try_from(2u8), Ok(CompressionScheme::ByteGrouping4LZ4));
assert_eq!(CompressionScheme::try_from(99u8), Ok(CompressionScheme::Auto));
assert!(CompressionScheme::try_from(3u8).is_err());
assert!(CompressionScheme::try_from(4u8).is_err());
}
#[test]
fn test_resolve_for_data() {
let data = vec![0u8; 1024];
let resolved = CompressionScheme::Auto.resolve_for_data(&data);
assert_ne!(resolved, CompressionScheme::Auto);
assert_eq!(CompressionScheme::LZ4.resolve_for_data(&data), CompressionScheme::LZ4);
assert_eq!(CompressionScheme::None.resolve_for_data(&data), CompressionScheme::None);
}
#[test]

View File

@@ -59,6 +59,7 @@ impl XorbChunkHeader {
}
pub fn set_compression_scheme(&mut self, compression_scheme: CompressionScheme) {
debug_assert_ne!(compression_scheme, CompressionScheme::Auto);
self.compression_scheme = compression_scheme as u8;
}
@@ -114,9 +115,13 @@ fn convert_three_byte_num(buf: &[u8; 3]) -> u32 {
pub fn serialize_chunk<W: Write>(
chunk: &[u8],
w: &mut W,
compression_scheme: Option<CompressionScheme>,
compression_scheme: CompressionScheme,
) -> Result<usize, XorbObjectError> {
let compression_scheme = compression_scheme.unwrap_or_else(|| CompressionScheme::choose_from_data(chunk));
let compression_scheme = compression_scheme.resolve_for_data(chunk);
debug_assert_ne!(compression_scheme, CompressionScheme::Auto);
if compression_scheme == CompressionScheme::Auto {
return Err(XorbObjectError::Format(anyhow!("CompressionScheme::Auto cannot be serialized into xorb chunks")));
}
let compressed = compression_scheme.compress_from_slice(chunk)?;
@@ -309,6 +314,16 @@ mod tests {
assert_eq!(data_copy.as_slice(), data);
}
#[test]
fn test_auto_scheme_never_serialized_in_chunk_header() {
let chunk = vec![0u8; 4096];
let mut serialized = Vec::new();
let _ = serialize_chunk(&chunk, &mut serialized, CompressionScheme::Auto).unwrap();
let header = deserialize_chunk_header(&mut Cursor::new(serialized)).unwrap();
assert_ne!(header.get_compression_scheme().unwrap(), CompressionScheme::Auto);
}
const CASES: &[(u32, ChunkSize, CompressionScheme)] = &[
(2, ChunkSize::Fixed(16), CompressionScheme::None),
(10, ChunkSize::Fixed(16), CompressionScheme::None),

View File

@@ -151,7 +151,7 @@ mod tests {
let mut out = Vec::new();
for _ in 0..num_chunks {
let data = gen_random_bytes(rng, CHUNK_SIZE as u32);
serialize_chunk(&data, &mut out, Some(compression_scheme)).unwrap();
serialize_chunk(&data, &mut out, compression_scheme).unwrap();
}
out
}

View File

@@ -1282,9 +1282,18 @@ pub struct SerializedXorbObject {
impl SerializedXorbObject {
/// Builds the xorb from raw xorb data.
pub fn from_xorb(
///
/// The compression scheme is determined by `HF_XET_XORB_COMPRESSION_POLICY`:
/// auto-detect (default) or an explicit scheme (none, lz4, bg4-lz4).
pub fn from_xorb(xorb: RawXorbData, serialize_footer: bool) -> Result<Self, XorbObjectError> {
let compression_scheme: CompressionScheme = xet_config().xorb.compression_policy.parse()?;
Self::from_xorb_with_compression(xorb, compression_scheme, serialize_footer)
}
/// Builds the xorb from raw xorb data with an explicit compression scheme override.
pub fn from_xorb_with_compression(
xorb: RawXorbData,
compression_scheme: Option<CompressionScheme>,
compression_scheme: CompressionScheme,
serialize_footer: bool,
) -> Result<Self, XorbObjectError> {
let mut xorb_object_info = XorbObjectInfoV1::default();
@@ -1320,7 +1329,7 @@ impl SerializedXorbObject {
num_chunks
};
if compression_scheme.is_none() && num_chunks != 0 {
if compression_scheme == CompressionScheme::Auto && num_chunks != 0 {
debug_assert!(xorb.file_boundaries.is_sorted());
debug_assert_ge!(xorb.file_boundaries.len(), 0);
debug_assert_lt!(*xorb.file_boundaries.last().unwrap(), xorb.data.len());
@@ -1334,10 +1343,10 @@ impl SerializedXorbObject {
// Choose the compression scheme for this block.
let compression_scheme = CompressionScheme::choose_from_data(&xorb.data[s_idx]);
debug_assert_ne!(compression_scheme, CompressionScheme::Auto);
for chunk in &xorb.data[s_idx..n_idx] {
// now serialize chunk directly to writer (since chunks come first!)
serialize_chunk(chunk, &mut serialized_data, Some(compression_scheme))?;
serialize_chunk(chunk, &mut serialized_data, compression_scheme)?;
xorb_object_info.chunk_boundary_offsets.push(serialized_data.len() as u32);
}
s_idx = n_idx;
@@ -1345,7 +1354,6 @@ impl SerializedXorbObject {
}
} else {
for chunk in xorb.data {
// now serialize chunk directly to writer (since chunks come first!)
serialize_chunk(&chunk, &mut serialized_data, compression_scheme)?;
xorb_object_info.chunk_boundary_offsets.push(serialized_data.len() as u32);
}
@@ -1387,7 +1395,7 @@ pub mod test_utils {
hash: &MerkleHash,
data: &[u8],
chunk_and_boundaries: &[(MerkleHash, u32)],
compression_scheme: Option<CompressionScheme>,
compression_scheme: CompressionScheme,
) -> Result<(XorbObject, usize, u64), XorbObjectError> {
let mut xorb = XorbObject::default();
xorb.info.xorb_hash = *hash;
@@ -1437,7 +1445,7 @@ pub mod test_utils {
hash: &MerkleHash,
data: Vec<u8>,
chunk_and_boundaries: Vec<(MerkleHash, u32)>,
compression: Option<CompressionScheme>,
compression: CompressionScheme,
) -> Result<SerializedXorbObject, XorbObjectError> {
let mut writer = Cursor::new(Vec::new());
@@ -1455,7 +1463,7 @@ pub mod test_utils {
pub fn verify_serialized_xorb_object(
xorb: &RawXorbData,
compression_scheme: Option<CompressionScheme>,
compression_scheme: CompressionScheme,
xorb_obj: &SerializedXorbObject,
) {
let xorb_hash = xorb.hash();
@@ -1544,9 +1552,10 @@ pub mod test_utils {
pub fn build_and_verify_xorb_object(
xorb: RawXorbData,
compression_scheme: Option<CompressionScheme>,
compression_scheme: CompressionScheme,
) -> SerializedXorbObject {
let xorb_obj = SerializedXorbObject::from_xorb(xorb.clone(), compression_scheme, true).unwrap();
let xorb_obj =
SerializedXorbObject::from_xorb_with_compression(xorb.clone(), compression_scheme, true).unwrap();
verify_serialized_xorb_object(&xorb, compression_scheme, &xorb_obj);
@@ -1590,7 +1599,7 @@ pub mod test_utils {
// build chunk, create ChunkInfo and keep going
let bytes_written = serialize_chunk(&bytes, &mut writer, Some(compression_scheme)).unwrap();
let bytes_written = serialize_chunk(&bytes, &mut writer, compression_scheme).unwrap();
total_bytes += bytes_written as u32;
@@ -1832,7 +1841,7 @@ mod tests {
&c.info.xorb_hash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::LZ4)
CompressionScheme::LZ4
)
.is_ok()
);
@@ -1852,7 +1861,7 @@ mod tests {
&c.info.xorb_hash,
&c_bytes,
&raw_chunk_boundaries,
Some(CompressionScheme::None)
CompressionScheme::None
)
.is_ok()
);
@@ -1890,7 +1899,7 @@ mod tests {
&c.info.xorb_hash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::LZ4)
CompressionScheme::LZ4
)
.is_ok()
);
@@ -1914,7 +1923,7 @@ mod tests {
&c.info.xorb_hash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::None)
CompressionScheme::None
)
.is_ok()
);
@@ -1935,7 +1944,7 @@ mod tests {
&c.info.xorb_hash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::None)
CompressionScheme::None
)
.is_ok()
);
@@ -1968,7 +1977,7 @@ mod tests {
&c.info.xorb_hash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::None)
CompressionScheme::None
)
.is_ok()
);
@@ -2000,7 +2009,7 @@ mod tests {
&c.info.xorb_hash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::None)
CompressionScheme::None
)
.is_ok()
);
@@ -2032,7 +2041,7 @@ mod tests {
&c.info.xorb_hash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::LZ4)
CompressionScheme::LZ4
)
.is_ok()
);
@@ -2063,7 +2072,7 @@ mod tests {
&c.info.xorb_hash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::LZ4)
CompressionScheme::LZ4
)
.is_ok()
);
@@ -2096,7 +2105,7 @@ mod tests {
&c.info.xorb_hash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::LZ4)
CompressionScheme::LZ4
)
.is_ok()
);
@@ -2128,7 +2137,7 @@ mod tests {
&c.info.xorb_hash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::LZ4)
CompressionScheme::LZ4
)
.is_ok()
);
@@ -2158,7 +2167,7 @@ mod tests {
&c.info.xorb_hash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::LZ4)
CompressionScheme::LZ4
)
.is_ok()
);
@@ -2190,7 +2199,7 @@ mod tests {
&c.info.xorb_hash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::LZ4)
CompressionScheme::LZ4
)
.is_ok()
);
@@ -2256,7 +2265,7 @@ mod tests {
&c.info.xorb_hash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::LZ4),
CompressionScheme::LZ4,
)
.is_ok()
);
@@ -2357,7 +2366,7 @@ mod tests {
&c.info.xorb_hash,
&raw_data,
&raw_chunk_boundaries,
Some(COMPRESSION_SCHEME)
COMPRESSION_SCHEME
)
.is_ok()
);
@@ -2418,7 +2427,7 @@ mod tests {
let bytes = gen_random_bytes(*size);
let chunk_hash = crate::merklehash::compute_data_hash(&bytes);
chunk_hashes_and_sizes.push((chunk_hash, bytes.len() as u64));
serialize_chunk(&bytes, &mut raw_data, Some(compression)).unwrap();
serialize_chunk(&bytes, &mut raw_data, compression).unwrap();
}
let expected_hash = crate::merklehash::xorb_hash(&chunk_hashes_and_sizes);
@@ -2470,7 +2479,7 @@ mod tests {
let bytes = gen_random_bytes(*size);
let chunk_hash = crate::merklehash::compute_data_hash(&bytes);
chunk_hashes_and_sizes.push((chunk_hash, bytes.len() as u64));
serialize_chunk(&bytes, &mut raw_data, Some(compression)).unwrap();
serialize_chunk(&bytes, &mut raw_data, compression).unwrap();
}
let expected_hash = crate::merklehash::xorb_hash(&chunk_hashes_and_sizes);
@@ -2495,7 +2504,7 @@ mod tests {
let bytes = gen_random_bytes(*size);
let chunk_hash = crate::merklehash::compute_data_hash(&bytes);
chunk_hashes_and_sizes.push((chunk_hash, bytes.len() as u64));
serialize_chunk(&bytes, &mut raw_data, Some(compression)).unwrap();
serialize_chunk(&bytes, &mut raw_data, compression).unwrap();
}
let mut output = Vec::new();
@@ -2516,7 +2525,7 @@ mod tests {
let mut raw_data = Vec::new();
let bytes = gen_random_bytes(1024);
let chunk_hash = crate::merklehash::compute_data_hash(&bytes);
serialize_chunk(&bytes, &mut raw_data, Some(CompressionScheme::LZ4)).unwrap();
serialize_chunk(&bytes, &mut raw_data, CompressionScheme::LZ4).unwrap();
let expected_hash = crate::merklehash::xorb_hash(&[(chunk_hash, 1024)]);
@@ -2533,4 +2542,12 @@ mod tests {
let deserialized = XorbObject::deserialize(&mut reader).unwrap();
assert_eq!(deserialized.info.num_chunks, 1);
}
#[test]
fn test_from_xorb_uses_default_config() {
let raw = build_raw_xorb(4, ChunkSize::Fixed(1024));
let serialized = SerializedXorbObject::from_xorb(raw, false).unwrap();
assert_eq!(serialized.num_chunks, 4);
assert!(serialized.serialized_data.len() > 0);
}
}

View File

@@ -6,7 +6,7 @@ use std::sync::{Arc, OnceLock};
use anyhow::Result;
use clap::{Args, Parser, Subcommand};
use ulid::Ulid;
use xet_data::processing::configurations::*;
use xet_data::processing::configurations::TranslatorConfig;
use xet_data::processing::{FileUploadSession, Sha256Policy, XetFileInfo};
use xet_runtime::core::XetRuntime;
@@ -127,8 +127,6 @@ async fn smudge_file(arg: &SmudgeArg) -> Result<()> {
}
async fn smudge(_name: Arc<str>, mut reader: impl Read, output_path: PathBuf) -> Result<()> {
use xet_data::processing::configurations::TranslatorConfig;
let mut input = String::new();
reader.read_to_string(&mut input)?;

View File

@@ -16,6 +16,7 @@ use xet_core_structures::xorb_object::CompressionScheme;
use xet_data::processing::data_client::default_config;
use xet_data::processing::migration_tool::hub_client_token_refresher::HubClientTokenRefresher;
use xet_data::processing::migration_tool::migrate::migrate_files_impl;
use xet_runtime::config::XetConfig;
use xet_runtime::core::XetRuntime;
const DEFAULT_HF_ENDPOINT: &str = "https://huggingface.co";
@@ -127,16 +128,8 @@ impl Command {
let file_paths = walk_files(arg.files, arg.recursive);
eprintln!("Dedupping {} files...", file_paths.len());
let (all_file_info, clean_ret, total_bytes_trans) = migrate_files_impl(
file_paths,
None,
arg.sequential,
hub_client,
None,
arg.compression.and_then(|c| CompressionScheme::try_from(c).ok()),
!arg.migrate,
)
.await?;
let (all_file_info, clean_ret, total_bytes_trans) =
migrate_files_impl(file_paths, None, arg.sequential, hub_client, None, !arg.migrate).await?;
// Print file info for analysis
if !arg.migrate {
@@ -216,19 +209,12 @@ async fn query_reconstruction(
let config = default_config(
jwt_info.cas_url.clone(),
None,
Some((jwt_info.access_token, jwt_info.exp)),
Some(token_refresher),
Some(Arc::new(headers)),
)?;
let cas_storage_config = &config.data_config;
let remote_client = RemoteClient::new(
&jwt_info.cas_url,
&cas_storage_config.auth,
"",
true,
cas_storage_config.custom_headers.clone(),
);
let remote_client =
RemoteClient::new(&jwt_info.cas_url, &config.session.auth, "", true, config.session.custom_headers.clone());
// Use V1 directly so the query tool returns the raw QueryReconstructionResponse for inspection.
remote_client
@@ -239,7 +225,22 @@ async fn query_reconstruction(
fn main() -> Result<()> {
let cli = XCommand::parse();
let threadpool = XetRuntime::new()?;
let mut config = XetConfig::new();
if let Command::Dedup(ref arg) = cli.command
&& let Some(c) = arg.compression
{
let scheme = CompressionScheme::try_from(c).map_err(|_| {
anyhow::anyhow!("Invalid compression value {c}; expected one of: 0 (none), 1 (lz4), 2 (bg4-lz4), 99 (auto)")
})?;
config
.xorb
.compression_policy
.try_set(<&str>::from(scheme))
.map_err(|e| anyhow::anyhow!(e))?;
}
let threadpool = XetRuntime::new_with_config(config)?;
threadpool.external_run_async_task(async move { cli.run().await })??;
Ok(())

View File

@@ -1,196 +1,237 @@
use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::sync::Arc;
use http::HeaderMap;
use tracing::info;
use xet_client::cas_client::auth::AuthConfig;
use xet_client::cas_client::remote_client::PREFIX_DEFAULT;
use xet_core_structures::xorb_object::CompressionScheme;
use xet_runtime::core::{xet_cache_root, xet_config};
use super::errors::Result;
/// Session-specific configuration that varies per upload/download session.
/// These are runtime values that cannot be configured via environment variables.
#[derive(Debug)]
pub enum Endpoint {
Server(String),
FileSystem(PathBuf),
InMemory,
}
#[derive(Debug)]
pub struct DataConfig {
pub endpoint: Endpoint,
pub compression: Option<CompressionScheme>,
pub struct SessionContext {
/// The endpoint URL. Use the `local://` prefix (configurable via `HF_XET_DATA_LOCAL_CAS_SCHEME`)
/// to specify a local filesystem path, or `memory://` for in-memory storage.
pub endpoint: String,
pub auth: Option<AuthConfig>,
pub prefix: String,
pub staging_directory: Option<PathBuf>,
pub custom_headers: Option<Arc<HeaderMap>>,
}
#[derive(Debug)]
pub struct GlobalDedupConfig {
pub global_dedup_policy: GlobalDedupPolicy,
}
#[derive(Debug)]
pub struct RepoInfo {
pub repo_paths: Vec<String>,
pub session_id: Option<String>,
}
#[derive(PartialEq, Default, Clone, Debug, Copy)]
pub enum GlobalDedupPolicy {
/// Never query for new shards using chunk hashes.
Never,
impl SessionContext {
/// Returns true if this endpoint points to a local filesystem path.
pub fn is_local(&self) -> bool {
self.endpoint.starts_with(&xet_config().data.local_cas_scheme)
}
/// Always query for new shards by chunks
#[default]
Always,
}
/// Returns the local filesystem path if this is a local endpoint.
pub fn local_path(&self) -> Option<PathBuf> {
let path = self.endpoint.strip_prefix(&xet_config().data.local_cas_scheme)?;
Some(PathBuf::from(path))
}
impl FromStr for GlobalDedupPolicy {
type Err = std::io::Error;
/// Returns true if this endpoint uses in-memory storage.
pub fn is_memory(&self) -> bool {
self.endpoint == "memory://"
}
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"never" => Ok(GlobalDedupPolicy::Never),
"always" => Ok(GlobalDedupPolicy::Always),
_ => Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("Invalid global dedup query policy, should be one of never, direct_only, always: {s}"),
)),
/// Creates a SessionContext for local filesystem-based operations.
pub fn for_local_path(base_dir: impl AsRef<Path>) -> Self {
let path = base_dir.as_ref().to_path_buf();
let endpoint = format!("{}{}", xet_config().data.local_cas_scheme, path.display());
Self {
endpoint,
auth: None,
custom_headers: None,
repo_paths: vec!["".into()],
session_id: None,
}
}
/// Creates a SessionContext for in-memory storage.
pub fn for_memory() -> Self {
Self {
endpoint: "memory://".into(),
auth: None,
custom_headers: None,
repo_paths: vec!["".into()],
session_id: None,
}
}
}
#[derive(Debug)]
pub struct ShardConfig {
pub prefix: String,
pub session_directory: PathBuf,
pub cache_directory: PathBuf,
pub global_dedup_policy: GlobalDedupPolicy,
}
#[derive(Debug)]
pub struct ProgressConfig {
pub aggregate: bool,
}
/// Main configuration for file upload/download operations.
/// Combines session-specific values with runtime-computed paths derived from the endpoint.
#[derive(Debug)]
pub struct TranslatorConfig {
pub data_config: DataConfig,
pub shard_config: ShardConfig,
pub repo_info: Option<RepoInfo>,
pub session_id: Option<String>,
pub progress_config: ProgressConfig,
pub session: SessionContext,
/// Directory for caching shard files.
pub shard_cache_directory: PathBuf,
/// Directory for session-specific shard files.
pub shard_session_directory: PathBuf,
/// Per-session override: when true, progress aggregation is disabled
/// regardless of the global `HF_XET_DATA_AGGREGATE_PROGRESS` config value.
pub force_disable_progress_aggregation: bool,
}
impl TranslatorConfig {
pub fn local_config(base_dir: impl AsRef<Path>) -> Result<Self> {
let path = base_dir.as_ref().join("xet");
std::fs::create_dir_all(&path)?;
fn create_base_xet_dir(base_dir: impl AsRef<Path>) -> Result<PathBuf> {
let base_path = base_dir.as_ref().join("xet");
std::fs::create_dir_all(&base_path)?;
Ok(base_path)
}
let translator_config = Self {
data_config: DataConfig {
endpoint: Endpoint::FileSystem(path.join("xorbs")),
compression: Default::default(),
auth: None,
prefix: PREFIX_DEFAULT.into(),
staging_directory: None,
custom_headers: None,
},
shard_config: ShardConfig {
prefix: PREFIX_DEFAULT.into(),
cache_directory: path.join("shard-cache"),
session_directory: path.join("shard-session"),
global_dedup_policy: Default::default(),
},
repo_info: Some(RepoInfo {
repo_paths: vec!["".into()],
}),
session_id: None,
progress_config: ProgressConfig { aggregate: true },
/// Creates a new TranslatorConfig from a SessionContext, computing all derived paths.
pub fn new(session: SessionContext) -> Result<Self> {
let config = xet_config();
let (shard_cache_directory, shard_session_directory) = if let Some(local_path) = session.local_path() {
let base_path = local_path.join("xet");
std::fs::create_dir_all(&base_path)?;
(base_path.join(&config.shard.cache_subdir), base_path.join(&config.session.dir_name))
} else if session.is_memory() {
let cache_path = xet_cache_root().join("memory");
std::fs::create_dir_all(&cache_path)?;
(cache_path.join(&config.shard.cache_subdir), cache_path.join(&config.session.dir_name))
} else {
let cache_path = compute_cache_path(&session.endpoint);
std::fs::create_dir_all(&cache_path)?;
let staging_directory = cache_path.join(&config.data.staging_subdir);
std::fs::create_dir_all(&staging_directory)?;
(cache_path.join(&config.shard.cache_subdir), staging_directory.join(&config.session.dir_name))
};
Ok(translator_config)
info!(
endpoint = %session.endpoint,
session_id = ?session.session_id,
shard_cache = %shard_cache_directory.display(),
shard_session = %shard_session_directory.display(),
"TranslatorConfig initialized"
);
Ok(Self {
session,
shard_cache_directory,
shard_session_directory,
force_disable_progress_aggregation: false,
})
}
/// Creates a TranslatorConfig for local filesystem-based storage.
pub fn local_config(base_dir: impl AsRef<Path>) -> Result<Self> {
Self::new(SessionContext::for_local_path(base_dir))
}
/// Creates a TranslatorConfig that uses in-memory storage for XORBs.
/// Shard data still uses file-based storage in the provided base directory.
pub fn memory_config(base_dir: impl AsRef<Path>) -> Result<Self> {
let path = base_dir.as_ref().join("xet");
std::fs::create_dir_all(&path)?;
let session = SessionContext::for_memory();
let config = xet_config();
let base_path = Self::create_base_xet_dir(base_dir)?;
let translator_config = Self {
data_config: DataConfig {
endpoint: Endpoint::InMemory,
compression: Default::default(),
auth: None,
prefix: PREFIX_DEFAULT.into(),
staging_directory: None,
custom_headers: None,
},
shard_config: ShardConfig {
prefix: PREFIX_DEFAULT.into(),
cache_directory: path.join("shard-cache"),
session_directory: path.join("shard-session"),
global_dedup_policy: Default::default(),
},
repo_info: Some(RepoInfo {
repo_paths: vec!["".into()],
}),
session_id: None,
progress_config: ProgressConfig { aggregate: true },
};
Ok(translator_config)
Ok(Self {
session,
shard_cache_directory: base_path.join(&config.shard.cache_subdir),
shard_session_directory: base_path.join(&config.session.dir_name),
force_disable_progress_aggregation: false,
})
}
/// Creates a TranslatorConfig that connects to a CAS server at the given endpoint.
/// Shard cache and session directories are created under the provided base directory.
/// Useful for tests that use LocalTestServer.
pub fn test_server_config(endpoint: impl AsRef<str>, base_dir: impl AsRef<Path>) -> Result<Self> {
let path = base_dir.as_ref().join("xet");
std::fs::create_dir_all(&path)?;
let translator_config = Self {
data_config: DataConfig {
endpoint: Endpoint::Server(endpoint.as_ref().to_string()),
compression: Default::default(),
auth: None,
prefix: PREFIX_DEFAULT.into(),
staging_directory: None,
custom_headers: None,
},
shard_config: ShardConfig {
prefix: PREFIX_DEFAULT.into(),
cache_directory: path.join("shard-cache"),
session_directory: path.join("shard-session"),
global_dedup_policy: Default::default(),
},
repo_info: Some(RepoInfo {
repo_paths: vec!["".into()],
}),
let session = SessionContext {
endpoint: endpoint.as_ref().to_string(),
auth: None,
custom_headers: None,
repo_paths: vec!["".into()],
session_id: None,
progress_config: ProgressConfig { aggregate: true },
};
let config = xet_config();
let base_path = Self::create_base_xet_dir(base_dir)?;
Ok(translator_config)
Ok(Self {
session,
shard_cache_directory: base_path.join(&config.shard.cache_subdir),
shard_session_directory: base_path.join(&config.session.dir_name),
force_disable_progress_aggregation: false,
})
}
pub fn disable_progress_aggregation(self) -> Self {
Self {
progress_config: ProgressConfig { aggregate: false },
..self
}
}
pub fn with_session_id(self, session_id: &str) -> Self {
if session_id.is_empty() {
return self;
}
Self {
session_id: Some(session_id.to_owned()),
..self
}
pub fn disable_progress_aggregation(mut self) -> Self {
self.force_disable_progress_aggregation = true;
self
}
}
/// Computes a cache-safe path from an endpoint URL.
fn compute_cache_path(endpoint: &str) -> PathBuf {
let cache_root = xet_cache_root();
let endpoint_prefix = endpoint
.chars()
.take(16)
.map(|c| if c.is_alphanumeric() { c } else { '_' })
.collect::<String>();
let endpoint_hash = xet_core_structures::merklehash::compute_data_hash(endpoint.as_bytes()).base64();
let endpoint_tag = format!("{endpoint_prefix}-{}", &endpoint_hash[..16]);
cache_root.join(endpoint_tag)
}
#[cfg(test)]
mod tests {
use tempfile::tempdir;
use super::{SessionContext, TranslatorConfig};
#[test]
fn test_session_context_mode_detection() {
let temp_dir = tempdir().unwrap();
let local_session = SessionContext::for_local_path(temp_dir.path());
assert!(local_session.is_local());
assert!(!local_session.is_memory());
assert_eq!(local_session.local_path().unwrap(), temp_dir.path().to_path_buf());
let memory_session = SessionContext::for_memory();
assert!(memory_session.is_memory());
assert!(!memory_session.is_local());
assert!(memory_session.local_path().is_none());
let remote_session = SessionContext {
endpoint: "http://localhost:8080".into(),
auth: None,
custom_headers: None,
repo_paths: Vec::new(),
session_id: None,
};
assert!(!remote_session.is_local());
assert!(!remote_session.is_memory());
assert!(remote_session.local_path().is_none());
}
#[test]
fn test_memory_and_server_configs_use_base_xet_layout() {
let temp_dir = tempdir().unwrap();
let memory_config = TranslatorConfig::memory_config(temp_dir.path()).unwrap();
assert!(memory_config.shard_cache_directory.starts_with(temp_dir.path().join("xet")));
assert!(memory_config.shard_session_directory.starts_with(temp_dir.path().join("xet")));
let server_config = TranslatorConfig::test_server_config("http://localhost:8080", temp_dir.path()).unwrap();
assert!(server_config.shard_cache_directory.starts_with(temp_dir.path().join("xet")));
assert!(server_config.shard_session_directory.starts_with(temp_dir.path().join("xet")));
}
}

View File

@@ -6,16 +6,14 @@ use std::sync::Arc;
use bytes::Bytes;
use http::header::HeaderMap;
use itertools::multizip;
use tracing::{Instrument, Span, info, info_span, instrument};
use tracing::{Instrument, Span, info_span, instrument};
use ulid::Ulid;
use xet_client::cas_client::auth::{AuthConfig, TokenRefresher};
use xet_client::cas_client::remote_client::PREFIX_DEFAULT;
use xet_core_structures::merklehash::MerkleHash;
use xet_core_structures::xorb_object::CompressionScheme;
use xet_runtime::core::par_utils::run_constrained_with_semaphore;
use xet_runtime::core::{XetRuntime, check_sigint_shutdown, xet_cache_root, xet_config};
use xet_runtime::core::{XetRuntime, check_sigint_shutdown, xet_config};
use super::configurations::*;
use super::configurations::{SessionContext, TranslatorConfig};
use super::errors::DataProcessingError;
use super::file_cleaner::Sha256Policy;
use super::file_download_session::FileDownloadSession;
@@ -25,70 +23,22 @@ use crate::progress_tracking::TrackingProgressUpdater;
pub fn default_config(
endpoint: String,
xorb_compression: Option<CompressionScheme>,
token_info: Option<(String, u64)>,
token_refresher: Option<Arc<dyn TokenRefresher>>,
custom_headers: Option<Arc<HeaderMap>>,
) -> errors::Result<TranslatorConfig> {
// Intercept local:// to run a simulated CAS server in a specified directory.
// This is useful for testing and development.
if endpoint.starts_with("local://") {
let local_path = endpoint.strip_prefix("local://").unwrap();
let local_path = PathBuf::from(local_path);
std::fs::create_dir_all(&local_path)?;
return TranslatorConfig::local_config(local_path);
}
let cache_root_path = xet_cache_root();
info!("Using cache path {cache_root_path:?}.");
let (token, token_expiration) = token_info.unzip();
let auth_cfg = AuthConfig::maybe_new(token, token_expiration, token_refresher);
// Calculate a fingerprint of the current endpoint to make sure caches stay separated.
let endpoint_tag = {
let endpoint_prefix = endpoint
.chars()
.take(16)
.map(|c| if c.is_alphanumeric() { c } else { '_' })
.collect::<String>();
// If more gets added
let endpoint_hash = xet_core_structures::merklehash::compute_data_hash(endpoint.as_bytes()).base64();
format!("{endpoint_prefix}-{}", &endpoint_hash[..16])
};
let cache_path = cache_root_path.join(endpoint_tag);
std::fs::create_dir_all(&cache_path)?;
let staging_root = cache_path.join("staging");
std::fs::create_dir_all(&staging_root)?;
let translator_config = TranslatorConfig {
data_config: DataConfig {
endpoint: Endpoint::Server(endpoint.clone()),
compression: xorb_compression,
auth: auth_cfg.clone(),
prefix: PREFIX_DEFAULT.into(),
staging_directory: None,
custom_headers,
},
shard_config: ShardConfig {
prefix: PREFIX_DEFAULT.into(),
cache_directory: cache_path.join("shard-cache"),
session_directory: staging_root.join("shard-session"),
global_dedup_policy: Default::default(),
},
repo_info: Some(RepoInfo {
repo_paths: vec!["".into()],
}),
let session = SessionContext {
endpoint,
auth: auth_cfg,
custom_headers,
repo_paths: vec!["".into()],
session_id: Some(Ulid::new().to_string()),
progress_config: ProgressConfig { aggregate: true },
};
// Return the temp dir so that it's not dropped and thus the directory deleted.
Ok(translator_config)
TranslatorConfig::new(session)
}
#[instrument(skip_all, name = "data_client::upload_bytes", fields(session_id = tracing::field::Empty, num_files=file_contents.len()))]
@@ -111,13 +61,12 @@ pub async fn upload_bytes_async(
let config = default_config(
endpoint.unwrap_or_else(|| xet_config().data.default_cas_endpoint.clone()),
None,
token_info,
token_refresher,
custom_headers,
)?;
Span::current().record("session_id", &config.session_id);
Span::current().record("session_id", &config.session.session_id);
let semaphore = XetRuntime::current().common().file_ingestion_semaphore.clone();
let upload_session = FileUploadSession::new(config.into(), progress_updater).await?;
@@ -168,7 +117,6 @@ pub async fn upload_async(
// for each file, return the filehash
let config = default_config(
endpoint.unwrap_or_else(|| xet_config().data.default_cas_endpoint.clone()),
None,
token_info,
token_refresher,
custom_headers,
@@ -176,7 +124,7 @@ pub async fn upload_async(
let span = Span::current();
span.record("session_id", &config.session_id);
span.record("session_id", &config.session.session_id);
let upload_session = FileUploadSession::new(config.into(), progress_updater).await?;
@@ -215,14 +163,13 @@ pub async fn download_async(
}
let config: Arc<TranslatorConfig> = default_config(
endpoint.unwrap_or_else(|| xet_config().data.default_cas_endpoint.clone()),
None,
token_info,
token_refresher,
custom_headers,
)?
.into();
Span::current().record("session_id", &config.session_id);
Span::current().record("session_id", &config.session.session_id);
let updaters = match progress_updaters {
None => vec![None; file_infos.len()],
@@ -424,11 +371,11 @@ mod tests {
let _hf_home_guard = EnvVarGuard::set("HF_HOME", temp_dir.path().to_str().unwrap());
let endpoint = "http://localhost:8080".to_string();
let result = default_config(endpoint, None, None, None, None);
let result = default_config(endpoint, None, None, None);
assert!(result.is_ok());
let config = result.unwrap();
assert!(config.shard_config.cache_directory.starts_with(temp_dir.path()));
assert!(config.shard_cache_directory.starts_with(temp_dir.path()));
}
#[test]
@@ -441,11 +388,11 @@ mod tests {
let hf_home_guard = EnvVarGuard::set("HF_HOME", temp_dir_hf_home.path().to_str().unwrap());
let endpoint = "http://localhost:8080".to_string();
let result = default_config(endpoint, None, None, None, None);
let result = default_config(endpoint, None, None, None);
assert!(result.is_ok());
let config = result.unwrap();
assert!(config.shard_config.cache_directory.starts_with(temp_dir_xet_cache.path()));
assert!(config.shard_cache_directory.starts_with(temp_dir_xet_cache.path()));
drop(hf_xet_cache_guard);
drop(hf_home_guard);
@@ -454,11 +401,11 @@ mod tests {
let _hf_home_guard = EnvVarGuard::set("HF_HOME", temp_dir.path().to_str().unwrap());
let endpoint = "http://localhost:8080".to_string();
let result = default_config(endpoint, None, None, None, None);
let result = default_config(endpoint, None, None, None);
assert!(result.is_ok());
let config = result.unwrap();
assert!(config.shard_config.cache_directory.starts_with(temp_dir.path()));
assert!(config.shard_cache_directory.starts_with(temp_dir.path()));
}
#[test]
@@ -468,24 +415,24 @@ mod tests {
let _hf_xet_cache_guard = EnvVarGuard::set("HF_XET_CACHE", temp_dir.path().to_str().unwrap());
let endpoint = "http://localhost:8080".to_string();
let result = default_config(endpoint, None, None, None, None);
let result = default_config(endpoint, None, None, None);
assert!(result.is_ok());
let config = result.unwrap();
assert!(config.shard_config.cache_directory.starts_with(temp_dir.path()));
assert!(config.shard_cache_directory.starts_with(temp_dir.path()));
}
#[test]
#[serial(default_config_env)]
fn test_default_config_without_env_vars() {
let endpoint = "http://localhost:8080".to_string();
let result = default_config(endpoint, None, None, None, None);
let result = default_config(endpoint, None, None, None);
let expected = home_dir().unwrap().join(".cache").join("huggingface").join("xet");
assert!(result.is_ok());
let config = result.unwrap();
let test_cache_dir = &config.shard_config.cache_directory;
let test_cache_dir = &config.shard_cache_directory;
assert!(
test_cache_dir.starts_with(&expected),
"cache dir = {test_cache_dir:?}; does not start with {expected:?}",

View File

@@ -6,7 +6,6 @@ use tracing::Instrument;
use xet_core_structures::merklehash::MerkleHash;
use xet_core_structures::metadata_shard::file_structs::FileDataSequenceEntry;
use super::configurations::GlobalDedupPolicy;
use super::errors::Result;
use super::file_upload_session::FileUploadSession;
use crate::deduplication::{DeduplicationDataInterface, RawXorbData};
@@ -26,7 +25,7 @@ impl UploadSessionDataManager {
}
fn global_dedup_queries_enabled(&self) -> bool {
matches!(self.session.config.shard_config.global_dedup_policy, GlobalDedupPolicy::Always)
xet_runtime::core::xet_config().deduplication.global_dedup_query_enabled
}
}

View File

@@ -36,6 +36,7 @@ impl FileDownloadSession {
progress_updater: Option<Arc<dyn TrackingProgressUpdater>>,
) -> Result<Arc<Self>> {
let session_id = config
.session
.session_id
.as_ref()
.map(Cow::Borrowed)

View File

@@ -16,7 +16,7 @@ use xet_core_structures::metadata_shard::file_structs::MDBFileInfo;
use xet_core_structures::xorb_object::SerializedXorbObject;
use xet_runtime::core::{XetRuntime, xet_config};
use super::configurations::*;
use super::configurations::TranslatorConfig;
use super::errors::*;
use super::file_cleaner::{Sha256Policy, SingleFileCleaner};
use super::remote_client_interface::create_remote_client;
@@ -37,13 +37,9 @@ use crate::progress_tracking::{NoOpProgressUpdater, TrackingProgressUpdater};
/// that succeeds or fails as a unit; i.e. all files get uploaded on finalization, and all shards
/// and xorbs needed to reconstruct those files are properly uploaded and registered.
pub struct FileUploadSession {
// The parts of this that manage the
pub(crate) client: Arc<dyn Client + Send + Sync>,
pub(crate) shard_interface: SessionShardInterface,
/// The configuration settings, if needed.
pub(crate) config: Arc<TranslatorConfig>,
/// Tracking upload completion between xorbs and files.
pub(crate) completion_tracker: Arc<CompletionTracker>,
@@ -85,6 +81,7 @@ impl FileUploadSession {
dry_run: bool,
) -> Result<Arc<FileUploadSession>> {
let session_id = config
.session
.session_id
.as_ref()
.map(Cow::Borrowed)
@@ -94,7 +91,10 @@ impl FileUploadSession {
match upload_progress_updater {
Some(updater) => {
let flush_interval = xet_config().data.progress_update_interval;
if !flush_interval.is_zero() && config.progress_config.aggregate {
if !flush_interval.is_zero()
&& xet_config().data.aggregate_progress
&& !config.force_disable_progress_aggregation
{
let aggregator = AggregatingProgressUpdater::new(
updater,
flush_interval,
@@ -127,7 +127,6 @@ impl FileUploadSession {
Ok(Arc::new(Self {
shard_interface,
client,
config,
completion_tracker,
progress_aggregator,
current_session_data: Mutex::new(DataAggregator::default()),
@@ -320,14 +319,13 @@ impl FileUploadSession {
// Serialize the object; this can be relatively expensive, so run it on a compute thread.
// XORBs are sent without footer - the server/client reconstructs it from chunk data.
let compression_scheme = self.config.data_config.compression;
let xorb_obj = XetRuntime::current()
.spawn_blocking(move || SerializedXorbObject::from_xorb(xorb, compression_scheme, false))
.spawn_blocking(move || SerializedXorbObject::from_xorb(xorb, false))
.await??;
let session = self.clone();
let upload_permit = self.client.acquire_upload_permit().await?;
let cas_prefix = session.config.data_config.prefix.clone();
let cas_prefix = xet_config().data.default_prefix.clone();
let completion_tracker = self.completion_tracker.clone();
let xorb_hash = xorb_obj.hash;
let raw_num_bytes = xorb_obj.raw_num_bytes;

View File

@@ -6,7 +6,6 @@ use tracing::{Instrument, Span, info_span, instrument};
use xet_client::cas_client::auth::TokenRefresher;
use xet_client::hub_client::{BearerCredentialHelper, HubClient, Operation, RepoInfo};
use xet_core_structures::metadata_shard::file_structs::MDBFileInfo;
use xet_core_structures::xorb_object::CompressionScheme;
use xet_runtime::core::XetRuntime;
use xet_runtime::core::par_utils::run_constrained;
@@ -48,7 +47,7 @@ pub async fn migrate_with_external_runtime(
Some(Arc::new(headers)),
)?;
migrate_files_impl(file_paths, sha256s, false, hub_client, cas_endpoint, None, false).await?;
migrate_files_impl(file_paths, sha256s, false, hub_client, cas_endpoint, false).await?;
Ok(())
}
@@ -63,7 +62,6 @@ pub async fn migrate_files_impl(
sequential: bool,
hub_client: HubClient,
cas_endpoint: Option<String>,
compression: Option<CompressionScheme>,
dry_run: bool,
) -> Result<MigrationInfo> {
let operation = Operation::Upload;
@@ -80,12 +78,11 @@ pub async fn migrate_files_impl(
let config = default_config(
cas,
compression,
Some((jwt_info.access_token, jwt_info.exp)),
Some(token_refresher),
Some(Arc::new(headers)),
)?;
Span::current().record("session_id", &config.session_id);
Span::current().record("session_id", &config.session.session_id);
let num_workers = if sequential {
1

View File

@@ -2,7 +2,7 @@ use std::sync::Arc;
use xet_client::cas_client::{Client, RemoteClient};
use super::configurations::*;
use super::configurations::TranslatorConfig;
use super::errors::Result;
pub(crate) async fn create_remote_client(
@@ -10,31 +10,30 @@ pub(crate) async fn create_remote_client(
session_id: &str,
dry_run: bool,
) -> Result<Arc<dyn Client>> {
let cas_storage_config = &config.data_config;
let session = &config.session;
match cas_storage_config.endpoint {
Endpoint::Server(ref endpoint) => Ok(RemoteClient::new(
endpoint,
&cas_storage_config.auth,
if let Some(local_path) = session.local_path() {
#[cfg(not(target_family = "wasm"))]
{
let xorb_path = local_path.join("xet").join("xorbs");
Ok(xet_client::cas_client::LocalClient::new(xorb_path).await?)
}
#[cfg(target_family = "wasm")]
unimplemented!("Local file system access is not supported in WASM builds")
} else if session.is_memory() {
#[cfg(not(target_family = "wasm"))]
{
Ok(xet_client::cas_client::MemoryClient::new())
}
#[cfg(target_family = "wasm")]
unimplemented!("In-memory client is not supported in WASM builds")
} else {
Ok(RemoteClient::new(
&session.endpoint,
&session.auth,
session_id,
dry_run,
cas_storage_config.custom_headers.clone(),
)),
Endpoint::FileSystem(ref path) => {
#[cfg(not(target_family = "wasm"))]
{
Ok(xet_client::cas_client::LocalClient::new(path).await?)
}
#[cfg(target_family = "wasm")]
unimplemented!("Local file system access is not supported in WASM builds")
},
Endpoint::InMemory => {
#[cfg(not(target_family = "wasm"))]
{
Ok(xet_client::cas_client::MemoryClient::new())
}
#[cfg(target_family = "wasm")]
unimplemented!("In-memory client is not supported in WASM builds")
},
session.custom_headers.clone(),
))
}
}

View File

@@ -32,7 +32,6 @@ pub struct SessionShardInterface {
cache_shard_manager: Arc<ShardFileManager>,
client: Arc<dyn Client + Send + Sync>,
config: Arc<TranslatorConfig>,
dry_run: bool,
@@ -60,17 +59,17 @@ impl SessionShardInterface {
dry_run: bool,
) -> Result<Self> {
// Create a temporary session directory where we hold all the shards before upload.
std::fs::create_dir_all(&config.shard_config.session_directory)?;
let shard_session_tempdir = TempDir::new_in(&config.shard_config.session_directory)?;
std::fs::create_dir_all(&config.shard_session_directory)?;
let shard_session_tempdir = TempDir::new_in(&config.shard_session_directory)?;
let session_dir = shard_session_tempdir.path().to_owned();
// Set up the cache dir.
let cache_dir = &config.shard_config.cache_directory;
let cache_dir = &config.shard_cache_directory;
std::fs::create_dir_all(cache_dir)?;
// Set up the shard session directory.
let xorb_metadata_staging_dir = config.shard_config.session_directory.join("xorb_metadata");
let xorb_metadata_staging_dir = config.shard_session_directory.join("xorb_metadata");
std::fs::create_dir_all(&xorb_metadata_staging_dir)?;
// To allow resume from previous session attempts, merge and copy all the valid shards in the xorb metadata
@@ -129,7 +128,6 @@ impl SessionShardInterface {
staged_shards_to_remove_on_success,
xorb_metadata_staging: Mutex::new((SystemTime::now(), MDBInMemoryShard::default())),
resumed_session_shard_manager,
config,
dry_run,
_shard_session_dir: shard_session_tempdir,
client,
@@ -140,7 +138,7 @@ impl SessionShardInterface {
pub async fn query_dedup_shard_by_chunk(&self, chunk_hash: &MerkleHash) -> Result<bool> {
let Ok(Some(new_shard)) = self
.client
.query_for_global_dedup_shard(&self.config.shard_config.prefix, chunk_hash)
.query_for_global_dedup_shard(&xet_config().data.default_prefix, chunk_hash)
.await
.info_error("Error attempting to query global dedup lookup.")
else {

View File

@@ -9,14 +9,19 @@ pub(super) fn create_translator_config(session: &XetSession) -> Result<Translato
.clone()
.unwrap_or_else(|| session.config.data.default_cas_endpoint.clone());
Ok(xet_data::processing::data_client::default_config(
let mut config = xet_data::processing::data_client::default_config(
endpoint,
None, // xorb_compression
session.token_info.clone(),
session.token_refresher.clone(),
session.custom_headers.clone(),
)?
.with_session_id(&session.id.to_string()))
)?;
let session_id = session.id.to_string();
if !session_id.is_empty() {
config.session.session_id = Some(session_id);
}
Ok(config)
}
/// State of the upload commit and download group

View File

@@ -23,6 +23,12 @@ pub const ENVIRONMENT_NAME_ALIASES: &[(&str, &str)] = &[
("HF_XET_SHARD_CACHE_SIZE_LIMIT", "HF_XET_MDB_SHARD_CACHE_SIZE_LIMIT"),
("HF_XET_SHARD_CHUNK_INDEX_TABLE_MAX_SIZE", "HF_XET_MDB_SHARD_CHUNK_INDEX_TABLE_MAX_SIZE"),
("HF_XET_SHARD_CHUNK_INDEX_TABLE_MAX_SIZE", "HF_XET_CHUNK_INDEX_TABLE_MAX_SIZE"),
// xorb compression fields were previously in the data group; support old HF_XET_DATA_XORB_* names
(
"HF_XET_XORB_COMPRESSION_SCHEME_RETEST_INTERVAL",
"HF_XET_DATA_XORB_COMPRESSION_SCHEME_RETEST_INTERVAL",
),
("HF_XET_XORB_COMPRESSION_POLICY", "HF_XET_DATA_XORB_COMPRESSION_POLICY"),
// Fixed concurrency aliases - these set initial, min, and max to the same value
("HF_XET_CLIENT_AC_INITIAL_UPLOAD_CONCURRENCY", "HF_XET_FIXED_UPLOAD_CONCURRENCY"),
("HF_XET_CLIENT_AC_MIN_UPLOAD_CONCURRENCY", "HF_XET_FIXED_UPLOAD_CONCURRENCY"),

View File

@@ -80,4 +80,27 @@ crate::config_group!({
/// Use the environment variable `HF_XET_DATA_DEFAULT_CAS_ENDPOINT` to set this value.
ref default_cas_endpoint: String = "http://localhost:8080".to_string();
/// Whether to aggregate progress updates before sending them.
/// When enabled, progress updates are batched and sent at regular intervals
/// to reduce overhead.
///
/// The default value is true.
///
/// Use the environment variable `HF_XET_DATA_AGGREGATE_PROGRESS` to set this value.
ref aggregate_progress: bool = true;
/// Default prefix used for CAS and shard operations.
///
/// The default value is "default".
///
/// Use the environment variable `HF_XET_DATA_DEFAULT_PREFIX` to set this value.
ref default_prefix: String = "default".to_string();
/// Subdirectory name for staging data within the endpoint cache directory.
///
/// The default value is "staging".
///
/// Use the environment variable `HF_XET_DATA_STAGING_SUBDIR` to set this value.
ref staging_subdir: String = "staging".to_string();
});

View File

@@ -23,4 +23,13 @@ crate::config_group!({
///
/// Use the environment variable `HF_XET_DEDUPLICATION_MIN_N_CHUNKS_PER_RANGE` to set this value.
ref min_n_chunks_per_range: f32 = 8.0;
/// Whether to enable global deduplication queries to the server.
/// When enabled, the system will query the server for deduplication shards
/// based on chunk hashes to enable cross-repository deduplication.
///
/// The default value is true.
///
/// Use the environment variable `HF_XET_DEDUPLICATION_GLOBAL_DEDUP_QUERY_ENABLED` to set this value.
ref global_dedup_query_enabled: bool = true;
});

View File

@@ -0,0 +1,9 @@
crate::config_group!({
/// Subdirectory name for shard session data within the staging directory.
///
/// The default value is "shard-session".
///
/// Use the environment variable `HF_XET_SESSION_DIR_NAME` to set this value.
ref dir_name: String = "shard-session".to_string();
});

View File

@@ -26,7 +26,7 @@ crate::config_group!({
///
/// The default value is 16gb.
///
/// Use the environment variable `HF_XET_METADATA_SHARD_CACHE_SIZE_LIMIT` to set this value.
/// Use the environment variable `HF_XET_SHARD_CACHE_SIZE_LIMIT` to set this value.
ref cache_size_limit : ByteSize = ByteSize::from("16gb");
/// The maximum size of the chunk index table that's stored in memory. After this,
@@ -36,4 +36,12 @@ crate::config_group!({
///
/// Use the environment variable `HF_XET_SHARD_CHUNK_INDEX_TABLE_MAX_SIZE` to set this value.
ref chunk_index_table_max_size: usize = 64 * 1024 * 1024;
/// Subdirectory name for shard cache within the endpoint cache directory.
///
/// The default value is "shard-cache".
///
/// Use the environment variable `HF_XET_SHARD_CACHE_SUBDIR` to set this value.
ref cache_subdir: String = "shard-cache".to_string();
});

View File

@@ -1,3 +1,5 @@
use crate::utils::ConfigEnum;
crate::config_group!({
/// How often should we retest the compression scheme?
/// Determining the optimal compression scheme takes time, but
@@ -9,4 +11,13 @@ crate::config_group!({
///
/// Use the environment variable `HF_XET_XORB_COMPRESSION_SCHEME_RETEST_INTERVAL` to set this value.
ref compression_scheme_retest_interval : usize = 32;
/// Compression policy for xorb data.
/// Valid values: "" or "auto" for automatic detection, "none", "lz4", "bg4-lz4".
/// When set to "" or "auto", the best compression scheme is chosen based on data analysis.
///
/// The default value is "auto" (auto-detect).
///
/// Use the environment variable `HF_XET_XORB_COMPRESSION_POLICY` to set this value.
ref compression_policy: ConfigEnum = ConfigEnum::new("auto", &["", "auto", "none", "lz4", "bg4-lz4"]);
});

View File

@@ -6,7 +6,7 @@
#[macro_export]
macro_rules! all_config_groups {
($mac:ident) => {
$mac!(data, shard, deduplication, chunk_cache, client, log, reconstruction, xorb);
$mac!(data, shard, deduplication, chunk_cache, client, log, reconstruction, xorb, session);
};
}
@@ -107,7 +107,7 @@ macro_rules! config_group {
}
let default_value: $type = $value;
self.$name = <$type>::parse(stringify!($name), maybe_env_value, default_value);
self.$name = <$type>::parse_config_value(stringify!($name), maybe_env_value, default_value);
}
)+
}
@@ -123,11 +123,12 @@ macro_rules! config_group {
match name {
$(
stringify!($name) => {
self.$name = <$type>::parse_user_value(&value_string)
.ok_or_else(|| $crate::config::ConfigError::ParseError {
if !self.$name.try_update_in_place(&value_string) {
return Err($crate::config::ConfigError::ParseError {
field: name.to_owned(),
value: value_string,
})?;
});
}
Ok(())
}
)+
@@ -155,7 +156,7 @@ macro_rules! config_group {
match name {
$(
stringify!($name) => {
self.$name = <$type as $crate::config::python::PythonConfigValue>::from_python(value)?;
<$type as $crate::config::python::PythonConfigValue>::update_from_python(&mut self.$name, value)?;
Ok(())
}
)+

View File

@@ -12,8 +12,6 @@ pub mod groups;
#[cfg(feature = "python")]
pub mod python;
// Re-export types from utils for backward compatibility and for use in config_group macro
// Re-export XetConfig for convenience
pub use xet_config::XetConfig;
pub use crate::utils::configuration_utils::ParsableConfigValue;
@@ -26,6 +24,7 @@ pub type ChunkCacheConfig = groups::chunk_cache::ConfigValues;
pub type ClientConfig = groups::client::ConfigValues;
pub type LogConfig = groups::log::ConfigValues;
pub type XorbConfig = groups::xorb::ConfigValues;
pub type SessionConfig = groups::session::ConfigValues;
#[cfg(feature = "python")]
pub use xet_config::py_xet_config::PyXetConfig;

View File

@@ -1,9 +1,9 @@
use pyo3::conversion::IntoPyObjectExt;
use pyo3::prelude::*;
use crate::utils::ByteSize;
#[cfg(not(target_family = "wasm"))]
use crate::utils::TemplatedPathBuf;
use crate::utils::{ByteSize, ConfigEnum};
/// Trait for converting config values to/from Python objects.
///
@@ -19,6 +19,17 @@ pub trait PythonConfigValue {
fn from_python(obj: &Bound<'_, PyAny>) -> PyResult<Self>
where
Self: Sized;
/// Update the value in place from a Python object. The default delegates to
/// `from_python`, but types like `ConfigEnum` override this to perform
/// context-aware validation using the existing value's metadata.
fn update_from_python(&mut self, obj: &Bound<'_, PyAny>) -> PyResult<()>
where
Self: Sized,
{
*self = Self::from_python(obj)?;
Ok(())
}
}
macro_rules! impl_python_extract {
@@ -72,6 +83,22 @@ impl<T: PythonConfigValue> PythonConfigValue for Option<T> {
}
}
impl PythonConfigValue for ConfigEnum {
fn to_python(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
self.as_str().into_py_any(py)
}
fn from_python(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
let s: String = obj.extract()?;
Ok(ConfigEnum::new_unchecked(s))
}
fn update_from_python(&mut self, obj: &Bound<'_, PyAny>) -> PyResult<()> {
let s: String = obj.extract()?;
self.try_set(&s).map_err(|e| pyo3::exceptions::PyValueError::new_err(e))
}
}
#[cfg(not(target_family = "wasm"))]
impl PythonConfigValue for TemplatedPathBuf {
fn to_python(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {

View File

@@ -17,13 +17,11 @@ crate::all_config_groups!(define_xet_config);
macro_rules! impl_xet_config_group_dispatch {
($($group:ident),*) => {
impl XetConfig {
// Internal: parse a dotted config path into (group, field).
fn split_path(path: &str) -> Result<(&str, &str), ConfigError> {
path.split_once('.')
.ok_or_else(|| ConfigError::InvalidPath(path.to_owned()))
}
// Internal: parse a dotted config path and convert parse errors to Python exceptions.
#[cfg(feature = "python")]
fn split_path_for_python(path: &str) -> pyo3::PyResult<(&str, &str)> {
Self::split_path(path).map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))
@@ -40,7 +38,6 @@ macro_rules! impl_xet_config_group_dispatch {
self
}
// Internal: dispatches with_config field updates to the correct group.
fn update_field(&mut self, path: &str, value: impl ToString) -> Result<(), ConfigError> {
let (group, field) = Self::split_path(path)?;
match group {
@@ -51,7 +48,6 @@ macro_rules! impl_xet_config_group_dispatch {
}
}
// Internal: dispatches Python with_config field updates to the correct group.
#[cfg(feature = "python")]
fn update_field_from_python(
&mut self,
@@ -69,7 +65,6 @@ macro_rules! impl_xet_config_group_dispatch {
}
}
// Internal: dispatches Python get/item access to the correct group.
#[cfg(feature = "python")]
fn get_field_to_python(
&self,
@@ -114,7 +109,6 @@ macro_rules! impl_xet_config_group_dispatch {
keys
}
// Internal: collects all (key, value) pairs for Python iteration.
#[cfg(feature = "python")]
fn all_items_to_python(
&self,

View File

@@ -0,0 +1,332 @@
use std::fmt;
use std::ops::Deref;
use std::str::FromStr;
use tracing::{event, info, warn};
use super::configuration_utils::{INFORMATION_LOG_LEVEL, ParsableConfigValue};
/// A config value restricted to a fixed set of valid lowercase string options.
///
/// Stores the normalized (lowercased) value and validates against `valid_values`
/// at parse time. If the user provides an invalid value via environment variable,
/// a warning is logged and the default is used instead.
///
/// # Usage in `config_group!`
///
/// ```rust,ignore
/// use crate::utils::ConfigEnum;
///
/// crate::config_group!({
/// ref compression_policy: ConfigEnum = ConfigEnum::new("auto", &["", "auto", "none", "lz4"]);
/// });
/// ```
#[derive(Clone)]
pub struct ConfigEnum {
value: String,
valid_values: &'static [&'static str],
}
impl ConfigEnum {
pub fn new(default: &str, valid_values: &'static [&'static str]) -> Self {
let lower = default.to_lowercase();
debug_assert!(
valid_values.iter().any(|v| v.to_lowercase() == lower),
"Default value \"{default}\" is not in the valid values list: {valid_values:?}"
);
ConfigEnum {
value: lower,
valid_values,
}
}
/// Creates a ConfigEnum with a string value and an empty `valid_values` list.
/// Because `valid_values` is empty, `try_set` will reject every value on the
/// resulting instance. This is intended only for deserialization paths
/// (e.g. Python's `from_python`) where the caller validates through the
/// *existing* field value's `try_set` rather than through this instance.
pub fn new_unchecked(value: impl Into<String>) -> Self {
ConfigEnum {
value: value.into().to_lowercase(),
valid_values: &[],
}
}
pub fn as_str(&self) -> &str {
&self.value
}
pub fn valid_values(&self) -> &'static [&'static str] {
self.valid_values
}
/// Set the value if it is valid (case-insensitive), otherwise return an error.
pub fn try_set(&mut self, value: &str) -> Result<(), String> {
let lower = value.to_lowercase();
if self.valid_values.iter().any(|v| v.to_lowercase() == lower) {
self.value = lower;
Ok(())
} else {
Err(format!("\"{value}\" is not a valid value. Valid values are: {:?}", self.valid_values))
}
}
/// Parse the stored value into a target type via `FromStr`, matching the
/// familiar `str::parse::<T>()` signature.
///
/// In debug builds, asserts that *every* entry in `valid_values` can be
/// successfully parsed into `T`, catching mismatches between the config's
/// allowed strings and the target type's parser at development time.
pub fn parse<T>(&self) -> Result<T, T::Err>
where
T: FromStr,
T::Err: fmt::Debug + fmt::Display,
{
#[cfg(debug_assertions)]
for v in self.valid_values {
if let Err(e) = v.parse::<T>() {
panic!("ConfigEnum valid value \"{v}\" cannot be parsed into {}: {e}", std::any::type_name::<T>());
}
}
self.value.parse::<T>()
}
}
impl fmt::Debug for ConfigEnum {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self.value)
}
}
impl fmt::Display for ConfigEnum {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.value)
}
}
impl Deref for ConfigEnum {
type Target = str;
fn deref(&self) -> &Self::Target {
&self.value
}
}
impl AsRef<str> for ConfigEnum {
fn as_ref(&self) -> &str {
&self.value
}
}
impl PartialEq<str> for ConfigEnum {
fn eq(&self, other: &str) -> bool {
self.value == other.to_lowercase()
}
}
impl PartialEq<&str> for ConfigEnum {
fn eq(&self, other: &&str) -> bool {
self.value == other.to_lowercase()
}
}
impl PartialEq for ConfigEnum {
fn eq(&self, other: &Self) -> bool {
self.value == other.value
}
}
impl Eq for ConfigEnum {}
impl ParsableConfigValue for ConfigEnum {
fn parse_user_value(_value: &str) -> Option<Self> {
None
}
fn to_config_string(&self) -> String {
self.value.clone()
}
fn try_update_in_place(&mut self, value: &str) -> bool {
self.try_set(value).is_ok()
}
fn parse_config_value(variable_name: &str, value: Option<String>, default: Self) -> Self {
match value {
Some(v) => {
let lower = v.to_lowercase();
if default.valid_values.iter().any(|valid| valid.to_lowercase() == lower) {
info!("Config: {variable_name} = {lower:?} (user set)");
ConfigEnum {
value: lower,
valid_values: default.valid_values,
}
} else {
warn!(
"Configuration value \"{v}\" for {variable_name} is not valid. \
Valid values are: {:?}. Reverting to default \"{}\".",
default.valid_values, default.value
);
info!("Config: {variable_name} = {:?} (default due to invalid value)", default.value);
default
}
},
None => {
event!(INFORMATION_LOG_LEVEL, "Config: {variable_name} = {:?} (default)", default.value);
default
},
}
}
}
#[cfg(test)]
mod tests {
use std::num::ParseIntError;
use super::*;
const VALID: &[&str] = &["", "auto", "none", "lz4", "bg4-lz4"];
const INT_VALID: &[&str] = &["1", "2", "3"];
#[test]
fn test_new_default_value() {
let ce = ConfigEnum::new("auto", VALID);
assert_eq!(ce.as_str(), "auto");
}
#[test]
fn test_new_normalizes_to_lowercase() {
let ce = ConfigEnum::new("AUTO", VALID);
assert_eq!(ce.as_str(), "auto");
}
#[test]
fn test_deref_and_asref() {
let ce = ConfigEnum::new("lz4", VALID);
let s: &str = &ce;
assert_eq!(s, "lz4");
assert_eq!(ce.as_ref(), "lz4");
}
#[test]
fn test_partial_eq_str() {
let ce = ConfigEnum::new("lz4", VALID);
assert_eq!(ce, "lz4");
assert_eq!(ce, "LZ4");
assert_ne!(ce, "auto");
}
#[test]
fn test_partial_eq_str_ref() {
let ce = ConfigEnum::new("lz4", VALID);
assert_eq!(ce, "lz4");
}
#[test]
fn test_partial_eq_self() {
let a = ConfigEnum::new("lz4", VALID);
let b = ConfigEnum::new("lz4", VALID);
assert_eq!(a, b);
}
#[test]
fn test_display() {
let ce = ConfigEnum::new("bg4-lz4", VALID);
assert_eq!(format!("{ce}"), "bg4-lz4");
}
#[test]
fn test_debug() {
let ce = ConfigEnum::new("bg4-lz4", VALID);
assert_eq!(format!("{ce:?}"), "\"bg4-lz4\"");
}
#[test]
fn test_parse_valid_value() {
let default = ConfigEnum::new("auto", VALID);
let result = ConfigEnum::parse_config_value("test", Some("LZ4".to_string()), default);
assert_eq!(result.as_str(), "lz4");
}
#[test]
fn test_parse_invalid_value_returns_default() {
let default = ConfigEnum::new("auto", VALID);
let result = ConfigEnum::parse_config_value("test", Some("zstd".to_string()), default);
assert_eq!(result.as_str(), "auto");
}
#[test]
fn test_parse_none_returns_default() {
let default = ConfigEnum::new("auto", VALID);
let result = ConfigEnum::parse_config_value("test", None, default);
assert_eq!(result.as_str(), "auto");
}
#[test]
fn test_parse_empty_string_valid() {
let default = ConfigEnum::new("auto", VALID);
let result = ConfigEnum::parse_config_value("test", Some("".to_string()), default);
assert_eq!(result.as_str(), "");
}
#[test]
#[should_panic(expected = "not in the valid values list")]
#[cfg(debug_assertions)]
fn test_new_invalid_default_panics() {
let _ = ConfigEnum::new("invalid", VALID);
}
#[test]
fn test_valid_values_accessor() {
let ce = ConfigEnum::new("auto", VALID);
assert_eq!(ce.valid_values(), VALID);
}
#[test]
fn test_try_set_valid() {
let mut ce = ConfigEnum::new("auto", VALID);
assert!(ce.try_set("lz4").is_ok());
assert_eq!(ce.as_str(), "lz4");
}
#[test]
fn test_try_set_case_insensitive() {
let mut ce = ConfigEnum::new("auto", VALID);
assert!(ce.try_set("LZ4").is_ok());
assert_eq!(ce.as_str(), "lz4");
}
#[test]
fn test_try_set_invalid() {
let mut ce = ConfigEnum::new("auto", VALID);
assert!(ce.try_set("zstd").is_err());
assert_eq!(ce.as_str(), "auto");
}
#[test]
fn test_try_set_empty_string() {
let mut ce = ConfigEnum::new("auto", VALID);
assert!(ce.try_set("").is_ok());
assert_eq!(ce.as_str(), "");
}
#[test]
fn test_parse_success() {
let ce = ConfigEnum::new("2", INT_VALID);
let val: Result<u32, ParseIntError> = ce.parse();
assert_eq!(val.unwrap(), 2);
}
#[test]
fn test_parse_all_values_parseable() {
let ce = ConfigEnum::new("1", INT_VALID);
let _: u32 = ce.parse().unwrap();
}
#[test]
#[should_panic(expected = "cannot be parsed")]
#[cfg(debug_assertions)]
fn test_parse_panics_on_unparseable_valid_value() {
let ce = ConfigEnum::new("auto", VALID);
let _: Result<u32, ParseIntError> = ce.parse();
}
}

View File

@@ -20,9 +20,21 @@ pub trait ParsableConfigValue: std::fmt::Debug + Sized {
/// Serialize this value to a string that can be parsed back via `parse_user_value`.
fn to_config_string(&self) -> String;
/// Try to update this value in place from a string. Returns true on success.
/// The default implementation delegates to `parse_user_value`, but types like
/// `ConfigEnum` override this to use context-aware validation.
fn try_update_in_place(&mut self, value: &str) -> bool {
if let Some(v) = Self::parse_user_value(value) {
*self = v;
true
} else {
false
}
}
/// Parse the value, returning the default if it can't be parsed or the string is empty.
/// Issue a warning if it can't be parsed.
fn parse(variable_name: &str, value: Option<String>, default: Self) -> Self {
fn parse_config_value(variable_name: &str, value: Option<String>, default: Self) -> Self {
match value {
Some(v) => match Self::parse_user_value(&v) {
Some(v) => {
@@ -170,7 +182,7 @@ macro_rules! test_configurable_constants {
{
let default_value = $value;
let maybe_env_value = std::env::var(concat!("HF_XET_",stringify!($name))).ok();
<$type>::parse(stringify!($name), maybe_env_value, default_value)
<$type>::parse_config_value(stringify!($name), maybe_env_value, default_value)
}
#[cfg(not(debug_assertions))]
{

View File

@@ -3,6 +3,9 @@ pub mod adjustable_semaphore;
pub mod byte_size;
pub use byte_size::ByteSize;
pub mod config_enum;
pub use config_enum::ConfigEnum;
pub mod configuration_utils;
pub use configuration_utils::is_high_performance;

View File

@@ -63,7 +63,21 @@ fn test_environment_variable_aliases() {
assert_eq!(XetConfig::new().data.session_xorb_metadata_flush_max_count, 128);
}
// MDB shard aliases
// Xorb aliases (old HF_XET_DATA_XORB_* names)
{
let _guard = EnvVarGuard::set("HF_XET_DATA_XORB_COMPRESSION_SCHEME_RETEST_INTERVAL", "64");
assert_eq!(XetConfig::new().xorb.compression_scheme_retest_interval, 64);
}
{
let _guard = EnvVarGuard::set("HF_XET_DATA_XORB_COMPRESSION_POLICY", "lz4");
assert_eq!(XetConfig::new().xorb.compression_policy.as_str(), "lz4");
}
// Shard aliases
{
let _guard = EnvVarGuard::set("HF_XET_MDB_SHARD_CACHE_SIZE_LIMIT", "24gb");
assert_eq!(*XetConfig::new().shard.cache_size_limit, 24_000_000_000);
}
{
let _guard = EnvVarGuard::set("HF_XET_SHARD_CACHE_SIZE_LIMIT", "32gb");
assert_eq!(*XetConfig::new().shard.cache_size_limit, 32_000_000_000);
@@ -74,6 +88,31 @@ fn test_environment_variable_aliases() {
}
}
#[test]
#[serial(config_env)]
fn test_new_config_env_overrides() {
{
let _guard = EnvVarGuard::set("HF_XET_DATA_AGGREGATE_PROGRESS", "false");
assert!(!XetConfig::new().data.aggregate_progress);
}
{
let _guard = EnvVarGuard::set("HF_XET_DATA_DEFAULT_PREFIX", "test-prefix");
assert_eq!(XetConfig::new().data.default_prefix, "test-prefix");
}
{
let _guard = EnvVarGuard::set("HF_XET_DATA_STAGING_SUBDIR", "tmp-stage");
assert_eq!(XetConfig::new().data.staging_subdir, "tmp-stage");
}
{
let _guard = EnvVarGuard::set("HF_XET_SESSION_DIR_NAME", "tmp-session");
assert_eq!(XetConfig::new().session.dir_name, "tmp-session");
}
{
let _guard = EnvVarGuard::set("HF_XET_DEDUPLICATION_GLOBAL_DEDUP_QUERY_ENABLED", "false");
assert!(!XetConfig::new().deduplication.global_dedup_query_enabled);
}
}
/// Test that primary environment variable takes precedence over alias when both are set.
#[test]
#[serial(config_env)]