mirror of
https://github.com/huggingface/xet-core.git
synced 2026-06-04 13:30:29 +08:00
1. This PR updates the hub client xet access token request to use custom
rev in addition to the default "main". This better supports the
"xet-write-token" API authorization model:
Clients can get a xet write token, if
- the "rev" is a regular branch, with a HF write token;
- the "rev" is a pr branch with an corresponding open PR, with a HF
write or read token;
- it intends to create a pr and repo is enabled for discussion, with a
HF write or read token.
2. Fixed a bug when getting the current branch name in a repo, which
didn't parse branch names with "/" correctly: change
`refs_heads_branch.rsplit('/').next()` to
`refs_heads_branch.strip_prefix("refs/heads/")`.
3. Also updated xet transfer agent to use the refresh route in the LFS
Batch Api
[response](e3be2b3c8f/server/app/gitHostingRoutes.ts (L1713)).
4. Use the session id in the LFS Batch Api
[response](e3be2b3c8f/server/app/gitHostingRoutes.ts (L1657))
for token refresh and CAS requests.
124 lines
3.9 KiB
Rust
124 lines
3.9 KiB
Rust
use std::sync::Arc;
|
|
|
|
use async_trait::async_trait;
|
|
use hub_client::{CredentialHelper, HubClientError, Operation, Result};
|
|
#[cfg(unix)]
|
|
use openssh::{KnownHosts, Session};
|
|
use reqwest::header;
|
|
use reqwest_middleware::RequestBuilder;
|
|
use serde::Deserialize;
|
|
|
|
use crate::git_url::GitUrl;
|
|
|
|
#[derive(Deserialize)]
|
|
struct GitLFSAuthentationResponseHeader {
|
|
#[serde(rename = "Authorization")]
|
|
authorization: String,
|
|
}
|
|
|
|
// This struct represents the JSON format of the `git-lfs-authenticate` command response over an
|
|
// SSH channel to the remote Git server. For details see `crate::auth.rs`.
|
|
#[derive(Deserialize)]
|
|
#[allow(unused)]
|
|
struct GitLFSAuthenticateResponse {
|
|
header: GitLFSAuthentationResponseHeader,
|
|
href: String,
|
|
expires_in: u32,
|
|
}
|
|
|
|
// This credential helper calls a remote command `git-lfs-authenticate` over an SSH channel
|
|
// to the remote Git server.
|
|
// We can't cache the authorization token from ssh authentication because
|
|
// it has a shorter TTL than that of a Xet CAS JWT.
|
|
pub struct SSHCredentialHelper {
|
|
remote_url: GitUrl,
|
|
operation: Operation,
|
|
}
|
|
|
|
impl SSHCredentialHelper {
|
|
#[allow(clippy::new_ret_no_self)]
|
|
pub fn new(remote_url: &GitUrl, operation: Operation) -> Arc<Self> {
|
|
Arc::new(Self {
|
|
remote_url: remote_url.clone(),
|
|
operation,
|
|
})
|
|
}
|
|
|
|
#[cfg(unix)]
|
|
async fn authenticate(&self) -> Result<GitLFSAuthenticateResponse> {
|
|
let host_url = self.remote_url.host_url().map_err(HubClientError::credential_helper_error)?;
|
|
let full_repo_path = self.remote_url.full_repo_path();
|
|
let session = Session::connect(&host_url, KnownHosts::Add)
|
|
.await
|
|
.map_err(HubClientError::credential_helper_error)?;
|
|
|
|
let output = session
|
|
.command("git-lfs-authenticate")
|
|
.arg(full_repo_path)
|
|
.arg(self.operation.as_str())
|
|
.output()
|
|
.await
|
|
.map_err(HubClientError::credential_helper_error)?;
|
|
|
|
serde_json::from_slice(&output.stdout).map_err(HubClientError::credential_helper_error)
|
|
}
|
|
|
|
#[cfg(not(unix))]
|
|
async fn authenticate(&self) -> Result<GitLFSAuthenticateResponse> {
|
|
unimplemented!()
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl CredentialHelper for SSHCredentialHelper {
|
|
async fn fill_credential(&self, req: RequestBuilder) -> anyhow::Result<RequestBuilder> {
|
|
let authenticated = self.authenticate().await?;
|
|
Ok(req.header(header::AUTHORIZATION, authenticated.header.authorization))
|
|
}
|
|
|
|
fn whoami(&self) -> &str {
|
|
"ssh"
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use anyhow::Result;
|
|
use hub_client::Operation;
|
|
|
|
use super::SSHCredentialHelper;
|
|
use crate::git_url::GitUrl;
|
|
|
|
#[tokio::test]
|
|
#[ignore = "need ssh server"]
|
|
async fn test_ssh_cred_helper_local() -> Result<()> {
|
|
let remote_url = "ssh://git@localhost:2222/datasets/test/td";
|
|
let parsed_url: GitUrl = remote_url.parse()?;
|
|
let ssh_helper = SSHCredentialHelper::new(&parsed_url, Operation::Download);
|
|
|
|
let response = ssh_helper.authenticate().await?;
|
|
|
|
assert!(response.header.authorization.starts_with("Basic"));
|
|
assert_eq!(response.href, "http://localhost:5564/datasets/test/td.git/info/lfs");
|
|
assert!(response.expires_in > 0);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[tokio::test]
|
|
#[ignore = "need ssh key"]
|
|
async fn test_ssh_cred_helper_remote() -> Result<()> {
|
|
let remote_url = "ssh://git@hf.co/seanses/tm"; // it seems that ssh port is not open on "huggingface.co"
|
|
let parsed_url: GitUrl = remote_url.parse()?;
|
|
let ssh_helper = SSHCredentialHelper::new(&parsed_url, Operation::Upload);
|
|
|
|
let response = ssh_helper.authenticate().await?;
|
|
|
|
assert!(response.header.authorization.starts_with("Basic"));
|
|
assert_eq!(response.href, "https://huggingface.co/seanses/tm.git/info/lfs");
|
|
assert!(response.expires_in > 0);
|
|
|
|
Ok(())
|
|
}
|
|
}
|