mirror of
https://github.com/huggingface/xet-core.git
synced 2026-06-04 13:30:29 +08:00
Migrate hf_xet Python bindings from legacy data_client to XetSession API
Replace all usage of xet_pkg::legacy module with the new xet_session interface for uploads, downloads, and progress tracking. Token refresh changes from Python callback to URL-based refresh handled directly in Rust.
This commit is contained in:
4
hf_xet/Cargo.lock
generated
4
hf_xet/Cargo.lock
generated
@@ -1102,7 +1102,6 @@ dependencies = [
|
||||
name = "hf_xet"
|
||||
version = "1.4.2"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"chrono",
|
||||
"hf-xet",
|
||||
"http",
|
||||
@@ -1112,9 +1111,10 @@ dependencies = [
|
||||
"pyo3",
|
||||
"rand 0.9.2",
|
||||
"signal-hook",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"winapi",
|
||||
"xet-client",
|
||||
"xet-data",
|
||||
"xet-runtime",
|
||||
]
|
||||
|
||||
|
||||
@@ -11,10 +11,10 @@ crate-type = ["cdylib", "lib"]
|
||||
|
||||
[dependencies]
|
||||
xet-runtime = { path = "../xet_runtime" }
|
||||
xet-client = { path = "../xet_client" }
|
||||
xet-data = { path = "../xet_data" }
|
||||
xet-pkg = { package = "hf-xet", path = "../xet_pkg", features = ["python"] }
|
||||
tokio = { version = "1", features = ["time", "macros", "rt"] }
|
||||
|
||||
async-trait = "0.1"
|
||||
chrono = "0.4"
|
||||
itertools = "0.14"
|
||||
lazy_static = "1.5"
|
||||
@@ -42,12 +42,12 @@ winapi = { version = "0.3", features = ["consoleapi", "wincon", "errhandlingapi"
|
||||
[features]
|
||||
default = ["no-default-cache", "elevated_information_level"] # By default, hf_xet disables the disk cache and uses aggressive logging level
|
||||
extension-module = ["pyo3/extension-module"] # Only enabled when building with maturin
|
||||
native-tls = ["xet-client/native-tls-vendored"]
|
||||
native-tls-vendored = ["xet-client/native-tls-vendored"]
|
||||
native-tls = ["xet-pkg/native-tls-vendored"]
|
||||
native-tls-vendored = ["xet-pkg/native-tls-vendored"]
|
||||
no-default-cache = ["xet-runtime/no-default-cache"]
|
||||
profiling = ["pprof"]
|
||||
tokio-console = ["xet-runtime/tokio-console"]
|
||||
elevated_information_level = ["xet-client/elevated_information_level", "xet-runtime/elevated_information_level"]
|
||||
elevated_information_level = ["xet-pkg/elevated_information_level"]
|
||||
|
||||
[profile.release]
|
||||
split-debuginfo = "packed"
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
mod logging;
|
||||
mod progress_update;
|
||||
mod runtime;
|
||||
mod token_refresh;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Debug;
|
||||
use std::iter::IntoIterator;
|
||||
use std::sync::Arc;
|
||||
use std::path::PathBuf;
|
||||
use std::time::Duration;
|
||||
|
||||
use http::header::{self, HeaderMap, HeaderName, HeaderValue};
|
||||
use itertools::Itertools;
|
||||
@@ -15,15 +14,13 @@ use pyo3::prelude::*;
|
||||
use pyo3::pyfunction;
|
||||
use rand::Rng;
|
||||
use runtime::async_run;
|
||||
use token_refresh::WrappedTokenRefresher;
|
||||
use tracing::debug;
|
||||
use xet_pkg::XetError;
|
||||
use xet_pkg::legacy::progress_tracking::TrackingProgressUpdater;
|
||||
use xet_pkg::legacy::{Sha256Policy, XetFileInfo, data_client};
|
||||
use xet_pkg::xet_session::{Sha256Policy, XetFileInfo, XetSessionBuilder};
|
||||
use xet_runtime::core::file_handle_limits;
|
||||
|
||||
use crate::logging::init_logging;
|
||||
use crate::progress_update::WrappedProgressUpdater;
|
||||
use crate::progress_update::{GroupProgressDiffState, ProgressCallback, send_simple_progress};
|
||||
|
||||
const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
|
||||
|
||||
@@ -31,37 +28,17 @@ const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VE
|
||||
#[cfg(feature = "profiling")]
|
||||
pub(crate) mod profiling;
|
||||
|
||||
/// Converts a HashMap of headers to a HeaderMap and merges in the USER_AGENT.
|
||||
///
|
||||
/// If the input contains a User-Agent header, the USER_AGENT is appended to it.
|
||||
/// Otherwise, USER_AGENT is set as the only User-Agent header.
|
||||
fn build_headers_with_user_agent(request_headers: Option<HashMap<String, String>>) -> PyResult<Option<Arc<HeaderMap>>> {
|
||||
let mut map = request_headers
|
||||
.map(|headers| {
|
||||
let mut map = HeaderMap::new();
|
||||
for (key, value) in headers {
|
||||
let name = HeaderName::from_bytes(key.as_bytes())
|
||||
.map_err(|e| PyValueError::new_err(format!("Invalid header name '{}': {}", key, e)))?;
|
||||
let value = HeaderValue::from_str(&value)
|
||||
.map_err(|e| PyValueError::new_err(format!("Invalid header value for '{}': {}", key, e)))?;
|
||||
map.insert(name, value);
|
||||
}
|
||||
Ok::<_, PyErr>(map)
|
||||
})
|
||||
.transpose()?
|
||||
.unwrap_or_default();
|
||||
/// Build a HeaderMap from a Python dict and merge in the USER_AGENT.
|
||||
fn build_headers_with_user_agent(request_headers: Option<HashMap<String, String>>) -> PyResult<HeaderMap> {
|
||||
let mut map = request_headers.map(build_header_map).transpose()?.unwrap_or_default();
|
||||
|
||||
// Append our USER_AGENT to any existing User-Agent header, or add it if not present
|
||||
let combined_user_agent = if let Some(existing_ua) = map.get(header::USER_AGENT) {
|
||||
// Append our user agent to the existing one
|
||||
let existing_str = existing_ua.to_str().unwrap_or("");
|
||||
format!("{}; {}", existing_str, USER_AGENT)
|
||||
} else {
|
||||
// No existing user agent, use ours
|
||||
USER_AGENT.to_string()
|
||||
};
|
||||
|
||||
// Try to create the combined header value, fall back gracefully if invalid
|
||||
let user_agent_value = HeaderValue::from_str(&combined_user_agent)
|
||||
.or_else(|_: http::header::InvalidHeaderValue| {
|
||||
Ok::<HeaderValue, http::header::InvalidHeaderValue>(HeaderValue::from_static(USER_AGENT))
|
||||
@@ -69,22 +46,58 @@ fn build_headers_with_user_agent(request_headers: Option<HashMap<String, String>
|
||||
.unwrap_or_else(|_: http::header::InvalidHeaderValue| HeaderValue::from_static("unknown"));
|
||||
map.insert(header::USER_AGENT, user_agent_value);
|
||||
|
||||
Ok(Some(Arc::new(map)))
|
||||
Ok(map)
|
||||
}
|
||||
|
||||
/// Build a HeaderMap from a Python dict without adding USER_AGENT.
|
||||
fn build_header_map(headers: HashMap<String, String>) -> PyResult<HeaderMap> {
|
||||
let mut map = HeaderMap::new();
|
||||
for (key, value) in headers {
|
||||
let name = HeaderName::from_bytes(key.as_bytes())
|
||||
.map_err(|e| PyValueError::new_err(format!("Invalid header name '{}': {}", key, e)))?;
|
||||
let value = HeaderValue::from_str(&value)
|
||||
.map_err(|e| PyValueError::new_err(format!("Invalid header value for '{}': {}", key, e)))?;
|
||||
map.insert(name, value);
|
||||
}
|
||||
Ok(map)
|
||||
}
|
||||
|
||||
fn convert_xet_error(e: impl Into<XetError>) -> PyErr {
|
||||
PyErr::from(e.into())
|
||||
}
|
||||
|
||||
/// Configure an auth group builder with optional endpoint, token, headers, and refresh URL.
|
||||
macro_rules! configure_auth_builder {
|
||||
($builder:expr, $endpoint:expr, $token_info:expr, $custom_headers:expr,
|
||||
$token_refresh_url:expr, $token_refresh_headers:expr) => {{
|
||||
let mut builder = $builder;
|
||||
if let Some(ep) = $endpoint {
|
||||
builder = builder.with_endpoint(ep);
|
||||
}
|
||||
if let Some((token, expiry)) = $token_info {
|
||||
builder = builder.with_token_info(token, expiry);
|
||||
}
|
||||
if let Some(headers) = $custom_headers {
|
||||
builder = builder.with_custom_headers(headers);
|
||||
}
|
||||
if let Some(url) = $token_refresh_url {
|
||||
let refresh_headers = $token_refresh_headers.unwrap_or_default();
|
||||
builder = builder.with_token_refresh_url(url, refresh_headers);
|
||||
}
|
||||
builder
|
||||
}};
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (file_contents, endpoint, token_info, token_refresher, progress_updater, _repo_type, request_headers=None, sha256s=None, skip_sha256=false), text_signature = "(file_contents: List[bytes], endpoint: Optional[str], token_info: Optional[(str, int)], token_refresher: Optional[Callable[[], (str, int)]], progress_updater: Optional[Callable[[int], None]], _repo_type: Optional[str], request_headers: Optional[Dict[str, str]], sha256s: Optional[List[str]], skip_sha256: bool = False) -> List[PyXetUploadInfo]")]
|
||||
#[pyo3(signature = (file_contents, endpoint, token_info, token_refresh_url, token_refresh_headers, progress_updater, _repo_type, request_headers=None, sha256s=None, skip_sha256=false), text_signature = "(file_contents: List[bytes], endpoint: Optional[str], token_info: Optional[(str, int)], token_refresh_url: Optional[str], token_refresh_headers: Optional[Dict[str, str]], progress_updater: Optional[Callable], _repo_type: Optional[str], request_headers: Optional[Dict[str, str]], sha256s: Optional[List[str]], skip_sha256: bool = False) -> List[PyXetUploadInfo]")]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn upload_bytes(
|
||||
py: Python,
|
||||
file_contents: Vec<Vec<u8>>,
|
||||
endpoint: Option<String>,
|
||||
token_info: Option<(String, u64)>,
|
||||
token_refresher: Option<Py<PyAny>>,
|
||||
token_refresh_url: Option<String>,
|
||||
token_refresh_headers: Option<HashMap<String, String>>,
|
||||
progress_updater: Option<Py<PyAny>>,
|
||||
_repo_type: Option<String>,
|
||||
request_headers: Option<HashMap<String, String>>,
|
||||
@@ -111,13 +124,11 @@ pub fn upload_bytes(
|
||||
None => vec![Sha256Policy::Compute; file_contents.len()],
|
||||
};
|
||||
|
||||
let refresher = token_refresher.map(WrappedTokenRefresher::from_func).transpose()?.map(Arc::new);
|
||||
let updater = progress_updater.map(WrappedProgressUpdater::new).transpose()?.map(Arc::new);
|
||||
let callback = progress_updater.map(ProgressCallback::new).transpose()?;
|
||||
let custom_headers = build_headers_with_user_agent(request_headers)?;
|
||||
let refresh_headers = token_refresh_headers.map(build_header_map).transpose()?;
|
||||
let x: u64 = rand::rng().random();
|
||||
|
||||
// Convert Python dict -> Rust HashMap -> HeaderMap and merge with USER_AGENT
|
||||
let header_map = build_headers_with_user_agent(request_headers)?;
|
||||
|
||||
async_run(py, async move {
|
||||
debug!(
|
||||
"Upload bytes call {x:x}: (PID = {}) Uploading {} files as bytes.",
|
||||
@@ -125,36 +136,43 @@ pub fn upload_bytes(
|
||||
file_contents.len(),
|
||||
);
|
||||
|
||||
let out: Vec<PyXetUploadInfo> = data_client::upload_bytes_async(
|
||||
file_contents,
|
||||
sha256_policies,
|
||||
let session = XetSessionBuilder::new().build().map_err(convert_xet_error)?;
|
||||
let builder = session.new_upload_commit().map_err(convert_xet_error)?;
|
||||
let builder = configure_auth_builder!(
|
||||
builder,
|
||||
endpoint,
|
||||
token_info,
|
||||
refresher.map(|v| v as Arc<_>),
|
||||
updater.map(|v| v as Arc<_>),
|
||||
header_map,
|
||||
)
|
||||
.await
|
||||
.map_err(convert_xet_error)?
|
||||
.into_iter()
|
||||
.map(PyXetUploadInfo::from)
|
||||
.collect();
|
||||
Some(custom_headers),
|
||||
token_refresh_url,
|
||||
refresh_headers
|
||||
);
|
||||
let commit = builder.build().await.map_err(convert_xet_error)?;
|
||||
|
||||
// Upload all byte blobs
|
||||
let mut handles = Vec::with_capacity(file_contents.len());
|
||||
for (blob, sha256) in file_contents.into_iter().zip(sha256_policies) {
|
||||
let handle = commit.upload_bytes(blob, sha256, None).await.map_err(convert_xet_error)?;
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Commit with concurrent progress polling
|
||||
let out = commit_with_progress(&commit, &handles, callback.as_ref()).await?;
|
||||
|
||||
debug!("Upload bytes call {x:x} finished.");
|
||||
|
||||
PyResult::Ok(out)
|
||||
})
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (file_paths, endpoint, token_info, token_refresher, progress_updater, _repo_type, request_headers=None, sha256s=None, skip_sha256=false), text_signature = "(file_paths: List[str], endpoint: Optional[str], token_info: Optional[(str, int)], token_refresher: Optional[Callable[[], (str, int)]], progress_updater: Optional[Callable[[int], None]], _repo_type: Optional[str], request_headers: Optional[Dict[str, str]], sha256s: Optional[List[str]], skip_sha256: bool = False) -> List[PyXetUploadInfo]")]
|
||||
#[pyo3(signature = (file_paths, endpoint, token_info, token_refresh_url, token_refresh_headers, progress_updater, _repo_type, request_headers=None, sha256s=None, skip_sha256=false), text_signature = "(file_paths: List[str], endpoint: Optional[str], token_info: Optional[(str, int)], token_refresh_url: Optional[str], token_refresh_headers: Optional[Dict[str, str]], progress_updater: Optional[Callable], _repo_type: Optional[str], request_headers: Optional[Dict[str, str]], sha256s: Optional[List[str]], skip_sha256: bool = False) -> List[PyXetUploadInfo]")]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn upload_files(
|
||||
py: Python,
|
||||
file_paths: Vec<String>,
|
||||
endpoint: Option<String>,
|
||||
token_info: Option<(String, u64)>,
|
||||
token_refresher: Option<Py<PyAny>>,
|
||||
token_refresh_url: Option<String>,
|
||||
token_refresh_headers: Option<HashMap<String, String>>,
|
||||
progress_updater: Option<Py<PyAny>>,
|
||||
_repo_type: Option<String>,
|
||||
request_headers: Option<HashMap<String, String>>,
|
||||
@@ -181,16 +199,13 @@ pub fn upload_files(
|
||||
None => vec![Sha256Policy::Compute; file_paths.len()],
|
||||
};
|
||||
|
||||
let refresher = token_refresher.map(WrappedTokenRefresher::from_func).transpose()?.map(Arc::new);
|
||||
let updater = progress_updater.map(WrappedProgressUpdater::new).transpose()?.map(Arc::new);
|
||||
let callback = progress_updater.map(ProgressCallback::new).transpose()?;
|
||||
let custom_headers = build_headers_with_user_agent(request_headers)?;
|
||||
let refresh_headers = token_refresh_headers.map(build_header_map).transpose()?;
|
||||
|
||||
let file_names = file_paths.iter().take(3).join(", ");
|
||||
|
||||
let x: u64 = rand::rng().random();
|
||||
|
||||
// Convert Python dict -> Rust HashMap -> HeaderMap and merge with USER_AGENT
|
||||
let header_map = build_headers_with_user_agent(request_headers)?;
|
||||
|
||||
async_run(py, async move {
|
||||
debug!(
|
||||
"Upload call {x:x}: (PID = {}) Uploading {} files {file_names}{}",
|
||||
@@ -199,57 +214,117 @@ pub fn upload_files(
|
||||
if file_paths.len() > 3 { "..." } else { "." }
|
||||
);
|
||||
|
||||
let out: Vec<PyXetUploadInfo> = data_client::upload_async(
|
||||
file_paths,
|
||||
sha256_policies,
|
||||
let session = XetSessionBuilder::new().build().map_err(convert_xet_error)?;
|
||||
let builder = session.new_upload_commit().map_err(convert_xet_error)?;
|
||||
let builder = configure_auth_builder!(
|
||||
builder,
|
||||
endpoint,
|
||||
token_info,
|
||||
refresher.map(|v| v as Arc<_>),
|
||||
updater.map(|v| v as Arc<_>),
|
||||
header_map,
|
||||
)
|
||||
.await
|
||||
.map_err(convert_xet_error)?
|
||||
.into_iter()
|
||||
.map(PyXetUploadInfo::from)
|
||||
.collect();
|
||||
Some(custom_headers),
|
||||
token_refresh_url,
|
||||
refresh_headers
|
||||
);
|
||||
let commit = builder.build().await.map_err(convert_xet_error)?;
|
||||
|
||||
// Upload all files
|
||||
let mut handles = Vec::with_capacity(file_paths.len());
|
||||
for (path, sha256) in file_paths.into_iter().zip(sha256_policies) {
|
||||
let handle = commit
|
||||
.upload_from_path(PathBuf::from(path), sha256)
|
||||
.await
|
||||
.map_err(convert_xet_error)?;
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Commit with concurrent progress polling
|
||||
let out = commit_with_progress(&commit, &handles, callback.as_ref()).await?;
|
||||
|
||||
debug!("Upload call {x:x} finished.");
|
||||
PyResult::Ok(out)
|
||||
})
|
||||
}
|
||||
|
||||
/// Commit an upload, polling progress concurrently if a callback is provided.
|
||||
/// Returns results in the same order as the handles.
|
||||
async fn commit_with_progress(
|
||||
commit: &xet_pkg::xet_session::XetUploadCommit,
|
||||
handles: &[xet_pkg::xet_session::XetFileUpload],
|
||||
callback: Option<&ProgressCallback>,
|
||||
) -> PyResult<Vec<PyXetUploadInfo>> {
|
||||
// Collect task_ids in order so we can map results back
|
||||
let task_ids: Vec<_> = handles.iter().map(|h| h.task_id()).collect();
|
||||
|
||||
let use_progress = callback.is_some_and(|c| c.is_enabled());
|
||||
|
||||
if use_progress {
|
||||
let callback = callback.unwrap();
|
||||
let commit_clone = commit.clone();
|
||||
let mut commit_task = tokio::spawn(async move { commit_clone.commit().await });
|
||||
|
||||
let mut diff_state = GroupProgressDiffState::new();
|
||||
let mut interval = tokio::time::interval(Duration::from_millis(250));
|
||||
|
||||
let report = loop {
|
||||
tokio::select! {
|
||||
result = &mut commit_task => {
|
||||
let report = result
|
||||
.map_err(|e| convert_xet_error(XetError::from(e)))?
|
||||
.map_err(convert_xet_error)?;
|
||||
|
||||
// Final progress update
|
||||
let group = commit.progress();
|
||||
let items = handles
|
||||
.iter()
|
||||
.filter_map(|h| h.progress().map(|p| (h.task_id(), p)))
|
||||
.collect();
|
||||
let diff = diff_state.compute_diff(group, items);
|
||||
let _ = callback.send_update(diff).await;
|
||||
|
||||
break report;
|
||||
}
|
||||
_ = interval.tick() => {
|
||||
let group = commit.progress();
|
||||
let items = handles
|
||||
.iter()
|
||||
.filter_map(|h| h.progress().map(|p| (h.task_id(), p)))
|
||||
.collect();
|
||||
let diff = diff_state.compute_diff(group, items);
|
||||
let _ = callback.send_update(diff).await;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Map results in order
|
||||
let mut out = Vec::with_capacity(task_ids.len());
|
||||
for id in &task_ids {
|
||||
let meta = report
|
||||
.uploads
|
||||
.get(id)
|
||||
.ok_or_else(|| convert_xet_error(XetError::Internal(format!("missing upload result for task {id}"))))?;
|
||||
out.push(PyXetUploadInfo::from(meta.xet_info.clone()));
|
||||
}
|
||||
Ok(out)
|
||||
} else {
|
||||
let report = commit.commit().await.map_err(convert_xet_error)?;
|
||||
|
||||
let mut out = Vec::with_capacity(task_ids.len());
|
||||
for id in &task_ids {
|
||||
let meta = report
|
||||
.uploads
|
||||
.get(id)
|
||||
.ok_or_else(|| convert_xet_error(XetError::Internal(format!("missing upload result for task {id}"))))?;
|
||||
out.push(PyXetUploadInfo::from(meta.xet_info.clone()));
|
||||
}
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
/// Compute xet hashes for files without uploading.
|
||||
///
|
||||
/// This function computes cryptographic hashes for the specified files using the same
|
||||
/// chunking and hashing algorithm as upload operations, but without requiring
|
||||
/// authentication or server connection. The resulting hashes can be used to verify
|
||||
/// file integrity after downloads or to determine which files need to be uploaded.
|
||||
///
|
||||
/// Args:
|
||||
/// file_paths: List of file paths to hash.
|
||||
///
|
||||
/// Returns:
|
||||
/// List[PyXetUploadInfo]: List of hash results in the same order as input paths.
|
||||
/// Each result contains the hash (as hex string) and file size in bytes.
|
||||
///
|
||||
/// Raises:
|
||||
/// RuntimeError: If any file cannot be read or hashed.
|
||||
///
|
||||
/// Example:
|
||||
/// >>> import hf_xet
|
||||
/// >>> results = hf_xet.hash_files(["/path/to/file1.txt", "/path/to/file2.txt"])
|
||||
/// >>> for path, info in zip(file_paths, results):
|
||||
/// ... print(f"Hash: {info.hash}, Size: {info.file_size}")
|
||||
///
|
||||
/// Note:
|
||||
/// This function is primarily used for validation and verification of transferred
|
||||
/// files. Clients can verify that downloaded files are correctly reassembled by
|
||||
/// comparing the computed hash with the expected hash from the server.
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (file_paths), text_signature = "(file_paths: List[str]) -> List[PyXetUploadInfo]")]
|
||||
pub fn hash_files(py: Python, file_paths: Vec<String>) -> PyResult<Vec<PyXetUploadInfo>> {
|
||||
async_run(py, async move {
|
||||
let out: Vec<PyXetUploadInfo> = data_client::hash_files_async(file_paths)
|
||||
let out: Vec<PyXetUploadInfo> = xet_data::processing::data_client::hash_files_async(file_paths)
|
||||
.await
|
||||
.map_err(convert_xet_error)?
|
||||
.into_iter()
|
||||
@@ -261,25 +336,23 @@ pub fn hash_files(py: Python, file_paths: Vec<String>) -> PyResult<Vec<PyXetUplo
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (files, endpoint, token_info, token_refresher, progress_updater, request_headers=None), text_signature = "(files: List[PyXetDownloadInfo], endpoint: Optional[str], token_info: Optional[(str, int)], token_refresher: Optional[Callable[[], (str, int)]], progress_updater: Optional[List[Callable[[int], None]]], request_headers: Optional[Dict[str, str]]) -> List[str]")]
|
||||
#[pyo3(signature = (files, endpoint, token_info, token_refresh_url, token_refresh_headers, progress_updater, request_headers=None), text_signature = "(files: List[PyXetDownloadInfo], endpoint: Optional[str], token_info: Optional[(str, int)], token_refresh_url: Optional[str], token_refresh_headers: Optional[Dict[str, str]], progress_updater: Optional[List[Callable[[int], None]]], request_headers: Optional[Dict[str, str]]) -> List[str]")]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn download_files(
|
||||
py: Python,
|
||||
files: Vec<PyXetDownloadInfo>,
|
||||
endpoint: Option<String>,
|
||||
token_info: Option<(String, u64)>,
|
||||
token_refresher: Option<Py<PyAny>>,
|
||||
token_refresh_url: Option<String>,
|
||||
token_refresh_headers: Option<HashMap<String, String>>,
|
||||
progress_updater: Option<Vec<Py<PyAny>>>,
|
||||
request_headers: Option<HashMap<String, String>>,
|
||||
) -> PyResult<Vec<String>> {
|
||||
let file_infos: Vec<_> = files.into_iter().map(<(XetFileInfo, DestinationPath)>::from).collect();
|
||||
let refresher = token_refresher.map(WrappedTokenRefresher::from_func).transpose()?.map(Arc::new);
|
||||
let updaters = progress_updater.map(try_parse_progress_updaters).transpose()?;
|
||||
|
||||
// Convert Python dict -> Rust HashMap -> HeaderMap and merge with USER_AGENT
|
||||
let header_map = build_headers_with_user_agent(request_headers)?;
|
||||
let custom_headers = build_headers_with_user_agent(request_headers)?;
|
||||
let refresh_headers = token_refresh_headers.map(build_header_map).transpose()?;
|
||||
|
||||
let x: u64 = rand::rng().random();
|
||||
|
||||
let file_names = file_infos.iter().take(3).map(|(_, p)| p).join(", ");
|
||||
|
||||
async_run(py, async move {
|
||||
@@ -290,39 +363,89 @@ pub fn download_files(
|
||||
if file_infos.len() > 3 { "..." } else { "." }
|
||||
);
|
||||
|
||||
let out: Vec<String> = data_client::download_async(
|
||||
file_infos,
|
||||
let session = XetSessionBuilder::new().build().map_err(convert_xet_error)?;
|
||||
let builder = session.new_file_download_group().map_err(convert_xet_error)?;
|
||||
let builder = configure_auth_builder!(
|
||||
builder,
|
||||
endpoint,
|
||||
token_info,
|
||||
refresher.map(|v| v as Arc<_>),
|
||||
updaters,
|
||||
header_map,
|
||||
)
|
||||
.await
|
||||
.map_err(convert_xet_error)?;
|
||||
Some(custom_headers),
|
||||
token_refresh_url,
|
||||
refresh_headers
|
||||
);
|
||||
let group = builder.build().await.map_err(convert_xet_error)?;
|
||||
|
||||
// Queue all downloads
|
||||
let mut dl_handles = Vec::with_capacity(file_infos.len());
|
||||
let mut paths = Vec::with_capacity(file_infos.len());
|
||||
for (file_info, dest_path) in file_infos {
|
||||
let handle = group
|
||||
.download_file_to_path(file_info, PathBuf::from(&dest_path))
|
||||
.await
|
||||
.map_err(convert_xet_error)?;
|
||||
dl_handles.push(handle);
|
||||
paths.push(dest_path);
|
||||
}
|
||||
|
||||
// Finish with concurrent progress polling
|
||||
let group_for_finish = group.clone();
|
||||
let mut finish_task = tokio::spawn(async move { group_for_finish.finish().await });
|
||||
|
||||
if let Some(updaters) = progress_updater {
|
||||
let mut prev_completed: Vec<u64> = vec![0; dl_handles.len()];
|
||||
let mut interval = tokio::time::interval(Duration::from_millis(250));
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = &mut finish_task => {
|
||||
result
|
||||
.map_err(|e| convert_xet_error(XetError::from(e)))?
|
||||
.map_err(convert_xet_error)?;
|
||||
|
||||
// Final per-file progress update
|
||||
for (i, handle) in dl_handles.iter().enumerate() {
|
||||
if i < updaters.len()
|
||||
&& let Some(report) = handle.progress()
|
||||
{
|
||||
let increment = report.bytes_completed.saturating_sub(prev_completed[i]);
|
||||
send_simple_progress(&updaters[i], increment).await;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
_ = interval.tick() => {
|
||||
for (i, handle) in dl_handles.iter().enumerate() {
|
||||
if i < updaters.len()
|
||||
&& let Some(report) = handle.progress()
|
||||
{
|
||||
let increment = report.bytes_completed.saturating_sub(prev_completed[i]);
|
||||
if increment > 0 {
|
||||
prev_completed[i] = report.bytes_completed;
|
||||
send_simple_progress(&updaters[i], increment).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
finish_task
|
||||
.await
|
||||
.map_err(|e| convert_xet_error(XetError::from(e)))?
|
||||
.map_err(convert_xet_error)?;
|
||||
}
|
||||
|
||||
debug!("Download call {x:x}: Completed.");
|
||||
|
||||
PyResult::Ok(out)
|
||||
PyResult::Ok(paths)
|
||||
})
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
pub fn force_sigint_shutdown() -> PyResult<()> {
|
||||
// Force a signint shutdown in the case where it gets intercepted by another process.
|
||||
crate::runtime::perform_sigint_shutdown();
|
||||
Err(PyKeyboardInterrupt::new_err(()))
|
||||
}
|
||||
|
||||
fn try_parse_progress_updaters(funcs: Vec<Py<PyAny>>) -> PyResult<Vec<Arc<dyn TrackingProgressUpdater>>> {
|
||||
let mut updaters = Vec::with_capacity(funcs.len());
|
||||
for updater_func in funcs {
|
||||
let wrapped = Arc::new(WrappedProgressUpdater::new(updater_func)?);
|
||||
updaters.push(wrapped as Arc<dyn TrackingProgressUpdater>);
|
||||
}
|
||||
Ok(updaters)
|
||||
}
|
||||
|
||||
// TODO: we won't need to subclass this in the next major version update.
|
||||
#[pyclass(subclass)]
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -508,9 +631,7 @@ mod tests {
|
||||
|
||||
// Initialize Python once for all tests
|
||||
fn setup() {
|
||||
// When auto-initialize is enabled, Python will be initialized on first use
|
||||
// This ensures Python is available for the tests
|
||||
let _ = pyo3::Python::attach(|_py| {});
|
||||
pyo3::Python::attach(|_py| {});
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -518,23 +639,19 @@ mod tests {
|
||||
setup();
|
||||
let empty_map: HashMap<String, String> = HashMap::new();
|
||||
let result = build_headers_with_user_agent(Some(empty_map)).unwrap();
|
||||
let headers = result.unwrap();
|
||||
|
||||
// Should have exactly one header: USER_AGENT
|
||||
assert_eq!(headers.len(), 1);
|
||||
assert!(headers.contains_key(header::USER_AGENT));
|
||||
assert_eq!(result.len(), 1);
|
||||
assert!(result.contains_key(header::USER_AGENT));
|
||||
|
||||
let user_agent = headers.get(header::USER_AGENT).unwrap().to_str().unwrap();
|
||||
let user_agent = result.get(header::USER_AGENT).unwrap().to_str().unwrap();
|
||||
assert_eq!(user_agent, USER_AGENT);
|
||||
|
||||
let result = build_headers_with_user_agent(None).unwrap();
|
||||
let headers = result.unwrap();
|
||||
|
||||
// Should have exactly one header: USER_AGENT
|
||||
assert_eq!(headers.len(), 1);
|
||||
assert!(headers.contains_key(header::USER_AGENT));
|
||||
assert_eq!(result.len(), 1);
|
||||
assert!(result.contains_key(header::USER_AGENT));
|
||||
|
||||
let user_agent = headers.get(header::USER_AGENT).unwrap().to_str().unwrap();
|
||||
let user_agent = result.get(header::USER_AGENT).unwrap().to_str().unwrap();
|
||||
assert_eq!(user_agent, USER_AGENT);
|
||||
}
|
||||
|
||||
@@ -546,17 +663,12 @@ mod tests {
|
||||
headers_map.insert("Authorization".to_string(), "Bearer token123".to_string());
|
||||
|
||||
let result = build_headers_with_user_agent(Some(headers_map)).unwrap();
|
||||
let headers = result.unwrap();
|
||||
|
||||
// Should have 3 headers: Content-Type, Authorization, and USER_AGENT
|
||||
assert_eq!(headers.len(), 3);
|
||||
assert_eq!(result.len(), 3);
|
||||
assert_eq!(result.get(header::CONTENT_TYPE).unwrap().to_str().unwrap(), "application/json");
|
||||
assert_eq!(result.get(header::AUTHORIZATION).unwrap().to_str().unwrap(), "Bearer token123");
|
||||
|
||||
// Verify each header was converted correctly
|
||||
assert_eq!(headers.get(header::CONTENT_TYPE).unwrap().to_str().unwrap(), "application/json");
|
||||
assert_eq!(headers.get(header::AUTHORIZATION).unwrap().to_str().unwrap(), "Bearer token123");
|
||||
|
||||
// Verify USER_AGENT was added
|
||||
let user_agent = headers.get(header::USER_AGENT).unwrap().to_str().unwrap();
|
||||
let user_agent = result.get(header::USER_AGENT).unwrap().to_str().unwrap();
|
||||
assert_eq!(user_agent, USER_AGENT);
|
||||
}
|
||||
|
||||
@@ -567,13 +679,10 @@ mod tests {
|
||||
headers_map.insert("User-Agent".to_string(), "CustomClient/1.0".to_string());
|
||||
|
||||
let result = build_headers_with_user_agent(Some(headers_map)).unwrap();
|
||||
let headers = result.unwrap();
|
||||
|
||||
// Should have exactly one header: USER_AGENT
|
||||
assert_eq!(headers.len(), 1);
|
||||
assert_eq!(result.len(), 1);
|
||||
|
||||
// Verify USER_AGENT was appended to existing one
|
||||
let user_agent = headers.get(header::USER_AGENT).unwrap().to_str().unwrap();
|
||||
let user_agent = result.get(header::USER_AGENT).unwrap().to_str().unwrap();
|
||||
assert_eq!(user_agent, format!("CustomClient/1.0; {}", USER_AGENT));
|
||||
}
|
||||
|
||||
@@ -585,18 +694,15 @@ mod tests {
|
||||
|
||||
let result = build_headers_with_user_agent(Some(headers_map));
|
||||
|
||||
// Should return an error for invalid header name
|
||||
assert!(result.is_err());
|
||||
let err_msg = result.unwrap_err().to_string();
|
||||
assert!(err_msg.contains("Invalid header name"));
|
||||
|
||||
let mut headers_map = HashMap::new();
|
||||
// Header values cannot contain newlines
|
||||
headers_map.insert("X-Custom".to_string(), "value\nwith\nnewlines".to_string());
|
||||
|
||||
let result = build_headers_with_user_agent(Some(headers_map));
|
||||
|
||||
// Should return an error for invalid header value
|
||||
assert!(result.is_err());
|
||||
let err_msg = result.unwrap_err().to_string();
|
||||
assert!(err_msg.contains("Invalid header value"));
|
||||
@@ -607,9 +713,9 @@ mod tests {
|
||||
setup();
|
||||
pyo3::Python::attach(|py| {
|
||||
let file_paths = vec!["a.txt".to_string(), "b.txt".to_string()];
|
||||
let sha256s = Some(vec!["abc123".to_string()]); // 1 hash for 2 files
|
||||
let sha256s = Some(vec!["abc123".to_string()]);
|
||||
|
||||
let result = upload_files(py, file_paths, None, None, None, None, None, None, sha256s, false);
|
||||
let result = upload_files(py, file_paths, None, None, None, None, None, None, None, sha256s, false);
|
||||
|
||||
assert!(result.is_err());
|
||||
let err_msg = result.unwrap_err().to_string();
|
||||
@@ -624,7 +730,7 @@ mod tests {
|
||||
let file_paths = vec!["a.txt".to_string()];
|
||||
let sha256s = Some(vec!["abc123".to_string()]);
|
||||
|
||||
let result = upload_files(py, file_paths, None, None, None, None, None, None, sha256s, true);
|
||||
let result = upload_files(py, file_paths, None, None, None, None, None, None, None, sha256s, true);
|
||||
|
||||
assert!(result.is_err());
|
||||
let err_msg = result.unwrap_err().to_string();
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::{Debug, Formatter};
|
||||
use std::sync::Arc;
|
||||
|
||||
use itertools::Itertools;
|
||||
use pyo3::exceptions::PyTypeError;
|
||||
@@ -7,163 +7,81 @@ use pyo3::prelude::PyAnyMethods;
|
||||
use pyo3::types::{IntoPyDict, PyList, PyString};
|
||||
use pyo3::{IntoPyObjectExt, Py, PyAny, PyResult, Python, pyclass};
|
||||
use tracing::error;
|
||||
use xet_pkg::legacy::progress_tracking::{ProgressUpdate, TrackingProgressUpdater};
|
||||
use xet_pkg::xet_session::{GroupProgressReport, ItemProgressReport, UniqueID};
|
||||
use xet_runtime::core::XetRuntime;
|
||||
use xet_runtime::error_printer::ErrorPrinter;
|
||||
|
||||
use crate::runtime::convert_multithreading_error;
|
||||
|
||||
/// Python-exposed versions of the per-item and total progress update classes.
|
||||
///
|
||||
/// Both `PyTotalProgressUpdate` and `PyItemProgressUpdate` are passed
|
||||
/// into a Python callback given to the wrapper class below. For example:
|
||||
///
|
||||
/// ```python
|
||||
/// def update_progress(self, total_update, item_updates):
|
||||
/// from rich.progress import Progress, TextColumn, BarColumn, TimeRemainingColumn
|
||||
///
|
||||
/// # Update overall progress (we assume this has been initialized).
|
||||
/// self.progress.update(
|
||||
/// self.bytes_processed_task_id,
|
||||
/// advance=total_update.total_bytes_completion_increment,
|
||||
/// total = total_update.total_bytes
|
||||
/// )
|
||||
///
|
||||
/// # Update upload progress ; the total may have changed so set that too.
|
||||
/// self.progress.update(
|
||||
/// self.bytes_uploaded_task_id,
|
||||
/// advance=total_update.total_transfer_bytes_completion_increment,
|
||||
/// total = total_update.total_transfer_bytes
|
||||
/// )
|
||||
///
|
||||
/// # Update each item:
|
||||
/// for item in item_updates:
|
||||
/// name = item.item_name
|
||||
/// if name not in self.item_tasks:
|
||||
/// self.item_tasks[name] = self.progress.add_task(
|
||||
/// name, total=item.total_bytes
|
||||
/// )
|
||||
/// self.progress.update(
|
||||
/// self.item_tasks[name],
|
||||
/// advance=item.bytes_completion_increment,
|
||||
/// )
|
||||
/// ```
|
||||
///
|
||||
/// In addition, the other possible bookkeeping values for everything are contained in this
|
||||
/// as needed.
|
||||
// === PyO3 progress update classes (exposed to Python) ===
|
||||
|
||||
#[pyclass]
|
||||
pub struct PyItemProgressUpdate {
|
||||
/// The name of the item, or a tag that is translated later.
|
||||
#[pyo3(get)]
|
||||
pub item_name: Py<PyString>,
|
||||
|
||||
/// The total bytes contained in this item.
|
||||
#[pyo3(get)]
|
||||
pub total_bytes: u64,
|
||||
|
||||
/// The number of bytes completed so far, either by deduplication or transfer.
|
||||
#[pyo3(get)]
|
||||
pub bytes_completed: u64,
|
||||
|
||||
/// The change in bytes completed since the last update.
|
||||
#[pyo3(get)]
|
||||
pub bytes_completion_increment: u64,
|
||||
}
|
||||
|
||||
/// Update class for total updates
|
||||
#[pyclass]
|
||||
pub struct PyTotalProgressUpdate {
|
||||
/// The total bytes known for processing and possibly uploaded or downloaded.
|
||||
#[pyo3(get)]
|
||||
pub total_bytes: u64,
|
||||
|
||||
/// How much total_bytes has changed from the last update.
|
||||
#[pyo3(get)]
|
||||
pub total_bytes_increment: u64,
|
||||
|
||||
/// How many of the bytes queued for processing have been examined
|
||||
/// and either deduped or queued for upload or download.
|
||||
#[pyo3(get)]
|
||||
pub total_bytes_completed: u64,
|
||||
|
||||
/// The change in total_bytes_completed since the same upload.
|
||||
#[pyo3(get)]
|
||||
pub total_bytes_completion_increment: u64,
|
||||
|
||||
/// If known, the current completion speed.
|
||||
#[pyo3(get)]
|
||||
pub total_bytes_completion_rate: Option<f64>,
|
||||
|
||||
/// The total bytes scheduled for transfer; also contained in total_bytes.
|
||||
#[pyo3(get)]
|
||||
pub total_transfer_bytes: u64,
|
||||
|
||||
/// How much total_transfer_bytes has changed since the last update.
|
||||
#[pyo3(get)]
|
||||
pub total_transfer_bytes_increment: u64,
|
||||
|
||||
/// The cumulative bytes uploaded or downloaded so far. Also contained in total_bytes_completed.
|
||||
#[pyo3(get)]
|
||||
pub total_transfer_bytes_completed: u64,
|
||||
|
||||
/// The change in total_transfer_bytes_completed since the last update.
|
||||
#[pyo3(get)]
|
||||
pub total_transfer_bytes_completion_increment: u64,
|
||||
|
||||
/// If known, the current completion speed for bytes transferred.
|
||||
#[pyo3(get)]
|
||||
pub total_transfer_bytes_completion_rate: Option<f64>,
|
||||
}
|
||||
|
||||
/// A wrapper over a passed-in python function to update
|
||||
/// the python process of some download/upload progress
|
||||
/// implements the ProgressUpdater trait and should be
|
||||
/// passed around as a ProgressUpdater trait object or
|
||||
/// as a template parameter
|
||||
struct WrappedProgressUpdaterImpl {
|
||||
/// Is this enabled?
|
||||
progress_updating_enabled: bool,
|
||||
|
||||
/// the function py_func is responsible for passing in the update value
|
||||
/// into the python context. Expects 1 int (uint64) parameter that
|
||||
/// is a number to increment the progress counter by.
|
||||
py_func: Py<PyAny>,
|
||||
name: String,
|
||||
|
||||
/// Whether to use the simple incremental progress updating method or
|
||||
/// the more detailed
|
||||
update_with_detailed_progress: bool,
|
||||
}
|
||||
|
||||
impl Debug for WrappedProgressUpdaterImpl {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "WrappedTokenRefresher({})", self.name)
|
||||
}
|
||||
}
|
||||
// === Progress callback (validated Python callable) ===
|
||||
|
||||
const DETAILED_PROGRESS_ARG_NAMES: [&str; 2] = ["total_update", "item_updates"];
|
||||
|
||||
impl WrappedProgressUpdaterImpl {
|
||||
pub fn new(py_func: Py<PyAny>) -> PyResult<Self> {
|
||||
// Analyze the function to make sure it's the correct form. If it's 4 arguments with
|
||||
// the appropriate names, than we call it using the detailed progress update; if it's
|
||||
// a single function, we assume it's a global increment function and just pass in the update
|
||||
// increment.
|
||||
//
|
||||
// Run on compute thread that doesn't block async workers
|
||||
Python::attach(|py| {
|
||||
let func = py_func.bind(py);
|
||||
/// A validated Python progress callback. Determines on construction whether
|
||||
/// the callback uses the simple (1-arg increment) or detailed (2-arg total +
|
||||
/// items) calling convention.
|
||||
pub struct ProgressCallback {
|
||||
py_func: Py<PyAny>,
|
||||
detailed: bool,
|
||||
enabled: bool,
|
||||
}
|
||||
|
||||
// Test if it's enabled first; if None is passed in, then this is disabled.
|
||||
impl Debug for ProgressCallback {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "ProgressCallback(detailed={}, enabled={})", self.detailed, self.enabled)
|
||||
}
|
||||
}
|
||||
|
||||
impl ProgressCallback {
|
||||
pub fn new(py_func: Py<PyAny>) -> PyResult<Self> {
|
||||
Python::attach(|py| {
|
||||
if py_func.is_none(py) {
|
||||
return Ok(Self {
|
||||
progress_updating_enabled: false,
|
||||
py_func,
|
||||
name: Default::default(),
|
||||
update_with_detailed_progress: false,
|
||||
detailed: false,
|
||||
enabled: false,
|
||||
});
|
||||
}
|
||||
|
||||
let func = py_func.bind(py);
|
||||
let name = func
|
||||
.repr()
|
||||
.and_then(|repr| repr.extract::<String>())
|
||||
@@ -187,7 +105,7 @@ impl WrappedProgressUpdaterImpl {
|
||||
})
|
||||
.collect::<PyResult<_>>()?;
|
||||
|
||||
let update_with_detailed_progress = match param_names.len() {
|
||||
let detailed = match param_names.len() {
|
||||
1 => false,
|
||||
2 => {
|
||||
if param_names
|
||||
@@ -212,43 +130,53 @@ impl WrappedProgressUpdaterImpl {
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
progress_updating_enabled: true,
|
||||
py_func,
|
||||
name,
|
||||
update_with_detailed_progress,
|
||||
detailed,
|
||||
enabled: true,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
async fn register_updates_impl(self: Arc<Self>, updates: ProgressUpdate) -> PyResult<()> {
|
||||
// Run on compute thread that doesn't block async workers
|
||||
pub fn is_enabled(&self) -> bool {
|
||||
self.enabled
|
||||
}
|
||||
|
||||
/// Send a progress diff to the Python callback via spawn_blocking
|
||||
/// (to avoid blocking the async runtime while holding the GIL).
|
||||
pub async fn send_update(&self, diff: ProgressDiff) -> PyResult<()> {
|
||||
if !self.enabled || diff.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let py_func = Python::attach(|py| self.py_func.clone_ref(py));
|
||||
let detailed = self.detailed;
|
||||
|
||||
let rt = XetRuntime::current();
|
||||
rt.spawn_blocking(move || {
|
||||
Python::attach(|py| {
|
||||
let f = self.py_func.bind(py);
|
||||
let f = py_func.bind(py);
|
||||
|
||||
if self.update_with_detailed_progress {
|
||||
let total_update_report: Py<PyAny> = Py::new(
|
||||
if detailed {
|
||||
let total_update: Py<PyAny> = Py::new(
|
||||
py,
|
||||
PyTotalProgressUpdate {
|
||||
total_bytes: updates.total_bytes,
|
||||
total_bytes_increment: updates.total_bytes_increment,
|
||||
total_bytes_completed: updates.total_bytes_completed,
|
||||
total_bytes_completion_increment: updates.total_bytes_completion_increment,
|
||||
total_bytes_completion_rate: updates.total_bytes_completion_rate,
|
||||
total_transfer_bytes: updates.total_transfer_bytes,
|
||||
total_transfer_bytes_increment: updates.total_transfer_bytes_increment,
|
||||
total_transfer_bytes_completed: updates.total_transfer_bytes_completed,
|
||||
total_transfer_bytes_completion_increment: updates
|
||||
.total_transfer_bytes_completion_increment,
|
||||
total_transfer_bytes_completion_rate: updates.total_transfer_bytes_completion_rate,
|
||||
total_bytes: diff.total_bytes,
|
||||
total_bytes_increment: diff.total_bytes_increment,
|
||||
total_bytes_completed: diff.total_bytes_completed,
|
||||
total_bytes_completion_increment: diff.total_bytes_completion_increment,
|
||||
total_bytes_completion_rate: diff.total_bytes_completion_rate,
|
||||
total_transfer_bytes: diff.total_transfer_bytes,
|
||||
total_transfer_bytes_increment: diff.total_transfer_bytes_increment,
|
||||
total_transfer_bytes_completed: diff.total_transfer_bytes_completed,
|
||||
total_transfer_bytes_completion_increment: diff.total_transfer_bytes_completion_increment,
|
||||
total_transfer_bytes_completion_rate: diff.total_transfer_bytes_completion_rate,
|
||||
},
|
||||
)?
|
||||
.into_py_any(py)?;
|
||||
|
||||
let item_updates_v: Vec<Py<PyAny>> = updates
|
||||
.item_updates
|
||||
.into_iter()
|
||||
let item_updates_v: Vec<Py<PyAny>> = diff
|
||||
.item_diffs
|
||||
.iter()
|
||||
.map(|u| {
|
||||
Py::new(
|
||||
py,
|
||||
@@ -269,16 +197,20 @@ impl WrappedProgressUpdaterImpl {
|
||||
let argname_item_updates: Py<PyAny> = DETAILED_PROGRESS_ARG_NAMES[1].into_py_any(py)?;
|
||||
|
||||
let kwargs = [
|
||||
(argname_total_update, total_update_report),
|
||||
(argname_total_update, total_update),
|
||||
(argname_item_updates, item_updates),
|
||||
]
|
||||
.into_py_dict(py)?;
|
||||
|
||||
f.call((), Some(&kwargs))?;
|
||||
} else {
|
||||
let update_increment: u64 =
|
||||
updates.item_updates.iter().map(|pr| pr.bytes_completion_increment).sum();
|
||||
let _ = f.call1((update_increment,))?;
|
||||
let update_increment: u64 = diff.item_diffs.iter().map(|d| d.bytes_completion_increment).sum();
|
||||
let increment = if update_increment > 0 {
|
||||
update_increment
|
||||
} else {
|
||||
diff.total_bytes_completion_increment
|
||||
};
|
||||
let _ = f.call1((increment,))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -289,29 +221,126 @@ impl WrappedProgressUpdaterImpl {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct WrappedProgressUpdater {
|
||||
inner: Arc<WrappedProgressUpdaterImpl>,
|
||||
// === Simple per-file progress callback for downloads ===
|
||||
|
||||
/// Send a simple byte-increment update to a Python callback.
|
||||
pub async fn send_simple_progress(py_func: &Py<PyAny>, increment: u64) {
|
||||
if increment == 0 {
|
||||
return;
|
||||
}
|
||||
let py_func = Python::attach(|py| py_func.clone_ref(py));
|
||||
let rt = XetRuntime::current();
|
||||
let _ = rt
|
||||
.spawn_blocking(move || {
|
||||
Python::attach(|py| {
|
||||
let f = py_func.bind(py);
|
||||
let _ = f.call1((increment,));
|
||||
})
|
||||
})
|
||||
.await
|
||||
.log_error("Python exception updating download progress:");
|
||||
}
|
||||
|
||||
impl WrappedProgressUpdater {
|
||||
pub fn new(py_func: Py<PyAny>) -> PyResult<Self> {
|
||||
Ok(Self {
|
||||
inner: Arc::new(WrappedProgressUpdaterImpl::new(py_func)?),
|
||||
})
|
||||
// === Progress diff types ===
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ItemDiff {
|
||||
pub item_name: String,
|
||||
pub total_bytes: u64,
|
||||
pub bytes_completed: u64,
|
||||
pub bytes_completion_increment: u64,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct ProgressDiff {
|
||||
pub total_bytes: u64,
|
||||
pub total_bytes_increment: u64,
|
||||
pub total_bytes_completed: u64,
|
||||
pub total_bytes_completion_increment: u64,
|
||||
pub total_bytes_completion_rate: Option<f64>,
|
||||
pub total_transfer_bytes: u64,
|
||||
pub total_transfer_bytes_increment: u64,
|
||||
pub total_transfer_bytes_completed: u64,
|
||||
pub total_transfer_bytes_completion_increment: u64,
|
||||
pub total_transfer_bytes_completion_rate: Option<f64>,
|
||||
pub item_diffs: Vec<ItemDiff>,
|
||||
}
|
||||
|
||||
impl ProgressDiff {
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.total_bytes_increment == 0
|
||||
&& self.total_bytes_completion_increment == 0
|
||||
&& self.total_transfer_bytes_increment == 0
|
||||
&& self.total_transfer_bytes_completion_increment == 0
|
||||
&& self.item_diffs.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl TrackingProgressUpdater for WrappedProgressUpdater {
|
||||
async fn register_updates(&self, updates: ProgressUpdate) {
|
||||
let inner = self.inner.clone();
|
||||
// === Diff state for group-level progress polling ===
|
||||
|
||||
if inner.progress_updating_enabled {
|
||||
let _ = inner
|
||||
.register_updates_impl(updates)
|
||||
.await
|
||||
.log_error("Python exception updating progress:");
|
||||
/// Tracks the previous progress snapshot so that incremental diffs can be
|
||||
/// computed each time the session's `progress()` is polled.
|
||||
pub struct GroupProgressDiffState {
|
||||
prev_group: GroupProgressReport,
|
||||
prev_items: HashMap<UniqueID, ItemProgressReport>,
|
||||
}
|
||||
|
||||
impl GroupProgressDiffState {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
prev_group: GroupProgressReport::default(),
|
||||
prev_items: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn compute_diff(
|
||||
&mut self,
|
||||
group: GroupProgressReport,
|
||||
items: HashMap<UniqueID, ItemProgressReport>,
|
||||
) -> ProgressDiff {
|
||||
let total_bytes_increment = group.total_bytes.saturating_sub(self.prev_group.total_bytes);
|
||||
let total_bytes_completion_increment = group
|
||||
.total_bytes_completed
|
||||
.saturating_sub(self.prev_group.total_bytes_completed);
|
||||
let total_transfer_bytes_increment =
|
||||
group.total_transfer_bytes.saturating_sub(self.prev_group.total_transfer_bytes);
|
||||
let total_transfer_bytes_completion_increment = group
|
||||
.total_transfer_bytes_completed
|
||||
.saturating_sub(self.prev_group.total_transfer_bytes_completed);
|
||||
|
||||
let mut item_diffs = Vec::new();
|
||||
for (&id, report) in &items {
|
||||
let prev = self.prev_items.get(&id);
|
||||
let prev_completed = prev.map_or(0, |p| p.bytes_completed);
|
||||
let increment = report.bytes_completed.saturating_sub(prev_completed);
|
||||
|
||||
if increment > 0 || prev.is_none() {
|
||||
item_diffs.push(ItemDiff {
|
||||
item_name: report.item_name.clone(),
|
||||
total_bytes: report.total_bytes,
|
||||
bytes_completed: report.bytes_completed,
|
||||
bytes_completion_increment: increment,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let diff = ProgressDiff {
|
||||
total_bytes: group.total_bytes,
|
||||
total_bytes_increment,
|
||||
total_bytes_completed: group.total_bytes_completed,
|
||||
total_bytes_completion_increment,
|
||||
total_bytes_completion_rate: group.total_bytes_completion_rate,
|
||||
total_transfer_bytes: group.total_transfer_bytes,
|
||||
total_transfer_bytes_increment,
|
||||
total_transfer_bytes_completed: group.total_transfer_bytes_completed,
|
||||
total_transfer_bytes_completion_increment,
|
||||
total_transfer_bytes_completion_rate: group.total_transfer_bytes_completion_rate,
|
||||
item_diffs,
|
||||
};
|
||||
|
||||
self.prev_group = group;
|
||||
self.prev_items = items;
|
||||
|
||||
diff
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
use std::fmt::{Debug, Formatter};
|
||||
|
||||
use pyo3::exceptions::PyTypeError;
|
||||
use pyo3::prelude::PyAnyMethods;
|
||||
use pyo3::{Py, PyAny, PyErr, PyResult, Python};
|
||||
use tracing::error;
|
||||
use xet_client::cas_client::auth::{AuthError, TokenInfo, TokenRefresher};
|
||||
|
||||
/// A wrapper struct of a python function to refresh the CAS auth token.
|
||||
/// Since tokens are generated by hub, we want to be able to refresh the
|
||||
/// token using the hub client, which is only available in python.
|
||||
pub struct WrappedTokenRefresher {
|
||||
/// The function responsible for refreshing a token.
|
||||
/// Expects no inputs and returns a (str, u64) representing the new token
|
||||
/// and the unixtime (in seconds) of expiration, raising an exception
|
||||
/// if there is an issue.
|
||||
py_func: Py<PyAny>,
|
||||
name: String,
|
||||
}
|
||||
|
||||
impl Debug for WrappedTokenRefresher {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "WrappedTokenRefresher({})", self.name)
|
||||
}
|
||||
}
|
||||
|
||||
impl WrappedTokenRefresher {
|
||||
pub fn from_func(py_func: Py<PyAny>) -> PyResult<Self> {
|
||||
let name = Self::validate_callable(&py_func)?;
|
||||
Ok(Self { py_func, name })
|
||||
}
|
||||
|
||||
/// Validate that the inputted python object is callable
|
||||
fn validate_callable(py_func: &Py<PyAny>) -> Result<String, PyErr> {
|
||||
Python::attach(|py| {
|
||||
let f = py_func.bind(py);
|
||||
let name = f
|
||||
.repr()
|
||||
.and_then(|repr| repr.extract::<String>())
|
||||
.unwrap_or("unknown".to_string());
|
||||
if !f.is_callable() {
|
||||
error!("TokenRefresher func: {name} is not callable");
|
||||
return Err(PyTypeError::new_err(format!("refresh func: {name} is not callable")));
|
||||
}
|
||||
Ok(name)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
|
||||
#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
|
||||
impl TokenRefresher for WrappedTokenRefresher {
|
||||
async fn refresh(&self) -> Result<TokenInfo, AuthError> {
|
||||
Python::attach(|py| {
|
||||
let f = self.py_func.bind(py);
|
||||
if !f.is_callable() {
|
||||
return Err(AuthError::RefreshFunctionNotCallable(self.name.clone()));
|
||||
}
|
||||
let result = f
|
||||
.call0()
|
||||
.map_err(|e| AuthError::TokenRefreshFailure(format!("Error refreshing token: {e:?}")))?;
|
||||
result.extract::<(String, u64)>().map_err(|e| {
|
||||
AuthError::TokenRefreshFailure(format!("refresh function didn't return a (String, u64) tuple: {e:?}"))
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -48,6 +48,8 @@ smoke-test = []
|
||||
fd-track = ["xet-runtime/fd-track", "xet-client/fd-track", "xet-data/fd-track"]
|
||||
python = ["xet-runtime/python", "dep:pyo3"]
|
||||
simulation = ["xet-client/simulation"]
|
||||
native-tls-vendored = ["xet-client/native-tls-vendored"]
|
||||
elevated_information_level = ["xet-client/elevated_information_level", "xet-runtime/elevated_information_level"]
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
|
||||
Reference in New Issue
Block a user