feat: expose skip_sha256 parameter in Python upload API (#705)

## Summary

Add `skip_sha256` and `sha256s` parameters to `upload_bytes()` Python
binding for per-file SHA-256 policies:
- `skip_sha256: bool = False` - Skip SHA-256 computation entirely (sets
`Sha256Policy::Skip`)
- `sha256s: Optional[List[str]] = None` - Provide pre-computed SHA-256
hashes (companion to existing parameter on `upload_files()`)
- These parameters are mutually exclusive

## Changes

**Python binding changes:**
- Add `skip_sha256` + `sha256s` params to `upload_bytes()` /
`upload_files()`
- All policy conversion happens at Python boundary

**Internal refactoring:**
- Add `Clone`/`Copy` derives + `from_skip()`/`from_hex()` helpers to
`Sha256Policy`
- Update `upload_bytes_async`, `upload_async`, `clean_file` to use
`Vec<Sha256Policy>`
- Update all internal callers across `git_xet`, `xet_pkg`, migration
tool, tests

## Motivation

`huggingface_hub` already knows whether SHA-256 is required. This change
enables skipping expensive computation when unnecessary, or passing
pre-computed hashes for bulk operations.

Companion to #678.

---------

Co-authored-by: Wauplin <lucainp@gmail.com>
This commit is contained in:
Adrien
2026-03-12 18:17:12 +01:00
committed by GitHub
parent cacd713218
commit 0fb930c8d0
8 changed files with 116 additions and 49 deletions

View File

@@ -6,8 +6,8 @@ use async_trait::async_trait;
use http::header;
use xet_client::cas_client::auth::TokenRefresher;
use xet_client::hub_client::Operation;
use xet_data::processing::FileUploadSession;
use xet_data::processing::data_client::{clean_file, default_config};
use xet_data::processing::{FileUploadSession, Sha256Policy};
use xet_data::progress_tracking::{ProgressUpdate, TrackingProgressUpdater};
use crate::constants::{
@@ -141,7 +141,7 @@ impl TransferAgent for XetAgent {
return Err(GitLFSProtocolError::bad_syntax("file path not provided for upload request").into());
};
clean_file(session.clone(), file_path, &req.oid, None).await?;
clean_file(session.clone(), file_path, Sha256Policy::from_hex(&req.oid), None).await?;
// We need to actually upload the shard after each file upload to have the files registered, because
//

View File

@@ -18,7 +18,7 @@ use runtime::async_run;
use token_refresh::WrappedTokenRefresher;
use tracing::debug;
use xet_data::processing::errors::DataProcessingError;
use xet_data::processing::{XetFileInfo, data_client};
use xet_data::processing::{Sha256Policy, XetFileInfo, data_client};
use xet_data::progress_tracking::TrackingProgressUpdater;
use xet_runtime::core::file_handle_limits;
@@ -81,7 +81,7 @@ fn convert_data_processing_error(e: DataProcessingError) -> PyErr {
}
#[pyfunction]
#[pyo3(signature = (file_contents, endpoint, token_info, token_refresher, progress_updater, _repo_type, request_headers=None), text_signature = "(file_contents: List[bytes], endpoint: Optional[str], token_info: Optional[(str, int)], token_refresher: Optional[Callable[[], (str, int)]], progress_updater: Optional[Callable[[int], None]], _repo_type: Optional[str], request_headers: Optional[Dict[str, str]]) -> List[PyXetUploadInfo]")]
#[pyo3(signature = (file_contents, endpoint, token_info, token_refresher, progress_updater, _repo_type, request_headers=None, sha256s=None, skip_sha256=false), text_signature = "(file_contents: List[bytes], endpoint: Optional[str], token_info: Optional[(str, int)], token_refresher: Optional[Callable[[], (str, int)]], progress_updater: Optional[Callable[[int], None]], _repo_type: Optional[str], request_headers: Optional[Dict[str, str]], sha256s: Optional[List[str]], skip_sha256: bool = False) -> List[PyXetUploadInfo]")]
#[allow(clippy::too_many_arguments)]
pub fn upload_bytes(
py: Python,
@@ -92,7 +92,29 @@ pub fn upload_bytes(
progress_updater: Option<Py<PyAny>>,
_repo_type: Option<String>,
request_headers: Option<HashMap<String, String>>,
sha256s: Option<Vec<String>>,
skip_sha256: bool,
) -> PyResult<Vec<PyXetUploadInfo>> {
if skip_sha256 && sha256s.is_some() {
return Err(PyRuntimeError::new_err("skip_sha256=True and sha256s are mutually exclusive"));
}
if let Some(ref s) = sha256s
&& s.len() != file_contents.len()
{
return Err(PyRuntimeError::new_err(format!(
"sha256s length ({}) must match file_contents length ({})",
s.len(),
file_contents.len()
)));
}
let sha256_policies: Vec<Sha256Policy> = match sha256s {
_ if skip_sha256 => vec![Sha256Policy::Skip; file_contents.len()],
Some(v) => v.iter().map(|s| Sha256Policy::from_hex(s)).collect(),
None => vec![Sha256Policy::Compute; file_contents.len()],
};
let refresher = token_refresher.map(WrappedTokenRefresher::from_func).transpose()?.map(Arc::new);
let updater = progress_updater.map(WrappedProgressUpdater::new).transpose()?.map(Arc::new);
let x: u64 = rand::rng().random();
@@ -109,6 +131,7 @@ pub fn upload_bytes(
let out: Vec<PyXetUploadInfo> = data_client::upload_bytes_async(
file_contents,
sha256_policies,
endpoint,
token_info,
refresher.map(|v| v as Arc<_>),
@@ -128,7 +151,7 @@ pub fn upload_bytes(
}
#[pyfunction]
#[pyo3(signature = (file_paths, endpoint, token_info, token_refresher, progress_updater, _repo_type, request_headers=None, sha256s=None), text_signature = "(file_paths: List[str], endpoint: Optional[str], token_info: Optional[(str, int)], token_refresher: Optional[Callable[[], (str, int)]], progress_updater: Optional[Callable[[int], None]], _repo_type: Optional[str], request_headers: Optional[Dict[str, str]], sha256s: Optional[List[str]]) -> List[PyXetUploadInfo]")]
#[pyo3(signature = (file_paths, endpoint, token_info, token_refresher, progress_updater, _repo_type, request_headers=None, sha256s=None, skip_sha256=false), text_signature = "(file_paths: List[str], endpoint: Optional[str], token_info: Optional[(str, int)], token_refresher: Optional[Callable[[], (str, int)]], progress_updater: Optional[Callable[[int], None]], _repo_type: Optional[str], request_headers: Optional[Dict[str, str]], sha256s: Optional[List[str]], skip_sha256: bool = False) -> List[PyXetUploadInfo]")]
#[allow(clippy::too_many_arguments)]
pub fn upload_files(
py: Python,
@@ -140,7 +163,12 @@ pub fn upload_files(
_repo_type: Option<String>,
request_headers: Option<HashMap<String, String>>,
sha256s: Option<Vec<String>>,
skip_sha256: bool,
) -> PyResult<Vec<PyXetUploadInfo>> {
if skip_sha256 && sha256s.is_some() {
return Err(PyRuntimeError::new_err("skip_sha256=True and sha256s are mutually exclusive"));
}
if let Some(ref s) = sha256s
&& s.len() != file_paths.len()
{
@@ -151,6 +179,12 @@ pub fn upload_files(
)));
}
let sha256_policies: Vec<Sha256Policy> = match sha256s {
_ if skip_sha256 => vec![Sha256Policy::Skip; file_paths.len()],
Some(v) => v.iter().map(|s| Sha256Policy::from_hex(s)).collect(),
None => vec![Sha256Policy::Compute; file_paths.len()],
};
let refresher = token_refresher.map(WrappedTokenRefresher::from_func).transpose()?.map(Arc::new);
let updater = progress_updater.map(WrappedProgressUpdater::new).transpose()?.map(Arc::new);
@@ -171,7 +205,7 @@ pub fn upload_files(
let out: Vec<PyXetUploadInfo> = data_client::upload_async(
file_paths,
sha256s,
sha256_policies,
endpoint,
token_info,
refresher.map(|v| v as Arc<_>),
@@ -564,11 +598,26 @@ mod tests {
let file_paths = vec!["a.txt".to_string(), "b.txt".to_string()];
let sha256s = Some(vec!["abc123".to_string()]); // 1 hash for 2 files
let result = upload_files(py, file_paths, None, None, None, None, None, None, sha256s);
let result = upload_files(py, file_paths, None, None, None, None, None, None, sha256s, false);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("sha256s length (1) must match file_paths length (2)"), "got: {err_msg}");
});
}
#[test]
fn test_upload_files_skip_sha256_conflicts_with_sha256s() {
setup();
pyo3::Python::attach(|py| {
let file_paths = vec!["a.txt".to_string()];
let sha256s = Some(vec!["abc123".to_string()]);
let result = upload_files(py, file_paths, None, None, None, None, None, None, sha256s, true);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("mutually exclusive"), "got: {err_msg}");
});
}
}

View File

@@ -11,7 +11,6 @@ use ulid::Ulid;
use xet_client::cas_client::auth::{AuthConfig, TokenRefresher};
use xet_client::cas_client::remote_client::PREFIX_DEFAULT;
use xet_core_structures::merklehash::MerkleHash;
use xet_core_structures::metadata_shard::Sha256;
use xet_core_structures::xorb_object::CompressionScheme;
use xet_runtime::core::par_utils::run_constrained_with_semaphore;
use xet_runtime::core::{XetRuntime, check_sigint_shutdown, xet_cache_root, xet_config};
@@ -95,12 +94,21 @@ pub fn default_config(
#[instrument(skip_all, name = "data_client::upload_bytes", fields(session_id = tracing::field::Empty, num_files=file_contents.len()))]
pub async fn upload_bytes_async(
file_contents: Vec<Vec<u8>>,
sha256_policies: Vec<Sha256Policy>,
endpoint: Option<String>,
token_info: Option<(String, u64)>,
token_refresher: Option<Arc<dyn TokenRefresher>>,
progress_updater: Option<Arc<dyn TrackingProgressUpdater>>,
custom_headers: Option<Arc<HeaderMap>>,
) -> errors::Result<Vec<XetFileInfo>> {
if sha256_policies.len() != file_contents.len() {
return Err(DataProcessingError::ParameterError(format!(
"sha256_policies length ({}) must match file_contents length ({})",
sha256_policies.len(),
file_contents.len()
)));
}
let config = default_config(
endpoint.unwrap_or_else(|| xet_config().data.default_cas_endpoint.clone()),
None,
@@ -113,9 +121,9 @@ pub async fn upload_bytes_async(
let semaphore = XetRuntime::current().common().file_ingestion_semaphore.clone();
let upload_session = FileUploadSession::new(config.into(), progress_updater).await?;
let clean_futures = file_contents.into_iter().map(|blob| {
let clean_futures = file_contents.into_iter().zip(sha256_policies).map(|(blob, policy)| {
let upload_session = upload_session.clone();
async move { clean_bytes(upload_session, blob, None).await.map(|(xf, _metrics)| xf) }
async move { clean_bytes(upload_session, blob, None, policy).await.map(|(xf, _metrics)| xf) }
.instrument(info_span!("clean_task"))
});
let files = run_constrained_with_semaphore(clean_futures, semaphore).await?;
@@ -139,13 +147,21 @@ pub async fn upload_bytes_async(
))]
pub async fn upload_async(
file_paths: Vec<String>,
sha256s: Option<Vec<String>>,
sha256_policies: Vec<Sha256Policy>,
endpoint: Option<String>,
token_info: Option<(String, u64)>,
token_refresher: Option<Arc<dyn TokenRefresher>>,
progress_updater: Option<Arc<dyn TrackingProgressUpdater>>,
custom_headers: Option<Arc<HeaderMap>>,
) -> errors::Result<Vec<XetFileInfo>> {
if sha256_policies.len() != file_paths.len() {
return Err(DataProcessingError::ParameterError(format!(
"sha256_policies length ({}) must match file_paths length ({})",
sha256_policies.len(),
file_paths.len()
)));
}
// chunk files
// produce Xorbs + Shards
// upload shards and xorbs
@@ -164,21 +180,8 @@ pub async fn upload_async(
let upload_session = FileUploadSession::new(config.into(), progress_updater).await?;
// Parse sha256 hex string and ignore invalid ones, or if no sha256 is provided,
// create an iterator of infinite number of "None"s.
let sha256s: Box<dyn Iterator<Item = Option<Sha256>> + Send> = match &sha256s {
Some(v) => {
if v.len() != file_paths.len() {
return Err(DataProcessingError::ParameterError(
"mismatched length of the file list and the sha256 list".into(),
));
}
Box::new(v.iter().map(|s| Sha256::from_hex(s).ok()))
},
None => Box::new(std::iter::repeat(None)),
};
let files_sha256_and_tracking_ids = multizip((file_paths.into_iter(), sha256s, std::iter::repeat_with(Ulid::new)));
let files_sha256_and_tracking_ids =
multizip((file_paths.into_iter(), sha256_policies.into_iter(), std::iter::repeat_with(Ulid::new)));
let ret = upload_session.upload_files(files_sha256_and_tracking_ids).await?;
@@ -261,22 +264,22 @@ pub async fn clean_bytes(
processor: Arc<FileUploadSession>,
bytes: Vec<u8>,
tracking_id: Option<Ulid>,
sha256_policy: Sha256Policy,
) -> errors::Result<(XetFileInfo, DeduplicationMetrics)> {
#[allow(clippy::unwrap_or_default)] // Ulid::default is Ulid::nil
let tracking_id = tracking_id.unwrap_or_else(Ulid::new);
let mut handle = processor
.start_clean(None, bytes.len() as u64, Sha256Policy::Compute, tracking_id)
.start_clean(None, bytes.len() as u64, sha256_policy, tracking_id)
.await;
handle.add_data(&bytes).await?;
handle.finish().await
}
// The provided sha256, if valid, will be directly used in shard upload to avoid redundant computation.
#[instrument(skip_all, name = "clean_file", fields(file.name = tracing::field::Empty, file.len = tracing::field::Empty))]
pub async fn clean_file(
processor: Arc<FileUploadSession>,
filename: impl AsRef<Path>,
sha256: impl AsRef<str>,
sha256_policy: Sha256Policy,
tracking_id: Option<Ulid>,
) -> errors::Result<(XetFileInfo, DeduplicationMetrics)> {
#[allow(clippy::unwrap_or_default)] // Ulid::default is Ulid::nil
@@ -290,12 +293,7 @@ pub async fn clean_file(
let mut buffer = vec![0u8; u64::min(filesize, *xet_config().data.ingestion_block_size) as usize];
let mut handle = processor
.start_clean(
Some(filename.as_ref().to_string_lossy().into()),
filesize,
Sha256::from_hex(sha256.as_ref()).ok().into(),
tracking_id,
)
.start_clean(Some(filename.as_ref().to_string_lossy().into()), filesize, sha256_policy, tracking_id)
.await;
loop {

View File

@@ -18,6 +18,7 @@ use crate::deduplication::{Chunk, Chunker, DeduplicationMetrics, FileDeduper};
use crate::progress_tracking::upload_tracking::CompletionTrackerFileId;
/// Controls how SHA-256 is handled during file cleaning.
#[derive(Clone, Copy)]
pub enum Sha256Policy {
/// Compute SHA-256 from the file data.
Compute,
@@ -27,6 +28,20 @@ pub enum Sha256Policy {
Skip,
}
impl Sha256Policy {
/// Returns `Skip` when `true`, `Compute` when `false`.
pub fn from_skip(skip: bool) -> Self {
if skip { Self::Skip } else { Self::Compute }
}
/// Parses a hex-encoded SHA-256 string into a policy.
///
/// Returns `Provided(hash)` if the hex is valid, `Compute` otherwise.
pub fn from_hex(hex: &str) -> Self {
Sha256::from_hex(hex).ok().into()
}
}
impl From<Option<Sha256>> for Sha256Policy {
fn from(sha256: Option<Sha256>) -> Self {
match sha256 {

View File

@@ -12,7 +12,6 @@ use tokio::task::{JoinHandle, JoinSet};
use tracing::{Instrument, Span, info_span, instrument};
use ulid::Ulid;
use xet_client::cas_client::{Client, ProgressCallback};
use xet_core_structures::metadata_shard::Sha256;
use xet_core_structures::metadata_shard::file_structs::MDBFileInfo;
use xet_core_structures::xorb_object::SerializedXorbObject;
use xet_runtime::core::{XetRuntime, xet_config};
@@ -142,7 +141,7 @@ impl FileUploadSession {
pub async fn upload_files(
self: &Arc<Self>,
files_sha256_and_tracking_ids: impl IntoIterator<Item = (impl AsRef<Path>, Option<Sha256>, Ulid)> + Send,
files_sha256_and_tracking_ids: impl IntoIterator<Item = (impl AsRef<Path>, Sha256Policy, Ulid)> + Send,
) -> Result<Vec<XetFileInfo>> {
let mut cleaning_tasks: Vec<JoinHandle<_>> = vec![];
@@ -185,7 +184,7 @@ impl FileUploadSession {
let mut reader = File::open(&file_path)?;
// Start the clean process for each file.
let mut cleaner = SingleFileCleaner::new(Some(file_name), file_id, sha256.into(), session);
let mut cleaner = SingleFileCleaner::new(Some(file_name), file_id, sha256, session);
let mut bytes_read = 0;
while bytes_read < file_size {

View File

@@ -12,7 +12,7 @@ use xet_runtime::core::par_utils::run_constrained;
use super::super::data_client::{clean_file, default_config};
use super::super::errors::DataProcessingError;
use super::super::{FileUploadSession, XetFileInfo};
use super::super::{FileUploadSession, Sha256Policy, XetFileInfo};
use super::hub_client_token_refresher::HubClientTokenRefresher;
const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
@@ -98,20 +98,20 @@ pub async fn migrate_files_impl(
FileUploadSession::new(config.into(), None).await?
};
let sha256s: Box<dyn Iterator<Item = String> + Send> = match sha256s {
let sha256_policies: Vec<Sha256Policy> = match sha256s {
Some(v) => {
if v.len() != file_paths.len() {
return Err(anyhow!("mismatched length of the file list and the sha256 list"));
}
Box::new(v.into_iter())
v.iter().map(|s| Sha256Policy::from_hex(s)).collect()
},
None => Box::new(std::iter::repeat(String::new())),
None => vec![Sha256Policy::Compute; file_paths.len()],
};
let clean_futs = file_paths.into_iter().zip(sha256s).map(|(file_path, sha256)| {
let clean_futs = file_paths.into_iter().zip(sha256_policies).map(|(file_path, policy)| {
let proc = processor.clone();
async move {
let (pf, metrics) = clean_file(proc, file_path, sha256, None).await?;
let (pf, metrics) = clean_file(proc, file_path, policy, None).await?;
Ok::<(XetFileInfo, u64), DataProcessingError>((pf, metrics.new_bytes))
}
.instrument(info_span!("clean_file"))

View File

@@ -11,6 +11,7 @@ use xet_client::cas_client::{Client, LocalClient, LocalTestServer, LocalTestServ
use super::configurations::TranslatorConfig;
use super::data_client::clean_file;
use super::file_cleaner::Sha256Policy;
use super::{FileDownloadSession, FileUploadSession, XetFileInfo};
use crate::progress_tracking::TrackingProgressUpdater;
@@ -204,7 +205,9 @@ impl HydrateDehydrateTest {
let upload_session = upload_session.clone();
if sequential {
let (pf, metrics) = clean_file(upload_session.clone(), entry.path(), "", None).await.unwrap();
let (pf, metrics) = clean_file(upload_session.clone(), entry.path(), Sha256Policy::Compute, None)
.await
.unwrap();
assert_eq!({ metrics.total_bytes }, entry.metadata().unwrap().len());
std::fs::write(out_file, pf.as_pointer_file().unwrap().as_bytes()).unwrap();
@@ -218,8 +221,11 @@ impl HydrateDehydrateTest {
.map(|entry| self.src_dir.join(entry.unwrap().file_name()))
.collect();
let files_sha256_and_tracking_ids =
multizip((files.iter(), std::iter::repeat(None), std::iter::repeat_with(Ulid::new)));
let files_sha256_and_tracking_ids = multizip((
files.iter(),
std::iter::repeat_with(|| Sha256Policy::Compute),
std::iter::repeat_with(Ulid::new),
));
let clean_results = upload_session.upload_files(files_sha256_and_tracking_ids).await.unwrap();

View File

@@ -248,7 +248,7 @@ impl UploadCommitInner {
*status.lock()? = TaskStatus::Running;
let result = clean_file(upload_session, &file_path, "", Some(tracking_id))
let result = clean_file(upload_session, &file_path, Sha256Policy::Compute, Some(tracking_id))
.await
.map_err(SessionError::from)
.map(|(file_info, _metrics)| file_info);
@@ -279,7 +279,7 @@ impl UploadCommitInner {
*status.lock()? = TaskStatus::Running;
let result = clean_bytes(upload_session, bytes, Some(tracking_id))
let result = clean_bytes(upload_session, bytes, Some(tracking_id), Sha256Policy::Compute)
.await
.map_err(SessionError::from)
.map(|(file_info, _metrics)| file_info);