From 126f30b9816b7cc737f38c02017867264472c933 Mon Sep 17 00:00:00 2001 From: Di Xiao Date: Mon, 30 Mar 2026 11:36:02 -0700 Subject: [PATCH] Prevent multiple XetSession (XetRuntime) attach to the same external tokio runtime (#757) When `XetRuntime` wraps an external tokio handle (External mode), it registers the handle ID in `EXTERNAL_RUNTIME_REGISTRY` so `XetRuntime::current()` can look up the correct instance from any task, and thus obtain the associated `XetConfig`. Previously, a second `from_external_with_config` call with the same handle would silently overwrite the registry entry, breaking the first runtime's `current()` lookups. As a result, tasks spawned off the first `XetRuntime` can no longer access their specific `XetRuntime` and its configs, and this is not expected behavior. This PR makes the second call fail with an explicit error instead. This PR checks if `EXTERNAL_RUNTIME_REGISTRY` already contains an entry with key being the Id of the tokio runtime Id it tries to attach to, and returns `RuntimeError::InvalidRuntime` error if it does to prevent the above issue. --- .../src/xet_session/file_download_group.rs | 43 ++-- xet_pkg/src/xet_session/session.rs | 90 +++++-- xet_pkg/src/xet_session/upload_commit.rs | 242 ++++++++++++++++-- xet_pkg/tests/test_xet_session.rs | 175 ++++++------- xet_runtime/src/core/runtime.rs | 85 +++++- xet_runtime/src/error.rs | 3 + 6 files changed, 462 insertions(+), 176 deletions(-) diff --git a/xet_pkg/src/xet_session/file_download_group.rs b/xet_pkg/src/xet_session/file_download_group.rs index e535d5d7..e953136d 100644 --- a/xet_pkg/src/xet_session/file_download_group.rs +++ b/xet_pkg/src/xet_session/file_download_group.rs @@ -430,7 +430,7 @@ mod tests { use super::*; use crate::xet_session::session::{XetSession, XetSessionBuilder}; - async fn local_session(temp: &TempDir) -> Result { + fn local_session(temp: &TempDir) -> Result { let cas_path = temp.path().join("cas"); Ok(XetSessionBuilder::new() .with_endpoint(format!("local://{}", cas_path.display())) @@ -629,7 +629,7 @@ mod tests { // Downloading a previously uploaded file produces byte-identical content at the destination. async fn test_download_file_round_trip() { let temp = tempdir().unwrap(); - let session = local_session(&temp).await.unwrap(); + let session = local_session(&temp).unwrap(); let original = b"Hello, download round-trip!"; let file_info = upload_bytes(&session, original, "payload.bin").await.unwrap(); @@ -647,7 +647,7 @@ mod tests { // A download task that fails transitions to Error status. async fn test_download_status_failed_for_invalid_file_info() { let temp = tempdir().unwrap(); - let session = local_session(&temp).await.unwrap(); + let session = local_session(&temp).unwrap(); let group = session.new_file_download_group().unwrap().build().await.unwrap(); let handle = group .download_file_to_path( @@ -669,7 +669,7 @@ mod tests { // task_id returned by download_file_to_path must match the per-item progress entry id. async fn test_download_task_id_matches_progress_item_id() { let temp = tempdir().unwrap(); - let session = local_session(&temp).await.unwrap(); + let session = local_session(&temp).unwrap(); let original = b"download id match"; let file_info = upload_bytes(&session, original, "id.bin").await.unwrap(); @@ -695,7 +695,7 @@ mod tests { // Downloading multiple files from a single group produces correct content for each. async fn test_download_multiple_files() { let temp = tempdir().unwrap(); - let session = local_session(&temp).await.unwrap(); + let session = local_session(&temp).unwrap(); let data_a = b"First file content"; let data_b = b"Second file content - different"; @@ -732,7 +732,7 @@ mod tests { // After a successful finish the aggregate download progress reflects bytes received. async fn test_download_progress_reflects_bytes_after_finish() { let temp = tempdir().unwrap(); - let session = local_session(&temp).await.unwrap(); + let session = local_session(&temp).unwrap(); let original = b"download progress tracking data"; let file_info = upload_bytes(&session, original, "prog.bin").await.unwrap(); @@ -767,7 +767,7 @@ mod tests { // Pattern 1: per-task result is accessible via task_id in the finish report downloads map. async fn test_download_result_accessible_via_task_id_in_finish_map() { let temp = tempdir().unwrap(); - let session = local_session(&temp).await.unwrap(); + let session = local_session(&temp).unwrap(); let data = b"result via task_id in finish map"; let file_info = upload_bytes(&session, data, "file.bin").await.unwrap(); let dest = temp.path().join("out.bin"); @@ -785,7 +785,7 @@ mod tests { // XetFileDownload::result() returns None before finish() is called. async fn test_download_result_none_before_finish() { let temp = tempdir().unwrap(); - let session = local_session(&temp).await.unwrap(); + let session = local_session(&temp).unwrap(); let file_info = upload_bytes(&session, b"some data", "file.bin").await.unwrap(); let dest = temp.path().join("out.bin"); let group = session.new_file_download_group().unwrap().build().await.unwrap(); @@ -798,7 +798,7 @@ mod tests { // XetFileDownload::result() returns Some after finish() completes. async fn test_download_result_some_after_finish() { let temp = tempdir().unwrap(); - let session = local_session(&temp).await.unwrap(); + let session = local_session(&temp).unwrap(); let data = b"download result test data"; let file_info = upload_bytes(&session, data, "file.bin").await.unwrap(); let dest = temp.path().join("out.bin"); @@ -814,7 +814,7 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_download_finish_second_call_returns_cached_result() { let temp = tempdir().unwrap(); - let session = local_session(&temp).await.unwrap(); + let session = local_session(&temp).unwrap(); let data = b"download finish cache test"; let file_info = upload_bytes(&session, data, "cache.bin").await.unwrap(); let dest = temp.path().join("cache.out"); @@ -837,7 +837,7 @@ mod tests { let temp = tempdir().unwrap(); futures::executor::block_on(async { - let session = local_session(&temp).await.unwrap(); + let session = local_session(&temp).unwrap(); assert_eq!(session.inner.runtime.mode(), RuntimeMode::Owned); let data = b"hello from futures executor"; @@ -865,7 +865,7 @@ mod tests { let temp = tempdir().unwrap(); smol::block_on(async { - let session = local_session(&temp).await.unwrap(); + let session = local_session(&temp).unwrap(); assert_eq!(session.inner.runtime.mode(), RuntimeMode::Owned); let data = b"hello from smol executor"; @@ -893,7 +893,7 @@ mod tests { let temp = tempdir().unwrap(); async_std::task::block_on(async { - let session = local_session(&temp).await.unwrap(); + let session = local_session(&temp).unwrap(); assert_eq!(session.inner.runtime.mode(), RuntimeMode::Owned); let data = b"hello from async-std executor"; @@ -917,13 +917,6 @@ mod tests { // ── Blocking API tests ──────────────────────────────────────────────────── - fn local_session_sync(temp: &TempDir) -> Result { - let cas_path = temp.path().join("cas"); - Ok(XetSessionBuilder::new() - .with_endpoint(format!("local://{}", cas_path.display())) - .build()?) - } - fn upload_bytes_blocking(session: &XetSession, data: &[u8], name: &str) -> Result { let commit = session.new_upload_commit()?.build_blocking()?; let _handle = commit.upload_bytes_blocking(data.to_vec(), Sha256Policy::Compute, Some(name.into()))?; @@ -935,7 +928,7 @@ mod tests { #[test] fn test_blocking_download_file_round_trip() -> Result<()> { let temp = tempdir()?; - let session = local_session_sync(&temp)?; + let session = local_session(&temp)?; let original = b"Hello, download round-trip!"; let file_info = upload_bytes_blocking(&session, original, "payload.bin")?; @@ -951,7 +944,7 @@ mod tests { #[test] fn test_blocking_download_multiple_files() -> Result<()> { let temp = tempdir()?; - let session = local_session_sync(&temp)?; + let session = local_session(&temp)?; let data_a = b"First file content"; let data_b = b"Second file content - different"; @@ -984,7 +977,7 @@ mod tests { #[test] fn test_blocking_download_progress_reflects_bytes_after_finish() -> Result<()> { let temp = tempdir()?; - let session = local_session_sync(&temp)?; + let session = local_session(&temp)?; let original = b"download progress tracking data"; let file_info = upload_bytes_blocking(&session, original, "prog.bin")?; @@ -1014,7 +1007,7 @@ mod tests { #[test] fn test_blocking_download_result_access_patterns() -> Result<()> { let temp = tempdir()?; - let session = local_session_sync(&temp)?; + let session = local_session(&temp)?; let data = b"download result access patterns"; let file_info = upload_bytes_blocking(&session, data, "file.bin")?; let dest = temp.path().join("out.bin"); @@ -1046,7 +1039,7 @@ mod tests { R: FnOnce(std::pin::Pin>>), { let temp = tempdir().unwrap(); - let session = local_session_sync(&temp).unwrap(); + let session = local_session(&temp).unwrap(); run(Box::pin(async move { let data = b"download from smol executor"; diff --git a/xet_pkg/src/xet_session/session.rs b/xet_pkg/src/xet_session/session.rs index 74c06039..1f8a89ce 100644 --- a/xet_pkg/src/xet_session/session.rs +++ b/xet_pkg/src/xet_session/session.rs @@ -7,6 +7,7 @@ use http::HeaderMap; use tracing::info; use ulid::Ulid; use xet_data::progress_tracking::UniqueID; +use xet_runtime::RuntimeError; use xet_runtime::config::XetConfig; use xet_runtime::core::XetRuntime; @@ -150,6 +151,10 @@ impl XetSessionBuilder { /// If the handle does **not** meet requirements (e.g. `current_thread` flavor or missing /// drivers), it is silently ignored and [`build`](Self::build) will fall back to creating /// an owned thread pool instead. + /// + /// If the handle is already in use by another live `XetSession`, [`build`](Self::build) will + /// also fall back to creating an owned thread pool — the duplicate is logged at `INFO` level + /// and no error is returned. pub fn with_tokio_handle(self, handle: tokio::runtime::Handle) -> Self { let accept = XetRuntime::handle_meets_requirements(&handle); if !accept { @@ -169,6 +174,10 @@ impl XetSessionBuilder { /// it — no second thread pool is created. Otherwise, an owned multi-thread /// runtime is created; async methods use an internal bridge and work from /// any executor, and `_blocking` methods are available. + /// + /// If the detected or provided handle is already registered to another live `XetSession`, + /// the duplicate attach is silently rejected and an owned runtime is created instead. + /// This prevents two sessions from fighting over the same tokio runtime's task scheduler. pub fn build(self) -> Result { let handle = self.tokio_handle.or_else(|| { tokio::runtime::Handle::try_current() @@ -177,8 +186,24 @@ impl XetSessionBuilder { }); let runtime = match handle { - Some(h) => XetRuntime::from_external_with_config(h, self.config.clone()), - None => XetRuntime::new_with_config(self.config.clone())?, + Some(h) => { + info!("XetSession using External runtime (wrapping caller's tokio handle)"); + let result = XetRuntime::from_external_with_config(h, self.config.clone()); + match result { + Ok(runtime) => runtime, + Err(RuntimeError::ExternalAlreadyAttached(_)) => { + info!( + "An existing XetSession already wraps caller's tokio handle, switching to creating Owned runtime" + ); + XetRuntime::new_with_config(self.config.clone())? + }, + Err(e) => Err(e)?, + } + }, + None => { + info!("XetSession creating Owned runtime (new thread pool)"); + XetRuntime::new_with_config(self.config.clone())? + }, }; let session = XetSession::new(self.config, self.endpoint, self.custom_headers, runtime); @@ -375,7 +400,8 @@ impl XetSession { #[cfg(test)] mod tests { - use xet_data::processing::XetFileInfo; + use tempfile::{TempDir, tempdir}; + use xet_data::processing::{Sha256Policy, XetFileInfo}; use xet_runtime::core::{RuntimeMode, XetRuntime}; use super::*; @@ -694,17 +720,7 @@ mod tests { // ── Streaming download round-trip tests ───────────────────────────────── - use tempfile::{TempDir, tempdir}; - use xet_data::processing::Sha256Policy; - - async fn local_session(temp: &TempDir) -> Result> { - let cas_path = temp.path().join("cas"); - Ok(XetSessionBuilder::new() - .with_endpoint(format!("local://{}", cas_path.display())) - .build()?) - } - - fn local_session_sync(temp: &TempDir) -> Result> { + fn local_session(temp: &TempDir) -> Result> { let cas_path = temp.path().join("cas"); Ok(XetSessionBuilder::new() .with_endpoint(format!("local://{}", cas_path.display())) @@ -741,7 +757,7 @@ mod tests { // Async streaming download round-trip: upload, stream, verify content. async fn test_download_stream_round_trip() { let temp = tempdir().unwrap(); - let session = local_session(&temp).await.unwrap(); + let session = local_session(&temp).unwrap(); let original = b"Hello, streaming download!"; let file_info = upload_bytes(&session, original, "stream.bin").await.unwrap(); @@ -765,7 +781,7 @@ mod tests { // Blocking streaming download round-trip: upload, stream, verify content. fn test_download_stream_blocking_round_trip() { let temp = tempdir().unwrap(); - let session = local_session_sync(&temp).unwrap(); + let session = local_session(&temp).unwrap(); let original = b"Hello, blocking streaming download!"; let file_info = upload_bytes_blocking(&session, original, "stream.bin").unwrap(); @@ -788,7 +804,7 @@ mod tests { // progress() reports correct totals after consuming the stream. async fn test_download_stream_progress_reports_completion() { let temp = tempdir().unwrap(); - let session = local_session(&temp).await.unwrap(); + let session = local_session(&temp).unwrap(); let original = b"progress tracking test data for streaming"; let file_info = upload_bytes(&session, original, "progress.bin").await.unwrap(); @@ -820,7 +836,7 @@ mod tests { // progress() works correctly in blocking mode. fn test_download_stream_blocking_progress_reports_completion() { let temp = tempdir().unwrap(); - let session = local_session_sync(&temp).unwrap(); + let session = local_session(&temp).unwrap(); let original = b"blocking progress tracking test data"; let file_info = upload_bytes_blocking(&session, original, "progress.bin").unwrap(); @@ -847,7 +863,7 @@ mod tests { // Multiple sequential streaming downloads share a single group's connection pool. async fn test_download_stream_multiple_sequential() { let temp = tempdir().unwrap(); - let session = local_session(&temp).await.unwrap(); + let session = local_session(&temp).unwrap(); let data_a = b"first stream payload"; let data_b = b"second stream payload"; let info_a = upload_bytes(&session, data_a, "a.bin").await.unwrap(); @@ -869,4 +885,40 @@ mod tests { } assert_eq!(collected_b, data_b); } + + // ── Duplicate tokio handle rejection ───────────────────────────────────── + + #[test] + // Building a second session with the same tokio handle while the first is alive must + // fall back to Owned mode rather than returning an error — the duplicate is handled + // gracefully so callers do not need to track handle ownership. + fn test_build_with_same_handle_falls_back_to_owned() { + let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap(); + let handle = tokio_rt.handle().clone(); + + let first = XetSessionBuilder::new().with_tokio_handle(handle.clone()).build().unwrap(); + assert_eq!(first.inner.runtime.mode(), RuntimeMode::External, "first build must use External runtime"); + + let second = XetSessionBuilder::new().with_tokio_handle(handle).build(); + assert!(second.is_ok(), "second build with the same tokio handle must still succeed"); + assert_eq!( + second.unwrap().inner.runtime.mode(), + RuntimeMode::Owned, + "second build must fall back to Owned runtime when External handle is already in use" + ); + } + + #[test] + // After the first session is dropped (deregistering the handle), a new session can + // attach to the same tokio handle successfully. + fn test_build_with_same_handle_succeeds_after_first_is_dropped() { + let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap(); + let handle = tokio_rt.handle().clone(); + + let first = XetSessionBuilder::new().with_tokio_handle(handle.clone()).build().unwrap(); + drop(first); + + let second = XetSessionBuilder::new().with_tokio_handle(handle).build(); + assert!(second.is_ok(), "build must succeed after the previous session holding the same handle is dropped"); + } } diff --git a/xet_pkg/src/xet_session/upload_commit.rs b/xet_pkg/src/xet_session/upload_commit.rs index 475116e2..01b4a6ed 100644 --- a/xet_pkg/src/xet_session/upload_commit.rs +++ b/xet_pkg/src/xet_session/upload_commit.rs @@ -613,7 +613,7 @@ mod tests { use super::*; use crate::xet_session::session::{XetSession, XetSessionBuilder}; - async fn local_session(temp: &TempDir) -> Result> { + fn local_session(temp: &TempDir) -> Result> { let cas_path = temp.path().join("cas"); Ok(XetSessionBuilder::new() .with_endpoint(format!("local://{}", cas_path.display())) @@ -837,7 +837,7 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_upload_bytes_round_trip() { let temp = tempdir().unwrap(); - let session = local_session(&temp).await.unwrap(); + let session = local_session(&temp).unwrap(); let data = b"Hello, upload commit round-trip!"; let commit = session.new_upload_commit().unwrap().build().await.unwrap(); let task_handle = commit @@ -858,7 +858,7 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_upload_bytes_task_id_matches_progress() { let temp = tempdir().unwrap(); - let session = local_session(&temp).await.unwrap(); + let session = local_session(&temp).unwrap(); let commit = session.new_upload_commit().unwrap().build().await.unwrap(); let handle = commit @@ -879,7 +879,7 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_upload_handle_file_path_none_for_bytes_upload() { let temp = tempdir().unwrap(); - let session = local_session(&temp).await.unwrap(); + let session = local_session(&temp).unwrap(); let commit = session.new_upload_commit().unwrap().build().await.unwrap(); let handle = commit .upload_bytes(b"no-path".to_vec(), Sha256Policy::Compute, Some("bytes.bin".into())) @@ -891,7 +891,7 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_upload_from_path_round_trip() { let temp = tempdir().unwrap(); - let session = local_session(&temp).await.unwrap(); + let session = local_session(&temp).unwrap(); let src = temp.path().join("data.bin"); let data = b"file path upload content"; std::fs::write(&src, data).unwrap(); @@ -908,7 +908,7 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_upload_handle_file_path_for_path_upload() { let temp = tempdir().unwrap(); - let session = local_session(&temp).await.unwrap(); + let session = local_session(&temp).unwrap(); let src = temp.path().join("path_meta.bin"); std::fs::write(&src, b"path metadata").unwrap(); let absolute = std::path::absolute(&src).unwrap(); @@ -920,7 +920,7 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_upload_bytes_sha256_policy_metadata() { let temp = tempdir().unwrap(); - let session = local_session(&temp).await.unwrap(); + let session = local_session(&temp).unwrap(); let commit = session.new_upload_commit().unwrap().build().await.unwrap(); let provided_sha256 = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".to_string(); @@ -955,7 +955,7 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_finish_returns_result_before_commit() { let temp = tempdir().unwrap(); - let session = local_session(&temp).await.unwrap(); + let session = local_session(&temp).unwrap(); let data = b"finish before commit"; let commit = session.new_upload_commit().unwrap().build().await.unwrap(); let handle = commit @@ -970,7 +970,7 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_finish_second_call_returns_cached_result() { let temp = tempdir().unwrap(); - let session = local_session(&temp).await.unwrap(); + let session = local_session(&temp).unwrap(); let commit = session.new_upload_commit().unwrap().build().await.unwrap(); let handle = commit .upload_bytes(b"idem".to_vec(), Sha256Policy::Compute, None) @@ -987,7 +987,7 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_finish_includes_dedup_metrics() { let temp = tempdir().unwrap(); - let session = local_session(&temp).await.unwrap(); + let session = local_session(&temp).unwrap(); let data = b"dedup metrics check"; let commit = session.new_upload_commit().unwrap().build().await.unwrap(); let handle = commit.upload_bytes(data.to_vec(), Sha256Policy::Compute, None).await.unwrap(); @@ -1001,7 +1001,7 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_upload_streaming_round_trip() { let temp = tempdir().unwrap(); - let session = local_session(&temp).await.unwrap(); + let session = local_session(&temp).unwrap(); let data = b"streamed upload bytes"; let commit = session.new_upload_commit().unwrap().build().await.unwrap(); let handle = commit @@ -1019,7 +1019,7 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_commit_errors_when_stream_not_finished() { let temp = tempdir().unwrap(); - let session = local_session(&temp).await.unwrap(); + let session = local_session(&temp).unwrap(); let commit = session.new_upload_commit().unwrap().build().await.unwrap(); let handle = commit .upload_stream(Some("unfinished.bin".into()), Sha256Policy::Compute) @@ -1033,7 +1033,7 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_stream_finish_second_call_is_already_completed() { let temp = tempdir().unwrap(); - let session = local_session(&temp).await.unwrap(); + let session = local_session(&temp).unwrap(); let commit = session.new_upload_commit().unwrap().build().await.unwrap(); let handle = commit .upload_stream(Some("idem.bin".into()), Sha256Policy::Compute) @@ -1051,7 +1051,7 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_stream_write_after_finish_errors() { let temp = tempdir().unwrap(); - let session = local_session(&temp).await.unwrap(); + let session = local_session(&temp).unwrap(); let commit = session.new_upload_commit().unwrap().build().await.unwrap(); let handle = commit .upload_stream(Some("done.bin".into()), Sha256Policy::Compute) @@ -1084,7 +1084,7 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_upload_multiple_files_in_one_commit() { let temp = tempdir().unwrap(); - let session = local_session(&temp).await.unwrap(); + let session = local_session(&temp).unwrap(); let commit = session.new_upload_commit().unwrap().build().await.unwrap(); let h1 = commit .upload_bytes(b"file one".to_vec(), Sha256Policy::Compute, Some("a.bin".into())) @@ -1109,7 +1109,7 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_upload_progress_reflects_bytes_after_commit() { let temp = tempdir().unwrap(); - let session = local_session(&temp).await.unwrap(); + let session = local_session(&temp).unwrap(); let data = b"progress tracking upload data"; let commit = session.new_upload_commit().unwrap().build().await.unwrap(); let progress_observer = commit.clone(); @@ -1132,7 +1132,7 @@ mod tests { let temp = tempdir().unwrap(); futures::executor::block_on(async { - let session = local_session(&temp).await.unwrap(); + let session = local_session(&temp).unwrap(); assert_eq!(session.inner.runtime.mode(), RuntimeMode::Owned); let data = b"hello from non-tokio executor"; @@ -1153,7 +1153,7 @@ mod tests { let temp = tempdir().unwrap(); smol::block_on(async { - let session = local_session(&temp).await.unwrap(); + let session = local_session(&temp).unwrap(); assert_eq!(session.inner.runtime.mode(), RuntimeMode::Owned); let data = b"hello from smol executor"; @@ -1174,7 +1174,7 @@ mod tests { let temp = tempdir().unwrap(); async_std::task::block_on(async { - let session = local_session(&temp).await.unwrap(); + let session = local_session(&temp).unwrap(); assert_eq!(session.inner.runtime.mode(), RuntimeMode::Owned); let data = b"hello from async-std executor"; @@ -1192,26 +1192,212 @@ mod tests { // ── Blocking API tests ──────────────────────────────────────────────────── - fn local_session_sync(temp: &TempDir) -> Result> { - let cas_path = temp.path().join("cas"); - Ok(XetSessionBuilder::new() - .with_endpoint(format!("local://{}", cas_path.display())) - .build()?) - } - #[test] fn test_blocking_upload_bytes_round_trip() -> Result<(), Box> { let temp = tempdir()?; - let session = local_session_sync(&temp)?; + let session = local_session(&temp)?; let data = b"Hello, upload commit round-trip!"; let commit = session.new_upload_commit()?.build_blocking()?; let task_handle = commit.upload_bytes_blocking(data.to_vec(), Sha256Policy::Compute, Some("hello.bin".into()))?; + let results = commit.commit_blocking()?; + assert_eq!(results.uploads.len(), 1); + let meta = results.uploads.get(&task_handle.task_id()).unwrap(); + assert_eq!(meta.xet_info.file_size, Some(data.len() as u64)); + assert!(!meta.xet_info.hash.is_empty()); + Ok(()) + } + + #[test] + fn test_blocking_upload_from_path_round_trip() -> Result<(), Box> { + let temp = tempdir()?; + let session = local_session(&temp)?; + let src = temp.path().join("data.bin"); + let data = b"file path upload content"; + std::fs::write(&src, data)?; + let commit = session.new_upload_commit()?.build_blocking()?; + let handle = commit.upload_from_path_blocking(src, Sha256Policy::Compute)?; commit.commit_blocking()?; - let meta = task_handle.try_finish().unwrap(); + let meta = handle.try_finish().unwrap(); assert_eq!(meta.xet_info.file_size, Some(data.len() as u64)); assert!(!meta.xet_info.hash.is_empty()); assert!(meta.xet_info.sha256.is_some()); Ok(()) } + + #[test] + fn test_blocking_upload_result_access_patterns() -> Result<(), Box> { + let temp = tempdir()?; + let session = local_session(&temp)?; + let data = b"result access patterns"; + let src = temp.path().join("data.bin"); + std::fs::write(&src, data)?; + let commit = session.new_upload_commit()?.build_blocking()?; + let handle = commit.upload_from_path_blocking(src, Sha256Policy::Compute)?; + + // Before commit, per-task result is not available yet. + assert!(handle.try_finish().is_none()); + + let results = commit.commit_blocking()?; + + // Result should be available in the commit map by task id. + let map_result = results + .uploads + .get(&handle.task_id()) + .expect("task_id must be present in results"); + assert_eq!(map_result.xet_info.file_size, Some(data.len() as u64)); + + // Result should also be available via the task handle. + let handle_result = handle.try_finish().expect("result must be set after commit"); + assert_eq!(handle_result.xet_info.file_size, Some(data.len() as u64)); + Ok(()) + } + + #[test] + fn test_blocking_upload_streaming_round_trip() -> Result<(), Box> { + let temp = tempdir()?; + let session = local_session(&temp)?; + let data = b"streamed upload bytes"; + let commit = session.new_upload_commit()?.build_blocking()?; + let stream = commit.upload_stream_blocking(Some("stream.bin".into()), Sha256Policy::Compute)?; + stream.write_blocking(data.to_vec())?; + let meta = stream.finish_blocking()?; + let results = commit.commit_blocking()?; + assert_eq!(results.uploads.len(), 1); + assert_eq!(meta.xet_info.file_size, Some(data.len() as u64)); + assert!(!meta.xet_info.hash.is_empty()); + Ok(()) + } + + #[test] + fn test_blocking_upload_multiple_files_in_one_commit() -> Result<(), Box> { + let temp = tempdir()?; + let session = local_session(&temp)?; + let commit = session.new_upload_commit()?.build_blocking()?; + commit.upload_bytes_blocking(b"file one".to_vec(), Sha256Policy::Compute, Some("a.bin".into()))?; + commit.upload_bytes_blocking(b"file two".to_vec(), Sha256Policy::Compute, Some("b.bin".into()))?; + commit.upload_bytes_blocking(b"file three".to_vec(), Sha256Policy::Compute, Some("c.bin".into()))?; + let results = commit.commit_blocking()?; + assert_eq!(results.uploads.len(), 3); + Ok(()) + } + + #[test] + fn test_blocking_upload_progress_reflects_bytes_after_commit() -> Result<(), Box> { + let temp = tempdir()?; + let session = local_session(&temp)?; + let data = b"progress tracking upload data"; + let commit = session.new_upload_commit()?.build_blocking()?; + let progress_observer = commit.clone(); + commit.upload_bytes_blocking(data.to_vec(), Sha256Policy::Compute, Some("prog.bin".into()))?; + commit.commit_blocking()?; + let snapshot = progress_observer.progress(); + assert_eq!(snapshot.total_bytes, data.len() as u64); + assert_eq!(snapshot.total_bytes_completed, data.len() as u64); + assert_eq!(snapshot.total_transfer_bytes, snapshot.total_transfer_bytes_completed); + assert!(snapshot.total_transfer_bytes_completed <= data.len() as u64); + Ok(()) + } + + #[test] + fn test_blocking_upload_file_returns_handle_without_status() -> Result<(), Box> { + let temp = tempdir()?; + let session = local_session(&temp)?; + let commit = session.new_upload_commit()?.build_blocking()?; + let handle = commit.upload_stream_blocking(Some("stream.bin".into()), Sha256Policy::Compute)?; + assert!(handle.try_finish().is_none()); + Ok(()) + } + + fn assert_blocking_upload_round_trip(run: R) + where + R: FnOnce(std::pin::Pin>>), + { + let temp = tempdir().unwrap(); + let session = local_session(&temp).unwrap(); + + run(Box::pin(async move { + let data = b"upload from smol executor"; + let commit = session.new_upload_commit().unwrap().build_blocking().unwrap(); + let handle = commit + .upload_bytes_blocking(data.to_vec(), Sha256Policy::Compute, Some("test.bin".into())) + .unwrap(); + let results = commit.commit_blocking().unwrap(); + let meta = results.uploads.get(&handle.task_id()).unwrap(); + assert_eq!(meta.xet_info.file_size, Some(data.len() as u64)); + assert!(!meta.xet_info.hash.is_empty()); + })); + } + + #[test] + fn test_blocking_upload_round_trip_in_smol() { + assert_blocking_upload_round_trip(|fut| smol::block_on(fut)); + } + + #[test] + fn test_blocking_upload_round_trip_in_futures_executor() { + assert_blocking_upload_round_trip(|fut| futures::executor::block_on(fut)); + } + + #[test] + fn test_blocking_upload_round_trip_in_async_std() { + assert_blocking_upload_round_trip(|fut| async_std::task::block_on(fut)); + } + + // ── External-mode _blocking guard ──────────────────────────────────────── + + #[tokio::test(flavor = "multi_thread")] + async fn test_upload_blocking_methods_error_in_external_mode() { + let session = XetSessionBuilder::new().build().unwrap(); + let commit = session.new_upload_commit().unwrap().build().await.unwrap(); + + let err = commit + .upload_from_path_blocking(PathBuf::from("/nonexistent"), Sha256Policy::Compute) + .err() + .unwrap(); + assert!(matches!(err, XetError::WrongRuntimeMode(_))); + + let err = commit.upload_bytes_blocking(vec![], Sha256Policy::Compute, None).err().unwrap(); + assert!(matches!(err, XetError::WrongRuntimeMode(_))); + + let err = commit.upload_stream_blocking(None, Sha256Policy::Compute).err().unwrap(); + assert!(matches!(err, XetError::WrongRuntimeMode(_))); + } + + // ── Owned-mode _blocking panic guard ───────────────────────────────────── + + #[test] + fn test_upload_from_path_blocking_panics_in_async_context() { + let session = XetSessionBuilder::new().build().unwrap(); + let commit = session.new_upload_commit().unwrap().build_blocking().unwrap(); + let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap(); + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + rt.block_on(async { + commit.upload_from_path_blocking(PathBuf::from("/nonexistent"), Sha256Policy::Compute) + }) + })); + assert!(result.is_err(), "upload_from_path_blocking() must panic when called from async"); + } + + #[test] + fn test_upload_bytes_blocking_panics_in_async_context() { + let session = XetSessionBuilder::new().build().unwrap(); + let commit = session.new_upload_commit().unwrap().build_blocking().unwrap(); + let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap(); + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + rt.block_on(async { commit.upload_bytes_blocking(vec![], Sha256Policy::Compute, None) }) + })); + assert!(result.is_err(), "upload_bytes_blocking() must panic when called from async"); + } + + #[test] + fn test_upload_file_blocking_panics_in_async_context() { + let session = XetSessionBuilder::new().build().unwrap(); + let commit = session.new_upload_commit().unwrap().build_blocking().unwrap(); + let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap(); + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + rt.block_on(async { commit.upload_stream_blocking(None, Sha256Policy::Compute) }) + })); + assert!(result.is_err(), "upload_stream_blocking() must panic when called from async"); + } } diff --git a/xet_pkg/tests/test_xet_session.rs b/xet_pkg/tests/test_xet_session.rs index 7eb7b0f4..66461878 100644 --- a/xet_pkg/tests/test_xet_session.rs +++ b/xet_pkg/tests/test_xet_session.rs @@ -26,17 +26,11 @@ use xet::xet_session::{ // ── Helpers ────────────────────────────────────────────────────────────── -fn local_endpoint(temp: &TempDir) -> String { +fn local_session(temp: &TempDir) -> Result> { let cas_path = temp.path().join("cas"); - format!("local://{}", cas_path.display()) -} - -fn async_session(temp: &TempDir) -> XetSession { - XetSessionBuilder::new().with_endpoint(local_endpoint(temp)).build().unwrap() -} - -fn sync_session(temp: &TempDir) -> XetSession { - XetSessionBuilder::new().with_endpoint(local_endpoint(temp)).build().unwrap() + Ok(XetSessionBuilder::new() + .with_endpoint(format!("local://{}", cas_path.display())) + .build()?) } fn to_file_info(meta: &XetFileMetadata) -> XetFileInfo { @@ -216,14 +210,14 @@ fn deficient_runtime_cases() -> Vec<(&'static str, RuntimeBuilder)> { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_upload_bytes_roundtrip() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); assert_roundtrip_async(&session, &temp, b"async upload bytes test", "bytes").await; } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_upload_from_path_roundtrip() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let src = temp.path().join("source.bin"); let data = b"upload from path integration test content"; @@ -252,7 +246,7 @@ async fn async_upload_from_path_roundtrip() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_multiple_files_in_one_commit() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let files: Vec<(&str, &[u8])> = vec![ ("alpha.bin", b"alpha content"), @@ -294,7 +288,7 @@ async fn async_multiple_files_in_one_commit() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_sha256_policy_variants() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let provided_sha256 = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa".to_string(); let commit = session.new_upload_commit().unwrap().build().await.unwrap(); @@ -328,7 +322,7 @@ async fn async_sha256_policy_variants() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_large_file_roundtrip() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let data: Vec = (0..65536u64).map(|i| (i % 251) as u8).collect(); assert_roundtrip_async(&session, &temp, &data, "large").await; } @@ -336,7 +330,7 @@ async fn async_large_file_roundtrip() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_multiple_commits_and_groups() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let info_a = upload_bytes_async(&session, b"commit A data", "a.bin").await; let info_b = upload_bytes_async(&session, b"commit B data", "b.bin").await; @@ -359,7 +353,7 @@ async fn async_multiple_commits_and_groups() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_task_status_transitions() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let commit = session.new_upload_commit().unwrap().build().await.unwrap(); let handle = commit @@ -378,7 +372,7 @@ async fn async_task_status_transitions() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_progress_tracking() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let data = b"progress tracking integration test data"; let commit = session.new_upload_commit().unwrap().build().await.unwrap(); @@ -399,7 +393,7 @@ async fn async_progress_tracking() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_download_unknown_size_roundtrip() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let data = b"download with unknown size via xet_pkg"; let file_info = upload_bytes_async(&session, data, "unknown_size.bin").await; @@ -419,7 +413,7 @@ async fn async_download_unknown_size_roundtrip() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_download_invalid_hash_fails() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let group = session.new_file_download_group().unwrap().build().await.unwrap(); let handle = group @@ -441,7 +435,7 @@ async fn async_download_invalid_hash_fails() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_upload_from_path_multiple_files() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let src_a = temp.path().join("src_a.bin"); let src_b = temp.path().join("src_b.bin"); @@ -474,14 +468,14 @@ async fn async_upload_from_path_multiple_files() { #[test] fn blocking_upload_bytes_roundtrip() { let temp = tempdir().unwrap(); - let session = sync_session(&temp); + let session = local_session(&temp).unwrap(); assert_roundtrip_sync(&session, &temp, b"blocking upload bytes test", "bytes"); } #[test] fn blocking_upload_from_path_roundtrip() { let temp = tempdir().unwrap(); - let session = sync_session(&temp); + let session = local_session(&temp).unwrap(); assert_upload_from_path_roundtrip_sync( &session, &temp, @@ -494,7 +488,7 @@ fn blocking_upload_from_path_roundtrip() { #[test] fn blocking_multiple_files_roundtrip() { let temp = tempdir().unwrap(); - let session = sync_session(&temp); + let session = local_session(&temp).unwrap(); let data_a = b"blocking file A"; let data_b = b"blocking file B is longer"; @@ -527,7 +521,7 @@ fn blocking_multiple_files_roundtrip() { #[test] fn blocking_large_file_roundtrip() { let temp = tempdir().unwrap(); - let session = sync_session(&temp); + let session = local_session(&temp).unwrap(); let data: Vec = (0..65536u64).map(|i| (i % 251) as u8).collect(); assert_roundtrip_sync(&session, &temp, &data, "large"); } @@ -535,7 +529,7 @@ fn blocking_large_file_roundtrip() { #[test] fn blocking_task_status_transitions() { let temp = tempdir().unwrap(); - let session = sync_session(&temp); + let session = local_session(&temp).unwrap(); let commit = session.new_upload_commit().unwrap().build_blocking().unwrap(); let handle = commit @@ -549,7 +543,7 @@ fn blocking_task_status_transitions() { #[test] fn blocking_progress_tracking() { let temp = tempdir().unwrap(); - let session = sync_session(&temp); + let session = local_session(&temp).unwrap(); let data = b"blocking progress tracking data"; let commit = session.new_upload_commit().unwrap().build_blocking().unwrap(); @@ -567,7 +561,7 @@ fn blocking_progress_tracking() { #[test] fn blocking_multiple_commits_and_groups() { let temp = tempdir().unwrap(); - let session = sync_session(&temp); + let session = local_session(&temp).unwrap(); let info_a = upload_bytes_sync(&session, b"blocking commit A", "a.bin"); let info_b = upload_bytes_sync(&session, b"blocking commit B", "b.bin"); @@ -597,7 +591,7 @@ fn bridge_upload_download_roundtrip() { let temp = tempdir().unwrap(); let tag = executor.label().to_string(); Box::pin(async move { - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let payload = format!("{tag} executor roundtrip"); assert_roundtrip_async(&session, &temp, payload.as_bytes(), &tag).await; }) @@ -610,7 +604,7 @@ fn bridge_multiple_files() { let temp = tempdir().unwrap(); let tag = executor.label().to_string(); Box::pin(async move { - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let files: Vec<(String, Vec)> = vec![ (format!("{tag}_a.bin"), format!("{tag} A").into_bytes()), @@ -654,7 +648,7 @@ fn bridge_upload_from_path_roundtrip() { let temp = tempdir().unwrap(); let tag = executor.label().to_string(); Box::pin(async move { - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let payload = format!("{tag} upload from path"); assert_upload_from_path_roundtrip_async( &session, @@ -674,7 +668,7 @@ fn bridge_large_file_roundtrip() { let temp = tempdir().unwrap(); let tag = executor.label().to_string(); Box::pin(async move { - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let data: Vec = (0..65536u64).map(|i| (i % 251) as u8).collect(); assert_roundtrip_async(&session, &temp, &data, &format!("large_{tag}")).await; }) @@ -695,7 +689,7 @@ fn deficient_tokio_async_roundtrip_matrix() { let rt = builder(); let temp = tempdir().unwrap(); rt.block_on(async { - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let payload = format!("{label} async roundtrip"); assert_roundtrip_async(&session, &temp, payload.as_bytes(), label).await; }); @@ -707,7 +701,7 @@ fn deficient_tokio_no_drivers_multiple_files() { let rt = tokio::runtime::Builder::new_multi_thread().build().unwrap(); let temp = tempdir().unwrap(); rt.block_on(async { - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let (info_a, info_b) = { let commit = session.new_upload_commit().unwrap().build().await.unwrap(); @@ -742,7 +736,7 @@ fn deficient_tokio_no_drivers_upload_from_path() { let rt = build_rt_no_drivers(); let temp = tempdir().unwrap(); rt.block_on(async { - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); assert_upload_from_path_roundtrip_async( &session, &temp, @@ -759,7 +753,7 @@ fn deficient_tokio_no_drivers_large_file() { let rt = build_rt_no_drivers(); let temp = tempdir().unwrap(); rt.block_on(async { - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let data: Vec = (0..65536u64).map(|i| (i % 251) as u8).collect(); assert_roundtrip_async(&session, &temp, &data, "large_deficient").await; }); @@ -776,8 +770,7 @@ fn deficient_tokio_handle_auto_fallback_blocking_roundtrip() { ] { let rt = builder(); let temp = tempdir().unwrap(); - let session = - rt.block_on(async { XetSessionBuilder::new().with_endpoint(local_endpoint(&temp)).build().unwrap() }); + let session = rt.block_on(async { local_session(&temp).unwrap() }); let payload = format!("{label} handle blocking roundtrip"); assert_roundtrip_sync(&session, &temp, payload.as_bytes(), &format!("{label}_blocking")); @@ -794,7 +787,7 @@ fn deficient_tokio_handle_auto_fallback_blocking_roundtrip() { fn blocking_in_non_tokio_executor_roundtrip() { run_on_all_non_tokio_executors(|executor| { let temp = tempdir().unwrap(); - let session = sync_session(&temp); + let session = local_session(&temp).unwrap(); let tag = executor.label().to_string(); Box::pin(async move { let payload = format!("blocking in {tag}"); @@ -807,7 +800,7 @@ fn blocking_in_non_tokio_executor_roundtrip() { fn blocking_in_non_tokio_executor_upload_from_path() { run_on_all_non_tokio_executors(|executor| { let temp = tempdir().unwrap(); - let session = sync_session(&temp); + let session = local_session(&temp).unwrap(); let tag = executor.label().to_string(); Box::pin(async move { let payload = format!("blocking {tag} upload from path"); @@ -910,7 +903,7 @@ async fn async_abort_rejects_download_on_existing_group() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_duplicate_content_produces_same_hash() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let data = b"deduplication test content"; let info1 = upload_bytes_async(&session, data, "first.bin").await; @@ -926,8 +919,8 @@ async fn async_duplicate_content_produces_same_hash() { async fn async_separate_sessions_are_isolated() { let temp1 = tempdir().unwrap(); let temp2 = tempdir().unwrap(); - let session1 = async_session(&temp1); - let session2 = async_session(&temp2); + let session1 = local_session(&temp1).unwrap(); + let session2 = local_session(&temp2).unwrap(); let info1 = upload_bytes_async(&session1, b"session 1 data", "s1.bin").await; @@ -969,7 +962,7 @@ fn collect_stream_blocking(stream: &mut XetDownloadStream) -> Vec { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_stream_roundtrip() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let data = b"async streaming download roundtrip"; let file_info = upload_bytes_async(&session, data, "stream.bin").await; @@ -981,7 +974,7 @@ async fn async_stream_roundtrip() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_stream_large_file() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let data: Vec = (0..65536u64).map(|i| (i % 251) as u8).collect(); let file_info = upload_bytes_async(&session, &data, "large_stream.bin").await; @@ -993,7 +986,7 @@ async fn async_stream_large_file() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_stream_progress_tracking() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let data = b"stream progress tracking integration test"; let file_info = upload_bytes_async(&session, data, "progress_stream.bin").await; @@ -1014,7 +1007,7 @@ async fn async_stream_progress_tracking() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_stream_multiple_sequential() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let data_a = b"stream sequential A"; let data_b = b"stream sequential B is different"; @@ -1032,7 +1025,7 @@ async fn async_stream_multiple_sequential() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_stream_cancel_before_consuming() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let data = b"stream cancel test data"; let file_info = upload_bytes_async(&session, data, "cancel_stream.bin").await; @@ -1044,7 +1037,8 @@ async fn async_stream_cancel_before_consuming() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_stream_aborted_session() { - let session = XetSessionBuilder::new().build().unwrap(); + let temp = tempdir().unwrap(); + let session = local_session(&temp).unwrap(); session.abort().unwrap(); let result = session.new_download_stream_group(); assert!(matches!(result, Err(SessionError::UserCancelled(_)))); @@ -1053,7 +1047,7 @@ async fn async_stream_aborted_session() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_stream_abort_cancels_active_stream() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let data: Vec = (0..65536u64).map(|i| (i % 251) as u8).collect(); let file_info = upload_bytes_async(&session, &data, "abort_stream.bin").await; @@ -1072,7 +1066,7 @@ async fn async_stream_abort_cancels_active_stream() { #[test] fn blocking_stream_roundtrip() { let temp = tempdir().unwrap(); - let session = sync_session(&temp); + let session = local_session(&temp).unwrap(); let data = b"blocking streaming download roundtrip"; let file_info = upload_bytes_sync(&session, data, "stream.bin"); @@ -1084,7 +1078,7 @@ fn blocking_stream_roundtrip() { #[test] fn blocking_stream_large_file() { let temp = tempdir().unwrap(); - let session = sync_session(&temp); + let session = local_session(&temp).unwrap(); let data: Vec = (0..65536u64).map(|i| (i % 251) as u8).collect(); let file_info = upload_bytes_sync(&session, &data, "large_stream.bin"); @@ -1096,7 +1090,7 @@ fn blocking_stream_large_file() { #[test] fn blocking_stream_progress_tracking() { let temp = tempdir().unwrap(); - let session = sync_session(&temp); + let session = local_session(&temp).unwrap(); let data = b"blocking stream progress integration test"; let file_info = upload_bytes_sync(&session, data, "progress_stream.bin"); @@ -1112,7 +1106,7 @@ fn blocking_stream_progress_tracking() { #[test] fn blocking_stream_multiple_sequential() { let temp = tempdir().unwrap(); - let session = sync_session(&temp); + let session = local_session(&temp).unwrap(); let data_a = b"blocking stream seq A"; let data_b = b"blocking stream seq B is longer"; @@ -1156,7 +1150,7 @@ fn bridge_stream_roundtrip() { let temp = tempdir().unwrap(); let tag = executor.label().to_string(); Box::pin(async move { - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let payload = format!("{tag} stream roundtrip"); let file_info = upload_bytes_async(&session, payload.as_bytes(), &format!("{tag}_stream.bin")).await; @@ -1173,7 +1167,7 @@ fn deficient_tokio_stream_roundtrip() { let rt = builder(); let temp = tempdir().unwrap(); rt.block_on(async { - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let payload = format!("{label} deficient stream"); let file_info = upload_bytes_async(&session, payload.as_bytes(), &format!("{label}_stream.bin")).await; @@ -1188,7 +1182,7 @@ fn deficient_tokio_stream_roundtrip() { fn blocking_stream_in_non_tokio_executor() { run_on_all_non_tokio_executors(|executor| { let temp = tempdir().unwrap(); - let session = sync_session(&temp); + let session = local_session(&temp).unwrap(); let tag = executor.label().to_string(); Box::pin(async move { let payload = format!("blocking stream in {tag}"); @@ -1230,7 +1224,7 @@ fn collect_unordered_stream_blocking(stream: &mut XetUnorderedDownloadStream, ex #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_unordered_stream_roundtrip() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let data = b"async unordered streaming download roundtrip"; let file_info = upload_bytes_async(&session, data, "unordered.bin").await; @@ -1242,7 +1236,7 @@ async fn async_unordered_stream_roundtrip() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_unordered_stream_large_file() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let data: Vec = (0..65536u64).map(|i| (i % 251) as u8).collect(); let file_info = upload_bytes_async(&session, &data, "large_unordered.bin").await; @@ -1254,7 +1248,7 @@ async fn async_unordered_stream_large_file() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_unordered_stream_progress_tracking() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let data = b"unordered stream progress tracking integration test"; let file_info = upload_bytes_async(&session, data, "progress_unordered.bin").await; @@ -1275,7 +1269,7 @@ async fn async_unordered_stream_progress_tracking() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_unordered_stream_cancel_before_consuming() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let data = b"unordered stream cancel test data"; let file_info = upload_bytes_async(&session, data, "cancel_unordered.bin").await; @@ -1287,7 +1281,8 @@ async fn async_unordered_stream_cancel_before_consuming() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_unordered_stream_aborted_session() { - let session = XetSessionBuilder::new().build().unwrap(); + let temp = tempdir().unwrap(); + let session = local_session(&temp).unwrap(); session.abort().unwrap(); let result = session.new_download_stream_group(); assert!(matches!(result, Err(SessionError::UserCancelled(_)))); @@ -1296,7 +1291,7 @@ async fn async_unordered_stream_aborted_session() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_unordered_stream_abort_cancels_active_stream() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let data: Vec = (0..65536u64).map(|i| (i % 251) as u8).collect(); let file_info = upload_bytes_async(&session, &data, "abort_unordered_stream.bin").await; @@ -1315,7 +1310,7 @@ async fn async_unordered_stream_abort_cancels_active_stream() { #[test] fn blocking_unordered_stream_roundtrip() { let temp = tempdir().unwrap(); - let session = sync_session(&temp); + let session = local_session(&temp).unwrap(); let data = b"blocking unordered streaming download roundtrip"; let file_info = upload_bytes_sync(&session, data, "unordered.bin"); @@ -1327,7 +1322,7 @@ fn blocking_unordered_stream_roundtrip() { #[test] fn blocking_unordered_stream_large_file() { let temp = tempdir().unwrap(); - let session = sync_session(&temp); + let session = local_session(&temp).unwrap(); let data: Vec = (0..65536u64).map(|i| (i % 251) as u8).collect(); let file_info = upload_bytes_sync(&session, &data, "large_unordered.bin"); @@ -1339,7 +1334,7 @@ fn blocking_unordered_stream_large_file() { #[test] fn blocking_unordered_stream_progress_tracking() { let temp = tempdir().unwrap(); - let session = sync_session(&temp); + let session = local_session(&temp).unwrap(); let data = b"blocking unordered stream progress integration test"; let file_info = upload_bytes_sync(&session, data, "progress_unordered.bin"); @@ -1381,7 +1376,7 @@ fn bridge_unordered_stream_roundtrip() { let temp = tempdir().unwrap(); let tag = executor.label().to_string(); Box::pin(async move { - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let payload = format!("{tag} unordered stream roundtrip"); let file_info = upload_bytes_async(&session, payload.as_bytes(), &format!("{tag}_unordered.bin")).await; @@ -1398,7 +1393,7 @@ fn deficient_tokio_unordered_stream_roundtrip() { let rt = builder(); let temp = tempdir().unwrap(); rt.block_on(async { - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let payload = format!("{label} deficient unordered stream"); let file_info = upload_bytes_async(&session, payload.as_bytes(), &format!("{label}_unordered.bin")).await; @@ -1413,7 +1408,7 @@ fn deficient_tokio_unordered_stream_roundtrip() { fn blocking_unordered_stream_in_non_tokio_executor() { run_on_all_non_tokio_executors(|executor| { let temp = tempdir().unwrap(); - let session = sync_session(&temp); + let session = local_session(&temp).unwrap(); let tag = executor.label().to_string(); Box::pin(async move { let payload = format!("blocking unordered stream in {tag}"); @@ -1441,7 +1436,7 @@ const RANGE_TEST_DATA: &[u8; 256] = &{ #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_stream_range_middle() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let file_info = upload_bytes_async(&session, RANGE_TEST_DATA, "range.bin").await; let group = async_stream_group(&session).await; @@ -1452,7 +1447,7 @@ async fn async_stream_range_middle() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_stream_range_from_start() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let file_info = upload_bytes_async(&session, RANGE_TEST_DATA, "range_start.bin").await; let group = async_stream_group(&session).await; @@ -1463,7 +1458,7 @@ async fn async_stream_range_from_start() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_stream_range_to_end() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let file_info = upload_bytes_async(&session, RANGE_TEST_DATA, "range_end.bin").await; let group = async_stream_group(&session).await; @@ -1474,7 +1469,7 @@ async fn async_stream_range_to_end() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_stream_range_full() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let file_info = upload_bytes_async(&session, RANGE_TEST_DATA, "range_full.bin").await; let group = async_stream_group(&session).await; @@ -1485,7 +1480,7 @@ async fn async_stream_range_full() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_stream_range_progress() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let file_info = upload_bytes_async(&session, RANGE_TEST_DATA, "range_progress.bin").await; let group = async_stream_group(&session).await; @@ -1505,7 +1500,7 @@ async fn async_stream_range_progress() { #[test] fn blocking_stream_range_middle() { let temp = tempdir().unwrap(); - let session = sync_session(&temp); + let session = local_session(&temp).unwrap(); let file_info = upload_bytes_sync(&session, RANGE_TEST_DATA, "range.bin"); let group = sync_stream_group(&session); @@ -1516,7 +1511,7 @@ fn blocking_stream_range_middle() { #[test] fn blocking_stream_range_progress() { let temp = tempdir().unwrap(); - let session = sync_session(&temp); + let session = local_session(&temp).unwrap(); let file_info = upload_bytes_sync(&session, RANGE_TEST_DATA, "range_progress.bin"); let group = sync_stream_group(&session); @@ -1531,7 +1526,7 @@ fn blocking_stream_range_progress() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_unordered_stream_range_middle() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let file_info = upload_bytes_async(&session, RANGE_TEST_DATA, "unord_range.bin").await; let group = async_stream_group(&session).await; @@ -1542,7 +1537,7 @@ async fn async_unordered_stream_range_middle() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_unordered_stream_range_from_start() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let file_info = upload_bytes_async(&session, RANGE_TEST_DATA, "unord_range_start.bin").await; let group = async_stream_group(&session).await; @@ -1553,7 +1548,7 @@ async fn async_unordered_stream_range_from_start() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_unordered_stream_range_to_end() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let file_info = upload_bytes_async(&session, RANGE_TEST_DATA, "unord_range_end.bin").await; let group = async_stream_group(&session).await; @@ -1564,7 +1559,7 @@ async fn async_unordered_stream_range_to_end() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_unordered_stream_range_progress() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let file_info = upload_bytes_async(&session, RANGE_TEST_DATA, "unord_range_progress.bin").await; let group = async_stream_group(&session).await; @@ -1584,7 +1579,7 @@ async fn async_unordered_stream_range_progress() { #[test] fn blocking_unordered_stream_range_middle() { let temp = tempdir().unwrap(); - let session = sync_session(&temp); + let session = local_session(&temp).unwrap(); let file_info = upload_bytes_sync(&session, RANGE_TEST_DATA, "unord_range.bin"); let group = sync_stream_group(&session); @@ -1595,7 +1590,7 @@ fn blocking_unordered_stream_range_middle() { #[test] fn blocking_unordered_stream_range_progress() { let temp = tempdir().unwrap(); - let session = sync_session(&temp); + let session = local_session(&temp).unwrap(); let file_info = upload_bytes_sync(&session, RANGE_TEST_DATA, "unord_range_progress.bin"); let group = sync_stream_group(&session); @@ -1613,7 +1608,7 @@ fn bridge_stream_range_roundtrip() { let temp = tempdir().unwrap(); let tag = executor.label().to_string(); Box::pin(async move { - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let file_info = upload_bytes_async(&session, RANGE_TEST_DATA, &format!("{tag}_range_stream.bin")).await; let group = async_stream_group(&session).await; @@ -1629,7 +1624,7 @@ fn bridge_unordered_stream_range_roundtrip() { let temp = tempdir().unwrap(); let tag = executor.label().to_string(); Box::pin(async move { - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let file_info = upload_bytes_async(&session, RANGE_TEST_DATA, &format!("{tag}_range_unord.bin")).await; let group = async_stream_group(&session).await; @@ -1645,7 +1640,7 @@ fn deficient_tokio_stream_range_roundtrip() { let rt = builder(); let temp = tempdir().unwrap(); rt.block_on(async { - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let file_info = upload_bytes_async(&session, RANGE_TEST_DATA, &format!("{label}_range_stream.bin")).await; let group = async_stream_group(&session).await; @@ -1661,7 +1656,7 @@ fn deficient_tokio_unordered_stream_range_roundtrip() { let rt = builder(); let temp = tempdir().unwrap(); rt.block_on(async { - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let file_info = upload_bytes_async(&session, RANGE_TEST_DATA, &format!("{label}_range_unord.bin")).await; let group = async_stream_group(&session).await; @@ -1675,7 +1670,7 @@ fn deficient_tokio_unordered_stream_range_roundtrip() { fn blocking_stream_range_in_non_tokio_executor() { run_on_all_non_tokio_executors(|executor| { let temp = tempdir().unwrap(); - let session = sync_session(&temp); + let session = local_session(&temp).unwrap(); let tag = executor.label().to_string(); Box::pin(async move { let file_info = upload_bytes_sync(&session, RANGE_TEST_DATA, &format!("{tag}_range_stream.bin")); @@ -1691,7 +1686,7 @@ fn blocking_stream_range_in_non_tokio_executor() { fn blocking_unordered_stream_range_in_non_tokio_executor() { run_on_all_non_tokio_executors(|executor| { let temp = tempdir().unwrap(); - let session = sync_session(&temp); + let session = local_session(&temp).unwrap(); let tag = executor.label().to_string(); Box::pin(async move { let file_info = upload_bytes_sync(&session, RANGE_TEST_DATA, &format!("{tag}_range_unord.bin")); @@ -1706,7 +1701,7 @@ fn blocking_unordered_stream_range_in_non_tokio_executor() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_stream_range_large_file() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let data: Vec = (0..65536u64).map(|i| (i % 251) as u8).collect(); let file_info = upload_bytes_async(&session, &data, "range_large.bin").await; @@ -1718,7 +1713,7 @@ async fn async_stream_range_large_file() { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn async_unordered_stream_range_large_file() { let temp = tempdir().unwrap(); - let session = async_session(&temp); + let session = local_session(&temp).unwrap(); let data: Vec = (0..65536u64).map(|i| (i % 251) as u8).collect(); let file_info = upload_bytes_async(&session, &data, "range_large_unord.bin").await; diff --git a/xet_runtime/src/core/runtime.rs b/xet_runtime/src/core/runtime.rs index b6b51755..18a4b52f 100644 --- a/xet_runtime/src/core/runtime.rs +++ b/xet_runtime/src/core/runtime.rs @@ -15,9 +15,9 @@ use reqwest::Client; use tokio::runtime::{Builder as TokioRuntimeBuilder, Handle as TokioRuntimeHandle, Runtime as TokioRuntime}; use tokio::sync::oneshot; use tokio::task::JoinHandle; +use tracing::debug; #[cfg(not(target_family = "wasm"))] use tracing::info; -use tracing::{debug, warn}; use super::XetCommon; use crate::config::XetConfig; @@ -346,8 +346,11 @@ impl XetRuntime { /// Wrap a caller-provided tokio handle after validating that it meets requirements. /// - /// Returns [`RuntimeError::InvalidRuntime`] if the handle lacks multi-thread - /// flavor, time driver, or IO driver. + /// # Errors + /// + /// - [`RuntimeError::InvalidRuntime`] — the handle lacks multi-thread flavor, time driver, or IO driver. + /// - [`RuntimeError::ExternalAlreadyAttached`] — a live `XetRuntime` is already registered for this handle (checked + /// inside [`from_external_with_config`](Self::from_external_with_config)). /// /// Not available on WASM targets. #[cfg(not(target_family = "wasm"))] @@ -362,7 +365,7 @@ impl XetRuntime { .into(), )); } - Ok(Self::from_external_with_config(rt_handle, config)) + Self::from_external_with_config(rt_handle, config) } /// Wrap an existing tokio [`TokioRuntimeHandle`] with a [`XetRuntime`] using the provided @@ -373,8 +376,25 @@ impl XetRuntime { /// [`XetRuntime::current()`] called from tasks running on `rt_handle`'s threads will return /// this instance (with the correct config and shared `XetCommon`) rather than a default /// throwaway. The entry is removed when the last `Arc` is dropped. - pub fn from_external_with_config(rt_handle: TokioRuntimeHandle, config: XetConfig) -> Arc { + /// + /// # Errors + /// + /// - [`RuntimeError::ExternalAlreadyAttached`] — a live `XetRuntime` is already registered for `rt_handle`'s tokio + /// runtime ID (i.e. the same handle was wrapped twice while the first is still alive). Drop the existing + /// `XetRuntime` first, or use a different handle. + pub fn from_external_with_config( + rt_handle: TokioRuntimeHandle, + config: XetConfig, + ) -> Result, RuntimeError> { let id = rt_handle.id(); + + let mut reg = EXTERNAL_RUNTIME_REGISTRY.write()?; + if let Some(existing) = reg.get(&id) + && existing.upgrade().is_some() + { + return Err(RuntimeError::ExternalAlreadyAttached(id)); + } + let rt = Arc::new(Self { backend: RuntimeBackend::External { handle_id: Some(id) }, handle_ref: rt_handle.into(), @@ -395,17 +415,19 @@ impl XetRuntime { .flatten(), config: Arc::new(config), }); - if let Ok(mut reg) = EXTERNAL_RUNTIME_REGISTRY.write() { - reg.insert(id, Arc::downgrade(&rt)); - } else { - warn!("EXTERNAL_RUNTIME_REGISTRY poisoned; skipping registration"); - } - rt + + reg.insert(id, Arc::downgrade(&rt)); + Ok(rt) } /// Wrap an existing tokio [`TokioRuntimeHandle`] with a [`XetRuntime`] using a default /// [`XetConfig`]. Prefer [`from_external_with_config`](Self::from_external_with_config) when /// you have a config available. + /// + /// Unlike [`from_external_with_config`](Self::from_external_with_config), this function does + /// **not** register the runtime in `EXTERNAL_RUNTIME_REGISTRY` and therefore performs no + /// duplicate-handle check. It is intended for lightweight, short-lived wrapping where + /// registry lookup via [`XetRuntime::current()`] is not required. pub fn from_external(rt_handle: TokioRuntimeHandle) -> Arc { let config = XetConfig::new(); Arc::new(Self { @@ -844,7 +866,7 @@ mod tests { let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap(); let mut config = XetConfig::new(); config.data.default_cas_endpoint = "https://test-endpoint.example.com".into(); - let xet_rt = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), config); + let xet_rt = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), config).unwrap(); // current_if_exists() from within the runtime must find the registered entry. tokio_rt.block_on(async { @@ -875,7 +897,7 @@ mod tests { #[test] fn test_bridge_async_external_mode_runs_directly() { let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap(); - let xet_rt = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), XetConfig::new()); + let xet_rt = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), XetConfig::new()).unwrap(); assert_eq!(xet_rt.mode(), RuntimeMode::External); let result = tokio_rt.block_on(async { xet_rt.bridge_async("test", async { 99 }).await.unwrap() }); @@ -902,7 +924,7 @@ mod tests { #[test] fn test_bridge_sync_external_mode_returns_error() { let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap(); - let xet_rt = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), XetConfig::new()); + let xet_rt = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), XetConfig::new()).unwrap(); assert_eq!(xet_rt.mode(), RuntimeMode::External); let result = xet_rt.bridge_sync(async { 789 }); @@ -967,4 +989,39 @@ mod tests { let err = result.unwrap().unwrap_err(); assert!(matches!(err, RuntimeError::TaskPanic(_))); } + + #[test] + // Wrapping the same tokio handle a second time (while the first XetRuntime is alive) + // must return ExternalAlreadyAttached. + fn test_from_external_with_config_duplicate_handle_fails() { + let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap(); + let _first = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), XetConfig::new()).unwrap(); + let second = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), XetConfig::new()); + assert!( + matches!(second, Err(RuntimeError::ExternalAlreadyAttached(_))), + "expected ExternalAlreadyAttached for duplicate handle, got: {second:?}" + ); + } + + #[test] + // After the first XetRuntime is dropped (deregistered), wrapping the same handle again + // must succeed. + fn test_from_external_with_config_reuse_handle_after_drop() { + let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap(); + let first = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), XetConfig::new()).unwrap(); + drop(first); + let second = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), XetConfig::new()); + assert!(second.is_ok(), "expected Ok after previous XetRuntime was dropped, got: {second:?}"); + } + + #[test] + // Two distinct tokio runtimes must each accept their own XetRuntime without conflict. + fn test_from_external_with_config_distinct_handles_both_succeed() { + let rt_a = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap(); + let rt_b = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap(); + let xet_a = XetRuntime::from_external_with_config(rt_a.handle().clone(), XetConfig::new()); + let xet_b = XetRuntime::from_external_with_config(rt_b.handle().clone(), XetConfig::new()); + assert!(xet_a.is_ok()); + assert!(xet_b.is_ok()); + } } diff --git a/xet_runtime/src/error.rs b/xet_runtime/src/error.rs index 1195b9c8..06eb5705 100644 --- a/xet_runtime/src/error.rs +++ b/xet_runtime/src/error.rs @@ -9,6 +9,9 @@ pub enum RuntimeError { #[error("Invalid runtime: {0}")] InvalidRuntime(String), + #[error("A XetRuntime is already attached to this tokio runtime handle with Id {0}")] + ExternalAlreadyAttached(tokio::runtime::Id), + #[error("Task panic: {0:?}")] TaskPanic(String),