mirror of
https://github.com/huggingface/xet-core.git
synced 2026-06-04 13:30:29 +08:00
Fixed race condition with is_complete
This commit is contained in:
@@ -221,9 +221,13 @@ impl UnorderedDownloadStream {
|
||||
self.progress.terms_in_progress()
|
||||
}
|
||||
|
||||
/// Returns `true` if all data has been fetched and the writer has finished.
|
||||
/// Returns `true` once the stream has reached terminal state.
|
||||
///
|
||||
/// This flips to `true` after [`next`](Self::next) / [`blocking_next`](Self::blocking_next)
|
||||
/// has observed the end-of-stream (`None`), or after [`cancel`](Self::cancel).
|
||||
/// Buffered but unconsumed channel items do not count as complete.
|
||||
pub fn is_complete(&self) -> bool {
|
||||
self.progress.is_finished() && self.progress.terms_in_progress() == 0
|
||||
self.finished
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
|
||||
use bytes::Bytes;
|
||||
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
|
||||
@@ -26,7 +26,6 @@ pub(crate) struct CompletedTerm {
|
||||
pub(crate) struct UnorderedWriterProgress {
|
||||
pub terms_in_progress: AtomicU64,
|
||||
pub bytes_in_progress: AtomicU64,
|
||||
pub finished: AtomicBool,
|
||||
}
|
||||
|
||||
impl UnorderedWriterProgress {
|
||||
@@ -37,10 +36,6 @@ impl UnorderedWriterProgress {
|
||||
pub fn bytes_in_progress(&self) -> u64 {
|
||||
self.bytes_in_progress.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn is_finished(&self) -> bool {
|
||||
self.finished.load(Ordering::Acquire)
|
||||
}
|
||||
}
|
||||
|
||||
/// Writer that delivers completed data terms in arbitrary order.
|
||||
@@ -146,14 +141,12 @@ impl DataWriter for UnorderedWriter {
|
||||
async fn finish(mut self: Box<Self>) -> Result<u64> {
|
||||
self.run_state.check_error()?;
|
||||
|
||||
self.finished = true;
|
||||
self.progress.finished.store(true, Ordering::Release);
|
||||
|
||||
while let Some(result) = self.task_set.join_next().await {
|
||||
self.total_bytes_sent +=
|
||||
result.map_err(|e| FileReconstructionError::InternalError(format!("Task join error: {e}")))??;
|
||||
}
|
||||
|
||||
self.finished = true;
|
||||
Ok(self.total_bytes_sent)
|
||||
}
|
||||
}
|
||||
@@ -175,7 +168,6 @@ impl UnorderedWriter {
|
||||
let progress = Arc::new(UnorderedWriterProgress {
|
||||
terms_in_progress: AtomicU64::new(0),
|
||||
bytes_in_progress: AtomicU64::new(0),
|
||||
finished: AtomicBool::new(false),
|
||||
});
|
||||
|
||||
let writer = Box::new(UnorderedWriter {
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
mod tests {
|
||||
use std::fs;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use tempfile::TempDir;
|
||||
use xet_client::cas_client::LocalTestServerBuilder;
|
||||
@@ -147,6 +148,33 @@ mod tests {
|
||||
assert_eq!(stream.terms_in_progress(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_unordered_stream_is_complete_loop_drains_all_data() {
|
||||
let server = LocalTestServerBuilder::new().start().await;
|
||||
let base_dir = TempDir::new().unwrap();
|
||||
let config = Arc::new(TranslatorConfig::test_server_config(server.http_endpoint(), base_dir.path()).unwrap());
|
||||
|
||||
let original_data: Vec<u8> = (0..131072u32).map(|i| (i % 251) as u8).collect();
|
||||
|
||||
let upload_session = FileUploadSession::new(config.clone()).await.unwrap();
|
||||
let xfi = upload_bytes(&upload_session, "is_complete_loop", &original_data).await;
|
||||
upload_session.finalize().await.unwrap();
|
||||
|
||||
let download_session = FileDownloadSession::new(config.clone()).await.unwrap();
|
||||
let (_id, mut stream) = download_session.download_unordered_stream(&xfi, None).await.unwrap();
|
||||
|
||||
let mut chunks = Vec::new();
|
||||
while !stream.is_complete() {
|
||||
if let Some((offset, chunk)) = stream.next().await.unwrap() {
|
||||
chunks.push((offset, chunk));
|
||||
tokio::time::sleep(Duration::from_millis(1)).await;
|
||||
}
|
||||
}
|
||||
|
||||
let assembled = reassemble(chunks, original_data.len());
|
||||
assert_eq!(assembled, original_data);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_unordered_stream_next_returns_none_after_complete() {
|
||||
let server = LocalTestServerBuilder::new().start().await;
|
||||
|
||||
Reference in New Issue
Block a user