Fixed race condition with is_complete

This commit is contained in:
Hoyt Koepke
2026-03-20 13:49:32 -07:00
parent 2291c0c4bb
commit e02890aa4b
3 changed files with 36 additions and 12 deletions

View File

@@ -221,9 +221,13 @@ impl UnorderedDownloadStream {
self.progress.terms_in_progress()
}
/// Returns `true` if all data has been fetched and the writer has finished.
/// Returns `true` once the stream has reached terminal state.
///
/// This flips to `true` after [`next`](Self::next) / [`blocking_next`](Self::blocking_next)
/// has observed the end-of-stream (`None`), or after [`cancel`](Self::cancel).
/// Buffered but unconsumed channel items do not count as complete.
pub fn is_complete(&self) -> bool {
self.progress.is_finished() && self.progress.terms_in_progress() == 0
self.finished
}
}

View File

@@ -1,5 +1,5 @@
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::atomic::{AtomicU64, Ordering};
use bytes::Bytes;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
@@ -26,7 +26,6 @@ pub(crate) struct CompletedTerm {
pub(crate) struct UnorderedWriterProgress {
pub terms_in_progress: AtomicU64,
pub bytes_in_progress: AtomicU64,
pub finished: AtomicBool,
}
impl UnorderedWriterProgress {
@@ -37,10 +36,6 @@ impl UnorderedWriterProgress {
pub fn bytes_in_progress(&self) -> u64 {
self.bytes_in_progress.load(Ordering::Relaxed)
}
pub fn is_finished(&self) -> bool {
self.finished.load(Ordering::Acquire)
}
}
/// Writer that delivers completed data terms in arbitrary order.
@@ -146,14 +141,12 @@ impl DataWriter for UnorderedWriter {
async fn finish(mut self: Box<Self>) -> Result<u64> {
self.run_state.check_error()?;
self.finished = true;
self.progress.finished.store(true, Ordering::Release);
while let Some(result) = self.task_set.join_next().await {
self.total_bytes_sent +=
result.map_err(|e| FileReconstructionError::InternalError(format!("Task join error: {e}")))??;
}
self.finished = true;
Ok(self.total_bytes_sent)
}
}
@@ -175,7 +168,6 @@ impl UnorderedWriter {
let progress = Arc::new(UnorderedWriterProgress {
terms_in_progress: AtomicU64::new(0),
bytes_in_progress: AtomicU64::new(0),
finished: AtomicBool::new(false),
});
let writer = Box::new(UnorderedWriter {

View File

@@ -2,6 +2,7 @@
mod tests {
use std::fs;
use std::sync::Arc;
use std::time::Duration;
use tempfile::TempDir;
use xet_client::cas_client::LocalTestServerBuilder;
@@ -147,6 +148,33 @@ mod tests {
assert_eq!(stream.terms_in_progress(), 0);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_unordered_stream_is_complete_loop_drains_all_data() {
let server = LocalTestServerBuilder::new().start().await;
let base_dir = TempDir::new().unwrap();
let config = Arc::new(TranslatorConfig::test_server_config(server.http_endpoint(), base_dir.path()).unwrap());
let original_data: Vec<u8> = (0..131072u32).map(|i| (i % 251) as u8).collect();
let upload_session = FileUploadSession::new(config.clone()).await.unwrap();
let xfi = upload_bytes(&upload_session, "is_complete_loop", &original_data).await;
upload_session.finalize().await.unwrap();
let download_session = FileDownloadSession::new(config.clone()).await.unwrap();
let (_id, mut stream) = download_session.download_unordered_stream(&xfi, None).await.unwrap();
let mut chunks = Vec::new();
while !stream.is_complete() {
if let Some((offset, chunk)) = stream.next().await.unwrap() {
chunks.push((offset, chunk));
tokio::time::sleep(Duration::from_millis(1)).await;
}
}
let assembled = reassemble(chunks, original_data.len());
assert_eq!(assembled, original_data);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_unordered_stream_next_returns_none_after_complete() {
let server = LocalTestServerBuilder::new().start().await;