Better support "xet-write-token" API authorization model and LFS Batch Api change (#498)

1. This PR updates the hub client xet access token request to use custom
rev in addition to the default "main". This better supports the
"xet-write-token" API authorization model:
Clients can get a xet write token, if
- the "rev" is a regular branch, with a HF write token;
- the "rev" is a pr branch with an corresponding open PR, with a HF
write or read token;
- it intends to create a pr and repo is enabled for discussion, with a
HF write or read token.

2. Fixed a bug when getting the current branch name in a repo, which
didn't parse branch names with "/" correctly: change
`refs_heads_branch.rsplit('/').next()` to
`refs_heads_branch.strip_prefix("refs/heads/")`.

3. Also updated xet transfer agent to use the refresh route in the LFS
Batch Api
[response](e3be2b3c8f/server/app/gitHostingRoutes.ts (L1713)).

4. Use the session id in the LFS Batch Api
[response](e3be2b3c8f/server/app/gitHostingRoutes.ts (L1657))
for token refresh and CAS requests.
This commit is contained in:
Di Xiao
2025-09-23 16:07:54 -07:00
committed by GitHub
parent 15942e295e
commit 75952ae618
29 changed files with 459 additions and 251 deletions

7
Cargo.lock generated
View File

@@ -1561,6 +1561,7 @@ dependencies = [
"serde_json",
"thiserror 2.0.12",
"tokio",
"urlencoding",
]
[[package]]
@@ -3968,6 +3969,12 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "urlencoding"
version = "2.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da"
[[package]]
name = "utf8_iter"
version = "1.0.4"

View File

@@ -102,6 +102,7 @@ tower-service = "0.3"
tracing = "0.1"
ulid = "1.2"
url = "2.5"
urlencoding = "2.1"
uuid = "1"
walkdir = "2"
web-time = "1.1"

View File

@@ -11,7 +11,7 @@ use clap::{Args, Parser, Subcommand};
use data::data_client::default_config;
use data::migration_tool::hub_client_token_refresher::HubClientTokenRefresher;
use data::migration_tool::migrate::migrate_files_impl;
use hub_client::{BearerCredentialHelper, HubClient, Operation};
use hub_client::{BearerCredentialHelper, HubClient, Operation, RepoInfo};
use merklehash::MerkleHash;
use utils::auth::TokenRefresher;
use walkdir::WalkDir;
@@ -56,8 +56,14 @@ impl XCommand {
.unwrap_or_else(|| std::env::var("HF_TOKEN").unwrap_or_default());
let cred_helper = BearerCredentialHelper::new(token, "");
let hub_client =
HubClient::new(&endpoint, &self.overrides.repo_type, &self.overrides.repo_id, "xtool", "", cred_helper)?;
let hub_client = HubClient::new(
&endpoint,
RepoInfo::try_from(&self.overrides.repo_type, &self.overrides.repo_id)?,
Some("main".to_owned()),
"xtool",
"",
cred_helper,
)?;
self.command.run(hub_client).await
}

View File

@@ -121,4 +121,15 @@ impl TranslatorConfig {
..self
}
}
pub fn with_session_id(self, session_id: &str) -> Self {
if session_id.is_empty() {
return self;
}
Self {
session_id: Some(session_id.to_owned()),
..self
}
}
}

View File

@@ -2,7 +2,7 @@ use std::sync::Arc;
use anyhow::Result;
use cas_object::CompressionScheme;
use hub_client::{BearerCredentialHelper, HubClient, Operation};
use hub_client::{BearerCredentialHelper, HubClient, Operation, RepoInfo};
use mdb_shard::file_structs::MDBFileInfo;
use tracing::{Instrument, Span, info_span, instrument};
use utils::auth::TokenRefresher;
@@ -33,7 +33,8 @@ pub async fn migrate_with_external_runtime(
repo_id: &str,
) -> Result<()> {
let cred_helper = BearerCredentialHelper::new(hub_token.to_owned(), "");
let hub_client = HubClient::new(hub_endpoint, repo_type, repo_id, "xtool", "", cred_helper)?;
let hub_client =
HubClient::new(hub_endpoint, RepoInfo::try_from(repo_type, repo_id)?, None, "xtool", "", cred_helper)?;
migrate_files_impl(file_paths, false, hub_client, cas_endpoint, None, false).await?;

View File

@@ -2,7 +2,7 @@ use std::path::PathBuf;
use crate::app::Command::Transfer;
use crate::constants::{GIT_LFS_CUSTOM_TRANSFER_AGENT_NAME, GIT_LFS_CUSTOM_TRANSFER_AGENT_PROGRAM};
use crate::errors::{Result, config_error};
use crate::errors::{GitXetError, Result};
use crate::git_process_wrapping::run_git_captured;
#[derive(Default)]
@@ -48,7 +48,7 @@ fn install_impl(location: ConfigLocation, concurrency: Option<u32>) -> Result<()
let concurrent = if let Some(c) = concurrency {
if c == 0 {
return Err(config_error("concurrency can't be 0"));
return Err(GitXetError::config_error("concurrency can't be 0"));
}
if c == 1 {
"false"

View File

@@ -7,14 +7,18 @@ use data::FileUploadSession;
use data::data_client::{clean_file, default_config};
use hub_client::Operation;
use progress_tracking::{ProgressUpdate, TrackingProgressUpdater};
use utils::auth::TokenRefresher;
use crate::constants::{HF_ENDPOINT_ENV, XET_ACCESS_TOKEN_HEADER, XET_TOKEN_EXPIRATION_HEADER};
use crate::errors::{Result, config_error, internal, not_supported};
use crate::constants::{
HF_ENDPOINT_ENV, XET_ACCESS_TOKEN_HEADER, XET_CAS_URL, XET_SESSION_ID, XET_TOKEN_EXPIRATION_HEADER,
};
use crate::errors::{GitXetError, Result};
use crate::git_repo::GitRepo;
use crate::git_url::{GitUrl, Scheme};
use crate::hub_client_token_refresher::HubClientTokenRefresher;
use crate::lfs_agent_protocol::errors::bad_syntax;
use crate::lfs_agent_protocol::{InitRequestInner, ProgressUpdater, TransferAgent, TransferRequest};
use crate::lfs_agent_protocol::{
GitLFSProtocolError, InitRequestInner, ProgressUpdater, TransferAgent, TransferRequest,
};
use crate::token_refresher::DirectRefreshRouteTokenRefresher;
// This implements a Git LFS custom transfer agent that uploads and downloads files using the Xet protocol.
#[derive(Default)]
@@ -38,7 +42,7 @@ impl TransferAgent for XetAgent {
let hf_endpoint = if !matches!(remote_url.scheme(), Scheme::Http | Scheme::Https) && remote_url.port().is_some()
{
Some(std::env::var(HF_ENDPOINT_ENV).map_err(|_| {
config_error(
GitXetError::config_error(
r#"This repository has a non-standard Hugging Face remote URL,
please specify the Hugging Face server endpoint using environment variable "HF_ENDPOINT""#,
)
@@ -55,7 +59,7 @@ impl TransferAgent for XetAgent {
}
async fn init_download(&mut self, _: &InitRequestInner) -> Result<()> {
Err(not_supported(
Err(GitXetError::not_supported(
"custom transfer for download is not implemented yet. Downloads should operate through standard git-lfs download protocol.
If you encounter errors downloading, contact Xet Team at Hugging Face.",
))
@@ -70,14 +74,15 @@ impl TransferAgent for XetAgent {
// so that if the internal git credential helper needs to prompt the user for credential,
// only one prompt is presented.
let repo = self.repo.get().unwrap(); // protocol state guarantees self.repo is set.
let token_refresher = HubClientTokenRefresher::new(
let session_id = req.action.header.get(XET_SESSION_ID).map(|s| s.as_str()).unwrap_or_default();
let token_refresher: Arc<dyn TokenRefresher> = Arc::new(DirectRefreshRouteTokenRefresher::new(
repo,
self.remote_url.clone(),
self.hf_endpoint.clone(),
&req.action.href,
Operation::Upload,
"",
)?;
session_id,
)?);
// From git-lfs:
// > First worker is the only one allowed to start immediately.
// > The rest wait until successful response from 1st worker to
@@ -94,16 +99,33 @@ impl TransferAgent for XetAgent {
updater: progress_updater,
};
let cas_url = req.action.href.clone();
let token = req.action.header[XET_ACCESS_TOKEN_HEADER].clone();
let token_expiry: u64 = req.action.header[XET_TOKEN_EXPIRATION_HEADER].parse().map_err(internal)?;
let cas_url = req
.action
.header
.get(XET_CAS_URL)
.ok_or_else(|| GitXetError::internal("Hugging Face Hub didn't provide a CAS URL"))?
.clone();
let token = req
.action
.header
.get(XET_ACCESS_TOKEN_HEADER)
.ok_or_else(|| GitXetError::internal("Hugging Face Hub didn't provide a CAS access token"))?
.clone();
let token_expiry: u64 = req
.action
.header
.get(XET_TOKEN_EXPIRATION_HEADER)
.ok_or_else(|| GitXetError::internal("Hugging Face Hub didn't provide a CAS access token expiration"))?
.parse()
.map_err(GitXetError::internal)?;
let config = default_config(cas_url, None, Some((token, token_expiry)), Some(Arc::new(token_refresher)))?
.disable_progress_aggregation(); // upload one file at a time so no need for the heavy progress aggregator
let config = default_config(cas_url, None, Some((token, token_expiry)), Some(token_refresher))?
.disable_progress_aggregation()
.with_session_id(session_id); // upload one file at a time so no need for the heavy progress aggregator
let session = FileUploadSession::new(config.into(), Some(Arc::new(xet_updater))).await?;
let Some(file_path) = &req.path else {
return Err(bad_syntax("file path not provided for upload request").into());
return Err(GitLFSProtocolError::bad_syntax("file path not provided for upload request").into());
};
clean_file(session.clone(), file_path).await?;

View File

@@ -5,7 +5,7 @@ use hub_client::{BearerCredentialHelper, CredentialHelper, NoopCredentialHelper,
use netrc::Netrc;
use crate::constants::HF_TOKEN_ENV;
use crate::errors::{GitXetError, Result, config_error};
use crate::errors::{GitXetError, Result};
use crate::git_repo::GitRepo;
use crate::git_url::{GitUrl, Scheme};
@@ -58,7 +58,7 @@ impl FromStr for AccessMode {
"private" => Ok(AccessMode::Private),
"negotiate" => Ok(AccessMode::Negotiate),
"" => Ok(AccessMode::Empty),
_ => Err(config_error(format!("invalid \"lfs.<url>.access\" type: {s}"))),
_ => Err(GitXetError::config_error(format!("invalid \"lfs.<url>.access\" type: {s}"))),
}
}
}
@@ -129,7 +129,7 @@ pub fn get_credential(repo: &GitRepo, remote_url: &GitUrl, operation: Operation)
#[cfg(unix)]
return Ok(SSHCredentialHelper::new(remote_url, operation));
#[cfg(not(unix))]
return Err(crate::errors::not_supported(format!(
return Err(GitXetError::not_supported(format!(
"using {} in a repository with SSH Git URL is under development; please check back for
upgrades or contact Xet Team at Hugging Face.",
crate::constants::GIT_LFS_CUSTOM_TRANSFER_AGENT_PROGRAM

View File

@@ -4,7 +4,7 @@ use std::sync::Arc;
use hub_client::BearerCredentialHelper;
use crate::errors::{Result, config_error};
use crate::errors::{GitXetError, Result};
use crate::git_process_wrapping::run_git_captured_with_input_and_output;
// This implements the mechanism to get credential stored in configured git credential helpers, including
@@ -76,7 +76,7 @@ impl GitCredentialHelper {
}
}
Err(config_error(format!("failed to find authentication for {host_url}")))
Err(GitXetError::config_error(format!("failed to find authentication for {host_url}")))
}
}

View File

@@ -1,7 +1,7 @@
use std::sync::Arc;
use async_trait::async_trait;
use hub_client::{CredentialHelper, Operation, Result, credential_helper_error};
use hub_client::{CredentialHelper, HubClientError, Operation, Result};
#[cfg(unix)]
use openssh::{KnownHosts, Session};
use reqwest::header;
@@ -46,11 +46,11 @@ impl SSHCredentialHelper {
#[cfg(unix)]
async fn authenticate(&self) -> Result<GitLFSAuthenticateResponse> {
let host_url = self.remote_url.host_url().map_err(credential_helper_error)?;
let host_url = self.remote_url.host_url().map_err(HubClientError::credential_helper_error)?;
let full_repo_path = self.remote_url.full_repo_path();
let session = Session::connect(&host_url, KnownHosts::Add)
.await
.map_err(credential_helper_error)?;
.map_err(HubClientError::credential_helper_error)?;
let output = session
.command("git-lfs-authenticate")
@@ -58,9 +58,9 @@ impl SSHCredentialHelper {
.arg(self.operation.as_str())
.output()
.await
.map_err(credential_helper_error)?;
.map_err(HubClientError::credential_helper_error)?;
serde_json::from_slice(&output.stdout).map_err(credential_helper_error)
serde_json::from_slice(&output.stdout).map_err(HubClientError::credential_helper_error)
}
#[cfg(not(unix))]

View File

@@ -7,8 +7,10 @@ pub const GIT_LFS_CUSTOM_TRANSFER_AGENT_PROGRAM: &str = "git-xet";
pub const CURRENT_VERSION: &str = env!("CARGO_PKG_VERSION");
// Moon-landing Xet service headers
pub const XET_CAS_URL: &str = "X-Xet-Cas-Url";
pub const XET_ACCESS_TOKEN_HEADER: &str = "X-Xet-Access-Token";
pub const XET_TOKEN_EXPIRATION_HEADER: &str = "X-Xet-Token-Expiration";
pub const XET_SESSION_ID: &str = "X-Xet-Session-Id";
// Environment variable names
pub const HF_TOKEN_ENV: &str = "HF_TOKEN";

View File

@@ -48,23 +48,25 @@ pub enum GitXetError {
pub type Result<T> = std::result::Result<T, GitXetError>;
pub(crate) fn git_cmd_failed(e: impl Display, source: Option<std::io::Error>) -> GitXetError {
GitXetError::GitCommandFailed {
reason: e.to_string(),
source,
impl GitXetError {
pub(crate) fn git_cmd_failed(e: impl Display, source: Option<std::io::Error>) -> GitXetError {
GitXetError::GitCommandFailed {
reason: e.to_string(),
source,
}
}
}
pub(crate) fn not_supported(e: impl Display) -> GitXetError {
GitXetError::NotSupported(e.to_string())
}
pub(crate) fn not_supported(e: impl Display) -> GitXetError {
GitXetError::NotSupported(e.to_string())
}
pub(crate) fn config_error(e: impl Display) -> GitXetError {
GitXetError::InvalidGitConfig(e.to_string())
}
pub(crate) fn config_error(e: impl Display) -> GitXetError {
GitXetError::InvalidGitConfig(e.to_string())
}
pub(crate) fn internal(e: impl Display) -> GitXetError {
GitXetError::Internal(e.to_string())
pub(crate) fn internal(e: impl Display) -> GitXetError {
GitXetError::Internal(e.to_string())
}
}
impl From<CasClientError> for GitXetError {

View File

@@ -3,7 +3,7 @@ use std::path::Path;
use std::process::{Child, ChildStdin, Command, Stdio};
use crate::constants::GIT_EXECUTABLE;
use crate::errors::{Result, git_cmd_failed, internal};
use crate::errors::{GitXetError, Result};
// This mod implements utilities to invoke Git commands through child processes from the `git` program.
@@ -74,8 +74,8 @@ impl CapturedCommand {
// From past experience, if the "git" program is not found the underlying error
// only says "Not Found" and is not very helpful to identify the cause. We thus
// capture this error and make the message more explicit.
std::io::ErrorKind::NotFound => git_cmd_failed(r#"program "git" not found"#, Some(e)),
_ => git_cmd_failed("internal", Some(e)),
std::io::ErrorKind::NotFound => GitXetError::git_cmd_failed(r#"program "git" not found"#, Some(e)),
_ => GitXetError::git_cmd_failed("internal", Some(e)),
})?,
})
}
@@ -94,7 +94,7 @@ impl CapturedCommand {
self.child_process
.stdin
.take()
.ok_or_else(|| internal("stdin of child process is not captured"))
.ok_or_else(|| GitXetError::internal("stdin of child process is not captured"))
}
/// Synchronously wait for the child to exit completely, returning `Ok(())` if the child exits with status code 0;
@@ -117,7 +117,7 @@ impl CapturedCommand {
_ => {
let stdout = std::str::from_utf8(&ret.stdout).unwrap_or("<Binary Data>").trim();
let stderr = std::str::from_utf8(&ret.stderr).unwrap_or("<Binary Data>").trim();
Err(git_cmd_failed(
Err(GitXetError::git_cmd_failed(
format!("err_code = {:?}, stdout = \"{}\", stderr = \"{}\"", ret.status.code(), stdout, stderr),
None,
))

View File

@@ -3,7 +3,7 @@ use std::sync::{Arc, Mutex};
use git2::{Config, Repository};
use crate::errors::{GitXetError, Result, config_error, internal};
use crate::errors::{GitXetError, Result};
use crate::git_url::GitUrl;
#[derive(Clone)]
@@ -33,10 +33,27 @@ impl GitRepo {
// Returns the path to the .git folder for normal repositories or the repository itself for bare repositories.
pub fn git_path(&self) -> Result<PathBuf> {
let repo = self.repo.lock().map_err(internal)?;
let repo = self.repo.lock().map_err(GitXetError::internal)?;
Ok(repo.path().to_path_buf())
}
// Resolves the reference pointed at by HEAD, returns branch name if it is a branch.
pub fn branch_name(&self) -> Result<Option<String>> {
let repo = self.repo.lock().map_err(GitXetError::internal)?;
let maybe_head_ref = repo.head();
Ok(maybe_head_ref.ok().and_then(|head_ref| {
if head_ref.is_branch() {
head_ref
.name()
.and_then(|refs_heads_branch| refs_heads_branch.strip_prefix("refs/heads/"))
.map(|branch| branch.to_owned())
} else {
None
}
}))
}
// Returns the remote that a git push/fetch/pull operation
// is targeted at, based on:
// 1. The currently tracked remote branch, if present
@@ -44,20 +61,9 @@ impl GitRepo {
// 3. Any other SINGLE remote defined in .git/config
// 4. Use "origin" as a fallback.
pub fn remote_name(&self) -> Result<String> {
let repo = self.repo.lock().map_err(internal)?;
let maybe_head_ref = repo.head();
let maybe_branch_name = maybe_head_ref.ok().and_then(|head_ref| {
if head_ref.is_branch() {
head_ref
.name()
.and_then(|refs_heads_branch| refs_heads_branch.rsplit('/').next())
.map(|branch| branch.to_owned())
} else {
None
}
});
let maybe_branch_name = self.branch_name()?;
let repo = self.repo.lock().map_err(GitXetError::internal)?;
let config = repo.config()?.snapshot()?;
// try tracking remote
@@ -87,13 +93,13 @@ impl GitRepo {
// Returns the URL for a specific remote name.
pub fn remote_name_to_url(&self, remote: &str) -> Result<GitUrl> {
let repo = self.repo.lock().map_err(internal)?;
let repo = self.repo.lock().map_err(GitXetError::internal)?;
let url: GitUrl = repo
.find_remote(remote)?
.url()
.map(|s| s.to_string())
.ok_or_else(|| config_error(format!("no url for remote \"{remote}\"")))?
.ok_or_else(|| GitXetError::config_error(format!("no url for remote \"{remote}\"")))?
.parse()?;
Ok(url)
@@ -108,7 +114,7 @@ impl GitRepo {
// Returns a snapshot of the current Git repo config.
pub fn config(&self) -> Result<Config> {
let repo = self.repo.lock().map_err(internal)?;
let repo = self.repo.lock().map_err(GitXetError::internal)?;
Ok(repo.config()?.snapshot()?)
}
@@ -122,6 +128,25 @@ mod tests {
use crate::git_repo::GitRepo;
use crate::test_utils::TestRepo;
#[test]
#[serial(env_var_write_read)]
fn test_get_ref_name() -> Result<()> {
let test_repo = TestRepo::new("main")?;
let repo = GitRepo::open(test_repo.path())?;
test_repo.new_commit("data", "hello".as_bytes(), "add new file")?;
assert_eq!(repo.branch_name()?, Some("main".to_owned()));
test_repo.new_branch("pr/1", "main")?;
test_repo.new_commit("data", "world".as_bytes(), "update file")?;
assert_eq!(repo.branch_name()?, Some("pr/1".to_owned()));
test_repo.checkout(&["HEAD^"])?;
assert_eq!(repo.branch_name()?, None);
Ok(())
}
#[test]
#[serial(env_var_write_read)]
fn test_get_remote_from_local_config() -> Result<()> {

View File

@@ -1,10 +1,10 @@
use std::fmt::Display;
use std::str::FromStr;
use git_url_parse::GitUrl as innerGitUrl;
pub use git_url_parse::Scheme;
use hub_client::{HFRepoType, RepoInfo};
use crate::errors::{GitXetError, Result, config_error, not_supported};
use crate::errors::{GitXetError, Result};
// This mod implements funtionalities to handle Git remote URLs, especially tailored for
// Git LFS and Hugging Face repo needs, including deriving Git LFS server endpoint from
@@ -62,7 +62,7 @@ impl GitUrl {
.inner
.host
.as_ref()
.ok_or_else(|| config_error("remote URL missing host name"))?;
.ok_or_else(|| GitXetError::config_error("remote URL missing host name"))?;
let port = self.inner.port;
let port_str = if translated || port.is_none() {
@@ -86,7 +86,7 @@ impl GitUrl {
.inner
.host
.as_ref()
.ok_or_else(|| config_error("remote URL missing host name"))?;
.ok_or_else(|| GitXetError::config_error("remote URL missing host name"))?;
let port = self.inner.port;
let port_str = if translated || port.is_none() {
@@ -105,11 +105,14 @@ impl GitUrl {
match self.inner.scheme {
Scheme::Http => Ok(("http", false)),
Scheme::Https => Ok(("https", false)),
Scheme::File | Scheme::Ftp | Scheme::Ftps => {
Err(not_supported(format!("cannot convert from scheme \"{}://\" to \"http(s)://\"", self.inner.scheme)))
},
Scheme::File | Scheme::Ftp | Scheme::Ftps => Err(GitXetError::not_supported(format!(
"cannot convert from scheme \"{}://\" to \"http(s)://\"",
self.inner.scheme
))),
Scheme::Git | Scheme::GitSsh | Scheme::Ssh => Ok(("https", true)),
Scheme::Unspecified => Err(not_supported("cannot convert from unspecified scheme to \"http(s)://\"")),
Scheme::Unspecified => {
Err(GitXetError::not_supported("cannot convert from unspecified scheme to \"http(s)://\""))
},
}
}
@@ -143,7 +146,7 @@ impl GitUrl {
.inner
.host
.as_ref()
.ok_or_else(|| config_error("remote URL missing host name"))?;
.ok_or_else(|| GitXetError::config_error("remote URL missing host name"))?;
let port_str = if let Some(p) = self.inner.port {
format!(":{}", p)
@@ -171,6 +174,7 @@ impl GitUrl {
}
// Returns the parsed full repo path into `RepoInfo`.
#[allow(unused)]
pub fn repo_info(&self) -> Result<RepoInfo> {
let path = self.full_repo_path();
let full_name = self.inner.fullname.clone(); // The full name of the repo, formatted as "owner/name"
@@ -181,58 +185,6 @@ impl GitUrl {
}
}
// This defines the exact three types of repos served on HF Hub.
#[derive(Debug, PartialEq)]
pub enum HFRepoType {
Model,
Dataset,
Space,
}
impl FromStr for HFRepoType {
type Err = GitXetError;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"" => Ok(HFRepoType::Model), // when repo type is omitted from the URL the default type is "model"
"model" | "models" => Ok(HFRepoType::Model),
"dataset" | "datasets" => Ok(HFRepoType::Dataset),
"space" | "spaces" => Ok(HFRepoType::Space),
t => Err(config_error(format!("invalid repo type {t}"))),
}
}
}
impl HFRepoType {
pub fn as_str(&self) -> &str {
match self {
HFRepoType::Model => "model",
HFRepoType::Dataset => "dataset",
HFRepoType::Space => "space",
}
}
}
impl Display for HFRepoType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Debug, PartialEq)]
pub struct RepoInfo {
// The type of a repo, one of "model | dataset | space"
pub repo_type: HFRepoType,
// The full name of a repo, formatted as "owner/name"
pub full_name: String,
}
impl Display for RepoInfo {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}/{}", self.repo_type, self.full_name)
}
}
#[cfg(test)]
mod test_lfs_server_discovery {
use super::GitUrl;

View File

@@ -1,67 +0,0 @@
use std::sync::Arc;
use async_trait::async_trait;
use hub_client::{HubClient, Operation};
use utils::auth::{TokenInfo, TokenRefresher};
use utils::errors::AuthError;
use crate::auth::get_credential;
use crate::constants::GIT_LFS_CUSTOM_TRANSFER_AGENT_PROGRAM;
use crate::errors::Result;
use crate::git_repo::GitRepo;
use crate::git_url::GitUrl;
pub struct HubClientTokenRefresher {
operation: Operation,
client: Arc<HubClient>,
}
impl HubClientTokenRefresher {
pub fn new(
repo: &GitRepo,
remote_url: Option<GitUrl>,
token_endpoint: Option<String>,
operation: Operation,
session_id: &str,
) -> Result<Self> {
let remote_url = match remote_url {
Some(r) => r,
None => repo.remote_url()?,
};
let repo_info = remote_url.repo_info()?;
let endpoint = match token_endpoint {
Some(e) => e,
None => remote_url.to_derived_http_host_url()?,
};
let cred_helper = get_credential(repo, &remote_url, operation)?;
let client = HubClient::new(
&endpoint,
repo_info.repo_type.as_str(),
&repo_info.full_name,
GIT_LFS_CUSTOM_TRANSFER_AGENT_PROGRAM,
session_id,
cred_helper,
)?;
Ok(Self {
operation,
client: Arc::new(client),
})
}
}
#[async_trait]
impl TokenRefresher for HubClientTokenRefresher {
async fn refresh(&self) -> std::result::Result<TokenInfo, AuthError> {
let jwt_info = self
.client
.get_cas_jwt(self.operation)
.await
.map_err(AuthError::token_refresh_failure)?;
Ok((jwt_info.access_token, jwt_info.exp))
}
}

View File

@@ -2,7 +2,7 @@ use std::io::{BufRead, Write};
use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use crate::errors::{Result, internal};
use crate::errors::{GitXetError, Result};
mod agent_state;
pub mod errors;
@@ -116,7 +116,7 @@ where
},
};
stdout.lock().map_err(internal)?.write_all(response.as_bytes())?;
stdout.lock().map_err(GitXetError::internal)?.write_all(response.as_bytes())?;
}
Ok(())

View File

@@ -1,4 +1,4 @@
use super::errors::{Result, bad_state};
use super::errors::{GitLFSProtocolError, Result};
// This defines the state of a transfer agent to make sure that request events are initiated
// in the correct order. Unlike a traditional state machines, we don't define a "terminated"
@@ -18,27 +18,27 @@ impl LFSAgentState {
match self {
Self::PendingInit => match to {
Self::InitedForUpload | Self::InitedForDownload => (),
_ => return Err(bad_state("init event not yet received")),
_ => return Err(GitLFSProtocolError::bad_state("init event not yet received")),
},
Self::InitedForUpload => match to {
Self::Uploading => (),
Self::Downloading => return Err(bad_state("agent initiated for upload")),
_ => return Err(bad_state("init event already received")),
Self::Downloading => return Err(GitLFSProtocolError::bad_state("agent initiated for upload")),
_ => return Err(GitLFSProtocolError::bad_state("init event already received")),
},
Self::InitedForDownload => match to {
Self::Downloading => (),
Self::Uploading => return Err(bad_state("agent initiated for download")),
_ => return Err(bad_state("init event already received")),
Self::Uploading => return Err(GitLFSProtocolError::bad_state("agent initiated for download")),
_ => return Err(GitLFSProtocolError::bad_state("init event already received")),
},
Self::Uploading => match to {
Self::Uploading => (),
Self::Downloading => return Err(bad_state("agent initiated for upload")),
_ => return Err(bad_state("data transfer already in progress")),
Self::Downloading => return Err(GitLFSProtocolError::bad_state("agent initiated for upload")),
_ => return Err(GitLFSProtocolError::bad_state("data transfer already in progress")),
},
Self::Downloading => match to {
Self::Downloading => (),
Self::Uploading => return Err(bad_state("agent initiated for download")),
_ => return Err(bad_state("data transfer already in progress")),
Self::Uploading => return Err(GitLFSProtocolError::bad_state("agent initiated for download")),
_ => return Err(GitLFSProtocolError::bad_state("data transfer already in progress")),
},
};

View File

@@ -22,14 +22,16 @@ pub enum GitLFSProtocolError {
pub(super) type Result<T> = std::result::Result<T, GitLFSProtocolError>;
pub(crate) fn bad_syntax(e: impl Display) -> GitLFSProtocolError {
GitLFSProtocolError::Syntax(e.to_string())
}
impl GitLFSProtocolError {
pub(crate) fn bad_syntax(e: impl Display) -> GitLFSProtocolError {
GitLFSProtocolError::Syntax(e.to_string())
}
pub(crate) fn bad_argument(e: impl Display) -> GitLFSProtocolError {
GitLFSProtocolError::Argument(e.to_string())
}
pub(crate) fn bad_argument(e: impl Display) -> GitLFSProtocolError {
GitLFSProtocolError::Argument(e.to_string())
}
pub(crate) fn bad_state(e: impl Display) -> GitLFSProtocolError {
GitLFSProtocolError::State(e.to_string())
pub(crate) fn bad_state(e: impl Display) -> GitLFSProtocolError {
GitLFSProtocolError::State(e.to_string())
}
}

View File

@@ -6,7 +6,7 @@ use std::str::FromStr;
use serde::{Deserialize, Serialize};
use serde_json::json;
use super::errors::{GitLFSProtocolError, Result, bad_argument, bad_syntax};
use super::errors::{GitLFSProtocolError, Result};
// This file defines the protocol that Git LFS uses to talk to
// custom transfer agents. This implementation follows the protocol specification
@@ -105,37 +105,45 @@ impl LFSProtocolRequestEvent {
InitRequest::Download(inner) => inner,
};
if inner.remote.is_empty() {
return Err(bad_argument("invalid remote"));
return Err(GitLFSProtocolError::bad_argument("invalid remote"));
}
Ok(())
},
LFSProtocolRequestEvent::Upload(req) => {
if req.oid.len() != OID_LEN {
return Err(bad_argument("invalid oid"));
return Err(GitLFSProtocolError::bad_argument("invalid oid"));
}
if req.size == 0 {
return Err(bad_argument("invalid size"));
return Err(GitLFSProtocolError::bad_argument("invalid size"));
}
if req.path.is_none() {
return Err(bad_syntax("file path not provided for upload request"));
return Err(GitLFSProtocolError::bad_syntax("file path not provided for upload request"));
}
if req.action.href.is_empty() {
return Err(GitLFSProtocolError::bad_argument("empty action.href in server response"));
}
Ok(())
},
LFSProtocolRequestEvent::Download(req) => {
if req.oid.len() != OID_LEN {
return Err(bad_argument("invalid oid"));
return Err(GitLFSProtocolError::bad_argument("invalid oid"));
}
if req.size == 0 {
return Err(bad_argument("invalid size"));
return Err(GitLFSProtocolError::bad_argument("invalid size"));
}
if req.path.is_some() {
return Err(bad_syntax("file path provided for download request"));
return Err(GitLFSProtocolError::bad_syntax("file path provided for download request"));
}
if req.action.href.is_empty() {
return Err(GitLFSProtocolError::bad_argument("empty action.href in server response"));
}
Ok(())
@@ -149,7 +157,7 @@ impl FromStr for LFSProtocolRequestEvent {
type Err = GitLFSProtocolError;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
let req: LFSProtocolRequestEvent = serde_json::from_str(s).map_err(bad_syntax)?;
let req: LFSProtocolRequestEvent = serde_json::from_str(s).map_err(GitLFSProtocolError::bad_syntax)?;
req.validate()?;
Ok(req)
}

View File

@@ -5,6 +5,6 @@ mod errors;
mod git_process_wrapping;
mod git_repo;
mod git_url;
mod hub_client_token_refresher;
mod lfs_agent_protocol;
mod test_utils;
mod token_refresher;

View File

@@ -68,6 +68,20 @@ impl TestRepo {
Ok(())
}
// Create a new branch `new_branch_name` off `base`.
pub fn new_branch(&self, new_branch_name: &str, base: &str) -> Result<()> {
run_git_captured(&self.repo_path, "checkout", &[base, "-b", new_branch_name])?;
Ok(())
}
// Run the versatile checkout command.
pub fn checkout(&self, args: &[&str]) -> Result<()> {
run_git_captured(&self.repo_path, "checkout", args)?;
Ok(())
}
// Create a new local branch `local_branch_name` off HEAD in this repo that tracks a remote branch
// `remote`:`remote_branch_name`.
pub fn new_branch_tracking_remote(
@@ -76,8 +90,8 @@ impl TestRepo {
remote_branch_name: &str,
local_branch_name: &str,
) -> Result<()> {
self.new_branch(local_branch_name, "HEAD")?;
run_git_captured(&self.repo_path, "fetch", &[remote, remote_branch_name])?;
run_git_captured(&self.repo_path, "checkout", &["-b", local_branch_name])?;
run_git_captured(
&self.repo_path,
"branch",

View File

@@ -0,0 +1,65 @@
use std::sync::Arc;
use async_trait::async_trait;
use cas_client::{Api, RetryConfig, build_http_client};
use hub_client::{CasJWTInfo, CredentialHelper, Operation};
use reqwest::header;
use reqwest_middleware::ClientWithMiddleware;
use utils::auth::{TokenInfo, TokenRefresher};
use utils::errors::AuthError;
use crate::auth::get_credential;
use crate::constants::GIT_LFS_CUSTOM_TRANSFER_AGENT_PROGRAM;
use crate::errors::Result;
use crate::git_repo::GitRepo;
use crate::git_url::GitUrl;
pub struct DirectRefreshRouteTokenRefresher {
refresh_route: String,
client: ClientWithMiddleware,
cred_helper: Arc<dyn CredentialHelper>,
}
impl DirectRefreshRouteTokenRefresher {
pub fn new(
repo: &GitRepo,
remote_url: Option<GitUrl>,
refresh_route: &str,
operation: Operation,
session_id: &str,
) -> Result<Self> {
let remote_url = match remote_url {
Some(r) => r,
None => repo.remote_url()?,
};
let cred_helper = get_credential(repo, &remote_url, operation)?;
Ok(Self {
refresh_route: refresh_route.to_owned(),
client: build_http_client(RetryConfig::default(), session_id)?,
cred_helper,
})
}
}
#[async_trait]
impl TokenRefresher for DirectRefreshRouteTokenRefresher {
async fn refresh(&self) -> std::result::Result<TokenInfo, AuthError> {
let req = self
.client
.get(&self.refresh_route)
.with_extension(Api("xet-token"))
.header(header::USER_AGENT, GIT_LFS_CUSTOM_TRANSFER_AGENT_PROGRAM);
let req = self
.cred_helper
.fill_credential(req)
.await
.map_err(AuthError::token_refresh_failure)?;
let response = req.send().await.map_err(AuthError::token_refresh_failure)?;
let jwt_info: CasJWTInfo = response.json().await.map_err(AuthError::token_refresh_failure)?;
Ok((jwt_info.access_token, jwt_info.exp))
}
}

7
hf_xet/Cargo.lock generated
View File

@@ -1355,6 +1355,7 @@ dependencies = [
"reqwest-middleware",
"serde",
"thiserror 2.0.15",
"urlencoding",
]
[[package]]
@@ -3793,6 +3794,12 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "urlencoding"
version = "2.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da"
[[package]]
name = "utf8_iter"
version = "1.0.4"

View File

@@ -13,6 +13,7 @@ reqwest = { workspace = true }
reqwest-middleware = { workspace = true }
serde = { workspace = true }
thiserror = { workspace = true }
urlencoding = { workspace = true }
[dev-dependencies]
serde_json = { workspace = true }

View File

@@ -3,10 +3,11 @@ use std::sync::Arc;
use cas_client::exports::ClientWithMiddleware;
use cas_client::{Api, RetryConfig, build_http_client};
use http::header;
use urlencoding::encode;
use crate::auth::CredentialHelper;
use crate::errors::*;
use crate::types::CasJWTInfo;
use crate::types::{CasJWTInfo, RepoInfo};
/// The type of operation to perform, either to upload files or to download files.
/// Different operations lead to CAS access token with different authorization levels.
@@ -34,8 +35,8 @@ impl Operation {
pub struct HubClient {
endpoint: String,
repo_type: String,
repo_id: String,
repo_info: RepoInfo,
reference: Option<String>,
user_agent: String,
client: ClientWithMiddleware,
cred_helper: Arc<dyn CredentialHelper>,
@@ -44,16 +45,16 @@ pub struct HubClient {
impl HubClient {
pub fn new(
endpoint: &str,
repo_type: &str,
repo_id: &str,
repo_info: RepoInfo,
reference: Option<String>,
user_agent: &str,
session_id: &str,
cred_helper: Arc<dyn CredentialHelper>,
) -> Result<Self> {
Ok(HubClient {
endpoint: endpoint.to_owned(),
repo_type: repo_type.to_owned(),
repo_id: repo_id.to_owned(),
repo_info,
reference,
user_agent: user_agent.to_owned(),
client: build_http_client(RetryConfig::default(), session_id)?,
cred_helper,
@@ -63,12 +64,27 @@ impl HubClient {
// Get CAS access token from Hub access token.
pub async fn get_cas_jwt(&self, operation: Operation) -> Result<CasJWTInfo> {
let endpoint = self.endpoint.as_str();
let repo_type = self.repo_type.as_str();
let repo_id = self.repo_id.as_str();
let repo_type = self.repo_info.repo_type.as_str();
let repo_id = self.repo_info.full_name.as_str();
let token_type = operation.token_type();
// The reference may contain "/" but the "xet-[]-token" API only parses "rev" from a single component,
// thus we encode the reference. It defaults to "main" if not specified by caller because the
// API route expects a "rev" component.
let rev = encode(self.reference.as_deref().unwrap_or("main"));
// Clients can get a xet write token, if
// - the "rev" is a regular branch, with a HF write token;
// - the "rev" is a pr branch, with a HF write or read token;
// - it intends to create a pr and repo is enabled for discussion, with a HF write or read token.
let query = if matches!(operation, Operation::Upload) && self.reference.is_none() {
"?create_pr=1"
} else {
""
};
// note that this API doesn't take a Basic auth
let url = format!("{endpoint}/api/{repo_type}s/{repo_id}/xet-{token_type}-token/main");
let url = format!("{endpoint}/api/{repo_type}s/{repo_id}/xet-{token_type}-token/{rev}{query}");
let req = self
.client
@@ -92,17 +108,79 @@ impl HubClient {
mod tests {
use super::HubClient;
use crate::errors::Result;
use crate::{BearerCredentialHelper, Operation};
use crate::{BearerCredentialHelper, HFRepoType, Operation, RepoInfo};
#[tokio::test]
#[ignore = "need valid token"]
async fn test_get_jwt_token() -> Result<()> {
let cred_helper = BearerCredentialHelper::new("[hf_token]".to_owned(), "");
let hub_client = HubClient::new("https://huggingface.co", "model", "seanses/tm", "xtool", "", cred_helper)?;
#[ignore = "need valid write token"]
async fn test_get_jwt_token_with_hf_write_token() -> Result<()> {
let cred_helper = BearerCredentialHelper::new("[hf_write_token]".to_owned(), "");
let hub_client = HubClient::new(
"https://huggingface.co",
RepoInfo {
repo_type: HFRepoType::Model,
full_name: "seanses/tm".into(),
},
Some("main".into()),
"xtool",
"",
cred_helper,
)?;
let read_info = hub_client.get_cas_jwt(Operation::Download).await?;
let read_info = hub_client.get_cas_jwt(Operation::Upload).await?;
println!("{:?}", read_info);
assert!(read_info.access_token.len() > 0);
assert!(read_info.cas_url.len() > 0);
assert!(read_info.exp > 0);
Ok(())
}
#[tokio::test]
#[ignore = "need valid read token and pr created on hub"]
async fn test_get_jwt_token_with_hf_read_token_pr_branch() -> Result<()> {
let cred_helper = BearerCredentialHelper::new("[hf_read_token]".to_owned(), "");
let hub_client = HubClient::new(
"https://huggingface.co",
RepoInfo {
repo_type: HFRepoType::Model,
full_name: "seanses/tm".into(),
},
Some("refs/pr/1".into()),
"xtool",
"",
cred_helper,
)?;
let read_info = hub_client.get_cas_jwt(Operation::Upload).await?;
assert!(read_info.access_token.len() > 0);
assert!(read_info.cas_url.len() > 0);
assert!(read_info.exp > 0);
Ok(())
}
#[tokio::test]
#[ignore = "need valid read token"]
async fn test_get_jwt_token_with_hf_read_token_create_pr() -> Result<()> {
let cred_helper = BearerCredentialHelper::new("[hf_read_token]".to_owned(), "");
let hub_client = HubClient::new(
"https://huggingface.co",
RepoInfo {
repo_type: HFRepoType::Model,
full_name: "seanses/tm".into(),
},
None,
"xtool",
"",
cred_helper,
)?;
let read_info = hub_client.get_cas_jwt(Operation::Upload).await?;
assert!(read_info.access_token.len() > 0);
assert!(read_info.cas_url.len() > 0);
assert!(read_info.exp > 0);
Ok(())
}

View File

@@ -13,10 +13,15 @@ pub enum HubClientError {
#[error("Credential helper error: {0}")]
CredentialHelper(anyhow::Error),
#[error("Invalid repo type: {0}")]
InvalidRepoType(String),
}
pub type Result<T> = std::result::Result<T, HubClientError>;
pub fn credential_helper_error(e: impl std::error::Error + Send + Sync + 'static) -> HubClientError {
HubClientError::CredentialHelper(e.into())
impl HubClientError {
pub fn credential_helper_error(e: impl std::error::Error + Send + Sync + 'static) -> HubClientError {
HubClientError::CredentialHelper(e.into())
}
}

View File

@@ -5,5 +5,5 @@ mod types;
pub use auth::{BearerCredentialHelper, CredentialHelper, NoopCredentialHelper};
pub use client::{HubClient, Operation};
pub use errors::{HubClientError, Result, credential_helper_error};
pub use types::CasJWTInfo;
pub use errors::{HubClientError, Result};
pub use types::{CasJWTInfo, HFRepoType, RepoInfo};

View File

@@ -1,5 +1,10 @@
use std::fmt::Display;
use std::str::FromStr;
use serde::Deserialize;
use crate::errors::{HubClientError, Result};
/// This defines the response format from the Huggingface Hub Xet CAS access token API.
#[derive(Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
@@ -9,6 +14,67 @@ pub struct CasJWTInfo {
pub access_token: String,
}
// This defines the exact three types of repos served on HF Hub.
#[derive(Debug, PartialEq)]
pub enum HFRepoType {
Model,
Dataset,
Space,
}
impl FromStr for HFRepoType {
type Err = HubClientError;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"" => Ok(HFRepoType::Model), // when repo type is omitted from the URL the default type is "model"
"model" | "models" => Ok(HFRepoType::Model),
"dataset" | "datasets" => Ok(HFRepoType::Dataset),
"space" | "spaces" => Ok(HFRepoType::Space),
t => Err(HubClientError::InvalidRepoType(t.to_owned())),
}
}
}
impl HFRepoType {
pub fn as_str(&self) -> &str {
match self {
HFRepoType::Model => "model",
HFRepoType::Dataset => "dataset",
HFRepoType::Space => "space",
}
}
}
impl Display for HFRepoType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Debug, PartialEq)]
pub struct RepoInfo {
// The type of a repo, one of "model | dataset | space"
pub repo_type: HFRepoType,
// The full name of a repo, formatted as "owner/name"
pub full_name: String,
}
impl RepoInfo {
pub fn try_from(repo_type: &str, repo_id: &str) -> Result<Self> {
Ok(Self {
repo_type: repo_type.parse()?,
full_name: repo_id.into(),
})
}
}
impl Display for RepoInfo {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}/{}", self.repo_type, self.full_name)
}
}
#[cfg(test)]
mod tests {
use anyhow::Result;