Feat: optional request_headers on hf_xet API calls (#661)

Adding support for setting an optional `request_header` map on the
hf_xet upload and download API calls. This map is augmented with the
hf_xet user agent string and is passed along with the requests to
xetcas.

This PR also adds some unit tests for testing the map merging behavior
to `hf_xet/lib.rs` and adds support for running these with cargo test
and in github actions CI step.
This commit is contained in:
Brian Ronan
2026-02-23 14:43:58 -08:00
committed by GitHub
parent b3c5d05fb7
commit 17e900a70e
24 changed files with 348 additions and 77 deletions

View File

@@ -52,6 +52,9 @@ jobs:
- name: Build and Test
run: |
cargo test --verbose --no-fail-fast --features "strict git-xet-for-integration-test"
- name: Build and Test hf_xet
run: |
cd hf_xet && cargo test --verbose --no-fail-fast
- name: Check Cargo.lock has no uncommitted changes
run: |
# the build and test steps would update Cargo.lock if it is out of date
@@ -69,6 +72,9 @@ jobs:
- name: Build and Test
run: |
cargo test --verbose --no-fail-fast --features "strict git-xet-for-integration-test"
- name: Build and Test hf_xet
run: |
cd hf_xet && cargo test --verbose --no-fail-fast
build_and_test-macos:
runs-on: macos-latest
steps:
@@ -86,6 +92,9 @@ jobs:
- name: Build and Test
run: |
cargo test --verbose --no-fail-fast --features "strict git-xet-for-integration-test"
- name: Build and Test hf_xet
run: |
cd hf_xet && cargo test --verbose --no-fail-fast
build_and_test-wasm:
name: Build WASM
runs-on: ubuntu-latest

3
Cargo.lock generated
View File

@@ -1015,6 +1015,7 @@ dependencies = [
"dirs",
"error_printer",
"file_reconstruction",
"http",
"hub_client",
"lazy_static",
"mdb_shard",
@@ -1702,6 +1703,7 @@ dependencies = [
"data",
"git-url-parse",
"git2",
"http",
"hub_client",
"progress_tracking",
"rand_core 0.6.4",
@@ -2036,6 +2038,7 @@ dependencies = [
"anyhow",
"async-trait",
"cas_client",
"http",
"reqwest",
"reqwest-middleware",
"serde",

View File

@@ -3,8 +3,8 @@ use std::sync::Arc;
use anyhow::anyhow;
use cas_types::{REQUEST_ID_HEADER, SESSION_ID_HEADER};
use error_printer::{ErrorPrinter, OptionPrinter};
use http::{Extensions, StatusCode};
use reqwest::header::{AUTHORIZATION, HeaderValue};
use http::{Extensions, HeaderMap, StatusCode};
use reqwest::header::{AUTHORIZATION, COOKIE, HeaderValue, SET_COOKIE};
use reqwest::{Request, Response};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, Middleware, Next};
use tokio::sync::Mutex;
@@ -38,9 +38,25 @@ impl Middleware for HttpsToHttpMiddleware {
}
}
/// Utility to redact sensitive headers from a HeaderMap before logging
fn redact_headers(headers: &HeaderMap) -> HeaderMap {
let mut sanitized_headers = headers.clone();
let sensitive_keys = vec![AUTHORIZATION, COOKIE, SET_COOKIE];
for key in sensitive_keys {
if sanitized_headers.contains_key(&key) {
sanitized_headers.insert(key, "[REDACTED]".parse().unwrap());
}
}
sanitized_headers
}
#[allow(unused_variables)]
#[cfg(not(target_family = "wasm"))]
fn reqwest_client(user_agent: &str, unix_socket_path: Option<&str>) -> Result<reqwest::Client, CasClientError> {
fn reqwest_client(
unix_socket_path: Option<&str>,
custom_headers: Option<Arc<HeaderMap>>,
) -> Result<reqwest::Client, CasClientError> {
// Check config if explicit socket path is not provided
let socket_path = unix_socket_path
.map(|s| s.to_string())
@@ -51,7 +67,7 @@ fn reqwest_client(user_agent: &str, unix_socket_path: Option<&str>) -> Result<re
// Create client function
let socket_path_clone = socket_path.clone();
let user_agent_for_closure = user_agent.to_string();
let custom_headers_for_client = custom_headers.clone();
let create_client = move || {
let config = &xet_config().client;
let mut builder = reqwest::Client::builder()
@@ -61,15 +77,15 @@ fn reqwest_client(user_agent: &str, unix_socket_path: Option<&str>) -> Result<re
.read_timeout(config.read_timeout)
.http1_only();
if !user_agent_for_closure.is_empty() {
builder = builder.user_agent(&user_agent_for_closure);
}
#[cfg(unix)]
if let Some(ref path) = socket_path_clone {
builder = builder.unix_socket(path.clone());
}
if let Some(headers) = custom_headers_for_client {
builder = builder.default_headers((*headers).clone());
}
builder.build()
};
@@ -80,10 +96,11 @@ fn reqwest_client(user_agent: &str, unix_socket_path: Option<&str>) -> Result<re
info!(socket_path=?socket_path, "HTTP client configured with Unix socket");
} else {
let config = &xet_config().client;
let custom_headers = custom_headers.as_deref().map(redact_headers);
info!(
idle_timeout=?config.idle_connection_timeout,
max_idle_connections=config.max_idle_connections,
user_agent=?if user_agent.is_empty() { None } else { Some(user_agent) },
custom_headers=?custom_headers,
"HTTP client configured"
);
}
@@ -92,13 +109,16 @@ fn reqwest_client(user_agent: &str, unix_socket_path: Option<&str>) -> Result<re
}
#[cfg(target_family = "wasm")]
fn reqwest_client(user_agent: &str, _unix_socket_path: Option<&str>) -> Result<reqwest::Client, CasClientError> {
// For WASM, create a new client with the specified user_agent
fn reqwest_client(
_unix_socket_path: Option<&str>,
custom_headers: Option<Arc<HeaderMap>>,
) -> Result<reqwest::Client, CasClientError> {
// For WASM, create a new client with the specified headers, including the user-agent.
// Note: we could cache this, but user_agent can vary, so we create per-call
// Unix socket path is ignored on WASM
let mut builder = reqwest::Client::builder();
if !user_agent.is_empty() {
builder = builder.user_agent(user_agent);
if let Some(custom_headers) = custom_headers {
builder = builder.default_headers((*custom_headers).clone());
}
Ok(builder.build()?)
}
@@ -108,14 +128,14 @@ fn reqwest_client(user_agent: &str, _unix_socket_path: Option<&str>) -> Result<r
pub fn build_auth_http_client(
auth_config: &Option<AuthConfig>,
session_id: &str,
user_agent: &str,
unix_socket_path: Option<&str>,
custom_headers: Option<Arc<HeaderMap>>,
) -> Result<ClientWithMiddleware, CasClientError> {
let auth_middleware = auth_config.as_ref().map(AuthMiddleware::from).info_none("CAS auth disabled");
let logging_middleware = Some(LoggingMiddleware);
let session_middleware = (!session_id.is_empty()).then(|| SessionMiddleware(session_id.to_owned()));
let mut builder = ClientBuilder::new(reqwest_client(user_agent, unix_socket_path)?);
let mut builder = ClientBuilder::new(reqwest_client(unix_socket_path, custom_headers)?);
#[cfg(unix)]
if unix_socket_path.is_some() {
@@ -132,10 +152,10 @@ pub fn build_auth_http_client(
/// Builds HTTP Client to talk to CAS.
pub fn build_http_client(
session_id: &str,
user_agent: &str,
unix_socket_path: Option<&str>,
custom_headers: Option<Arc<HeaderMap>>,
) -> Result<ClientWithMiddleware, CasClientError> {
build_auth_http_client(&None, session_id, user_agent, unix_socket_path)
build_auth_http_client(&None, session_id, unix_socket_path, custom_headers)
}
/// Helper trait to allow the reqwest_middleware client to optionally add a middleware.
@@ -365,7 +385,7 @@ mod tests {
#[test]
fn test_build_http_client_without_uds() {
let result = build_http_client("test-session", "test-agent", None);
let result = build_http_client("test-session", None, None);
assert!(result.is_ok());
}
}

View File

@@ -9,7 +9,7 @@ use cas_types::{
};
use futures::TryStreamExt;
use http::HeaderValue;
use http::header::{CONTENT_LENGTH, RANGE};
use http::header::{CONTENT_LENGTH, HeaderMap, RANGE};
use mdb_shard::file_structs::{FileDataSequenceEntry, FileDataSequenceHeader, MDBFileInfo};
use merklehash::MerkleHash;
use reqwest::{Body, Response, StatusCode, Url};
@@ -54,23 +54,26 @@ impl RemoteClient {
/// * `auth` - Optional authentication configuration
/// * `session_id` - Session identifier
/// * `dry_run` - Whether to run in dry-run mode
/// * `user_agent` - User agent string
/// * `unix_socket_path` - Optional Unix socket path for proxying connections (ignored on non-Unix platforms)
/// * `custom_headers` - Optional custom headers to include in HTTP requests (should include User-Agent)
pub fn new_with_socket(
endpoint: &str,
auth: &Option<AuthConfig>,
session_id: &str,
dry_run: bool,
user_agent: &str,
unix_socket_path: Option<&str>,
custom_headers: Option<Arc<HeaderMap>>,
) -> Arc<Self> {
Arc::new(Self {
endpoint: endpoint.to_string(),
dry_run,
authenticated_http_client: Arc::new(
http_client::build_auth_http_client(auth, session_id, user_agent, unix_socket_path).unwrap(),
http_client::build_auth_http_client(auth, session_id, unix_socket_path, custom_headers.clone())
.unwrap(),
),
http_client: Arc::new(
http_client::build_http_client(session_id, unix_socket_path, custom_headers).unwrap(),
),
http_client: Arc::new(http_client::build_http_client(session_id, user_agent, unix_socket_path).unwrap()),
upload_concurrency_controller: AdaptiveConcurrencyController::new_upload("upload"),
download_concurrency_controller: AdaptiveConcurrencyController::new_download("download"),
})
@@ -80,20 +83,27 @@ impl RemoteClient {
///
/// If `HF_XET_CLIENT_UNIX_SOCKET_PATH` is set in the configuration, this will
/// automatically use the Unix socket for connections (checked by build_http_client).
///
/// # Arguments
/// * `endpoint` - The CAS endpoint URL
/// * `auth` - Optional authentication configuration
/// * `session_id` - Session identifier
/// * `dry_run` - Whether to run in dry-run mode
/// * `custom_headers` - Optional custom headers to include in HTTP requests (should include User-Agent)
pub fn new(
endpoint: &str,
auth: &Option<AuthConfig>,
session_id: &str,
dry_run: bool,
user_agent: &str,
custom_headers: Option<Arc<HeaderMap>>,
) -> Arc<Self> {
Arc::new(Self {
endpoint: endpoint.to_string(),
dry_run,
authenticated_http_client: Arc::new(
http_client::build_auth_http_client(auth, session_id, user_agent, None).unwrap(),
http_client::build_auth_http_client(auth, session_id, None, custom_headers.clone()).unwrap(),
),
http_client: Arc::new(http_client::build_http_client(session_id, user_agent, None).unwrap()),
http_client: Arc::new(http_client::build_http_client(session_id, None, custom_headers).unwrap()),
upload_concurrency_controller: AdaptiveConcurrencyController::new_upload("upload"),
download_concurrency_controller: AdaptiveConcurrencyController::new_download("download"),
})
@@ -226,6 +236,7 @@ impl Client for RemoteClient {
// convert exclusive-end to inclusive-end range
request = request.header(RANGE, HttpRange::from(range).range_header())
}
request.send()
})
.await;
@@ -630,7 +641,7 @@ mod tests {
let raw_xorb = build_raw_xorb(3, ChunkSize::Random(512, 10248));
let threadpool = XetRuntime::new().unwrap();
let client = RemoteClient::new(CAS_ENDPOINT, &None, "", false, "");
let client = RemoteClient::new(CAS_ENDPOINT, &None, "", false, None);
let cas_object = build_and_verify_cas_object(raw_xorb, Some(CompressionScheme::LZ4));

View File

@@ -36,6 +36,7 @@ use std::time::Duration;
use async_trait::async_trait;
use axum::Router;
use axum::routing::{get, head, post};
use http::header::{self, HeaderMap, HeaderValue};
use tokio::net::TcpListener;
use tokio::sync::oneshot;
use tower_http::cors::CorsLayer;
@@ -303,6 +304,8 @@ impl LocalTestServer {
tokio::time::sleep(Duration::from_millis(50)).await;
let mut headers = HeaderMap::new();
headers.insert(header::USER_AGENT, HeaderValue::from_static("test-agent"));
let (remote_client, socket_proxy) = {
#[cfg(unix)]
{
@@ -323,20 +326,21 @@ impl LocalTestServer {
&None,
"test-session",
false,
"test-agent",
Some(&socket_path_str),
Some(Arc::new(headers)),
);
(client, Some(proxy))
} else {
let client = RemoteClient::new(&tcp_endpoint, &None, "test-session", false, "test-agent");
let client =
RemoteClient::new(&tcp_endpoint, &None, "test-session", false, Some(Arc::new(headers)));
(client, None)
}
}
#[cfg(not(unix))]
{
let client = RemoteClient::new(&tcp_endpoint, &None, "test-session", false, "test-agent");
let client = RemoteClient::new(&tcp_endpoint, &None, "test-session", false, None);
(client, Option::<()>::None)
}
};

View File

@@ -10,8 +10,7 @@ use cas_client::http_client::build_http_client;
use cas_client::progress_tracked_streams::{StreamProgressReporter, UploadProgressStream};
use cas_client::retry_wrapper::RetryWrapper;
use clap::Parser;
use http::HeaderValue;
use http::header::CONTENT_LENGTH;
use http::header::{self, CONTENT_LENGTH, HeaderMap, HeaderValue};
use rand::Rng;
use reqwest::Body;
use reqwest_middleware::ClientWithMiddleware;
@@ -102,7 +101,10 @@ async fn run_client(min_data_kb: u64, max_data_kb: u64, repeat_duration_seconds:
eprintln!("Connecting to server at {}", server_addr);
// Create HTTP client
let http_client = build_http_client("test_session", "test_user_agent", None).expect("Failed to create HTTP client");
let mut headers = HeaderMap::new();
headers.insert(header::USER_AGENT, HeaderValue::from_static("test_user_agent"));
let http_client =
build_http_client("test_session", None, Some(Arc::new(headers))).expect("Failed to create HTTP client");
// Wait for server to be ready before starting
wait_for_server_ready(&http_client, server_addr)

View File

@@ -45,6 +45,7 @@ async-trait = { workspace = true }
bytes = { workspace = true }
chrono = { workspace = true }
clap = { workspace = true }
http = { workspace = true }
lazy_static = { workspace = true }
more-asserts = { workspace = true }
prometheus = { workspace = true }

View File

@@ -11,6 +11,7 @@ use clap::{Args, Parser, Subcommand};
use data::data_client::default_config;
use data::migration_tool::hub_client_token_refresher::HubClientTokenRefresher;
use data::migration_tool::migrate::migrate_files_impl;
use http::header::{self, HeaderMap, HeaderValue};
use hub_client::{BearerCredentialHelper, HubClient, Operation, RepoInfo};
use merklehash::MerkleHash;
use utils::auth::TokenRefresher;
@@ -56,14 +57,17 @@ impl XCommand {
.token
.unwrap_or_else(|| std::env::var("HF_TOKEN").unwrap_or_default());
let mut headers = HeaderMap::new();
headers.insert(header::USER_AGENT, HeaderValue::from_static(USER_AGENT));
let cred_helper = BearerCredentialHelper::new(token, "");
let hub_client = HubClient::new(
&endpoint,
RepoInfo::try_from(&self.overrides.repo_type, &self.overrides.repo_id)?,
Some("main".to_owned()),
USER_AGENT,
"",
cred_helper,
Some(Arc::new(headers)),
)?;
self.command.run(hub_client).await
@@ -206,16 +210,25 @@ async fn query_reconstruction(
client: Arc::new(hub_client),
}) as Arc<dyn TokenRefresher>;
// Create headers with USER_AGENT
let mut headers = http::HeaderMap::new();
headers.insert(http::header::USER_AGENT, http::HeaderValue::from_static(USER_AGENT));
let config = default_config(
jwt_info.cas_url.clone(),
None,
Some((jwt_info.access_token, jwt_info.exp)),
Some(token_refresher),
USER_AGENT.to_string(),
Some(Arc::new(headers)),
)?;
let cas_storage_config = &config.data_config;
let remote_client =
RemoteClient::new(&jwt_info.cas_url, &cas_storage_config.auth, "", true, &cas_storage_config.user_agent);
let remote_client = RemoteClient::new(
&jwt_info.cas_url,
&cas_storage_config.auth,
"",
true,
cas_storage_config.custom_headers.clone(),
);
remote_client
.get_reconstruction(&file_hash, bytes_range)

View File

@@ -1,8 +1,10 @@
use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::sync::Arc;
use cas_client::remote_client::PREFIX_DEFAULT;
use cas_object::CompressionScheme;
use http::HeaderMap;
use utils::auth::AuthConfig;
use crate::errors::Result;
@@ -21,7 +23,7 @@ pub struct DataConfig {
pub auth: Option<AuthConfig>,
pub prefix: String,
pub staging_directory: Option<PathBuf>,
pub user_agent: String,
pub custom_headers: Option<Arc<HeaderMap>>,
}
#[derive(Debug)]
@@ -93,7 +95,7 @@ impl TranslatorConfig {
auth: None,
prefix: PREFIX_DEFAULT.into(),
staging_directory: None,
user_agent: String::new(),
custom_headers: None,
},
shard_config: ShardConfig {
prefix: PREFIX_DEFAULT.into(),
@@ -124,7 +126,7 @@ impl TranslatorConfig {
auth: None,
prefix: PREFIX_DEFAULT.into(),
staging_directory: None,
user_agent: String::new(),
custom_headers: None,
},
shard_config: ShardConfig {
prefix: PREFIX_DEFAULT.into(),
@@ -156,7 +158,7 @@ impl TranslatorConfig {
auth: None,
prefix: PREFIX_DEFAULT.into(),
staging_directory: None,
user_agent: String::new(),
custom_headers: None,
},
shard_config: ShardConfig {
prefix: PREFIX_DEFAULT.into(),

View File

@@ -7,6 +7,7 @@ use bytes::Bytes;
use cas_client::remote_client::PREFIX_DEFAULT;
use cas_object::CompressionScheme;
use deduplication::{Chunker, DeduplicationMetrics};
use http::header::HeaderMap;
use mdb_shard::Sha256;
use merklehash::MerkleHash;
use progress_tracking::TrackingProgressUpdater;
@@ -27,7 +28,7 @@ pub fn default_config(
xorb_compression: Option<CompressionScheme>,
token_info: Option<(String, u64)>,
token_refresher: Option<Arc<dyn TokenRefresher>>,
user_agent: String,
custom_headers: Option<Arc<HeaderMap>>,
) -> errors::Result<TranslatorConfig> {
// Intercept local:// to run a simulated CAS server in a specified directory.
// This is useful for testing and development.
@@ -71,7 +72,7 @@ pub fn default_config(
auth: auth_cfg.clone(),
prefix: PREFIX_DEFAULT.into(),
staging_directory: None,
user_agent: user_agent.clone(),
custom_headers,
},
shard_config: ShardConfig {
prefix: PREFIX_DEFAULT.into(),
@@ -97,14 +98,14 @@ pub async fn upload_bytes_async(
token_info: Option<(String, u64)>,
token_refresher: Option<Arc<dyn TokenRefresher>>,
progress_updater: Option<Arc<dyn TrackingProgressUpdater>>,
user_agent: String,
custom_headers: Option<Arc<HeaderMap>>,
) -> errors::Result<Vec<XetFileInfo>> {
let config = default_config(
endpoint.unwrap_or_else(|| xet_config().data.default_cas_endpoint.clone()),
None,
token_info,
token_refresher,
user_agent,
custom_headers,
)?;
Span::current().record("session_id", &config.session_id);
@@ -142,7 +143,7 @@ pub async fn upload_async(
token_info: Option<(String, u64)>,
token_refresher: Option<Arc<dyn TokenRefresher>>,
progress_updater: Option<Arc<dyn TrackingProgressUpdater>>,
user_agent: String,
custom_headers: Option<Arc<HeaderMap>>,
) -> errors::Result<Vec<XetFileInfo>> {
// chunk files
// produce Xorbs + Shards
@@ -153,7 +154,7 @@ pub async fn upload_async(
None,
token_info,
token_refresher,
user_agent,
custom_headers,
)?;
let span = Span::current();
@@ -201,7 +202,7 @@ pub async fn download_async(
token_info: Option<(String, u64)>,
token_refresher: Option<Arc<dyn TokenRefresher>>,
progress_updaters: Option<Vec<Arc<dyn TrackingProgressUpdater>>>,
user_agent: String,
custom_headers: Option<Arc<HeaderMap>>,
) -> errors::Result<Vec<String>> {
if let Some(updaters) = &progress_updaters
&& updaters.len() != file_infos.len()
@@ -213,9 +214,10 @@ pub async fn download_async(
None,
token_info,
token_refresher,
user_agent,
custom_headers,
)?
.into();
Span::current().record("session_id", &config.session_id);
let updaters = match progress_updaters {
@@ -404,7 +406,7 @@ mod tests {
let _hf_home_guard = EnvVarGuard::set("HF_HOME", temp_dir.path().to_str().unwrap());
let endpoint = "http://localhost:8080".to_string();
let result = default_config(endpoint, None, None, None, String::new());
let result = default_config(endpoint, None, None, None, None);
assert!(result.is_ok());
let config = result.unwrap();
@@ -421,7 +423,7 @@ mod tests {
let hf_home_guard = EnvVarGuard::set("HF_HOME", temp_dir_hf_home.path().to_str().unwrap());
let endpoint = "http://localhost:8080".to_string();
let result = default_config(endpoint, None, None, None, String::new());
let result = default_config(endpoint, None, None, None, None);
assert!(result.is_ok());
let config = result.unwrap();
@@ -434,7 +436,7 @@ mod tests {
let _hf_home_guard = EnvVarGuard::set("HF_HOME", temp_dir.path().to_str().unwrap());
let endpoint = "http://localhost:8080".to_string();
let result = default_config(endpoint, None, None, None, String::new());
let result = default_config(endpoint, None, None, None, None);
assert!(result.is_ok());
let config = result.unwrap();
@@ -448,7 +450,7 @@ mod tests {
let _hf_xet_cache_guard = EnvVarGuard::set("HF_XET_CACHE", temp_dir.path().to_str().unwrap());
let endpoint = "http://localhost:8080".to_string();
let result = default_config(endpoint, None, None, None, String::new());
let result = default_config(endpoint, None, None, None, None);
assert!(result.is_ok());
let config = result.unwrap();
@@ -459,7 +461,7 @@ mod tests {
#[serial(default_config_env)]
fn test_default_config_without_env_vars() {
let endpoint = "http://localhost:8080".to_string();
let result = default_config(endpoint, None, None, None, String::new());
let result = default_config(endpoint, None, None, None, None);
let expected = home_dir().unwrap().join(".cache").join("huggingface").join("xet");

View File

@@ -2,6 +2,7 @@ use std::sync::Arc;
use anyhow::{Result, anyhow};
use cas_object::CompressionScheme;
use http::header;
use hub_client::{BearerCredentialHelper, HubClient, Operation, RepoInfo};
use mdb_shard::file_structs::MDBFileInfo;
use tracing::{Instrument, Span, info_span, instrument};
@@ -36,13 +37,15 @@ pub async fn migrate_with_external_runtime(
repo_id: &str,
) -> Result<()> {
let cred_helper = BearerCredentialHelper::new(hub_token.to_owned(), "");
let mut headers = header::HeaderMap::new();
headers.insert(header::USER_AGENT, header::HeaderValue::from_static(USER_AGENT));
let hub_client = HubClient::new(
hub_endpoint,
RepoInfo::try_from(repo_type, repo_id)?,
Some("main".to_owned()),
USER_AGENT,
"",
cred_helper,
Some(Arc::new(headers)),
)?;
migrate_files_impl(file_paths, sha256s, false, hub_client, cas_endpoint, None, false).await?;
@@ -71,12 +74,16 @@ pub async fn migrate_files_impl(
}) as Arc<dyn TokenRefresher>;
let cas = cas_endpoint.unwrap_or(jwt_info.cas_url);
// Create headers with USER_AGENT
let mut headers = http::HeaderMap::new();
headers.insert(http::header::USER_AGENT, http::HeaderValue::from_static(USER_AGENT));
let config = default_config(
cas,
compression,
Some((jwt_info.access_token, jwt_info.exp)),
Some(token_refresher),
USER_AGENT.to_string(),
Some(Arc::new(headers)),
)?;
Span::current().record("session_id", &config.session_id);

View File

@@ -18,7 +18,7 @@ pub(crate) async fn create_remote_client(
&cas_storage_config.auth,
session_id,
dry_run,
&cas_storage_config.user_agent,
cas_storage_config.custom_headers.clone(),
)),
Endpoint::FileSystem(ref path) => {
#[cfg(not(target_family = "wasm"))]

View File

@@ -19,6 +19,7 @@ async-trait = { workspace = true }
clap = { workspace = true }
git-url-parse = { workspace = true }
git2 = { workspace = true }
http = { workspace = true }
rand_core = { workspace = true }
reqwest = { workspace = true }
reqwest-middleware = { workspace = true }

View File

@@ -5,6 +5,7 @@ use std::sync::{Arc, OnceLock};
use async_trait::async_trait;
use data::FileUploadSession;
use data::data_client::{clean_file, default_config};
use http::header;
use hub_client::Operation;
use progress_tracking::{ProgressUpdate, TrackingProgressUpdater};
use utils::auth::TokenRefresher;
@@ -77,15 +78,17 @@ impl TransferAgent for XetAgent {
// only one prompt is presented.
let repo = self.repo.get().unwrap(); // protocol state guarantees self.repo is set.
let mut headers = header::HeaderMap::new();
headers.insert(header::USER_AGENT, header::HeaderValue::from_static(USER_AGENT));
let session_id = req.action.header.get(XET_SESSION_ID).map(|s| s.as_str()).unwrap_or_default();
let user_agent = USER_AGENT;
let token_refresher: Arc<dyn TokenRefresher> = Arc::new(DirectRefreshRouteTokenRefresher::new(
repo,
self.remote_url.clone(),
&req.action.href,
Operation::Upload,
session_id,
user_agent,
Some(Arc::new(headers)),
)?);
// From git-lfs:
// > First worker is the only one allowed to start immediately.
@@ -123,8 +126,12 @@ impl TransferAgent for XetAgent {
.parse()
.map_err(GitXetError::internal)?;
// Create headers with user agent
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(header::USER_AGENT, header::HeaderValue::from_static(USER_AGENT));
let config =
default_config(cas_url, None, Some((token, token_expiry)), Some(token_refresher), user_agent.to_string())?
default_config(cas_url, None, Some((token, token_expiry)), Some(token_refresher), Some(Arc::new(headers)))?
.disable_progress_aggregation()
.with_session_id(session_id); // upload one file at a time so no need for the heavy progress aggregator
let session = FileUploadSession::new(config.into(), Some(Arc::new(xet_updater))).await?;

View File

@@ -3,6 +3,7 @@ use std::sync::Arc;
use async_trait::async_trait;
use cas_client::retry_wrapper::RetryWrapper;
use cas_client::{Api, build_http_client};
use http::header::HeaderMap;
use hub_client::{CasJWTInfo, CredentialHelper, Operation};
use reqwest_middleware::ClientWithMiddleware;
use utils::auth::{TokenInfo, TokenRefresher};
@@ -26,7 +27,7 @@ impl DirectRefreshRouteTokenRefresher {
refresh_route: &str,
operation: Operation,
session_id: &str,
user_agent: &str,
custom_headers: Option<Arc<HeaderMap>>,
) -> Result<Self> {
let remote_url = match remote_url {
Some(r) => r,
@@ -37,7 +38,7 @@ impl DirectRefreshRouteTokenRefresher {
Ok(Self {
refresh_route: refresh_route.to_owned(),
client: build_http_client(session_id, user_agent, None)?,
client: build_http_client(session_id, None, custom_headers)?,
cred_helper,
})
}

3
hf_xet/Cargo.lock generated
View File

@@ -761,6 +761,7 @@ dependencies = [
"deduplication",
"error_printer",
"file_reconstruction",
"http",
"hub_client",
"lazy_static",
"mdb_shard",
@@ -1394,6 +1395,7 @@ dependencies = [
"chrono",
"data",
"error_printer",
"http",
"itertools 0.14.0",
"lazy_static",
"pprof",
@@ -1461,6 +1463,7 @@ dependencies = [
"anyhow",
"async-trait",
"cas_client",
"http",
"reqwest",
"reqwest-middleware",
"serde",

View File

@@ -7,7 +7,7 @@ license = "Apache-2.0"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[lib]
name = "hf_xet"
crate-type = ["cdylib"]
crate-type = ["cdylib", "lib"]
[dependencies]
cas_client = { path = "../cas_client" }
@@ -29,12 +29,12 @@ pprof = { version = "0.14", features = [
"protobuf-codec",
], optional = true }
pyo3 = { version = "0.26", features = [
"extension-module",
"abi3-py37",
"auto-initialize",
] }
rand = "0.9"
tracing = "0.1"
http = "1"
# Unix-specific dependencies
[target.'cfg(unix)'.dependencies]
@@ -46,6 +46,7 @@ winapi = { version = "0.3", features = ["consoleapi", "wincon", "errhandlingapi"
[features]
default = ["no-default-cache", "elevated_information_level"] # By default, hf_xet disables the disk cache and uses aggressive logging level
extension-module = ["pyo3/extension-module"] # Only enabled when building with maturin
native-tls = ["cas_client/native-tls-vendored"]
native-tls-vendored = ["cas_client/native-tls-vendored"]
no-default-cache = ["xet_config/no-default-cache"]

View File

@@ -3,12 +3,14 @@ mod progress_update;
mod runtime;
mod token_refresh;
use std::collections::HashMap;
use std::fmt::Debug;
use std::iter::IntoIterator;
use std::sync::Arc;
use data::errors::DataProcessingError;
use data::{XetFileInfo, data_client};
use http::header::{self, HeaderMap, HeaderName, HeaderValue};
use itertools::Itertools;
use progress_tracking::TrackingProgressUpdater;
use pyo3::exceptions::{PyKeyboardInterrupt, PyRuntimeError};
@@ -29,6 +31,47 @@ const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VE
#[cfg(feature = "profiling")]
pub(crate) mod profiling;
/// Converts a HashMap of headers to a HeaderMap and merges in the USER_AGENT.
///
/// If the input contains a User-Agent header, the USER_AGENT is appended to it.
/// Otherwise, USER_AGENT is set as the only User-Agent header.
fn build_headers_with_user_agent(request_headers: Option<HashMap<String, String>>) -> PyResult<Option<Arc<HeaderMap>>> {
let mut map = request_headers
.map(|headers| {
let mut map = HeaderMap::new();
for (key, value) in headers {
let name = HeaderName::from_bytes(key.as_bytes())
.map_err(|e| PyRuntimeError::new_err(format!("Invalid header name '{}': {}", key, e)))?;
let value = HeaderValue::from_str(&value)
.map_err(|e| PyRuntimeError::new_err(format!("Invalid header value for '{}': {}", key, e)))?;
map.insert(name, value);
}
Ok::<_, PyErr>(map)
})
.transpose()?
.unwrap_or_default();
// Append our USER_AGENT to any existing User-Agent header, or add it if not present
let combined_user_agent = if let Some(existing_ua) = map.get(header::USER_AGENT) {
// Append our user agent to the existing one
let existing_str = existing_ua.to_str().unwrap_or("");
format!("{}; {}", existing_str, USER_AGENT)
} else {
// No existing user agent, use ours
USER_AGENT.to_string()
};
// Try to create the combined header value, fall back gracefully if invalid
let user_agent_value = HeaderValue::from_str(&combined_user_agent)
.or_else(|_: http::header::InvalidHeaderValue| {
Ok::<HeaderValue, http::header::InvalidHeaderValue>(HeaderValue::from_static(USER_AGENT))
})
.unwrap_or_else(|_: http::header::InvalidHeaderValue| HeaderValue::from_static("unknown"));
map.insert(header::USER_AGENT, user_agent_value);
Ok(Some(Arc::new(map)))
}
fn convert_data_processing_error(e: DataProcessingError) -> PyErr {
if cfg!(debug_assertions) {
PyRuntimeError::new_err(format!("Data processing error: {e:?}"))
@@ -38,7 +81,8 @@ fn convert_data_processing_error(e: DataProcessingError) -> PyErr {
}
#[pyfunction]
#[pyo3(signature = (file_contents, endpoint, token_info, token_refresher, progress_updater, _repo_type), text_signature = "(file_contents: List[bytes], endpoint: Optional[str], token_info: Optional[(str, int)], token_refresher: Optional[Callable[[], (str, int)]], progress_updater: Optional[Callable[[int], None]], _repo_type: Optional[str]) -> List[PyXetUploadInfo]")]
#[pyo3(signature = (file_contents, endpoint, token_info, token_refresher, progress_updater, _repo_type, request_headers=None), text_signature = "(file_contents: List[bytes], endpoint: Optional[str], token_info: Optional[(str, int)], token_refresher: Optional[Callable[[], (str, int)]], progress_updater: Optional[Callable[[int], None]], _repo_type: Optional[str], request_headers: Optional[Dict[str, str]]) -> List[PyXetUploadInfo]")]
#[allow(clippy::too_many_arguments)]
pub fn upload_bytes(
py: Python,
file_contents: Vec<Vec<u8>>,
@@ -47,11 +91,15 @@ pub fn upload_bytes(
token_refresher: Option<Py<PyAny>>,
progress_updater: Option<Py<PyAny>>,
_repo_type: Option<String>,
request_headers: Option<HashMap<String, String>>,
) -> PyResult<Vec<PyXetUploadInfo>> {
let refresher = token_refresher.map(WrappedTokenRefresher::from_func).transpose()?.map(Arc::new);
let updater = progress_updater.map(WrappedProgressUpdater::new).transpose()?.map(Arc::new);
let x: u64 = rand::rng().random();
// Convert Python dict -> Rust HashMap -> HeaderMap and merge with USER_AGENT
let header_map = build_headers_with_user_agent(request_headers)?;
async_run(py, async move {
debug!(
"Upload bytes call {x:x}: (PID = {}) Uploading {} files as bytes.",
@@ -65,7 +113,7 @@ pub fn upload_bytes(
token_info,
refresher.map(|v| v as Arc<_>),
updater.map(|v| v as Arc<_>),
USER_AGENT.to_string(),
header_map,
)
.await
.map_err(convert_data_processing_error)?
@@ -80,7 +128,8 @@ pub fn upload_bytes(
}
#[pyfunction]
#[pyo3(signature = (file_paths, endpoint, token_info, token_refresher, progress_updater, _repo_type), text_signature = "(file_paths: List[str], endpoint: Optional[str], token_info: Optional[(str, int)], token_refresher: Optional[Callable[[], (str, int)]], progress_updater: Optional[Callable[[int], None]], _repo_type: Optional[str]) -> List[PyXetUploadInfo]")]
#[pyo3(signature = (file_paths, endpoint, token_info, token_refresher, progress_updater, _repo_type, request_headers=None), text_signature = "(file_paths: List[str], endpoint: Optional[str], token_info: Optional[(str, int)], token_refresher: Optional[Callable[[], (str, int)]], progress_updater: Optional[Callable[[int], None]], _repo_type: Optional[str], request_headers: Optional[Dict[str, str]]) -> List[PyXetUploadInfo]")]
#[allow(clippy::too_many_arguments)]
pub fn upload_files(
py: Python,
file_paths: Vec<String>,
@@ -89,6 +138,7 @@ pub fn upload_files(
token_refresher: Option<Py<PyAny>>,
progress_updater: Option<Py<PyAny>>,
_repo_type: Option<String>,
request_headers: Option<HashMap<String, String>>,
) -> PyResult<Vec<PyXetUploadInfo>> {
let refresher = token_refresher.map(WrappedTokenRefresher::from_func).transpose()?.map(Arc::new);
let updater = progress_updater.map(WrappedProgressUpdater::new).transpose()?.map(Arc::new);
@@ -97,6 +147,9 @@ pub fn upload_files(
let x: u64 = rand::rng().random();
// Convert Python dict -> Rust HashMap -> HeaderMap and merge with USER_AGENT
let header_map = build_headers_with_user_agent(request_headers)?;
async_run(py, async move {
debug!(
"Upload call {x:x}: (PID = {}) Uploading {} files {file_names}{}",
@@ -112,7 +165,7 @@ pub fn upload_files(
token_info,
refresher.map(|v| v as Arc<_>),
updater.map(|v| v as Arc<_>),
USER_AGENT.to_string(),
header_map,
)
.await
.map_err(convert_data_processing_error)?
@@ -167,7 +220,7 @@ pub fn hash_files(py: Python, file_paths: Vec<String>) -> PyResult<Vec<PyXetUplo
}
#[pyfunction]
#[pyo3(signature = (files, endpoint, token_info, token_refresher, progress_updater), text_signature = "(files: List[PyXetDownloadInfo], endpoint: Optional[str], token_info: Optional[(str, int)], token_refresher: Optional[Callable[[], (str, int)]], progress_updater: Optional[List[Callable[[int], None]]]) -> List[str]")]
#[pyo3(signature = (files, endpoint, token_info, token_refresher, progress_updater, request_headers=None), text_signature = "(files: List[PyXetDownloadInfo], endpoint: Optional[str], token_info: Optional[(str, int)], token_refresher: Optional[Callable[[], (str, int)]], progress_updater: Optional[List[Callable[[int], None]]], request_headers: Optional[Dict[str, str]]) -> List[str]")]
pub fn download_files(
py: Python,
files: Vec<PyXetDownloadInfo>,
@@ -175,11 +228,15 @@ pub fn download_files(
token_info: Option<(String, u64)>,
token_refresher: Option<Py<PyAny>>,
progress_updater: Option<Vec<Py<PyAny>>>,
request_headers: Option<HashMap<String, String>>,
) -> PyResult<Vec<String>> {
let file_infos: Vec<_> = files.into_iter().map(<(XetFileInfo, DestinationPath)>::from).collect();
let refresher = token_refresher.map(WrappedTokenRefresher::from_func).transpose()?.map(Arc::new);
let updaters = progress_updater.map(try_parse_progress_updaters).transpose()?;
// Convert Python dict -> Rust HashMap -> HeaderMap and merge with USER_AGENT
let header_map = build_headers_with_user_agent(request_headers)?;
let x: u64 = rand::rng().random();
let file_names = file_infos.iter().take(3).map(|(_, p)| p).join(", ");
@@ -198,7 +255,7 @@ pub fn download_files(
token_info,
refresher.map(|v| v as Arc<_>),
updaters,
USER_AGENT.to_string(),
header_map,
)
.await
.map_err(convert_data_processing_error)?;
@@ -386,3 +443,106 @@ pub fn hf_xet(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
Ok(())
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use super::*;
// Initialize Python once for all tests
fn setup() {
// When auto-initialize is enabled, Python will be initialized on first use
// This ensures Python is available for the tests
let _ = pyo3::Python::attach(|_py| {});
}
#[test]
fn test_build_headers_with_none_empty_hashmap() {
setup();
let empty_map: HashMap<String, String> = HashMap::new();
let result = build_headers_with_user_agent(Some(empty_map)).unwrap();
let headers = result.unwrap();
// Should have exactly one header: USER_AGENT
assert_eq!(headers.len(), 1);
assert!(headers.contains_key(header::USER_AGENT));
let user_agent = headers.get(header::USER_AGENT).unwrap().to_str().unwrap();
assert_eq!(user_agent, USER_AGENT);
let result = build_headers_with_user_agent(None).unwrap();
let headers = result.unwrap();
// Should have exactly one header: USER_AGENT
assert_eq!(headers.len(), 1);
assert!(headers.contains_key(header::USER_AGENT));
let user_agent = headers.get(header::USER_AGENT).unwrap().to_str().unwrap();
assert_eq!(user_agent, USER_AGENT);
}
#[test]
fn test_build_headers_with_valid_headers() {
setup();
let mut headers_map = HashMap::new();
headers_map.insert("Content-Type".to_string(), "application/json".to_string());
headers_map.insert("Authorization".to_string(), "Bearer token123".to_string());
let result = build_headers_with_user_agent(Some(headers_map)).unwrap();
let headers = result.unwrap();
// Should have 3 headers: Content-Type, Authorization, and USER_AGENT
assert_eq!(headers.len(), 3);
// Verify each header was converted correctly
assert_eq!(headers.get(header::CONTENT_TYPE).unwrap().to_str().unwrap(), "application/json");
assert_eq!(headers.get(header::AUTHORIZATION).unwrap().to_str().unwrap(), "Bearer token123");
// Verify USER_AGENT was added
let user_agent = headers.get(header::USER_AGENT).unwrap().to_str().unwrap();
assert_eq!(user_agent, USER_AGENT);
}
#[test]
fn test_build_headers_appends_to_existing_user_agent() {
setup();
let mut headers_map = HashMap::new();
headers_map.insert("User-Agent".to_string(), "CustomClient/1.0".to_string());
let result = build_headers_with_user_agent(Some(headers_map)).unwrap();
let headers = result.unwrap();
// Should have exactly one header: USER_AGENT
assert_eq!(headers.len(), 1);
// Verify USER_AGENT was appended to existing one
let user_agent = headers.get(header::USER_AGENT).unwrap().to_str().unwrap();
assert_eq!(user_agent, format!("CustomClient/1.0; {}", USER_AGENT));
}
#[test]
fn test_build_headers_with_invalid_header_name_or_value() {
setup();
let mut headers_map = HashMap::new();
headers_map.insert("Invalid Header!".to_string(), "value".to_string());
let result = build_headers_with_user_agent(Some(headers_map));
// Should return an error for invalid header name
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("Invalid header name"));
let mut headers_map = HashMap::new();
// Header values cannot contain newlines
headers_map.insert("X-Custom".to_string(), "value\nwith\nnewlines".to_string());
let result = build_headers_with_user_agent(Some(headers_map));
// Should return an error for invalid header value
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("Invalid header value"));
}
}

View File

@@ -1144,6 +1144,7 @@ dependencies = [
"env_logger",
"futures",
"getrandom 0.3.4",
"http",
"js-sys",
"log",
"mdb_shard",

View File

@@ -22,6 +22,7 @@ console_log = { version = "1.0.0", features = ["color"] }
env_logger = "0.11"
futures = "0.3"
getrandom = { version = "0.3", features = ["wasm_js"] }
http = "1"
js-sys = "0.3.77"
log = "0.4"
serde = { version = "1", features = ["derive"] }

View File

@@ -5,6 +5,7 @@ use cas_client::{Client, RemoteClient};
use cas_object::SerializedCasObject;
use deduplication::constants::{MAX_XORB_BYTES, MAX_XORB_CHUNKS};
use deduplication::{DataAggregator, DeduplicationMetrics, RawXorbData};
use http::header::{self, HeaderValue};
use mdb_shard::MDBShardInfo;
use mdb_shard::shard_in_memory::MDBInMemoryShard;
use merklehash::{HashedWrite, MerkleHash};
@@ -38,12 +39,21 @@ pub struct FileUploadSession {
impl FileUploadSession {
pub fn new(config: Arc<TranslatorConfig>) -> Self {
let headers = match HeaderValue::from_str(&config.data_config.user_agent) {
Ok(value) => {
let mut headers = http::HeaderMap::new();
headers.insert(header::USER_AGENT, value);
Some(Arc::new(headers))
}
Err(_) => None
};
let client = RemoteClient::new(
&config.data_config.endpoint,
&config.data_config.auth,
&config.session_id,
false,
&config.data_config.user_agent,
headers,
);
let xorb_uploader =

View File

@@ -8,6 +8,7 @@ cas_client = { path = "../cas_client" }
anyhow = { workspace = true }
async-trait = { workspace = true }
http = { workspace = true }
reqwest = { workspace = true }
reqwest-middleware = { workspace = true }
serde = { workspace = true }

View File

@@ -3,6 +3,7 @@ use std::sync::Arc;
use cas_client::exports::ClientWithMiddleware;
use cas_client::retry_wrapper::RetryWrapper;
use cas_client::{Api, build_http_client};
use http::header::HeaderMap;
use urlencoding::encode;
use crate::auth::CredentialHelper;
@@ -46,15 +47,15 @@ impl HubClient {
endpoint: &str,
repo_info: RepoInfo,
reference: Option<String>,
user_agent: &str,
session_id: &str,
cred_helper: Arc<dyn CredentialHelper>,
custom_headers: Option<Arc<HeaderMap>>,
) -> Result<Self> {
Ok(HubClient {
endpoint: endpoint.to_owned(),
repo_info,
reference,
client: build_http_client(session_id, user_agent, None)?,
client: build_http_client(session_id, None, custom_headers)?,
cred_helper,
})
}
@@ -109,6 +110,10 @@ impl HubClient {
#[cfg(test)]
mod tests {
use std::sync::Arc;
use http::header::{self, HeaderMap, HeaderValue};
use super::HubClient;
use crate::errors::Result;
use crate::{BearerCredentialHelper, HFRepoType, Operation, RepoInfo};
@@ -117,6 +122,8 @@ mod tests {
#[ignore = "need valid write token"]
async fn test_get_jwt_token_with_hf_write_token() -> Result<()> {
let cred_helper = BearerCredentialHelper::new("[hf_write_token]".to_owned(), "");
let mut headers = HeaderMap::new();
headers.insert(header::USER_AGENT, HeaderValue::from_static("xtool"));
let hub_client = HubClient::new(
"https://huggingface.co",
RepoInfo {
@@ -124,9 +131,9 @@ mod tests {
full_name: "seanses/tm".into(),
},
Some("main".into()),
"xtool",
"",
cred_helper,
Some(Arc::new(headers)),
)?;
let read_info = hub_client.get_cas_jwt(Operation::Upload).await?;
@@ -142,6 +149,8 @@ mod tests {
#[ignore = "need valid read token and pr created on hub"]
async fn test_get_jwt_token_with_hf_read_token_pr_branch() -> Result<()> {
let cred_helper = BearerCredentialHelper::new("[hf_read_token]".to_owned(), "");
let mut headers = HeaderMap::new();
headers.insert(header::USER_AGENT, HeaderValue::from_static("xtool"));
let hub_client = HubClient::new(
"https://huggingface.co",
RepoInfo {
@@ -149,9 +158,9 @@ mod tests {
full_name: "seanses/tm".into(),
},
Some("refs/pr/1".into()),
"xtool",
"",
cred_helper,
Some(Arc::new(headers)),
)?;
let read_info = hub_client.get_cas_jwt(Operation::Upload).await?;
@@ -167,6 +176,8 @@ mod tests {
#[ignore = "need valid read token"]
async fn test_get_jwt_token_with_hf_read_token_create_pr() -> Result<()> {
let cred_helper = BearerCredentialHelper::new("[hf_read_token]".to_owned(), "");
let mut headers = HeaderMap::new();
headers.insert(header::USER_AGENT, HeaderValue::from_static("xtool"));
let hub_client = HubClient::new(
"https://huggingface.co",
RepoInfo {
@@ -174,9 +185,9 @@ mod tests {
full_name: "seanses/tm".into(),
},
None,
"xtool",
"",
cred_helper,
Some(Arc::new(headers)),
)?;
let read_info = hub_client.get_cas_jwt(Operation::Upload).await?;

View File

@@ -25,7 +25,6 @@ futures = { workspace = true }
lazy_static = { workspace = true }
pin-project = { workspace = true }
serde = { workspace = true }
shellexpand = { workspace = true, features = ["path"] }
thiserror = { workspace = true }
tokio = { workspace = true, features = [
"time",
@@ -39,6 +38,7 @@ tracing = { workspace = true }
[target.'cfg(not(target_family = "wasm"))'.dependencies]
bincode = { workspace = true }
rand = { workspace = true }
shellexpand = { workspace = true, features = ["path"] }
tokio-util = { workspace = true, features = ["io"] }
[target.'cfg(not(target_family = "wasm"))'.dev-dependencies]