No re-computation of sha256 if provided (#570)

This PR let File cleaner skip re-computing sha256 if provided for files
-- sha256 computation isn't blazing fast, and there's no need to
re-compute. The sha256 is not verified anywhere, only serving as a
foreign key in the Xet file table to the file id in moon-landing
lfs-files table.

- Updates git-xet and migration utility to pass in the sha256 which is
already available.
- Will update huggingface_hub for the `upload_large_folder` case later.
- Related PR in repo-scanner:
https://github.com/huggingface-internal/repository-scanner/pull/368
(update the commit id when this merges).

Fix XET-200
This commit is contained in:
Di Xiao
2025-11-21 06:21:15 +08:00
committed by GitHub
parent b5563ecd93
commit 4baae2f006
14 changed files with 142 additions and 55 deletions

View File

@@ -92,7 +92,7 @@ async fn clean(mut reader: impl Read, mut writer: impl Write, size: u64) -> Resu
FileUploadSession::new(TranslatorConfig::local_config(std::env::current_dir()?)?.into(), None).await?;
let mut size_read = 0;
let mut handle = translator.start_clean(None, size).await;
let mut handle = translator.start_clean(None, size, None).await;
loop {
let bytes = reader.read(&mut read_buf)?;

View File

@@ -125,6 +125,7 @@ impl Command {
let (all_file_info, clean_ret, total_bytes_trans) = migrate_files_impl(
file_paths,
None,
arg.sequential,
hub_client,
None,

View File

@@ -8,6 +8,7 @@ use cas_client::{CacheConfig, SeekingOutputProvider, SequentialOutput, sequentia
use cas_object::CompressionScheme;
use deduplication::DeduplicationMetrics;
use lazy_static::lazy_static;
use mdb_shard::Sha256;
use progress_tracking::TrackingProgressUpdater;
use progress_tracking::item_tracking::ItemProgressUpdater;
use tracing::{Instrument, Span, info, info_span, instrument};
@@ -132,6 +133,7 @@ pub async fn upload_bytes_async(
Ok(files)
}
// The sha256, if provided and valid, will be directly used in shard upload to avoid redundant computation.
#[instrument(skip_all, name = "data_client::upload_files",
fields(session_id = tracing::field::Empty,
num_files=file_paths.len(),
@@ -144,6 +146,7 @@ pub async fn upload_bytes_async(
))]
pub async fn upload_async(
file_paths: Vec<String>,
sha256s: Option<Vec<String>>,
endpoint: Option<String>,
token_info: Option<(String, u64)>,
token_refresher: Option<Arc<dyn TokenRefresher>>,
@@ -168,7 +171,23 @@ pub async fn upload_async(
let upload_session = FileUploadSession::new(config.into(), progress_updater).await?;
let ret = upload_session.upload_files(&file_paths).await?;
// Parse sha256 hex string and ignore invalid ones, or if no sha256 is provided,
// create an iterator of infinite number of "None"s.
let sha256s: Box<dyn Iterator<Item = Option<Sha256>> + Send> = match &sha256s {
Some(v) => {
if v.len() != file_paths.len() {
return Err(DataProcessingError::ParameterError(
"mistached length of the file list and the sha256 list".into(),
));
}
Box::new(v.iter().map(|s| Sha256::from_hex(s).ok()))
},
None => Box::new(std::iter::repeat(None)),
};
let files_and_sha256s = file_paths.into_iter().zip(sha256s);
let ret = upload_session.upload_files(files_and_sha256s).await?;
// Push the CAS blocks and flush the mdb to disk
let metrics = upload_session.finalize().await?;
@@ -229,25 +248,29 @@ pub async fn clean_bytes(
processor: Arc<FileUploadSession>,
bytes: Vec<u8>,
) -> errors::Result<(XetFileInfo, DeduplicationMetrics)> {
let mut handle = processor.start_clean(None, bytes.len() as u64).await;
let mut handle = processor.start_clean(None, bytes.len() as u64, None).await;
handle.add_data(&bytes).await?;
handle.finish().await
}
// The provided sha256, if valid, will be directly used in shard upload to avoid redundant computation.
#[instrument(skip_all, name = "clean_file", fields(file.name = tracing::field::Empty, file.len = tracing::field::Empty))]
pub async fn clean_file(
processor: Arc<FileUploadSession>,
filename: impl AsRef<Path>,
sha256: impl AsRef<str>,
) -> errors::Result<(XetFileInfo, DeduplicationMetrics)> {
let mut reader = File::open(&filename)?;
let n = reader.metadata()?.len();
let filesize = reader.metadata()?.len();
let span = Span::current();
span.record("file.name", filename.as_ref().to_str());
span.record("file.len", n);
let mut buffer = vec![0u8; u64::min(n, *xet_config().data.ingestion_block_size) as usize];
span.record("file.len", filesize);
let mut buffer = vec![0u8; u64::min(filesize, *xet_config().data.ingestion_block_size) as usize];
let mut handle = processor.start_clean(Some(filename.as_ref().to_string_lossy().into()), n).await;
let mut handle = processor
.start_clean(Some(filename.as_ref().to_string_lossy().into()), filesize, Sha256::from_hex(sha256.as_ref()).ok())
.await;
loop {
let bytes = reader.read(&mut buffer)?;

View File

@@ -5,8 +5,8 @@ use std::sync::Arc;
use bytes::Bytes;
use chrono::{DateTime, Utc};
use deduplication::{Chunk, Chunker, DeduplicationMetrics, FileDeduper};
use mdb_shard::Sha256;
use mdb_shard::file_structs::FileMetadataExt;
use merklehash::MerkleHash;
use progress_tracking::upload_tracking::CompletionTrackerFileId;
use tracing::{Instrument, debug_span, info, instrument};
use xet_runtime::xet_config;
@@ -43,9 +43,11 @@ pub struct SingleFileCleaner {
}
impl SingleFileCleaner {
// If a sha256 value is given in the parameter, the cleaner avoids computing the sha256 again internally.
pub(crate) fn new(
file_name: Option<Arc<str>>,
file_id: CompletionTrackerFileId,
sha256: Option<Sha256>,
session: Arc<FileUploadSession>,
) -> Self {
let deduper = FileDeduper::new(UploadSessionDataManager::new(session.clone(), file_id), file_id);
@@ -56,7 +58,7 @@ impl SingleFileCleaner {
dedup_manager_fut: Box::pin(async move { Ok(deduper) }),
session,
chunker: deduplication::Chunker::default(),
sha_generator: ShaGenerator::new(),
sha_generator: sha256.map(ShaGenerator::ProvidedValue).unwrap_or_else(ShaGenerator::generate),
start_time: Utc::now(),
}
}
@@ -148,7 +150,7 @@ impl SingleFileCleaner {
}
// Finalize the sha256 hashing and create the metadata extension
let sha256: MerkleHash = self.sha_generator.finalize().await?;
let sha256: Sha256 = self.sha_generator.finalize().await?;
let metadata_ext = FileMetadataExt::new(sha256);
let (file_hash, remaining_file_data, deduplication_metrics) =

View File

@@ -12,6 +12,7 @@ use deduplication::constants::{MAX_XORB_BYTES, MAX_XORB_CHUNKS};
use deduplication::{DataAggregator, DeduplicationMetrics, RawXorbData};
use jsonwebtoken::{DecodingKey, Validation, decode};
use lazy_static::lazy_static;
use mdb_shard::Sha256;
use mdb_shard::file_structs::MDBFileInfo;
use more_asserts::*;
use progress_tracking::aggregator::AggregatingProgressUpdater;
@@ -189,10 +190,13 @@ impl FileUploadSession {
}))
}
pub async fn upload_files(self: &Arc<Self>, files: &[impl AsRef<Path>]) -> Result<Vec<XetFileInfo>> {
let mut cleaning_tasks: Vec<JoinHandle<_>> = Vec::with_capacity(files.len());
pub async fn upload_files(
self: &Arc<Self>,
files_and_sha256s: impl IntoIterator<Item = (impl AsRef<Path>, Option<Sha256>)> + Send,
) -> Result<Vec<XetFileInfo>> {
let mut cleaning_tasks: Vec<JoinHandle<_>> = vec![];
for f in files {
for (f, sha256) in files_and_sha256s.into_iter() {
let file_path = f.as_ref().to_owned();
let file_name: Arc<str> = Arc::from(file_path.to_string_lossy());
@@ -229,7 +233,7 @@ impl FileUploadSession {
let mut reader = File::open(&file_path)?;
// Start the clean process for each file.
let mut cleaner = SingleFileCleaner::new(Some(file_name), file_id, session);
let mut cleaner = SingleFileCleaner::new(Some(file_name), file_id, sha256, session);
let mut bytes_read = 0;
while bytes_read < file_size {
@@ -279,7 +283,7 @@ impl FileUploadSession {
}
// Join all the cleaning tasks.
let mut ret = Vec::with_capacity(files.len());
let mut ret = Vec::with_capacity(cleaning_tasks.len());
for task in cleaning_tasks {
ret.push(task.await??);
@@ -294,14 +298,22 @@ impl FileUploadSession {
///
/// The caller is responsible for memory usage management, the parameter "buffer_size"
/// indicates the maximum number of Vec<u8> in the internal buffer.
pub async fn start_clean(self: &Arc<Self>, file_name: Option<Arc<str>>, size: u64) -> SingleFileCleaner {
///
/// If a sha256 is provided, the value will be directly used in shard upload to
/// avoid redundant computation.
pub async fn start_clean(
self: &Arc<Self>,
file_name: Option<Arc<str>>,
size: u64,
sha256: Option<Sha256>,
) -> SingleFileCleaner {
// Get a new file id for the completion tracking
let file_id = self
.completion_tracker
.register_new_file(file_name.clone().unwrap_or_default(), size)
.await;
SingleFileCleaner::new(file_name, file_id, self.clone())
SingleFileCleaner::new(file_name, file_id, sha256, self.clone())
}
/// Registers a new xorb for upload, returning true if the xorb was added to the upload queue and false
@@ -604,7 +616,9 @@ mod tests {
.await
.unwrap();
let mut cleaner = upload_session.start_clean(Some("test".into()), read_data.len() as u64).await;
let mut cleaner = upload_session
.start_clean(Some("test".into()), read_data.len() as u64, None)
.await;
// Read blocks from the source file and hand them to the cleaning handle
cleaner.add_data(&read_data[..]).await.unwrap();

View File

@@ -1,6 +1,6 @@
use std::sync::Arc;
use anyhow::Result;
use anyhow::{Result, anyhow};
use cas_object::CompressionScheme;
use hub_client::{BearerCredentialHelper, HubClient, Operation, RepoInfo};
use mdb_shard::file_structs::MDBFileInfo;
@@ -28,6 +28,7 @@ const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VE
/// ```
pub async fn migrate_with_external_runtime(
file_paths: Vec<String>,
sha256s: Option<Vec<String>>,
hub_endpoint: &str,
cas_endpoint: Option<String>,
hub_token: &str,
@@ -44,7 +45,7 @@ pub async fn migrate_with_external_runtime(
cred_helper,
)?;
migrate_files_impl(file_paths, false, hub_client, cas_endpoint, None, false).await?;
migrate_files_impl(file_paths, sha256s, false, hub_client, cas_endpoint, None, false).await?;
Ok(())
}
@@ -55,6 +56,7 @@ pub type MigrationInfo = (Vec<MDBFileInfo>, Vec<(XetFileInfo, u64)>, u64);
#[instrument(skip_all, name = "migrate_files", fields(session_id = tracing::field::Empty, num_files = file_paths.len()))]
pub async fn migrate_files_impl(
file_paths: Vec<String>,
sha256s: Option<Vec<String>>,
sequential: bool,
hub_client: HubClient,
cas_endpoint: Option<String>,
@@ -89,11 +91,20 @@ pub async fn migrate_files_impl(
FileUploadSession::new(config.into(), None).await?
};
// let file_paths_with_spans = add_spans(file_paths, || info_span!("migration::clean_file"));
let clean_futs = file_paths.into_iter().map(|file_path| {
let sha256s: Box<dyn Iterator<Item = String> + Send> = match sha256s {
Some(v) => {
if v.len() != file_paths.len() {
return Err(anyhow!("mistached length of the file list and the sha256 list"));
}
Box::new(v.into_iter())
},
None => Box::new(std::iter::repeat(String::new())),
};
let clean_futs = file_paths.into_iter().zip(sha256s).map(|(file_path, sha256)| {
let proc = processor.clone();
async move {
let (pf, metrics) = clean_file(proc, file_path).await?;
let (pf, metrics) = clean_file(proc, file_path, sha256).await?;
Ok::<(XetFileInfo, u64), DataProcessingError>((pf, metrics.new_bytes))
}
.instrument(info_span!("clean_file"))

View File

@@ -1,23 +1,44 @@
use merklehash::MerkleHash;
use sha2::{Digest, Sha256};
use mdb_shard::Sha256;
use sha2::{Digest, Sha256 as sha2Sha256};
use tokio::task::{JoinError, JoinHandle};
/// Helper struct to generate a sha256 hash as a MerkleHash.
#[derive(Debug)]
pub struct ShaGenerator {
hasher: Option<JoinHandle<Result<Sha256, JoinError>>>,
pub enum ShaGenerator {
Generate(Sha256Generator),
ProvidedValue(Sha256),
}
impl ShaGenerator {
pub fn new() -> Self {
Self { hasher: None }
pub async fn update(&mut self, new_data: impl AsRef<[u8]> + Send + Sync + 'static) -> Result<(), JoinError> {
match self {
Self::Generate(generator) => generator.update(new_data).await,
Self::ProvidedValue(_) => Ok(()),
}
}
pub async fn finalize(self) -> Result<Sha256, JoinError> {
match self {
Self::Generate(generator) => generator.finalize().await,
Self::ProvidedValue(hash) => Ok(hash),
}
}
pub fn generate() -> Self {
Self::Generate(Sha256Generator::default())
}
}
/// Helper struct to generate a sha256 hash.
#[derive(Debug, Default)]
pub struct Sha256Generator {
hasher: Option<JoinHandle<Result<sha2Sha256, JoinError>>>,
}
impl Sha256Generator {
/// Complete the last block, then hand off the new chunks to the new hasher.
pub async fn update(&mut self, new_data: impl AsRef<[u8]> + Send + Sync + 'static) -> Result<(), JoinError> {
let mut hasher = match self.hasher.take() {
Some(jh) => jh.await??,
None => Sha256::default(),
None => sha2Sha256::default(),
};
// The previous task returns the hasher; we consume that and pass it on.
@@ -32,17 +53,17 @@ impl ShaGenerator {
}
/// Generates a sha256 from the current state of the variant.
pub async fn finalize(mut self) -> Result<MerkleHash, JoinError> {
pub async fn finalize(mut self) -> Result<Sha256, JoinError> {
let current_state = self.hasher.take();
let hasher = match current_state {
Some(jh) => jh.await??,
None => return Ok(MerkleHash::default()),
None => return Ok(Sha256::default()),
};
let sha256 = hasher.finalize();
let hex_str = format!("{sha256:x}");
Ok(MerkleHash::from_hex(&hex_str).expect("Converting sha256 to merklehash."))
Ok(Sha256::from_hex(&hex_str).expect("Converting sha256 to merklehash."))
}
}
@@ -59,7 +80,7 @@ mod sha_tests {
#[tokio::test]
async fn test_sha_generation_builder() {
let mut sha_generator = ShaGenerator::new();
let mut sha_generator = Sha256Generator::default();
sha_generator.update(TEST_DATA.as_bytes()).await.unwrap();
let hash = sha_generator.finalize().await.unwrap();
@@ -68,7 +89,7 @@ mod sha_tests {
#[tokio::test]
async fn test_sha_generation_build_multiple_chunks() {
let mut sha_generator = ShaGenerator::new();
let mut sha_generator = Sha256Generator::default();
let td = TEST_DATA.as_bytes();
sha_generator.update(&td[0..4]).await.unwrap();
sha_generator.update(&td[4..td.len()]).await.unwrap();
@@ -85,7 +106,7 @@ mod sha_tests {
let mut rand_data = [0u8; 4096];
rng().fill(&mut rand_data[..]);
let mut sha_generator = ShaGenerator::new();
let mut sha_generator = Sha256Generator::default();
// Add in random chunks.
let mut pos = 0;
@@ -98,7 +119,7 @@ mod sha_tests {
let out_hash = sha_generator.finalize().await.unwrap();
let ref_hash = format!("{:x}", Sha256::digest(rand_data));
let ref_hash = format!("{:x}", sha2Sha256::digest(rand_data));
assert_eq!(out_hash.hex(), ref_hash);
}

View File

@@ -171,7 +171,7 @@ impl LocalHydrateDehydrateTest {
let upload_session = upload_session.clone();
if sequential {
let (pf, metrics) = clean_file(upload_session.clone(), entry.path()).await.unwrap();
let (pf, metrics) = clean_file(upload_session.clone(), entry.path(), "").await.unwrap();
assert_eq!({ metrics.total_bytes }, entry.metadata().unwrap().len());
std::fs::write(out_file, pf.as_pointer_file().unwrap().as_bytes()).unwrap();
@@ -185,7 +185,10 @@ impl LocalHydrateDehydrateTest {
.map(|entry| self.src_dir.join(entry.unwrap().file_name()))
.collect();
let clean_results = upload_session.upload_files(&files).await.unwrap();
let clean_results = upload_session
.upload_files(files.iter().zip(std::iter::repeat(None)))
.await
.unwrap();
for (i, xf) in clean_results.into_iter().enumerate() {
std::fs::write(self.ptr_dir.join(files[i].file_name().unwrap()), serde_json::to_string(&xf).unwrap())

View File

@@ -66,7 +66,9 @@ mod tests {
.unwrap();
// Feed it half the data, and checkpoint.
let mut cleaner = file_upload_session.start_clean(Some("data".into()), data.len() as u64).await;
let mut cleaner = file_upload_session
.start_clean(Some("data".into()), data.len() as u64, None)
.await;
cleaner.add_data(&data[..half_n]).await.unwrap();
cleaner.checkpoint().await.unwrap();
@@ -82,7 +84,9 @@ mod tests {
let file_upload_session = FileUploadSession::new(config, Some(progress_tracker.clone())).await.unwrap();
// Feed it half the data, and checkpoint.
let mut cleaner = file_upload_session.start_clean(Some("data".into()), data.len() as u64).await;
let mut cleaner = file_upload_session
.start_clean(Some("data".into()), data.len() as u64, None)
.await;
// Add all the data. Roughly the first half should dedup.
cleaner.add_data(&data).await.unwrap();
@@ -136,7 +140,9 @@ mod tests {
.unwrap();
// Feed it half the data, and checkpoint.
let mut cleaner = file_upload_session.start_clean(Some("data".into()), data.len() as u64).await;
let mut cleaner = file_upload_session
.start_clean(Some("data".into()), data.len() as u64, None)
.await;
cleaner.add_data(&data[..rn]).await.unwrap();
cleaner.checkpoint().await.unwrap();
@@ -166,7 +172,9 @@ mod tests {
let file_upload_session = FileUploadSession::new(config, Some(progress_tracker.clone())).await.unwrap();
// Feed it half the data, and checkpoint.
let mut cleaner = file_upload_session.start_clean(Some("data".into()), data.len() as u64).await;
let mut cleaner = file_upload_session
.start_clean(Some("data".into()), data.len() as u64, None)
.await;
// Add all the data. Roughly the first half should dedup.
cleaner.add_data(&data).await.unwrap();

View File

@@ -133,7 +133,7 @@ impl TransferAgent for XetAgent {
return Err(GitLFSProtocolError::bad_syntax("file path not provided for upload request").into());
};
clean_file(session.clone(), file_path).await?;
clean_file(session.clone(), file_path, &req.oid).await?;
// We need to actually upload the shard after each file upload to have the files registered, because
//

View File

@@ -107,6 +107,7 @@ pub fn upload_files(
let out: Vec<PyXetUploadInfo> = data_client::upload_async(
file_paths,
None,
endpoint,
token_info,
refresher.map(|v| v as Arc<_>),

View File

@@ -3,8 +3,8 @@ use std::io::{Cursor, Read, Write};
use std::mem::size_of;
use bytes::Bytes;
use merklehash::MerkleHash;
use merklehash::data_hash::hex;
use merklehash::{DataHash, MerkleHash};
use serde::Serialize;
use utils::serialization_utils::*;
@@ -18,6 +18,8 @@ pub const MDB_FILE_FLAG_VERIFICATION_MASK: u32 = 1 << 31;
pub const MDB_FILE_FLAG_WITH_METADATA_EXT: u32 = 1 << 30;
pub const MDB_FILE_FLAG_METADATA_EXT_MASK: u32 = 1 << 30;
pub type Sha256 = DataHash;
/// Each file consists of a FileDataSequenceHeader following
/// a sequence of FileDataSequenceEntry, maybe a sequence
/// of FileVerificationEntry, and maybe a FileMetadataExt
@@ -303,12 +305,12 @@ impl FileVerificationEntry {
#[derive(Clone, Debug, Default, PartialEq, Serialize)]
pub struct FileMetadataExt {
#[serde(with = "hex::serde")]
pub sha256: MerkleHash,
pub sha256: Sha256,
pub _unused: [u64; 2],
}
impl FileMetadataExt {
pub fn new(sha256: MerkleHash) -> Self {
pub fn new(sha256: Sha256) -> Self {
Self {
sha256,
_unused: Default::default(),

View File

@@ -17,6 +17,7 @@ pub use constants::{
MDB_SHARD_EXPIRATION_BUFFER, MDB_SHARD_GLOBAL_DEDUP_CHUNK_MODULUS, MDB_SHARD_LOCAL_CACHE_EXPIRATION,
hash_is_global_dedup_eligible,
};
pub use file_structs::Sha256;
pub use shard_file_handle::MDBShardFile;
pub use shard_file_manager::ShardFileManager;
pub use shard_format::{MDBShardFileFooter, MDBShardFileHeader, MDBShardInfo};

View File

@@ -3,10 +3,10 @@ use std::mem::{size_of, transmute};
use futures::AsyncReadExt;
use futures::io::AsyncRead;
use merklehash::MerkleHash;
use merklehash::DataHash;
#[inline]
pub fn write_hash<W: Write>(writer: &mut W, m: &MerkleHash) -> Result<(), std::io::Error> {
pub fn write_hash<W: Write>(writer: &mut W, m: &DataHash) -> Result<(), std::io::Error> {
writer.write_all(m.as_bytes())
}
@@ -49,11 +49,11 @@ pub fn write_u64s<W: Write>(writer: &mut W, vs: &[u64]) -> Result<(), std::io::E
}
#[inline]
pub fn read_hash<R: Read>(reader: &mut R) -> Result<MerkleHash, std::io::Error> {
pub fn read_hash<R: Read>(reader: &mut R) -> Result<DataHash, std::io::Error> {
let mut m = [0u8; 32];
reader.read_exact(&mut m)?; // Not endian safe.
Ok(MerkleHash::from(unsafe { transmute::<[u8; 32], [u64; 4]>(m) }))
Ok(DataHash::from(unsafe { transmute::<[u8; 32], [u64; 4]>(m) }))
}
#[inline]
@@ -102,11 +102,11 @@ pub fn read_u64s<R: Read>(reader: &mut R, vs: &mut [u64]) -> Result<(), std::io:
// Async version of the above.
#[inline]
pub async fn read_hash_async<R: AsyncRead + Unpin>(reader: &mut R) -> Result<MerkleHash, std::io::Error> {
pub async fn read_hash_async<R: AsyncRead + Unpin>(reader: &mut R) -> Result<DataHash, std::io::Error> {
let mut m = [0u8; 32];
reader.read_exact(&mut m).await?; // Not endian safe.
Ok(MerkleHash::from(unsafe { transmute::<[u8; 32], [u64; 4]>(m) }))
Ok(DataHash::from(unsafe { transmute::<[u8; 32], [u64; 4]>(m) }))
}
#[inline]