Migrate hf_xet Python bindings from legacy data_client to XetSession API

Replace all usage of xet_pkg::legacy module with the new xet_session
interface for uploads, downloads, and progress tracking. Token refresh
changes from Python callback to URL-based refresh handled directly in
Rust.
This commit is contained in:
Assaf Vayner
2026-04-06 16:25:47 -07:00
parent 08377eab3c
commit 40dbc9773f
6 changed files with 454 additions and 384 deletions

4
hf_xet/Cargo.lock generated
View File

@@ -1102,7 +1102,6 @@ dependencies = [
name = "hf_xet"
version = "1.4.2"
dependencies = [
"async-trait",
"chrono",
"hf-xet",
"http",
@@ -1112,9 +1111,10 @@ dependencies = [
"pyo3",
"rand 0.9.2",
"signal-hook",
"tokio",
"tracing",
"winapi",
"xet-client",
"xet-data",
"xet-runtime",
]

View File

@@ -11,10 +11,10 @@ crate-type = ["cdylib", "lib"]
[dependencies]
xet-runtime = { path = "../xet_runtime" }
xet-client = { path = "../xet_client" }
xet-data = { path = "../xet_data" }
xet-pkg = { package = "hf-xet", path = "../xet_pkg", features = ["python"] }
tokio = { version = "1", features = ["time", "macros", "rt"] }
async-trait = "0.1"
chrono = "0.4"
itertools = "0.14"
lazy_static = "1.5"
@@ -42,12 +42,12 @@ winapi = { version = "0.3", features = ["consoleapi", "wincon", "errhandlingapi"
[features]
default = ["no-default-cache", "elevated_information_level"] # By default, hf_xet disables the disk cache and uses aggressive logging level
extension-module = ["pyo3/extension-module"] # Only enabled when building with maturin
native-tls = ["xet-client/native-tls-vendored"]
native-tls-vendored = ["xet-client/native-tls-vendored"]
native-tls = ["xet-pkg/native-tls-vendored"]
native-tls-vendored = ["xet-pkg/native-tls-vendored"]
no-default-cache = ["xet-runtime/no-default-cache"]
profiling = ["pprof"]
tokio-console = ["xet-runtime/tokio-console"]
elevated_information_level = ["xet-client/elevated_information_level", "xet-runtime/elevated_information_level"]
elevated_information_level = ["xet-pkg/elevated_information_level"]
[profile.release]
split-debuginfo = "packed"

View File

@@ -1,12 +1,11 @@
mod logging;
mod progress_update;
mod runtime;
mod token_refresh;
use std::collections::HashMap;
use std::fmt::Debug;
use std::iter::IntoIterator;
use std::sync::Arc;
use std::path::PathBuf;
use std::time::Duration;
use http::header::{self, HeaderMap, HeaderName, HeaderValue};
use itertools::Itertools;
@@ -15,15 +14,13 @@ use pyo3::prelude::*;
use pyo3::pyfunction;
use rand::Rng;
use runtime::async_run;
use token_refresh::WrappedTokenRefresher;
use tracing::debug;
use xet_pkg::XetError;
use xet_pkg::legacy::progress_tracking::TrackingProgressUpdater;
use xet_pkg::legacy::{Sha256Policy, XetFileInfo, data_client};
use xet_pkg::xet_session::{Sha256Policy, XetFileInfo, XetSessionBuilder};
use xet_runtime::core::file_handle_limits;
use crate::logging::init_logging;
use crate::progress_update::WrappedProgressUpdater;
use crate::progress_update::{GroupProgressDiffState, ProgressCallback, send_simple_progress};
const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
@@ -31,37 +28,17 @@ const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VE
#[cfg(feature = "profiling")]
pub(crate) mod profiling;
/// Converts a HashMap of headers to a HeaderMap and merges in the USER_AGENT.
///
/// If the input contains a User-Agent header, the USER_AGENT is appended to it.
/// Otherwise, USER_AGENT is set as the only User-Agent header.
fn build_headers_with_user_agent(request_headers: Option<HashMap<String, String>>) -> PyResult<Option<Arc<HeaderMap>>> {
let mut map = request_headers
.map(|headers| {
let mut map = HeaderMap::new();
for (key, value) in headers {
let name = HeaderName::from_bytes(key.as_bytes())
.map_err(|e| PyValueError::new_err(format!("Invalid header name '{}': {}", key, e)))?;
let value = HeaderValue::from_str(&value)
.map_err(|e| PyValueError::new_err(format!("Invalid header value for '{}': {}", key, e)))?;
map.insert(name, value);
}
Ok::<_, PyErr>(map)
})
.transpose()?
.unwrap_or_default();
/// Build a HeaderMap from a Python dict and merge in the USER_AGENT.
fn build_headers_with_user_agent(request_headers: Option<HashMap<String, String>>) -> PyResult<HeaderMap> {
let mut map = request_headers.map(build_header_map).transpose()?.unwrap_or_default();
// Append our USER_AGENT to any existing User-Agent header, or add it if not present
let combined_user_agent = if let Some(existing_ua) = map.get(header::USER_AGENT) {
// Append our user agent to the existing one
let existing_str = existing_ua.to_str().unwrap_or("");
format!("{}; {}", existing_str, USER_AGENT)
} else {
// No existing user agent, use ours
USER_AGENT.to_string()
};
// Try to create the combined header value, fall back gracefully if invalid
let user_agent_value = HeaderValue::from_str(&combined_user_agent)
.or_else(|_: http::header::InvalidHeaderValue| {
Ok::<HeaderValue, http::header::InvalidHeaderValue>(HeaderValue::from_static(USER_AGENT))
@@ -69,22 +46,58 @@ fn build_headers_with_user_agent(request_headers: Option<HashMap<String, String>
.unwrap_or_else(|_: http::header::InvalidHeaderValue| HeaderValue::from_static("unknown"));
map.insert(header::USER_AGENT, user_agent_value);
Ok(Some(Arc::new(map)))
Ok(map)
}
/// Build a HeaderMap from a Python dict without adding USER_AGENT.
fn build_header_map(headers: HashMap<String, String>) -> PyResult<HeaderMap> {
let mut map = HeaderMap::new();
for (key, value) in headers {
let name = HeaderName::from_bytes(key.as_bytes())
.map_err(|e| PyValueError::new_err(format!("Invalid header name '{}': {}", key, e)))?;
let value = HeaderValue::from_str(&value)
.map_err(|e| PyValueError::new_err(format!("Invalid header value for '{}': {}", key, e)))?;
map.insert(name, value);
}
Ok(map)
}
fn convert_xet_error(e: impl Into<XetError>) -> PyErr {
PyErr::from(e.into())
}
/// Configure an auth group builder with optional endpoint, token, headers, and refresh URL.
macro_rules! configure_auth_builder {
($builder:expr, $endpoint:expr, $token_info:expr, $custom_headers:expr,
$token_refresh_url:expr, $token_refresh_headers:expr) => {{
let mut builder = $builder;
if let Some(ep) = $endpoint {
builder = builder.with_endpoint(ep);
}
if let Some((token, expiry)) = $token_info {
builder = builder.with_token_info(token, expiry);
}
if let Some(headers) = $custom_headers {
builder = builder.with_custom_headers(headers);
}
if let Some(url) = $token_refresh_url {
let refresh_headers = $token_refresh_headers.unwrap_or_default();
builder = builder.with_token_refresh_url(url, refresh_headers);
}
builder
}};
}
#[pyfunction]
#[pyo3(signature = (file_contents, endpoint, token_info, token_refresher, progress_updater, _repo_type, request_headers=None, sha256s=None, skip_sha256=false), text_signature = "(file_contents: List[bytes], endpoint: Optional[str], token_info: Optional[(str, int)], token_refresher: Optional[Callable[[], (str, int)]], progress_updater: Optional[Callable[[int], None]], _repo_type: Optional[str], request_headers: Optional[Dict[str, str]], sha256s: Optional[List[str]], skip_sha256: bool = False) -> List[PyXetUploadInfo]")]
#[pyo3(signature = (file_contents, endpoint, token_info, token_refresh_url, token_refresh_headers, progress_updater, _repo_type, request_headers=None, sha256s=None, skip_sha256=false), text_signature = "(file_contents: List[bytes], endpoint: Optional[str], token_info: Optional[(str, int)], token_refresh_url: Optional[str], token_refresh_headers: Optional[Dict[str, str]], progress_updater: Optional[Callable], _repo_type: Optional[str], request_headers: Optional[Dict[str, str]], sha256s: Optional[List[str]], skip_sha256: bool = False) -> List[PyXetUploadInfo]")]
#[allow(clippy::too_many_arguments)]
pub fn upload_bytes(
py: Python,
file_contents: Vec<Vec<u8>>,
endpoint: Option<String>,
token_info: Option<(String, u64)>,
token_refresher: Option<Py<PyAny>>,
token_refresh_url: Option<String>,
token_refresh_headers: Option<HashMap<String, String>>,
progress_updater: Option<Py<PyAny>>,
_repo_type: Option<String>,
request_headers: Option<HashMap<String, String>>,
@@ -111,13 +124,11 @@ pub fn upload_bytes(
None => vec![Sha256Policy::Compute; file_contents.len()],
};
let refresher = token_refresher.map(WrappedTokenRefresher::from_func).transpose()?.map(Arc::new);
let updater = progress_updater.map(WrappedProgressUpdater::new).transpose()?.map(Arc::new);
let callback = progress_updater.map(ProgressCallback::new).transpose()?;
let custom_headers = build_headers_with_user_agent(request_headers)?;
let refresh_headers = token_refresh_headers.map(build_header_map).transpose()?;
let x: u64 = rand::rng().random();
// Convert Python dict -> Rust HashMap -> HeaderMap and merge with USER_AGENT
let header_map = build_headers_with_user_agent(request_headers)?;
async_run(py, async move {
debug!(
"Upload bytes call {x:x}: (PID = {}) Uploading {} files as bytes.",
@@ -125,36 +136,43 @@ pub fn upload_bytes(
file_contents.len(),
);
let out: Vec<PyXetUploadInfo> = data_client::upload_bytes_async(
file_contents,
sha256_policies,
let session = XetSessionBuilder::new().build().map_err(convert_xet_error)?;
let builder = session.new_upload_commit().map_err(convert_xet_error)?;
let builder = configure_auth_builder!(
builder,
endpoint,
token_info,
refresher.map(|v| v as Arc<_>),
updater.map(|v| v as Arc<_>),
header_map,
)
.await
.map_err(convert_xet_error)?
.into_iter()
.map(PyXetUploadInfo::from)
.collect();
Some(custom_headers),
token_refresh_url,
refresh_headers
);
let commit = builder.build().await.map_err(convert_xet_error)?;
// Upload all byte blobs
let mut handles = Vec::with_capacity(file_contents.len());
for (blob, sha256) in file_contents.into_iter().zip(sha256_policies) {
let handle = commit.upload_bytes(blob, sha256, None).await.map_err(convert_xet_error)?;
handles.push(handle);
}
// Commit with concurrent progress polling
let out = commit_with_progress(&commit, &handles, callback.as_ref()).await?;
debug!("Upload bytes call {x:x} finished.");
PyResult::Ok(out)
})
}
#[pyfunction]
#[pyo3(signature = (file_paths, endpoint, token_info, token_refresher, progress_updater, _repo_type, request_headers=None, sha256s=None, skip_sha256=false), text_signature = "(file_paths: List[str], endpoint: Optional[str], token_info: Optional[(str, int)], token_refresher: Optional[Callable[[], (str, int)]], progress_updater: Optional[Callable[[int], None]], _repo_type: Optional[str], request_headers: Optional[Dict[str, str]], sha256s: Optional[List[str]], skip_sha256: bool = False) -> List[PyXetUploadInfo]")]
#[pyo3(signature = (file_paths, endpoint, token_info, token_refresh_url, token_refresh_headers, progress_updater, _repo_type, request_headers=None, sha256s=None, skip_sha256=false), text_signature = "(file_paths: List[str], endpoint: Optional[str], token_info: Optional[(str, int)], token_refresh_url: Optional[str], token_refresh_headers: Optional[Dict[str, str]], progress_updater: Optional[Callable], _repo_type: Optional[str], request_headers: Optional[Dict[str, str]], sha256s: Optional[List[str]], skip_sha256: bool = False) -> List[PyXetUploadInfo]")]
#[allow(clippy::too_many_arguments)]
pub fn upload_files(
py: Python,
file_paths: Vec<String>,
endpoint: Option<String>,
token_info: Option<(String, u64)>,
token_refresher: Option<Py<PyAny>>,
token_refresh_url: Option<String>,
token_refresh_headers: Option<HashMap<String, String>>,
progress_updater: Option<Py<PyAny>>,
_repo_type: Option<String>,
request_headers: Option<HashMap<String, String>>,
@@ -181,16 +199,13 @@ pub fn upload_files(
None => vec![Sha256Policy::Compute; file_paths.len()],
};
let refresher = token_refresher.map(WrappedTokenRefresher::from_func).transpose()?.map(Arc::new);
let updater = progress_updater.map(WrappedProgressUpdater::new).transpose()?.map(Arc::new);
let callback = progress_updater.map(ProgressCallback::new).transpose()?;
let custom_headers = build_headers_with_user_agent(request_headers)?;
let refresh_headers = token_refresh_headers.map(build_header_map).transpose()?;
let file_names = file_paths.iter().take(3).join(", ");
let x: u64 = rand::rng().random();
// Convert Python dict -> Rust HashMap -> HeaderMap and merge with USER_AGENT
let header_map = build_headers_with_user_agent(request_headers)?;
async_run(py, async move {
debug!(
"Upload call {x:x}: (PID = {}) Uploading {} files {file_names}{}",
@@ -199,57 +214,117 @@ pub fn upload_files(
if file_paths.len() > 3 { "..." } else { "." }
);
let out: Vec<PyXetUploadInfo> = data_client::upload_async(
file_paths,
sha256_policies,
let session = XetSessionBuilder::new().build().map_err(convert_xet_error)?;
let builder = session.new_upload_commit().map_err(convert_xet_error)?;
let builder = configure_auth_builder!(
builder,
endpoint,
token_info,
refresher.map(|v| v as Arc<_>),
updater.map(|v| v as Arc<_>),
header_map,
)
.await
.map_err(convert_xet_error)?
.into_iter()
.map(PyXetUploadInfo::from)
.collect();
Some(custom_headers),
token_refresh_url,
refresh_headers
);
let commit = builder.build().await.map_err(convert_xet_error)?;
// Upload all files
let mut handles = Vec::with_capacity(file_paths.len());
for (path, sha256) in file_paths.into_iter().zip(sha256_policies) {
let handle = commit
.upload_from_path(PathBuf::from(path), sha256)
.await
.map_err(convert_xet_error)?;
handles.push(handle);
}
// Commit with concurrent progress polling
let out = commit_with_progress(&commit, &handles, callback.as_ref()).await?;
debug!("Upload call {x:x} finished.");
PyResult::Ok(out)
})
}
/// Commit an upload, polling progress concurrently if a callback is provided.
/// Returns results in the same order as the handles.
async fn commit_with_progress(
commit: &xet_pkg::xet_session::XetUploadCommit,
handles: &[xet_pkg::xet_session::XetFileUpload],
callback: Option<&ProgressCallback>,
) -> PyResult<Vec<PyXetUploadInfo>> {
// Collect task_ids in order so we can map results back
let task_ids: Vec<_> = handles.iter().map(|h| h.task_id()).collect();
let use_progress = callback.is_some_and(|c| c.is_enabled());
if use_progress {
let callback = callback.unwrap();
let commit_clone = commit.clone();
let mut commit_task = tokio::spawn(async move { commit_clone.commit().await });
let mut diff_state = GroupProgressDiffState::new();
let mut interval = tokio::time::interval(Duration::from_millis(250));
let report = loop {
tokio::select! {
result = &mut commit_task => {
let report = result
.map_err(|e| convert_xet_error(XetError::from(e)))?
.map_err(convert_xet_error)?;
// Final progress update
let group = commit.progress();
let items = handles
.iter()
.filter_map(|h| h.progress().map(|p| (h.task_id(), p)))
.collect();
let diff = diff_state.compute_diff(group, items);
let _ = callback.send_update(diff).await;
break report;
}
_ = interval.tick() => {
let group = commit.progress();
let items = handles
.iter()
.filter_map(|h| h.progress().map(|p| (h.task_id(), p)))
.collect();
let diff = diff_state.compute_diff(group, items);
let _ = callback.send_update(diff).await;
}
}
};
// Map results in order
let mut out = Vec::with_capacity(task_ids.len());
for id in &task_ids {
let meta = report
.uploads
.get(id)
.ok_or_else(|| convert_xet_error(XetError::Internal(format!("missing upload result for task {id}"))))?;
out.push(PyXetUploadInfo::from(meta.xet_info.clone()));
}
Ok(out)
} else {
let report = commit.commit().await.map_err(convert_xet_error)?;
let mut out = Vec::with_capacity(task_ids.len());
for id in &task_ids {
let meta = report
.uploads
.get(id)
.ok_or_else(|| convert_xet_error(XetError::Internal(format!("missing upload result for task {id}"))))?;
out.push(PyXetUploadInfo::from(meta.xet_info.clone()));
}
Ok(out)
}
}
/// Compute xet hashes for files without uploading.
///
/// This function computes cryptographic hashes for the specified files using the same
/// chunking and hashing algorithm as upload operations, but without requiring
/// authentication or server connection. The resulting hashes can be used to verify
/// file integrity after downloads or to determine which files need to be uploaded.
///
/// Args:
/// file_paths: List of file paths to hash.
///
/// Returns:
/// List[PyXetUploadInfo]: List of hash results in the same order as input paths.
/// Each result contains the hash (as hex string) and file size in bytes.
///
/// Raises:
/// RuntimeError: If any file cannot be read or hashed.
///
/// Example:
/// >>> import hf_xet
/// >>> results = hf_xet.hash_files(["/path/to/file1.txt", "/path/to/file2.txt"])
/// >>> for path, info in zip(file_paths, results):
/// ... print(f"Hash: {info.hash}, Size: {info.file_size}")
///
/// Note:
/// This function is primarily used for validation and verification of transferred
/// files. Clients can verify that downloaded files are correctly reassembled by
/// comparing the computed hash with the expected hash from the server.
#[pyfunction]
#[pyo3(signature = (file_paths), text_signature = "(file_paths: List[str]) -> List[PyXetUploadInfo]")]
pub fn hash_files(py: Python, file_paths: Vec<String>) -> PyResult<Vec<PyXetUploadInfo>> {
async_run(py, async move {
let out: Vec<PyXetUploadInfo> = data_client::hash_files_async(file_paths)
let out: Vec<PyXetUploadInfo> = xet_data::processing::data_client::hash_files_async(file_paths)
.await
.map_err(convert_xet_error)?
.into_iter()
@@ -261,25 +336,23 @@ pub fn hash_files(py: Python, file_paths: Vec<String>) -> PyResult<Vec<PyXetUplo
}
#[pyfunction]
#[pyo3(signature = (files, endpoint, token_info, token_refresher, progress_updater, request_headers=None), text_signature = "(files: List[PyXetDownloadInfo], endpoint: Optional[str], token_info: Optional[(str, int)], token_refresher: Optional[Callable[[], (str, int)]], progress_updater: Optional[List[Callable[[int], None]]], request_headers: Optional[Dict[str, str]]) -> List[str]")]
#[pyo3(signature = (files, endpoint, token_info, token_refresh_url, token_refresh_headers, progress_updater, request_headers=None), text_signature = "(files: List[PyXetDownloadInfo], endpoint: Optional[str], token_info: Optional[(str, int)], token_refresh_url: Optional[str], token_refresh_headers: Optional[Dict[str, str]], progress_updater: Optional[List[Callable[[int], None]]], request_headers: Optional[Dict[str, str]]) -> List[str]")]
#[allow(clippy::too_many_arguments)]
pub fn download_files(
py: Python,
files: Vec<PyXetDownloadInfo>,
endpoint: Option<String>,
token_info: Option<(String, u64)>,
token_refresher: Option<Py<PyAny>>,
token_refresh_url: Option<String>,
token_refresh_headers: Option<HashMap<String, String>>,
progress_updater: Option<Vec<Py<PyAny>>>,
request_headers: Option<HashMap<String, String>>,
) -> PyResult<Vec<String>> {
let file_infos: Vec<_> = files.into_iter().map(<(XetFileInfo, DestinationPath)>::from).collect();
let refresher = token_refresher.map(WrappedTokenRefresher::from_func).transpose()?.map(Arc::new);
let updaters = progress_updater.map(try_parse_progress_updaters).transpose()?;
// Convert Python dict -> Rust HashMap -> HeaderMap and merge with USER_AGENT
let header_map = build_headers_with_user_agent(request_headers)?;
let custom_headers = build_headers_with_user_agent(request_headers)?;
let refresh_headers = token_refresh_headers.map(build_header_map).transpose()?;
let x: u64 = rand::rng().random();
let file_names = file_infos.iter().take(3).map(|(_, p)| p).join(", ");
async_run(py, async move {
@@ -290,39 +363,89 @@ pub fn download_files(
if file_infos.len() > 3 { "..." } else { "." }
);
let out: Vec<String> = data_client::download_async(
file_infos,
let session = XetSessionBuilder::new().build().map_err(convert_xet_error)?;
let builder = session.new_file_download_group().map_err(convert_xet_error)?;
let builder = configure_auth_builder!(
builder,
endpoint,
token_info,
refresher.map(|v| v as Arc<_>),
updaters,
header_map,
)
.await
.map_err(convert_xet_error)?;
Some(custom_headers),
token_refresh_url,
refresh_headers
);
let group = builder.build().await.map_err(convert_xet_error)?;
// Queue all downloads
let mut dl_handles = Vec::with_capacity(file_infos.len());
let mut paths = Vec::with_capacity(file_infos.len());
for (file_info, dest_path) in file_infos {
let handle = group
.download_file_to_path(file_info, PathBuf::from(&dest_path))
.await
.map_err(convert_xet_error)?;
dl_handles.push(handle);
paths.push(dest_path);
}
// Finish with concurrent progress polling
let group_for_finish = group.clone();
let mut finish_task = tokio::spawn(async move { group_for_finish.finish().await });
if let Some(updaters) = progress_updater {
let mut prev_completed: Vec<u64> = vec![0; dl_handles.len()];
let mut interval = tokio::time::interval(Duration::from_millis(250));
loop {
tokio::select! {
result = &mut finish_task => {
result
.map_err(|e| convert_xet_error(XetError::from(e)))?
.map_err(convert_xet_error)?;
// Final per-file progress update
for (i, handle) in dl_handles.iter().enumerate() {
if i < updaters.len()
&& let Some(report) = handle.progress()
{
let increment = report.bytes_completed.saturating_sub(prev_completed[i]);
send_simple_progress(&updaters[i], increment).await;
}
}
break;
}
_ = interval.tick() => {
for (i, handle) in dl_handles.iter().enumerate() {
if i < updaters.len()
&& let Some(report) = handle.progress()
{
let increment = report.bytes_completed.saturating_sub(prev_completed[i]);
if increment > 0 {
prev_completed[i] = report.bytes_completed;
send_simple_progress(&updaters[i], increment).await;
}
}
}
}
}
}
} else {
finish_task
.await
.map_err(|e| convert_xet_error(XetError::from(e)))?
.map_err(convert_xet_error)?;
}
debug!("Download call {x:x}: Completed.");
PyResult::Ok(out)
PyResult::Ok(paths)
})
}
#[pyfunction]
pub fn force_sigint_shutdown() -> PyResult<()> {
// Force a signint shutdown in the case where it gets intercepted by another process.
crate::runtime::perform_sigint_shutdown();
Err(PyKeyboardInterrupt::new_err(()))
}
fn try_parse_progress_updaters(funcs: Vec<Py<PyAny>>) -> PyResult<Vec<Arc<dyn TrackingProgressUpdater>>> {
let mut updaters = Vec::with_capacity(funcs.len());
for updater_func in funcs {
let wrapped = Arc::new(WrappedProgressUpdater::new(updater_func)?);
updaters.push(wrapped as Arc<dyn TrackingProgressUpdater>);
}
Ok(updaters)
}
// TODO: we won't need to subclass this in the next major version update.
#[pyclass(subclass)]
#[derive(Clone, Debug)]
@@ -508,9 +631,7 @@ mod tests {
// Initialize Python once for all tests
fn setup() {
// When auto-initialize is enabled, Python will be initialized on first use
// This ensures Python is available for the tests
let _ = pyo3::Python::attach(|_py| {});
pyo3::Python::attach(|_py| {});
}
#[test]
@@ -518,23 +639,19 @@ mod tests {
setup();
let empty_map: HashMap<String, String> = HashMap::new();
let result = build_headers_with_user_agent(Some(empty_map)).unwrap();
let headers = result.unwrap();
// Should have exactly one header: USER_AGENT
assert_eq!(headers.len(), 1);
assert!(headers.contains_key(header::USER_AGENT));
assert_eq!(result.len(), 1);
assert!(result.contains_key(header::USER_AGENT));
let user_agent = headers.get(header::USER_AGENT).unwrap().to_str().unwrap();
let user_agent = result.get(header::USER_AGENT).unwrap().to_str().unwrap();
assert_eq!(user_agent, USER_AGENT);
let result = build_headers_with_user_agent(None).unwrap();
let headers = result.unwrap();
// Should have exactly one header: USER_AGENT
assert_eq!(headers.len(), 1);
assert!(headers.contains_key(header::USER_AGENT));
assert_eq!(result.len(), 1);
assert!(result.contains_key(header::USER_AGENT));
let user_agent = headers.get(header::USER_AGENT).unwrap().to_str().unwrap();
let user_agent = result.get(header::USER_AGENT).unwrap().to_str().unwrap();
assert_eq!(user_agent, USER_AGENT);
}
@@ -546,17 +663,12 @@ mod tests {
headers_map.insert("Authorization".to_string(), "Bearer token123".to_string());
let result = build_headers_with_user_agent(Some(headers_map)).unwrap();
let headers = result.unwrap();
// Should have 3 headers: Content-Type, Authorization, and USER_AGENT
assert_eq!(headers.len(), 3);
assert_eq!(result.len(), 3);
assert_eq!(result.get(header::CONTENT_TYPE).unwrap().to_str().unwrap(), "application/json");
assert_eq!(result.get(header::AUTHORIZATION).unwrap().to_str().unwrap(), "Bearer token123");
// Verify each header was converted correctly
assert_eq!(headers.get(header::CONTENT_TYPE).unwrap().to_str().unwrap(), "application/json");
assert_eq!(headers.get(header::AUTHORIZATION).unwrap().to_str().unwrap(), "Bearer token123");
// Verify USER_AGENT was added
let user_agent = headers.get(header::USER_AGENT).unwrap().to_str().unwrap();
let user_agent = result.get(header::USER_AGENT).unwrap().to_str().unwrap();
assert_eq!(user_agent, USER_AGENT);
}
@@ -567,13 +679,10 @@ mod tests {
headers_map.insert("User-Agent".to_string(), "CustomClient/1.0".to_string());
let result = build_headers_with_user_agent(Some(headers_map)).unwrap();
let headers = result.unwrap();
// Should have exactly one header: USER_AGENT
assert_eq!(headers.len(), 1);
assert_eq!(result.len(), 1);
// Verify USER_AGENT was appended to existing one
let user_agent = headers.get(header::USER_AGENT).unwrap().to_str().unwrap();
let user_agent = result.get(header::USER_AGENT).unwrap().to_str().unwrap();
assert_eq!(user_agent, format!("CustomClient/1.0; {}", USER_AGENT));
}
@@ -585,18 +694,15 @@ mod tests {
let result = build_headers_with_user_agent(Some(headers_map));
// Should return an error for invalid header name
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("Invalid header name"));
let mut headers_map = HashMap::new();
// Header values cannot contain newlines
headers_map.insert("X-Custom".to_string(), "value\nwith\nnewlines".to_string());
let result = build_headers_with_user_agent(Some(headers_map));
// Should return an error for invalid header value
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("Invalid header value"));
@@ -607,9 +713,9 @@ mod tests {
setup();
pyo3::Python::attach(|py| {
let file_paths = vec!["a.txt".to_string(), "b.txt".to_string()];
let sha256s = Some(vec!["abc123".to_string()]); // 1 hash for 2 files
let sha256s = Some(vec!["abc123".to_string()]);
let result = upload_files(py, file_paths, None, None, None, None, None, None, sha256s, false);
let result = upload_files(py, file_paths, None, None, None, None, None, None, None, sha256s, false);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
@@ -624,7 +730,7 @@ mod tests {
let file_paths = vec!["a.txt".to_string()];
let sha256s = Some(vec!["abc123".to_string()]);
let result = upload_files(py, file_paths, None, None, None, None, None, None, sha256s, true);
let result = upload_files(py, file_paths, None, None, None, None, None, None, None, sha256s, true);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();

View File

@@ -1,5 +1,5 @@
use std::collections::HashMap;
use std::fmt::{Debug, Formatter};
use std::sync::Arc;
use itertools::Itertools;
use pyo3::exceptions::PyTypeError;
@@ -7,163 +7,81 @@ use pyo3::prelude::PyAnyMethods;
use pyo3::types::{IntoPyDict, PyList, PyString};
use pyo3::{IntoPyObjectExt, Py, PyAny, PyResult, Python, pyclass};
use tracing::error;
use xet_pkg::legacy::progress_tracking::{ProgressUpdate, TrackingProgressUpdater};
use xet_pkg::xet_session::{GroupProgressReport, ItemProgressReport, UniqueID};
use xet_runtime::core::XetRuntime;
use xet_runtime::error_printer::ErrorPrinter;
use crate::runtime::convert_multithreading_error;
/// Python-exposed versions of the per-item and total progress update classes.
///
/// Both `PyTotalProgressUpdate` and `PyItemProgressUpdate` are passed
/// into a Python callback given to the wrapper class below. For example:
///
/// ```python
/// def update_progress(self, total_update, item_updates):
/// from rich.progress import Progress, TextColumn, BarColumn, TimeRemainingColumn
///
/// # Update overall progress (we assume this has been initialized).
/// self.progress.update(
/// self.bytes_processed_task_id,
/// advance=total_update.total_bytes_completion_increment,
/// total = total_update.total_bytes
/// )
///
/// # Update upload progress ; the total may have changed so set that too.
/// self.progress.update(
/// self.bytes_uploaded_task_id,
/// advance=total_update.total_transfer_bytes_completion_increment,
/// total = total_update.total_transfer_bytes
/// )
///
/// # Update each item:
/// for item in item_updates:
/// name = item.item_name
/// if name not in self.item_tasks:
/// self.item_tasks[name] = self.progress.add_task(
/// name, total=item.total_bytes
/// )
/// self.progress.update(
/// self.item_tasks[name],
/// advance=item.bytes_completion_increment,
/// )
/// ```
///
/// In addition, the other possible bookkeeping values for everything are contained in this
/// as needed.
// === PyO3 progress update classes (exposed to Python) ===
#[pyclass]
pub struct PyItemProgressUpdate {
/// The name of the item, or a tag that is translated later.
#[pyo3(get)]
pub item_name: Py<PyString>,
/// The total bytes contained in this item.
#[pyo3(get)]
pub total_bytes: u64,
/// The number of bytes completed so far, either by deduplication or transfer.
#[pyo3(get)]
pub bytes_completed: u64,
/// The change in bytes completed since the last update.
#[pyo3(get)]
pub bytes_completion_increment: u64,
}
/// Update class for total updates
#[pyclass]
pub struct PyTotalProgressUpdate {
/// The total bytes known for processing and possibly uploaded or downloaded.
#[pyo3(get)]
pub total_bytes: u64,
/// How much total_bytes has changed from the last update.
#[pyo3(get)]
pub total_bytes_increment: u64,
/// How many of the bytes queued for processing have been examined
/// and either deduped or queued for upload or download.
#[pyo3(get)]
pub total_bytes_completed: u64,
/// The change in total_bytes_completed since the same upload.
#[pyo3(get)]
pub total_bytes_completion_increment: u64,
/// If known, the current completion speed.
#[pyo3(get)]
pub total_bytes_completion_rate: Option<f64>,
/// The total bytes scheduled for transfer; also contained in total_bytes.
#[pyo3(get)]
pub total_transfer_bytes: u64,
/// How much total_transfer_bytes has changed since the last update.
#[pyo3(get)]
pub total_transfer_bytes_increment: u64,
/// The cumulative bytes uploaded or downloaded so far. Also contained in total_bytes_completed.
#[pyo3(get)]
pub total_transfer_bytes_completed: u64,
/// The change in total_transfer_bytes_completed since the last update.
#[pyo3(get)]
pub total_transfer_bytes_completion_increment: u64,
/// If known, the current completion speed for bytes transferred.
#[pyo3(get)]
pub total_transfer_bytes_completion_rate: Option<f64>,
}
/// A wrapper over a passed-in python function to update
/// the python process of some download/upload progress
/// implements the ProgressUpdater trait and should be
/// passed around as a ProgressUpdater trait object or
/// as a template parameter
struct WrappedProgressUpdaterImpl {
/// Is this enabled?
progress_updating_enabled: bool,
/// the function py_func is responsible for passing in the update value
/// into the python context. Expects 1 int (uint64) parameter that
/// is a number to increment the progress counter by.
py_func: Py<PyAny>,
name: String,
/// Whether to use the simple incremental progress updating method or
/// the more detailed
update_with_detailed_progress: bool,
}
impl Debug for WrappedProgressUpdaterImpl {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "WrappedTokenRefresher({})", self.name)
}
}
// === Progress callback (validated Python callable) ===
const DETAILED_PROGRESS_ARG_NAMES: [&str; 2] = ["total_update", "item_updates"];
impl WrappedProgressUpdaterImpl {
pub fn new(py_func: Py<PyAny>) -> PyResult<Self> {
// Analyze the function to make sure it's the correct form. If it's 4 arguments with
// the appropriate names, than we call it using the detailed progress update; if it's
// a single function, we assume it's a global increment function and just pass in the update
// increment.
//
// Run on compute thread that doesn't block async workers
Python::attach(|py| {
let func = py_func.bind(py);
/// A validated Python progress callback. Determines on construction whether
/// the callback uses the simple (1-arg increment) or detailed (2-arg total +
/// items) calling convention.
pub struct ProgressCallback {
py_func: Py<PyAny>,
detailed: bool,
enabled: bool,
}
// Test if it's enabled first; if None is passed in, then this is disabled.
impl Debug for ProgressCallback {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "ProgressCallback(detailed={}, enabled={})", self.detailed, self.enabled)
}
}
impl ProgressCallback {
pub fn new(py_func: Py<PyAny>) -> PyResult<Self> {
Python::attach(|py| {
if py_func.is_none(py) {
return Ok(Self {
progress_updating_enabled: false,
py_func,
name: Default::default(),
update_with_detailed_progress: false,
detailed: false,
enabled: false,
});
}
let func = py_func.bind(py);
let name = func
.repr()
.and_then(|repr| repr.extract::<String>())
@@ -187,7 +105,7 @@ impl WrappedProgressUpdaterImpl {
})
.collect::<PyResult<_>>()?;
let update_with_detailed_progress = match param_names.len() {
let detailed = match param_names.len() {
1 => false,
2 => {
if param_names
@@ -212,43 +130,53 @@ impl WrappedProgressUpdaterImpl {
};
Ok(Self {
progress_updating_enabled: true,
py_func,
name,
update_with_detailed_progress,
detailed,
enabled: true,
})
})
}
async fn register_updates_impl(self: Arc<Self>, updates: ProgressUpdate) -> PyResult<()> {
// Run on compute thread that doesn't block async workers
pub fn is_enabled(&self) -> bool {
self.enabled
}
/// Send a progress diff to the Python callback via spawn_blocking
/// (to avoid blocking the async runtime while holding the GIL).
pub async fn send_update(&self, diff: ProgressDiff) -> PyResult<()> {
if !self.enabled || diff.is_empty() {
return Ok(());
}
let py_func = Python::attach(|py| self.py_func.clone_ref(py));
let detailed = self.detailed;
let rt = XetRuntime::current();
rt.spawn_blocking(move || {
Python::attach(|py| {
let f = self.py_func.bind(py);
let f = py_func.bind(py);
if self.update_with_detailed_progress {
let total_update_report: Py<PyAny> = Py::new(
if detailed {
let total_update: Py<PyAny> = Py::new(
py,
PyTotalProgressUpdate {
total_bytes: updates.total_bytes,
total_bytes_increment: updates.total_bytes_increment,
total_bytes_completed: updates.total_bytes_completed,
total_bytes_completion_increment: updates.total_bytes_completion_increment,
total_bytes_completion_rate: updates.total_bytes_completion_rate,
total_transfer_bytes: updates.total_transfer_bytes,
total_transfer_bytes_increment: updates.total_transfer_bytes_increment,
total_transfer_bytes_completed: updates.total_transfer_bytes_completed,
total_transfer_bytes_completion_increment: updates
.total_transfer_bytes_completion_increment,
total_transfer_bytes_completion_rate: updates.total_transfer_bytes_completion_rate,
total_bytes: diff.total_bytes,
total_bytes_increment: diff.total_bytes_increment,
total_bytes_completed: diff.total_bytes_completed,
total_bytes_completion_increment: diff.total_bytes_completion_increment,
total_bytes_completion_rate: diff.total_bytes_completion_rate,
total_transfer_bytes: diff.total_transfer_bytes,
total_transfer_bytes_increment: diff.total_transfer_bytes_increment,
total_transfer_bytes_completed: diff.total_transfer_bytes_completed,
total_transfer_bytes_completion_increment: diff.total_transfer_bytes_completion_increment,
total_transfer_bytes_completion_rate: diff.total_transfer_bytes_completion_rate,
},
)?
.into_py_any(py)?;
let item_updates_v: Vec<Py<PyAny>> = updates
.item_updates
.into_iter()
let item_updates_v: Vec<Py<PyAny>> = diff
.item_diffs
.iter()
.map(|u| {
Py::new(
py,
@@ -269,16 +197,20 @@ impl WrappedProgressUpdaterImpl {
let argname_item_updates: Py<PyAny> = DETAILED_PROGRESS_ARG_NAMES[1].into_py_any(py)?;
let kwargs = [
(argname_total_update, total_update_report),
(argname_total_update, total_update),
(argname_item_updates, item_updates),
]
.into_py_dict(py)?;
f.call((), Some(&kwargs))?;
} else {
let update_increment: u64 =
updates.item_updates.iter().map(|pr| pr.bytes_completion_increment).sum();
let _ = f.call1((update_increment,))?;
let update_increment: u64 = diff.item_diffs.iter().map(|d| d.bytes_completion_increment).sum();
let increment = if update_increment > 0 {
update_increment
} else {
diff.total_bytes_completion_increment
};
let _ = f.call1((increment,))?;
}
Ok(())
@@ -289,29 +221,126 @@ impl WrappedProgressUpdaterImpl {
}
}
#[derive(Debug)]
pub struct WrappedProgressUpdater {
inner: Arc<WrappedProgressUpdaterImpl>,
// === Simple per-file progress callback for downloads ===
/// Send a simple byte-increment update to a Python callback.
pub async fn send_simple_progress(py_func: &Py<PyAny>, increment: u64) {
if increment == 0 {
return;
}
let py_func = Python::attach(|py| py_func.clone_ref(py));
let rt = XetRuntime::current();
let _ = rt
.spawn_blocking(move || {
Python::attach(|py| {
let f = py_func.bind(py);
let _ = f.call1((increment,));
})
})
.await
.log_error("Python exception updating download progress:");
}
impl WrappedProgressUpdater {
pub fn new(py_func: Py<PyAny>) -> PyResult<Self> {
Ok(Self {
inner: Arc::new(WrappedProgressUpdaterImpl::new(py_func)?),
})
// === Progress diff types ===
#[derive(Clone, Debug)]
pub struct ItemDiff {
pub item_name: String,
pub total_bytes: u64,
pub bytes_completed: u64,
pub bytes_completion_increment: u64,
}
#[derive(Clone, Debug, Default)]
pub struct ProgressDiff {
pub total_bytes: u64,
pub total_bytes_increment: u64,
pub total_bytes_completed: u64,
pub total_bytes_completion_increment: u64,
pub total_bytes_completion_rate: Option<f64>,
pub total_transfer_bytes: u64,
pub total_transfer_bytes_increment: u64,
pub total_transfer_bytes_completed: u64,
pub total_transfer_bytes_completion_increment: u64,
pub total_transfer_bytes_completion_rate: Option<f64>,
pub item_diffs: Vec<ItemDiff>,
}
impl ProgressDiff {
pub fn is_empty(&self) -> bool {
self.total_bytes_increment == 0
&& self.total_bytes_completion_increment == 0
&& self.total_transfer_bytes_increment == 0
&& self.total_transfer_bytes_completion_increment == 0
&& self.item_diffs.is_empty()
}
}
#[async_trait::async_trait]
impl TrackingProgressUpdater for WrappedProgressUpdater {
async fn register_updates(&self, updates: ProgressUpdate) {
let inner = self.inner.clone();
// === Diff state for group-level progress polling ===
if inner.progress_updating_enabled {
let _ = inner
.register_updates_impl(updates)
.await
.log_error("Python exception updating progress:");
/// Tracks the previous progress snapshot so that incremental diffs can be
/// computed each time the session's `progress()` is polled.
pub struct GroupProgressDiffState {
prev_group: GroupProgressReport,
prev_items: HashMap<UniqueID, ItemProgressReport>,
}
impl GroupProgressDiffState {
pub fn new() -> Self {
Self {
prev_group: GroupProgressReport::default(),
prev_items: HashMap::new(),
}
}
pub fn compute_diff(
&mut self,
group: GroupProgressReport,
items: HashMap<UniqueID, ItemProgressReport>,
) -> ProgressDiff {
let total_bytes_increment = group.total_bytes.saturating_sub(self.prev_group.total_bytes);
let total_bytes_completion_increment = group
.total_bytes_completed
.saturating_sub(self.prev_group.total_bytes_completed);
let total_transfer_bytes_increment =
group.total_transfer_bytes.saturating_sub(self.prev_group.total_transfer_bytes);
let total_transfer_bytes_completion_increment = group
.total_transfer_bytes_completed
.saturating_sub(self.prev_group.total_transfer_bytes_completed);
let mut item_diffs = Vec::new();
for (&id, report) in &items {
let prev = self.prev_items.get(&id);
let prev_completed = prev.map_or(0, |p| p.bytes_completed);
let increment = report.bytes_completed.saturating_sub(prev_completed);
if increment > 0 || prev.is_none() {
item_diffs.push(ItemDiff {
item_name: report.item_name.clone(),
total_bytes: report.total_bytes,
bytes_completed: report.bytes_completed,
bytes_completion_increment: increment,
});
}
}
let diff = ProgressDiff {
total_bytes: group.total_bytes,
total_bytes_increment,
total_bytes_completed: group.total_bytes_completed,
total_bytes_completion_increment,
total_bytes_completion_rate: group.total_bytes_completion_rate,
total_transfer_bytes: group.total_transfer_bytes,
total_transfer_bytes_increment,
total_transfer_bytes_completed: group.total_transfer_bytes_completed,
total_transfer_bytes_completion_increment,
total_transfer_bytes_completion_rate: group.total_transfer_bytes_completion_rate,
item_diffs,
};
self.prev_group = group;
self.prev_items = items;
diff
}
}

View File

@@ -1,67 +0,0 @@
use std::fmt::{Debug, Formatter};
use pyo3::exceptions::PyTypeError;
use pyo3::prelude::PyAnyMethods;
use pyo3::{Py, PyAny, PyErr, PyResult, Python};
use tracing::error;
use xet_client::cas_client::auth::{AuthError, TokenInfo, TokenRefresher};
/// A wrapper struct of a python function to refresh the CAS auth token.
/// Since tokens are generated by hub, we want to be able to refresh the
/// token using the hub client, which is only available in python.
pub struct WrappedTokenRefresher {
/// The function responsible for refreshing a token.
/// Expects no inputs and returns a (str, u64) representing the new token
/// and the unixtime (in seconds) of expiration, raising an exception
/// if there is an issue.
py_func: Py<PyAny>,
name: String,
}
impl Debug for WrappedTokenRefresher {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "WrappedTokenRefresher({})", self.name)
}
}
impl WrappedTokenRefresher {
pub fn from_func(py_func: Py<PyAny>) -> PyResult<Self> {
let name = Self::validate_callable(&py_func)?;
Ok(Self { py_func, name })
}
/// Validate that the inputted python object is callable
fn validate_callable(py_func: &Py<PyAny>) -> Result<String, PyErr> {
Python::attach(|py| {
let f = py_func.bind(py);
let name = f
.repr()
.and_then(|repr| repr.extract::<String>())
.unwrap_or("unknown".to_string());
if !f.is_callable() {
error!("TokenRefresher func: {name} is not callable");
return Err(PyTypeError::new_err(format!("refresh func: {name} is not callable")));
}
Ok(name)
})
}
}
#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
impl TokenRefresher for WrappedTokenRefresher {
async fn refresh(&self) -> Result<TokenInfo, AuthError> {
Python::attach(|py| {
let f = self.py_func.bind(py);
if !f.is_callable() {
return Err(AuthError::RefreshFunctionNotCallable(self.name.clone()));
}
let result = f
.call0()
.map_err(|e| AuthError::TokenRefreshFailure(format!("Error refreshing token: {e:?}")))?;
result.extract::<(String, u64)>().map_err(|e| {
AuthError::TokenRefreshFailure(format!("refresh function didn't return a (String, u64) tuple: {e:?}"))
})
})
}
}

View File

@@ -48,6 +48,8 @@ smoke-test = []
fd-track = ["xet-runtime/fd-track", "xet-client/fd-track", "xet-data/fd-track"]
python = ["xet-runtime/python", "dep:pyo3"]
simulation = ["xet-client/simulation"]
native-tls-vendored = ["xet-client/native-tls-vendored"]
elevated_information_level = ["xet-client/elevated_information_level", "xet-runtime/elevated_information_level"]
[dev-dependencies]
anyhow = { workspace = true }