mirror of
https://github.com/huggingface/xet-core.git
synced 2026-06-04 13:30:29 +08:00
Fix for incorrect error propagation on truncated download stream. (#683)
Currently, the async stream logic silently swallows an UnexpectedEOF, treating it the same as an EOF. This is a bug; this PR fixes it to propagate UnexpectedEOF while handling correct EOF as the end of the stream.
This commit is contained in:
@@ -13,7 +13,7 @@ use mdb_shard::file_structs::{FileDataSequenceEntry, FileDataSequenceHeader, MDB
|
||||
use merklehash::MerkleHash;
|
||||
use reqwest::{Body, Response, StatusCode, Url};
|
||||
use reqwest_middleware::ClientWithMiddleware;
|
||||
use tracing::{event, info, instrument, warn};
|
||||
use tracing::{event, info, instrument};
|
||||
use utils::auth::AuthConfig;
|
||||
use xet_runtime::xet_config;
|
||||
use xorb_object::SerializedXorbObject;
|
||||
@@ -372,17 +372,13 @@ impl Client for RemoteClient {
|
||||
|
||||
match result {
|
||||
Ok((_compressed_len, chunk_byte_indices)) => {
|
||||
if let Some(expected) = uncompressed_size_if_known {
|
||||
debug_assert_eq!(
|
||||
buffer.len(),
|
||||
expected,
|
||||
"get_file_term_data: expected {} bytes, got {}",
|
||||
expected,
|
||||
if let Some(expected) = uncompressed_size_if_known
|
||||
&& expected != buffer.len()
|
||||
{
|
||||
return Err(RetryableReqwestError::RetryableError(CasClientError::Other(format!(
|
||||
"get_file_term_data: expected {expected} uncompressed bytes, got {}",
|
||||
buffer.len()
|
||||
);
|
||||
if expected != buffer.len() {
|
||||
warn!("get_file_term_data: expected {} bytes, got {}", expected, buffer.len());
|
||||
}
|
||||
))));
|
||||
}
|
||||
Ok((Bytes::from(buffer), chunk_byte_indices))
|
||||
},
|
||||
|
||||
@@ -163,6 +163,14 @@ pub fn deserialize_chunk_to_writer<R: Read, W: Write>(
|
||||
writer: &mut W,
|
||||
) -> Result<(usize, u32), XorbObjectError> {
|
||||
let header = deserialize_chunk_header(reader)?;
|
||||
deserialize_chunk_with_header_to_writer(reader, writer, header)
|
||||
}
|
||||
|
||||
fn deserialize_chunk_with_header_to_writer<R: Read, W: Write>(
|
||||
reader: &mut R,
|
||||
writer: &mut W,
|
||||
header: XorbChunkHeader,
|
||||
) -> Result<(usize, u32), XorbObjectError> {
|
||||
let mut compressed_data_reader = reader.take(header.get_compressed_length().into());
|
||||
|
||||
let uncompressed_len = header
|
||||
@@ -184,6 +192,24 @@ pub fn deserialize_chunks<R: Read>(reader: &mut R) -> Result<(Vec<u8>, Vec<u32>)
|
||||
Ok((buf, chunk_byte_indices))
|
||||
}
|
||||
|
||||
/// Reads the next chunk header, returning `None` on clean EOF.
|
||||
///
|
||||
/// Uses a single `read()` call to detect EOF (returns 0), then completes
|
||||
/// any partial header with `read_exact`. An `UnexpectedEof` from `read_exact`
|
||||
/// means the stream was truncated mid-header.
|
||||
fn try_read_chunk_header<R: Read>(reader: &mut R) -> Result<Option<XorbChunkHeader>, XorbObjectError> {
|
||||
let mut header_buf = [0u8; XORB_CHUNK_HEADER_LENGTH];
|
||||
let n = match reader.read(&mut header_buf) {
|
||||
Ok(0) => return Ok(None),
|
||||
Ok(n) => n,
|
||||
Err(e) => return Err(XorbObjectError::InternalIOError(e)),
|
||||
};
|
||||
if n < XORB_CHUNK_HEADER_LENGTH {
|
||||
reader.read_exact(&mut header_buf[n..])?;
|
||||
}
|
||||
parse_chunk_header(header_buf).map(Some)
|
||||
}
|
||||
|
||||
pub fn deserialize_chunks_to_writer<R: Read, W: Write>(
|
||||
reader: &mut R,
|
||||
writer: &mut W,
|
||||
@@ -196,21 +222,11 @@ pub fn deserialize_chunks_to_writer<R: Read, W: Write>(
|
||||
let mut chunk_byte_indices = Vec::<u32>::new();
|
||||
chunk_byte_indices.push(num_uncompressed_written);
|
||||
|
||||
loop {
|
||||
match deserialize_chunk_to_writer(reader, writer) {
|
||||
Ok((delta_written, uncompressed_chunk_len)) => {
|
||||
num_compressed_written += delta_written;
|
||||
num_uncompressed_written += uncompressed_chunk_len;
|
||||
chunk_byte_indices.push(num_uncompressed_written); // record end of current chunk
|
||||
},
|
||||
Err(XorbObjectError::InternalIOError(e)) => {
|
||||
if e.kind() == std::io::ErrorKind::UnexpectedEof {
|
||||
break;
|
||||
}
|
||||
return Err(XorbObjectError::InternalIOError(e));
|
||||
},
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
while let Some(header) = try_read_chunk_header(reader)? {
|
||||
let (delta_written, uncompressed_chunk_len) = deserialize_chunk_with_header_to_writer(reader, writer, header)?;
|
||||
num_compressed_written += delta_written;
|
||||
num_uncompressed_written += uncompressed_chunk_len;
|
||||
chunk_byte_indices.push(num_uncompressed_written);
|
||||
}
|
||||
|
||||
Ok((num_compressed_written, chunk_byte_indices))
|
||||
@@ -321,4 +337,40 @@ mod tests {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_truncated_stream_returns_error() {
|
||||
let (_, xorb_data, _, _) = build_xorb_object(3, ChunkSize::Fixed(1024), CompressionScheme::None);
|
||||
|
||||
// Truncate mid-header (e.g. 2 bytes into the second chunk's header)
|
||||
let first_chunk_end = XORB_CHUNK_HEADER_LENGTH + 1024;
|
||||
let mid_header = first_chunk_end + 2;
|
||||
let truncated = &xorb_data[..mid_header];
|
||||
let res = deserialize_chunks_to_writer(&mut Cursor::new(truncated), &mut Vec::new());
|
||||
assert!(res.is_err(), "truncation mid-header should error, not silently succeed");
|
||||
|
||||
// Truncate mid-data (header present but compressed payload cut short)
|
||||
let mid_data = first_chunk_end + XORB_CHUNK_HEADER_LENGTH + 10;
|
||||
let truncated = &xorb_data[..mid_data];
|
||||
let res = deserialize_chunks_to_writer(&mut Cursor::new(truncated), &mut Vec::new());
|
||||
assert!(res.is_err(), "truncation mid-data should error, not silently succeed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_exact_eof_after_complete_chunk_succeeds() {
|
||||
let (_, xorb_data, raw_data, raw_chunk_boundaries) =
|
||||
build_xorb_object(3, ChunkSize::Fixed(1024), CompressionScheme::None);
|
||||
|
||||
// Truncate exactly at the end of chunk 0. This should be treated as clean EOF.
|
||||
let first_chunk_end = XORB_CHUNK_HEADER_LENGTH + 1024;
|
||||
let truncated = &xorb_data[..first_chunk_end];
|
||||
|
||||
let mut out = Vec::new();
|
||||
let (num_read, chunk_byte_indices) =
|
||||
deserialize_chunks_to_writer(&mut Cursor::new(truncated), &mut out).unwrap();
|
||||
|
||||
assert_eq!(num_read, first_chunk_end);
|
||||
assert_eq!(chunk_byte_indices, vec![0, raw_chunk_boundaries[0].1]);
|
||||
assert_eq!(&out[..], &raw_data[..1024]);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,6 +22,14 @@ pub async fn deserialize_chunk_to_writer<R: AsyncRead + Unpin, W: Write>(
|
||||
writer: &mut W,
|
||||
) -> Result<(usize, u32), XorbObjectError> {
|
||||
let header = deserialize_chunk_header(reader).await?;
|
||||
deserialize_chunk_with_header_to_writer(reader, writer, header).await
|
||||
}
|
||||
|
||||
async fn deserialize_chunk_with_header_to_writer<R: AsyncRead + Unpin, W: Write>(
|
||||
reader: &mut R,
|
||||
writer: &mut W,
|
||||
header: XorbChunkHeader,
|
||||
) -> Result<(usize, u32), XorbObjectError> {
|
||||
let mut compressed_data = vec![0u8; header.get_compressed_length() as usize];
|
||||
reader.read_exact(&mut compressed_data).await?;
|
||||
|
||||
@@ -46,6 +54,26 @@ pub async fn deserialize_chunk<R: AsyncRead + Unpin>(reader: &mut R) -> Result<(
|
||||
Ok((buf, compressed_len, uncompressed_len))
|
||||
}
|
||||
|
||||
/// Reads the next chunk header from an async reader, returning `None` on clean EOF.
|
||||
///
|
||||
/// Uses a single `read()` call to detect EOF (returns 0), then completes
|
||||
/// any partial header with `read_exact`. An `UnexpectedEof` from `read_exact`
|
||||
/// means the stream was truncated mid-header.
|
||||
async fn try_read_chunk_header_async<R: AsyncRead + Unpin>(
|
||||
reader: &mut R,
|
||||
) -> Result<Option<XorbChunkHeader>, XorbObjectError> {
|
||||
let mut header_buf = [0u8; XORB_CHUNK_HEADER_LENGTH];
|
||||
let n = match AsyncReadExt::read(reader, &mut header_buf).await {
|
||||
Ok(0) => return Ok(None),
|
||||
Ok(n) => n,
|
||||
Err(e) => return Err(XorbObjectError::InternalIOError(e)),
|
||||
};
|
||||
if n < XORB_CHUNK_HEADER_LENGTH {
|
||||
reader.read_exact(&mut header_buf[n..]).await?;
|
||||
}
|
||||
parse_chunk_header(header_buf).map(Some)
|
||||
}
|
||||
|
||||
pub async fn deserialize_chunks_to_writer_from_async_read<R: AsyncRead + Unpin, W: Write>(
|
||||
reader: &mut R,
|
||||
writer: &mut W,
|
||||
@@ -58,21 +86,12 @@ pub async fn deserialize_chunks_to_writer_from_async_read<R: AsyncRead + Unpin,
|
||||
let mut chunk_byte_indices = Vec::<u32>::new();
|
||||
chunk_byte_indices.push(num_uncompressed_written);
|
||||
|
||||
loop {
|
||||
match deserialize_chunk_to_writer(reader, writer).await {
|
||||
Ok((delta_written, uncompressed_chunk_len)) => {
|
||||
num_compressed_written += delta_written;
|
||||
num_uncompressed_written += uncompressed_chunk_len;
|
||||
chunk_byte_indices.push(num_uncompressed_written); // record end of current chunk
|
||||
},
|
||||
Err(XorbObjectError::InternalIOError(e)) => {
|
||||
if e.kind() == std::io::ErrorKind::UnexpectedEof {
|
||||
break;
|
||||
}
|
||||
return Err(XorbObjectError::InternalIOError(e));
|
||||
},
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
while let Some(header) = try_read_chunk_header_async(reader).await? {
|
||||
let (delta_written, uncompressed_chunk_len) =
|
||||
deserialize_chunk_with_header_to_writer(reader, writer, header).await?;
|
||||
num_compressed_written += delta_written;
|
||||
num_uncompressed_written += uncompressed_chunk_len;
|
||||
chunk_byte_indices.push(num_uncompressed_written);
|
||||
}
|
||||
|
||||
Ok((num_compressed_written, chunk_byte_indices))
|
||||
@@ -177,4 +196,46 @@ mod tests {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_truncated_stream_returns_error() {
|
||||
use crate::XORB_CHUNK_HEADER_LENGTH;
|
||||
|
||||
let rng = &mut rng();
|
||||
let data = get_chunks(rng, 3, CompressionScheme::None);
|
||||
|
||||
let first_chunk_end = XORB_CHUNK_HEADER_LENGTH + CHUNK_SIZE;
|
||||
|
||||
// Truncate mid-header
|
||||
let mid_header = first_chunk_end + 2;
|
||||
let stream = futures::stream::iter(vec![Ok::<_, std::io::Error>(Bytes::copy_from_slice(&data[..mid_header]))]);
|
||||
let res = deserialize_chunks_to_writer_from_stream(stream, &mut Vec::new()).await;
|
||||
assert!(res.is_err(), "truncation mid-header should error, not silently succeed");
|
||||
|
||||
// Truncate mid-data
|
||||
let mid_data = first_chunk_end + XORB_CHUNK_HEADER_LENGTH + 10;
|
||||
let stream = futures::stream::iter(vec![Ok::<_, std::io::Error>(Bytes::copy_from_slice(&data[..mid_data]))]);
|
||||
let res = deserialize_chunks_to_writer_from_stream(stream, &mut Vec::new()).await;
|
||||
assert!(res.is_err(), "truncation mid-data should error, not silently succeed");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_exact_eof_after_complete_chunk_succeeds() {
|
||||
use crate::XORB_CHUNK_HEADER_LENGTH;
|
||||
|
||||
let rng = &mut rng();
|
||||
let data = get_chunks(rng, 3, CompressionScheme::None);
|
||||
let first_chunk_end = XORB_CHUNK_HEADER_LENGTH + CHUNK_SIZE;
|
||||
|
||||
// Truncate exactly at end of first chunk. This should be clean EOF.
|
||||
let stream = futures::stream::iter(vec![Ok::<_, std::io::Error>(Bytes::copy_from_slice(
|
||||
&data[..first_chunk_end],
|
||||
))]);
|
||||
let mut out = Vec::new();
|
||||
let (num_read, chunk_byte_indices) = deserialize_chunks_to_writer_from_stream(stream, &mut out).await.unwrap();
|
||||
|
||||
assert_eq!(num_read, first_chunk_end);
|
||||
assert_eq!(chunk_byte_indices, vec![0, CHUNK_SIZE as u32]);
|
||||
assert_eq!(out.len(), CHUNK_SIZE);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user