mirror of
https://github.com/huggingface/xet-core.git
synced 2026-06-04 13:30:29 +08:00
Better support "xet-write-token" API authorization model and LFS Batch Api change (#498)
1. This PR updates the hub client xet access token request to use custom
rev in addition to the default "main". This better supports the
"xet-write-token" API authorization model:
Clients can get a xet write token, if
- the "rev" is a regular branch, with a HF write token;
- the "rev" is a pr branch with an corresponding open PR, with a HF
write or read token;
- it intends to create a pr and repo is enabled for discussion, with a
HF write or read token.
2. Fixed a bug when getting the current branch name in a repo, which
didn't parse branch names with "/" correctly: change
`refs_heads_branch.rsplit('/').next()` to
`refs_heads_branch.strip_prefix("refs/heads/")`.
3. Also updated xet transfer agent to use the refresh route in the LFS
Batch Api
[response](e3be2b3c8f/server/app/gitHostingRoutes.ts (L1713)).
4. Use the session id in the LFS Batch Api
[response](e3be2b3c8f/server/app/gitHostingRoutes.ts (L1657))
for token refresh and CAS requests.
This commit is contained in:
7
Cargo.lock
generated
7
Cargo.lock
generated
@@ -1561,6 +1561,7 @@ dependencies = [
|
||||
"serde_json",
|
||||
"thiserror 2.0.12",
|
||||
"tokio",
|
||||
"urlencoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3968,6 +3969,12 @@ dependencies = [
|
||||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "urlencoding"
|
||||
version = "2.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da"
|
||||
|
||||
[[package]]
|
||||
name = "utf8_iter"
|
||||
version = "1.0.4"
|
||||
|
||||
@@ -102,6 +102,7 @@ tower-service = "0.3"
|
||||
tracing = "0.1"
|
||||
ulid = "1.2"
|
||||
url = "2.5"
|
||||
urlencoding = "2.1"
|
||||
uuid = "1"
|
||||
walkdir = "2"
|
||||
web-time = "1.1"
|
||||
|
||||
@@ -11,7 +11,7 @@ use clap::{Args, Parser, Subcommand};
|
||||
use data::data_client::default_config;
|
||||
use data::migration_tool::hub_client_token_refresher::HubClientTokenRefresher;
|
||||
use data::migration_tool::migrate::migrate_files_impl;
|
||||
use hub_client::{BearerCredentialHelper, HubClient, Operation};
|
||||
use hub_client::{BearerCredentialHelper, HubClient, Operation, RepoInfo};
|
||||
use merklehash::MerkleHash;
|
||||
use utils::auth::TokenRefresher;
|
||||
use walkdir::WalkDir;
|
||||
@@ -56,8 +56,14 @@ impl XCommand {
|
||||
.unwrap_or_else(|| std::env::var("HF_TOKEN").unwrap_or_default());
|
||||
|
||||
let cred_helper = BearerCredentialHelper::new(token, "");
|
||||
let hub_client =
|
||||
HubClient::new(&endpoint, &self.overrides.repo_type, &self.overrides.repo_id, "xtool", "", cred_helper)?;
|
||||
let hub_client = HubClient::new(
|
||||
&endpoint,
|
||||
RepoInfo::try_from(&self.overrides.repo_type, &self.overrides.repo_id)?,
|
||||
Some("main".to_owned()),
|
||||
"xtool",
|
||||
"",
|
||||
cred_helper,
|
||||
)?;
|
||||
|
||||
self.command.run(hub_client).await
|
||||
}
|
||||
|
||||
@@ -121,4 +121,15 @@ impl TranslatorConfig {
|
||||
..self
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_session_id(self, session_id: &str) -> Self {
|
||||
if session_id.is_empty() {
|
||||
return self;
|
||||
}
|
||||
|
||||
Self {
|
||||
session_id: Some(session_id.to_owned()),
|
||||
..self
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use cas_object::CompressionScheme;
|
||||
use hub_client::{BearerCredentialHelper, HubClient, Operation};
|
||||
use hub_client::{BearerCredentialHelper, HubClient, Operation, RepoInfo};
|
||||
use mdb_shard::file_structs::MDBFileInfo;
|
||||
use tracing::{Instrument, Span, info_span, instrument};
|
||||
use utils::auth::TokenRefresher;
|
||||
@@ -33,7 +33,8 @@ pub async fn migrate_with_external_runtime(
|
||||
repo_id: &str,
|
||||
) -> Result<()> {
|
||||
let cred_helper = BearerCredentialHelper::new(hub_token.to_owned(), "");
|
||||
let hub_client = HubClient::new(hub_endpoint, repo_type, repo_id, "xtool", "", cred_helper)?;
|
||||
let hub_client =
|
||||
HubClient::new(hub_endpoint, RepoInfo::try_from(repo_type, repo_id)?, None, "xtool", "", cred_helper)?;
|
||||
|
||||
migrate_files_impl(file_paths, false, hub_client, cas_endpoint, None, false).await?;
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::path::PathBuf;
|
||||
|
||||
use crate::app::Command::Transfer;
|
||||
use crate::constants::{GIT_LFS_CUSTOM_TRANSFER_AGENT_NAME, GIT_LFS_CUSTOM_TRANSFER_AGENT_PROGRAM};
|
||||
use crate::errors::{Result, config_error};
|
||||
use crate::errors::{GitXetError, Result};
|
||||
use crate::git_process_wrapping::run_git_captured;
|
||||
|
||||
#[derive(Default)]
|
||||
@@ -48,7 +48,7 @@ fn install_impl(location: ConfigLocation, concurrency: Option<u32>) -> Result<()
|
||||
|
||||
let concurrent = if let Some(c) = concurrency {
|
||||
if c == 0 {
|
||||
return Err(config_error("concurrency can't be 0"));
|
||||
return Err(GitXetError::config_error("concurrency can't be 0"));
|
||||
}
|
||||
if c == 1 {
|
||||
"false"
|
||||
|
||||
@@ -7,14 +7,18 @@ use data::FileUploadSession;
|
||||
use data::data_client::{clean_file, default_config};
|
||||
use hub_client::Operation;
|
||||
use progress_tracking::{ProgressUpdate, TrackingProgressUpdater};
|
||||
use utils::auth::TokenRefresher;
|
||||
|
||||
use crate::constants::{HF_ENDPOINT_ENV, XET_ACCESS_TOKEN_HEADER, XET_TOKEN_EXPIRATION_HEADER};
|
||||
use crate::errors::{Result, config_error, internal, not_supported};
|
||||
use crate::constants::{
|
||||
HF_ENDPOINT_ENV, XET_ACCESS_TOKEN_HEADER, XET_CAS_URL, XET_SESSION_ID, XET_TOKEN_EXPIRATION_HEADER,
|
||||
};
|
||||
use crate::errors::{GitXetError, Result};
|
||||
use crate::git_repo::GitRepo;
|
||||
use crate::git_url::{GitUrl, Scheme};
|
||||
use crate::hub_client_token_refresher::HubClientTokenRefresher;
|
||||
use crate::lfs_agent_protocol::errors::bad_syntax;
|
||||
use crate::lfs_agent_protocol::{InitRequestInner, ProgressUpdater, TransferAgent, TransferRequest};
|
||||
use crate::lfs_agent_protocol::{
|
||||
GitLFSProtocolError, InitRequestInner, ProgressUpdater, TransferAgent, TransferRequest,
|
||||
};
|
||||
use crate::token_refresher::DirectRefreshRouteTokenRefresher;
|
||||
|
||||
// This implements a Git LFS custom transfer agent that uploads and downloads files using the Xet protocol.
|
||||
#[derive(Default)]
|
||||
@@ -38,7 +42,7 @@ impl TransferAgent for XetAgent {
|
||||
let hf_endpoint = if !matches!(remote_url.scheme(), Scheme::Http | Scheme::Https) && remote_url.port().is_some()
|
||||
{
|
||||
Some(std::env::var(HF_ENDPOINT_ENV).map_err(|_| {
|
||||
config_error(
|
||||
GitXetError::config_error(
|
||||
r#"This repository has a non-standard Hugging Face remote URL,
|
||||
please specify the Hugging Face server endpoint using environment variable "HF_ENDPOINT""#,
|
||||
)
|
||||
@@ -55,7 +59,7 @@ impl TransferAgent for XetAgent {
|
||||
}
|
||||
|
||||
async fn init_download(&mut self, _: &InitRequestInner) -> Result<()> {
|
||||
Err(not_supported(
|
||||
Err(GitXetError::not_supported(
|
||||
"custom transfer for download is not implemented yet. Downloads should operate through standard git-lfs download protocol.
|
||||
If you encounter errors downloading, contact Xet Team at Hugging Face.",
|
||||
))
|
||||
@@ -70,14 +74,15 @@ impl TransferAgent for XetAgent {
|
||||
// so that if the internal git credential helper needs to prompt the user for credential,
|
||||
// only one prompt is presented.
|
||||
let repo = self.repo.get().unwrap(); // protocol state guarantees self.repo is set.
|
||||
let token_refresher = HubClientTokenRefresher::new(
|
||||
|
||||
let session_id = req.action.header.get(XET_SESSION_ID).map(|s| s.as_str()).unwrap_or_default();
|
||||
let token_refresher: Arc<dyn TokenRefresher> = Arc::new(DirectRefreshRouteTokenRefresher::new(
|
||||
repo,
|
||||
self.remote_url.clone(),
|
||||
self.hf_endpoint.clone(),
|
||||
&req.action.href,
|
||||
Operation::Upload,
|
||||
"",
|
||||
)?;
|
||||
|
||||
session_id,
|
||||
)?);
|
||||
// From git-lfs:
|
||||
// > First worker is the only one allowed to start immediately.
|
||||
// > The rest wait until successful response from 1st worker to
|
||||
@@ -94,16 +99,33 @@ impl TransferAgent for XetAgent {
|
||||
updater: progress_updater,
|
||||
};
|
||||
|
||||
let cas_url = req.action.href.clone();
|
||||
let token = req.action.header[XET_ACCESS_TOKEN_HEADER].clone();
|
||||
let token_expiry: u64 = req.action.header[XET_TOKEN_EXPIRATION_HEADER].parse().map_err(internal)?;
|
||||
let cas_url = req
|
||||
.action
|
||||
.header
|
||||
.get(XET_CAS_URL)
|
||||
.ok_or_else(|| GitXetError::internal("Hugging Face Hub didn't provide a CAS URL"))?
|
||||
.clone();
|
||||
let token = req
|
||||
.action
|
||||
.header
|
||||
.get(XET_ACCESS_TOKEN_HEADER)
|
||||
.ok_or_else(|| GitXetError::internal("Hugging Face Hub didn't provide a CAS access token"))?
|
||||
.clone();
|
||||
let token_expiry: u64 = req
|
||||
.action
|
||||
.header
|
||||
.get(XET_TOKEN_EXPIRATION_HEADER)
|
||||
.ok_or_else(|| GitXetError::internal("Hugging Face Hub didn't provide a CAS access token expiration"))?
|
||||
.parse()
|
||||
.map_err(GitXetError::internal)?;
|
||||
|
||||
let config = default_config(cas_url, None, Some((token, token_expiry)), Some(Arc::new(token_refresher)))?
|
||||
.disable_progress_aggregation(); // upload one file at a time so no need for the heavy progress aggregator
|
||||
let config = default_config(cas_url, None, Some((token, token_expiry)), Some(token_refresher))?
|
||||
.disable_progress_aggregation()
|
||||
.with_session_id(session_id); // upload one file at a time so no need for the heavy progress aggregator
|
||||
let session = FileUploadSession::new(config.into(), Some(Arc::new(xet_updater))).await?;
|
||||
|
||||
let Some(file_path) = &req.path else {
|
||||
return Err(bad_syntax("file path not provided for upload request").into());
|
||||
return Err(GitLFSProtocolError::bad_syntax("file path not provided for upload request").into());
|
||||
};
|
||||
|
||||
clean_file(session.clone(), file_path).await?;
|
||||
|
||||
@@ -5,7 +5,7 @@ use hub_client::{BearerCredentialHelper, CredentialHelper, NoopCredentialHelper,
|
||||
use netrc::Netrc;
|
||||
|
||||
use crate::constants::HF_TOKEN_ENV;
|
||||
use crate::errors::{GitXetError, Result, config_error};
|
||||
use crate::errors::{GitXetError, Result};
|
||||
use crate::git_repo::GitRepo;
|
||||
use crate::git_url::{GitUrl, Scheme};
|
||||
|
||||
@@ -58,7 +58,7 @@ impl FromStr for AccessMode {
|
||||
"private" => Ok(AccessMode::Private),
|
||||
"negotiate" => Ok(AccessMode::Negotiate),
|
||||
"" => Ok(AccessMode::Empty),
|
||||
_ => Err(config_error(format!("invalid \"lfs.<url>.access\" type: {s}"))),
|
||||
_ => Err(GitXetError::config_error(format!("invalid \"lfs.<url>.access\" type: {s}"))),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -129,7 +129,7 @@ pub fn get_credential(repo: &GitRepo, remote_url: &GitUrl, operation: Operation)
|
||||
#[cfg(unix)]
|
||||
return Ok(SSHCredentialHelper::new(remote_url, operation));
|
||||
#[cfg(not(unix))]
|
||||
return Err(crate::errors::not_supported(format!(
|
||||
return Err(GitXetError::not_supported(format!(
|
||||
"using {} in a repository with SSH Git URL is under development; please check back for
|
||||
upgrades or contact Xet Team at Hugging Face.",
|
||||
crate::constants::GIT_LFS_CUSTOM_TRANSFER_AGENT_PROGRAM
|
||||
|
||||
@@ -4,7 +4,7 @@ use std::sync::Arc;
|
||||
|
||||
use hub_client::BearerCredentialHelper;
|
||||
|
||||
use crate::errors::{Result, config_error};
|
||||
use crate::errors::{GitXetError, Result};
|
||||
use crate::git_process_wrapping::run_git_captured_with_input_and_output;
|
||||
|
||||
// This implements the mechanism to get credential stored in configured git credential helpers, including
|
||||
@@ -76,7 +76,7 @@ impl GitCredentialHelper {
|
||||
}
|
||||
}
|
||||
|
||||
Err(config_error(format!("failed to find authentication for {host_url}")))
|
||||
Err(GitXetError::config_error(format!("failed to find authentication for {host_url}")))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use hub_client::{CredentialHelper, Operation, Result, credential_helper_error};
|
||||
use hub_client::{CredentialHelper, HubClientError, Operation, Result};
|
||||
#[cfg(unix)]
|
||||
use openssh::{KnownHosts, Session};
|
||||
use reqwest::header;
|
||||
@@ -46,11 +46,11 @@ impl SSHCredentialHelper {
|
||||
|
||||
#[cfg(unix)]
|
||||
async fn authenticate(&self) -> Result<GitLFSAuthenticateResponse> {
|
||||
let host_url = self.remote_url.host_url().map_err(credential_helper_error)?;
|
||||
let host_url = self.remote_url.host_url().map_err(HubClientError::credential_helper_error)?;
|
||||
let full_repo_path = self.remote_url.full_repo_path();
|
||||
let session = Session::connect(&host_url, KnownHosts::Add)
|
||||
.await
|
||||
.map_err(credential_helper_error)?;
|
||||
.map_err(HubClientError::credential_helper_error)?;
|
||||
|
||||
let output = session
|
||||
.command("git-lfs-authenticate")
|
||||
@@ -58,9 +58,9 @@ impl SSHCredentialHelper {
|
||||
.arg(self.operation.as_str())
|
||||
.output()
|
||||
.await
|
||||
.map_err(credential_helper_error)?;
|
||||
.map_err(HubClientError::credential_helper_error)?;
|
||||
|
||||
serde_json::from_slice(&output.stdout).map_err(credential_helper_error)
|
||||
serde_json::from_slice(&output.stdout).map_err(HubClientError::credential_helper_error)
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
|
||||
@@ -7,8 +7,10 @@ pub const GIT_LFS_CUSTOM_TRANSFER_AGENT_PROGRAM: &str = "git-xet";
|
||||
pub const CURRENT_VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
|
||||
// Moon-landing Xet service headers
|
||||
pub const XET_CAS_URL: &str = "X-Xet-Cas-Url";
|
||||
pub const XET_ACCESS_TOKEN_HEADER: &str = "X-Xet-Access-Token";
|
||||
pub const XET_TOKEN_EXPIRATION_HEADER: &str = "X-Xet-Token-Expiration";
|
||||
pub const XET_SESSION_ID: &str = "X-Xet-Session-Id";
|
||||
|
||||
// Environment variable names
|
||||
pub const HF_TOKEN_ENV: &str = "HF_TOKEN";
|
||||
|
||||
@@ -48,23 +48,25 @@ pub enum GitXetError {
|
||||
|
||||
pub type Result<T> = std::result::Result<T, GitXetError>;
|
||||
|
||||
pub(crate) fn git_cmd_failed(e: impl Display, source: Option<std::io::Error>) -> GitXetError {
|
||||
GitXetError::GitCommandFailed {
|
||||
reason: e.to_string(),
|
||||
source,
|
||||
impl GitXetError {
|
||||
pub(crate) fn git_cmd_failed(e: impl Display, source: Option<std::io::Error>) -> GitXetError {
|
||||
GitXetError::GitCommandFailed {
|
||||
reason: e.to_string(),
|
||||
source,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn not_supported(e: impl Display) -> GitXetError {
|
||||
GitXetError::NotSupported(e.to_string())
|
||||
}
|
||||
pub(crate) fn not_supported(e: impl Display) -> GitXetError {
|
||||
GitXetError::NotSupported(e.to_string())
|
||||
}
|
||||
|
||||
pub(crate) fn config_error(e: impl Display) -> GitXetError {
|
||||
GitXetError::InvalidGitConfig(e.to_string())
|
||||
}
|
||||
pub(crate) fn config_error(e: impl Display) -> GitXetError {
|
||||
GitXetError::InvalidGitConfig(e.to_string())
|
||||
}
|
||||
|
||||
pub(crate) fn internal(e: impl Display) -> GitXetError {
|
||||
GitXetError::Internal(e.to_string())
|
||||
pub(crate) fn internal(e: impl Display) -> GitXetError {
|
||||
GitXetError::Internal(e.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CasClientError> for GitXetError {
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::path::Path;
|
||||
use std::process::{Child, ChildStdin, Command, Stdio};
|
||||
|
||||
use crate::constants::GIT_EXECUTABLE;
|
||||
use crate::errors::{Result, git_cmd_failed, internal};
|
||||
use crate::errors::{GitXetError, Result};
|
||||
|
||||
// This mod implements utilities to invoke Git commands through child processes from the `git` program.
|
||||
|
||||
@@ -74,8 +74,8 @@ impl CapturedCommand {
|
||||
// From past experience, if the "git" program is not found the underlying error
|
||||
// only says "Not Found" and is not very helpful to identify the cause. We thus
|
||||
// capture this error and make the message more explicit.
|
||||
std::io::ErrorKind::NotFound => git_cmd_failed(r#"program "git" not found"#, Some(e)),
|
||||
_ => git_cmd_failed("internal", Some(e)),
|
||||
std::io::ErrorKind::NotFound => GitXetError::git_cmd_failed(r#"program "git" not found"#, Some(e)),
|
||||
_ => GitXetError::git_cmd_failed("internal", Some(e)),
|
||||
})?,
|
||||
})
|
||||
}
|
||||
@@ -94,7 +94,7 @@ impl CapturedCommand {
|
||||
self.child_process
|
||||
.stdin
|
||||
.take()
|
||||
.ok_or_else(|| internal("stdin of child process is not captured"))
|
||||
.ok_or_else(|| GitXetError::internal("stdin of child process is not captured"))
|
||||
}
|
||||
|
||||
/// Synchronously wait for the child to exit completely, returning `Ok(())` if the child exits with status code 0;
|
||||
@@ -117,7 +117,7 @@ impl CapturedCommand {
|
||||
_ => {
|
||||
let stdout = std::str::from_utf8(&ret.stdout).unwrap_or("<Binary Data>").trim();
|
||||
let stderr = std::str::from_utf8(&ret.stderr).unwrap_or("<Binary Data>").trim();
|
||||
Err(git_cmd_failed(
|
||||
Err(GitXetError::git_cmd_failed(
|
||||
format!("err_code = {:?}, stdout = \"{}\", stderr = \"{}\"", ret.status.code(), stdout, stderr),
|
||||
None,
|
||||
))
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::sync::{Arc, Mutex};
|
||||
|
||||
use git2::{Config, Repository};
|
||||
|
||||
use crate::errors::{GitXetError, Result, config_error, internal};
|
||||
use crate::errors::{GitXetError, Result};
|
||||
use crate::git_url::GitUrl;
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -33,10 +33,27 @@ impl GitRepo {
|
||||
|
||||
// Returns the path to the .git folder for normal repositories or the repository itself for bare repositories.
|
||||
pub fn git_path(&self) -> Result<PathBuf> {
|
||||
let repo = self.repo.lock().map_err(internal)?;
|
||||
let repo = self.repo.lock().map_err(GitXetError::internal)?;
|
||||
Ok(repo.path().to_path_buf())
|
||||
}
|
||||
|
||||
// Resolves the reference pointed at by HEAD, returns branch name if it is a branch.
|
||||
pub fn branch_name(&self) -> Result<Option<String>> {
|
||||
let repo = self.repo.lock().map_err(GitXetError::internal)?;
|
||||
|
||||
let maybe_head_ref = repo.head();
|
||||
Ok(maybe_head_ref.ok().and_then(|head_ref| {
|
||||
if head_ref.is_branch() {
|
||||
head_ref
|
||||
.name()
|
||||
.and_then(|refs_heads_branch| refs_heads_branch.strip_prefix("refs/heads/"))
|
||||
.map(|branch| branch.to_owned())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
// Returns the remote that a git push/fetch/pull operation
|
||||
// is targeted at, based on:
|
||||
// 1. The currently tracked remote branch, if present
|
||||
@@ -44,20 +61,9 @@ impl GitRepo {
|
||||
// 3. Any other SINGLE remote defined in .git/config
|
||||
// 4. Use "origin" as a fallback.
|
||||
pub fn remote_name(&self) -> Result<String> {
|
||||
let repo = self.repo.lock().map_err(internal)?;
|
||||
|
||||
let maybe_head_ref = repo.head();
|
||||
let maybe_branch_name = maybe_head_ref.ok().and_then(|head_ref| {
|
||||
if head_ref.is_branch() {
|
||||
head_ref
|
||||
.name()
|
||||
.and_then(|refs_heads_branch| refs_heads_branch.rsplit('/').next())
|
||||
.map(|branch| branch.to_owned())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
let maybe_branch_name = self.branch_name()?;
|
||||
|
||||
let repo = self.repo.lock().map_err(GitXetError::internal)?;
|
||||
let config = repo.config()?.snapshot()?;
|
||||
|
||||
// try tracking remote
|
||||
@@ -87,13 +93,13 @@ impl GitRepo {
|
||||
|
||||
// Returns the URL for a specific remote name.
|
||||
pub fn remote_name_to_url(&self, remote: &str) -> Result<GitUrl> {
|
||||
let repo = self.repo.lock().map_err(internal)?;
|
||||
let repo = self.repo.lock().map_err(GitXetError::internal)?;
|
||||
|
||||
let url: GitUrl = repo
|
||||
.find_remote(remote)?
|
||||
.url()
|
||||
.map(|s| s.to_string())
|
||||
.ok_or_else(|| config_error(format!("no url for remote \"{remote}\"")))?
|
||||
.ok_or_else(|| GitXetError::config_error(format!("no url for remote \"{remote}\"")))?
|
||||
.parse()?;
|
||||
|
||||
Ok(url)
|
||||
@@ -108,7 +114,7 @@ impl GitRepo {
|
||||
|
||||
// Returns a snapshot of the current Git repo config.
|
||||
pub fn config(&self) -> Result<Config> {
|
||||
let repo = self.repo.lock().map_err(internal)?;
|
||||
let repo = self.repo.lock().map_err(GitXetError::internal)?;
|
||||
|
||||
Ok(repo.config()?.snapshot()?)
|
||||
}
|
||||
@@ -122,6 +128,25 @@ mod tests {
|
||||
use crate::git_repo::GitRepo;
|
||||
use crate::test_utils::TestRepo;
|
||||
|
||||
#[test]
|
||||
#[serial(env_var_write_read)]
|
||||
fn test_get_ref_name() -> Result<()> {
|
||||
let test_repo = TestRepo::new("main")?;
|
||||
let repo = GitRepo::open(test_repo.path())?;
|
||||
|
||||
test_repo.new_commit("data", "hello".as_bytes(), "add new file")?;
|
||||
assert_eq!(repo.branch_name()?, Some("main".to_owned()));
|
||||
|
||||
test_repo.new_branch("pr/1", "main")?;
|
||||
test_repo.new_commit("data", "world".as_bytes(), "update file")?;
|
||||
assert_eq!(repo.branch_name()?, Some("pr/1".to_owned()));
|
||||
|
||||
test_repo.checkout(&["HEAD^"])?;
|
||||
assert_eq!(repo.branch_name()?, None);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial(env_var_write_read)]
|
||||
fn test_get_remote_from_local_config() -> Result<()> {
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
use std::fmt::Display;
|
||||
use std::str::FromStr;
|
||||
|
||||
use git_url_parse::GitUrl as innerGitUrl;
|
||||
pub use git_url_parse::Scheme;
|
||||
use hub_client::{HFRepoType, RepoInfo};
|
||||
|
||||
use crate::errors::{GitXetError, Result, config_error, not_supported};
|
||||
use crate::errors::{GitXetError, Result};
|
||||
|
||||
// This mod implements funtionalities to handle Git remote URLs, especially tailored for
|
||||
// Git LFS and Hugging Face repo needs, including deriving Git LFS server endpoint from
|
||||
@@ -62,7 +62,7 @@ impl GitUrl {
|
||||
.inner
|
||||
.host
|
||||
.as_ref()
|
||||
.ok_or_else(|| config_error("remote URL missing host name"))?;
|
||||
.ok_or_else(|| GitXetError::config_error("remote URL missing host name"))?;
|
||||
|
||||
let port = self.inner.port;
|
||||
let port_str = if translated || port.is_none() {
|
||||
@@ -86,7 +86,7 @@ impl GitUrl {
|
||||
.inner
|
||||
.host
|
||||
.as_ref()
|
||||
.ok_or_else(|| config_error("remote URL missing host name"))?;
|
||||
.ok_or_else(|| GitXetError::config_error("remote URL missing host name"))?;
|
||||
|
||||
let port = self.inner.port;
|
||||
let port_str = if translated || port.is_none() {
|
||||
@@ -105,11 +105,14 @@ impl GitUrl {
|
||||
match self.inner.scheme {
|
||||
Scheme::Http => Ok(("http", false)),
|
||||
Scheme::Https => Ok(("https", false)),
|
||||
Scheme::File | Scheme::Ftp | Scheme::Ftps => {
|
||||
Err(not_supported(format!("cannot convert from scheme \"{}://\" to \"http(s)://\"", self.inner.scheme)))
|
||||
},
|
||||
Scheme::File | Scheme::Ftp | Scheme::Ftps => Err(GitXetError::not_supported(format!(
|
||||
"cannot convert from scheme \"{}://\" to \"http(s)://\"",
|
||||
self.inner.scheme
|
||||
))),
|
||||
Scheme::Git | Scheme::GitSsh | Scheme::Ssh => Ok(("https", true)),
|
||||
Scheme::Unspecified => Err(not_supported("cannot convert from unspecified scheme to \"http(s)://\"")),
|
||||
Scheme::Unspecified => {
|
||||
Err(GitXetError::not_supported("cannot convert from unspecified scheme to \"http(s)://\""))
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,7 +146,7 @@ impl GitUrl {
|
||||
.inner
|
||||
.host
|
||||
.as_ref()
|
||||
.ok_or_else(|| config_error("remote URL missing host name"))?;
|
||||
.ok_or_else(|| GitXetError::config_error("remote URL missing host name"))?;
|
||||
|
||||
let port_str = if let Some(p) = self.inner.port {
|
||||
format!(":{}", p)
|
||||
@@ -171,6 +174,7 @@ impl GitUrl {
|
||||
}
|
||||
|
||||
// Returns the parsed full repo path into `RepoInfo`.
|
||||
#[allow(unused)]
|
||||
pub fn repo_info(&self) -> Result<RepoInfo> {
|
||||
let path = self.full_repo_path();
|
||||
let full_name = self.inner.fullname.clone(); // The full name of the repo, formatted as "owner/name"
|
||||
@@ -181,58 +185,6 @@ impl GitUrl {
|
||||
}
|
||||
}
|
||||
|
||||
// This defines the exact three types of repos served on HF Hub.
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum HFRepoType {
|
||||
Model,
|
||||
Dataset,
|
||||
Space,
|
||||
}
|
||||
|
||||
impl FromStr for HFRepoType {
|
||||
type Err = GitXetError;
|
||||
|
||||
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"" => Ok(HFRepoType::Model), // when repo type is omitted from the URL the default type is "model"
|
||||
"model" | "models" => Ok(HFRepoType::Model),
|
||||
"dataset" | "datasets" => Ok(HFRepoType::Dataset),
|
||||
"space" | "spaces" => Ok(HFRepoType::Space),
|
||||
t => Err(config_error(format!("invalid repo type {t}"))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl HFRepoType {
|
||||
pub fn as_str(&self) -> &str {
|
||||
match self {
|
||||
HFRepoType::Model => "model",
|
||||
HFRepoType::Dataset => "dataset",
|
||||
HFRepoType::Space => "space",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for HFRepoType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(self.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct RepoInfo {
|
||||
// The type of a repo, one of "model | dataset | space"
|
||||
pub repo_type: HFRepoType,
|
||||
// The full name of a repo, formatted as "owner/name"
|
||||
pub full_name: String,
|
||||
}
|
||||
|
||||
impl Display for RepoInfo {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}/{}", self.repo_type, self.full_name)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test_lfs_server_discovery {
|
||||
use super::GitUrl;
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use hub_client::{HubClient, Operation};
|
||||
use utils::auth::{TokenInfo, TokenRefresher};
|
||||
use utils::errors::AuthError;
|
||||
|
||||
use crate::auth::get_credential;
|
||||
use crate::constants::GIT_LFS_CUSTOM_TRANSFER_AGENT_PROGRAM;
|
||||
use crate::errors::Result;
|
||||
use crate::git_repo::GitRepo;
|
||||
use crate::git_url::GitUrl;
|
||||
|
||||
pub struct HubClientTokenRefresher {
|
||||
operation: Operation,
|
||||
client: Arc<HubClient>,
|
||||
}
|
||||
|
||||
impl HubClientTokenRefresher {
|
||||
pub fn new(
|
||||
repo: &GitRepo,
|
||||
remote_url: Option<GitUrl>,
|
||||
token_endpoint: Option<String>,
|
||||
operation: Operation,
|
||||
session_id: &str,
|
||||
) -> Result<Self> {
|
||||
let remote_url = match remote_url {
|
||||
Some(r) => r,
|
||||
None => repo.remote_url()?,
|
||||
};
|
||||
let repo_info = remote_url.repo_info()?;
|
||||
|
||||
let endpoint = match token_endpoint {
|
||||
Some(e) => e,
|
||||
None => remote_url.to_derived_http_host_url()?,
|
||||
};
|
||||
|
||||
let cred_helper = get_credential(repo, &remote_url, operation)?;
|
||||
|
||||
let client = HubClient::new(
|
||||
&endpoint,
|
||||
repo_info.repo_type.as_str(),
|
||||
&repo_info.full_name,
|
||||
GIT_LFS_CUSTOM_TRANSFER_AGENT_PROGRAM,
|
||||
session_id,
|
||||
cred_helper,
|
||||
)?;
|
||||
|
||||
Ok(Self {
|
||||
operation,
|
||||
client: Arc::new(client),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TokenRefresher for HubClientTokenRefresher {
|
||||
async fn refresh(&self) -> std::result::Result<TokenInfo, AuthError> {
|
||||
let jwt_info = self
|
||||
.client
|
||||
.get_cas_jwt(self.operation)
|
||||
.await
|
||||
.map_err(AuthError::token_refresh_failure)?;
|
||||
|
||||
Ok((jwt_info.access_token, jwt_info.exp))
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,7 @@ use std::io::{BufRead, Write};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use crate::errors::{Result, internal};
|
||||
use crate::errors::{GitXetError, Result};
|
||||
|
||||
mod agent_state;
|
||||
pub mod errors;
|
||||
@@ -116,7 +116,7 @@ where
|
||||
},
|
||||
};
|
||||
|
||||
stdout.lock().map_err(internal)?.write_all(response.as_bytes())?;
|
||||
stdout.lock().map_err(GitXetError::internal)?.write_all(response.as_bytes())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use super::errors::{Result, bad_state};
|
||||
use super::errors::{GitLFSProtocolError, Result};
|
||||
|
||||
// This defines the state of a transfer agent to make sure that request events are initiated
|
||||
// in the correct order. Unlike a traditional state machines, we don't define a "terminated"
|
||||
@@ -18,27 +18,27 @@ impl LFSAgentState {
|
||||
match self {
|
||||
Self::PendingInit => match to {
|
||||
Self::InitedForUpload | Self::InitedForDownload => (),
|
||||
_ => return Err(bad_state("init event not yet received")),
|
||||
_ => return Err(GitLFSProtocolError::bad_state("init event not yet received")),
|
||||
},
|
||||
Self::InitedForUpload => match to {
|
||||
Self::Uploading => (),
|
||||
Self::Downloading => return Err(bad_state("agent initiated for upload")),
|
||||
_ => return Err(bad_state("init event already received")),
|
||||
Self::Downloading => return Err(GitLFSProtocolError::bad_state("agent initiated for upload")),
|
||||
_ => return Err(GitLFSProtocolError::bad_state("init event already received")),
|
||||
},
|
||||
Self::InitedForDownload => match to {
|
||||
Self::Downloading => (),
|
||||
Self::Uploading => return Err(bad_state("agent initiated for download")),
|
||||
_ => return Err(bad_state("init event already received")),
|
||||
Self::Uploading => return Err(GitLFSProtocolError::bad_state("agent initiated for download")),
|
||||
_ => return Err(GitLFSProtocolError::bad_state("init event already received")),
|
||||
},
|
||||
Self::Uploading => match to {
|
||||
Self::Uploading => (),
|
||||
Self::Downloading => return Err(bad_state("agent initiated for upload")),
|
||||
_ => return Err(bad_state("data transfer already in progress")),
|
||||
Self::Downloading => return Err(GitLFSProtocolError::bad_state("agent initiated for upload")),
|
||||
_ => return Err(GitLFSProtocolError::bad_state("data transfer already in progress")),
|
||||
},
|
||||
Self::Downloading => match to {
|
||||
Self::Downloading => (),
|
||||
Self::Uploading => return Err(bad_state("agent initiated for download")),
|
||||
_ => return Err(bad_state("data transfer already in progress")),
|
||||
Self::Uploading => return Err(GitLFSProtocolError::bad_state("agent initiated for download")),
|
||||
_ => return Err(GitLFSProtocolError::bad_state("data transfer already in progress")),
|
||||
},
|
||||
};
|
||||
|
||||
|
||||
@@ -22,14 +22,16 @@ pub enum GitLFSProtocolError {
|
||||
|
||||
pub(super) type Result<T> = std::result::Result<T, GitLFSProtocolError>;
|
||||
|
||||
pub(crate) fn bad_syntax(e: impl Display) -> GitLFSProtocolError {
|
||||
GitLFSProtocolError::Syntax(e.to_string())
|
||||
}
|
||||
impl GitLFSProtocolError {
|
||||
pub(crate) fn bad_syntax(e: impl Display) -> GitLFSProtocolError {
|
||||
GitLFSProtocolError::Syntax(e.to_string())
|
||||
}
|
||||
|
||||
pub(crate) fn bad_argument(e: impl Display) -> GitLFSProtocolError {
|
||||
GitLFSProtocolError::Argument(e.to_string())
|
||||
}
|
||||
pub(crate) fn bad_argument(e: impl Display) -> GitLFSProtocolError {
|
||||
GitLFSProtocolError::Argument(e.to_string())
|
||||
}
|
||||
|
||||
pub(crate) fn bad_state(e: impl Display) -> GitLFSProtocolError {
|
||||
GitLFSProtocolError::State(e.to_string())
|
||||
pub(crate) fn bad_state(e: impl Display) -> GitLFSProtocolError {
|
||||
GitLFSProtocolError::State(e.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ use std::str::FromStr;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
|
||||
use super::errors::{GitLFSProtocolError, Result, bad_argument, bad_syntax};
|
||||
use super::errors::{GitLFSProtocolError, Result};
|
||||
|
||||
// This file defines the protocol that Git LFS uses to talk to
|
||||
// custom transfer agents. This implementation follows the protocol specification
|
||||
@@ -105,37 +105,45 @@ impl LFSProtocolRequestEvent {
|
||||
InitRequest::Download(inner) => inner,
|
||||
};
|
||||
if inner.remote.is_empty() {
|
||||
return Err(bad_argument("invalid remote"));
|
||||
return Err(GitLFSProtocolError::bad_argument("invalid remote"));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
},
|
||||
LFSProtocolRequestEvent::Upload(req) => {
|
||||
if req.oid.len() != OID_LEN {
|
||||
return Err(bad_argument("invalid oid"));
|
||||
return Err(GitLFSProtocolError::bad_argument("invalid oid"));
|
||||
}
|
||||
|
||||
if req.size == 0 {
|
||||
return Err(bad_argument("invalid size"));
|
||||
return Err(GitLFSProtocolError::bad_argument("invalid size"));
|
||||
}
|
||||
|
||||
if req.path.is_none() {
|
||||
return Err(bad_syntax("file path not provided for upload request"));
|
||||
return Err(GitLFSProtocolError::bad_syntax("file path not provided for upload request"));
|
||||
}
|
||||
|
||||
if req.action.href.is_empty() {
|
||||
return Err(GitLFSProtocolError::bad_argument("empty action.href in server response"));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
},
|
||||
LFSProtocolRequestEvent::Download(req) => {
|
||||
if req.oid.len() != OID_LEN {
|
||||
return Err(bad_argument("invalid oid"));
|
||||
return Err(GitLFSProtocolError::bad_argument("invalid oid"));
|
||||
}
|
||||
|
||||
if req.size == 0 {
|
||||
return Err(bad_argument("invalid size"));
|
||||
return Err(GitLFSProtocolError::bad_argument("invalid size"));
|
||||
}
|
||||
|
||||
if req.path.is_some() {
|
||||
return Err(bad_syntax("file path provided for download request"));
|
||||
return Err(GitLFSProtocolError::bad_syntax("file path provided for download request"));
|
||||
}
|
||||
|
||||
if req.action.href.is_empty() {
|
||||
return Err(GitLFSProtocolError::bad_argument("empty action.href in server response"));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -149,7 +157,7 @@ impl FromStr for LFSProtocolRequestEvent {
|
||||
type Err = GitLFSProtocolError;
|
||||
|
||||
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||
let req: LFSProtocolRequestEvent = serde_json::from_str(s).map_err(bad_syntax)?;
|
||||
let req: LFSProtocolRequestEvent = serde_json::from_str(s).map_err(GitLFSProtocolError::bad_syntax)?;
|
||||
req.validate()?;
|
||||
Ok(req)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,6 @@ mod errors;
|
||||
mod git_process_wrapping;
|
||||
mod git_repo;
|
||||
mod git_url;
|
||||
mod hub_client_token_refresher;
|
||||
mod lfs_agent_protocol;
|
||||
mod test_utils;
|
||||
mod token_refresher;
|
||||
|
||||
@@ -68,6 +68,20 @@ impl TestRepo {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Create a new branch `new_branch_name` off `base`.
|
||||
pub fn new_branch(&self, new_branch_name: &str, base: &str) -> Result<()> {
|
||||
run_git_captured(&self.repo_path, "checkout", &[base, "-b", new_branch_name])?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Run the versatile checkout command.
|
||||
pub fn checkout(&self, args: &[&str]) -> Result<()> {
|
||||
run_git_captured(&self.repo_path, "checkout", args)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// Create a new local branch `local_branch_name` off HEAD in this repo that tracks a remote branch
|
||||
// `remote`:`remote_branch_name`.
|
||||
pub fn new_branch_tracking_remote(
|
||||
@@ -76,8 +90,8 @@ impl TestRepo {
|
||||
remote_branch_name: &str,
|
||||
local_branch_name: &str,
|
||||
) -> Result<()> {
|
||||
self.new_branch(local_branch_name, "HEAD")?;
|
||||
run_git_captured(&self.repo_path, "fetch", &[remote, remote_branch_name])?;
|
||||
run_git_captured(&self.repo_path, "checkout", &["-b", local_branch_name])?;
|
||||
run_git_captured(
|
||||
&self.repo_path,
|
||||
"branch",
|
||||
|
||||
65
git_xet/src/token_refresher.rs
Normal file
65
git_xet/src/token_refresher.rs
Normal file
@@ -0,0 +1,65 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use cas_client::{Api, RetryConfig, build_http_client};
|
||||
use hub_client::{CasJWTInfo, CredentialHelper, Operation};
|
||||
use reqwest::header;
|
||||
use reqwest_middleware::ClientWithMiddleware;
|
||||
use utils::auth::{TokenInfo, TokenRefresher};
|
||||
use utils::errors::AuthError;
|
||||
|
||||
use crate::auth::get_credential;
|
||||
use crate::constants::GIT_LFS_CUSTOM_TRANSFER_AGENT_PROGRAM;
|
||||
use crate::errors::Result;
|
||||
use crate::git_repo::GitRepo;
|
||||
use crate::git_url::GitUrl;
|
||||
|
||||
pub struct DirectRefreshRouteTokenRefresher {
|
||||
refresh_route: String,
|
||||
client: ClientWithMiddleware,
|
||||
cred_helper: Arc<dyn CredentialHelper>,
|
||||
}
|
||||
|
||||
impl DirectRefreshRouteTokenRefresher {
|
||||
pub fn new(
|
||||
repo: &GitRepo,
|
||||
remote_url: Option<GitUrl>,
|
||||
refresh_route: &str,
|
||||
operation: Operation,
|
||||
session_id: &str,
|
||||
) -> Result<Self> {
|
||||
let remote_url = match remote_url {
|
||||
Some(r) => r,
|
||||
None => repo.remote_url()?,
|
||||
};
|
||||
|
||||
let cred_helper = get_credential(repo, &remote_url, operation)?;
|
||||
|
||||
Ok(Self {
|
||||
refresh_route: refresh_route.to_owned(),
|
||||
client: build_http_client(RetryConfig::default(), session_id)?,
|
||||
cred_helper,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl TokenRefresher for DirectRefreshRouteTokenRefresher {
|
||||
async fn refresh(&self) -> std::result::Result<TokenInfo, AuthError> {
|
||||
let req = self
|
||||
.client
|
||||
.get(&self.refresh_route)
|
||||
.with_extension(Api("xet-token"))
|
||||
.header(header::USER_AGENT, GIT_LFS_CUSTOM_TRANSFER_AGENT_PROGRAM);
|
||||
let req = self
|
||||
.cred_helper
|
||||
.fill_credential(req)
|
||||
.await
|
||||
.map_err(AuthError::token_refresh_failure)?;
|
||||
let response = req.send().await.map_err(AuthError::token_refresh_failure)?;
|
||||
|
||||
let jwt_info: CasJWTInfo = response.json().await.map_err(AuthError::token_refresh_failure)?;
|
||||
|
||||
Ok((jwt_info.access_token, jwt_info.exp))
|
||||
}
|
||||
}
|
||||
7
hf_xet/Cargo.lock
generated
7
hf_xet/Cargo.lock
generated
@@ -1355,6 +1355,7 @@ dependencies = [
|
||||
"reqwest-middleware",
|
||||
"serde",
|
||||
"thiserror 2.0.15",
|
||||
"urlencoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3793,6 +3794,12 @@ dependencies = [
|
||||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "urlencoding"
|
||||
version = "2.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da"
|
||||
|
||||
[[package]]
|
||||
name = "utf8_iter"
|
||||
version = "1.0.4"
|
||||
|
||||
@@ -13,6 +13,7 @@ reqwest = { workspace = true }
|
||||
reqwest-middleware = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
urlencoding = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
serde_json = { workspace = true }
|
||||
|
||||
@@ -3,10 +3,11 @@ use std::sync::Arc;
|
||||
use cas_client::exports::ClientWithMiddleware;
|
||||
use cas_client::{Api, RetryConfig, build_http_client};
|
||||
use http::header;
|
||||
use urlencoding::encode;
|
||||
|
||||
use crate::auth::CredentialHelper;
|
||||
use crate::errors::*;
|
||||
use crate::types::CasJWTInfo;
|
||||
use crate::types::{CasJWTInfo, RepoInfo};
|
||||
|
||||
/// The type of operation to perform, either to upload files or to download files.
|
||||
/// Different operations lead to CAS access token with different authorization levels.
|
||||
@@ -34,8 +35,8 @@ impl Operation {
|
||||
|
||||
pub struct HubClient {
|
||||
endpoint: String,
|
||||
repo_type: String,
|
||||
repo_id: String,
|
||||
repo_info: RepoInfo,
|
||||
reference: Option<String>,
|
||||
user_agent: String,
|
||||
client: ClientWithMiddleware,
|
||||
cred_helper: Arc<dyn CredentialHelper>,
|
||||
@@ -44,16 +45,16 @@ pub struct HubClient {
|
||||
impl HubClient {
|
||||
pub fn new(
|
||||
endpoint: &str,
|
||||
repo_type: &str,
|
||||
repo_id: &str,
|
||||
repo_info: RepoInfo,
|
||||
reference: Option<String>,
|
||||
user_agent: &str,
|
||||
session_id: &str,
|
||||
cred_helper: Arc<dyn CredentialHelper>,
|
||||
) -> Result<Self> {
|
||||
Ok(HubClient {
|
||||
endpoint: endpoint.to_owned(),
|
||||
repo_type: repo_type.to_owned(),
|
||||
repo_id: repo_id.to_owned(),
|
||||
repo_info,
|
||||
reference,
|
||||
user_agent: user_agent.to_owned(),
|
||||
client: build_http_client(RetryConfig::default(), session_id)?,
|
||||
cred_helper,
|
||||
@@ -63,12 +64,27 @@ impl HubClient {
|
||||
// Get CAS access token from Hub access token.
|
||||
pub async fn get_cas_jwt(&self, operation: Operation) -> Result<CasJWTInfo> {
|
||||
let endpoint = self.endpoint.as_str();
|
||||
let repo_type = self.repo_type.as_str();
|
||||
let repo_id = self.repo_id.as_str();
|
||||
let repo_type = self.repo_info.repo_type.as_str();
|
||||
let repo_id = self.repo_info.full_name.as_str();
|
||||
let token_type = operation.token_type();
|
||||
|
||||
// The reference may contain "/" but the "xet-[]-token" API only parses "rev" from a single component,
|
||||
// thus we encode the reference. It defaults to "main" if not specified by caller because the
|
||||
// API route expects a "rev" component.
|
||||
let rev = encode(self.reference.as_deref().unwrap_or("main"));
|
||||
|
||||
// Clients can get a xet write token, if
|
||||
// - the "rev" is a regular branch, with a HF write token;
|
||||
// - the "rev" is a pr branch, with a HF write or read token;
|
||||
// - it intends to create a pr and repo is enabled for discussion, with a HF write or read token.
|
||||
let query = if matches!(operation, Operation::Upload) && self.reference.is_none() {
|
||||
"?create_pr=1"
|
||||
} else {
|
||||
""
|
||||
};
|
||||
|
||||
// note that this API doesn't take a Basic auth
|
||||
let url = format!("{endpoint}/api/{repo_type}s/{repo_id}/xet-{token_type}-token/main");
|
||||
let url = format!("{endpoint}/api/{repo_type}s/{repo_id}/xet-{token_type}-token/{rev}{query}");
|
||||
|
||||
let req = self
|
||||
.client
|
||||
@@ -92,17 +108,79 @@ impl HubClient {
|
||||
mod tests {
|
||||
use super::HubClient;
|
||||
use crate::errors::Result;
|
||||
use crate::{BearerCredentialHelper, Operation};
|
||||
use crate::{BearerCredentialHelper, HFRepoType, Operation, RepoInfo};
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "need valid token"]
|
||||
async fn test_get_jwt_token() -> Result<()> {
|
||||
let cred_helper = BearerCredentialHelper::new("[hf_token]".to_owned(), "");
|
||||
let hub_client = HubClient::new("https://huggingface.co", "model", "seanses/tm", "xtool", "", cred_helper)?;
|
||||
#[ignore = "need valid write token"]
|
||||
async fn test_get_jwt_token_with_hf_write_token() -> Result<()> {
|
||||
let cred_helper = BearerCredentialHelper::new("[hf_write_token]".to_owned(), "");
|
||||
let hub_client = HubClient::new(
|
||||
"https://huggingface.co",
|
||||
RepoInfo {
|
||||
repo_type: HFRepoType::Model,
|
||||
full_name: "seanses/tm".into(),
|
||||
},
|
||||
Some("main".into()),
|
||||
"xtool",
|
||||
"",
|
||||
cred_helper,
|
||||
)?;
|
||||
|
||||
let read_info = hub_client.get_cas_jwt(Operation::Download).await?;
|
||||
let read_info = hub_client.get_cas_jwt(Operation::Upload).await?;
|
||||
|
||||
println!("{:?}", read_info);
|
||||
assert!(read_info.access_token.len() > 0);
|
||||
assert!(read_info.cas_url.len() > 0);
|
||||
assert!(read_info.exp > 0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "need valid read token and pr created on hub"]
|
||||
async fn test_get_jwt_token_with_hf_read_token_pr_branch() -> Result<()> {
|
||||
let cred_helper = BearerCredentialHelper::new("[hf_read_token]".to_owned(), "");
|
||||
let hub_client = HubClient::new(
|
||||
"https://huggingface.co",
|
||||
RepoInfo {
|
||||
repo_type: HFRepoType::Model,
|
||||
full_name: "seanses/tm".into(),
|
||||
},
|
||||
Some("refs/pr/1".into()),
|
||||
"xtool",
|
||||
"",
|
||||
cred_helper,
|
||||
)?;
|
||||
|
||||
let read_info = hub_client.get_cas_jwt(Operation::Upload).await?;
|
||||
|
||||
assert!(read_info.access_token.len() > 0);
|
||||
assert!(read_info.cas_url.len() > 0);
|
||||
assert!(read_info.exp > 0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "need valid read token"]
|
||||
async fn test_get_jwt_token_with_hf_read_token_create_pr() -> Result<()> {
|
||||
let cred_helper = BearerCredentialHelper::new("[hf_read_token]".to_owned(), "");
|
||||
let hub_client = HubClient::new(
|
||||
"https://huggingface.co",
|
||||
RepoInfo {
|
||||
repo_type: HFRepoType::Model,
|
||||
full_name: "seanses/tm".into(),
|
||||
},
|
||||
None,
|
||||
"xtool",
|
||||
"",
|
||||
cred_helper,
|
||||
)?;
|
||||
|
||||
let read_info = hub_client.get_cas_jwt(Operation::Upload).await?;
|
||||
|
||||
assert!(read_info.access_token.len() > 0);
|
||||
assert!(read_info.cas_url.len() > 0);
|
||||
assert!(read_info.exp > 0);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -13,10 +13,15 @@ pub enum HubClientError {
|
||||
|
||||
#[error("Credential helper error: {0}")]
|
||||
CredentialHelper(anyhow::Error),
|
||||
|
||||
#[error("Invalid repo type: {0}")]
|
||||
InvalidRepoType(String),
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, HubClientError>;
|
||||
|
||||
pub fn credential_helper_error(e: impl std::error::Error + Send + Sync + 'static) -> HubClientError {
|
||||
HubClientError::CredentialHelper(e.into())
|
||||
impl HubClientError {
|
||||
pub fn credential_helper_error(e: impl std::error::Error + Send + Sync + 'static) -> HubClientError {
|
||||
HubClientError::CredentialHelper(e.into())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,5 +5,5 @@ mod types;
|
||||
|
||||
pub use auth::{BearerCredentialHelper, CredentialHelper, NoopCredentialHelper};
|
||||
pub use client::{HubClient, Operation};
|
||||
pub use errors::{HubClientError, Result, credential_helper_error};
|
||||
pub use types::CasJWTInfo;
|
||||
pub use errors::{HubClientError, Result};
|
||||
pub use types::{CasJWTInfo, HFRepoType, RepoInfo};
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
use std::fmt::Display;
|
||||
use std::str::FromStr;
|
||||
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::errors::{HubClientError, Result};
|
||||
|
||||
/// This defines the response format from the Huggingface Hub Xet CAS access token API.
|
||||
#[derive(Deserialize, Debug)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
@@ -9,6 +14,67 @@ pub struct CasJWTInfo {
|
||||
pub access_token: String,
|
||||
}
|
||||
|
||||
// This defines the exact three types of repos served on HF Hub.
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum HFRepoType {
|
||||
Model,
|
||||
Dataset,
|
||||
Space,
|
||||
}
|
||||
|
||||
impl FromStr for HFRepoType {
|
||||
type Err = HubClientError;
|
||||
|
||||
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"" => Ok(HFRepoType::Model), // when repo type is omitted from the URL the default type is "model"
|
||||
"model" | "models" => Ok(HFRepoType::Model),
|
||||
"dataset" | "datasets" => Ok(HFRepoType::Dataset),
|
||||
"space" | "spaces" => Ok(HFRepoType::Space),
|
||||
t => Err(HubClientError::InvalidRepoType(t.to_owned())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl HFRepoType {
|
||||
pub fn as_str(&self) -> &str {
|
||||
match self {
|
||||
HFRepoType::Model => "model",
|
||||
HFRepoType::Dataset => "dataset",
|
||||
HFRepoType::Space => "space",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for HFRepoType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(self.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct RepoInfo {
|
||||
// The type of a repo, one of "model | dataset | space"
|
||||
pub repo_type: HFRepoType,
|
||||
// The full name of a repo, formatted as "owner/name"
|
||||
pub full_name: String,
|
||||
}
|
||||
|
||||
impl RepoInfo {
|
||||
pub fn try_from(repo_type: &str, repo_id: &str) -> Result<Self> {
|
||||
Ok(Self {
|
||||
repo_type: repo_type.parse()?,
|
||||
full_name: repo_id.into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for RepoInfo {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}/{}", self.repo_type, self.full_name)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::Result;
|
||||
|
||||
Reference in New Issue
Block a user