mirror of
https://github.com/huggingface/xet-core.git
synced 2026-06-04 13:30:29 +08:00
Adding hf_xet integration (#10)
This commit is contained in:
5
.gitignore
vendored
5
.gitignore
vendored
@@ -1,8 +1,8 @@
|
||||
.idea
|
||||
**/.idea
|
||||
# Generated by Cargo
|
||||
# will have compiled files and executables
|
||||
debug/
|
||||
target/
|
||||
**/target/
|
||||
|
||||
# These are backup files generated by rustfmt
|
||||
**/*.rs.bk
|
||||
@@ -15,3 +15,4 @@ target/
|
||||
|
||||
# VS Code configs
|
||||
.vscode
|
||||
venv
|
||||
|
||||
@@ -21,6 +21,10 @@ members = [
|
||||
"cas_types",
|
||||
]
|
||||
|
||||
exclude = [
|
||||
"hf_xet",
|
||||
]
|
||||
|
||||
[profile.release]
|
||||
opt-level = 3
|
||||
lto = true
|
||||
|
||||
@@ -6,7 +6,7 @@ mod clean;
|
||||
pub mod configurations;
|
||||
mod constants;
|
||||
mod data_processing;
|
||||
mod errors;
|
||||
pub mod errors;
|
||||
mod metrics;
|
||||
mod pointer_file;
|
||||
mod remote_shard_interface;
|
||||
|
||||
@@ -12,7 +12,7 @@ const CURRENT_VERSION: &str = "0";
|
||||
/// A struct that wraps a Xet pointer file.
|
||||
/// Xet pointer file format is a TOML file,
|
||||
/// and the first line must be of the form "# xet version <x.y>"
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Default)]
|
||||
pub struct PointerFile {
|
||||
/// The version string of the pointer file
|
||||
version_string: String,
|
||||
|
||||
4625
hf_xet/Cargo.lock
generated
Normal file
4625
hf_xet/Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
20
hf_xet/Cargo.toml
Normal file
20
hf_xet/Cargo.toml
Normal file
@@ -0,0 +1,20 @@
|
||||
[package]
|
||||
name = "hf_xet"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
[lib]
|
||||
name = "hf_xet"
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
pyo3 = { version = "0.20.2", features = [
|
||||
"extension-module",
|
||||
"abi3-py37",
|
||||
] }
|
||||
|
||||
data = { path = "../data" }
|
||||
tokio = { version = "1.36", features = ["full"] }
|
||||
parutils = { path = "../parutils" }
|
||||
|
||||
4
hf_xet/README.md
Normal file
4
hf_xet/README.md
Normal file
@@ -0,0 +1,4 @@
|
||||
# Development Notes
|
||||
|
||||
* `pip install maturin`
|
||||
* from this directory: `maturin develop`
|
||||
20
hf_xet/pyproject.toml
Normal file
20
hf_xet/pyproject.toml
Normal file
@@ -0,0 +1,20 @@
|
||||
[build-system]
|
||||
requires = ["maturin>=1.7,<2.0"]
|
||||
build-backend = "maturin"
|
||||
|
||||
[project]
|
||||
name = "hfxet"
|
||||
requires-python = ">=3.8"
|
||||
classifiers = [
|
||||
"Programming Language :: Rust",
|
||||
"Programming Language :: Python :: Implementation :: CPython",
|
||||
"Programming Language :: Python :: Implementation :: PyPy",
|
||||
]
|
||||
dynamic = ["version"]
|
||||
[project.optional-dependencies]
|
||||
tests = [
|
||||
"pytest",
|
||||
]
|
||||
[tool.maturin]
|
||||
python-source = "python"
|
||||
features = ["pyo3/extension-module"]
|
||||
55
hf_xet/src/config.rs
Normal file
55
hf_xet/src/config.rs
Normal file
@@ -0,0 +1,55 @@
|
||||
use std::env::current_dir;
|
||||
use std::fs;
|
||||
use data::configurations::{Auth, CacheConfig, DedupConfig, Endpoint, FileQueryPolicy, RepoInfo, StorageConfig, TranslatorConfig};
|
||||
use data::{DEFAULT_BLOCK_SIZE, errors};
|
||||
|
||||
pub const SMALL_FILE_THRESHOLD: usize = 1;
|
||||
|
||||
pub fn default_config(endpoint: String) -> errors::Result<TranslatorConfig> {
|
||||
let path = current_dir()?.join(".xet");
|
||||
fs::create_dir_all(&path)?;
|
||||
|
||||
let translator_config = TranslatorConfig {
|
||||
file_query_policy: FileQueryPolicy::ServerOnly,
|
||||
cas_storage_config: StorageConfig {
|
||||
endpoint: Endpoint::Server(endpoint.clone()),
|
||||
auth: Auth {
|
||||
user_id: "".into(),
|
||||
login_id: "".into(),
|
||||
},
|
||||
prefix: "default".into(),
|
||||
cache_config: Some(CacheConfig {
|
||||
cache_directory: path.join("cache"),
|
||||
cache_size: 10 * 1024 * 1024 * 1024, // 10 GiB
|
||||
cache_blocksize: DEFAULT_BLOCK_SIZE,
|
||||
}),
|
||||
staging_directory: None,
|
||||
},
|
||||
shard_storage_config: StorageConfig {
|
||||
endpoint: Endpoint::Server(endpoint),
|
||||
auth: Auth {
|
||||
user_id: "".into(),
|
||||
login_id: "".into(),
|
||||
},
|
||||
prefix: "default-merkledb".into(),
|
||||
cache_config: Some(CacheConfig {
|
||||
cache_directory: path.join("shard-cache"),
|
||||
cache_size: 0, // ignored
|
||||
cache_blocksize: 0, // ignored
|
||||
}),
|
||||
staging_directory: Some(path.join("shard-session")),
|
||||
},
|
||||
dedup_config: Some(DedupConfig {
|
||||
repo_salt: None,
|
||||
small_file_threshold: SMALL_FILE_THRESHOLD,
|
||||
global_dedup_policy: Default::default(),
|
||||
}),
|
||||
repo_info: Some(RepoInfo {
|
||||
repo_paths: vec!["".into()],
|
||||
}),
|
||||
};
|
||||
|
||||
translator_config.validate()?;
|
||||
|
||||
Ok(translator_config)
|
||||
}
|
||||
129
hf_xet/src/data_client.rs
Normal file
129
hf_xet/src/data_client.rs
Normal file
@@ -0,0 +1,129 @@
|
||||
use std::fs;
|
||||
use std::fs::File;
|
||||
use std::io::{BufReader, Read};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use data::{errors, PointerFile, PointerFileTranslator};
|
||||
use data::errors::DataProcessingError;
|
||||
use parutils::{ParallelError, tokio_par_for_each};
|
||||
use crate::config::default_config;
|
||||
|
||||
/// The maximum git filter protocol packet size
|
||||
pub const GIT_MAX_PACKET_SIZE: usize = 65516;
|
||||
pub const MAX_CONCURRENT_UPLOADS: usize = 8; // TODO
|
||||
pub const MAX_CONCURRENT_DOWNLOADS: usize = 8; // TODO
|
||||
|
||||
const DEFAULT_CAS_ENDPOINT: &str = "https://localhost:4884";
|
||||
const READ_BLOCK_SIZE: usize = 1024 * 1024;
|
||||
|
||||
pub async fn upload_async(file_paths: Vec<String>) -> errors::Result<Vec<PointerFile>> {
|
||||
// chunk files
|
||||
// produce Xorbs + Shards
|
||||
// upload shards and xorbs
|
||||
// for each file, return the filehash
|
||||
|
||||
let config = default_config(DEFAULT_CAS_ENDPOINT.to_string())?;
|
||||
let processor = Arc::new(PointerFileTranslator::new(config).await?);
|
||||
let processor = &processor;
|
||||
// for all files, clean them, producing pointer files.
|
||||
let pointers = tokio_par_for_each(
|
||||
file_paths,
|
||||
MAX_CONCURRENT_UPLOADS,
|
||||
|f, _| async {
|
||||
let proc = processor.clone();
|
||||
clean_file(&proc, f).await
|
||||
},
|
||||
).await.map_err(|e| match e {
|
||||
ParallelError::JoinError => {
|
||||
DataProcessingError::InternalError("Join error".to_string())
|
||||
}
|
||||
ParallelError::TaskError(e) => e,
|
||||
})?;
|
||||
|
||||
// Push the CAS blocks and flush the mdb to disk
|
||||
processor.finalize_cleaning().await?;
|
||||
|
||||
Ok(pointers)
|
||||
}
|
||||
|
||||
pub async fn download_async(pointer_files: Vec<PointerFile>) -> errors::Result<Vec<String>> {
|
||||
let config = default_config(DEFAULT_CAS_ENDPOINT.to_string())?;
|
||||
let processor = Arc::new(PointerFileTranslator::new(config).await?);
|
||||
let processor = &processor;
|
||||
let paths = tokio_par_for_each(
|
||||
pointer_files,
|
||||
MAX_CONCURRENT_DOWNLOADS,
|
||||
|pointer_file, _| async move {
|
||||
let proc = processor.clone();
|
||||
smudge_file(&proc, &pointer_file).await
|
||||
},
|
||||
).await.map_err(|e| match e {
|
||||
ParallelError::JoinError => {
|
||||
DataProcessingError::InternalError("Join error".to_string())
|
||||
}
|
||||
ParallelError::TaskError(e) => e,
|
||||
})?;
|
||||
|
||||
Ok(paths)
|
||||
}
|
||||
|
||||
async fn clean_file(processor: &PointerFileTranslator, f: String) -> errors::Result<PointerFile> {
|
||||
let mut read_buf = vec![0u8; READ_BLOCK_SIZE];
|
||||
let path = PathBuf::from(f);
|
||||
let mut reader = BufReader::new(File::open(path.clone())?);
|
||||
|
||||
let handle = processor.start_clean(1024, None).await?;
|
||||
|
||||
loop {
|
||||
let bytes = reader.read(&mut read_buf)?;
|
||||
if bytes == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
handle.add_bytes(read_buf[0..bytes].to_vec()).await?;
|
||||
}
|
||||
|
||||
let pf_str = handle.result().await?;
|
||||
let pf = PointerFile::init_from_string(&pf_str, "");
|
||||
Ok(pf)
|
||||
}
|
||||
|
||||
async fn smudge_file(proc: &PointerFileTranslator, pointer_file: &PointerFile) -> errors::Result<String> {
|
||||
let path = PathBuf::from(pointer_file.path());
|
||||
if let Some(parent_dir) = path.parent() {
|
||||
fs::create_dir_all(parent_dir)?;
|
||||
}
|
||||
let mut f = File::create(&path)?;
|
||||
proc.smudge_file_from_pointer(&pointer_file, &mut f, None).await?;
|
||||
Ok(pointer_file.path().to_string())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::env::current_dir;
|
||||
use std::fs::canonicalize;
|
||||
use super::*;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn upload_files() {
|
||||
let cwd = current_dir().unwrap();
|
||||
let path = cwd.join("src").join("data_client.rs");
|
||||
let abs_path = canonicalize(path).unwrap();
|
||||
let s = abs_path.to_string_lossy();
|
||||
let files = vec![
|
||||
s.to_string(),
|
||||
];
|
||||
let pointers = upload_async(files).await.unwrap();
|
||||
println!("files: {pointers:?}");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn download_files() {
|
||||
let pointers = vec![
|
||||
PointerFile::init_from_info("/tmp/foo.rs", "e12be5e7cf55a47b78089bdf6fa1ebafe1836ef2b3ea7206b08ca37398f98a6f", 12700),
|
||||
];
|
||||
let paths = download_async(pointers).await.unwrap();
|
||||
println!("paths: {paths:?}");
|
||||
}
|
||||
}
|
||||
|
||||
89
hf_xet/src/lib.rs
Normal file
89
hf_xet/src/lib.rs
Normal file
@@ -0,0 +1,89 @@
|
||||
mod data_client;
|
||||
mod config;
|
||||
|
||||
use pyo3::{pyfunction, PyResult};
|
||||
use pyo3::exceptions::PyException;
|
||||
use pyo3::prelude::*;
|
||||
use data::PointerFile;
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (file_paths), text_signature = "(file_paths: List[str]) -> List[PyPointerFile]")]
|
||||
pub fn upload_files(file_paths: Vec<String>) -> PyResult<Vec<PyPointerFile>> {
|
||||
Ok(tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.build()?
|
||||
.block_on(async {
|
||||
data_client::upload_async(file_paths).await
|
||||
}).map_err(|e| PyException::new_err(format!("{e:?}")))?
|
||||
.into_iter()
|
||||
.map(PyPointerFile::from)
|
||||
.collect())
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(signature = (files), text_signature = "(files: List[PyPointerFile]) -> List[str]")]
|
||||
pub fn download_files(files: Vec<PyPointerFile>) -> PyResult<Vec<String>> {
|
||||
let pfs = files.into_iter().map(PointerFile::from)
|
||||
.collect();
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.build()?
|
||||
.block_on(async move {
|
||||
data_client::download_async(pfs).await
|
||||
}).map_err(|e| PyException::new_err(format!("{e:?}")))
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct PyPointerFile {
|
||||
#[pyo3(get, set)]
|
||||
path: String,
|
||||
#[pyo3(get)]
|
||||
hash: String,
|
||||
#[pyo3(get)]
|
||||
filesize: u64,
|
||||
}
|
||||
|
||||
impl From<PointerFile> for PyPointerFile {
|
||||
fn from(pf: PointerFile) -> Self {
|
||||
Self {
|
||||
path: pf.path().to_string(),
|
||||
hash: pf.hash_string().to_string(),
|
||||
filesize: pf.filesize(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PyPointerFile> for PointerFile {
|
||||
fn from(pf: PyPointerFile) -> Self {
|
||||
PointerFile::init_from_info(&pf.path, &pf.hash, pf.filesize)
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyPointerFile {
|
||||
#[new]
|
||||
pub fn new(path: String, hash: String, filesize: u64) -> Self {
|
||||
Self {
|
||||
path,
|
||||
hash,
|
||||
filesize,
|
||||
}
|
||||
}
|
||||
|
||||
fn __str__(&self) -> String {
|
||||
format!("{self:?}")
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
format!("PyPointerFile({}, {}, {})", self.path, self.hash, self.filesize)
|
||||
}
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
pub fn hf_xet(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(upload_files, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(download_files, m)?)?;
|
||||
m.add_class::<PyPointerFile>()?;
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user