upgrade rust edition to 2024; upgrade rustc to 1.89 (#494)

- Upgrade Rust edition and rustc version to bring in some nice features,
e.g. let chains instead of nested if block.
- Fix clippy and format due to the upgrade.
- Fix a bug identified by the new rustc:
6cb0a7fb4e/xet_runtime/src/runtime.rs (L195)
```
#[cfg(not(target_family = "wasm"))]
{
    // A new multithreaded runtime with a capped number of threads
    TokioRuntimeBuilder::new_multi_thread().worker_threads(get_num_tokio_worker_threads())
}
```
here the end curly bracket drops the temporary builder while a `&mut
Self` to the dropped value is returned. (this may be due to a difference
between compilers regarding how they treat the scope of "{...}" of
`#[cfg(...))] {...}`?)
This commit is contained in:
Di Xiao
2025-09-17 10:28:50 -07:00
committed by GitHub
parent 6cb0a7fb4e
commit fa030edcd5
90 changed files with 473 additions and 437 deletions

View File

@@ -30,8 +30,8 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install Rust 1.86
uses: dtolnay/rust-toolchain@1.86.0
- name: Install Rust 1.89
uses: dtolnay/rust-toolchain@1.89.0
with:
components: clippy
- name: Lint
@@ -55,8 +55,8 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install Rust 1.86
uses: dtolnay/rust-toolchain@1.86.0
- name: Install Rust 1.89
uses: dtolnay/rust-toolchain@1.89.0
with:
components: clippy
- name: Build and Test
@@ -67,8 +67,8 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Install Rust 1.86
uses: dtolnay/rust-toolchain@1.86.0
- name: Install Rust 1.89
uses: dtolnay/rust-toolchain@1.89.0
with:
components: clippy
- name: Set up Git LFS

View File

@@ -1,7 +1,7 @@
[package]
name = "cas_client"
version = "0.14.5"
edition = "2021"
edition = "2024"
[dependencies]

View File

@@ -10,8 +10,8 @@ use deduplication::constants::MAX_XORB_BYTES;
use derivative::Derivative;
use error_printer::ErrorPrinter;
use futures::TryStreamExt;
use http::header::RANGE;
use http::StatusCode;
use http::header::RANGE;
use merklehash::MerkleHash;
use reqwest::Response;
use reqwest_middleware::ClientWithMiddleware;
@@ -22,7 +22,7 @@ use utils::singleflight::Group;
use crate::error::{CasClientError, Result};
use crate::http_client::Api;
use crate::output_provider::OutputProvider;
use crate::remote_client::{get_reconstruction_with_endpoint_and_client, PREFIX_DEFAULT};
use crate::remote_client::{PREFIX_DEFAULT, get_reconstruction_with_endpoint_and_client};
use crate::retry_wrapper::{RetryWrapper, RetryableReqwestError};
utils::configurable_constants! {
@@ -572,7 +572,7 @@ mod tests {
use tokio::time::sleep;
use super::*;
use crate::{build_http_client, RetryConfig};
use crate::{RetryConfig, build_http_client};
#[tokio::test]
async fn test_fetch_info_query_and_find() -> Result<()> {

View File

@@ -5,8 +5,8 @@ use anyhow::anyhow;
use http::StatusCode;
use merklehash::MerkleHash;
use thiserror::Error;
use tokio::sync::mpsc::error::SendError;
use tokio::sync::AcquireError;
use tokio::sync::mpsc::error::SendError;
use tokio::task::JoinError;
#[non_exhaustive]

View File

@@ -5,21 +5,21 @@ use anyhow::anyhow;
use cas_types::{REQUEST_ID_HEADER, SESSION_ID_HEADER};
use error_printer::{ErrorPrinter, OptionPrinter};
use http::{Extensions, StatusCode};
use reqwest::header::{HeaderValue, AUTHORIZATION};
use reqwest::header::{AUTHORIZATION, HeaderValue};
use reqwest::{Request, Response};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, Middleware, Next};
use reqwest_retry::policies::ExponentialBackoff;
use reqwest_retry::{
default_on_request_failure, default_on_request_success, DefaultRetryableStrategy, RetryTransientMiddleware,
Retryable, RetryableStrategy,
DefaultRetryableStrategy, RetryTransientMiddleware, Retryable, RetryableStrategy, default_on_request_failure,
default_on_request_success,
};
use tokio::sync::Mutex;
use tracing::{debug, info_span, warn, Instrument};
use tracing::{Instrument, debug, info_span, warn};
use utils::auth::{AuthConfig, TokenProvider};
use crate::constants::{CLIENT_IDLE_CONNECTION_TIMEOUT_SECS, CLIENT_MAX_IDLE_CONNECTIONS};
use crate::retry_wrapper::on_request_failure;
use crate::{error, CasClientError};
use crate::{CasClientError, error};
pub(crate) const NUM_RETRIES: u32 = 5;
pub(crate) const BASE_RETRY_DELAY_MS: u64 = 3000; // 3s
@@ -30,10 +30,10 @@ pub struct No429RetryStrategy;
impl RetryableStrategy for No429RetryStrategy {
fn handle(&self, res: &Result<Response, reqwest_middleware::Error>) -> Option<Retryable> {
if let Ok(success) = res {
if success.status() == StatusCode::TOO_MANY_REQUESTS {
return Some(Retryable::Fatal);
}
if let Ok(success) = res
&& success.status() == StatusCode::TOO_MANY_REQUESTS
{
return Some(Retryable::Fatal);
}
const DEFAULT_STRATEGY: DefaultRetryableStrategy = DefaultRetryableStrategy;
@@ -340,12 +340,10 @@ impl ResponseErrorLogger<error::Result<Response>> for reqwest_middleware::Result
}
pub fn request_id_from_response(res: &Response) -> &str {
let request_id = res
.headers()
res.headers()
.get(REQUEST_ID_HEADER)
.and_then(|h| h.to_str().ok())
.unwrap_or_default();
request_id
.unwrap_or_default()
}
#[cfg(test)]

View File

@@ -9,9 +9,9 @@ use merklehash::MerkleHash;
use progress_tracking::item_tracking::SingleItemProgressUpdater;
use progress_tracking::upload_tracking::CompletionTracker;
use crate::error::Result;
#[cfg(not(target_family = "wasm"))]
use crate::OutputProvider;
use crate::error::Result;
/// A Client to the Shard service. The shard service
/// provides for

View File

@@ -1,7 +1,7 @@
#![allow(dead_code)]
pub use chunk_cache::{CacheConfig, CHUNK_CACHE_SIZE_BYTES};
pub use http_client::{build_auth_http_client, build_http_client, Api, RetryConfig};
pub use chunk_cache::{CHUNK_CACHE_SIZE_BYTES, CacheConfig};
pub use http_client::{Api, RetryConfig, build_auth_http_client, build_http_client};
pub use interface::Client;
#[cfg(not(target_family = "wasm"))]
pub use local_client::LocalClient;

View File

@@ -1,4 +1,4 @@
use std::fs::{metadata, File};
use std::fs::{File, metadata};
use std::io::{BufReader, Cursor, Write};
use std::path::{Path, PathBuf};
use std::sync::Arc;
@@ -21,9 +21,9 @@ use tempfile::TempDir;
use tokio::runtime::Handle;
use tracing::{debug, error, info, warn};
use crate::Client;
use crate::error::{CasClientError, Result};
use crate::output_provider::OutputProvider;
use crate::Client;
pub struct LocalClient {
tmp_dir: Option<TempDir>, // To hold directory to use for local testing
@@ -395,8 +395,8 @@ fn map_heed_db_error(e: heed::Error) -> CasClientError {
#[cfg(test)]
mod tests {
use cas_object::test_utils::*;
use cas_object::CompressionScheme::LZ4;
use cas_object::test_utils::*;
use deduplication::test_utils::raw_xorb_to_vec;
use mdb_shard::utils::parse_shard_filename;

View File

@@ -13,21 +13,21 @@ use cas_types::{
};
use chunk_cache::{CacheConfig, ChunkCache};
use error_printer::ErrorPrinter;
use http::header::{CONTENT_LENGTH, RANGE};
use http::HeaderValue;
use http::header::{CONTENT_LENGTH, RANGE};
use mdb_shard::file_structs::{FileDataSequenceEntry, FileDataSequenceHeader, MDBFileInfo};
use merklehash::MerkleHash;
use progress_tracking::item_tracking::SingleItemProgressUpdater;
use progress_tracking::upload_tracking::CompletionTracker;
use reqwest::{Body, Response, StatusCode, Url};
use reqwest_middleware::ClientWithMiddleware;
use tokio::sync::{mpsc, OwnedSemaphorePermit};
use tokio::sync::{OwnedSemaphorePermit, mpsc};
use tokio::task::{JoinHandle, JoinSet};
use tracing::{debug, info, instrument};
use utils::auth::AuthConfig;
#[cfg(not(target_family = "wasm"))]
use utils::singleflight::Group;
use xet_runtime::{global_semaphore_handle, GlobalSemaphoreHandle, XetRuntime};
use xet_runtime::{GlobalSemaphoreHandle, XetRuntime, global_semaphore_handle};
#[cfg(not(target_family = "wasm"))]
use crate::download_utils::*;
@@ -36,7 +36,7 @@ use crate::http_client::{Api, ResponseErrorLogger, RetryConfig};
#[cfg(not(target_family = "wasm"))]
use crate::output_provider::OutputProvider;
use crate::retry_wrapper::RetryWrapper;
use crate::{http_client, Client};
use crate::{Client, http_client};
pub const CAS_ENDPOINT: &str = "http://localhost:8080";
pub const PREFIX_DEFAULT: &str = "default";
@@ -97,10 +97,10 @@ pub(crate) async fn get_reconstruction_with_endpoint_and_client(
let e = response.unwrap_err();
// bytes_range not satisfiable
if let CasClientError::ReqwestError(e, _) = &e {
if let Some(StatusCode::RANGE_NOT_SATISFIABLE) = e.status() {
return Ok(None);
}
if let CasClientError::ReqwestError(e, _) = &e
&& let Some(StatusCode::RANGE_NOT_SATISFIABLE) = e.status()
{
return Ok(None);
}
return Err(e);
@@ -818,8 +818,8 @@ mod tests {
use std::collections::HashMap;
use anyhow::Result;
use cas_object::test_utils::*;
use cas_object::CompressionScheme;
use cas_object::test_utils::*;
use cas_types::{CASReconstructionFetchInfo, CASReconstructionTerm, ChunkRange};
use deduplication::constants::MAX_XORB_BYTES;
use httpmock::Method::GET;

View File

@@ -1,10 +1,10 @@
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use reqwest::{Error as ReqwestError, Response, StatusCode};
use reqwest_retry::{default_on_request_success, Retryable};
use tokio_retry::strategy::{jitter, ExponentialBackoff};
use reqwest_retry::{Retryable, default_on_request_success};
use tokio_retry::RetryIf;
use tokio_retry::strategy::{ExponentialBackoff, jitter};
use tracing::{error, info};
use crate::constants::{CLIENT_RETRY_BASE_DELAY_MS, CLIENT_RETRY_MAX_ATTEMPTS};
@@ -387,8 +387,8 @@ fn get_source_error_type<T: std::error::Error + 'static>(err: &dyn std::error::E
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use serde::{Deserialize, Serialize};

View File

@@ -1,6 +1,6 @@
use std::pin::Pin;
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
use std::task::{Context, Poll};
use bytes::Bytes;

View File

@@ -1,7 +1,7 @@
[package]
name = "cas_object"
version = "0.1.0"
edition = "2021"
edition = "2024"
[[bench]]
name = "compression_bench"

View File

@@ -148,7 +148,9 @@ fn main() {
);
unsafe {
println!("{BG4_SPLIT_RUNTIME} s, {BG4_LZ4_COMPRESS_RUNTIME} s , {BG4_LZ4_DECOMPRESS_RUNTIME} s, {BG4_REGROUP_RUNTIME} s");
println!(
"{BG4_SPLIT_RUNTIME} s, {BG4_LZ4_COMPRESS_RUNTIME} s , {BG4_LZ4_DECOMPRESS_RUNTIME} s, {BG4_REGROUP_RUNTIME} s"
);
}
// For CSV exporting

View File

@@ -57,9 +57,11 @@ impl BG4Predictor {
for i in byte_range.0..byte_range.1 {
let idx = i + offset;
let n_ones = *per_byte_popcnt.get_unchecked(i) as usize;
let loc = (idx % 4) * 9 + n_ones;
*dest_ptr.add(loc) += 1;
unsafe {
let n_ones = *per_byte_popcnt.get_unchecked(i) as usize;
let loc = (idx % 4) * 9 + n_ones;
*dest_ptr.add(loc) += 1;
}
}
}

View File

@@ -6,8 +6,8 @@
/// Then, use compression_prediction_tests.py to analyze this data.
use std::fs::OpenOptions;
use std::io::{Read, Seek, SeekFrom};
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
use anyhow::Result;
use cas_object::serialize_chunk;

View File

@@ -4,9 +4,9 @@ use std::mem::size_of;
use anyhow::anyhow;
use deduplication::constants::MAX_CHUNK_SIZE;
use crate::CompressionScheme;
use crate::cas_object_format::CAS_OBJECT_FORMAT_IDENT;
use crate::error::CasObjectError;
use crate::CompressionScheme;
#[cfg(not(target_family = "wasm"))]
pub mod deserialize_async;
@@ -220,7 +220,7 @@ mod tests {
use std::io::Cursor;
use super::*;
use crate::test_utils::{build_cas_object, ChunkSize};
use crate::test_utils::{ChunkSize, build_cas_object};
const COMPRESSED_LEN: u32 = 66051;
const UNCOMPRESSED_LEN: u32 = 131072;

View File

@@ -6,7 +6,7 @@ use futures::io::{AsyncRead, AsyncReadExt};
use futures::{Stream, TryStreamExt};
use crate::error::CasObjectError;
use crate::{parse_chunk_header, CASChunkHeader, CAS_CHUNK_HEADER_LENGTH};
use crate::{CAS_CHUNK_HEADER_LENGTH, CASChunkHeader, parse_chunk_header};
pub async fn deserialize_chunk_header<R: AsyncRead + Unpin>(reader: &mut R) -> Result<CASChunkHeader, CasObjectError> {
let mut buf = [0u8; size_of::<CASChunkHeader>()];
@@ -113,10 +113,10 @@ where
mod tests {
use bytes::Bytes;
use futures::Stream;
use rand::{rng, Rng};
use rand::{Rng, rng};
use crate::deserialize_async::deserialize_chunks_to_writer_from_stream;
use crate::{serialize_chunk, CompressionScheme};
use crate::{CompressionScheme, serialize_chunk};
fn gen_random_bytes(rng: &mut impl Rng, uncompressed_chunk_size: u32) -> Vec<u8> {
let mut data = vec![0u8; uncompressed_chunk_size as usize];

View File

@@ -4,8 +4,8 @@ use std::mem::{size_of, size_of_val};
use anyhow::anyhow;
use bytes::Buf;
use deduplication::constants::TARGET_CHUNK_SIZE;
use deduplication::RawXorbData;
use deduplication::constants::TARGET_CHUNK_SIZE;
#[cfg(not(target_family = "wasm"))]
use futures::AsyncReadExt;
use mdb_shard::chunk_verification::range_hash_from_chunks;
@@ -1769,14 +1769,16 @@ mod tests {
// Act & Assert
let mut writer: Cursor<Vec<u8>> = Cursor::new(Vec::new());
assert!(serialize_xorb_to_stream_reference(
&mut writer,
&c.info.cashash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::LZ4)
)
.is_ok());
assert!(
serialize_xorb_to_stream_reference(
&mut writer,
&c.info.cashash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::LZ4)
)
.is_ok()
);
let mut reader = writer.clone();
reader.set_position(0);
@@ -1787,14 +1789,16 @@ mod tests {
let c_bytes = c.get_all_bytes(&mut reader).unwrap();
let mut writer: Cursor<Vec<u8>> = Cursor::new(Vec::new());
assert!(serialize_xorb_to_stream_reference(
&mut writer,
&c.info.cashash,
&c_bytes,
&raw_chunk_boundaries,
Some(CompressionScheme::None)
)
.is_ok());
assert!(
serialize_xorb_to_stream_reference(
&mut writer,
&c.info.cashash,
&c_bytes,
&raw_chunk_boundaries,
Some(CompressionScheme::None)
)
.is_ok()
);
let mut reader = writer.clone();
reader.set_position(0);
@@ -1815,14 +1819,16 @@ mod tests {
build_cas_object(55, ChunkSize::Fixed(53212), CompressionScheme::LZ4);
// Act & Assert
let mut buf: Cursor<Vec<u8>> = Cursor::new(Vec::new());
assert!(serialize_xorb_to_stream_reference(
&mut buf,
&c.info.cashash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::LZ4)
)
.is_ok());
assert!(
serialize_xorb_to_stream_reference(
&mut buf,
&c.info.cashash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::LZ4)
)
.is_ok()
);
let serialized_all_bytes = c.get_all_bytes(&mut buf).unwrap();
@@ -1837,14 +1843,16 @@ mod tests {
build_cas_object(3, ChunkSize::Fixed(100), CompressionScheme::None);
let mut buf: Cursor<Vec<u8>> = Cursor::new(Vec::new());
// Act & Assert
assert!(serialize_xorb_to_stream_reference(
&mut buf,
&c.info.cashash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::None)
)
.is_ok());
assert!(
serialize_xorb_to_stream_reference(
&mut buf,
&c.info.cashash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::None)
)
.is_ok()
);
assert!(CasObject::validate_cas_object(&mut buf, &c.info.cashash).unwrap().is_some());
}
@@ -1856,14 +1864,16 @@ mod tests {
build_cas_object(32, ChunkSize::Fixed(16384), CompressionScheme::None);
let mut buf: Cursor<Vec<u8>> = Cursor::new(Vec::new());
// Act & Assert
assert!(serialize_xorb_to_stream_reference(
&mut buf,
&c.info.cashash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::None)
)
.is_ok());
assert!(
serialize_xorb_to_stream_reference(
&mut buf,
&c.info.cashash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::None)
)
.is_ok()
);
assert!(CasObject::validate_cas_object(&mut buf, &c.info.cashash).unwrap().is_some());
@@ -1887,14 +1897,16 @@ mod tests {
build_cas_object(32, ChunkSize::Random(512, 65536), CompressionScheme::None);
let mut buf: Cursor<Vec<u8>> = Cursor::new(Vec::new());
// Act & Assert
assert!(serialize_xorb_to_stream_reference(
&mut buf,
&c.info.cashash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::None)
)
.is_ok());
assert!(
serialize_xorb_to_stream_reference(
&mut buf,
&c.info.cashash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::None)
)
.is_ok()
);
assert!(CasObject::validate_cas_object(&mut buf, &c.info.cashash).unwrap().is_some());
@@ -1917,14 +1929,16 @@ mod tests {
build_cas_object(256, ChunkSize::Random(512, 65536), CompressionScheme::None);
let mut buf: Cursor<Vec<u8>> = Cursor::new(Vec::new());
// Act & Assert
assert!(serialize_xorb_to_stream_reference(
&mut buf,
&c.info.cashash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::None)
)
.is_ok());
assert!(
serialize_xorb_to_stream_reference(
&mut buf,
&c.info.cashash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::None)
)
.is_ok()
);
assert!(CasObject::validate_cas_object(&mut buf, &c.info.cashash).unwrap().is_some());
@@ -1947,14 +1961,16 @@ mod tests {
build_cas_object(1, ChunkSize::Fixed(8), CompressionScheme::LZ4);
let mut writer: Cursor<Vec<u8>> = Cursor::new(Vec::new());
// Act & Assert
assert!(serialize_xorb_to_stream_reference(
&mut writer,
&c.info.cashash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::LZ4)
)
.is_ok());
assert!(
serialize_xorb_to_stream_reference(
&mut writer,
&c.info.cashash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::LZ4)
)
.is_ok()
);
let mut reader = writer.clone();
reader.set_position(0);
@@ -1976,14 +1992,16 @@ mod tests {
build_cas_object(32, ChunkSize::Fixed(16384), CompressionScheme::LZ4);
let mut buf: Cursor<Vec<u8>> = Cursor::new(Vec::new());
// Act & Assert
assert!(serialize_xorb_to_stream_reference(
&mut buf,
&c.info.cashash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::LZ4)
)
.is_ok());
assert!(
serialize_xorb_to_stream_reference(
&mut buf,
&c.info.cashash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::LZ4)
)
.is_ok()
);
assert!(CasObject::validate_cas_object(&mut buf, &c.info.cashash).unwrap().is_some());
@@ -2007,14 +2025,16 @@ mod tests {
build_cas_object(32, ChunkSize::Random(512, 65536), CompressionScheme::LZ4);
let mut buf: Cursor<Vec<u8>> = Cursor::new(Vec::new());
// Act & Assert
assert!(serialize_xorb_to_stream_reference(
&mut buf,
&c.info.cashash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::LZ4)
)
.is_ok());
assert!(
serialize_xorb_to_stream_reference(
&mut buf,
&c.info.cashash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::LZ4)
)
.is_ok()
);
assert!(CasObject::validate_cas_object(&mut buf, &c.info.cashash).unwrap().is_some());
@@ -2037,14 +2057,16 @@ mod tests {
build_cas_object(256, ChunkSize::Random(512, 65536), CompressionScheme::LZ4);
let mut writer: Cursor<Vec<u8>> = Cursor::new(Vec::new());
// Act & Assert
assert!(serialize_xorb_to_stream_reference(
&mut writer,
&c.info.cashash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::LZ4)
)
.is_ok());
assert!(
serialize_xorb_to_stream_reference(
&mut writer,
&c.info.cashash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::LZ4)
)
.is_ok()
);
let mut reader = writer.clone();
reader.set_position(0);
@@ -2065,14 +2087,16 @@ mod tests {
build_cas_object(64, ChunkSize::Random(512, 2048), CompressionScheme::LZ4);
let mut buf: Cursor<Vec<u8>> = Cursor::new(Vec::new());
// Act & Assert
assert!(serialize_xorb_to_stream_reference(
&mut buf,
&c.info.cashash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::LZ4)
)
.is_ok());
assert!(
serialize_xorb_to_stream_reference(
&mut buf,
&c.info.cashash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::LZ4)
)
.is_ok()
);
let xorb_bytes = buf.into_inner();
// length - 4 byte for the info_length - info_length + ident + version (already read ident + version)
@@ -2095,14 +2119,16 @@ mod tests {
build_cas_object(4, ChunkSize::Random(512, 2048), CompressionScheme::LZ4);
let mut buf: Cursor<Vec<u8>> = Cursor::new(Vec::new());
// Act & Assert
assert!(serialize_xorb_to_stream_reference(
&mut buf,
&c.info.cashash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::LZ4)
)
.is_ok());
assert!(
serialize_xorb_to_stream_reference(
&mut buf,
&c.info.cashash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::LZ4)
)
.is_ok()
);
let xorb_bytes = buf.into_inner();
@@ -2159,14 +2185,16 @@ mod tests {
build_cas_object(4, ChunkSize::Random(512, 2048), CompressionScheme::LZ4);
let mut buf: Cursor<Vec<u8>> = Cursor::new(Vec::new());
assert!(serialize_xorb_to_stream_reference(
&mut buf,
&c.info.cashash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::LZ4),
)
.is_ok());
assert!(
serialize_xorb_to_stream_reference(
&mut buf,
&c.info.cashash,
&raw_data,
&raw_chunk_boundaries,
Some(CompressionScheme::LZ4),
)
.is_ok()
);
// Switch V1 footer to V0
let mut cas_info_v0 = CasObjectInfoV0::default();
@@ -2258,14 +2286,16 @@ mod tests {
// Act & Assert
let mut writer: Cursor<Vec<u8>> = Cursor::new(Vec::new());
assert!(serialize_xorb_to_stream_reference(
&mut writer,
&c.info.cashash,
&raw_data,
&raw_chunk_boundaries,
Some(COMPRESSION_SCHEME)
)
.is_ok());
assert!(
serialize_xorb_to_stream_reference(
&mut writer,
&c.info.cashash,
&raw_data,
&raw_chunk_boundaries,
Some(COMPRESSION_SCHEME)
)
.is_ok()
);
let original = c.info;
writer.seek(SeekFrom::Start(0)).unwrap();

View File

@@ -1,13 +1,13 @@
use std::borrow::Cow;
use std::fmt::Display;
use std::io::{copy, Cursor, Read, Write};
use std::io::{Cursor, Read, Write, copy};
use std::time::Instant;
use anyhow::anyhow;
use lz4_flex::frame::{FrameDecoder, FrameEncoder};
use crate::byte_grouping::bg4::{bg4_regroup, bg4_split};
use crate::byte_grouping::BG4Predictor;
use crate::byte_grouping::bg4::{bg4_regroup, bg4_split};
use crate::error::{CasObjectError, Result};
pub static mut BG4_SPLIT_RUNTIME: f64 = 0.;

View File

@@ -1,7 +1,7 @@
[package]
name = "cas_types"
version = "0.1.0"
edition = "2021"
edition = "2024"
[dependencies]
merklehash = { path = "../merklehash" }

View File

@@ -1,8 +1,8 @@
use std::fmt::{Display, Formatter};
use std::str::FromStr;
use merklehash::data_hash::hex;
use merklehash::MerkleHash;
use merklehash::data_hash::hex;
use serde::{Deserialize, Serialize};
use crate::error::CasTypesError;

View File

@@ -1,7 +1,7 @@
[package]
name = "chunk_cache"
version = "0.1.0"
edition = "2021"
edition = "2024"
[dependencies]
cas_types = { path = "../cas_types" }

View File

@@ -6,9 +6,9 @@ use std::path::{Path, PathBuf};
use std::sync::Arc;
use async_trait::async_trait;
use base64::engine::general_purpose::URL_SAFE;
use base64::engine::GeneralPurpose;
use base64::Engine;
use base64::engine::GeneralPurpose;
use base64::engine::general_purpose::URL_SAFE;
use cas_types::{ChunkRange, Key};
use error_printer::ErrorPrinter;
use file_utils::SafeFileCreator;
@@ -725,20 +725,20 @@ fn try_parse_cache_file(file_result: io::Result<DirEntry>, capacity: u64) -> Opt
/// removes a file but disregards a "NotFound" error if the file is already gone
fn remove_file(path: impl AsRef<Path>) -> Result<(), ChunkCacheError> {
if let Err(e) = std::fs::remove_file(path) {
if e.kind() != ErrorKind::NotFound {
return Err(e.into());
}
if let Err(e) = std::fs::remove_file(path)
&& e.kind() != ErrorKind::NotFound
{
return Err(e.into());
}
Ok(())
}
/// removes a directory but disregards a "NotFound" error if the directory is already gone
fn remove_dir(path: impl AsRef<Path>) -> Result<(), ChunkCacheError> {
if let Err(e) = std::fs::remove_dir(path) {
if e.kind() != ErrorKind::NotFound {
return Err(e.into());
}
if let Err(e) = std::fs::remove_dir(path)
&& e.kind() != ErrorKind::NotFound
{
return Err(e.into());
}
Ok(())
}
@@ -815,12 +815,12 @@ mod tests {
use std::collections::BTreeSet;
use cas_types::{ChunkRange, Key};
use rand::rngs::StdRng;
use rand::SeedableRng;
use rand::rngs::StdRng;
use tempdir::TempDir;
use utils::output_bytes;
use super::{DiskCache, DEFAULT_CHUNK_CACHE_CAPACITY};
use super::{DEFAULT_CHUNK_CACHE_CAPACITY, DiskCache};
use crate::disk::test_utils::*;
use crate::disk::try_parse_key;
use crate::{CacheConfig, ChunkCache};
@@ -837,11 +837,13 @@ mod tests {
..Default::default()
};
let cache = DiskCache::initialize(&config).unwrap();
assert!(cache
.get(&random_key(&mut rng), &random_range(&mut rng))
.await
.unwrap()
.is_none());
assert!(
cache
.get(&random_key(&mut rng), &random_range(&mut rng))
.await
.unwrap()
.is_none()
);
}
#[tokio::test]
@@ -1173,10 +1175,12 @@ mod tests {
let right_chunk_byte_indices: Vec<u32> =
(&chunk_byte_indices[1..]).iter().map(|v| v - chunk_byte_indices[1]).collect();
let right_data = &data[chunk_byte_indices[1] as usize..];
assert!(cache
.put(&key, &right_range, &right_chunk_byte_indices, right_data)
.await
.is_ok());
assert!(
cache
.put(&key, &right_range, &right_chunk_byte_indices, right_data)
.await
.is_ok()
);
assert_eq!(total_bytes, cache.total_bytes().await);
// middle range
@@ -1188,10 +1192,12 @@ mod tests {
let middle_data =
&data[chunk_byte_indices[1] as usize..chunk_byte_indices[chunk_byte_indices.len() - 2] as usize];
assert!(cache
.put(&key, &middle_range, &middle_chunk_byte_indices, middle_data)
.await
.is_ok());
assert!(
cache
.put(&key, &middle_range, &middle_chunk_byte_indices, middle_data)
.await
.is_ok()
);
assert_eq!(total_bytes, cache.total_bytes().await);
}
@@ -1239,11 +1245,13 @@ mod tests {
#[test]
fn test_initialize_with_cache_size_0() {
assert!(DiskCache::initialize(&CacheConfig {
cache_directory: "/tmp".into(),
cache_size: 0,
})
.is_err());
assert!(
DiskCache::initialize(&CacheConfig {
cache_directory: "/tmp".into(),
cache_size: 0,
})
.is_err()
);
}
}
@@ -1253,7 +1261,7 @@ mod concurrency_tests {
use super::DiskCache;
use crate::disk::DEFAULT_CHUNK_CACHE_CAPACITY;
use crate::{CacheConfig, ChunkCache, RandomEntryIterator, RANGE_LEN};
use crate::{CacheConfig, ChunkCache, RANGE_LEN, RandomEntryIterator};
const NUM_ITEMS_PER_TASK: usize = 20;
const RANDOM_SEED: u64 = 878987298749287;

View File

@@ -3,8 +3,8 @@ use std::hash::{Hash, Hasher};
use std::io::Cursor;
use std::mem::size_of;
use std::ops::Deref;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use base64::Engine;
use cas_types::ChunkRange;
@@ -168,7 +168,7 @@ mod tests {
use cas_types::ChunkRange;
use crate::disk::cache_item::CACHE_ITEM_FILE_NAME_BUF_SIZE;
use crate::disk::{CacheItem, BASE64_ENGINE};
use crate::disk::{BASE64_ENGINE, CacheItem};
impl Default for CacheItem {
fn default() -> Self {

View File

@@ -4,7 +4,7 @@ use cas_types::{ChunkRange, Key};
use merklehash::MerkleHash;
use rand::rngs::{StdRng, ThreadRng};
use rand::seq::SliceRandom;
use rand::{rng, Rng, SeedableRng};
use rand::{Rng, SeedableRng, rng};
#[cfg(test)]
pub const RANGE_LEN: u32 = 16 << 10;

View File

@@ -7,8 +7,8 @@ use std::path::PathBuf;
use async_trait::async_trait;
pub use cache_manager::get_cache;
use cas_types::{ChunkRange, Key};
pub use disk::test_utils::*;
pub use disk::DiskCache;
pub use disk::test_utils::*;
use error::ChunkCacheError;
use mockall::automock;

View File

@@ -1,7 +1,7 @@
[package]
name = "chunk_cache_bench"
version = "0.1.0"
edition = "2021"
edition = "2024"
[dependencies]
cas_types = { path = "../cas_types" }

View File

@@ -1,7 +1,7 @@
[package]
name = "data"
version = "0.14.5"
edition = "2021"
edition = "2024"
[lib]
doctest = false

View File

@@ -160,7 +160,7 @@ impl Command {
fn walk_files(files: Vec<String>, recursive: bool) -> Vec<String> {
// Scan all files if under recursive mode
let file_paths = if recursive {
if recursive {
files
.iter()
.flat_map(|dir| {
@@ -179,9 +179,7 @@ fn walk_files(files: Vec<String>, recursive: bool) -> Vec<String> {
.collect::<Vec<_>>()
} else {
files
};
file_paths
}
}
fn is_git_special_files(path: &str) -> bool {

View File

@@ -3,7 +3,7 @@ use std::str::FromStr;
use std::sync::Arc;
use cas_client::remote_client::PREFIX_DEFAULT;
use cas_client::{CacheConfig, CHUNK_CACHE_SIZE_BYTES};
use cas_client::{CHUNK_CACHE_SIZE_BYTES, CacheConfig};
use cas_object::CompressionScheme;
use utils::auth::AuthConfig;

View File

@@ -6,24 +6,24 @@ use std::path::{Path, PathBuf};
use std::sync::Arc;
use cas_client::remote_client::PREFIX_DEFAULT;
use cas_client::{CacheConfig, FileProvider, OutputProvider, CHUNK_CACHE_SIZE_BYTES};
use cas_client::{CHUNK_CACHE_SIZE_BYTES, CacheConfig, FileProvider, OutputProvider};
use cas_object::CompressionScheme;
use deduplication::DeduplicationMetrics;
use dirs::home_dir;
use progress_tracking::item_tracking::ItemProgressUpdater;
use progress_tracking::TrackingProgressUpdater;
use tracing::{info, info_span, instrument, Instrument, Span};
use progress_tracking::item_tracking::ItemProgressUpdater;
use tracing::{Instrument, Span, info, info_span, instrument};
use ulid::Ulid;
use utils::auth::{AuthConfig, TokenRefresher};
use utils::normalized_path_from_user_string;
use xet_runtime::utils::run_constrained_with_semaphore;
use xet_runtime::{global_semaphore_handle, GlobalSemaphoreHandle, XetRuntime};
use xet_runtime::{GlobalSemaphoreHandle, XetRuntime, global_semaphore_handle};
use crate::configurations::*;
use crate::constants::{INGESTION_BLOCK_SIZE, MAX_CONCURRENT_DOWNLOADS};
use crate::errors::DataProcessingError;
use crate::file_upload_session::CONCURRENT_FILE_INGESTION_LIMITER;
use crate::{errors, FileDownloader, FileUploadSession, XetFileInfo};
use crate::{FileDownloader, FileUploadSession, XetFileInfo, errors};
utils::configurable_constants! {
ref DEFAULT_CAS_ENDPOINT: String = "http://localhost:8080".to_string();
@@ -201,12 +201,10 @@ pub async fn download_async(
global_semaphore_handle!(*MAX_CONCURRENT_DOWNLOADS);
}
if let Some(updaters) = &progress_updaters {
if updaters.len() != file_infos.len() {
return Err(DataProcessingError::ParameterError(
"updaters are not same length as pointer_files".to_string(),
));
}
if let Some(updaters) = &progress_updaters
&& updaters.len() != file_infos.len()
{
return Err(DataProcessingError::ParameterError("updaters are not same length as pointer_files".to_string()));
}
let config =
default_config(endpoint.unwrap_or(DEFAULT_CAS_ENDPOINT.to_string()), None, token_info, token_refresher)?;
@@ -289,10 +287,9 @@ async fn smudge_file(
#[cfg(test)]
mod tests {
use std::env;
use serial_test::serial;
use tempfile::tempdir;
use utils::EnvVarGuard;
use super::*;
@@ -300,7 +297,7 @@ mod tests {
#[serial(default_config_env)]
fn test_default_config_with_hf_home() {
let temp_dir = tempdir().unwrap();
env::set_var("HF_HOME", temp_dir.path().to_str().unwrap());
let _hf_home_guard = EnvVarGuard::set("HF_HOME", temp_dir.path().to_str().unwrap());
let endpoint = "http://localhost:8080".to_string();
let result = default_config(endpoint, None, None, None);
@@ -308,8 +305,6 @@ mod tests {
assert!(result.is_ok());
let config = result.unwrap();
assert!(config.data_config.cache_config.cache_directory.starts_with(&temp_dir.path()));
env::remove_var("HF_HOME");
}
#[test]
@@ -318,25 +313,27 @@ mod tests {
let temp_dir_xet_cache = tempdir().unwrap();
let temp_dir_hf_home = tempdir().unwrap();
env::set_var("HF_XET_CACHE", temp_dir_xet_cache.path().to_str().unwrap());
env::set_var("HF_HOME", temp_dir_hf_home.path().to_str().unwrap());
let hf_xet_cache_guard = EnvVarGuard::set("HF_XET_CACHE", temp_dir_xet_cache.path().to_str().unwrap());
let hf_home_guard = EnvVarGuard::set("HF_HOME", temp_dir_hf_home.path().to_str().unwrap());
let endpoint = "http://localhost:8080".to_string();
let result = default_config(endpoint, None, None, None);
assert!(result.is_ok());
let config = result.unwrap();
assert!(config
.data_config
.cache_config
.cache_directory
.starts_with(&temp_dir_xet_cache.path()));
assert!(
config
.data_config
.cache_config
.cache_directory
.starts_with(&temp_dir_xet_cache.path())
);
env::remove_var("HF_XET_CACHE");
env::remove_var("HF_HOME");
drop(hf_xet_cache_guard);
drop(hf_home_guard);
let temp_dir = tempdir().unwrap();
env::set_var("HF_HOME", temp_dir.path().to_str().unwrap());
let _hf_home_guard = EnvVarGuard::set("HF_HOME", temp_dir.path().to_str().unwrap());
let endpoint = "http://localhost:8080".to_string();
let result = default_config(endpoint, None, None, None);
@@ -344,15 +341,13 @@ mod tests {
assert!(result.is_ok());
let config = result.unwrap();
assert!(config.data_config.cache_config.cache_directory.starts_with(&temp_dir.path()));
env::remove_var("HF_HOME");
}
#[test]
#[serial(default_config_env)]
fn test_default_config_with_hf_xet_cache() {
let temp_dir = tempdir().unwrap();
env::set_var("HF_XET_CACHE", temp_dir.path().to_str().unwrap());
let _hf_xet_cache_guard = EnvVarGuard::set("HF_XET_CACHE", temp_dir.path().to_str().unwrap());
let endpoint = "http://localhost:8080".to_string();
let result = default_config(endpoint, None, None, None);
@@ -360,8 +355,6 @@ mod tests {
assert!(result.is_ok());
let config = result.unwrap();
assert!(config.data_config.cache_config.cache_directory.starts_with(&temp_dir.path()));
env::remove_var("HF_XET_CACHE");
}
#[test]

View File

@@ -8,14 +8,14 @@ use deduplication::{Chunk, Chunker, DeduplicationMetrics, FileDeduper};
use mdb_shard::file_structs::FileMetadataExt;
use merklehash::MerkleHash;
use progress_tracking::upload_tracking::CompletionTrackerFileId;
use tracing::{debug_span, info, instrument, Instrument};
use tracing::{Instrument, debug_span, info, instrument};
use crate::XetFileInfo;
use crate::constants::INGESTION_BLOCK_SIZE;
use crate::deduplication_interface::UploadSessionDataManager;
use crate::errors::Result;
use crate::file_upload_session::FileUploadSession;
use crate::sha256::ShaGenerator;
use crate::XetFileInfo;
/// A class that encapsulates the clean and data task around a single file.
pub struct SingleFileCleaner {

View File

@@ -11,7 +11,7 @@ use cas_client::Client;
use cas_object::SerializedCasObject;
use deduplication::constants::{MAX_XORB_BYTES, MAX_XORB_CHUNKS};
use deduplication::{DataAggregator, DeduplicationMetrics, RawXorbData};
use jsonwebtoken::{decode, DecodingKey, Validation};
use jsonwebtoken::{DecodingKey, Validation, decode};
use lazy_static::lazy_static;
use mdb_shard::file_structs::MDBFileInfo;
use more_asserts::*;
@@ -22,9 +22,9 @@ use progress_tracking::verification_wrapper::ProgressUpdaterVerificationWrapper;
use progress_tracking::{NoOpProgressUpdater, TrackingProgressUpdater};
use tokio::sync::{Mutex, OwnedSemaphorePermit};
use tokio::task::{JoinHandle, JoinSet};
use tracing::{info_span, instrument, Instrument, Span};
use tracing::{Instrument, Span, info_span, instrument};
use ulid::Ulid;
use xet_runtime::{global_semaphore_handle, GlobalSemaphoreHandle, XetRuntime};
use xet_runtime::{GlobalSemaphoreHandle, XetRuntime, global_semaphore_handle};
use crate::configurations::*;
use crate::constants::{
@@ -35,7 +35,7 @@ use crate::errors::*;
use crate::file_cleaner::SingleFileCleaner;
use crate::remote_client_interface::create_remote_client;
use crate::shard_interface::SessionShardInterface;
use crate::{prometheus_metrics, XetFileInfo};
use crate::{XetFileInfo, prometheus_metrics};
lazy_static! {
pub static ref CONCURRENT_FILE_INGESTION_LIMITER: GlobalSemaphoreHandle =

View File

@@ -4,10 +4,10 @@ use anyhow::Result;
use cas_object::CompressionScheme;
use hub_client::{BearerCredentialHelper, HubClient, Operation};
use mdb_shard::file_structs::MDBFileInfo;
use tracing::{info_span, instrument, Instrument, Span};
use tracing::{Instrument, Span, info_span, instrument};
use utils::auth::TokenRefresher;
use xet_runtime::utils::run_constrained;
use xet_runtime::XetRuntime;
use xet_runtime::utils::run_constrained;
use super::hub_client_token_refresher::HubClientTokenRefresher;
use crate::data_client::{clean_file, default_config};

View File

@@ -1,5 +1,5 @@
use lazy_static::lazy_static;
use prometheus::{register_int_counter, IntCounter};
use prometheus::{IntCounter, register_int_counter};
// Some of the common tracking things
lazy_static! {

View File

@@ -48,7 +48,7 @@ impl ShaGenerator {
#[cfg(test)]
mod sha_tests {
use rand::{rng, Rng};
use rand::{Rng, rng};
use super::*;

View File

@@ -1,24 +1,24 @@
use std::fs::File;
use std::io::Read;
use std::path::PathBuf;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, SystemTime};
use bytes::Bytes;
use cas_client::Client;
use error_printer::ErrorPrinter;
use mdb_shard::ShardFileManager;
use mdb_shard::cas_structs::MDBCASInfo;
use mdb_shard::constants::MDB_SHARD_MAX_TARGET_SIZE;
use mdb_shard::file_structs::{FileDataSequenceEntry, MDBFileInfo};
use mdb_shard::session_directory::{consolidate_shards_in_directory, merge_shards_background, ShardMergeResult};
use mdb_shard::session_directory::{ShardMergeResult, consolidate_shards_in_directory, merge_shards_background};
use mdb_shard::shard_in_memory::MDBInMemoryShard;
use mdb_shard::ShardFileManager;
use merklehash::MerkleHash;
use tempfile::TempDir;
use tokio::sync::Mutex;
use tokio::task::JoinSet;
use tracing::{debug, info, info_span, Instrument};
use tracing::{Instrument, debug, info, info_span};
use crate::configurations::TranslatorConfig;
use crate::constants::{
@@ -161,11 +161,11 @@ impl SessionShardInterface {
query_hashes: &[MerkleHash],
) -> Result<Option<(usize, FileDataSequenceEntry, bool)>> {
// First, see if there's something in the resumed session.
if let Some(resumed_session_sfm) = &self.resumed_session_shard_manager {
if let Some((n_entries, fse)) = resumed_session_sfm.chunk_hash_dedup_query(query_hashes).await? {
// Return true, as the data here is already known to have been uploaded.
return Ok(Some((n_entries, fse, true)));
}
if let Some(resumed_session_sfm) = &self.resumed_session_shard_manager
&& let Some((n_entries, fse)) = resumed_session_sfm.chunk_hash_dedup_query(query_hashes).await?
{
// Return true, as the data here is already known to have been uploaded.
return Ok(Some((n_entries, fse, true)));
}
// Now, check the local session directory.

View File

@@ -1,4 +1,4 @@
use std::fs::{create_dir_all, read_dir, File};
use std::fs::{File, create_dir_all, read_dir};
use std::io::{Read, Write};
use std::path::{Path, PathBuf};
use std::sync::Arc;

View File

@@ -1,7 +1,7 @@
// Run tests that determine deduplication, especially across different test subjects.
use data::FileUploadSession;
use data::configurations::TranslatorConfig;
use data::constants::{PROGRESS_UPDATE_INTERVAL_MS, SESSION_XORB_METADATA_FLUSH_MAX_COUNT};
use data::FileUploadSession;
use deduplication::constants::{MAX_XORB_BYTES, MAX_XORB_CHUNKS, TARGET_CHUNK_SIZE};
use mdb_shard::MDB_SHARD_TARGET_SIZE;
use tempfile::TempDir;
@@ -27,7 +27,7 @@ test_set_globals! {
// Test the deduplication framework.
#[cfg(test)]
mod tests {
use data::test_utils::{create_random_file, create_random_files, LocalHydrateDehydrateTest};
use data::test_utils::{LocalHydrateDehydrateTest, create_random_file, create_random_files};
use deduplication::constants::MAX_CHUNK_SIZE;
use more_asserts::*;
use progress_tracking::aggregator::AggregatingProgressUpdater;

View File

@@ -1,7 +1,7 @@
[package]
name = "deduplication"
version = "0.14.5"
edition = "2021"
edition = "2024"
[dependencies]
mdb_shard = { path = "../mdb_shard" }

View File

@@ -1,5 +1,5 @@
use bytes::Bytes;
use merklehash::{compute_data_hash, MerkleHash};
use merklehash::{MerkleHash, compute_data_hash};
#[derive(Debug, Clone, PartialEq)]
pub struct Chunk {

View File

@@ -4,8 +4,8 @@ use std::io::{Read, Seek, SeekFrom};
use bytes::Bytes;
use more_asserts::{debug_assert_ge, debug_assert_le};
use crate::constants::{MAXIMUM_CHUNK_MULTIPLIER, MINIMUM_CHUNK_DIVISOR, TARGET_CHUNK_SIZE};
use crate::Chunk;
use crate::constants::{MAXIMUM_CHUNK_MULTIPLIER, MINIMUM_CHUNK_DIVISOR, TARGET_CHUNK_SIZE};
/// Chunk Generator given an input stream. Do not use directly.
/// Use `chunk_target_default`.

View File

@@ -4,9 +4,9 @@ use mdb_shard::file_structs::MDBFileInfo;
use merklehash::MerkleHash;
use more_asserts::*;
use crate::Chunk;
use crate::constants::{MAX_XORB_BYTES, MAX_XORB_CHUNKS};
use crate::raw_xorb_data::RawXorbData;
use crate::Chunk;
#[derive(Default, Debug)]
pub struct DataAggregator {

View File

@@ -5,17 +5,17 @@ use mdb_shard::file_structs::{
FileDataSequenceEntry, FileDataSequenceHeader, FileMetadataExt, FileVerificationEntry, MDBFileInfo,
};
use mdb_shard::hash_is_global_dedup_eligible;
use merklehash::{file_hash, MerkleHash};
use merklehash::{MerkleHash, file_hash};
use more_asserts::{debug_assert_le, debug_assert_lt};
use progress_tracking::upload_tracking::FileXorbDependency;
use crate::Chunk;
use crate::constants::{MAX_XORB_BYTES, MAX_XORB_CHUNKS};
use crate::data_aggregator::DataAggregator;
use crate::dedup_metrics::DeduplicationMetrics;
use crate::defrag_prevention::DefragPrevention;
use crate::interface::DeduplicationDataInterface;
use crate::raw_xorb_data::RawXorbData;
use crate::Chunk;
pub struct FileDeduper<DataInterfaceType: DeduplicationDataInterface> {
data_mng: DataInterfaceType,
@@ -350,12 +350,12 @@ impl<DataInterfaceType: DeduplicationDataInterface> FileDeduper<DataInterfaceTyp
let mut end_idx = base_idx + 1;
for (i, chunk) in chunks.iter().enumerate().skip(1) {
if let Some(&idx) = self.new_data_hash_lookup.get(chunk) {
if idx == base_idx + i {
end_idx = idx + 1;
n_bytes += self.new_data[idx].data.len();
continue;
}
if let Some(&idx) = self.new_data_hash_lookup.get(chunk)
&& idx == base_idx + i
{
end_idx = idx + 1;
n_bytes += self.new_data[idx].data.len();
continue;
}
break;
}

View File

@@ -9,9 +9,9 @@ mod interface;
mod raw_xorb_data;
pub use chunk::Chunk;
pub use chunking::{find_partitions, Chunker};
pub use chunking::{Chunker, find_partitions};
pub use data_aggregator::DataAggregator;
pub use dedup_metrics::DeduplicationMetrics;
pub use file_deduplication::FileDeduper;
pub use interface::DeduplicationDataInterface;
pub use raw_xorb_data::{test_utils, RawXorbData};
pub use raw_xorb_data::{RawXorbData, test_utils};

View File

@@ -1,9 +1,9 @@
use mdb_shard::cas_structs::{CASChunkSequenceEntry, CASChunkSequenceHeader, MDBCASInfo};
use merklehash::{xorb_hash, MerkleHash};
use merklehash::{MerkleHash, xorb_hash};
use more_asserts::*;
use crate::constants::{MAX_XORB_BYTES, MAX_XORB_CHUNKS};
use crate::Chunk;
use crate::constants::{MAX_XORB_BYTES, MAX_XORB_CHUNKS};
/// This struct is the data needed to cut a
#[derive(Default, Debug, Clone)]

View File

@@ -1,7 +1,7 @@
[package]
name = "error_printer"
version = "0.14.5"
edition = "2021"
edition = "2024"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]

View File

@@ -1,7 +1,7 @@
[package]
name = "file_utils"
version = "0.14.2"
edition = "2021"
edition = "2024"
[dependencies]
lazy_static = { workspace = true }

View File

@@ -2,5 +2,5 @@ mod file_metadata;
mod privilege_context;
mod safe_file_creator;
pub use privilege_context::{create_dir_all, create_file, PrivilegedExecutionContext};
pub use privilege_context::{PrivilegedExecutionContext, create_dir_all, create_file};
pub use safe_file_creator::SafeFileCreator;

View File

@@ -12,7 +12,7 @@ use winapi::um::{
processthreadsapi::GetCurrentProcess,
processthreadsapi::OpenProcessToken,
securitybaseapi::GetTokenInformation,
winnt::{TokenElevation, HANDLE, TOKEN_ELEVATION, TOKEN_QUERY},
winnt::{HANDLE, TOKEN_ELEVATION, TOKEN_QUERY, TokenElevation},
};
#[cfg(test)]

View File

@@ -3,7 +3,7 @@ use std::io::{self, BufWriter, Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
use rand::distr::Alphanumeric;
use rand::{rng, Rng};
use rand::{Rng, rng};
use crate::create_file;
use crate::file_metadata::set_file_metadata;

View File

@@ -145,11 +145,11 @@ fn install_command(args: InstallArg) -> Result<()> {
return Ok(());
}
if let Some(c) = args.concurrency {
if c == 0 {
eprintln!(r#"Error: "--concurrency" should be a number greater than 0."#);
return Ok(());
}
if let Some(c) = args.concurrency
&& c == 0
{
eprintln!(r#"Error: "--concurrency" should be a number greater than 0."#);
return Ok(());
}
if args.system {

View File

@@ -102,9 +102,9 @@ pub fn get_credential(repo: &GitRepo, remote_url: &GitUrl, operation: Operation)
// 2. check embedded authentication
let credential = remote_url.credential();
match credential {
(Some(_user), Some(token)) => return Ok(BearerCredentialHelper::new(token, "url")),
_ => (), // valid only when both user and token exist
// valid only when both user and token exist
if let (Some(_user), Some(token)) = credential {
return Ok(BearerCredentialHelper::new(token, "url"));
}
// 3. check credential from environment

View File

@@ -69,10 +69,10 @@ impl GitCredentialHelper {
let mut line = line?;
line.retain(|c| !c.is_whitespace());
if let Some(hf_token) = line.strip_prefix("password=") {
if !hf_token.is_empty() {
return Ok(hf_token.to_owned());
}
if let Some(hf_token) = line.strip_prefix("password=")
&& !hf_token.is_empty()
{
return Ok(hf_token.to_owned());
}
}

View File

@@ -61,10 +61,10 @@ impl GitRepo {
let config = repo.config()?.snapshot()?;
// try tracking remote
if let Some(branch) = maybe_branch_name {
if let Ok(remote) = config.get_string(&format!("branch.{}.remote", branch)) {
return Ok(remote);
}
if let Some(branch) = maybe_branch_name
&& let Ok(remote) = config.get_string(&format!("branch.{}.remote", branch))
{
return Ok(remote);
}
// try lfsdefault remote
@@ -74,10 +74,10 @@ impl GitRepo {
// use only remote if there is only 1
let remotes = repo.remotes()?;
if remotes.len() == 1 {
if let Some(remote) = remotes.get(0) {
return Ok(remote.to_string());
}
if remotes.len() == 1
&& let Some(remote) = remotes.get(0)
{
return Ok(remote.to_string());
}
// fall back to default if all above lookup failed,

View File

@@ -1,7 +1,7 @@
[package]
name = "hf_xet"
version = "1.1.10"
edition = "2021"
edition = "2024"
license = "Apache-2.0"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

View File

@@ -8,7 +8,7 @@ use std::iter::IntoIterator;
use std::sync::Arc;
use data::errors::DataProcessingError;
use data::{data_client, XetFileInfo};
use data::{XetFileInfo, data_client};
use itertools::Itertools;
use progress_tracking::TrackingProgressUpdater;
use pyo3::exceptions::{PyKeyboardInterrupt, PyRuntimeError};

View File

@@ -2,8 +2,8 @@ use std::env;
use std::path::Path;
use std::sync::OnceLock;
use pyo3::types::PyAnyMethods;
use pyo3::Python;
use pyo3::types::PyAnyMethods;
use tracing::info;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
@@ -86,12 +86,11 @@ fn get_version_info_string(py: Python<'_>) -> String {
let mut version_info = String::new();
// Get Python version
if let Ok(sys) = py.import("sys") {
if let Ok(version) = sys.getattr("version").and_then(|v| v.extract::<String>()) {
if let Some(python_version_number) = version.split_whitespace().next() {
version_info.push_str(&format!("python/{python_version_number}; "));
}
}
if let Ok(sys) = py.import("sys")
&& let Ok(version) = sys.getattr("version").and_then(|v| v.extract::<String>())
&& let Some(python_version_number) = version.split_whitespace().next()
{
version_info.push_str(&format!("python/{python_version_number}; "));
}
// Get huggingface_hub+hf_xet versions

View File

@@ -7,7 +7,7 @@ use progress_tracking::{ProgressUpdate, TrackingProgressUpdater};
use pyo3::exceptions::PyTypeError;
use pyo3::prelude::PyAnyMethods;
use pyo3::types::{IntoPyDict, PyList, PyString};
use pyo3::{pyclass, IntoPyObjectExt, Py, PyAny, PyResult, Python};
use pyo3::{IntoPyObjectExt, Py, PyAny, PyResult, Python, pyclass};
use tracing::error;
use xet_runtime::exports::tokio;
@@ -207,7 +207,7 @@ impl WrappedProgressUpdaterImpl {
return Err(PyTypeError::new_err(format!(
"Function {name} must take exactly 1 or 2 arguments, but got {}",
param_names.len()
)))
)));
},
};

View File

@@ -6,9 +6,9 @@ use lazy_static::lazy_static;
use pyo3::exceptions::{PyKeyboardInterrupt, PyRuntimeError};
use pyo3::prelude::*;
use tracing::info;
use xet_runtime::XetRuntime;
use xet_runtime::errors::MultithreadedRuntimeError;
use xet_runtime::sync_primatives::spawn_os_thread;
use xet_runtime::XetRuntime;
lazy_static! {
static ref SIGINT_DETECTED: Arc<AtomicBool> = Arc::new(AtomicBool::new(false));
@@ -214,13 +214,13 @@ where
// Now, if we're in the middle of a shutdown, and this is an error, then
// just translate that error to a KeyboardInterrupt (or we get a lot of
if let Err(ref e) = &result {
if in_sigint_shutdown() {
if cfg!(debug_assertions) {
eprintln!("[debug] ignored error reported during shutdown: {e:?}");
}
return Err(PyKeyboardInterrupt::new_err(()));
if let Err(e) = &result
&& in_sigint_shutdown()
{
if cfg!(debug_assertions) {
eprintln!("[debug] ignored error reported during shutdown: {e:?}");
}
return Err(PyKeyboardInterrupt::new_err(()));
}
// Now return the result.

View File

@@ -1,7 +1,7 @@
[package]
name = "hf_xet_wasm"
version = "0.0.1"
edition = "2021"
edition = "2024"
[lib]
crate-type = ["cdylib", "rlib"]

View File

@@ -1,7 +1,7 @@
[package]
name = "mdb_shard"
version = "0.14.5"
edition = "2021"
edition = "2024"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]

View File

@@ -54,7 +54,7 @@ pub type Result<T> = std::result::Result<T, MDBShardError>;
impl PartialEq for MDBShardError {
fn eq(&self, other: &MDBShardError) -> bool {
match (self, other) {
(MDBShardError::IOError(ref e1), MDBShardError::IOError(ref e2)) => e1.kind() == e2.kind(),
(MDBShardError::IOError(e1), MDBShardError::IOError(e2)) => e1.kind() == e2.kind(),
_ => false,
}
}

View File

@@ -3,8 +3,8 @@ use std::io::{Cursor, Read, Write};
use std::mem::size_of;
use bytes::Bytes;
use merklehash::data_hash::hex;
use merklehash::MerkleHash;
use merklehash::data_hash::hex;
use serde::Serialize;
use utils::serialization_utils::*;
@@ -598,9 +598,9 @@ impl MDBFileInfoView {
#[cfg(test)]
mod tests {
use itertools::{iproduct, Itertools};
use rand::prelude::StdRng;
use itertools::{Itertools, iproduct};
use rand::SeedableRng;
use rand::prelude::StdRng;
use super::*;
use crate::shard_file::test_routines::simple_hash;

View File

@@ -13,7 +13,7 @@ pub mod shard_format;
pub mod shard_in_memory;
pub mod utils;
pub use constants::{hash_is_global_dedup_eligible, MDB_SHARD_TARGET_SIZE};
pub use constants::{MDB_SHARD_TARGET_SIZE, hash_is_global_dedup_eligible};
pub use shard_file_handle::MDBShardFile;
pub use shard_file_manager::ShardFileManager;
pub use shard_format::{MDBShardFileFooter, MDBShardFileHeader, MDBShardInfo};

View File

@@ -69,32 +69,33 @@ fn get_next_actions_for_file_info(
op: MDBSetOperation,
) -> Option<[NextAction; 2]> {
// Special case for union operation on file info with same file hash.
if let (Some(ft0), Some(ft1)) = (h1, h2) {
if std::cmp::Ordering::Equal == ft0.file_hash.cmp(&ft1.file_hash) && op == MDBSetOperation::Union {
// Now two parties have the same file hash, and union should produce only one copy.
// Which one to use is a bit tricky as we have multiple optional pieces of information.
// We can leverage whether one party's flags are a superset of the other to directly
// copy over one of the file info. If neither party's flags are a superset, we will
// need to merge both infos into a single, complete info with info from both parties.
//
// Note: we make an assumption that if info exists in both places, it is the same
// since we can't make a distinction of what is valid and what isn't.
let superset = FileDataSequenceHeader::compare_flag_superset(ft0, ft1);
return match superset {
SupersetResult::SuperA | SupersetResult::Equal => {
// use ft0 since it has more info
Some([NextAction::CopyToOut, NextAction::SkipOver])
},
SupersetResult::SuperB => {
// use ft1 since it has more info
Some([NextAction::SkipOver, NextAction::CopyToOut])
},
SupersetResult::Neither => {
// need to merge as both have some info the other doesn't
Some([NextAction::Merge, NextAction::Nothing]) // Note: merge advances both entries
},
};
}
if let (Some(ft0), Some(ft1)) = (h1, h2)
&& std::cmp::Ordering::Equal == ft0.file_hash.cmp(&ft1.file_hash)
&& op == MDBSetOperation::Union
{
// Now two parties have the same file hash, and union should produce only one copy.
// Which one to use is a bit tricky as we have multiple optional pieces of information.
// We can leverage whether one party's flags are a superset of the other to directly
// copy over one of the file info. If neither party's flags are a superset, we will
// need to merge both infos into a single, complete info with info from both parties.
//
// Note: we make an assumption that if info exists in both places, it is the same
// since we can't make a distinction of what is valid and what isn't.
let superset = FileDataSequenceHeader::compare_flag_superset(ft0, ft1);
return match superset {
SupersetResult::SuperA | SupersetResult::Equal => {
// use ft0 since it has more info
Some([NextAction::CopyToOut, NextAction::SkipOver])
},
SupersetResult::SuperB => {
// use ft1 since it has more info
Some([NextAction::SkipOver, NextAction::CopyToOut])
},
SupersetResult::Neither => {
// need to merge as both have some info the other doesn't
Some([NextAction::Merge, NextAction::Nothing]) // Note: merge advances both entries
},
};
}
get_next_actions(h1.map(|f| &f.file_hash), h2.map(|f| &f.file_hash), op)
@@ -131,11 +132,7 @@ fn set_operation<R: Read + Seek, W: Write>(
let load_next = |_r: &mut R, _s: &MDBShardInfo| -> Result<_> {
let fdsh = FileDataSequenceHeader::deserialize(_r)?;
if fdsh.is_bookend() {
Ok(None)
} else {
Ok(Some(fdsh))
}
if fdsh.is_bookend() { Ok(None) } else { Ok(Some(fdsh)) }
};
let mut file_data_header = [load_next(r[0], s[0])?, load_next(r[1], s[1])?];
@@ -276,11 +273,7 @@ fn set_operation<R: Read + Seek, W: Write>(
let load_next = |_r: &mut R, _s: &MDBShardInfo| -> Result<_> {
let ccsh = CASChunkSequenceHeader::deserialize(_r)?;
if ccsh.is_bookend() {
Ok(None)
} else {
Ok(Some(ccsh))
}
if ccsh.is_bookend() { Ok(None) } else { Ok(Some(ccsh)) }
};
let mut cas_data_header = [load_next(r[0], s[0])?, load_next(r[1], s[1])?];
@@ -418,7 +411,7 @@ fn shard_file_op(f1: &Path, f2: &Path, out: &Path, op: MDBSetOperation) -> Resul
let temp_file_name = dir.join(format!(".{uuid}.mdb_temp"));
let mut hashed_write; // Need to access after file is closed.
// Scoped so that file is closed and flushed before name is changed.
// Scoped so that file is closed and flushed before name is changed.
let shard;
{

View File

@@ -1,16 +1,16 @@
use std::fs::File;
use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::{Duration, Instant};
use anyhow::{anyhow, Ok, Result};
use anyhow::{Ok, Result, anyhow};
use clap::Parser;
use mdb_shard::cas_structs::{CASChunkSequenceEntry, CASChunkSequenceHeader, MDBCASInfo};
use mdb_shard::shard_file_manager::ShardFileManager;
use mdb_shard::shard_format::test_routines::rng_hash;
use mdb_shard::shard_format::MDBShardInfo;
use mdb_shard::shard_format::test_routines::rng_hash;
use mdb_shard::shard_in_memory::MDBInMemoryShard;
use merklehash::MerkleHash;
use rand::rngs::StdRng;
@@ -137,6 +137,7 @@ async fn run_shard_benchmark(
});
// Wait for all tasks to complete
#[allow(clippy::never_loop)]
for task in tasks {
task.await?;
}

View File

@@ -7,9 +7,10 @@ use std::sync::{Arc, RwLock};
use std::time::{Duration, SystemTime};
use heapify::{make_heap_with, pop_heap_with};
use merklehash::{compute_data_hash, HMACKey, HashedWrite, MerkleHash};
use merklehash::{HMACKey, HashedWrite, MerkleHash, compute_data_hash};
use tracing::{debug, error, info, warn};
use crate::MDBShardFileFooter;
use crate::cas_structs::CASChunkSequenceHeader;
use crate::constants::MDB_SHARD_EXPIRATION_BUFFER_SECS;
use crate::error::{MDBShardError, Result};
@@ -17,7 +18,6 @@ use crate::file_structs::{FileDataSequenceEntry, MDBFileInfo};
use crate::shard_file::current_timestamp;
use crate::shard_format::MDBShardInfo;
use crate::utils::{parse_shard_filename, shard_file_name, temp_shard_file_name, truncate_hash};
use crate::MDBShardFileFooter;
/// When a specific implementation of the
#[derive(Debug)]
@@ -472,7 +472,9 @@ impl MDBShardFile {
// Check the parsed shard from the filename.
if let Some(parsed_shard_hash) = parse_shard_filename(&self.path) {
if hash != parsed_shard_hash {
error!("Hash parsed from filename does not match the computed hash; hash from filename={parsed_shard_hash:?}, hash of file={hash:?}");
error!(
"Hash parsed from filename does not match the computed hash; hash from filename={parsed_shard_hash:?}, hash of file={hash:?}"
);
}
} else {
warn!("Unable to obtain hash from filename.");

View File

@@ -1,8 +1,8 @@
use std::collections::HashMap;
use std::io::Cursor;
use std::path::{Path, PathBuf};
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use merklehash::{HMACKey, MerkleHash};
use tokio::sync::RwLock;

View File

@@ -1,5 +1,5 @@
use std::collections::{BTreeMap, HashMap};
use std::io::{copy, Read, Seek, SeekFrom, Write};
use std::io::{Read, Seek, SeekFrom, Write, copy};
use std::mem::size_of;
use std::ops::Add;
use std::sync::Arc;

View File

@@ -1,4 +1,4 @@
use std::io::{copy, Cursor, Read, Write};
use std::io::{Cursor, Read, Write, copy};
use std::mem::size_of;
use bytes::Bytes;
@@ -9,7 +9,7 @@ use itertools::Itertools;
use crate::cas_structs::{CASChunkSequenceEntry, CASChunkSequenceHeader, MDBCASInfoView};
use crate::error::{MDBShardError, Result};
use crate::file_structs::{FileDataSequenceHeader, MDBFileInfoView};
use crate::shard_file::{current_timestamp, MDB_FILE_INFO_ENTRY_SIZE};
use crate::shard_file::{MDB_FILE_INFO_ENTRY_SIZE, current_timestamp};
use crate::{MDBShardFileFooter, MDBShardFileHeader};
/// Runs through a shard file info section, calling the specified callback function for each entry.
@@ -359,11 +359,11 @@ mod tests {
use anyhow::Result;
use super::MDBMinimalShard;
use crate::MDBShardInfo;
use crate::cas_structs::MDBCASInfo;
use crate::file_structs::MDBFileInfo;
use crate::shard_file::test_routines::{convert_to_file, gen_random_shard};
use crate::shard_in_memory::MDBInMemoryShard;
use crate::MDBShardInfo;
fn verify_serialization(min_shard: &MDBMinimalShard, mem_shard: &MDBInMemoryShard) -> Result<()> {
for verification in [true, false] {

View File

@@ -1,7 +1,7 @@
[package]
name = "merklehash"
version = "0.14.5"
edition = "2021"
edition = "2024"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]

View File

@@ -1,7 +1,7 @@
use std::cell::RefCell;
use std::fmt::Write;
use crate::{compute_internal_node_hash, MerkleHash};
use crate::{MerkleHash, compute_internal_node_hash};
pub const AGGREGATED_HASHES_MEAN_TREE_BRANCHING_FACTOR: u64 = 4;

View File

@@ -7,9 +7,9 @@ use std::num::ParseIntError;
use std::ops::{Deref, DerefMut};
use std::{fmt, str};
use base64::Engine as _;
// URL safe Base 64 encoding with ending characters removed.
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine as _;
use rand::rngs::SmallRng;
use rand::{RngCore, SeedableRng};
use safe_transmute::{transmute_to_bytes, transmute_to_bytes_mut};
@@ -464,7 +464,7 @@ mod tests {
use rand::prelude::*;
use crate::{compute_data_hash, DataHash, HashedWrite};
use crate::{DataHash, HashedWrite, compute_data_hash};
#[test]
fn test_try_from_bytes() {

View File

@@ -1,7 +1,7 @@
[package]
name = "progress_tracking"
version = "0.1.0"
edition = "2021"
edition = "2024"
[dependencies]
merklehash = { path = "../merklehash" }

View File

@@ -1,5 +1,5 @@
use std::collections::hash_map::Entry as HashMapEntry;
use std::collections::HashMap;
use std::collections::hash_map::Entry as HashMapEntry;
use std::sync::Arc;
use std::time::Duration;
@@ -231,8 +231,8 @@ impl TrackingProgressUpdater for AggregatingProgressUpdater {
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use super::*;

View File

@@ -1,10 +1,10 @@
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use more_asserts::debug_assert_le;
use crate::progress_info::{ItemProgressUpdate, ProgressUpdate};
use crate::TrackingProgressUpdater;
use crate::progress_info::{ItemProgressUpdate, ProgressUpdate};
/// This wraps a TrackingProgressUpdater, translating per-item updates to a full progress report.
pub struct ItemProgressUpdater {

View File

@@ -1,7 +1,7 @@
[package]
name = "utils"
version = "0.14.5"
edition = "2021"
edition = "2024"
[lib]
name = "utils"

View File

@@ -1,6 +1,6 @@
use std::io::Write;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use std::task::{Context, Poll, ready};
use futures::{AsyncRead, AsyncReadExt};

View File

@@ -20,7 +20,9 @@ pub trait ParsableConfigValue: Debug + Sized {
v
},
None => {
warn!("Configuration value {v} for {variable_name} cannot be parsed into correct type; reverting to default.");
warn!(
"Configuration value {v} for {variable_name} cannot be parsed into correct type; reverting to default."
);
info!("Config: {variable_name} = {default:?} (default due to parse error)");
default
},
@@ -182,7 +184,9 @@ macro_rules! test_set_globals {
// Construct the environment variable name, e.g. "HF_XET_MAX_NUM_CHUNKS"
let env_name = concat!("HF_XET_", stringify!($var_name));
// Convert the $val to a string and set it
std::env::set_var(env_name, val.to_string());
unsafe {
std::env::set_var(env_name, val.to_string());
}
// Force lazy_static to be read now:
let actual_value = *$var_name;

View File

@@ -11,7 +11,9 @@ pub struct EnvVarGuard {
impl EnvVarGuard {
pub fn set(key: &'static str, value: impl AsRef<OsStr>) -> Self {
let prev = env::var(key).ok();
env::set_var(key, value);
unsafe {
env::set_var(key, value);
}
Self { key, prev }
}
}
@@ -19,9 +21,13 @@ impl EnvVarGuard {
impl Drop for EnvVarGuard {
fn drop(&mut self) {
if let Some(v) = &self.prev {
env::set_var(self.key, v);
unsafe {
env::set_var(self.key, v);
}
} else {
env::remove_var(self.key);
unsafe {
env::remove_var(self.key);
}
}
}
}

View File

@@ -21,4 +21,4 @@ pub use rw_task_lock::{RwTaskLock, RwTaskLockError, RwTaskLockReadGuard};
mod file_paths;
#[cfg(not(target_family = "wasm"))]
pub use file_paths::{normalized_path_from_user_string, CwdGuard, EnvVarGuard};
pub use file_paths::{CwdGuard, EnvVarGuard, normalized_path_from_user_string};

View File

@@ -31,7 +31,7 @@ impl<T, E> Deref for RwTaskLockReadGuard<'_, T, E> {
type Target = T;
fn deref(&self) -> &T {
match &*self.guard {
RwTaskLockState::Ready(ref val) => val,
RwTaskLockState::Ready(val) => val,
_ => unreachable!("Read guard is only constructed for Ready state"),
}
}

View File

@@ -1,8 +1,8 @@
use std::io::{Read, Write};
use std::mem::{size_of, transmute};
use futures::io::AsyncRead;
use futures::AsyncReadExt;
use futures::io::AsyncRead;
use merklehash::MerkleHash;
#[inline]

View File

@@ -43,7 +43,7 @@ use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicU16, Ordering};
use std::sync::{Arc, Mutex, RwLock};
use std::task::{ready, Context, Poll};
use std::task::{Context, Poll, ready};
use error_printer::ErrorPrinter;
use futures::future::Either;
@@ -397,8 +397,8 @@ where
#[cfg(test)]
pub(crate) mod tests {
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Duration;
use futures::future::join_all;
@@ -647,16 +647,16 @@ mod test_deadlock {
use std::collections::HashMap;
use std::sync::Arc;
use futures::stream::iter;
use futures::StreamExt;
use futures::stream::iter;
use tests::WAITER_TIMEOUT;
use tokio::runtime::Handle;
use tokio::sync::mpsc::error::SendError;
use tokio::sync::mpsc::{channel, Sender};
use tokio::sync::mpsc::{Sender, channel};
use tokio::sync::{Mutex, Notify};
use tokio::time::timeout;
use super::{tests, Group};
use super::{Group, tests};
#[tokio::test]
async fn test_deadlock() {
@@ -794,8 +794,8 @@ mod test_futures_unordered {
use std::sync::Arc;
use std::time::Duration;
use futures_util::stream::FuturesUnordered;
use futures_util::TryStreamExt;
use futures_util::stream::FuturesUnordered;
use tokio::sync::mpsc;
use tokio::time::sleep;

View File

@@ -1,7 +1,7 @@
[package]
name = "xet_runtime"
version = "0.1.0"
edition = "2021"
edition = "2024"
[dependencies]
utils = { path = "../utils" }

View File

@@ -5,7 +5,7 @@ pub mod runtime;
pub use runtime::XetRuntime;
pub mod sync_primatives;
pub use sync_primatives::{spawn_os_thread, SyncJoinHandle};
pub use sync_primatives::{SyncJoinHandle, spawn_os_thread};
#[macro_use]
mod global_semaphores;

View File

@@ -148,10 +148,10 @@ impl XetRuntime {
fn current_if_exists() -> Option<Arc<Self>> {
let maybe_rt = THREAD_RUNTIME_REF.with_borrow(|rt| rt.clone());
if let Some((pid, rt)) = maybe_rt {
if pid == std::process::id() {
return Some(rt);
}
if let Some((pid, rt)) = maybe_rt
&& pid == std::process::id()
{
return Some(rt);
}
None
@@ -188,24 +188,30 @@ impl XetRuntime {
format!("{THREADPOOL_THREAD_ID_PREFIX}-{id}")
};
let tokio_rt = {
let mut tokio_rt_builder = {
#[cfg(not(target_family = "wasm"))]
{
// A new multithreaded runtime with a capped number of threads
TokioRuntimeBuilder::new_multi_thread().worker_threads(get_num_tokio_worker_threads())
TokioRuntimeBuilder::new_multi_thread()
}
#[cfg(target_family = "wasm")]
{
TokioRuntimeBuilder::new_current_thread()
}
};
#[cfg(not(target_family = "wasm"))]
{
tokio_rt_builder.worker_threads(get_num_tokio_worker_threads());
}
.thread_name_fn(get_thread_name) // thread names will be hf-xet-0, hf-xet-1, etc.
.on_thread_start(set_threadlocal_reference) // Set the local runtime reference.
.thread_stack_size(THREADPOOL_STACK_SIZE) // 8MB stack size, default is 2MB
.enable_all() // enable all features, including IO/Timer/Signal/Reactor
.build()
.map_err(MultithreadedRuntimeError::RuntimeInitializationError)?;
let tokio_rt = tokio_rt_builder
.thread_name_fn(get_thread_name) // thread names will be hf-xet-0, hf-xet-1, etc.
.on_thread_start(set_threadlocal_reference) // Set the local runtime reference.
.thread_stack_size(THREADPOOL_STACK_SIZE) // 8MB stack size, default is 2MB
.enable_all() // enable all features, including IO/Timer/Signal/Reactor
.build()
.map_err(MultithreadedRuntimeError::RuntimeInitializationError)?;
// Now that the runtime is created, fill out the original struct.
let handle = tokio_rt.handle().clone();

View File

@@ -70,7 +70,7 @@ impl<T: Send + Sync + 'static> SyncJoinHandle<T> {
/// # Examples
///
/// ```
/// use xet_runtime::{spawn_os_thread, SyncJoinHandle};
/// use xet_runtime::{SyncJoinHandle, spawn_os_thread};
/// let handle: SyncJoinHandle<_> = spawn_os_thread(|| 42);
///
/// // Possibly do some work here...

View File

@@ -251,13 +251,7 @@ mod parallel_tests {
#[tokio::test(flavor = "multi_thread")]
async fn test_returns_join_error_on_panic() {
let futures = (0..10).map(|i| async move {
if i == 5 {
panic!("5")
} else {
Result::<_, i32>::Ok(i)
}
});
let futures = (0..10).map(|i| async move { if i == 5 { panic!("5") } else { Result::<_, i32>::Ok(i) } });
let result = run_constrained(futures, 2).await;
if let Err(ParutilsError::Join(e)) = result {