diff --git a/Cargo.lock b/Cargo.lock index 437c5cc4..28be0351 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1561,6 +1561,7 @@ dependencies = [ "serde_json", "thiserror 2.0.12", "tokio", + "urlencoding", ] [[package]] @@ -3968,6 +3969,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + [[package]] name = "utf8_iter" version = "1.0.4" diff --git a/Cargo.toml b/Cargo.toml index 79fbe654..ba08a5b1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -102,6 +102,7 @@ tower-service = "0.3" tracing = "0.1" ulid = "1.2" url = "2.5" +urlencoding = "2.1" uuid = "1" walkdir = "2" web-time = "1.1" diff --git a/data/src/bin/xtool.rs b/data/src/bin/xtool.rs index 14c586a7..a5e7f3ce 100644 --- a/data/src/bin/xtool.rs +++ b/data/src/bin/xtool.rs @@ -11,7 +11,7 @@ use clap::{Args, Parser, Subcommand}; use data::data_client::default_config; use data::migration_tool::hub_client_token_refresher::HubClientTokenRefresher; use data::migration_tool::migrate::migrate_files_impl; -use hub_client::{BearerCredentialHelper, HubClient, Operation}; +use hub_client::{BearerCredentialHelper, HubClient, Operation, RepoInfo}; use merklehash::MerkleHash; use utils::auth::TokenRefresher; use walkdir::WalkDir; @@ -56,8 +56,14 @@ impl XCommand { .unwrap_or_else(|| std::env::var("HF_TOKEN").unwrap_or_default()); let cred_helper = BearerCredentialHelper::new(token, ""); - let hub_client = - HubClient::new(&endpoint, &self.overrides.repo_type, &self.overrides.repo_id, "xtool", "", cred_helper)?; + let hub_client = HubClient::new( + &endpoint, + RepoInfo::try_from(&self.overrides.repo_type, &self.overrides.repo_id)?, + Some("main".to_owned()), + "xtool", + "", + cred_helper, + )?; self.command.run(hub_client).await } diff --git a/data/src/configurations.rs b/data/src/configurations.rs index 840b0c0d..0ed00cb6 100644 --- a/data/src/configurations.rs +++ b/data/src/configurations.rs @@ -121,4 +121,15 @@ impl TranslatorConfig { ..self } } + + pub fn with_session_id(self, session_id: &str) -> Self { + if session_id.is_empty() { + return self; + } + + Self { + session_id: Some(session_id.to_owned()), + ..self + } + } } diff --git a/data/src/migration_tool/migrate.rs b/data/src/migration_tool/migrate.rs index 2e84647b..91297ce5 100644 --- a/data/src/migration_tool/migrate.rs +++ b/data/src/migration_tool/migrate.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use anyhow::Result; use cas_object::CompressionScheme; -use hub_client::{BearerCredentialHelper, HubClient, Operation}; +use hub_client::{BearerCredentialHelper, HubClient, Operation, RepoInfo}; use mdb_shard::file_structs::MDBFileInfo; use tracing::{Instrument, Span, info_span, instrument}; use utils::auth::TokenRefresher; @@ -33,7 +33,8 @@ pub async fn migrate_with_external_runtime( repo_id: &str, ) -> Result<()> { let cred_helper = BearerCredentialHelper::new(hub_token.to_owned(), ""); - let hub_client = HubClient::new(hub_endpoint, repo_type, repo_id, "xtool", "", cred_helper)?; + let hub_client = + HubClient::new(hub_endpoint, RepoInfo::try_from(repo_type, repo_id)?, None, "xtool", "", cred_helper)?; migrate_files_impl(file_paths, false, hub_client, cas_endpoint, None, false).await?; diff --git a/git_xet/src/app/install.rs b/git_xet/src/app/install.rs index 04aa52ff..0514deea 100644 --- a/git_xet/src/app/install.rs +++ b/git_xet/src/app/install.rs @@ -2,7 +2,7 @@ use std::path::PathBuf; use crate::app::Command::Transfer; use crate::constants::{GIT_LFS_CUSTOM_TRANSFER_AGENT_NAME, GIT_LFS_CUSTOM_TRANSFER_AGENT_PROGRAM}; -use crate::errors::{Result, config_error}; +use crate::errors::{GitXetError, Result}; use crate::git_process_wrapping::run_git_captured; #[derive(Default)] @@ -48,7 +48,7 @@ fn install_impl(location: ConfigLocation, concurrency: Option) -> Result<() let concurrent = if let Some(c) = concurrency { if c == 0 { - return Err(config_error("concurrency can't be 0")); + return Err(GitXetError::config_error("concurrency can't be 0")); } if c == 1 { "false" diff --git a/git_xet/src/app/xet_agent.rs b/git_xet/src/app/xet_agent.rs index 00bec246..d5d600c9 100644 --- a/git_xet/src/app/xet_agent.rs +++ b/git_xet/src/app/xet_agent.rs @@ -7,14 +7,18 @@ use data::FileUploadSession; use data::data_client::{clean_file, default_config}; use hub_client::Operation; use progress_tracking::{ProgressUpdate, TrackingProgressUpdater}; +use utils::auth::TokenRefresher; -use crate::constants::{HF_ENDPOINT_ENV, XET_ACCESS_TOKEN_HEADER, XET_TOKEN_EXPIRATION_HEADER}; -use crate::errors::{Result, config_error, internal, not_supported}; +use crate::constants::{ + HF_ENDPOINT_ENV, XET_ACCESS_TOKEN_HEADER, XET_CAS_URL, XET_SESSION_ID, XET_TOKEN_EXPIRATION_HEADER, +}; +use crate::errors::{GitXetError, Result}; use crate::git_repo::GitRepo; use crate::git_url::{GitUrl, Scheme}; -use crate::hub_client_token_refresher::HubClientTokenRefresher; -use crate::lfs_agent_protocol::errors::bad_syntax; -use crate::lfs_agent_protocol::{InitRequestInner, ProgressUpdater, TransferAgent, TransferRequest}; +use crate::lfs_agent_protocol::{ + GitLFSProtocolError, InitRequestInner, ProgressUpdater, TransferAgent, TransferRequest, +}; +use crate::token_refresher::DirectRefreshRouteTokenRefresher; // This implements a Git LFS custom transfer agent that uploads and downloads files using the Xet protocol. #[derive(Default)] @@ -38,7 +42,7 @@ impl TransferAgent for XetAgent { let hf_endpoint = if !matches!(remote_url.scheme(), Scheme::Http | Scheme::Https) && remote_url.port().is_some() { Some(std::env::var(HF_ENDPOINT_ENV).map_err(|_| { - config_error( + GitXetError::config_error( r#"This repository has a non-standard Hugging Face remote URL, please specify the Hugging Face server endpoint using environment variable "HF_ENDPOINT""#, ) @@ -55,7 +59,7 @@ impl TransferAgent for XetAgent { } async fn init_download(&mut self, _: &InitRequestInner) -> Result<()> { - Err(not_supported( + Err(GitXetError::not_supported( "custom transfer for download is not implemented yet. Downloads should operate through standard git-lfs download protocol. If you encounter errors downloading, contact Xet Team at Hugging Face.", )) @@ -70,14 +74,15 @@ impl TransferAgent for XetAgent { // so that if the internal git credential helper needs to prompt the user for credential, // only one prompt is presented. let repo = self.repo.get().unwrap(); // protocol state guarantees self.repo is set. - let token_refresher = HubClientTokenRefresher::new( + + let session_id = req.action.header.get(XET_SESSION_ID).map(|s| s.as_str()).unwrap_or_default(); + let token_refresher: Arc = Arc::new(DirectRefreshRouteTokenRefresher::new( repo, self.remote_url.clone(), - self.hf_endpoint.clone(), + &req.action.href, Operation::Upload, - "", - )?; - + session_id, + )?); // From git-lfs: // > First worker is the only one allowed to start immediately. // > The rest wait until successful response from 1st worker to @@ -94,16 +99,33 @@ impl TransferAgent for XetAgent { updater: progress_updater, }; - let cas_url = req.action.href.clone(); - let token = req.action.header[XET_ACCESS_TOKEN_HEADER].clone(); - let token_expiry: u64 = req.action.header[XET_TOKEN_EXPIRATION_HEADER].parse().map_err(internal)?; + let cas_url = req + .action + .header + .get(XET_CAS_URL) + .ok_or_else(|| GitXetError::internal("Hugging Face Hub didn't provide a CAS URL"))? + .clone(); + let token = req + .action + .header + .get(XET_ACCESS_TOKEN_HEADER) + .ok_or_else(|| GitXetError::internal("Hugging Face Hub didn't provide a CAS access token"))? + .clone(); + let token_expiry: u64 = req + .action + .header + .get(XET_TOKEN_EXPIRATION_HEADER) + .ok_or_else(|| GitXetError::internal("Hugging Face Hub didn't provide a CAS access token expiration"))? + .parse() + .map_err(GitXetError::internal)?; - let config = default_config(cas_url, None, Some((token, token_expiry)), Some(Arc::new(token_refresher)))? - .disable_progress_aggregation(); // upload one file at a time so no need for the heavy progress aggregator + let config = default_config(cas_url, None, Some((token, token_expiry)), Some(token_refresher))? + .disable_progress_aggregation() + .with_session_id(session_id); // upload one file at a time so no need for the heavy progress aggregator let session = FileUploadSession::new(config.into(), Some(Arc::new(xet_updater))).await?; let Some(file_path) = &req.path else { - return Err(bad_syntax("file path not provided for upload request").into()); + return Err(GitLFSProtocolError::bad_syntax("file path not provided for upload request").into()); }; clean_file(session.clone(), file_path).await?; diff --git a/git_xet/src/auth.rs b/git_xet/src/auth.rs index 98019426..c17bac44 100644 --- a/git_xet/src/auth.rs +++ b/git_xet/src/auth.rs @@ -5,7 +5,7 @@ use hub_client::{BearerCredentialHelper, CredentialHelper, NoopCredentialHelper, use netrc::Netrc; use crate::constants::HF_TOKEN_ENV; -use crate::errors::{GitXetError, Result, config_error}; +use crate::errors::{GitXetError, Result}; use crate::git_repo::GitRepo; use crate::git_url::{GitUrl, Scheme}; @@ -58,7 +58,7 @@ impl FromStr for AccessMode { "private" => Ok(AccessMode::Private), "negotiate" => Ok(AccessMode::Negotiate), "" => Ok(AccessMode::Empty), - _ => Err(config_error(format!("invalid \"lfs..access\" type: {s}"))), + _ => Err(GitXetError::config_error(format!("invalid \"lfs..access\" type: {s}"))), } } } @@ -129,7 +129,7 @@ pub fn get_credential(repo: &GitRepo, remote_url: &GitUrl, operation: Operation) #[cfg(unix)] return Ok(SSHCredentialHelper::new(remote_url, operation)); #[cfg(not(unix))] - return Err(crate::errors::not_supported(format!( + return Err(GitXetError::not_supported(format!( "using {} in a repository with SSH Git URL is under development; please check back for upgrades or contact Xet Team at Hugging Face.", crate::constants::GIT_LFS_CUSTOM_TRANSFER_AGENT_PROGRAM diff --git a/git_xet/src/auth/git.rs b/git_xet/src/auth/git.rs index edc6c60e..9ac92123 100644 --- a/git_xet/src/auth/git.rs +++ b/git_xet/src/auth/git.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use hub_client::BearerCredentialHelper; -use crate::errors::{Result, config_error}; +use crate::errors::{GitXetError, Result}; use crate::git_process_wrapping::run_git_captured_with_input_and_output; // This implements the mechanism to get credential stored in configured git credential helpers, including @@ -76,7 +76,7 @@ impl GitCredentialHelper { } } - Err(config_error(format!("failed to find authentication for {host_url}"))) + Err(GitXetError::config_error(format!("failed to find authentication for {host_url}"))) } } diff --git a/git_xet/src/auth/ssh.rs b/git_xet/src/auth/ssh.rs index 7b02a12e..7f4433c8 100644 --- a/git_xet/src/auth/ssh.rs +++ b/git_xet/src/auth/ssh.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use async_trait::async_trait; -use hub_client::{CredentialHelper, Operation, Result, credential_helper_error}; +use hub_client::{CredentialHelper, HubClientError, Operation, Result}; #[cfg(unix)] use openssh::{KnownHosts, Session}; use reqwest::header; @@ -46,11 +46,11 @@ impl SSHCredentialHelper { #[cfg(unix)] async fn authenticate(&self) -> Result { - let host_url = self.remote_url.host_url().map_err(credential_helper_error)?; + let host_url = self.remote_url.host_url().map_err(HubClientError::credential_helper_error)?; let full_repo_path = self.remote_url.full_repo_path(); let session = Session::connect(&host_url, KnownHosts::Add) .await - .map_err(credential_helper_error)?; + .map_err(HubClientError::credential_helper_error)?; let output = session .command("git-lfs-authenticate") @@ -58,9 +58,9 @@ impl SSHCredentialHelper { .arg(self.operation.as_str()) .output() .await - .map_err(credential_helper_error)?; + .map_err(HubClientError::credential_helper_error)?; - serde_json::from_slice(&output.stdout).map_err(credential_helper_error) + serde_json::from_slice(&output.stdout).map_err(HubClientError::credential_helper_error) } #[cfg(not(unix))] diff --git a/git_xet/src/constants.rs b/git_xet/src/constants.rs index 35afd8d5..2a5bb7fb 100644 --- a/git_xet/src/constants.rs +++ b/git_xet/src/constants.rs @@ -7,8 +7,10 @@ pub const GIT_LFS_CUSTOM_TRANSFER_AGENT_PROGRAM: &str = "git-xet"; pub const CURRENT_VERSION: &str = env!("CARGO_PKG_VERSION"); // Moon-landing Xet service headers +pub const XET_CAS_URL: &str = "X-Xet-Cas-Url"; pub const XET_ACCESS_TOKEN_HEADER: &str = "X-Xet-Access-Token"; pub const XET_TOKEN_EXPIRATION_HEADER: &str = "X-Xet-Token-Expiration"; +pub const XET_SESSION_ID: &str = "X-Xet-Session-Id"; // Environment variable names pub const HF_TOKEN_ENV: &str = "HF_TOKEN"; diff --git a/git_xet/src/errors.rs b/git_xet/src/errors.rs index 05f4e008..67d840a8 100644 --- a/git_xet/src/errors.rs +++ b/git_xet/src/errors.rs @@ -48,23 +48,25 @@ pub enum GitXetError { pub type Result = std::result::Result; -pub(crate) fn git_cmd_failed(e: impl Display, source: Option) -> GitXetError { - GitXetError::GitCommandFailed { - reason: e.to_string(), - source, +impl GitXetError { + pub(crate) fn git_cmd_failed(e: impl Display, source: Option) -> GitXetError { + GitXetError::GitCommandFailed { + reason: e.to_string(), + source, + } } -} -pub(crate) fn not_supported(e: impl Display) -> GitXetError { - GitXetError::NotSupported(e.to_string()) -} + pub(crate) fn not_supported(e: impl Display) -> GitXetError { + GitXetError::NotSupported(e.to_string()) + } -pub(crate) fn config_error(e: impl Display) -> GitXetError { - GitXetError::InvalidGitConfig(e.to_string()) -} + pub(crate) fn config_error(e: impl Display) -> GitXetError { + GitXetError::InvalidGitConfig(e.to_string()) + } -pub(crate) fn internal(e: impl Display) -> GitXetError { - GitXetError::Internal(e.to_string()) + pub(crate) fn internal(e: impl Display) -> GitXetError { + GitXetError::Internal(e.to_string()) + } } impl From for GitXetError { diff --git a/git_xet/src/git_process_wrapping.rs b/git_xet/src/git_process_wrapping.rs index f49f07bc..fe22f068 100644 --- a/git_xet/src/git_process_wrapping.rs +++ b/git_xet/src/git_process_wrapping.rs @@ -3,7 +3,7 @@ use std::path::Path; use std::process::{Child, ChildStdin, Command, Stdio}; use crate::constants::GIT_EXECUTABLE; -use crate::errors::{Result, git_cmd_failed, internal}; +use crate::errors::{GitXetError, Result}; // This mod implements utilities to invoke Git commands through child processes from the `git` program. @@ -74,8 +74,8 @@ impl CapturedCommand { // From past experience, if the "git" program is not found the underlying error // only says "Not Found" and is not very helpful to identify the cause. We thus // capture this error and make the message more explicit. - std::io::ErrorKind::NotFound => git_cmd_failed(r#"program "git" not found"#, Some(e)), - _ => git_cmd_failed("internal", Some(e)), + std::io::ErrorKind::NotFound => GitXetError::git_cmd_failed(r#"program "git" not found"#, Some(e)), + _ => GitXetError::git_cmd_failed("internal", Some(e)), })?, }) } @@ -94,7 +94,7 @@ impl CapturedCommand { self.child_process .stdin .take() - .ok_or_else(|| internal("stdin of child process is not captured")) + .ok_or_else(|| GitXetError::internal("stdin of child process is not captured")) } /// Synchronously wait for the child to exit completely, returning `Ok(())` if the child exits with status code 0; @@ -117,7 +117,7 @@ impl CapturedCommand { _ => { let stdout = std::str::from_utf8(&ret.stdout).unwrap_or("").trim(); let stderr = std::str::from_utf8(&ret.stderr).unwrap_or("").trim(); - Err(git_cmd_failed( + Err(GitXetError::git_cmd_failed( format!("err_code = {:?}, stdout = \"{}\", stderr = \"{}\"", ret.status.code(), stdout, stderr), None, )) diff --git a/git_xet/src/git_repo.rs b/git_xet/src/git_repo.rs index 48f6e03e..e3060e64 100644 --- a/git_xet/src/git_repo.rs +++ b/git_xet/src/git_repo.rs @@ -3,7 +3,7 @@ use std::sync::{Arc, Mutex}; use git2::{Config, Repository}; -use crate::errors::{GitXetError, Result, config_error, internal}; +use crate::errors::{GitXetError, Result}; use crate::git_url::GitUrl; #[derive(Clone)] @@ -33,10 +33,27 @@ impl GitRepo { // Returns the path to the .git folder for normal repositories or the repository itself for bare repositories. pub fn git_path(&self) -> Result { - let repo = self.repo.lock().map_err(internal)?; + let repo = self.repo.lock().map_err(GitXetError::internal)?; Ok(repo.path().to_path_buf()) } + // Resolves the reference pointed at by HEAD, returns branch name if it is a branch. + pub fn branch_name(&self) -> Result> { + let repo = self.repo.lock().map_err(GitXetError::internal)?; + + let maybe_head_ref = repo.head(); + Ok(maybe_head_ref.ok().and_then(|head_ref| { + if head_ref.is_branch() { + head_ref + .name() + .and_then(|refs_heads_branch| refs_heads_branch.strip_prefix("refs/heads/")) + .map(|branch| branch.to_owned()) + } else { + None + } + })) + } + // Returns the remote that a git push/fetch/pull operation // is targeted at, based on: // 1. The currently tracked remote branch, if present @@ -44,20 +61,9 @@ impl GitRepo { // 3. Any other SINGLE remote defined in .git/config // 4. Use "origin" as a fallback. pub fn remote_name(&self) -> Result { - let repo = self.repo.lock().map_err(internal)?; - - let maybe_head_ref = repo.head(); - let maybe_branch_name = maybe_head_ref.ok().and_then(|head_ref| { - if head_ref.is_branch() { - head_ref - .name() - .and_then(|refs_heads_branch| refs_heads_branch.rsplit('/').next()) - .map(|branch| branch.to_owned()) - } else { - None - } - }); + let maybe_branch_name = self.branch_name()?; + let repo = self.repo.lock().map_err(GitXetError::internal)?; let config = repo.config()?.snapshot()?; // try tracking remote @@ -87,13 +93,13 @@ impl GitRepo { // Returns the URL for a specific remote name. pub fn remote_name_to_url(&self, remote: &str) -> Result { - let repo = self.repo.lock().map_err(internal)?; + let repo = self.repo.lock().map_err(GitXetError::internal)?; let url: GitUrl = repo .find_remote(remote)? .url() .map(|s| s.to_string()) - .ok_or_else(|| config_error(format!("no url for remote \"{remote}\"")))? + .ok_or_else(|| GitXetError::config_error(format!("no url for remote \"{remote}\"")))? .parse()?; Ok(url) @@ -108,7 +114,7 @@ impl GitRepo { // Returns a snapshot of the current Git repo config. pub fn config(&self) -> Result { - let repo = self.repo.lock().map_err(internal)?; + let repo = self.repo.lock().map_err(GitXetError::internal)?; Ok(repo.config()?.snapshot()?) } @@ -122,6 +128,25 @@ mod tests { use crate::git_repo::GitRepo; use crate::test_utils::TestRepo; + #[test] + #[serial(env_var_write_read)] + fn test_get_ref_name() -> Result<()> { + let test_repo = TestRepo::new("main")?; + let repo = GitRepo::open(test_repo.path())?; + + test_repo.new_commit("data", "hello".as_bytes(), "add new file")?; + assert_eq!(repo.branch_name()?, Some("main".to_owned())); + + test_repo.new_branch("pr/1", "main")?; + test_repo.new_commit("data", "world".as_bytes(), "update file")?; + assert_eq!(repo.branch_name()?, Some("pr/1".to_owned())); + + test_repo.checkout(&["HEAD^"])?; + assert_eq!(repo.branch_name()?, None); + + Ok(()) + } + #[test] #[serial(env_var_write_read)] fn test_get_remote_from_local_config() -> Result<()> { diff --git a/git_xet/src/git_url.rs b/git_xet/src/git_url.rs index ccfee90e..6c9cd20a 100644 --- a/git_xet/src/git_url.rs +++ b/git_xet/src/git_url.rs @@ -1,10 +1,10 @@ -use std::fmt::Display; use std::str::FromStr; use git_url_parse::GitUrl as innerGitUrl; pub use git_url_parse::Scheme; +use hub_client::{HFRepoType, RepoInfo}; -use crate::errors::{GitXetError, Result, config_error, not_supported}; +use crate::errors::{GitXetError, Result}; // This mod implements funtionalities to handle Git remote URLs, especially tailored for // Git LFS and Hugging Face repo needs, including deriving Git LFS server endpoint from @@ -62,7 +62,7 @@ impl GitUrl { .inner .host .as_ref() - .ok_or_else(|| config_error("remote URL missing host name"))?; + .ok_or_else(|| GitXetError::config_error("remote URL missing host name"))?; let port = self.inner.port; let port_str = if translated || port.is_none() { @@ -86,7 +86,7 @@ impl GitUrl { .inner .host .as_ref() - .ok_or_else(|| config_error("remote URL missing host name"))?; + .ok_or_else(|| GitXetError::config_error("remote URL missing host name"))?; let port = self.inner.port; let port_str = if translated || port.is_none() { @@ -105,11 +105,14 @@ impl GitUrl { match self.inner.scheme { Scheme::Http => Ok(("http", false)), Scheme::Https => Ok(("https", false)), - Scheme::File | Scheme::Ftp | Scheme::Ftps => { - Err(not_supported(format!("cannot convert from scheme \"{}://\" to \"http(s)://\"", self.inner.scheme))) - }, + Scheme::File | Scheme::Ftp | Scheme::Ftps => Err(GitXetError::not_supported(format!( + "cannot convert from scheme \"{}://\" to \"http(s)://\"", + self.inner.scheme + ))), Scheme::Git | Scheme::GitSsh | Scheme::Ssh => Ok(("https", true)), - Scheme::Unspecified => Err(not_supported("cannot convert from unspecified scheme to \"http(s)://\"")), + Scheme::Unspecified => { + Err(GitXetError::not_supported("cannot convert from unspecified scheme to \"http(s)://\"")) + }, } } @@ -143,7 +146,7 @@ impl GitUrl { .inner .host .as_ref() - .ok_or_else(|| config_error("remote URL missing host name"))?; + .ok_or_else(|| GitXetError::config_error("remote URL missing host name"))?; let port_str = if let Some(p) = self.inner.port { format!(":{}", p) @@ -171,6 +174,7 @@ impl GitUrl { } // Returns the parsed full repo path into `RepoInfo`. + #[allow(unused)] pub fn repo_info(&self) -> Result { let path = self.full_repo_path(); let full_name = self.inner.fullname.clone(); // The full name of the repo, formatted as "owner/name" @@ -181,58 +185,6 @@ impl GitUrl { } } -// This defines the exact three types of repos served on HF Hub. -#[derive(Debug, PartialEq)] -pub enum HFRepoType { - Model, - Dataset, - Space, -} - -impl FromStr for HFRepoType { - type Err = GitXetError; - - fn from_str(s: &str) -> std::result::Result { - match s.to_lowercase().as_str() { - "" => Ok(HFRepoType::Model), // when repo type is omitted from the URL the default type is "model" - "model" | "models" => Ok(HFRepoType::Model), - "dataset" | "datasets" => Ok(HFRepoType::Dataset), - "space" | "spaces" => Ok(HFRepoType::Space), - t => Err(config_error(format!("invalid repo type {t}"))), - } - } -} - -impl HFRepoType { - pub fn as_str(&self) -> &str { - match self { - HFRepoType::Model => "model", - HFRepoType::Dataset => "dataset", - HFRepoType::Space => "space", - } - } -} - -impl Display for HFRepoType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(self.as_str()) - } -} - -#[derive(Debug, PartialEq)] -pub struct RepoInfo { - // The type of a repo, one of "model | dataset | space" - pub repo_type: HFRepoType, - // The full name of a repo, formatted as "owner/name" - pub full_name: String, -} - -impl Display for RepoInfo { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}/{}", self.repo_type, self.full_name) - } -} - #[cfg(test)] mod test_lfs_server_discovery { use super::GitUrl; diff --git a/git_xet/src/hub_client_token_refresher.rs b/git_xet/src/hub_client_token_refresher.rs deleted file mode 100644 index 258e2a78..00000000 --- a/git_xet/src/hub_client_token_refresher.rs +++ /dev/null @@ -1,67 +0,0 @@ -use std::sync::Arc; - -use async_trait::async_trait; -use hub_client::{HubClient, Operation}; -use utils::auth::{TokenInfo, TokenRefresher}; -use utils::errors::AuthError; - -use crate::auth::get_credential; -use crate::constants::GIT_LFS_CUSTOM_TRANSFER_AGENT_PROGRAM; -use crate::errors::Result; -use crate::git_repo::GitRepo; -use crate::git_url::GitUrl; - -pub struct HubClientTokenRefresher { - operation: Operation, - client: Arc, -} - -impl HubClientTokenRefresher { - pub fn new( - repo: &GitRepo, - remote_url: Option, - token_endpoint: Option, - operation: Operation, - session_id: &str, - ) -> Result { - let remote_url = match remote_url { - Some(r) => r, - None => repo.remote_url()?, - }; - let repo_info = remote_url.repo_info()?; - - let endpoint = match token_endpoint { - Some(e) => e, - None => remote_url.to_derived_http_host_url()?, - }; - - let cred_helper = get_credential(repo, &remote_url, operation)?; - - let client = HubClient::new( - &endpoint, - repo_info.repo_type.as_str(), - &repo_info.full_name, - GIT_LFS_CUSTOM_TRANSFER_AGENT_PROGRAM, - session_id, - cred_helper, - )?; - - Ok(Self { - operation, - client: Arc::new(client), - }) - } -} - -#[async_trait] -impl TokenRefresher for HubClientTokenRefresher { - async fn refresh(&self) -> std::result::Result { - let jwt_info = self - .client - .get_cas_jwt(self.operation) - .await - .map_err(AuthError::token_refresh_failure)?; - - Ok((jwt_info.access_token, jwt_info.exp)) - } -} diff --git a/git_xet/src/lfs_agent_protocol.rs b/git_xet/src/lfs_agent_protocol.rs index 0179b3d1..317ee15b 100644 --- a/git_xet/src/lfs_agent_protocol.rs +++ b/git_xet/src/lfs_agent_protocol.rs @@ -2,7 +2,7 @@ use std::io::{BufRead, Write}; use std::path::PathBuf; use std::sync::{Arc, Mutex}; -use crate::errors::{Result, internal}; +use crate::errors::{GitXetError, Result}; mod agent_state; pub mod errors; @@ -116,7 +116,7 @@ where }, }; - stdout.lock().map_err(internal)?.write_all(response.as_bytes())?; + stdout.lock().map_err(GitXetError::internal)?.write_all(response.as_bytes())?; } Ok(()) diff --git a/git_xet/src/lfs_agent_protocol/agent_state.rs b/git_xet/src/lfs_agent_protocol/agent_state.rs index 7dd9eee1..f9740798 100644 --- a/git_xet/src/lfs_agent_protocol/agent_state.rs +++ b/git_xet/src/lfs_agent_protocol/agent_state.rs @@ -1,4 +1,4 @@ -use super::errors::{Result, bad_state}; +use super::errors::{GitLFSProtocolError, Result}; // This defines the state of a transfer agent to make sure that request events are initiated // in the correct order. Unlike a traditional state machines, we don't define a "terminated" @@ -18,27 +18,27 @@ impl LFSAgentState { match self { Self::PendingInit => match to { Self::InitedForUpload | Self::InitedForDownload => (), - _ => return Err(bad_state("init event not yet received")), + _ => return Err(GitLFSProtocolError::bad_state("init event not yet received")), }, Self::InitedForUpload => match to { Self::Uploading => (), - Self::Downloading => return Err(bad_state("agent initiated for upload")), - _ => return Err(bad_state("init event already received")), + Self::Downloading => return Err(GitLFSProtocolError::bad_state("agent initiated for upload")), + _ => return Err(GitLFSProtocolError::bad_state("init event already received")), }, Self::InitedForDownload => match to { Self::Downloading => (), - Self::Uploading => return Err(bad_state("agent initiated for download")), - _ => return Err(bad_state("init event already received")), + Self::Uploading => return Err(GitLFSProtocolError::bad_state("agent initiated for download")), + _ => return Err(GitLFSProtocolError::bad_state("init event already received")), }, Self::Uploading => match to { Self::Uploading => (), - Self::Downloading => return Err(bad_state("agent initiated for upload")), - _ => return Err(bad_state("data transfer already in progress")), + Self::Downloading => return Err(GitLFSProtocolError::bad_state("agent initiated for upload")), + _ => return Err(GitLFSProtocolError::bad_state("data transfer already in progress")), }, Self::Downloading => match to { Self::Downloading => (), - Self::Uploading => return Err(bad_state("agent initiated for download")), - _ => return Err(bad_state("data transfer already in progress")), + Self::Uploading => return Err(GitLFSProtocolError::bad_state("agent initiated for download")), + _ => return Err(GitLFSProtocolError::bad_state("data transfer already in progress")), }, }; diff --git a/git_xet/src/lfs_agent_protocol/errors.rs b/git_xet/src/lfs_agent_protocol/errors.rs index ec8e7252..cf205231 100644 --- a/git_xet/src/lfs_agent_protocol/errors.rs +++ b/git_xet/src/lfs_agent_protocol/errors.rs @@ -22,14 +22,16 @@ pub enum GitLFSProtocolError { pub(super) type Result = std::result::Result; -pub(crate) fn bad_syntax(e: impl Display) -> GitLFSProtocolError { - GitLFSProtocolError::Syntax(e.to_string()) -} +impl GitLFSProtocolError { + pub(crate) fn bad_syntax(e: impl Display) -> GitLFSProtocolError { + GitLFSProtocolError::Syntax(e.to_string()) + } -pub(crate) fn bad_argument(e: impl Display) -> GitLFSProtocolError { - GitLFSProtocolError::Argument(e.to_string()) -} + pub(crate) fn bad_argument(e: impl Display) -> GitLFSProtocolError { + GitLFSProtocolError::Argument(e.to_string()) + } -pub(crate) fn bad_state(e: impl Display) -> GitLFSProtocolError { - GitLFSProtocolError::State(e.to_string()) + pub(crate) fn bad_state(e: impl Display) -> GitLFSProtocolError { + GitLFSProtocolError::State(e.to_string()) + } } diff --git a/git_xet/src/lfs_agent_protocol/protocol_spec.rs b/git_xet/src/lfs_agent_protocol/protocol_spec.rs index ae742a27..9654a7e9 100644 --- a/git_xet/src/lfs_agent_protocol/protocol_spec.rs +++ b/git_xet/src/lfs_agent_protocol/protocol_spec.rs @@ -6,7 +6,7 @@ use std::str::FromStr; use serde::{Deserialize, Serialize}; use serde_json::json; -use super::errors::{GitLFSProtocolError, Result, bad_argument, bad_syntax}; +use super::errors::{GitLFSProtocolError, Result}; // This file defines the protocol that Git LFS uses to talk to // custom transfer agents. This implementation follows the protocol specification @@ -105,37 +105,45 @@ impl LFSProtocolRequestEvent { InitRequest::Download(inner) => inner, }; if inner.remote.is_empty() { - return Err(bad_argument("invalid remote")); + return Err(GitLFSProtocolError::bad_argument("invalid remote")); } Ok(()) }, LFSProtocolRequestEvent::Upload(req) => { if req.oid.len() != OID_LEN { - return Err(bad_argument("invalid oid")); + return Err(GitLFSProtocolError::bad_argument("invalid oid")); } if req.size == 0 { - return Err(bad_argument("invalid size")); + return Err(GitLFSProtocolError::bad_argument("invalid size")); } if req.path.is_none() { - return Err(bad_syntax("file path not provided for upload request")); + return Err(GitLFSProtocolError::bad_syntax("file path not provided for upload request")); + } + + if req.action.href.is_empty() { + return Err(GitLFSProtocolError::bad_argument("empty action.href in server response")); } Ok(()) }, LFSProtocolRequestEvent::Download(req) => { if req.oid.len() != OID_LEN { - return Err(bad_argument("invalid oid")); + return Err(GitLFSProtocolError::bad_argument("invalid oid")); } if req.size == 0 { - return Err(bad_argument("invalid size")); + return Err(GitLFSProtocolError::bad_argument("invalid size")); } if req.path.is_some() { - return Err(bad_syntax("file path provided for download request")); + return Err(GitLFSProtocolError::bad_syntax("file path provided for download request")); + } + + if req.action.href.is_empty() { + return Err(GitLFSProtocolError::bad_argument("empty action.href in server response")); } Ok(()) @@ -149,7 +157,7 @@ impl FromStr for LFSProtocolRequestEvent { type Err = GitLFSProtocolError; fn from_str(s: &str) -> std::result::Result { - let req: LFSProtocolRequestEvent = serde_json::from_str(s).map_err(bad_syntax)?; + let req: LFSProtocolRequestEvent = serde_json::from_str(s).map_err(GitLFSProtocolError::bad_syntax)?; req.validate()?; Ok(req) } diff --git a/git_xet/src/lib.rs b/git_xet/src/lib.rs index cac92191..4a9917cd 100644 --- a/git_xet/src/lib.rs +++ b/git_xet/src/lib.rs @@ -5,6 +5,6 @@ mod errors; mod git_process_wrapping; mod git_repo; mod git_url; -mod hub_client_token_refresher; mod lfs_agent_protocol; mod test_utils; +mod token_refresher; diff --git a/git_xet/src/test_utils/test_repo.rs b/git_xet/src/test_utils/test_repo.rs index 8d45e990..0bfb96bd 100644 --- a/git_xet/src/test_utils/test_repo.rs +++ b/git_xet/src/test_utils/test_repo.rs @@ -68,6 +68,20 @@ impl TestRepo { Ok(()) } + // Create a new branch `new_branch_name` off `base`. + pub fn new_branch(&self, new_branch_name: &str, base: &str) -> Result<()> { + run_git_captured(&self.repo_path, "checkout", &[base, "-b", new_branch_name])?; + + Ok(()) + } + + // Run the versatile checkout command. + pub fn checkout(&self, args: &[&str]) -> Result<()> { + run_git_captured(&self.repo_path, "checkout", args)?; + + Ok(()) + } + // Create a new local branch `local_branch_name` off HEAD in this repo that tracks a remote branch // `remote`:`remote_branch_name`. pub fn new_branch_tracking_remote( @@ -76,8 +90,8 @@ impl TestRepo { remote_branch_name: &str, local_branch_name: &str, ) -> Result<()> { + self.new_branch(local_branch_name, "HEAD")?; run_git_captured(&self.repo_path, "fetch", &[remote, remote_branch_name])?; - run_git_captured(&self.repo_path, "checkout", &["-b", local_branch_name])?; run_git_captured( &self.repo_path, "branch", diff --git a/git_xet/src/token_refresher.rs b/git_xet/src/token_refresher.rs new file mode 100644 index 00000000..61bdddb1 --- /dev/null +++ b/git_xet/src/token_refresher.rs @@ -0,0 +1,65 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use cas_client::{Api, RetryConfig, build_http_client}; +use hub_client::{CasJWTInfo, CredentialHelper, Operation}; +use reqwest::header; +use reqwest_middleware::ClientWithMiddleware; +use utils::auth::{TokenInfo, TokenRefresher}; +use utils::errors::AuthError; + +use crate::auth::get_credential; +use crate::constants::GIT_LFS_CUSTOM_TRANSFER_AGENT_PROGRAM; +use crate::errors::Result; +use crate::git_repo::GitRepo; +use crate::git_url::GitUrl; + +pub struct DirectRefreshRouteTokenRefresher { + refresh_route: String, + client: ClientWithMiddleware, + cred_helper: Arc, +} + +impl DirectRefreshRouteTokenRefresher { + pub fn new( + repo: &GitRepo, + remote_url: Option, + refresh_route: &str, + operation: Operation, + session_id: &str, + ) -> Result { + let remote_url = match remote_url { + Some(r) => r, + None => repo.remote_url()?, + }; + + let cred_helper = get_credential(repo, &remote_url, operation)?; + + Ok(Self { + refresh_route: refresh_route.to_owned(), + client: build_http_client(RetryConfig::default(), session_id)?, + cred_helper, + }) + } +} + +#[async_trait] +impl TokenRefresher for DirectRefreshRouteTokenRefresher { + async fn refresh(&self) -> std::result::Result { + let req = self + .client + .get(&self.refresh_route) + .with_extension(Api("xet-token")) + .header(header::USER_AGENT, GIT_LFS_CUSTOM_TRANSFER_AGENT_PROGRAM); + let req = self + .cred_helper + .fill_credential(req) + .await + .map_err(AuthError::token_refresh_failure)?; + let response = req.send().await.map_err(AuthError::token_refresh_failure)?; + + let jwt_info: CasJWTInfo = response.json().await.map_err(AuthError::token_refresh_failure)?; + + Ok((jwt_info.access_token, jwt_info.exp)) + } +} diff --git a/hf_xet/Cargo.lock b/hf_xet/Cargo.lock index 4b39d3c7..3e5088c3 100644 --- a/hf_xet/Cargo.lock +++ b/hf_xet/Cargo.lock @@ -1355,6 +1355,7 @@ dependencies = [ "reqwest-middleware", "serde", "thiserror 2.0.15", + "urlencoding", ] [[package]] @@ -3793,6 +3794,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "urlencoding" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" + [[package]] name = "utf8_iter" version = "1.0.4" diff --git a/hub_client/Cargo.toml b/hub_client/Cargo.toml index 528788d1..e958310c 100644 --- a/hub_client/Cargo.toml +++ b/hub_client/Cargo.toml @@ -13,6 +13,7 @@ reqwest = { workspace = true } reqwest-middleware = { workspace = true } serde = { workspace = true } thiserror = { workspace = true } +urlencoding = { workspace = true } [dev-dependencies] serde_json = { workspace = true } diff --git a/hub_client/src/client.rs b/hub_client/src/client.rs index 2a2b3d62..80bc9cd2 100644 --- a/hub_client/src/client.rs +++ b/hub_client/src/client.rs @@ -3,10 +3,11 @@ use std::sync::Arc; use cas_client::exports::ClientWithMiddleware; use cas_client::{Api, RetryConfig, build_http_client}; use http::header; +use urlencoding::encode; use crate::auth::CredentialHelper; use crate::errors::*; -use crate::types::CasJWTInfo; +use crate::types::{CasJWTInfo, RepoInfo}; /// The type of operation to perform, either to upload files or to download files. /// Different operations lead to CAS access token with different authorization levels. @@ -34,8 +35,8 @@ impl Operation { pub struct HubClient { endpoint: String, - repo_type: String, - repo_id: String, + repo_info: RepoInfo, + reference: Option, user_agent: String, client: ClientWithMiddleware, cred_helper: Arc, @@ -44,16 +45,16 @@ pub struct HubClient { impl HubClient { pub fn new( endpoint: &str, - repo_type: &str, - repo_id: &str, + repo_info: RepoInfo, + reference: Option, user_agent: &str, session_id: &str, cred_helper: Arc, ) -> Result { Ok(HubClient { endpoint: endpoint.to_owned(), - repo_type: repo_type.to_owned(), - repo_id: repo_id.to_owned(), + repo_info, + reference, user_agent: user_agent.to_owned(), client: build_http_client(RetryConfig::default(), session_id)?, cred_helper, @@ -63,12 +64,27 @@ impl HubClient { // Get CAS access token from Hub access token. pub async fn get_cas_jwt(&self, operation: Operation) -> Result { let endpoint = self.endpoint.as_str(); - let repo_type = self.repo_type.as_str(); - let repo_id = self.repo_id.as_str(); + let repo_type = self.repo_info.repo_type.as_str(); + let repo_id = self.repo_info.full_name.as_str(); let token_type = operation.token_type(); + // The reference may contain "/" but the "xet-[]-token" API only parses "rev" from a single component, + // thus we encode the reference. It defaults to "main" if not specified by caller because the + // API route expects a "rev" component. + let rev = encode(self.reference.as_deref().unwrap_or("main")); + + // Clients can get a xet write token, if + // - the "rev" is a regular branch, with a HF write token; + // - the "rev" is a pr branch, with a HF write or read token; + // - it intends to create a pr and repo is enabled for discussion, with a HF write or read token. + let query = if matches!(operation, Operation::Upload) && self.reference.is_none() { + "?create_pr=1" + } else { + "" + }; + // note that this API doesn't take a Basic auth - let url = format!("{endpoint}/api/{repo_type}s/{repo_id}/xet-{token_type}-token/main"); + let url = format!("{endpoint}/api/{repo_type}s/{repo_id}/xet-{token_type}-token/{rev}{query}"); let req = self .client @@ -92,17 +108,79 @@ impl HubClient { mod tests { use super::HubClient; use crate::errors::Result; - use crate::{BearerCredentialHelper, Operation}; + use crate::{BearerCredentialHelper, HFRepoType, Operation, RepoInfo}; #[tokio::test] - #[ignore = "need valid token"] - async fn test_get_jwt_token() -> Result<()> { - let cred_helper = BearerCredentialHelper::new("[hf_token]".to_owned(), ""); - let hub_client = HubClient::new("https://huggingface.co", "model", "seanses/tm", "xtool", "", cred_helper)?; + #[ignore = "need valid write token"] + async fn test_get_jwt_token_with_hf_write_token() -> Result<()> { + let cred_helper = BearerCredentialHelper::new("[hf_write_token]".to_owned(), ""); + let hub_client = HubClient::new( + "https://huggingface.co", + RepoInfo { + repo_type: HFRepoType::Model, + full_name: "seanses/tm".into(), + }, + Some("main".into()), + "xtool", + "", + cred_helper, + )?; - let read_info = hub_client.get_cas_jwt(Operation::Download).await?; + let read_info = hub_client.get_cas_jwt(Operation::Upload).await?; - println!("{:?}", read_info); + assert!(read_info.access_token.len() > 0); + assert!(read_info.cas_url.len() > 0); + assert!(read_info.exp > 0); + + Ok(()) + } + + #[tokio::test] + #[ignore = "need valid read token and pr created on hub"] + async fn test_get_jwt_token_with_hf_read_token_pr_branch() -> Result<()> { + let cred_helper = BearerCredentialHelper::new("[hf_read_token]".to_owned(), ""); + let hub_client = HubClient::new( + "https://huggingface.co", + RepoInfo { + repo_type: HFRepoType::Model, + full_name: "seanses/tm".into(), + }, + Some("refs/pr/1".into()), + "xtool", + "", + cred_helper, + )?; + + let read_info = hub_client.get_cas_jwt(Operation::Upload).await?; + + assert!(read_info.access_token.len() > 0); + assert!(read_info.cas_url.len() > 0); + assert!(read_info.exp > 0); + + Ok(()) + } + + #[tokio::test] + #[ignore = "need valid read token"] + async fn test_get_jwt_token_with_hf_read_token_create_pr() -> Result<()> { + let cred_helper = BearerCredentialHelper::new("[hf_read_token]".to_owned(), ""); + let hub_client = HubClient::new( + "https://huggingface.co", + RepoInfo { + repo_type: HFRepoType::Model, + full_name: "seanses/tm".into(), + }, + None, + "xtool", + "", + cred_helper, + )?; + + let read_info = hub_client.get_cas_jwt(Operation::Upload).await?; + + assert!(read_info.access_token.len() > 0); + assert!(read_info.cas_url.len() > 0); + assert!(read_info.exp > 0); Ok(()) } diff --git a/hub_client/src/errors.rs b/hub_client/src/errors.rs index b0456c82..749d56fb 100644 --- a/hub_client/src/errors.rs +++ b/hub_client/src/errors.rs @@ -13,10 +13,15 @@ pub enum HubClientError { #[error("Credential helper error: {0}")] CredentialHelper(anyhow::Error), + + #[error("Invalid repo type: {0}")] + InvalidRepoType(String), } pub type Result = std::result::Result; -pub fn credential_helper_error(e: impl std::error::Error + Send + Sync + 'static) -> HubClientError { - HubClientError::CredentialHelper(e.into()) +impl HubClientError { + pub fn credential_helper_error(e: impl std::error::Error + Send + Sync + 'static) -> HubClientError { + HubClientError::CredentialHelper(e.into()) + } } diff --git a/hub_client/src/lib.rs b/hub_client/src/lib.rs index a8a8cfe8..f6b7015f 100644 --- a/hub_client/src/lib.rs +++ b/hub_client/src/lib.rs @@ -5,5 +5,5 @@ mod types; pub use auth::{BearerCredentialHelper, CredentialHelper, NoopCredentialHelper}; pub use client::{HubClient, Operation}; -pub use errors::{HubClientError, Result, credential_helper_error}; -pub use types::CasJWTInfo; +pub use errors::{HubClientError, Result}; +pub use types::{CasJWTInfo, HFRepoType, RepoInfo}; diff --git a/hub_client/src/types.rs b/hub_client/src/types.rs index e9850ae5..76c33c22 100644 --- a/hub_client/src/types.rs +++ b/hub_client/src/types.rs @@ -1,5 +1,10 @@ +use std::fmt::Display; +use std::str::FromStr; + use serde::Deserialize; +use crate::errors::{HubClientError, Result}; + /// This defines the response format from the Huggingface Hub Xet CAS access token API. #[derive(Deserialize, Debug)] #[serde(rename_all = "camelCase")] @@ -9,6 +14,67 @@ pub struct CasJWTInfo { pub access_token: String, } +// This defines the exact three types of repos served on HF Hub. +#[derive(Debug, PartialEq)] +pub enum HFRepoType { + Model, + Dataset, + Space, +} + +impl FromStr for HFRepoType { + type Err = HubClientError; + + fn from_str(s: &str) -> std::result::Result { + match s.to_lowercase().as_str() { + "" => Ok(HFRepoType::Model), // when repo type is omitted from the URL the default type is "model" + "model" | "models" => Ok(HFRepoType::Model), + "dataset" | "datasets" => Ok(HFRepoType::Dataset), + "space" | "spaces" => Ok(HFRepoType::Space), + t => Err(HubClientError::InvalidRepoType(t.to_owned())), + } + } +} + +impl HFRepoType { + pub fn as_str(&self) -> &str { + match self { + HFRepoType::Model => "model", + HFRepoType::Dataset => "dataset", + HFRepoType::Space => "space", + } + } +} + +impl Display for HFRepoType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_str()) + } +} + +#[derive(Debug, PartialEq)] +pub struct RepoInfo { + // The type of a repo, one of "model | dataset | space" + pub repo_type: HFRepoType, + // The full name of a repo, formatted as "owner/name" + pub full_name: String, +} + +impl RepoInfo { + pub fn try_from(repo_type: &str, repo_id: &str) -> Result { + Ok(Self { + repo_type: repo_type.parse()?, + full_name: repo_id.into(), + }) + } +} + +impl Display for RepoInfo { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}/{}", self.repo_type, self.full_name) + } +} + #[cfg(test)] mod tests { use anyhow::Result;