mirror of
https://github.com/huggingface/xet-core.git
synced 2026-06-04 13:30:29 +08:00
Update config groups to handle more of the data management values. (#702)
This PR moves some config values that were part of the data configuration into XetConfig, specifically the compression_policy, staging_subdir, session_dir_name, and global_dedup_query_enabled. This also consolidates the remaining values into a single struct with endpoint and authentication information.
This commit is contained in:
@@ -131,10 +131,12 @@ impl TransferAgent for XetAgent {
|
||||
|
||||
let headers = user_agent_headers;
|
||||
|
||||
let config =
|
||||
default_config(cas_url, None, Some((token, token_expiry)), Some(token_refresher), Some(Arc::new(headers)))?
|
||||
.disable_progress_aggregation()
|
||||
.with_session_id(session_id); // upload one file at a time so no need for the heavy progress aggregator
|
||||
let mut config =
|
||||
default_config(cas_url, Some((token, token_expiry)), Some(token_refresher), Some(Arc::new(headers)))?
|
||||
.disable_progress_aggregation();
|
||||
if !session_id.is_empty() {
|
||||
config.session.session_id = Some(session_id.to_owned());
|
||||
}
|
||||
let session = FileUploadSession::new(config.into(), Some(Arc::new(xet_updater))).await?;
|
||||
|
||||
let Some(file_path) = &req.path else {
|
||||
|
||||
@@ -164,7 +164,7 @@ pub async fn clean_file(file: web_sys::File, endpoint: String, jwt_token: String
|
||||
let config = TranslatorConfig {
|
||||
data_config: DataConfig {
|
||||
endpoint,
|
||||
compression: Some(CompressionScheme::LZ4),
|
||||
compression: CompressionScheme::LZ4,
|
||||
auth: AuthConfig::maybe_new(Some(jwt_token), Some(expiration), None),
|
||||
prefix: "default".to_owned(),
|
||||
user_agent: USER_AGENT.to_string(),
|
||||
|
||||
@@ -6,7 +6,7 @@ use xet_core_structures::xorb_object::CompressionScheme;
|
||||
#[derive(Debug)]
|
||||
pub struct DataConfig {
|
||||
pub endpoint: String,
|
||||
pub compression: Option<CompressionScheme>,
|
||||
pub compression: CompressionScheme,
|
||||
pub auth: Option<AuthConfig>,
|
||||
pub prefix: String,
|
||||
pub user_agent: String,
|
||||
|
||||
@@ -55,7 +55,7 @@ impl XetSession {
|
||||
let config = TranslatorConfig {
|
||||
data_config: DataConfig {
|
||||
endpoint,
|
||||
compression: Some(CompressionScheme::LZ4),
|
||||
compression: CompressionScheme::LZ4,
|
||||
auth: Some(auth),
|
||||
prefix: "default".to_owned(),
|
||||
user_agent: USER_AGENT.to_string(),
|
||||
|
||||
@@ -126,8 +126,8 @@ impl FileUploadSession {
|
||||
}
|
||||
|
||||
// XORBs are sent without footer - the server/client reconstructs it from chunk data.
|
||||
let compression_scheme = self.config.data_config.compression;
|
||||
let xorb_obj = SerializedXorbObject::from_xorb(xorb, compression_scheme, false)?;
|
||||
let xorb_obj =
|
||||
SerializedXorbObject::from_xorb_with_compression(xorb, self.config.data_config.compression, false)?;
|
||||
|
||||
let Some(ref mut xorb_uploader) = *self.xorb_uploader.lock().await else {
|
||||
return Err(DataProcessingError::internal("register xorb after drop"));
|
||||
|
||||
@@ -757,7 +757,7 @@ mod tests {
|
||||
let threadpool = XetRuntime::new().unwrap();
|
||||
let client = RemoteClient::new(CAS_ENDPOINT, &None, "", false, None);
|
||||
|
||||
let xorb_obj = build_and_verify_xorb_object(raw_xorb, Some(CompressionScheme::LZ4));
|
||||
let xorb_obj = build_and_verify_xorb_object(raw_xorb, CompressionScheme::LZ4);
|
||||
|
||||
// Act
|
||||
let result = threadpool
|
||||
|
||||
@@ -133,7 +133,7 @@ pub trait ClientTestingUtils: Client + Send + Sync {
|
||||
|
||||
shard.add_xorb_block(raw_xorb.xorb_info.clone())?;
|
||||
|
||||
let serialized_xorb = SerializedXorbObject::from_xorb(raw_xorb.clone(), None, true)?;
|
||||
let serialized_xorb = SerializedXorbObject::from_xorb(raw_xorb.clone(), true)?;
|
||||
|
||||
let upload_permit = self.acquire_upload_permit().await?;
|
||||
self.upload_xorb("default", serialized_xorb, None, upload_permit).await?;
|
||||
|
||||
@@ -1071,6 +1071,7 @@ fn generate_v2_fetch_url(hash: &MerkleHash, ranges: &[XorbRangeDescriptor], time
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use xet_core_structures::xorb_object::CompressionScheme;
|
||||
use xet_core_structures::xorb_object::xorb_format_test_utils::{
|
||||
ChunkSize, build_and_verify_xorb_object, build_raw_xorb,
|
||||
};
|
||||
@@ -1094,7 +1095,7 @@ mod tests {
|
||||
async fn test_download_fetch_term_data_validation() {
|
||||
// Setup: Create a client and upload a xorb
|
||||
let xorb = build_raw_xorb(3, ChunkSize::Fixed(2048));
|
||||
let xorb_obj = build_and_verify_xorb_object(xorb, None);
|
||||
let xorb_obj = build_and_verify_xorb_object(xorb, CompressionScheme::Auto);
|
||||
let hash = xorb_obj.hash;
|
||||
|
||||
let client = LocalClient::temporary().await.unwrap();
|
||||
|
||||
@@ -73,19 +73,15 @@ fn kl_divergence(pv: &[f64], qv: &[f64]) -> f64 {
|
||||
}
|
||||
|
||||
fn lz4_compress_size(data: &[u8]) -> usize {
|
||||
serialize_chunk(
|
||||
data,
|
||||
&mut std::io::Empty::default(),
|
||||
Some(xet_core_structures::xorb_object::CompressionScheme::LZ4),
|
||||
)
|
||||
.unwrap()
|
||||
serialize_chunk(data, &mut std::io::Empty::default(), xet_core_structures::xorb_object::CompressionScheme::LZ4)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn bg4_lz4_compress_size(data: &[u8]) -> usize {
|
||||
serialize_chunk(
|
||||
data,
|
||||
&mut std::io::Empty::default(),
|
||||
Some(xet_core_structures::xorb_object::CompressionScheme::ByteGrouping4LZ4),
|
||||
xet_core_structures::xorb_object::CompressionScheme::ByteGrouping4LZ4,
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::borrow::Cow;
|
||||
use std::fmt::Display;
|
||||
use std::io::{Cursor, Read, Write, copy};
|
||||
use std::str::FromStr;
|
||||
use std::time::Instant;
|
||||
|
||||
use anyhow::anyhow;
|
||||
@@ -19,12 +20,12 @@ pub static mut BG4_LZ4_DECOMPRESS_RUNTIME: f64 = 0.;
|
||||
#[repr(u8)]
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Default)]
|
||||
pub enum CompressionScheme {
|
||||
#[default]
|
||||
None = 0,
|
||||
LZ4 = 1,
|
||||
ByteGrouping4LZ4 = 2, // 4 byte groups
|
||||
#[default]
|
||||
Auto = 99,
|
||||
}
|
||||
pub const NUM_COMPRESSION_SCHEMES: usize = 3;
|
||||
|
||||
impl Display for CompressionScheme {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
@@ -34,9 +35,10 @@ impl Display for CompressionScheme {
|
||||
impl From<&CompressionScheme> for &'static str {
|
||||
fn from(value: &CompressionScheme) -> Self {
|
||||
match value {
|
||||
CompressionScheme::None => "none",
|
||||
CompressionScheme::Auto => "auto",
|
||||
CompressionScheme::LZ4 => "lz4",
|
||||
CompressionScheme::ByteGrouping4LZ4 => "bg4-lz4",
|
||||
CompressionScheme::None => "none",
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -55,14 +57,41 @@ impl TryFrom<u8> for CompressionScheme {
|
||||
0 => Ok(CompressionScheme::None),
|
||||
1 => Ok(CompressionScheme::LZ4),
|
||||
2 => Ok(CompressionScheme::ByteGrouping4LZ4),
|
||||
99 => Ok(CompressionScheme::Auto),
|
||||
_ => Err(XorbObjectError::Format(anyhow!("cannot convert value {value} to CompressionScheme"))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for CompressionScheme {
|
||||
type Err = XorbObjectError;
|
||||
|
||||
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||
match s.trim().to_lowercase().as_str() {
|
||||
"" | "auto" => Ok(CompressionScheme::Auto),
|
||||
"none" => Ok(CompressionScheme::None),
|
||||
"lz4" => Ok(CompressionScheme::LZ4),
|
||||
"bg4-lz4" => Ok(CompressionScheme::ByteGrouping4LZ4),
|
||||
_ => Err(XorbObjectError::Format(anyhow!(
|
||||
"Invalid compression scheme '{s}'. Valid values are: auto, none, lz4, bg4-lz4."
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl CompressionScheme {
|
||||
/// Resolves `Auto` to a concrete scheme by analyzing the data.
|
||||
pub fn resolve_for_data(&self, data: &[u8]) -> Self {
|
||||
if *self == CompressionScheme::Auto {
|
||||
CompressionScheme::choose_from_data(data)
|
||||
} else {
|
||||
*self
|
||||
}
|
||||
}
|
||||
|
||||
pub fn compress_from_slice<'a>(&self, data: &'a [u8]) -> Result<Cow<'a, [u8]>> {
|
||||
Ok(match self {
|
||||
CompressionScheme::Auto => return self.resolve_for_data(data).compress_from_slice(data),
|
||||
CompressionScheme::None => data.into(),
|
||||
CompressionScheme::LZ4 => lz4_compress_from_slice(data).map(Cow::from)?,
|
||||
CompressionScheme::ByteGrouping4LZ4 => bg4_lz4_compress_from_slice(data).map(Cow::from)?,
|
||||
@@ -71,6 +100,9 @@ impl CompressionScheme {
|
||||
|
||||
pub fn decompress_from_slice<'a>(&self, data: &'a [u8]) -> Result<Cow<'a, [u8]>> {
|
||||
Ok(match self {
|
||||
CompressionScheme::Auto => {
|
||||
return Err(XorbObjectError::Format(anyhow!("Cannot decompress with Auto scheme")));
|
||||
},
|
||||
CompressionScheme::None => data.into(),
|
||||
CompressionScheme::LZ4 => lz4_decompress_from_slice(data).map(Cow::from)?,
|
||||
CompressionScheme::ByteGrouping4LZ4 => bg4_lz4_decompress_from_slice(data).map(Cow::from)?,
|
||||
@@ -79,6 +111,9 @@ impl CompressionScheme {
|
||||
|
||||
pub fn decompress_from_reader<R: Read, W: Write>(&self, reader: &mut R, writer: &mut W) -> Result<u64> {
|
||||
Ok(match self {
|
||||
CompressionScheme::Auto => {
|
||||
return Err(XorbObjectError::Format(anyhow!("Cannot decompress with Auto scheme")));
|
||||
},
|
||||
CompressionScheme::None => copy(reader, writer)?,
|
||||
CompressionScheme::LZ4 => lz4_decompress_from_reader(reader, writer)?,
|
||||
CompressionScheme::ByteGrouping4LZ4 => bg4_lz4_decompress_from_reader(reader, writer)?,
|
||||
@@ -169,19 +204,73 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_default_is_auto() {
|
||||
assert_eq!(CompressionScheme::default(), CompressionScheme::Auto);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_to_str() {
|
||||
assert_eq!(Into::<&str>::into(CompressionScheme::Auto), "auto");
|
||||
assert_eq!(Into::<&str>::into(CompressionScheme::None), "none");
|
||||
assert_eq!(Into::<&str>::into(CompressionScheme::LZ4), "lz4");
|
||||
assert_eq!(Into::<&str>::into(CompressionScheme::ByteGrouping4LZ4), "bg4-lz4");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_str() {
|
||||
assert_eq!("".parse::<CompressionScheme>().unwrap(), CompressionScheme::Auto);
|
||||
assert_eq!("auto".parse::<CompressionScheme>().unwrap(), CompressionScheme::Auto);
|
||||
assert_eq!("none".parse::<CompressionScheme>().unwrap(), CompressionScheme::None);
|
||||
assert_eq!("lz4".parse::<CompressionScheme>().unwrap(), CompressionScheme::LZ4);
|
||||
assert_eq!("bg4-lz4".parse::<CompressionScheme>().unwrap(), CompressionScheme::ByteGrouping4LZ4);
|
||||
assert_eq!(" LZ4 ".parse::<CompressionScheme>().unwrap(), CompressionScheme::LZ4);
|
||||
assert!("zstd".parse::<CompressionScheme>().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_display() {
|
||||
assert_eq!(format!("{}", CompressionScheme::Auto), "auto");
|
||||
assert_eq!(format!("{}", CompressionScheme::None), "none");
|
||||
assert_eq!(format!("{}", CompressionScheme::LZ4), "lz4");
|
||||
assert_eq!(format!("{}", CompressionScheme::ByteGrouping4LZ4), "bg4-lz4");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_with_config_enum() {
|
||||
use xet_runtime::utils::ConfigEnum;
|
||||
|
||||
let ce = ConfigEnum::new("auto", &["", "auto", "none", "lz4", "bg4-lz4"]);
|
||||
let scheme: CompressionScheme = ce.parse().unwrap();
|
||||
assert_eq!(scheme, CompressionScheme::Auto);
|
||||
|
||||
let ce = ConfigEnum::new("lz4", &["", "auto", "none", "lz4", "bg4-lz4"]);
|
||||
let scheme: CompressionScheme = ce.parse().unwrap();
|
||||
assert_eq!(scheme, CompressionScheme::LZ4);
|
||||
|
||||
let ce = ConfigEnum::new("none", &["", "auto", "none", "lz4", "bg4-lz4"]);
|
||||
let scheme: CompressionScheme = ce.parse().unwrap();
|
||||
assert_eq!(scheme, CompressionScheme::None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_u8() {
|
||||
assert_eq!(CompressionScheme::try_from(0u8), Ok(CompressionScheme::None));
|
||||
assert_eq!(CompressionScheme::try_from(1u8), Ok(CompressionScheme::LZ4));
|
||||
assert_eq!(CompressionScheme::try_from(2u8), Ok(CompressionScheme::ByteGrouping4LZ4));
|
||||
assert_eq!(CompressionScheme::try_from(99u8), Ok(CompressionScheme::Auto));
|
||||
assert!(CompressionScheme::try_from(3u8).is_err());
|
||||
assert!(CompressionScheme::try_from(4u8).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_for_data() {
|
||||
let data = vec![0u8; 1024];
|
||||
let resolved = CompressionScheme::Auto.resolve_for_data(&data);
|
||||
assert_ne!(resolved, CompressionScheme::Auto);
|
||||
|
||||
assert_eq!(CompressionScheme::LZ4.resolve_for_data(&data), CompressionScheme::LZ4);
|
||||
assert_eq!(CompressionScheme::None.resolve_for_data(&data), CompressionScheme::None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -59,6 +59,7 @@ impl XorbChunkHeader {
|
||||
}
|
||||
|
||||
pub fn set_compression_scheme(&mut self, compression_scheme: CompressionScheme) {
|
||||
debug_assert_ne!(compression_scheme, CompressionScheme::Auto);
|
||||
self.compression_scheme = compression_scheme as u8;
|
||||
}
|
||||
|
||||
@@ -114,9 +115,13 @@ fn convert_three_byte_num(buf: &[u8; 3]) -> u32 {
|
||||
pub fn serialize_chunk<W: Write>(
|
||||
chunk: &[u8],
|
||||
w: &mut W,
|
||||
compression_scheme: Option<CompressionScheme>,
|
||||
compression_scheme: CompressionScheme,
|
||||
) -> Result<usize, XorbObjectError> {
|
||||
let compression_scheme = compression_scheme.unwrap_or_else(|| CompressionScheme::choose_from_data(chunk));
|
||||
let compression_scheme = compression_scheme.resolve_for_data(chunk);
|
||||
debug_assert_ne!(compression_scheme, CompressionScheme::Auto);
|
||||
if compression_scheme == CompressionScheme::Auto {
|
||||
return Err(XorbObjectError::Format(anyhow!("CompressionScheme::Auto cannot be serialized into xorb chunks")));
|
||||
}
|
||||
|
||||
let compressed = compression_scheme.compress_from_slice(chunk)?;
|
||||
|
||||
@@ -309,6 +314,16 @@ mod tests {
|
||||
assert_eq!(data_copy.as_slice(), data);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auto_scheme_never_serialized_in_chunk_header() {
|
||||
let chunk = vec![0u8; 4096];
|
||||
let mut serialized = Vec::new();
|
||||
let _ = serialize_chunk(&chunk, &mut serialized, CompressionScheme::Auto).unwrap();
|
||||
|
||||
let header = deserialize_chunk_header(&mut Cursor::new(serialized)).unwrap();
|
||||
assert_ne!(header.get_compression_scheme().unwrap(), CompressionScheme::Auto);
|
||||
}
|
||||
|
||||
const CASES: &[(u32, ChunkSize, CompressionScheme)] = &[
|
||||
(2, ChunkSize::Fixed(16), CompressionScheme::None),
|
||||
(10, ChunkSize::Fixed(16), CompressionScheme::None),
|
||||
|
||||
@@ -151,7 +151,7 @@ mod tests {
|
||||
let mut out = Vec::new();
|
||||
for _ in 0..num_chunks {
|
||||
let data = gen_random_bytes(rng, CHUNK_SIZE as u32);
|
||||
serialize_chunk(&data, &mut out, Some(compression_scheme)).unwrap();
|
||||
serialize_chunk(&data, &mut out, compression_scheme).unwrap();
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
@@ -1282,9 +1282,18 @@ pub struct SerializedXorbObject {
|
||||
|
||||
impl SerializedXorbObject {
|
||||
/// Builds the xorb from raw xorb data.
|
||||
pub fn from_xorb(
|
||||
///
|
||||
/// The compression scheme is determined by `HF_XET_XORB_COMPRESSION_POLICY`:
|
||||
/// auto-detect (default) or an explicit scheme (none, lz4, bg4-lz4).
|
||||
pub fn from_xorb(xorb: RawXorbData, serialize_footer: bool) -> Result<Self, XorbObjectError> {
|
||||
let compression_scheme: CompressionScheme = xet_config().xorb.compression_policy.parse()?;
|
||||
Self::from_xorb_with_compression(xorb, compression_scheme, serialize_footer)
|
||||
}
|
||||
|
||||
/// Builds the xorb from raw xorb data with an explicit compression scheme override.
|
||||
pub fn from_xorb_with_compression(
|
||||
xorb: RawXorbData,
|
||||
compression_scheme: Option<CompressionScheme>,
|
||||
compression_scheme: CompressionScheme,
|
||||
serialize_footer: bool,
|
||||
) -> Result<Self, XorbObjectError> {
|
||||
let mut xorb_object_info = XorbObjectInfoV1::default();
|
||||
@@ -1320,7 +1329,7 @@ impl SerializedXorbObject {
|
||||
num_chunks
|
||||
};
|
||||
|
||||
if compression_scheme.is_none() && num_chunks != 0 {
|
||||
if compression_scheme == CompressionScheme::Auto && num_chunks != 0 {
|
||||
debug_assert!(xorb.file_boundaries.is_sorted());
|
||||
debug_assert_ge!(xorb.file_boundaries.len(), 0);
|
||||
debug_assert_lt!(*xorb.file_boundaries.last().unwrap(), xorb.data.len());
|
||||
@@ -1334,10 +1343,10 @@ impl SerializedXorbObject {
|
||||
|
||||
// Choose the compression scheme for this block.
|
||||
let compression_scheme = CompressionScheme::choose_from_data(&xorb.data[s_idx]);
|
||||
debug_assert_ne!(compression_scheme, CompressionScheme::Auto);
|
||||
|
||||
for chunk in &xorb.data[s_idx..n_idx] {
|
||||
// now serialize chunk directly to writer (since chunks come first!)
|
||||
serialize_chunk(chunk, &mut serialized_data, Some(compression_scheme))?;
|
||||
serialize_chunk(chunk, &mut serialized_data, compression_scheme)?;
|
||||
xorb_object_info.chunk_boundary_offsets.push(serialized_data.len() as u32);
|
||||
}
|
||||
s_idx = n_idx;
|
||||
@@ -1345,7 +1354,6 @@ impl SerializedXorbObject {
|
||||
}
|
||||
} else {
|
||||
for chunk in xorb.data {
|
||||
// now serialize chunk directly to writer (since chunks come first!)
|
||||
serialize_chunk(&chunk, &mut serialized_data, compression_scheme)?;
|
||||
xorb_object_info.chunk_boundary_offsets.push(serialized_data.len() as u32);
|
||||
}
|
||||
@@ -1387,7 +1395,7 @@ pub mod test_utils {
|
||||
hash: &MerkleHash,
|
||||
data: &[u8],
|
||||
chunk_and_boundaries: &[(MerkleHash, u32)],
|
||||
compression_scheme: Option<CompressionScheme>,
|
||||
compression_scheme: CompressionScheme,
|
||||
) -> Result<(XorbObject, usize, u64), XorbObjectError> {
|
||||
let mut xorb = XorbObject::default();
|
||||
xorb.info.xorb_hash = *hash;
|
||||
@@ -1437,7 +1445,7 @@ pub mod test_utils {
|
||||
hash: &MerkleHash,
|
||||
data: Vec<u8>,
|
||||
chunk_and_boundaries: Vec<(MerkleHash, u32)>,
|
||||
compression: Option<CompressionScheme>,
|
||||
compression: CompressionScheme,
|
||||
) -> Result<SerializedXorbObject, XorbObjectError> {
|
||||
let mut writer = Cursor::new(Vec::new());
|
||||
|
||||
@@ -1455,7 +1463,7 @@ pub mod test_utils {
|
||||
|
||||
pub fn verify_serialized_xorb_object(
|
||||
xorb: &RawXorbData,
|
||||
compression_scheme: Option<CompressionScheme>,
|
||||
compression_scheme: CompressionScheme,
|
||||
xorb_obj: &SerializedXorbObject,
|
||||
) {
|
||||
let xorb_hash = xorb.hash();
|
||||
@@ -1544,9 +1552,10 @@ pub mod test_utils {
|
||||
|
||||
pub fn build_and_verify_xorb_object(
|
||||
xorb: RawXorbData,
|
||||
compression_scheme: Option<CompressionScheme>,
|
||||
compression_scheme: CompressionScheme,
|
||||
) -> SerializedXorbObject {
|
||||
let xorb_obj = SerializedXorbObject::from_xorb(xorb.clone(), compression_scheme, true).unwrap();
|
||||
let xorb_obj =
|
||||
SerializedXorbObject::from_xorb_with_compression(xorb.clone(), compression_scheme, true).unwrap();
|
||||
|
||||
verify_serialized_xorb_object(&xorb, compression_scheme, &xorb_obj);
|
||||
|
||||
@@ -1590,7 +1599,7 @@ pub mod test_utils {
|
||||
|
||||
// build chunk, create ChunkInfo and keep going
|
||||
|
||||
let bytes_written = serialize_chunk(&bytes, &mut writer, Some(compression_scheme)).unwrap();
|
||||
let bytes_written = serialize_chunk(&bytes, &mut writer, compression_scheme).unwrap();
|
||||
|
||||
total_bytes += bytes_written as u32;
|
||||
|
||||
@@ -1832,7 +1841,7 @@ mod tests {
|
||||
&c.info.xorb_hash,
|
||||
&raw_data,
|
||||
&raw_chunk_boundaries,
|
||||
Some(CompressionScheme::LZ4)
|
||||
CompressionScheme::LZ4
|
||||
)
|
||||
.is_ok()
|
||||
);
|
||||
@@ -1852,7 +1861,7 @@ mod tests {
|
||||
&c.info.xorb_hash,
|
||||
&c_bytes,
|
||||
&raw_chunk_boundaries,
|
||||
Some(CompressionScheme::None)
|
||||
CompressionScheme::None
|
||||
)
|
||||
.is_ok()
|
||||
);
|
||||
@@ -1890,7 +1899,7 @@ mod tests {
|
||||
&c.info.xorb_hash,
|
||||
&raw_data,
|
||||
&raw_chunk_boundaries,
|
||||
Some(CompressionScheme::LZ4)
|
||||
CompressionScheme::LZ4
|
||||
)
|
||||
.is_ok()
|
||||
);
|
||||
@@ -1914,7 +1923,7 @@ mod tests {
|
||||
&c.info.xorb_hash,
|
||||
&raw_data,
|
||||
&raw_chunk_boundaries,
|
||||
Some(CompressionScheme::None)
|
||||
CompressionScheme::None
|
||||
)
|
||||
.is_ok()
|
||||
);
|
||||
@@ -1935,7 +1944,7 @@ mod tests {
|
||||
&c.info.xorb_hash,
|
||||
&raw_data,
|
||||
&raw_chunk_boundaries,
|
||||
Some(CompressionScheme::None)
|
||||
CompressionScheme::None
|
||||
)
|
||||
.is_ok()
|
||||
);
|
||||
@@ -1968,7 +1977,7 @@ mod tests {
|
||||
&c.info.xorb_hash,
|
||||
&raw_data,
|
||||
&raw_chunk_boundaries,
|
||||
Some(CompressionScheme::None)
|
||||
CompressionScheme::None
|
||||
)
|
||||
.is_ok()
|
||||
);
|
||||
@@ -2000,7 +2009,7 @@ mod tests {
|
||||
&c.info.xorb_hash,
|
||||
&raw_data,
|
||||
&raw_chunk_boundaries,
|
||||
Some(CompressionScheme::None)
|
||||
CompressionScheme::None
|
||||
)
|
||||
.is_ok()
|
||||
);
|
||||
@@ -2032,7 +2041,7 @@ mod tests {
|
||||
&c.info.xorb_hash,
|
||||
&raw_data,
|
||||
&raw_chunk_boundaries,
|
||||
Some(CompressionScheme::LZ4)
|
||||
CompressionScheme::LZ4
|
||||
)
|
||||
.is_ok()
|
||||
);
|
||||
@@ -2063,7 +2072,7 @@ mod tests {
|
||||
&c.info.xorb_hash,
|
||||
&raw_data,
|
||||
&raw_chunk_boundaries,
|
||||
Some(CompressionScheme::LZ4)
|
||||
CompressionScheme::LZ4
|
||||
)
|
||||
.is_ok()
|
||||
);
|
||||
@@ -2096,7 +2105,7 @@ mod tests {
|
||||
&c.info.xorb_hash,
|
||||
&raw_data,
|
||||
&raw_chunk_boundaries,
|
||||
Some(CompressionScheme::LZ4)
|
||||
CompressionScheme::LZ4
|
||||
)
|
||||
.is_ok()
|
||||
);
|
||||
@@ -2128,7 +2137,7 @@ mod tests {
|
||||
&c.info.xorb_hash,
|
||||
&raw_data,
|
||||
&raw_chunk_boundaries,
|
||||
Some(CompressionScheme::LZ4)
|
||||
CompressionScheme::LZ4
|
||||
)
|
||||
.is_ok()
|
||||
);
|
||||
@@ -2158,7 +2167,7 @@ mod tests {
|
||||
&c.info.xorb_hash,
|
||||
&raw_data,
|
||||
&raw_chunk_boundaries,
|
||||
Some(CompressionScheme::LZ4)
|
||||
CompressionScheme::LZ4
|
||||
)
|
||||
.is_ok()
|
||||
);
|
||||
@@ -2190,7 +2199,7 @@ mod tests {
|
||||
&c.info.xorb_hash,
|
||||
&raw_data,
|
||||
&raw_chunk_boundaries,
|
||||
Some(CompressionScheme::LZ4)
|
||||
CompressionScheme::LZ4
|
||||
)
|
||||
.is_ok()
|
||||
);
|
||||
@@ -2256,7 +2265,7 @@ mod tests {
|
||||
&c.info.xorb_hash,
|
||||
&raw_data,
|
||||
&raw_chunk_boundaries,
|
||||
Some(CompressionScheme::LZ4),
|
||||
CompressionScheme::LZ4,
|
||||
)
|
||||
.is_ok()
|
||||
);
|
||||
@@ -2357,7 +2366,7 @@ mod tests {
|
||||
&c.info.xorb_hash,
|
||||
&raw_data,
|
||||
&raw_chunk_boundaries,
|
||||
Some(COMPRESSION_SCHEME)
|
||||
COMPRESSION_SCHEME
|
||||
)
|
||||
.is_ok()
|
||||
);
|
||||
@@ -2418,7 +2427,7 @@ mod tests {
|
||||
let bytes = gen_random_bytes(*size);
|
||||
let chunk_hash = crate::merklehash::compute_data_hash(&bytes);
|
||||
chunk_hashes_and_sizes.push((chunk_hash, bytes.len() as u64));
|
||||
serialize_chunk(&bytes, &mut raw_data, Some(compression)).unwrap();
|
||||
serialize_chunk(&bytes, &mut raw_data, compression).unwrap();
|
||||
}
|
||||
|
||||
let expected_hash = crate::merklehash::xorb_hash(&chunk_hashes_and_sizes);
|
||||
@@ -2470,7 +2479,7 @@ mod tests {
|
||||
let bytes = gen_random_bytes(*size);
|
||||
let chunk_hash = crate::merklehash::compute_data_hash(&bytes);
|
||||
chunk_hashes_and_sizes.push((chunk_hash, bytes.len() as u64));
|
||||
serialize_chunk(&bytes, &mut raw_data, Some(compression)).unwrap();
|
||||
serialize_chunk(&bytes, &mut raw_data, compression).unwrap();
|
||||
}
|
||||
|
||||
let expected_hash = crate::merklehash::xorb_hash(&chunk_hashes_and_sizes);
|
||||
@@ -2495,7 +2504,7 @@ mod tests {
|
||||
let bytes = gen_random_bytes(*size);
|
||||
let chunk_hash = crate::merklehash::compute_data_hash(&bytes);
|
||||
chunk_hashes_and_sizes.push((chunk_hash, bytes.len() as u64));
|
||||
serialize_chunk(&bytes, &mut raw_data, Some(compression)).unwrap();
|
||||
serialize_chunk(&bytes, &mut raw_data, compression).unwrap();
|
||||
}
|
||||
|
||||
let mut output = Vec::new();
|
||||
@@ -2516,7 +2525,7 @@ mod tests {
|
||||
let mut raw_data = Vec::new();
|
||||
let bytes = gen_random_bytes(1024);
|
||||
let chunk_hash = crate::merklehash::compute_data_hash(&bytes);
|
||||
serialize_chunk(&bytes, &mut raw_data, Some(CompressionScheme::LZ4)).unwrap();
|
||||
serialize_chunk(&bytes, &mut raw_data, CompressionScheme::LZ4).unwrap();
|
||||
|
||||
let expected_hash = crate::merklehash::xorb_hash(&[(chunk_hash, 1024)]);
|
||||
|
||||
@@ -2533,4 +2542,12 @@ mod tests {
|
||||
let deserialized = XorbObject::deserialize(&mut reader).unwrap();
|
||||
assert_eq!(deserialized.info.num_chunks, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_xorb_uses_default_config() {
|
||||
let raw = build_raw_xorb(4, ChunkSize::Fixed(1024));
|
||||
let serialized = SerializedXorbObject::from_xorb(raw, false).unwrap();
|
||||
assert_eq!(serialized.num_chunks, 4);
|
||||
assert!(serialized.serialized_data.len() > 0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ use std::sync::{Arc, OnceLock};
|
||||
use anyhow::Result;
|
||||
use clap::{Args, Parser, Subcommand};
|
||||
use ulid::Ulid;
|
||||
use xet_data::processing::configurations::*;
|
||||
use xet_data::processing::configurations::TranslatorConfig;
|
||||
use xet_data::processing::{FileUploadSession, Sha256Policy, XetFileInfo};
|
||||
use xet_runtime::core::XetRuntime;
|
||||
|
||||
@@ -127,8 +127,6 @@ async fn smudge_file(arg: &SmudgeArg) -> Result<()> {
|
||||
}
|
||||
|
||||
async fn smudge(_name: Arc<str>, mut reader: impl Read, output_path: PathBuf) -> Result<()> {
|
||||
use xet_data::processing::configurations::TranslatorConfig;
|
||||
|
||||
let mut input = String::new();
|
||||
reader.read_to_string(&mut input)?;
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ use xet_core_structures::xorb_object::CompressionScheme;
|
||||
use xet_data::processing::data_client::default_config;
|
||||
use xet_data::processing::migration_tool::hub_client_token_refresher::HubClientTokenRefresher;
|
||||
use xet_data::processing::migration_tool::migrate::migrate_files_impl;
|
||||
use xet_runtime::config::XetConfig;
|
||||
use xet_runtime::core::XetRuntime;
|
||||
|
||||
const DEFAULT_HF_ENDPOINT: &str = "https://huggingface.co";
|
||||
@@ -127,16 +128,8 @@ impl Command {
|
||||
let file_paths = walk_files(arg.files, arg.recursive);
|
||||
eprintln!("Dedupping {} files...", file_paths.len());
|
||||
|
||||
let (all_file_info, clean_ret, total_bytes_trans) = migrate_files_impl(
|
||||
file_paths,
|
||||
None,
|
||||
arg.sequential,
|
||||
hub_client,
|
||||
None,
|
||||
arg.compression.and_then(|c| CompressionScheme::try_from(c).ok()),
|
||||
!arg.migrate,
|
||||
)
|
||||
.await?;
|
||||
let (all_file_info, clean_ret, total_bytes_trans) =
|
||||
migrate_files_impl(file_paths, None, arg.sequential, hub_client, None, !arg.migrate).await?;
|
||||
|
||||
// Print file info for analysis
|
||||
if !arg.migrate {
|
||||
@@ -216,19 +209,12 @@ async fn query_reconstruction(
|
||||
|
||||
let config = default_config(
|
||||
jwt_info.cas_url.clone(),
|
||||
None,
|
||||
Some((jwt_info.access_token, jwt_info.exp)),
|
||||
Some(token_refresher),
|
||||
Some(Arc::new(headers)),
|
||||
)?;
|
||||
let cas_storage_config = &config.data_config;
|
||||
let remote_client = RemoteClient::new(
|
||||
&jwt_info.cas_url,
|
||||
&cas_storage_config.auth,
|
||||
"",
|
||||
true,
|
||||
cas_storage_config.custom_headers.clone(),
|
||||
);
|
||||
let remote_client =
|
||||
RemoteClient::new(&jwt_info.cas_url, &config.session.auth, "", true, config.session.custom_headers.clone());
|
||||
|
||||
// Use V1 directly so the query tool returns the raw QueryReconstructionResponse for inspection.
|
||||
remote_client
|
||||
@@ -239,7 +225,22 @@ async fn query_reconstruction(
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let cli = XCommand::parse();
|
||||
let threadpool = XetRuntime::new()?;
|
||||
|
||||
let mut config = XetConfig::new();
|
||||
if let Command::Dedup(ref arg) = cli.command
|
||||
&& let Some(c) = arg.compression
|
||||
{
|
||||
let scheme = CompressionScheme::try_from(c).map_err(|_| {
|
||||
anyhow::anyhow!("Invalid compression value {c}; expected one of: 0 (none), 1 (lz4), 2 (bg4-lz4), 99 (auto)")
|
||||
})?;
|
||||
config
|
||||
.xorb
|
||||
.compression_policy
|
||||
.try_set(<&str>::from(scheme))
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
}
|
||||
|
||||
let threadpool = XetRuntime::new_with_config(config)?;
|
||||
threadpool.external_run_async_task(async move { cli.run().await })??;
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -1,196 +1,237 @@
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use http::HeaderMap;
|
||||
use tracing::info;
|
||||
use xet_client::cas_client::auth::AuthConfig;
|
||||
use xet_client::cas_client::remote_client::PREFIX_DEFAULT;
|
||||
use xet_core_structures::xorb_object::CompressionScheme;
|
||||
use xet_runtime::core::{xet_cache_root, xet_config};
|
||||
|
||||
use super::errors::Result;
|
||||
|
||||
/// Session-specific configuration that varies per upload/download session.
|
||||
/// These are runtime values that cannot be configured via environment variables.
|
||||
#[derive(Debug)]
|
||||
pub enum Endpoint {
|
||||
Server(String),
|
||||
FileSystem(PathBuf),
|
||||
InMemory,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DataConfig {
|
||||
pub endpoint: Endpoint,
|
||||
pub compression: Option<CompressionScheme>,
|
||||
pub struct SessionContext {
|
||||
/// The endpoint URL. Use the `local://` prefix (configurable via `HF_XET_DATA_LOCAL_CAS_SCHEME`)
|
||||
/// to specify a local filesystem path, or `memory://` for in-memory storage.
|
||||
pub endpoint: String,
|
||||
pub auth: Option<AuthConfig>,
|
||||
pub prefix: String,
|
||||
pub staging_directory: Option<PathBuf>,
|
||||
pub custom_headers: Option<Arc<HeaderMap>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct GlobalDedupConfig {
|
||||
pub global_dedup_policy: GlobalDedupPolicy,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RepoInfo {
|
||||
pub repo_paths: Vec<String>,
|
||||
pub session_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Default, Clone, Debug, Copy)]
|
||||
pub enum GlobalDedupPolicy {
|
||||
/// Never query for new shards using chunk hashes.
|
||||
Never,
|
||||
impl SessionContext {
|
||||
/// Returns true if this endpoint points to a local filesystem path.
|
||||
pub fn is_local(&self) -> bool {
|
||||
self.endpoint.starts_with(&xet_config().data.local_cas_scheme)
|
||||
}
|
||||
|
||||
/// Always query for new shards by chunks
|
||||
#[default]
|
||||
Always,
|
||||
}
|
||||
/// Returns the local filesystem path if this is a local endpoint.
|
||||
pub fn local_path(&self) -> Option<PathBuf> {
|
||||
let path = self.endpoint.strip_prefix(&xet_config().data.local_cas_scheme)?;
|
||||
Some(PathBuf::from(path))
|
||||
}
|
||||
|
||||
impl FromStr for GlobalDedupPolicy {
|
||||
type Err = std::io::Error;
|
||||
/// Returns true if this endpoint uses in-memory storage.
|
||||
pub fn is_memory(&self) -> bool {
|
||||
self.endpoint == "memory://"
|
||||
}
|
||||
|
||||
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"never" => Ok(GlobalDedupPolicy::Never),
|
||||
"always" => Ok(GlobalDedupPolicy::Always),
|
||||
_ => Err(std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidInput,
|
||||
format!("Invalid global dedup query policy, should be one of never, direct_only, always: {s}"),
|
||||
)),
|
||||
/// Creates a SessionContext for local filesystem-based operations.
|
||||
pub fn for_local_path(base_dir: impl AsRef<Path>) -> Self {
|
||||
let path = base_dir.as_ref().to_path_buf();
|
||||
let endpoint = format!("{}{}", xet_config().data.local_cas_scheme, path.display());
|
||||
Self {
|
||||
endpoint,
|
||||
auth: None,
|
||||
custom_headers: None,
|
||||
repo_paths: vec!["".into()],
|
||||
session_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a SessionContext for in-memory storage.
|
||||
pub fn for_memory() -> Self {
|
||||
Self {
|
||||
endpoint: "memory://".into(),
|
||||
auth: None,
|
||||
custom_headers: None,
|
||||
repo_paths: vec!["".into()],
|
||||
session_id: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ShardConfig {
|
||||
pub prefix: String,
|
||||
pub session_directory: PathBuf,
|
||||
pub cache_directory: PathBuf,
|
||||
pub global_dedup_policy: GlobalDedupPolicy,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ProgressConfig {
|
||||
pub aggregate: bool,
|
||||
}
|
||||
|
||||
/// Main configuration for file upload/download operations.
|
||||
/// Combines session-specific values with runtime-computed paths derived from the endpoint.
|
||||
#[derive(Debug)]
|
||||
pub struct TranslatorConfig {
|
||||
pub data_config: DataConfig,
|
||||
pub shard_config: ShardConfig,
|
||||
pub repo_info: Option<RepoInfo>,
|
||||
pub session_id: Option<String>,
|
||||
pub progress_config: ProgressConfig,
|
||||
pub session: SessionContext,
|
||||
|
||||
/// Directory for caching shard files.
|
||||
pub shard_cache_directory: PathBuf,
|
||||
|
||||
/// Directory for session-specific shard files.
|
||||
pub shard_session_directory: PathBuf,
|
||||
|
||||
/// Per-session override: when true, progress aggregation is disabled
|
||||
/// regardless of the global `HF_XET_DATA_AGGREGATE_PROGRESS` config value.
|
||||
pub force_disable_progress_aggregation: bool,
|
||||
}
|
||||
|
||||
impl TranslatorConfig {
|
||||
pub fn local_config(base_dir: impl AsRef<Path>) -> Result<Self> {
|
||||
let path = base_dir.as_ref().join("xet");
|
||||
std::fs::create_dir_all(&path)?;
|
||||
fn create_base_xet_dir(base_dir: impl AsRef<Path>) -> Result<PathBuf> {
|
||||
let base_path = base_dir.as_ref().join("xet");
|
||||
std::fs::create_dir_all(&base_path)?;
|
||||
Ok(base_path)
|
||||
}
|
||||
|
||||
let translator_config = Self {
|
||||
data_config: DataConfig {
|
||||
endpoint: Endpoint::FileSystem(path.join("xorbs")),
|
||||
compression: Default::default(),
|
||||
auth: None,
|
||||
prefix: PREFIX_DEFAULT.into(),
|
||||
staging_directory: None,
|
||||
custom_headers: None,
|
||||
},
|
||||
shard_config: ShardConfig {
|
||||
prefix: PREFIX_DEFAULT.into(),
|
||||
cache_directory: path.join("shard-cache"),
|
||||
session_directory: path.join("shard-session"),
|
||||
global_dedup_policy: Default::default(),
|
||||
},
|
||||
repo_info: Some(RepoInfo {
|
||||
repo_paths: vec!["".into()],
|
||||
}),
|
||||
session_id: None,
|
||||
progress_config: ProgressConfig { aggregate: true },
|
||||
/// Creates a new TranslatorConfig from a SessionContext, computing all derived paths.
|
||||
pub fn new(session: SessionContext) -> Result<Self> {
|
||||
let config = xet_config();
|
||||
|
||||
let (shard_cache_directory, shard_session_directory) = if let Some(local_path) = session.local_path() {
|
||||
let base_path = local_path.join("xet");
|
||||
std::fs::create_dir_all(&base_path)?;
|
||||
|
||||
(base_path.join(&config.shard.cache_subdir), base_path.join(&config.session.dir_name))
|
||||
} else if session.is_memory() {
|
||||
let cache_path = xet_cache_root().join("memory");
|
||||
std::fs::create_dir_all(&cache_path)?;
|
||||
|
||||
(cache_path.join(&config.shard.cache_subdir), cache_path.join(&config.session.dir_name))
|
||||
} else {
|
||||
let cache_path = compute_cache_path(&session.endpoint);
|
||||
std::fs::create_dir_all(&cache_path)?;
|
||||
|
||||
let staging_directory = cache_path.join(&config.data.staging_subdir);
|
||||
std::fs::create_dir_all(&staging_directory)?;
|
||||
|
||||
(cache_path.join(&config.shard.cache_subdir), staging_directory.join(&config.session.dir_name))
|
||||
};
|
||||
|
||||
Ok(translator_config)
|
||||
info!(
|
||||
endpoint = %session.endpoint,
|
||||
session_id = ?session.session_id,
|
||||
shard_cache = %shard_cache_directory.display(),
|
||||
shard_session = %shard_session_directory.display(),
|
||||
"TranslatorConfig initialized"
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
session,
|
||||
shard_cache_directory,
|
||||
shard_session_directory,
|
||||
force_disable_progress_aggregation: false,
|
||||
})
|
||||
}
|
||||
|
||||
/// Creates a TranslatorConfig for local filesystem-based storage.
|
||||
pub fn local_config(base_dir: impl AsRef<Path>) -> Result<Self> {
|
||||
Self::new(SessionContext::for_local_path(base_dir))
|
||||
}
|
||||
|
||||
/// Creates a TranslatorConfig that uses in-memory storage for XORBs.
|
||||
/// Shard data still uses file-based storage in the provided base directory.
|
||||
pub fn memory_config(base_dir: impl AsRef<Path>) -> Result<Self> {
|
||||
let path = base_dir.as_ref().join("xet");
|
||||
std::fs::create_dir_all(&path)?;
|
||||
let session = SessionContext::for_memory();
|
||||
let config = xet_config();
|
||||
let base_path = Self::create_base_xet_dir(base_dir)?;
|
||||
|
||||
let translator_config = Self {
|
||||
data_config: DataConfig {
|
||||
endpoint: Endpoint::InMemory,
|
||||
compression: Default::default(),
|
||||
auth: None,
|
||||
prefix: PREFIX_DEFAULT.into(),
|
||||
staging_directory: None,
|
||||
custom_headers: None,
|
||||
},
|
||||
shard_config: ShardConfig {
|
||||
prefix: PREFIX_DEFAULT.into(),
|
||||
cache_directory: path.join("shard-cache"),
|
||||
session_directory: path.join("shard-session"),
|
||||
global_dedup_policy: Default::default(),
|
||||
},
|
||||
repo_info: Some(RepoInfo {
|
||||
repo_paths: vec!["".into()],
|
||||
}),
|
||||
session_id: None,
|
||||
progress_config: ProgressConfig { aggregate: true },
|
||||
};
|
||||
|
||||
Ok(translator_config)
|
||||
Ok(Self {
|
||||
session,
|
||||
shard_cache_directory: base_path.join(&config.shard.cache_subdir),
|
||||
shard_session_directory: base_path.join(&config.session.dir_name),
|
||||
force_disable_progress_aggregation: false,
|
||||
})
|
||||
}
|
||||
|
||||
/// Creates a TranslatorConfig that connects to a CAS server at the given endpoint.
|
||||
/// Shard cache and session directories are created under the provided base directory.
|
||||
/// Useful for tests that use LocalTestServer.
|
||||
pub fn test_server_config(endpoint: impl AsRef<str>, base_dir: impl AsRef<Path>) -> Result<Self> {
|
||||
let path = base_dir.as_ref().join("xet");
|
||||
std::fs::create_dir_all(&path)?;
|
||||
|
||||
let translator_config = Self {
|
||||
data_config: DataConfig {
|
||||
endpoint: Endpoint::Server(endpoint.as_ref().to_string()),
|
||||
compression: Default::default(),
|
||||
auth: None,
|
||||
prefix: PREFIX_DEFAULT.into(),
|
||||
staging_directory: None,
|
||||
custom_headers: None,
|
||||
},
|
||||
shard_config: ShardConfig {
|
||||
prefix: PREFIX_DEFAULT.into(),
|
||||
cache_directory: path.join("shard-cache"),
|
||||
session_directory: path.join("shard-session"),
|
||||
global_dedup_policy: Default::default(),
|
||||
},
|
||||
repo_info: Some(RepoInfo {
|
||||
repo_paths: vec!["".into()],
|
||||
}),
|
||||
let session = SessionContext {
|
||||
endpoint: endpoint.as_ref().to_string(),
|
||||
auth: None,
|
||||
custom_headers: None,
|
||||
repo_paths: vec!["".into()],
|
||||
session_id: None,
|
||||
progress_config: ProgressConfig { aggregate: true },
|
||||
};
|
||||
let config = xet_config();
|
||||
let base_path = Self::create_base_xet_dir(base_dir)?;
|
||||
|
||||
Ok(translator_config)
|
||||
Ok(Self {
|
||||
session,
|
||||
shard_cache_directory: base_path.join(&config.shard.cache_subdir),
|
||||
shard_session_directory: base_path.join(&config.session.dir_name),
|
||||
force_disable_progress_aggregation: false,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn disable_progress_aggregation(self) -> Self {
|
||||
Self {
|
||||
progress_config: ProgressConfig { aggregate: false },
|
||||
..self
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_session_id(self, session_id: &str) -> Self {
|
||||
if session_id.is_empty() {
|
||||
return self;
|
||||
}
|
||||
|
||||
Self {
|
||||
session_id: Some(session_id.to_owned()),
|
||||
..self
|
||||
}
|
||||
pub fn disable_progress_aggregation(mut self) -> Self {
|
||||
self.force_disable_progress_aggregation = true;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Computes a cache-safe path from an endpoint URL.
|
||||
fn compute_cache_path(endpoint: &str) -> PathBuf {
|
||||
let cache_root = xet_cache_root();
|
||||
|
||||
let endpoint_prefix = endpoint
|
||||
.chars()
|
||||
.take(16)
|
||||
.map(|c| if c.is_alphanumeric() { c } else { '_' })
|
||||
.collect::<String>();
|
||||
|
||||
let endpoint_hash = xet_core_structures::merklehash::compute_data_hash(endpoint.as_bytes()).base64();
|
||||
let endpoint_tag = format!("{endpoint_prefix}-{}", &endpoint_hash[..16]);
|
||||
|
||||
cache_root.join(endpoint_tag)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use tempfile::tempdir;
|
||||
|
||||
use super::{SessionContext, TranslatorConfig};
|
||||
|
||||
#[test]
|
||||
fn test_session_context_mode_detection() {
|
||||
let temp_dir = tempdir().unwrap();
|
||||
let local_session = SessionContext::for_local_path(temp_dir.path());
|
||||
assert!(local_session.is_local());
|
||||
assert!(!local_session.is_memory());
|
||||
assert_eq!(local_session.local_path().unwrap(), temp_dir.path().to_path_buf());
|
||||
|
||||
let memory_session = SessionContext::for_memory();
|
||||
assert!(memory_session.is_memory());
|
||||
assert!(!memory_session.is_local());
|
||||
assert!(memory_session.local_path().is_none());
|
||||
|
||||
let remote_session = SessionContext {
|
||||
endpoint: "http://localhost:8080".into(),
|
||||
auth: None,
|
||||
custom_headers: None,
|
||||
repo_paths: Vec::new(),
|
||||
session_id: None,
|
||||
};
|
||||
assert!(!remote_session.is_local());
|
||||
assert!(!remote_session.is_memory());
|
||||
assert!(remote_session.local_path().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_memory_and_server_configs_use_base_xet_layout() {
|
||||
let temp_dir = tempdir().unwrap();
|
||||
|
||||
let memory_config = TranslatorConfig::memory_config(temp_dir.path()).unwrap();
|
||||
assert!(memory_config.shard_cache_directory.starts_with(temp_dir.path().join("xet")));
|
||||
assert!(memory_config.shard_session_directory.starts_with(temp_dir.path().join("xet")));
|
||||
|
||||
let server_config = TranslatorConfig::test_server_config("http://localhost:8080", temp_dir.path()).unwrap();
|
||||
assert!(server_config.shard_cache_directory.starts_with(temp_dir.path().join("xet")));
|
||||
assert!(server_config.shard_session_directory.starts_with(temp_dir.path().join("xet")));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,16 +6,14 @@ use std::sync::Arc;
|
||||
use bytes::Bytes;
|
||||
use http::header::HeaderMap;
|
||||
use itertools::multizip;
|
||||
use tracing::{Instrument, Span, info, info_span, instrument};
|
||||
use tracing::{Instrument, Span, info_span, instrument};
|
||||
use ulid::Ulid;
|
||||
use xet_client::cas_client::auth::{AuthConfig, TokenRefresher};
|
||||
use xet_client::cas_client::remote_client::PREFIX_DEFAULT;
|
||||
use xet_core_structures::merklehash::MerkleHash;
|
||||
use xet_core_structures::xorb_object::CompressionScheme;
|
||||
use xet_runtime::core::par_utils::run_constrained_with_semaphore;
|
||||
use xet_runtime::core::{XetRuntime, check_sigint_shutdown, xet_cache_root, xet_config};
|
||||
use xet_runtime::core::{XetRuntime, check_sigint_shutdown, xet_config};
|
||||
|
||||
use super::configurations::*;
|
||||
use super::configurations::{SessionContext, TranslatorConfig};
|
||||
use super::errors::DataProcessingError;
|
||||
use super::file_cleaner::Sha256Policy;
|
||||
use super::file_download_session::FileDownloadSession;
|
||||
@@ -25,70 +23,22 @@ use crate::progress_tracking::TrackingProgressUpdater;
|
||||
|
||||
pub fn default_config(
|
||||
endpoint: String,
|
||||
xorb_compression: Option<CompressionScheme>,
|
||||
token_info: Option<(String, u64)>,
|
||||
token_refresher: Option<Arc<dyn TokenRefresher>>,
|
||||
custom_headers: Option<Arc<HeaderMap>>,
|
||||
) -> errors::Result<TranslatorConfig> {
|
||||
// Intercept local:// to run a simulated CAS server in a specified directory.
|
||||
// This is useful for testing and development.
|
||||
if endpoint.starts_with("local://") {
|
||||
let local_path = endpoint.strip_prefix("local://").unwrap();
|
||||
let local_path = PathBuf::from(local_path);
|
||||
std::fs::create_dir_all(&local_path)?;
|
||||
return TranslatorConfig::local_config(local_path);
|
||||
}
|
||||
|
||||
let cache_root_path = xet_cache_root();
|
||||
info!("Using cache path {cache_root_path:?}.");
|
||||
|
||||
let (token, token_expiration) = token_info.unzip();
|
||||
let auth_cfg = AuthConfig::maybe_new(token, token_expiration, token_refresher);
|
||||
|
||||
// Calculate a fingerprint of the current endpoint to make sure caches stay separated.
|
||||
let endpoint_tag = {
|
||||
let endpoint_prefix = endpoint
|
||||
.chars()
|
||||
.take(16)
|
||||
.map(|c| if c.is_alphanumeric() { c } else { '_' })
|
||||
.collect::<String>();
|
||||
|
||||
// If more gets added
|
||||
let endpoint_hash = xet_core_structures::merklehash::compute_data_hash(endpoint.as_bytes()).base64();
|
||||
|
||||
format!("{endpoint_prefix}-{}", &endpoint_hash[..16])
|
||||
};
|
||||
|
||||
let cache_path = cache_root_path.join(endpoint_tag);
|
||||
std::fs::create_dir_all(&cache_path)?;
|
||||
|
||||
let staging_root = cache_path.join("staging");
|
||||
std::fs::create_dir_all(&staging_root)?;
|
||||
|
||||
let translator_config = TranslatorConfig {
|
||||
data_config: DataConfig {
|
||||
endpoint: Endpoint::Server(endpoint.clone()),
|
||||
compression: xorb_compression,
|
||||
auth: auth_cfg.clone(),
|
||||
prefix: PREFIX_DEFAULT.into(),
|
||||
staging_directory: None,
|
||||
custom_headers,
|
||||
},
|
||||
shard_config: ShardConfig {
|
||||
prefix: PREFIX_DEFAULT.into(),
|
||||
cache_directory: cache_path.join("shard-cache"),
|
||||
session_directory: staging_root.join("shard-session"),
|
||||
global_dedup_policy: Default::default(),
|
||||
},
|
||||
repo_info: Some(RepoInfo {
|
||||
repo_paths: vec!["".into()],
|
||||
}),
|
||||
let session = SessionContext {
|
||||
endpoint,
|
||||
auth: auth_cfg,
|
||||
custom_headers,
|
||||
repo_paths: vec!["".into()],
|
||||
session_id: Some(Ulid::new().to_string()),
|
||||
progress_config: ProgressConfig { aggregate: true },
|
||||
};
|
||||
|
||||
// Return the temp dir so that it's not dropped and thus the directory deleted.
|
||||
Ok(translator_config)
|
||||
TranslatorConfig::new(session)
|
||||
}
|
||||
|
||||
#[instrument(skip_all, name = "data_client::upload_bytes", fields(session_id = tracing::field::Empty, num_files=file_contents.len()))]
|
||||
@@ -111,13 +61,12 @@ pub async fn upload_bytes_async(
|
||||
|
||||
let config = default_config(
|
||||
endpoint.unwrap_or_else(|| xet_config().data.default_cas_endpoint.clone()),
|
||||
None,
|
||||
token_info,
|
||||
token_refresher,
|
||||
custom_headers,
|
||||
)?;
|
||||
|
||||
Span::current().record("session_id", &config.session_id);
|
||||
Span::current().record("session_id", &config.session.session_id);
|
||||
|
||||
let semaphore = XetRuntime::current().common().file_ingestion_semaphore.clone();
|
||||
let upload_session = FileUploadSession::new(config.into(), progress_updater).await?;
|
||||
@@ -168,7 +117,6 @@ pub async fn upload_async(
|
||||
// for each file, return the filehash
|
||||
let config = default_config(
|
||||
endpoint.unwrap_or_else(|| xet_config().data.default_cas_endpoint.clone()),
|
||||
None,
|
||||
token_info,
|
||||
token_refresher,
|
||||
custom_headers,
|
||||
@@ -176,7 +124,7 @@ pub async fn upload_async(
|
||||
|
||||
let span = Span::current();
|
||||
|
||||
span.record("session_id", &config.session_id);
|
||||
span.record("session_id", &config.session.session_id);
|
||||
|
||||
let upload_session = FileUploadSession::new(config.into(), progress_updater).await?;
|
||||
|
||||
@@ -215,14 +163,13 @@ pub async fn download_async(
|
||||
}
|
||||
let config: Arc<TranslatorConfig> = default_config(
|
||||
endpoint.unwrap_or_else(|| xet_config().data.default_cas_endpoint.clone()),
|
||||
None,
|
||||
token_info,
|
||||
token_refresher,
|
||||
custom_headers,
|
||||
)?
|
||||
.into();
|
||||
|
||||
Span::current().record("session_id", &config.session_id);
|
||||
Span::current().record("session_id", &config.session.session_id);
|
||||
|
||||
let updaters = match progress_updaters {
|
||||
None => vec![None; file_infos.len()],
|
||||
@@ -424,11 +371,11 @@ mod tests {
|
||||
let _hf_home_guard = EnvVarGuard::set("HF_HOME", temp_dir.path().to_str().unwrap());
|
||||
|
||||
let endpoint = "http://localhost:8080".to_string();
|
||||
let result = default_config(endpoint, None, None, None, None);
|
||||
let result = default_config(endpoint, None, None, None);
|
||||
|
||||
assert!(result.is_ok());
|
||||
let config = result.unwrap();
|
||||
assert!(config.shard_config.cache_directory.starts_with(temp_dir.path()));
|
||||
assert!(config.shard_cache_directory.starts_with(temp_dir.path()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -441,11 +388,11 @@ mod tests {
|
||||
let hf_home_guard = EnvVarGuard::set("HF_HOME", temp_dir_hf_home.path().to_str().unwrap());
|
||||
|
||||
let endpoint = "http://localhost:8080".to_string();
|
||||
let result = default_config(endpoint, None, None, None, None);
|
||||
let result = default_config(endpoint, None, None, None);
|
||||
|
||||
assert!(result.is_ok());
|
||||
let config = result.unwrap();
|
||||
assert!(config.shard_config.cache_directory.starts_with(temp_dir_xet_cache.path()));
|
||||
assert!(config.shard_cache_directory.starts_with(temp_dir_xet_cache.path()));
|
||||
|
||||
drop(hf_xet_cache_guard);
|
||||
drop(hf_home_guard);
|
||||
@@ -454,11 +401,11 @@ mod tests {
|
||||
let _hf_home_guard = EnvVarGuard::set("HF_HOME", temp_dir.path().to_str().unwrap());
|
||||
|
||||
let endpoint = "http://localhost:8080".to_string();
|
||||
let result = default_config(endpoint, None, None, None, None);
|
||||
let result = default_config(endpoint, None, None, None);
|
||||
|
||||
assert!(result.is_ok());
|
||||
let config = result.unwrap();
|
||||
assert!(config.shard_config.cache_directory.starts_with(temp_dir.path()));
|
||||
assert!(config.shard_cache_directory.starts_with(temp_dir.path()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -468,24 +415,24 @@ mod tests {
|
||||
let _hf_xet_cache_guard = EnvVarGuard::set("HF_XET_CACHE", temp_dir.path().to_str().unwrap());
|
||||
|
||||
let endpoint = "http://localhost:8080".to_string();
|
||||
let result = default_config(endpoint, None, None, None, None);
|
||||
let result = default_config(endpoint, None, None, None);
|
||||
|
||||
assert!(result.is_ok());
|
||||
let config = result.unwrap();
|
||||
assert!(config.shard_config.cache_directory.starts_with(temp_dir.path()));
|
||||
assert!(config.shard_cache_directory.starts_with(temp_dir.path()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial(default_config_env)]
|
||||
fn test_default_config_without_env_vars() {
|
||||
let endpoint = "http://localhost:8080".to_string();
|
||||
let result = default_config(endpoint, None, None, None, None);
|
||||
let result = default_config(endpoint, None, None, None);
|
||||
|
||||
let expected = home_dir().unwrap().join(".cache").join("huggingface").join("xet");
|
||||
|
||||
assert!(result.is_ok());
|
||||
let config = result.unwrap();
|
||||
let test_cache_dir = &config.shard_config.cache_directory;
|
||||
let test_cache_dir = &config.shard_cache_directory;
|
||||
assert!(
|
||||
test_cache_dir.starts_with(&expected),
|
||||
"cache dir = {test_cache_dir:?}; does not start with {expected:?}",
|
||||
|
||||
@@ -6,7 +6,6 @@ use tracing::Instrument;
|
||||
use xet_core_structures::merklehash::MerkleHash;
|
||||
use xet_core_structures::metadata_shard::file_structs::FileDataSequenceEntry;
|
||||
|
||||
use super::configurations::GlobalDedupPolicy;
|
||||
use super::errors::Result;
|
||||
use super::file_upload_session::FileUploadSession;
|
||||
use crate::deduplication::{DeduplicationDataInterface, RawXorbData};
|
||||
@@ -26,7 +25,7 @@ impl UploadSessionDataManager {
|
||||
}
|
||||
|
||||
fn global_dedup_queries_enabled(&self) -> bool {
|
||||
matches!(self.session.config.shard_config.global_dedup_policy, GlobalDedupPolicy::Always)
|
||||
xet_runtime::core::xet_config().deduplication.global_dedup_query_enabled
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ impl FileDownloadSession {
|
||||
progress_updater: Option<Arc<dyn TrackingProgressUpdater>>,
|
||||
) -> Result<Arc<Self>> {
|
||||
let session_id = config
|
||||
.session
|
||||
.session_id
|
||||
.as_ref()
|
||||
.map(Cow::Borrowed)
|
||||
|
||||
@@ -16,7 +16,7 @@ use xet_core_structures::metadata_shard::file_structs::MDBFileInfo;
|
||||
use xet_core_structures::xorb_object::SerializedXorbObject;
|
||||
use xet_runtime::core::{XetRuntime, xet_config};
|
||||
|
||||
use super::configurations::*;
|
||||
use super::configurations::TranslatorConfig;
|
||||
use super::errors::*;
|
||||
use super::file_cleaner::{Sha256Policy, SingleFileCleaner};
|
||||
use super::remote_client_interface::create_remote_client;
|
||||
@@ -37,13 +37,9 @@ use crate::progress_tracking::{NoOpProgressUpdater, TrackingProgressUpdater};
|
||||
/// that succeeds or fails as a unit; i.e. all files get uploaded on finalization, and all shards
|
||||
/// and xorbs needed to reconstruct those files are properly uploaded and registered.
|
||||
pub struct FileUploadSession {
|
||||
// The parts of this that manage the
|
||||
pub(crate) client: Arc<dyn Client + Send + Sync>,
|
||||
pub(crate) shard_interface: SessionShardInterface,
|
||||
|
||||
/// The configuration settings, if needed.
|
||||
pub(crate) config: Arc<TranslatorConfig>,
|
||||
|
||||
/// Tracking upload completion between xorbs and files.
|
||||
pub(crate) completion_tracker: Arc<CompletionTracker>,
|
||||
|
||||
@@ -85,6 +81,7 @@ impl FileUploadSession {
|
||||
dry_run: bool,
|
||||
) -> Result<Arc<FileUploadSession>> {
|
||||
let session_id = config
|
||||
.session
|
||||
.session_id
|
||||
.as_ref()
|
||||
.map(Cow::Borrowed)
|
||||
@@ -94,7 +91,10 @@ impl FileUploadSession {
|
||||
match upload_progress_updater {
|
||||
Some(updater) => {
|
||||
let flush_interval = xet_config().data.progress_update_interval;
|
||||
if !flush_interval.is_zero() && config.progress_config.aggregate {
|
||||
if !flush_interval.is_zero()
|
||||
&& xet_config().data.aggregate_progress
|
||||
&& !config.force_disable_progress_aggregation
|
||||
{
|
||||
let aggregator = AggregatingProgressUpdater::new(
|
||||
updater,
|
||||
flush_interval,
|
||||
@@ -127,7 +127,6 @@ impl FileUploadSession {
|
||||
Ok(Arc::new(Self {
|
||||
shard_interface,
|
||||
client,
|
||||
config,
|
||||
completion_tracker,
|
||||
progress_aggregator,
|
||||
current_session_data: Mutex::new(DataAggregator::default()),
|
||||
@@ -320,14 +319,13 @@ impl FileUploadSession {
|
||||
|
||||
// Serialize the object; this can be relatively expensive, so run it on a compute thread.
|
||||
// XORBs are sent without footer - the server/client reconstructs it from chunk data.
|
||||
let compression_scheme = self.config.data_config.compression;
|
||||
let xorb_obj = XetRuntime::current()
|
||||
.spawn_blocking(move || SerializedXorbObject::from_xorb(xorb, compression_scheme, false))
|
||||
.spawn_blocking(move || SerializedXorbObject::from_xorb(xorb, false))
|
||||
.await??;
|
||||
|
||||
let session = self.clone();
|
||||
let upload_permit = self.client.acquire_upload_permit().await?;
|
||||
let cas_prefix = session.config.data_config.prefix.clone();
|
||||
let cas_prefix = xet_config().data.default_prefix.clone();
|
||||
let completion_tracker = self.completion_tracker.clone();
|
||||
let xorb_hash = xorb_obj.hash;
|
||||
let raw_num_bytes = xorb_obj.raw_num_bytes;
|
||||
|
||||
@@ -6,7 +6,6 @@ use tracing::{Instrument, Span, info_span, instrument};
|
||||
use xet_client::cas_client::auth::TokenRefresher;
|
||||
use xet_client::hub_client::{BearerCredentialHelper, HubClient, Operation, RepoInfo};
|
||||
use xet_core_structures::metadata_shard::file_structs::MDBFileInfo;
|
||||
use xet_core_structures::xorb_object::CompressionScheme;
|
||||
use xet_runtime::core::XetRuntime;
|
||||
use xet_runtime::core::par_utils::run_constrained;
|
||||
|
||||
@@ -48,7 +47,7 @@ pub async fn migrate_with_external_runtime(
|
||||
Some(Arc::new(headers)),
|
||||
)?;
|
||||
|
||||
migrate_files_impl(file_paths, sha256s, false, hub_client, cas_endpoint, None, false).await?;
|
||||
migrate_files_impl(file_paths, sha256s, false, hub_client, cas_endpoint, false).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -63,7 +62,6 @@ pub async fn migrate_files_impl(
|
||||
sequential: bool,
|
||||
hub_client: HubClient,
|
||||
cas_endpoint: Option<String>,
|
||||
compression: Option<CompressionScheme>,
|
||||
dry_run: bool,
|
||||
) -> Result<MigrationInfo> {
|
||||
let operation = Operation::Upload;
|
||||
@@ -80,12 +78,11 @@ pub async fn migrate_files_impl(
|
||||
|
||||
let config = default_config(
|
||||
cas,
|
||||
compression,
|
||||
Some((jwt_info.access_token, jwt_info.exp)),
|
||||
Some(token_refresher),
|
||||
Some(Arc::new(headers)),
|
||||
)?;
|
||||
Span::current().record("session_id", &config.session_id);
|
||||
Span::current().record("session_id", &config.session.session_id);
|
||||
|
||||
let num_workers = if sequential {
|
||||
1
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::sync::Arc;
|
||||
|
||||
use xet_client::cas_client::{Client, RemoteClient};
|
||||
|
||||
use super::configurations::*;
|
||||
use super::configurations::TranslatorConfig;
|
||||
use super::errors::Result;
|
||||
|
||||
pub(crate) async fn create_remote_client(
|
||||
@@ -10,31 +10,30 @@ pub(crate) async fn create_remote_client(
|
||||
session_id: &str,
|
||||
dry_run: bool,
|
||||
) -> Result<Arc<dyn Client>> {
|
||||
let cas_storage_config = &config.data_config;
|
||||
let session = &config.session;
|
||||
|
||||
match cas_storage_config.endpoint {
|
||||
Endpoint::Server(ref endpoint) => Ok(RemoteClient::new(
|
||||
endpoint,
|
||||
&cas_storage_config.auth,
|
||||
if let Some(local_path) = session.local_path() {
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
{
|
||||
let xorb_path = local_path.join("xet").join("xorbs");
|
||||
Ok(xet_client::cas_client::LocalClient::new(xorb_path).await?)
|
||||
}
|
||||
#[cfg(target_family = "wasm")]
|
||||
unimplemented!("Local file system access is not supported in WASM builds")
|
||||
} else if session.is_memory() {
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
{
|
||||
Ok(xet_client::cas_client::MemoryClient::new())
|
||||
}
|
||||
#[cfg(target_family = "wasm")]
|
||||
unimplemented!("In-memory client is not supported in WASM builds")
|
||||
} else {
|
||||
Ok(RemoteClient::new(
|
||||
&session.endpoint,
|
||||
&session.auth,
|
||||
session_id,
|
||||
dry_run,
|
||||
cas_storage_config.custom_headers.clone(),
|
||||
)),
|
||||
Endpoint::FileSystem(ref path) => {
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
{
|
||||
Ok(xet_client::cas_client::LocalClient::new(path).await?)
|
||||
}
|
||||
#[cfg(target_family = "wasm")]
|
||||
unimplemented!("Local file system access is not supported in WASM builds")
|
||||
},
|
||||
Endpoint::InMemory => {
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
{
|
||||
Ok(xet_client::cas_client::MemoryClient::new())
|
||||
}
|
||||
#[cfg(target_family = "wasm")]
|
||||
unimplemented!("In-memory client is not supported in WASM builds")
|
||||
},
|
||||
session.custom_headers.clone(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,7 +32,6 @@ pub struct SessionShardInterface {
|
||||
cache_shard_manager: Arc<ShardFileManager>,
|
||||
|
||||
client: Arc<dyn Client + Send + Sync>,
|
||||
config: Arc<TranslatorConfig>,
|
||||
|
||||
dry_run: bool,
|
||||
|
||||
@@ -60,17 +59,17 @@ impl SessionShardInterface {
|
||||
dry_run: bool,
|
||||
) -> Result<Self> {
|
||||
// Create a temporary session directory where we hold all the shards before upload.
|
||||
std::fs::create_dir_all(&config.shard_config.session_directory)?;
|
||||
let shard_session_tempdir = TempDir::new_in(&config.shard_config.session_directory)?;
|
||||
std::fs::create_dir_all(&config.shard_session_directory)?;
|
||||
let shard_session_tempdir = TempDir::new_in(&config.shard_session_directory)?;
|
||||
|
||||
let session_dir = shard_session_tempdir.path().to_owned();
|
||||
|
||||
// Set up the cache dir.
|
||||
let cache_dir = &config.shard_config.cache_directory;
|
||||
let cache_dir = &config.shard_cache_directory;
|
||||
std::fs::create_dir_all(cache_dir)?;
|
||||
|
||||
// Set up the shard session directory.
|
||||
let xorb_metadata_staging_dir = config.shard_config.session_directory.join("xorb_metadata");
|
||||
let xorb_metadata_staging_dir = config.shard_session_directory.join("xorb_metadata");
|
||||
std::fs::create_dir_all(&xorb_metadata_staging_dir)?;
|
||||
|
||||
// To allow resume from previous session attempts, merge and copy all the valid shards in the xorb metadata
|
||||
@@ -129,7 +128,6 @@ impl SessionShardInterface {
|
||||
staged_shards_to_remove_on_success,
|
||||
xorb_metadata_staging: Mutex::new((SystemTime::now(), MDBInMemoryShard::default())),
|
||||
resumed_session_shard_manager,
|
||||
config,
|
||||
dry_run,
|
||||
_shard_session_dir: shard_session_tempdir,
|
||||
client,
|
||||
@@ -140,7 +138,7 @@ impl SessionShardInterface {
|
||||
pub async fn query_dedup_shard_by_chunk(&self, chunk_hash: &MerkleHash) -> Result<bool> {
|
||||
let Ok(Some(new_shard)) = self
|
||||
.client
|
||||
.query_for_global_dedup_shard(&self.config.shard_config.prefix, chunk_hash)
|
||||
.query_for_global_dedup_shard(&xet_config().data.default_prefix, chunk_hash)
|
||||
.await
|
||||
.info_error("Error attempting to query global dedup lookup.")
|
||||
else {
|
||||
|
||||
@@ -9,14 +9,19 @@ pub(super) fn create_translator_config(session: &XetSession) -> Result<Translato
|
||||
.clone()
|
||||
.unwrap_or_else(|| session.config.data.default_cas_endpoint.clone());
|
||||
|
||||
Ok(xet_data::processing::data_client::default_config(
|
||||
let mut config = xet_data::processing::data_client::default_config(
|
||||
endpoint,
|
||||
None, // xorb_compression
|
||||
session.token_info.clone(),
|
||||
session.token_refresher.clone(),
|
||||
session.custom_headers.clone(),
|
||||
)?
|
||||
.with_session_id(&session.id.to_string()))
|
||||
)?;
|
||||
|
||||
let session_id = session.id.to_string();
|
||||
if !session_id.is_empty() {
|
||||
config.session.session_id = Some(session_id);
|
||||
}
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// State of the upload commit and download group
|
||||
|
||||
@@ -23,6 +23,12 @@ pub const ENVIRONMENT_NAME_ALIASES: &[(&str, &str)] = &[
|
||||
("HF_XET_SHARD_CACHE_SIZE_LIMIT", "HF_XET_MDB_SHARD_CACHE_SIZE_LIMIT"),
|
||||
("HF_XET_SHARD_CHUNK_INDEX_TABLE_MAX_SIZE", "HF_XET_MDB_SHARD_CHUNK_INDEX_TABLE_MAX_SIZE"),
|
||||
("HF_XET_SHARD_CHUNK_INDEX_TABLE_MAX_SIZE", "HF_XET_CHUNK_INDEX_TABLE_MAX_SIZE"),
|
||||
// xorb compression fields were previously in the data group; support old HF_XET_DATA_XORB_* names
|
||||
(
|
||||
"HF_XET_XORB_COMPRESSION_SCHEME_RETEST_INTERVAL",
|
||||
"HF_XET_DATA_XORB_COMPRESSION_SCHEME_RETEST_INTERVAL",
|
||||
),
|
||||
("HF_XET_XORB_COMPRESSION_POLICY", "HF_XET_DATA_XORB_COMPRESSION_POLICY"),
|
||||
// Fixed concurrency aliases - these set initial, min, and max to the same value
|
||||
("HF_XET_CLIENT_AC_INITIAL_UPLOAD_CONCURRENCY", "HF_XET_FIXED_UPLOAD_CONCURRENCY"),
|
||||
("HF_XET_CLIENT_AC_MIN_UPLOAD_CONCURRENCY", "HF_XET_FIXED_UPLOAD_CONCURRENCY"),
|
||||
|
||||
@@ -80,4 +80,27 @@ crate::config_group!({
|
||||
/// Use the environment variable `HF_XET_DATA_DEFAULT_CAS_ENDPOINT` to set this value.
|
||||
ref default_cas_endpoint: String = "http://localhost:8080".to_string();
|
||||
|
||||
/// Whether to aggregate progress updates before sending them.
|
||||
/// When enabled, progress updates are batched and sent at regular intervals
|
||||
/// to reduce overhead.
|
||||
///
|
||||
/// The default value is true.
|
||||
///
|
||||
/// Use the environment variable `HF_XET_DATA_AGGREGATE_PROGRESS` to set this value.
|
||||
ref aggregate_progress: bool = true;
|
||||
|
||||
/// Default prefix used for CAS and shard operations.
|
||||
///
|
||||
/// The default value is "default".
|
||||
///
|
||||
/// Use the environment variable `HF_XET_DATA_DEFAULT_PREFIX` to set this value.
|
||||
ref default_prefix: String = "default".to_string();
|
||||
|
||||
/// Subdirectory name for staging data within the endpoint cache directory.
|
||||
///
|
||||
/// The default value is "staging".
|
||||
///
|
||||
/// Use the environment variable `HF_XET_DATA_STAGING_SUBDIR` to set this value.
|
||||
ref staging_subdir: String = "staging".to_string();
|
||||
|
||||
});
|
||||
|
||||
@@ -23,4 +23,13 @@ crate::config_group!({
|
||||
///
|
||||
/// Use the environment variable `HF_XET_DEDUPLICATION_MIN_N_CHUNKS_PER_RANGE` to set this value.
|
||||
ref min_n_chunks_per_range: f32 = 8.0;
|
||||
|
||||
/// Whether to enable global deduplication queries to the server.
|
||||
/// When enabled, the system will query the server for deduplication shards
|
||||
/// based on chunk hashes to enable cross-repository deduplication.
|
||||
///
|
||||
/// The default value is true.
|
||||
///
|
||||
/// Use the environment variable `HF_XET_DEDUPLICATION_GLOBAL_DEDUP_QUERY_ENABLED` to set this value.
|
||||
ref global_dedup_query_enabled: bool = true;
|
||||
});
|
||||
|
||||
9
xet_runtime/src/config/groups/session.rs
Normal file
9
xet_runtime/src/config/groups/session.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
crate::config_group!({
|
||||
|
||||
/// Subdirectory name for shard session data within the staging directory.
|
||||
///
|
||||
/// The default value is "shard-session".
|
||||
///
|
||||
/// Use the environment variable `HF_XET_SESSION_DIR_NAME` to set this value.
|
||||
ref dir_name: String = "shard-session".to_string();
|
||||
});
|
||||
@@ -26,7 +26,7 @@ crate::config_group!({
|
||||
///
|
||||
/// The default value is 16gb.
|
||||
///
|
||||
/// Use the environment variable `HF_XET_METADATA_SHARD_CACHE_SIZE_LIMIT` to set this value.
|
||||
/// Use the environment variable `HF_XET_SHARD_CACHE_SIZE_LIMIT` to set this value.
|
||||
ref cache_size_limit : ByteSize = ByteSize::from("16gb");
|
||||
|
||||
/// The maximum size of the chunk index table that's stored in memory. After this,
|
||||
@@ -36,4 +36,12 @@ crate::config_group!({
|
||||
///
|
||||
/// Use the environment variable `HF_XET_SHARD_CHUNK_INDEX_TABLE_MAX_SIZE` to set this value.
|
||||
ref chunk_index_table_max_size: usize = 64 * 1024 * 1024;
|
||||
|
||||
/// Subdirectory name for shard cache within the endpoint cache directory.
|
||||
///
|
||||
/// The default value is "shard-cache".
|
||||
///
|
||||
/// Use the environment variable `HF_XET_SHARD_CACHE_SUBDIR` to set this value.
|
||||
ref cache_subdir: String = "shard-cache".to_string();
|
||||
|
||||
});
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use crate::utils::ConfigEnum;
|
||||
|
||||
crate::config_group!({
|
||||
/// How often should we retest the compression scheme?
|
||||
/// Determining the optimal compression scheme takes time, but
|
||||
@@ -9,4 +11,13 @@ crate::config_group!({
|
||||
///
|
||||
/// Use the environment variable `HF_XET_XORB_COMPRESSION_SCHEME_RETEST_INTERVAL` to set this value.
|
||||
ref compression_scheme_retest_interval : usize = 32;
|
||||
|
||||
/// Compression policy for xorb data.
|
||||
/// Valid values: "" or "auto" for automatic detection, "none", "lz4", "bg4-lz4".
|
||||
/// When set to "" or "auto", the best compression scheme is chosen based on data analysis.
|
||||
///
|
||||
/// The default value is "auto" (auto-detect).
|
||||
///
|
||||
/// Use the environment variable `HF_XET_XORB_COMPRESSION_POLICY` to set this value.
|
||||
ref compression_policy: ConfigEnum = ConfigEnum::new("auto", &["", "auto", "none", "lz4", "bg4-lz4"]);
|
||||
});
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
#[macro_export]
|
||||
macro_rules! all_config_groups {
|
||||
($mac:ident) => {
|
||||
$mac!(data, shard, deduplication, chunk_cache, client, log, reconstruction, xorb);
|
||||
$mac!(data, shard, deduplication, chunk_cache, client, log, reconstruction, xorb, session);
|
||||
};
|
||||
}
|
||||
|
||||
@@ -107,7 +107,7 @@ macro_rules! config_group {
|
||||
}
|
||||
|
||||
let default_value: $type = $value;
|
||||
self.$name = <$type>::parse(stringify!($name), maybe_env_value, default_value);
|
||||
self.$name = <$type>::parse_config_value(stringify!($name), maybe_env_value, default_value);
|
||||
}
|
||||
)+
|
||||
}
|
||||
@@ -123,11 +123,12 @@ macro_rules! config_group {
|
||||
match name {
|
||||
$(
|
||||
stringify!($name) => {
|
||||
self.$name = <$type>::parse_user_value(&value_string)
|
||||
.ok_or_else(|| $crate::config::ConfigError::ParseError {
|
||||
if !self.$name.try_update_in_place(&value_string) {
|
||||
return Err($crate::config::ConfigError::ParseError {
|
||||
field: name.to_owned(),
|
||||
value: value_string,
|
||||
})?;
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
)+
|
||||
@@ -155,7 +156,7 @@ macro_rules! config_group {
|
||||
match name {
|
||||
$(
|
||||
stringify!($name) => {
|
||||
self.$name = <$type as $crate::config::python::PythonConfigValue>::from_python(value)?;
|
||||
<$type as $crate::config::python::PythonConfigValue>::update_from_python(&mut self.$name, value)?;
|
||||
Ok(())
|
||||
}
|
||||
)+
|
||||
|
||||
@@ -12,8 +12,6 @@ pub mod groups;
|
||||
#[cfg(feature = "python")]
|
||||
pub mod python;
|
||||
|
||||
// Re-export types from utils for backward compatibility and for use in config_group macro
|
||||
// Re-export XetConfig for convenience
|
||||
pub use xet_config::XetConfig;
|
||||
|
||||
pub use crate::utils::configuration_utils::ParsableConfigValue;
|
||||
@@ -26,6 +24,7 @@ pub type ChunkCacheConfig = groups::chunk_cache::ConfigValues;
|
||||
pub type ClientConfig = groups::client::ConfigValues;
|
||||
pub type LogConfig = groups::log::ConfigValues;
|
||||
pub type XorbConfig = groups::xorb::ConfigValues;
|
||||
pub type SessionConfig = groups::session::ConfigValues;
|
||||
|
||||
#[cfg(feature = "python")]
|
||||
pub use xet_config::py_xet_config::PyXetConfig;
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use pyo3::conversion::IntoPyObjectExt;
|
||||
use pyo3::prelude::*;
|
||||
|
||||
use crate::utils::ByteSize;
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
use crate::utils::TemplatedPathBuf;
|
||||
use crate::utils::{ByteSize, ConfigEnum};
|
||||
|
||||
/// Trait for converting config values to/from Python objects.
|
||||
///
|
||||
@@ -19,6 +19,17 @@ pub trait PythonConfigValue {
|
||||
fn from_python(obj: &Bound<'_, PyAny>) -> PyResult<Self>
|
||||
where
|
||||
Self: Sized;
|
||||
|
||||
/// Update the value in place from a Python object. The default delegates to
|
||||
/// `from_python`, but types like `ConfigEnum` override this to perform
|
||||
/// context-aware validation using the existing value's metadata.
|
||||
fn update_from_python(&mut self, obj: &Bound<'_, PyAny>) -> PyResult<()>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
*self = Self::from_python(obj)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! impl_python_extract {
|
||||
@@ -72,6 +83,22 @@ impl<T: PythonConfigValue> PythonConfigValue for Option<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl PythonConfigValue for ConfigEnum {
|
||||
fn to_python(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
|
||||
self.as_str().into_py_any(py)
|
||||
}
|
||||
|
||||
fn from_python(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
|
||||
let s: String = obj.extract()?;
|
||||
Ok(ConfigEnum::new_unchecked(s))
|
||||
}
|
||||
|
||||
fn update_from_python(&mut self, obj: &Bound<'_, PyAny>) -> PyResult<()> {
|
||||
let s: String = obj.extract()?;
|
||||
self.try_set(&s).map_err(|e| pyo3::exceptions::PyValueError::new_err(e))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
impl PythonConfigValue for TemplatedPathBuf {
|
||||
fn to_python(&self, py: Python<'_>) -> PyResult<Py<PyAny>> {
|
||||
|
||||
@@ -17,13 +17,11 @@ crate::all_config_groups!(define_xet_config);
|
||||
macro_rules! impl_xet_config_group_dispatch {
|
||||
($($group:ident),*) => {
|
||||
impl XetConfig {
|
||||
// Internal: parse a dotted config path into (group, field).
|
||||
fn split_path(path: &str) -> Result<(&str, &str), ConfigError> {
|
||||
path.split_once('.')
|
||||
.ok_or_else(|| ConfigError::InvalidPath(path.to_owned()))
|
||||
}
|
||||
|
||||
// Internal: parse a dotted config path and convert parse errors to Python exceptions.
|
||||
#[cfg(feature = "python")]
|
||||
fn split_path_for_python(path: &str) -> pyo3::PyResult<(&str, &str)> {
|
||||
Self::split_path(path).map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))
|
||||
@@ -40,7 +38,6 @@ macro_rules! impl_xet_config_group_dispatch {
|
||||
self
|
||||
}
|
||||
|
||||
// Internal: dispatches with_config field updates to the correct group.
|
||||
fn update_field(&mut self, path: &str, value: impl ToString) -> Result<(), ConfigError> {
|
||||
let (group, field) = Self::split_path(path)?;
|
||||
match group {
|
||||
@@ -51,7 +48,6 @@ macro_rules! impl_xet_config_group_dispatch {
|
||||
}
|
||||
}
|
||||
|
||||
// Internal: dispatches Python with_config field updates to the correct group.
|
||||
#[cfg(feature = "python")]
|
||||
fn update_field_from_python(
|
||||
&mut self,
|
||||
@@ -69,7 +65,6 @@ macro_rules! impl_xet_config_group_dispatch {
|
||||
}
|
||||
}
|
||||
|
||||
// Internal: dispatches Python get/item access to the correct group.
|
||||
#[cfg(feature = "python")]
|
||||
fn get_field_to_python(
|
||||
&self,
|
||||
@@ -114,7 +109,6 @@ macro_rules! impl_xet_config_group_dispatch {
|
||||
keys
|
||||
}
|
||||
|
||||
// Internal: collects all (key, value) pairs for Python iteration.
|
||||
#[cfg(feature = "python")]
|
||||
fn all_items_to_python(
|
||||
&self,
|
||||
|
||||
332
xet_runtime/src/utils/config_enum.rs
Normal file
332
xet_runtime/src/utils/config_enum.rs
Normal file
@@ -0,0 +1,332 @@
|
||||
use std::fmt;
|
||||
use std::ops::Deref;
|
||||
use std::str::FromStr;
|
||||
|
||||
use tracing::{event, info, warn};
|
||||
|
||||
use super::configuration_utils::{INFORMATION_LOG_LEVEL, ParsableConfigValue};
|
||||
|
||||
/// A config value restricted to a fixed set of valid lowercase string options.
|
||||
///
|
||||
/// Stores the normalized (lowercased) value and validates against `valid_values`
|
||||
/// at parse time. If the user provides an invalid value via environment variable,
|
||||
/// a warning is logged and the default is used instead.
|
||||
///
|
||||
/// # Usage in `config_group!`
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use crate::utils::ConfigEnum;
|
||||
///
|
||||
/// crate::config_group!({
|
||||
/// ref compression_policy: ConfigEnum = ConfigEnum::new("auto", &["", "auto", "none", "lz4"]);
|
||||
/// });
|
||||
/// ```
|
||||
#[derive(Clone)]
|
||||
pub struct ConfigEnum {
|
||||
value: String,
|
||||
valid_values: &'static [&'static str],
|
||||
}
|
||||
|
||||
impl ConfigEnum {
|
||||
pub fn new(default: &str, valid_values: &'static [&'static str]) -> Self {
|
||||
let lower = default.to_lowercase();
|
||||
debug_assert!(
|
||||
valid_values.iter().any(|v| v.to_lowercase() == lower),
|
||||
"Default value \"{default}\" is not in the valid values list: {valid_values:?}"
|
||||
);
|
||||
ConfigEnum {
|
||||
value: lower,
|
||||
valid_values,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a ConfigEnum with a string value and an empty `valid_values` list.
|
||||
/// Because `valid_values` is empty, `try_set` will reject every value on the
|
||||
/// resulting instance. This is intended only for deserialization paths
|
||||
/// (e.g. Python's `from_python`) where the caller validates through the
|
||||
/// *existing* field value's `try_set` rather than through this instance.
|
||||
pub fn new_unchecked(value: impl Into<String>) -> Self {
|
||||
ConfigEnum {
|
||||
value: value.into().to_lowercase(),
|
||||
valid_values: &[],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_str(&self) -> &str {
|
||||
&self.value
|
||||
}
|
||||
|
||||
pub fn valid_values(&self) -> &'static [&'static str] {
|
||||
self.valid_values
|
||||
}
|
||||
|
||||
/// Set the value if it is valid (case-insensitive), otherwise return an error.
|
||||
pub fn try_set(&mut self, value: &str) -> Result<(), String> {
|
||||
let lower = value.to_lowercase();
|
||||
if self.valid_values.iter().any(|v| v.to_lowercase() == lower) {
|
||||
self.value = lower;
|
||||
Ok(())
|
||||
} else {
|
||||
Err(format!("\"{value}\" is not a valid value. Valid values are: {:?}", self.valid_values))
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse the stored value into a target type via `FromStr`, matching the
|
||||
/// familiar `str::parse::<T>()` signature.
|
||||
///
|
||||
/// In debug builds, asserts that *every* entry in `valid_values` can be
|
||||
/// successfully parsed into `T`, catching mismatches between the config's
|
||||
/// allowed strings and the target type's parser at development time.
|
||||
pub fn parse<T>(&self) -> Result<T, T::Err>
|
||||
where
|
||||
T: FromStr,
|
||||
T::Err: fmt::Debug + fmt::Display,
|
||||
{
|
||||
#[cfg(debug_assertions)]
|
||||
for v in self.valid_values {
|
||||
if let Err(e) = v.parse::<T>() {
|
||||
panic!("ConfigEnum valid value \"{v}\" cannot be parsed into {}: {e}", std::any::type_name::<T>());
|
||||
}
|
||||
}
|
||||
self.value.parse::<T>()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Debug for ConfigEnum {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{:?}", self.value)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for ConfigEnum {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.value)
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for ConfigEnum {
|
||||
type Target = str;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.value
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<str> for ConfigEnum {
|
||||
fn as_ref(&self) -> &str {
|
||||
&self.value
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq<str> for ConfigEnum {
|
||||
fn eq(&self, other: &str) -> bool {
|
||||
self.value == other.to_lowercase()
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq<&str> for ConfigEnum {
|
||||
fn eq(&self, other: &&str) -> bool {
|
||||
self.value == other.to_lowercase()
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for ConfigEnum {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.value == other.value
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for ConfigEnum {}
|
||||
|
||||
impl ParsableConfigValue for ConfigEnum {
|
||||
fn parse_user_value(_value: &str) -> Option<Self> {
|
||||
None
|
||||
}
|
||||
|
||||
fn to_config_string(&self) -> String {
|
||||
self.value.clone()
|
||||
}
|
||||
|
||||
fn try_update_in_place(&mut self, value: &str) -> bool {
|
||||
self.try_set(value).is_ok()
|
||||
}
|
||||
|
||||
fn parse_config_value(variable_name: &str, value: Option<String>, default: Self) -> Self {
|
||||
match value {
|
||||
Some(v) => {
|
||||
let lower = v.to_lowercase();
|
||||
if default.valid_values.iter().any(|valid| valid.to_lowercase() == lower) {
|
||||
info!("Config: {variable_name} = {lower:?} (user set)");
|
||||
ConfigEnum {
|
||||
value: lower,
|
||||
valid_values: default.valid_values,
|
||||
}
|
||||
} else {
|
||||
warn!(
|
||||
"Configuration value \"{v}\" for {variable_name} is not valid. \
|
||||
Valid values are: {:?}. Reverting to default \"{}\".",
|
||||
default.valid_values, default.value
|
||||
);
|
||||
info!("Config: {variable_name} = {:?} (default due to invalid value)", default.value);
|
||||
default
|
||||
}
|
||||
},
|
||||
None => {
|
||||
event!(INFORMATION_LOG_LEVEL, "Config: {variable_name} = {:?} (default)", default.value);
|
||||
default
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::num::ParseIntError;
|
||||
|
||||
use super::*;
|
||||
|
||||
const VALID: &[&str] = &["", "auto", "none", "lz4", "bg4-lz4"];
|
||||
const INT_VALID: &[&str] = &["1", "2", "3"];
|
||||
|
||||
#[test]
|
||||
fn test_new_default_value() {
|
||||
let ce = ConfigEnum::new("auto", VALID);
|
||||
assert_eq!(ce.as_str(), "auto");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_new_normalizes_to_lowercase() {
|
||||
let ce = ConfigEnum::new("AUTO", VALID);
|
||||
assert_eq!(ce.as_str(), "auto");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deref_and_asref() {
|
||||
let ce = ConfigEnum::new("lz4", VALID);
|
||||
let s: &str = &ce;
|
||||
assert_eq!(s, "lz4");
|
||||
assert_eq!(ce.as_ref(), "lz4");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_eq_str() {
|
||||
let ce = ConfigEnum::new("lz4", VALID);
|
||||
assert_eq!(ce, "lz4");
|
||||
assert_eq!(ce, "LZ4");
|
||||
assert_ne!(ce, "auto");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_eq_str_ref() {
|
||||
let ce = ConfigEnum::new("lz4", VALID);
|
||||
assert_eq!(ce, "lz4");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_partial_eq_self() {
|
||||
let a = ConfigEnum::new("lz4", VALID);
|
||||
let b = ConfigEnum::new("lz4", VALID);
|
||||
assert_eq!(a, b);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_display() {
|
||||
let ce = ConfigEnum::new("bg4-lz4", VALID);
|
||||
assert_eq!(format!("{ce}"), "bg4-lz4");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_debug() {
|
||||
let ce = ConfigEnum::new("bg4-lz4", VALID);
|
||||
assert_eq!(format!("{ce:?}"), "\"bg4-lz4\"");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_valid_value() {
|
||||
let default = ConfigEnum::new("auto", VALID);
|
||||
let result = ConfigEnum::parse_config_value("test", Some("LZ4".to_string()), default);
|
||||
assert_eq!(result.as_str(), "lz4");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_invalid_value_returns_default() {
|
||||
let default = ConfigEnum::new("auto", VALID);
|
||||
let result = ConfigEnum::parse_config_value("test", Some("zstd".to_string()), default);
|
||||
assert_eq!(result.as_str(), "auto");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_none_returns_default() {
|
||||
let default = ConfigEnum::new("auto", VALID);
|
||||
let result = ConfigEnum::parse_config_value("test", None, default);
|
||||
assert_eq!(result.as_str(), "auto");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_empty_string_valid() {
|
||||
let default = ConfigEnum::new("auto", VALID);
|
||||
let result = ConfigEnum::parse_config_value("test", Some("".to_string()), default);
|
||||
assert_eq!(result.as_str(), "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "not in the valid values list")]
|
||||
#[cfg(debug_assertions)]
|
||||
fn test_new_invalid_default_panics() {
|
||||
let _ = ConfigEnum::new("invalid", VALID);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_valid_values_accessor() {
|
||||
let ce = ConfigEnum::new("auto", VALID);
|
||||
assert_eq!(ce.valid_values(), VALID);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_try_set_valid() {
|
||||
let mut ce = ConfigEnum::new("auto", VALID);
|
||||
assert!(ce.try_set("lz4").is_ok());
|
||||
assert_eq!(ce.as_str(), "lz4");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_try_set_case_insensitive() {
|
||||
let mut ce = ConfigEnum::new("auto", VALID);
|
||||
assert!(ce.try_set("LZ4").is_ok());
|
||||
assert_eq!(ce.as_str(), "lz4");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_try_set_invalid() {
|
||||
let mut ce = ConfigEnum::new("auto", VALID);
|
||||
assert!(ce.try_set("zstd").is_err());
|
||||
assert_eq!(ce.as_str(), "auto");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_try_set_empty_string() {
|
||||
let mut ce = ConfigEnum::new("auto", VALID);
|
||||
assert!(ce.try_set("").is_ok());
|
||||
assert_eq!(ce.as_str(), "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_success() {
|
||||
let ce = ConfigEnum::new("2", INT_VALID);
|
||||
let val: Result<u32, ParseIntError> = ce.parse();
|
||||
assert_eq!(val.unwrap(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_all_values_parseable() {
|
||||
let ce = ConfigEnum::new("1", INT_VALID);
|
||||
let _: u32 = ce.parse().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "cannot be parsed")]
|
||||
#[cfg(debug_assertions)]
|
||||
fn test_parse_panics_on_unparseable_valid_value() {
|
||||
let ce = ConfigEnum::new("auto", VALID);
|
||||
let _: Result<u32, ParseIntError> = ce.parse();
|
||||
}
|
||||
}
|
||||
@@ -20,9 +20,21 @@ pub trait ParsableConfigValue: std::fmt::Debug + Sized {
|
||||
/// Serialize this value to a string that can be parsed back via `parse_user_value`.
|
||||
fn to_config_string(&self) -> String;
|
||||
|
||||
/// Try to update this value in place from a string. Returns true on success.
|
||||
/// The default implementation delegates to `parse_user_value`, but types like
|
||||
/// `ConfigEnum` override this to use context-aware validation.
|
||||
fn try_update_in_place(&mut self, value: &str) -> bool {
|
||||
if let Some(v) = Self::parse_user_value(value) {
|
||||
*self = v;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse the value, returning the default if it can't be parsed or the string is empty.
|
||||
/// Issue a warning if it can't be parsed.
|
||||
fn parse(variable_name: &str, value: Option<String>, default: Self) -> Self {
|
||||
fn parse_config_value(variable_name: &str, value: Option<String>, default: Self) -> Self {
|
||||
match value {
|
||||
Some(v) => match Self::parse_user_value(&v) {
|
||||
Some(v) => {
|
||||
@@ -170,7 +182,7 @@ macro_rules! test_configurable_constants {
|
||||
{
|
||||
let default_value = $value;
|
||||
let maybe_env_value = std::env::var(concat!("HF_XET_",stringify!($name))).ok();
|
||||
<$type>::parse(stringify!($name), maybe_env_value, default_value)
|
||||
<$type>::parse_config_value(stringify!($name), maybe_env_value, default_value)
|
||||
}
|
||||
#[cfg(not(debug_assertions))]
|
||||
{
|
||||
|
||||
@@ -3,6 +3,9 @@ pub mod adjustable_semaphore;
|
||||
pub mod byte_size;
|
||||
pub use byte_size::ByteSize;
|
||||
|
||||
pub mod config_enum;
|
||||
pub use config_enum::ConfigEnum;
|
||||
|
||||
pub mod configuration_utils;
|
||||
pub use configuration_utils::is_high_performance;
|
||||
|
||||
|
||||
@@ -63,7 +63,21 @@ fn test_environment_variable_aliases() {
|
||||
assert_eq!(XetConfig::new().data.session_xorb_metadata_flush_max_count, 128);
|
||||
}
|
||||
|
||||
// MDB shard aliases
|
||||
// Xorb aliases (old HF_XET_DATA_XORB_* names)
|
||||
{
|
||||
let _guard = EnvVarGuard::set("HF_XET_DATA_XORB_COMPRESSION_SCHEME_RETEST_INTERVAL", "64");
|
||||
assert_eq!(XetConfig::new().xorb.compression_scheme_retest_interval, 64);
|
||||
}
|
||||
{
|
||||
let _guard = EnvVarGuard::set("HF_XET_DATA_XORB_COMPRESSION_POLICY", "lz4");
|
||||
assert_eq!(XetConfig::new().xorb.compression_policy.as_str(), "lz4");
|
||||
}
|
||||
|
||||
// Shard aliases
|
||||
{
|
||||
let _guard = EnvVarGuard::set("HF_XET_MDB_SHARD_CACHE_SIZE_LIMIT", "24gb");
|
||||
assert_eq!(*XetConfig::new().shard.cache_size_limit, 24_000_000_000);
|
||||
}
|
||||
{
|
||||
let _guard = EnvVarGuard::set("HF_XET_SHARD_CACHE_SIZE_LIMIT", "32gb");
|
||||
assert_eq!(*XetConfig::new().shard.cache_size_limit, 32_000_000_000);
|
||||
@@ -74,6 +88,31 @@ fn test_environment_variable_aliases() {
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial(config_env)]
|
||||
fn test_new_config_env_overrides() {
|
||||
{
|
||||
let _guard = EnvVarGuard::set("HF_XET_DATA_AGGREGATE_PROGRESS", "false");
|
||||
assert!(!XetConfig::new().data.aggregate_progress);
|
||||
}
|
||||
{
|
||||
let _guard = EnvVarGuard::set("HF_XET_DATA_DEFAULT_PREFIX", "test-prefix");
|
||||
assert_eq!(XetConfig::new().data.default_prefix, "test-prefix");
|
||||
}
|
||||
{
|
||||
let _guard = EnvVarGuard::set("HF_XET_DATA_STAGING_SUBDIR", "tmp-stage");
|
||||
assert_eq!(XetConfig::new().data.staging_subdir, "tmp-stage");
|
||||
}
|
||||
{
|
||||
let _guard = EnvVarGuard::set("HF_XET_SESSION_DIR_NAME", "tmp-session");
|
||||
assert_eq!(XetConfig::new().session.dir_name, "tmp-session");
|
||||
}
|
||||
{
|
||||
let _guard = EnvVarGuard::set("HF_XET_DEDUPLICATION_GLOBAL_DEDUP_QUERY_ENABLED", "false");
|
||||
assert!(!XetConfig::new().deduplication.global_dedup_query_enabled);
|
||||
}
|
||||
}
|
||||
|
||||
/// Test that primary environment variable takes precedence over alias when both are set.
|
||||
#[test]
|
||||
#[serial(config_env)]
|
||||
Reference in New Issue
Block a user