mirror of
https://github.com/huggingface/xet-core.git
synced 2026-06-04 13:30:29 +08:00
Move XetRuntime model away from thread-local statics (#801)
This PR moves the XetRuntime model away from using thread-local statics
and decouples the XetConfig and XetCommon structs from a single runtime.
It introduces a struct XetContext that gives the runtime context for
operations:
```
struct XetContext {
pub runtime : Arc<XetRuntime>, // The current tokio runtime wrapper, minus the config and common objects..
pub common : Arc<XetCommon>, // The common cache objects, semaphores, rate trackers, etc.
pub config : Arc<XetConfig> // The config
}
```
Now, instead of using functions like `xet_runtime()` and `xet_config()` that examine the thread-local storage, we now explicitly passing through a XetContext instance from the session creation that gets stored in each major processing struct.
This allows decoupling between the runtime, config, and common caches, especially:
- Running multiple config settings and/or endpoints within the same pre-existing tokio runtime.
- Running multiple runtimes that share the same XetCommon object.
This commit is contained in:
107
api_changes/update_260402_runtime_and_threadpool_renames.md
Normal file
107
api_changes/update_260402_runtime_and_threadpool_renames.md
Normal file
@@ -0,0 +1,107 @@
|
||||
# API Update: Split runtime execution from runtime context (2026-04-02)
|
||||
|
||||
## Overview
|
||||
|
||||
This update splits the old monolithic `xet_runtime::core::XetRuntime` into:
|
||||
|
||||
- `XetRuntime`: execution backend (Tokio thread pool or external handle wrapper) with async/sync bridge entry points. Lives in `xet_runtime/src/core/runtime.rs`.
|
||||
- `XetContext`: lightweight clonable context containing `{ runtime: Arc<XetRuntime>, config: Arc<XetConfig>, common: Arc<XetCommon> }`. Lives in `xet_runtime/src/core/context.rs`.
|
||||
|
||||
It also changes default context selection when called inside an existing Tokio runtime.
|
||||
|
||||
---
|
||||
|
||||
## Breaking Changes
|
||||
|
||||
### Runtime type split
|
||||
|
||||
- On `origin/main`, `XetRuntime` was the combined threadpool/executor/config/cache wrapper.
|
||||
- After this update, `XetRuntime` is the pure execution backend.
|
||||
- `XetContext` is the new clonable context wrapper:
|
||||
- `pub runtime: Arc<XetRuntime>`
|
||||
- `pub config: Arc<XetConfig>`
|
||||
- `pub common: Arc<XetCommon>`
|
||||
|
||||
Code that previously accessed config or common state on `XetRuntime`
|
||||
must now go through an `XetContext` and access `ctx.config`, `ctx.common`,
|
||||
or `ctx.runtime` for execution methods.
|
||||
|
||||
### Context construction behavior
|
||||
|
||||
`XetContext::default()` now:
|
||||
|
||||
1. Checks the thread-local for an existing owned `XetRuntime` (TLS reuse).
|
||||
2. Detects a current Tokio handle with `Handle::try_current()` and uses it only if it satisfies runtime requirements (`handle_meets_requirements`).
|
||||
3. Falls back to creating a new owned `XetRuntime` otherwise.
|
||||
|
||||
This means callers running inside a compatible Tokio runtime can receive an External-mode `XetRuntime`
|
||||
rather than always creating a new owned pool.
|
||||
|
||||
### Removed/relocated exports
|
||||
|
||||
- `xet_runtime::core::check_sigint_shutdown` free function is no longer exported.
|
||||
Use `ctx.check_sigint_shutdown()` (or `ctx.runtime.in_sigint_shutdown()` where needed).
|
||||
- `xet_runtime::core::xet_config()` is no longer exported.
|
||||
Pass `&XetContext`/`&XetConfig` explicitly instead of relying on a process-global accessor.
|
||||
- `XORB_CUT_THRESHOLD_BYTES` and `XORB_CUT_THRESHOLD_CHUNKS` global constants are removed;
|
||||
compute thresholds at the point of use via the explicit config (`ctx.config.xorb.simulation_max_bytes`, etc.).
|
||||
|
||||
---
|
||||
|
||||
## Migration Guide
|
||||
|
||||
Update execution calls and context construction:
|
||||
|
||||
```rust
|
||||
// Before (origin/main)
|
||||
use xet_runtime::core::XetRuntime;
|
||||
let rt = XetRuntime::new().unwrap();
|
||||
rt.bridge_sync(async { /* ... */ })?;
|
||||
|
||||
// After
|
||||
use xet_runtime::core::{XetContext, XetRuntime};
|
||||
use xet_runtime::config::XetConfig;
|
||||
let config = XetConfig::new();
|
||||
let runtime = XetRuntime::new(&config)?;
|
||||
let ctx = XetContext::new(config, runtime);
|
||||
ctx.runtime.bridge_sync(async { /* ... */ })?;
|
||||
|
||||
// Or use the default constructor:
|
||||
let ctx = XetContext::default()?;
|
||||
ctx.runtime.bridge_sync(async { /* ... */ })?;
|
||||
```
|
||||
|
||||
Most call sites now accept `&XetContext` and reach the execution backend via `ctx.runtime`.
|
||||
Code that assumes `XetContext::default()` always creates an owned runtime should check
|
||||
`ctx.runtime.mode()` (`RuntimeMode::Owned` vs `RuntimeMode::External`) and adjust
|
||||
blocking/async usage accordingly.
|
||||
|
||||
---
|
||||
|
||||
### Logging configuration
|
||||
|
||||
- `LoggingConfig::default_to_directory` is removed; use `LoggingConfig::from_directory(&XetConfig, ...)` instead.
|
||||
|
||||
### Chunk cache
|
||||
|
||||
- `DiskCache::initialize` and related APIs now take explicit `&XetConfig` instead of
|
||||
reading it from a global accessor.
|
||||
|
||||
### Shared caches
|
||||
|
||||
- Subsystem caches (shard file manager, shard file cache, reqwest client) are moved from
|
||||
process-global statics into `XetCommon::cache_get_or_create`, scoped to the context lifetime.
|
||||
|
||||
---
|
||||
|
||||
## Affected Crates
|
||||
|
||||
- `xet_runtime` — `context.rs` (`XetContext`), `runtime.rs` (`XetRuntime` execution backend), `common.rs` (shared runtime-scoped caches), `config.rs` (removed global accessor), `mod.rs` (re-exports)
|
||||
- `xet_core_structures` — shard manager and session directory APIs now take `&XetContext`
|
||||
- `xet_client` — `http_client`, `RemoteClient`, `RetryWrapper`, `AdaptiveConcurrencyController`, chunk cache, auth, hub client, and simulation clients updated
|
||||
- `xet_data` — `TranslatorConfig`, file upload/download sessions, file reconstruction, deduplication, shard interface
|
||||
- `xet_pkg` — `XetSession`, legacy data client, and all session sub-modules thread `&XetContext`
|
||||
- `git_xet` — LFS agent and token refresher use a process-global `XetContext`
|
||||
- `hf_xet` — Python bindings use `get_or_init_runtime()` to manage a process-global `Arc<XetContext>`
|
||||
- `simulation` — upload simulation uses `XetContext::from_external` for host runtime reuse
|
||||
- `wasm/hf_xet_wasm` — file upload session and cleaner thread `XetContext` context
|
||||
@@ -8,12 +8,19 @@ use xet_client::cas_client::auth::TokenRefresher;
|
||||
use xet_client::hub_client::Operation;
|
||||
use xet_pkg::legacy::progress_tracking::{GroupProgressCallbackUpdater, ProgressUpdate, TrackingProgressUpdater};
|
||||
use xet_pkg::legacy::{FileUploadSession, Sha256Policy, clean_file, default_config};
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
use crate::constants::{
|
||||
HF_ENDPOINT_ENV, XET_ACCESS_TOKEN_HEADER, XET_CAS_URL, XET_SESSION_ID, XET_TOKEN_EXPIRATION_HEADER,
|
||||
};
|
||||
|
||||
const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
|
||||
|
||||
fn xet_runtime() -> &'static XetContext {
|
||||
static RUNTIME: OnceLock<XetContext> = OnceLock::new();
|
||||
RUNTIME.get_or_init(|| XetContext::default().expect("xet context"))
|
||||
}
|
||||
|
||||
use crate::errors::{GitXetError, Result};
|
||||
use crate::git_repo::GitRepo;
|
||||
use crate::git_url::{GitUrl, Scheme};
|
||||
@@ -85,6 +92,7 @@ impl TransferAgent for XetAgent {
|
||||
|
||||
let session_id = req.action.header.get(XET_SESSION_ID).map(|s| s.as_str()).unwrap_or_default();
|
||||
let token_refresher: Arc<dyn TokenRefresher> = Arc::new(new_git_token_refresher(
|
||||
xet_runtime(),
|
||||
repo,
|
||||
self.remote_url.clone(),
|
||||
&req.action.href,
|
||||
@@ -130,9 +138,14 @@ impl TransferAgent for XetAgent {
|
||||
|
||||
let headers = user_agent_headers;
|
||||
|
||||
let mut config =
|
||||
default_config(cas_url, Some((token, token_expiry)), Some(token_refresher), Some(Arc::new(headers)))?
|
||||
.disable_progress_aggregation();
|
||||
let mut config = default_config(
|
||||
xet_runtime(),
|
||||
cas_url,
|
||||
Some((token, token_expiry)),
|
||||
Some(token_refresher),
|
||||
Some(Arc::new(headers)),
|
||||
)?
|
||||
.disable_progress_aggregation();
|
||||
if !session_id.is_empty() {
|
||||
config.session.session_id = Some(session_id.to_owned());
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ use http::header::HeaderMap;
|
||||
use xet_client::cas_client::auth::DirectRefreshRouteTokenRefresher;
|
||||
use xet_client::common::http_client::build_http_client;
|
||||
use xet_client::hub_client::Operation;
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
use crate::auth::get_credential;
|
||||
use crate::errors::Result;
|
||||
@@ -13,6 +14,7 @@ use crate::git_url::GitUrl;
|
||||
/// Build a [`DirectRefreshRouteTokenRefresher`] for the git-xet path,
|
||||
/// deriving credentials from the git repo's credential helper.
|
||||
pub fn new_git_token_refresher(
|
||||
ctx: &XetContext,
|
||||
repo: &GitRepo,
|
||||
remote_url: Option<GitUrl>,
|
||||
refresh_route: &str,
|
||||
@@ -25,6 +27,11 @@ pub fn new_git_token_refresher(
|
||||
None => repo.remote_url()?,
|
||||
};
|
||||
let cred_helper = get_credential(repo, &remote_url, operation)?;
|
||||
let client = build_http_client(session_id, None, custom_headers)?;
|
||||
Ok(DirectRefreshRouteTokenRefresher::new(refresh_route.to_owned(), client, Some(cred_helper)))
|
||||
let client = build_http_client(ctx, session_id, None, custom_headers)?;
|
||||
Ok(DirectRefreshRouteTokenRefresher::new(
|
||||
ctx.clone(),
|
||||
refresh_route.to_owned(),
|
||||
client,
|
||||
Some(cred_helper),
|
||||
))
|
||||
}
|
||||
|
||||
@@ -14,13 +14,13 @@ use pyo3::exceptions::{PyKeyboardInterrupt, PyValueError};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::pyfunction;
|
||||
use rand::Rng;
|
||||
use runtime::async_run;
|
||||
use runtime::{async_run, get_or_init_runtime};
|
||||
use token_refresh::WrappedTokenRefresher;
|
||||
use tracing::debug;
|
||||
use xet_pkg::XetError;
|
||||
use xet_pkg::legacy::progress_tracking::TrackingProgressUpdater;
|
||||
use xet_pkg::legacy::{Sha256Policy, XetFileInfo, data_client};
|
||||
use xet_runtime::core::file_handle_limits;
|
||||
use xet_runtime::core::{XetContext, file_handle_limits};
|
||||
|
||||
use crate::logging::init_logging;
|
||||
use crate::progress_update::WrappedProgressUpdater;
|
||||
@@ -112,7 +112,11 @@ pub fn upload_bytes(
|
||||
};
|
||||
|
||||
let refresher = token_refresher.map(WrappedTokenRefresher::from_func).transpose()?.map(Arc::new);
|
||||
let updater = progress_updater.map(WrappedProgressUpdater::new).transpose()?.map(Arc::new);
|
||||
let runtime = get_or_init_runtime().map_err(convert_xet_error)?;
|
||||
let updater = progress_updater
|
||||
.map(|p| WrappedProgressUpdater::new(p, runtime.clone()))
|
||||
.transpose()?
|
||||
.map(Arc::new);
|
||||
let x: u64 = rand::rng().random();
|
||||
|
||||
// Convert Python dict -> Rust HashMap -> HeaderMap and merge with USER_AGENT
|
||||
@@ -124,8 +128,8 @@ pub fn upload_bytes(
|
||||
std::process::id(),
|
||||
file_contents.len(),
|
||||
);
|
||||
|
||||
let out: Vec<PyXetUploadInfo> = data_client::upload_bytes_async(
|
||||
&runtime,
|
||||
file_contents,
|
||||
sha256_policies,
|
||||
endpoint,
|
||||
@@ -182,7 +186,11 @@ pub fn upload_files(
|
||||
};
|
||||
|
||||
let refresher = token_refresher.map(WrappedTokenRefresher::from_func).transpose()?.map(Arc::new);
|
||||
let updater = progress_updater.map(WrappedProgressUpdater::new).transpose()?.map(Arc::new);
|
||||
let runtime = get_or_init_runtime().map_err(convert_xet_error)?;
|
||||
let updater = progress_updater
|
||||
.map(|p| WrappedProgressUpdater::new(p, runtime.clone()))
|
||||
.transpose()?
|
||||
.map(Arc::new);
|
||||
|
||||
let file_names = file_paths.iter().take(3).join(", ");
|
||||
|
||||
@@ -198,8 +206,8 @@ pub fn upload_files(
|
||||
file_paths.len(),
|
||||
if file_paths.len() > 3 { "..." } else { "." }
|
||||
);
|
||||
|
||||
let out: Vec<PyXetUploadInfo> = data_client::upload_async(
|
||||
&runtime,
|
||||
file_paths,
|
||||
sha256_policies,
|
||||
endpoint,
|
||||
@@ -249,7 +257,8 @@ pub fn upload_files(
|
||||
#[pyo3(signature = (file_paths), text_signature = "(file_paths: List[str]) -> List[PyXetUploadInfo]")]
|
||||
pub fn hash_files(py: Python, file_paths: Vec<String>) -> PyResult<Vec<PyXetUploadInfo>> {
|
||||
async_run(py, async move {
|
||||
let out: Vec<PyXetUploadInfo> = data_client::hash_files_async(file_paths)
|
||||
let runtime = get_or_init_runtime().map_err(convert_xet_error)?;
|
||||
let out: Vec<PyXetUploadInfo> = data_client::hash_files_async(&runtime, file_paths)
|
||||
.await
|
||||
.map_err(convert_xet_error)?
|
||||
.into_iter()
|
||||
@@ -273,7 +282,8 @@ pub fn download_files(
|
||||
) -> PyResult<Vec<String>> {
|
||||
let file_infos: Vec<_> = files.into_iter().map(<(XetFileInfo, DestinationPath)>::from).collect();
|
||||
let refresher = token_refresher.map(WrappedTokenRefresher::from_func).transpose()?.map(Arc::new);
|
||||
let updaters = progress_updater.map(try_parse_progress_updaters).transpose()?;
|
||||
let runtime = get_or_init_runtime().map_err(convert_xet_error)?;
|
||||
let updaters = progress_updater.map(|f| try_parse_progress_updaters(f, &runtime)).transpose()?;
|
||||
|
||||
// Convert Python dict -> Rust HashMap -> HeaderMap and merge with USER_AGENT
|
||||
let header_map = build_headers_with_user_agent(request_headers)?;
|
||||
@@ -289,8 +299,8 @@ pub fn download_files(
|
||||
file_infos.len(),
|
||||
if file_infos.len() > 3 { "..." } else { "." }
|
||||
);
|
||||
|
||||
let out: Vec<String> = data_client::download_async(
|
||||
&runtime,
|
||||
file_infos,
|
||||
endpoint,
|
||||
token_info,
|
||||
@@ -314,10 +324,13 @@ pub fn force_sigint_shutdown() -> PyResult<()> {
|
||||
Err(PyKeyboardInterrupt::new_err(()))
|
||||
}
|
||||
|
||||
fn try_parse_progress_updaters(funcs: Vec<Py<PyAny>>) -> PyResult<Vec<Arc<dyn TrackingProgressUpdater>>> {
|
||||
fn try_parse_progress_updaters(
|
||||
funcs: Vec<Py<PyAny>>,
|
||||
ctx: &XetContext,
|
||||
) -> PyResult<Vec<Arc<dyn TrackingProgressUpdater>>> {
|
||||
let mut updaters = Vec::with_capacity(funcs.len());
|
||||
for updater_func in funcs {
|
||||
let wrapped = Arc::new(WrappedProgressUpdater::new(updater_func)?);
|
||||
let wrapped = Arc::new(WrappedProgressUpdater::new(updater_func, ctx.clone())?);
|
||||
updaters.push(wrapped as Arc<dyn TrackingProgressUpdater>);
|
||||
}
|
||||
Ok(updaters)
|
||||
|
||||
@@ -36,7 +36,9 @@ pub fn init_logging(py: Python) {
|
||||
let xet_cache_directory = xet_runtime::core::xet_cache_root();
|
||||
let log_dir = xet_cache_directory.join("logs");
|
||||
|
||||
let cfg = LoggingConfig::default_to_directory(version_info, log_dir);
|
||||
// Called before any XetContext is created, so we use a standalone default config for
|
||||
// early-init logging setup.
|
||||
let cfg = LoggingConfig::from_directory(&xet_runtime::config::XetConfig::new(), version_info, log_dir);
|
||||
|
||||
xet_runtime::logging::init(cfg);
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ use pyo3::types::{IntoPyDict, PyList, PyString};
|
||||
use pyo3::{IntoPyObjectExt, Py, PyAny, PyResult, Python, pyclass};
|
||||
use tracing::error;
|
||||
use xet_pkg::legacy::progress_tracking::{ProgressUpdate, TrackingProgressUpdater};
|
||||
use xet_runtime::core::XetRuntime;
|
||||
use xet_runtime::core::XetContext;
|
||||
use xet_runtime::error_printer::ErrorPrinter;
|
||||
|
||||
use crate::runtime::convert_multithreading_error;
|
||||
@@ -121,6 +121,8 @@ pub struct PyTotalProgressUpdate {
|
||||
/// passed around as a ProgressUpdater trait object or
|
||||
/// as a template parameter
|
||||
struct WrappedProgressUpdaterImpl {
|
||||
ctx: XetContext,
|
||||
|
||||
/// Is this enabled?
|
||||
progress_updating_enabled: bool,
|
||||
|
||||
@@ -144,7 +146,7 @@ impl Debug for WrappedProgressUpdaterImpl {
|
||||
const DETAILED_PROGRESS_ARG_NAMES: [&str; 2] = ["total_update", "item_updates"];
|
||||
|
||||
impl WrappedProgressUpdaterImpl {
|
||||
pub fn new(py_func: Py<PyAny>) -> PyResult<Self> {
|
||||
pub fn new(py_func: Py<PyAny>, ctx: XetContext) -> PyResult<Self> {
|
||||
// Analyze the function to make sure it's the correct form. If it's 4 arguments with
|
||||
// the appropriate names, than we call it using the detailed progress update; if it's
|
||||
// a single function, we assume it's a global increment function and just pass in the update
|
||||
@@ -157,6 +159,7 @@ impl WrappedProgressUpdaterImpl {
|
||||
// Test if it's enabled first; if None is passed in, then this is disabled.
|
||||
if py_func.is_none(py) {
|
||||
return Ok(Self {
|
||||
ctx,
|
||||
progress_updating_enabled: false,
|
||||
py_func,
|
||||
name: Default::default(),
|
||||
@@ -212,6 +215,7 @@ impl WrappedProgressUpdaterImpl {
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
ctx,
|
||||
progress_updating_enabled: true,
|
||||
py_func,
|
||||
name,
|
||||
@@ -221,71 +225,71 @@ impl WrappedProgressUpdaterImpl {
|
||||
}
|
||||
|
||||
async fn register_updates_impl(self: Arc<Self>, updates: ProgressUpdate) -> PyResult<()> {
|
||||
// Run on compute thread that doesn't block async workers
|
||||
let rt = XetRuntime::current();
|
||||
rt.spawn_blocking(move || {
|
||||
Python::attach(|py| {
|
||||
let f = self.py_func.bind(py);
|
||||
let runtime = self.ctx.runtime.clone();
|
||||
runtime
|
||||
.spawn_blocking(move || {
|
||||
Python::attach(|py| {
|
||||
let f = self.py_func.bind(py);
|
||||
|
||||
if self.update_with_detailed_progress {
|
||||
let total_update_report: Py<PyAny> = Py::new(
|
||||
py,
|
||||
PyTotalProgressUpdate {
|
||||
total_bytes: updates.total_bytes,
|
||||
total_bytes_increment: updates.total_bytes_increment,
|
||||
total_bytes_completed: updates.total_bytes_completed,
|
||||
total_bytes_completion_increment: updates.total_bytes_completion_increment,
|
||||
total_bytes_completion_rate: updates.total_bytes_completion_rate,
|
||||
total_transfer_bytes: updates.total_transfer_bytes,
|
||||
total_transfer_bytes_increment: updates.total_transfer_bytes_increment,
|
||||
total_transfer_bytes_completed: updates.total_transfer_bytes_completed,
|
||||
total_transfer_bytes_completion_increment: updates
|
||||
.total_transfer_bytes_completion_increment,
|
||||
total_transfer_bytes_completion_rate: updates.total_transfer_bytes_completion_rate,
|
||||
},
|
||||
)?
|
||||
.into_py_any(py)?;
|
||||
if self.update_with_detailed_progress {
|
||||
let total_update_report: Py<PyAny> = Py::new(
|
||||
py,
|
||||
PyTotalProgressUpdate {
|
||||
total_bytes: updates.total_bytes,
|
||||
total_bytes_increment: updates.total_bytes_increment,
|
||||
total_bytes_completed: updates.total_bytes_completed,
|
||||
total_bytes_completion_increment: updates.total_bytes_completion_increment,
|
||||
total_bytes_completion_rate: updates.total_bytes_completion_rate,
|
||||
total_transfer_bytes: updates.total_transfer_bytes,
|
||||
total_transfer_bytes_increment: updates.total_transfer_bytes_increment,
|
||||
total_transfer_bytes_completed: updates.total_transfer_bytes_completed,
|
||||
total_transfer_bytes_completion_increment: updates
|
||||
.total_transfer_bytes_completion_increment,
|
||||
total_transfer_bytes_completion_rate: updates.total_transfer_bytes_completion_rate,
|
||||
},
|
||||
)?
|
||||
.into_py_any(py)?;
|
||||
|
||||
let item_updates_v: Vec<Py<PyAny>> = updates
|
||||
.item_updates
|
||||
.into_iter()
|
||||
.map(|u| {
|
||||
Py::new(
|
||||
py,
|
||||
PyItemProgressUpdate {
|
||||
item_name: PyString::new(py, &u.item_name).into(),
|
||||
total_bytes: u.total_bytes,
|
||||
bytes_completed: u.bytes_completed,
|
||||
bytes_completion_increment: u.bytes_completion_increment,
|
||||
},
|
||||
)?
|
||||
.into_py_any(py)
|
||||
})
|
||||
.collect::<PyResult<Vec<_>>>()?;
|
||||
let item_updates_v: Vec<Py<PyAny>> = updates
|
||||
.item_updates
|
||||
.into_iter()
|
||||
.map(|u| {
|
||||
Py::new(
|
||||
py,
|
||||
PyItemProgressUpdate {
|
||||
item_name: PyString::new(py, &u.item_name).into(),
|
||||
total_bytes: u.total_bytes,
|
||||
bytes_completed: u.bytes_completed,
|
||||
bytes_completion_increment: u.bytes_completion_increment,
|
||||
},
|
||||
)?
|
||||
.into_py_any(py)
|
||||
})
|
||||
.collect::<PyResult<Vec<_>>>()?;
|
||||
|
||||
let item_updates: Py<PyAny> = PyList::new(py, item_updates_v)?.into_py_any(py)?;
|
||||
let item_updates: Py<PyAny> = PyList::new(py, item_updates_v)?.into_py_any(py)?;
|
||||
|
||||
let argname_total_update: Py<PyAny> = DETAILED_PROGRESS_ARG_NAMES[0].into_py_any(py)?;
|
||||
let argname_item_updates: Py<PyAny> = DETAILED_PROGRESS_ARG_NAMES[1].into_py_any(py)?;
|
||||
let argname_total_update: Py<PyAny> = DETAILED_PROGRESS_ARG_NAMES[0].into_py_any(py)?;
|
||||
let argname_item_updates: Py<PyAny> = DETAILED_PROGRESS_ARG_NAMES[1].into_py_any(py)?;
|
||||
|
||||
let kwargs = [
|
||||
(argname_total_update, total_update_report),
|
||||
(argname_item_updates, item_updates),
|
||||
]
|
||||
.into_py_dict(py)?;
|
||||
let kwargs = [
|
||||
(argname_total_update, total_update_report),
|
||||
(argname_item_updates, item_updates),
|
||||
]
|
||||
.into_py_dict(py)?;
|
||||
|
||||
f.call((), Some(&kwargs))?;
|
||||
} else {
|
||||
let update_increment: u64 =
|
||||
updates.item_updates.iter().map(|pr| pr.bytes_completion_increment).sum();
|
||||
let _ = f.call1((update_increment,))?;
|
||||
}
|
||||
f.call((), Some(&kwargs))?;
|
||||
} else {
|
||||
let update_increment: u64 =
|
||||
updates.item_updates.iter().map(|pr| pr.bytes_completion_increment).sum();
|
||||
let _ = f.call1((update_increment,))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
Ok(())
|
||||
})
|
||||
})
|
||||
})
|
||||
.await
|
||||
.map_err(convert_multithreading_error)?
|
||||
.await
|
||||
.map_err(convert_multithreading_error)?
|
||||
}
|
||||
}
|
||||
|
||||
@@ -295,9 +299,9 @@ pub struct WrappedProgressUpdater {
|
||||
}
|
||||
|
||||
impl WrappedProgressUpdater {
|
||||
pub fn new(py_func: Py<PyAny>) -> PyResult<Self> {
|
||||
pub fn new(py_func: Py<PyAny>, ctx: XetContext) -> PyResult<Self> {
|
||||
Ok(Self {
|
||||
inner: Arc::new(WrappedProgressUpdaterImpl::new(py_func)?),
|
||||
inner: Arc::new(WrappedProgressUpdaterImpl::new(py_func, ctx)?),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,13 +8,13 @@ use pyo3::prelude::*;
|
||||
use tracing::info;
|
||||
use xet_pkg::XetError;
|
||||
use xet_runtime::RuntimeError;
|
||||
use xet_runtime::core::XetRuntime;
|
||||
use xet_runtime::core::XetContext;
|
||||
use xet_runtime::core::sync_primatives::spawn_os_thread;
|
||||
|
||||
lazy_static! {
|
||||
static ref SIGINT_DETECTED: Arc<AtomicBool> = Arc::new(AtomicBool::new(false));
|
||||
static ref SIGINT_HANDLER_INSTALL_PID: (AtomicU32, Mutex<()>) = (AtomicU32::new(0), Mutex::new(()));
|
||||
static ref MULTITHREADED_RUNTIME: RwLock<Option<(u32, Arc<XetRuntime>)>> = RwLock::new(None);
|
||||
static ref MULTITHREADED_RUNTIME: RwLock<Option<(u32, XetContext)>> = RwLock::new(None);
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
@@ -101,12 +101,12 @@ pub(crate) fn perform_sigint_shutdown() {
|
||||
let maybe_runtime = MULTITHREADED_RUNTIME.write().unwrap().take();
|
||||
|
||||
// Shut it down gracefully if we own it in this process.
|
||||
if let Some((runtime_pid, ref runtime)) = maybe_runtime {
|
||||
if let Some((runtime_pid, ref ctx)) = maybe_runtime {
|
||||
// Only do anything with the runtime if we're on the right process.
|
||||
// Otherwise, it's none of our business.
|
||||
if runtime_pid == std::process::id() && runtime.external_executor_count() != 0 {
|
||||
if runtime_pid == std::process::id() && ctx.runtime.external_executor_count() != 0 {
|
||||
eprintln!("Cancellation requested; stopping current tasks.");
|
||||
runtime.perform_sigint_shutdown();
|
||||
ctx.runtime.perform_sigint_shutdown();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -139,36 +139,33 @@ fn signal_check_background_loop() {
|
||||
}
|
||||
|
||||
// This should be called once on library load.
|
||||
pub fn init_threadpool() -> Result<Arc<XetRuntime>, RuntimeError> {
|
||||
pub fn init_threadpool() -> Result<XetContext, RuntimeError> {
|
||||
// Need to initialize. Upgrade to write lock.
|
||||
let mut guard = MULTITHREADED_RUNTIME.write().unwrap();
|
||||
|
||||
// Has another thread done this already?
|
||||
let pid = std::process::id();
|
||||
|
||||
if let Some((runtime_pid, existing)) = guard.take() {
|
||||
if let Some((runtime_pid, ref existing)) = *guard {
|
||||
if runtime_pid == pid {
|
||||
// We're OK, so reset it here.
|
||||
*guard = Some((pid, existing.clone()));
|
||||
return Ok(existing);
|
||||
return Ok(existing.clone());
|
||||
} else {
|
||||
// Ok, discard the previous runtime, as it's effectively poisoned by the
|
||||
// fork-exec, and we simply need to leak it and restart from scratch. The memory and
|
||||
// resources will be freed up when the child exits.
|
||||
existing.discard_runtime();
|
||||
existing.runtime.discard_runtime();
|
||||
|
||||
info!("Runtime restarted due to detected process ID change, likely due to running inside a fork call.");
|
||||
}
|
||||
}
|
||||
|
||||
// Create a new Tokio runtime.
|
||||
let runtime = XetRuntime::new()?;
|
||||
let ctx = XetContext::default()?;
|
||||
|
||||
// Check the signal handler. This must be reinstalled on new or after a spawn
|
||||
check_sigint_handler()?;
|
||||
|
||||
// Set the runtime in the global tracker.
|
||||
*guard = Some((pid, runtime.clone()));
|
||||
*guard = Some((pid, ctx.clone()));
|
||||
|
||||
// Spawn a background non-tokio thread to check the sigint flag.
|
||||
std::thread::spawn(signal_check_background_loop);
|
||||
@@ -185,12 +182,10 @@ pub fn init_threadpool() -> Result<Arc<XetRuntime>, RuntimeError> {
|
||||
// being initialized.)
|
||||
drop(guard);
|
||||
|
||||
// Return the runtime
|
||||
Ok(runtime)
|
||||
Ok(ctx)
|
||||
}
|
||||
|
||||
// This function initializes the runtime if not present, otherwise returns the existing one.
|
||||
fn get_threadpool() -> Result<Arc<XetRuntime>, RuntimeError> {
|
||||
pub(crate) fn get_or_init_runtime() -> Result<XetContext, RuntimeError> {
|
||||
// First try a read lock to see if it's already initialized.
|
||||
{
|
||||
let guard = MULTITHREADED_RUNTIME.read().unwrap();
|
||||
@@ -203,8 +198,6 @@ fn get_threadpool() -> Result<Arc<XetRuntime>, RuntimeError> {
|
||||
}
|
||||
}
|
||||
|
||||
// Init and return
|
||||
|
||||
init_threadpool()
|
||||
}
|
||||
|
||||
@@ -222,9 +215,9 @@ where
|
||||
// Now, without the GIL, spawn the task on a new OS thread. This avoids having tokio cache stuff in
|
||||
// thread-local storage that is invalidated after a fork-exec.
|
||||
spawn_os_thread(move || {
|
||||
let runtime = get_threadpool().map_err(convert_multithreading_error)?;
|
||||
let ctx = get_or_init_runtime().map_err(convert_multithreading_error)?;
|
||||
|
||||
runtime
|
||||
ctx.runtime
|
||||
.external_run_async_task(execution_call)
|
||||
.map_err(convert_multithreading_error)?
|
||||
.into()
|
||||
|
||||
@@ -19,7 +19,7 @@ use simulation::upload_concurrency::run_upload_clients_until_cancelled;
|
||||
use tokio::time::sleep;
|
||||
use tracing::info;
|
||||
use xet_client::cas_client::simulation::local_server::ServerLatencyProfile;
|
||||
use xet_runtime::core::XetRuntime;
|
||||
use xet_runtime::core::XetContext;
|
||||
use xet_runtime::logging::{LoggingConfig, init as init_logging};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@@ -136,27 +136,32 @@ fn setup_logging(out_dir: &Path) {
|
||||
let log_dest = format!("{}/", out_dir.display());
|
||||
// SAFETY: Called from main() before any threads are spawned.
|
||||
unsafe { std::env::set_var("HF_XET_LOG_DEST", &log_dest) };
|
||||
init_logging(LoggingConfig::default_to_directory("run_upload_scenario".to_string(), out_dir));
|
||||
init_logging(LoggingConfig::from_directory(
|
||||
&xet_runtime::config::XetConfig::new(),
|
||||
"run_upload_scenario".to_string(),
|
||||
out_dir,
|
||||
));
|
||||
}
|
||||
|
||||
/// Runs an async future on a fresh multi-threaded tokio runtime, using XetRuntime for initialization.
|
||||
/// Runs an async future on a fresh multi-threaded tokio runtime, using XetContext for initialization.
|
||||
fn run_async<F>(future: F) -> ScenarioResult<()>
|
||||
where
|
||||
F: Future<Output = ScenarioResult<()>> + Send + 'static,
|
||||
{
|
||||
let xet = XetRuntime::new().map_err(|e| ScenarioError::Runtime(e.to_string()))?;
|
||||
xet.bridge_sync(async move {
|
||||
tokio::task::spawn_blocking(move || {
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.map_err(|e| ScenarioError::Runtime(e.to_string()))?
|
||||
.block_on(future)
|
||||
let xet = XetContext::default().map_err(|e| ScenarioError::Runtime(e.to_string()))?;
|
||||
xet.runtime
|
||||
.bridge_sync(async move {
|
||||
tokio::task::spawn_blocking(move || {
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.map_err(|e| ScenarioError::Runtime(e.to_string()))?
|
||||
.block_on(future)
|
||||
})
|
||||
.await
|
||||
.map_err(ScenarioError::from)?
|
||||
})
|
||||
.await
|
||||
.map_err(ScenarioError::from)?
|
||||
})
|
||||
.map_err(|e| ScenarioError::Runtime(e.to_string()))?
|
||||
.map_err(|e| ScenarioError::Runtime(e.to_string()))?
|
||||
}
|
||||
|
||||
// ── Entry point ──────────────────────────────────────────────────────────────
|
||||
|
||||
@@ -10,6 +10,8 @@ use xet_client::cas_client::simulation::NetworkProfileOptions;
|
||||
use xet_client::cas_client::simulation::local_server::ServerLatencyProfile;
|
||||
use xet_client::cas_client::{LocalTestServer, LocalTestServerBuilder};
|
||||
use xet_client::common::http_client::build_http_client;
|
||||
use xet_runtime::config::XetConfig;
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
use crate::upload_concurrency::generate_timeline_csv;
|
||||
|
||||
@@ -240,8 +242,9 @@ impl SimulationScenario {
|
||||
/// Waits for the server (and proxy if present) to be ready by GETting the ping endpoint every 100ms.
|
||||
/// Errors if not ready within 30s.
|
||||
async fn wait_until_ready(&self) -> ScenarioResult<()> {
|
||||
let http_client =
|
||||
build_http_client("simulation_scenario", None, None).map_err(|e| ScenarioError::Scenario(e.to_string()))?;
|
||||
let ctx = XetContext::from_external(tokio::runtime::Handle::current(), XetConfig::new());
|
||||
let http_client = build_http_client(&ctx, "simulation_scenario", None, None)
|
||||
.map_err(|e| ScenarioError::Scenario(e.to_string()))?;
|
||||
let server_addr = self.server.http_endpoint();
|
||||
let ping_url = format!("{}{}", base_url(server_addr).trim_end_matches('/'), PING_PATH);
|
||||
let timeout = Duration::from_secs(SERVER_READY_TIMEOUT_SECS);
|
||||
|
||||
@@ -21,7 +21,8 @@ use xet_client::cas_client::adaptive_concurrency::{
|
||||
use xet_client::cas_client::progress_tracked_streams::UploadProgressStream;
|
||||
use xet_client::cas_client::retry_wrapper::RetryWrapper;
|
||||
use xet_client::common::http_client::build_http_client;
|
||||
use xet_runtime::core::xet_config;
|
||||
use xet_runtime::config::XetConfig;
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
use crate::scenario::base_url;
|
||||
|
||||
@@ -190,11 +191,13 @@ fn spawn_stats_reporter(
|
||||
|
||||
/// Shared context for upload workers.
|
||||
struct UploadContext {
|
||||
ctx: XetContext,
|
||||
url: String,
|
||||
http_client: reqwest_middleware::ClientWithMiddleware,
|
||||
base_data: Bytes,
|
||||
min_data_size: u64,
|
||||
max_data_size: u64,
|
||||
upload_reporting_block_size: usize,
|
||||
counters: UploadCounters,
|
||||
concurrency_controller: Arc<AdaptiveConcurrencyController>,
|
||||
start_instant: Arc<Mutex<Instant>>,
|
||||
@@ -203,18 +206,20 @@ struct UploadContext {
|
||||
}
|
||||
|
||||
/// Spawns `UPLOAD_TASKS` concurrent upload workers into the given `JoinSet`.
|
||||
fn spawn_upload_tasks(join_set: &mut JoinSet<()>, ctx: &UploadContext) {
|
||||
fn spawn_upload_tasks(join_set: &mut JoinSet<()>, upload: &UploadContext) {
|
||||
for _ in 0..UPLOAD_TASKS {
|
||||
let concurrency_controller = ctx.concurrency_controller.clone();
|
||||
let http_client = ctx.http_client.clone();
|
||||
let url = ctx.url.clone();
|
||||
let counters = ctx.counters.clone();
|
||||
let base_data = ctx.base_data.clone();
|
||||
let task_cancel = ctx.cancel.clone();
|
||||
let task_start = Arc::clone(&ctx.start_instant);
|
||||
let end_duration = ctx.end_duration;
|
||||
let min_data_size = ctx.min_data_size;
|
||||
let max_data_size = ctx.max_data_size;
|
||||
let xet_ctx = upload.ctx.clone();
|
||||
let concurrency_controller = upload.concurrency_controller.clone();
|
||||
let http_client = upload.http_client.clone();
|
||||
let url = upload.url.clone();
|
||||
let counters = upload.counters.clone();
|
||||
let base_data = upload.base_data.clone();
|
||||
let task_cancel = upload.cancel.clone();
|
||||
let task_start = Arc::clone(&upload.start_instant);
|
||||
let end_duration = upload.end_duration;
|
||||
let min_data_size = upload.min_data_size;
|
||||
let max_data_size = upload.max_data_size;
|
||||
let upload_reporting_block_size = upload.upload_reporting_block_size;
|
||||
|
||||
join_set.spawn(async move {
|
||||
loop {
|
||||
@@ -245,14 +250,14 @@ fn spawn_upload_tasks(join_set: &mut JoinSet<()>, ctx: &UploadContext) {
|
||||
let do_one_upload = async {
|
||||
counters.retry_wrapper_calls.fetch_add(1, Ordering::Relaxed);
|
||||
let request_start = Instant::now();
|
||||
let result = RetryWrapper::new("upload_benchmark")
|
||||
let result = RetryWrapper::new(xet_ctx.clone(), "upload_benchmark")
|
||||
.with_connection_permit(permit, Some(payload_size))
|
||||
.run({
|
||||
let http_client = http_client.clone();
|
||||
let url = url.clone();
|
||||
let payload_data = payload_data.clone();
|
||||
let http_calls = counters.http_calls.clone();
|
||||
let block_size = xet_config().client.upload_reporting_block_size;
|
||||
let block_size = upload_reporting_block_size;
|
||||
move || {
|
||||
let http_client = http_client.clone();
|
||||
let url = url.clone();
|
||||
@@ -310,7 +315,8 @@ async fn run_upload_clients_impl(
|
||||
let max_data_size = max_data_kb * 1024;
|
||||
let client_id = rand::rng().random_range(0..1000000000_u64);
|
||||
|
||||
let http_client = build_http_client("test_session", None, None)?;
|
||||
let ctx = XetContext::from_external(tokio::runtime::Handle::current(), XetConfig::new());
|
||||
let http_client = build_http_client(&ctx, "test_session", None, None)?;
|
||||
|
||||
let duration_sec = repeat_duration_seconds.unwrap_or(u64::MAX);
|
||||
let client_params = serde_json::json!({
|
||||
@@ -325,7 +331,7 @@ async fn run_upload_clients_impl(
|
||||
let params_path = output_dir.join(format!("client_parameters_{}.json", client_id));
|
||||
std::fs::write(¶ms_path, serde_json::to_string_pretty(&client_params)?)?;
|
||||
|
||||
let concurrency_controller = AdaptiveConcurrencyController::new_upload("test_uploads");
|
||||
let concurrency_controller = AdaptiveConcurrencyController::new_upload(ctx.clone(), "test_uploads");
|
||||
let start_instant = Arc::new(Mutex::new(Instant::now()));
|
||||
let end_duration = Duration::from_secs(duration_sec);
|
||||
let counters = UploadCounters::new();
|
||||
@@ -350,12 +356,16 @@ async fn run_upload_clients_impl(
|
||||
let url_base = base_url(server_addr);
|
||||
let url = format!("{}{}", url_base.trim_end_matches('/'), DUMMY_UPLOAD_PATH);
|
||||
|
||||
let upload_reporting_block_size = ctx.config.client.upload_reporting_block_size;
|
||||
|
||||
let upload_ctx = UploadContext {
|
||||
ctx,
|
||||
url,
|
||||
http_client,
|
||||
base_data,
|
||||
min_data_size,
|
||||
max_data_size,
|
||||
upload_reporting_block_size,
|
||||
counters,
|
||||
concurrency_controller,
|
||||
start_instant,
|
||||
|
||||
1
wasm/hf_xet_wasm/Cargo.lock
generated
1
wasm/hf_xet_wasm/Cargo.lock
generated
@@ -882,6 +882,7 @@ dependencies = [
|
||||
"xet-client",
|
||||
"xet-core-structures",
|
||||
"xet-data",
|
||||
"xet-runtime",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -10,6 +10,7 @@ crate-type = ["cdylib", "rlib"]
|
||||
xet-core-structures = { path = "../../xet_core_structures" }
|
||||
xet-client = { path = "../../xet_client" }
|
||||
xet-data = { path = "../../xet_data" }
|
||||
xet-runtime = { path = "../../xet_runtime" }
|
||||
|
||||
async-trait = "0.1"
|
||||
bytes = "1.11"
|
||||
|
||||
@@ -175,7 +175,11 @@ pub async fn clean_file(file: web_sys::File, endpoint: String, jwt_token: String
|
||||
session_id: uuid::Uuid::new_v4().to_string(),
|
||||
};
|
||||
|
||||
let upload_session = Arc::new(FileUploadSession::new(Arc::new(config)));
|
||||
let ctx = xet_runtime::core::XetContext::from_external(
|
||||
tokio::runtime::Handle::current(),
|
||||
xet_runtime::config::XetConfig::new(),
|
||||
);
|
||||
let upload_session = Arc::new(FileUploadSession::new(ctx, Arc::new(config)));
|
||||
|
||||
let mut handle = upload_session.start_clean(0, None);
|
||||
|
||||
|
||||
@@ -65,7 +65,11 @@ impl XetSession {
|
||||
},
|
||||
session_id: uuid::Uuid::new_v4().to_string(),
|
||||
};
|
||||
let upload = FileUploadSession::new(Arc::new(config));
|
||||
let ctx = xet_runtime::core::XetContext::from_external(
|
||||
tokio::runtime::Handle::current(),
|
||||
xet_runtime::config::XetConfig::new(),
|
||||
);
|
||||
let upload = FileUploadSession::new(ctx, Arc::new(config));
|
||||
|
||||
Self {
|
||||
upload: Arc::new(upload),
|
||||
|
||||
@@ -75,7 +75,11 @@ impl SingleFileCleaner {
|
||||
_file_id: file_id,
|
||||
session: session.clone(),
|
||||
cpu_task: CPUTask::CurrentThread((Chunker::default(), ShaGeneration::new(sha256))),
|
||||
dedup_manager: FileDeduper::new(UploadSessionDataManager::new(session), file_id),
|
||||
dedup_manager: FileDeduper::new(
|
||||
UploadSessionDataManager::new(session.clone()),
|
||||
file_id,
|
||||
session.ctx.clone(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,7 +117,11 @@ impl SingleFileCleaner {
|
||||
_file_id: file_id,
|
||||
session: session.clone(),
|
||||
cpu_task,
|
||||
dedup_manager: FileDeduper::new(UploadSessionDataManager::new(session), file_id),
|
||||
dedup_manager: FileDeduper::new(
|
||||
UploadSessionDataManager::new(session.clone()),
|
||||
file_id,
|
||||
session.ctx.clone(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,8 +8,9 @@ use xet_core_structures::merklehash::{HashedWrite, MerkleHash};
|
||||
use xet_core_structures::metadata_shard::MDBShardInfo;
|
||||
use xet_core_structures::metadata_shard::shard_in_memory::MDBInMemoryShard;
|
||||
use xet_core_structures::xorb_object::SerializedXorbObject;
|
||||
use xet_data::deduplication::constants::{XORB_CUT_THRESHOLD_BYTES, XORB_CUT_THRESHOLD_CHUNKS};
|
||||
use xet_core_structures::xorb_object::constants::{MAX_XORB_BYTES, MAX_XORB_CHUNKS};
|
||||
use xet_data::deduplication::{DataAggregator, DeduplicationMetrics, RawXorbData};
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
use super::configurations::TranslatorConfig;
|
||||
use super::errors::*;
|
||||
@@ -26,6 +27,8 @@ static UPLOAD_CONCURRENCY: usize = 5;
|
||||
/// that succeeds or fails as a unit; i.e. all files get uploaded on finalization, and all shards
|
||||
/// and xorbs needed to reconstruct those files are properly uploaded and registered.
|
||||
pub struct FileUploadSession {
|
||||
pub(crate) ctx: XetContext,
|
||||
|
||||
/// The configuration settings, if needed.
|
||||
pub(crate) config: Arc<TranslatorConfig>,
|
||||
|
||||
@@ -38,7 +41,7 @@ pub struct FileUploadSession {
|
||||
}
|
||||
|
||||
impl FileUploadSession {
|
||||
pub fn new(config: Arc<TranslatorConfig>) -> Self {
|
||||
pub fn new(ctx: XetContext, config: Arc<TranslatorConfig>) -> Self {
|
||||
let headers = match HeaderValue::from_str(&config.data_config.user_agent) {
|
||||
Ok(value) => {
|
||||
let mut headers = http::HeaderMap::new();
|
||||
@@ -49,6 +52,7 @@ impl FileUploadSession {
|
||||
};
|
||||
|
||||
let client = RemoteClient::new(
|
||||
ctx.clone(),
|
||||
&config.data_config.endpoint,
|
||||
&config.data_config.auth,
|
||||
&config.session_id,
|
||||
@@ -60,6 +64,7 @@ impl FileUploadSession {
|
||||
Box::new(XorbUploaderSpawnParallel::new(client.clone(), &config.data_config.prefix, UPLOAD_CONCURRENCY));
|
||||
|
||||
Self {
|
||||
ctx,
|
||||
session_shard: Mutex::new(MDBInMemoryShard::default()),
|
||||
xorb_uploader: Mutex::new(Some(xorb_uploader)),
|
||||
config,
|
||||
@@ -82,8 +87,8 @@ impl FileUploadSession {
|
||||
let mut current_session_data = self.current_session_data.lock().await;
|
||||
|
||||
// Do we need to cut one of these to a xorb?
|
||||
if current_session_data.num_bytes() + file_data.num_bytes() > *XORB_CUT_THRESHOLD_BYTES
|
||||
|| current_session_data.num_chunks() + file_data.num_chunks() > *XORB_CUT_THRESHOLD_CHUNKS
|
||||
if current_session_data.num_bytes() + file_data.num_bytes() > *MAX_XORB_BYTES
|
||||
|| current_session_data.num_chunks() + file_data.num_chunks() > *MAX_XORB_CHUNKS
|
||||
{
|
||||
// Cut the larger one as a xorb, uploading it and registering the files.
|
||||
if current_session_data.num_bytes() > file_data.num_bytes() {
|
||||
@@ -126,8 +131,12 @@ impl FileUploadSession {
|
||||
}
|
||||
|
||||
// XORBs are sent without footer - the server/client reconstructs it from chunk data.
|
||||
let xorb_obj =
|
||||
SerializedXorbObject::from_xorb_with_compression(xorb, self.config.data_config.compression, false)?;
|
||||
let xorb_obj = SerializedXorbObject::from_xorb_with_compression(
|
||||
xorb,
|
||||
self.config.data_config.compression,
|
||||
false,
|
||||
self.ctx.config.xorb.compression_scheme_retest_interval,
|
||||
)?;
|
||||
|
||||
let Some(ref mut xorb_uploader) = *self.xorb_uploader.lock().await else {
|
||||
return Err(DataProcessingError::internal("register xorb after drop"));
|
||||
|
||||
@@ -3,8 +3,8 @@ use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tokio_with_wasm::alias as wasmtokio;
|
||||
use xet_client::cas_client::Client;
|
||||
use xet_client::ClientError;
|
||||
use xet_client::cas_client::Client;
|
||||
use xet_core_structures::xorb_object::SerializedXorbObject;
|
||||
|
||||
use crate::errors::*;
|
||||
|
||||
@@ -11,7 +11,7 @@ use tracing::info;
|
||||
#[cfg(target_family = "wasm")]
|
||||
use web_time::Instant;
|
||||
use xet_core_structures::ExpWeightedMovingAvg;
|
||||
use xet_runtime::core::xet_config;
|
||||
use xet_runtime::core::XetContext;
|
||||
use xet_runtime::utils::adjustable_semaphore::{AdjustableSemaphore, AdjustableSemaphorePermit};
|
||||
|
||||
use super::super::progress_tracked_streams::ProgressCallback;
|
||||
@@ -42,6 +42,7 @@ pub struct CCLatencyModelState {
|
||||
|
||||
/// The internal state of the concurrency controller.
|
||||
struct ConcurrencyControllerState {
|
||||
ctx: XetContext,
|
||||
/// A running model of the current bandwidth. Uses an exponentially weighted average to predict the
|
||||
rtt_predictor: RTTPredictor,
|
||||
|
||||
@@ -73,12 +74,13 @@ struct ConcurrencyControllerState {
|
||||
}
|
||||
|
||||
impl ConcurrencyControllerState {
|
||||
fn new() -> Self {
|
||||
let config = xet_config();
|
||||
fn new(ctx: XetContext) -> Self {
|
||||
let config = &ctx.config;
|
||||
let rtt_half_life_count = config.client.ac_latency_rtt_half_life;
|
||||
let success_half_life_count = config.client.ac_success_tracking_half_life;
|
||||
|
||||
Self {
|
||||
ctx,
|
||||
rtt_predictor: RTTPredictor::new(rtt_half_life_count),
|
||||
success_ratio_tracking: ExpWeightedMovingAvg::new_count_decay(success_half_life_count),
|
||||
last_adjustment_time: Instant::now(),
|
||||
@@ -92,7 +94,7 @@ impl ConcurrencyControllerState {
|
||||
}
|
||||
|
||||
fn success_ratio_thresholds(&self) -> (f64, f64) {
|
||||
let config = xet_config();
|
||||
let config = &self.ctx.config;
|
||||
let increase_threshold = config.client.ac_healthy_success_ratio_threshold;
|
||||
let decrease_threshold = config.client.ac_unhealthy_success_ratio_threshold;
|
||||
(increase_threshold, decrease_threshold)
|
||||
@@ -130,7 +132,7 @@ impl ConcurrencyControllerState {
|
||||
|
||||
#[inline]
|
||||
fn latency_model_state(&self, current_concurrency: f64) -> CCLatencyModelState {
|
||||
let config = xet_config();
|
||||
let config = &self.ctx.config;
|
||||
let (predicted_max_rtt, prediction_max_rtt_standard_error) = self
|
||||
.rtt_predictor
|
||||
.predict(*config.client.ac_max_reference_transmission_size, current_concurrency);
|
||||
@@ -159,7 +161,7 @@ impl ConcurrencyControllerState {
|
||||
|
||||
let quantile_95 = (mu + REFERENCE_SIZE_QUANTILE_Z * sigma).exp();
|
||||
|
||||
let config = xet_config();
|
||||
let config = &self.ctx.config;
|
||||
let min_size = *config.client.ac_min_reference_transmission_size;
|
||||
let max_size = *config.client.ac_max_reference_transmission_size;
|
||||
|
||||
@@ -234,15 +236,14 @@ impl ConcurrencyControllerState {
|
||||
/// ```ignore
|
||||
/// use crate::adaptive_concurrency::{AdaptiveConcurrencyController, ConnectionPermit};
|
||||
/// use crate::retry_wrapper::RetryWrapper;
|
||||
/// use xet_runtime::core::XetContext;
|
||||
///
|
||||
/// // Create a controller (typically done once during client initialization)
|
||||
/// let upload_controller = AdaptiveConcurrencyController::new_upload("upload");
|
||||
/// let ctx = XetContext::default()?;
|
||||
/// let upload_controller = AdaptiveConcurrencyController::new_upload(ctx.clone(), "upload");
|
||||
///
|
||||
/// // Before making a request, acquire a permit
|
||||
/// let permit = upload_controller.acquire_connection_permit().await?;
|
||||
///
|
||||
/// // Use the permit with RetryWrapper to track the transfer
|
||||
/// let response: UploadResponse = RetryWrapper::new("cas::upload_shard")
|
||||
/// let response: UploadResponse = RetryWrapper::new(ctx, "cas::upload_shard")
|
||||
/// .with_connection_permit(permit, Some(shard_data.len() as u64))
|
||||
/// .run_and_extract_json(move |_partial_report_fn| {
|
||||
/// client.post(url.clone()).body(shard_data.clone()).send()
|
||||
@@ -258,6 +259,7 @@ impl ConcurrencyControllerState {
|
||||
///
|
||||
/// The controller uses these reports to update its models and adjust concurrency accordingly.
|
||||
pub struct AdaptiveConcurrencyController {
|
||||
ctx: XetContext,
|
||||
// The current state, including tracking information and when previous adjustments were made.
|
||||
// Also holds related constants
|
||||
state: Mutex<ConcurrencyControllerState>,
|
||||
@@ -284,6 +286,7 @@ pub struct AdaptiveConcurrencyController {
|
||||
|
||||
impl AdaptiveConcurrencyController {
|
||||
pub fn new(
|
||||
ctx: XetContext,
|
||||
logging_tag: &'static str,
|
||||
concurrency: usize,
|
||||
concurrency_bounds: (usize, usize),
|
||||
@@ -299,9 +302,10 @@ impl AdaptiveConcurrencyController {
|
||||
"Initializing Adaptive Concurrency Controller for {logging_tag} with starting concurrency = {current_concurrency}; min = {min_concurrency}, max = {max_concurrency}, min_bytes_for_adjustment = {min_bytes_required_for_adjustment}, min_completed_transmissions_for_adjustment = {min_completed_transmissions_required_for_adjustment}"
|
||||
);
|
||||
|
||||
let config = xet_config();
|
||||
let config = &ctx.config;
|
||||
Arc::new(Self {
|
||||
state: Mutex::new(ConcurrencyControllerState::new()),
|
||||
ctx: ctx.clone(),
|
||||
state: Mutex::new(ConcurrencyControllerState::new(ctx.clone())),
|
||||
concurrency_semaphore: AdjustableSemaphore::new(
|
||||
current_concurrency as u64,
|
||||
(min_concurrency as u64, max_concurrency as u64),
|
||||
@@ -317,11 +321,12 @@ impl AdaptiveConcurrencyController {
|
||||
}
|
||||
|
||||
/// Create a new concurrency controller with a fixed maximum concurrency; adjustments are disabled.
|
||||
pub fn new_fixed(logging_tag: &'static str, concurrency: usize) -> Arc<Self> {
|
||||
pub fn new_fixed(ctx: XetContext, logging_tag: &'static str, concurrency: usize) -> Arc<Self> {
|
||||
info!("Fixing maximum concurrency for {logging_tag} at {concurrency}; adaptive concurrency disabled.");
|
||||
|
||||
Arc::new(Self {
|
||||
state: Mutex::new(ConcurrencyControllerState::new()),
|
||||
ctx: ctx.clone(),
|
||||
state: Mutex::new(ConcurrencyControllerState::new(ctx)),
|
||||
concurrency_semaphore: AdjustableSemaphore::new(
|
||||
concurrency as u64,
|
||||
(concurrency as u64, concurrency as u64),
|
||||
@@ -337,9 +342,10 @@ impl AdaptiveConcurrencyController {
|
||||
|
||||
/// Create a new concurrency controller for uploads using configuration values.
|
||||
/// This will use adaptive concurrency if enabled, otherwise fixed concurrency.
|
||||
pub fn new_upload(logging_tag: &'static str) -> Arc<Self> {
|
||||
let config = xet_config();
|
||||
pub fn new_upload(ctx: XetContext, logging_tag: &'static str) -> Arc<Self> {
|
||||
let config = ctx.config.clone();
|
||||
Self::new(
|
||||
ctx,
|
||||
logging_tag,
|
||||
config.client.ac_initial_upload_concurrency,
|
||||
(config.client.ac_min_upload_concurrency, config.client.ac_max_upload_concurrency),
|
||||
@@ -350,9 +356,10 @@ impl AdaptiveConcurrencyController {
|
||||
|
||||
/// Create a new concurrency controller for downloads using configuration values.
|
||||
/// This will use adaptive concurrency if enabled, otherwise fixed concurrency.
|
||||
pub fn new_download(logging_tag: &'static str) -> Arc<Self> {
|
||||
let config = xet_config();
|
||||
pub fn new_download(ctx: XetContext, logging_tag: &'static str) -> Arc<Self> {
|
||||
let config = ctx.config.clone();
|
||||
Self::new(
|
||||
ctx,
|
||||
logging_tag,
|
||||
config.client.ac_initial_download_concurrency,
|
||||
(config.client.ac_min_download_concurrency, config.client.ac_max_download_concurrency),
|
||||
@@ -446,7 +453,7 @@ impl AdaptiveConcurrencyController {
|
||||
let t_actual = elapsed_time.as_secs_f64().max(1e-4);
|
||||
|
||||
// Track if the transfer completed within a healthy time.
|
||||
let config = xet_config();
|
||||
let config = &self.ctx.config;
|
||||
let completed_in_time = elapsed_time < config.client.ac_max_healthy_rtt;
|
||||
|
||||
let mut state_lg = self.state.lock().await;
|
||||
@@ -731,7 +738,6 @@ impl ConnectionPermit {
|
||||
// Testing routines.
|
||||
#[cfg(test)]
|
||||
mod test_constants {
|
||||
|
||||
pub const TR_HALF_LIFE_COUNT: f64 = 10.0;
|
||||
pub const INCR_SPACING_MS: u64 = 200;
|
||||
pub const DECR_SPACING_MS: u64 = 100;
|
||||
@@ -743,11 +749,11 @@ mod test_constants {
|
||||
|
||||
#[cfg(test)]
|
||||
impl ConcurrencyControllerState {
|
||||
#[cfg(test)]
|
||||
fn new_testing() -> Self {
|
||||
fn new_testing(ctx: XetContext) -> Self {
|
||||
use self::test_constants::TR_HALF_LIFE_COUNT;
|
||||
|
||||
Self {
|
||||
ctx,
|
||||
rtt_predictor: RTTPredictor::new(TR_HALF_LIFE_COUNT),
|
||||
success_ratio_tracking: ExpWeightedMovingAvg::new_count_decay(TR_HALF_LIFE_COUNT),
|
||||
last_adjustment_time: Instant::now(),
|
||||
@@ -764,8 +770,10 @@ impl ConcurrencyControllerState {
|
||||
#[cfg(test)]
|
||||
impl AdaptiveConcurrencyController {
|
||||
pub fn new_testing(concurrency: usize, concurrency_bounds: (usize, usize)) -> Arc<Self> {
|
||||
let ctx = XetContext::default().expect("test runtime");
|
||||
Arc::new(Self {
|
||||
state: Mutex::new(ConcurrencyControllerState::new_testing()),
|
||||
ctx: ctx.clone(),
|
||||
state: Mutex::new(ConcurrencyControllerState::new_testing(ctx)),
|
||||
concurrency_semaphore: AdjustableSemaphore::new(
|
||||
concurrency as u64,
|
||||
(concurrency_bounds.0 as u64, concurrency_bounds.1 as u64),
|
||||
@@ -783,6 +791,7 @@ impl AdaptiveConcurrencyController {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use tokio::time::{self, Duration, advance};
|
||||
use xet_runtime::config::XetConfig;
|
||||
|
||||
use super::test_constants::*;
|
||||
use super::*;
|
||||
@@ -984,13 +993,15 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_reference_size_returns_none_with_insufficient_data() {
|
||||
let state = ConcurrencyControllerState::new_testing();
|
||||
let ctx = XetContext::default().expect("test runtime");
|
||||
let state = ConcurrencyControllerState::new_testing(ctx);
|
||||
assert!(state.estimated_reference_transmission_size().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reference_size_with_uniform_sizes() {
|
||||
let mut state = ConcurrencyControllerState::new_testing();
|
||||
let ctx = XetContext::default().expect("test runtime");
|
||||
let mut state = ConcurrencyControllerState::new_testing(ctx);
|
||||
|
||||
let size: u64 = 10 * 1024 * 1024; // 10 MB
|
||||
for _ in 0..10 {
|
||||
@@ -998,7 +1009,7 @@ mod tests {
|
||||
}
|
||||
|
||||
let ref_size = state.estimated_reference_transmission_size().unwrap();
|
||||
let config = xet_config();
|
||||
let config = XetConfig::new();
|
||||
|
||||
// With zero variance, the 95th percentile should equal the mean (~10MB).
|
||||
debug_assert!(ref_size >= *config.client.ac_min_reference_transmission_size);
|
||||
@@ -1008,21 +1019,23 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_reference_size_bounded_by_minimum() {
|
||||
let mut state = ConcurrencyControllerState::new_testing();
|
||||
let ctx = XetContext::default().expect("test runtime");
|
||||
let mut state = ConcurrencyControllerState::new_testing(ctx);
|
||||
|
||||
let size: u64 = 1024; // 1 KB
|
||||
for _ in 0..10 {
|
||||
state.update_size_tracking(size);
|
||||
}
|
||||
|
||||
let config = xet_config();
|
||||
let config = XetConfig::new();
|
||||
let ref_size = state.estimated_reference_transmission_size().unwrap();
|
||||
assert_eq!(ref_size, *config.client.ac_min_reference_transmission_size);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reference_size_bounded_by_config_maximum() {
|
||||
let mut state = ConcurrencyControllerState::new_testing();
|
||||
let ctx = XetContext::default().expect("test runtime");
|
||||
let mut state = ConcurrencyControllerState::new_testing(ctx);
|
||||
|
||||
let size: u64 = 200 * 1024 * 1024; // 200 MB (above the 64MB config default)
|
||||
for _ in 0..10 {
|
||||
@@ -1030,13 +1043,14 @@ mod tests {
|
||||
}
|
||||
|
||||
let ref_size = state.estimated_reference_transmission_size().unwrap();
|
||||
let config = xet_config();
|
||||
let config = XetConfig::new();
|
||||
assert!(ref_size <= *config.client.ac_max_reference_transmission_size);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reference_size_skips_zero_byte_transfers() {
|
||||
let mut state = ConcurrencyControllerState::new_testing();
|
||||
let ctx = XetContext::default().expect("test runtime");
|
||||
let mut state = ConcurrencyControllerState::new_testing(ctx);
|
||||
|
||||
for _ in 0..10 {
|
||||
state.update_size_tracking(0);
|
||||
@@ -1048,15 +1062,17 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_reference_size_with_mixed_sizes() {
|
||||
let config = xet_config();
|
||||
let config = XetConfig::new();
|
||||
|
||||
let mut small_only_state = ConcurrencyControllerState::new_testing();
|
||||
let ctx_small = XetContext::default().expect("test runtime");
|
||||
let mut small_only_state = ConcurrencyControllerState::new_testing(ctx_small);
|
||||
for _ in 0..10 {
|
||||
small_only_state.update_size_tracking(512 * 1024); // 512 KB
|
||||
}
|
||||
let small_only_ref_size = small_only_state.estimated_reference_transmission_size().unwrap();
|
||||
|
||||
let mut state = ConcurrencyControllerState::new_testing();
|
||||
let ctx = XetContext::default().expect("test runtime");
|
||||
let mut state = ConcurrencyControllerState::new_testing(ctx);
|
||||
|
||||
// Mix of small and large transfers
|
||||
for _ in 0..5 {
|
||||
|
||||
@@ -6,6 +6,7 @@ use std::time::{SystemTime, UNIX_EPOCH};
|
||||
use reqwest_middleware::ClientWithMiddleware;
|
||||
use thiserror::Error;
|
||||
use tracing::info;
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
use crate::common::auth::CredentialHelper;
|
||||
|
||||
@@ -67,6 +68,7 @@ impl TokenRefresher for ErrTokenRefresher {
|
||||
/// An optional [`CredentialHelper`](crate::common::auth::CredentialHelper) is applied to the
|
||||
/// request before it is sent; pass `None` when no additional credentials are needed.
|
||||
pub struct DirectRefreshRouteTokenRefresher {
|
||||
ctx: XetContext,
|
||||
refresh_route: String,
|
||||
client: ClientWithMiddleware,
|
||||
cred_helper: Option<Arc<dyn CredentialHelper>>,
|
||||
@@ -82,11 +84,13 @@ impl std::fmt::Debug for DirectRefreshRouteTokenRefresher {
|
||||
|
||||
impl DirectRefreshRouteTokenRefresher {
|
||||
pub fn new(
|
||||
ctx: XetContext,
|
||||
refresh_route: impl Into<String>,
|
||||
client: ClientWithMiddleware,
|
||||
cred_helper: Option<Arc<dyn CredentialHelper>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
ctx,
|
||||
refresh_route: refresh_route.into(),
|
||||
client,
|
||||
cred_helper,
|
||||
@@ -98,27 +102,28 @@ impl DirectRefreshRouteTokenRefresher {
|
||||
let refresh_route = self.refresh_route.clone();
|
||||
let cred_helper = self.cred_helper.clone();
|
||||
|
||||
let jwt_info: crate::hub_client::CasJWTInfo = super::retry_wrapper::RetryWrapper::new("xet-token")
|
||||
.run_and_extract_json(move || {
|
||||
let refresh_route = refresh_route.clone();
|
||||
let client = client.clone();
|
||||
let cred_helper = cred_helper.clone();
|
||||
async move {
|
||||
let req = client
|
||||
.get(&refresh_route)
|
||||
.with_extension(crate::common::http_client::Api("xet-token"));
|
||||
let req = if let Some(helper) = cred_helper {
|
||||
helper
|
||||
.fill_credential(req)
|
||||
.await
|
||||
.map_err(reqwest_middleware::Error::middleware)?
|
||||
} else {
|
||||
req
|
||||
};
|
||||
req.send().await
|
||||
}
|
||||
})
|
||||
.await?;
|
||||
let jwt_info: crate::hub_client::CasJWTInfo =
|
||||
super::retry_wrapper::RetryWrapper::new(self.ctx.clone(), "xet-token")
|
||||
.run_and_extract_json(move || {
|
||||
let refresh_route = refresh_route.clone();
|
||||
let client = client.clone();
|
||||
let cred_helper = cred_helper.clone();
|
||||
async move {
|
||||
let req = client
|
||||
.get(&refresh_route)
|
||||
.with_extension(crate::common::http_client::Api("xet-token"));
|
||||
let req = if let Some(helper) = cred_helper {
|
||||
helper
|
||||
.fill_credential(req)
|
||||
.await
|
||||
.map_err(reqwest_middleware::Error::middleware)?
|
||||
} else {
|
||||
req
|
||||
};
|
||||
req.send().await
|
||||
}
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(jwt_info)
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ use tracing::{event, info, instrument};
|
||||
use xet_core_structures::merklehash::MerkleHash;
|
||||
use xet_core_structures::metadata_shard::file_structs::{FileDataSequenceEntry, FileDataSequenceHeader, MDBFileInfo};
|
||||
use xet_core_structures::xorb_object::SerializedXorbObject;
|
||||
use xet_runtime::core::xet_config;
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
use super::adaptive_concurrency::{AdaptiveConcurrencyController, ConnectionPermit};
|
||||
use super::auth::AuthConfig;
|
||||
@@ -39,6 +39,7 @@ lazy_static! {
|
||||
}
|
||||
|
||||
pub struct RemoteClient {
|
||||
pub(crate) ctx: XetContext,
|
||||
endpoint: String,
|
||||
dry_run: bool,
|
||||
http_client: Arc<ClientWithMiddleware>,
|
||||
@@ -64,6 +65,7 @@ impl RemoteClient {
|
||||
/// * `unix_socket_path` - Optional Unix socket path for proxying connections (ignored on non-Unix platforms)
|
||||
/// * `custom_headers` - Optional custom headers to include in HTTP requests (should include User-Agent)
|
||||
pub fn new_with_socket(
|
||||
ctx: XetContext,
|
||||
endpoint: &str,
|
||||
auth: &Option<AuthConfig>,
|
||||
session_id: &str,
|
||||
@@ -72,22 +74,29 @@ impl RemoteClient {
|
||||
custom_headers: Option<Arc<HeaderMap>>,
|
||||
) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
ctx: ctx.clone(),
|
||||
endpoint: endpoint.to_string(),
|
||||
dry_run,
|
||||
authenticated_http_client: Arc::new(
|
||||
http_client::build_auth_http_client(auth, session_id, unix_socket_path, custom_headers.clone())
|
||||
http_client::build_auth_http_client(&ctx, auth, session_id, unix_socket_path, custom_headers.clone())
|
||||
.unwrap(),
|
||||
),
|
||||
http_client: Arc::new(
|
||||
http_client::build_http_client(session_id, unix_socket_path, custom_headers.clone()).unwrap(),
|
||||
http_client::build_http_client(&ctx, session_id, unix_socket_path, custom_headers.clone()).unwrap(),
|
||||
),
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
shard_upload_http_client: Arc::new(
|
||||
http_client::build_auth_http_client_no_read_timeout(auth, session_id, unix_socket_path, custom_headers)
|
||||
.unwrap(),
|
||||
http_client::build_auth_http_client_no_read_timeout(
|
||||
&ctx,
|
||||
auth,
|
||||
session_id,
|
||||
unix_socket_path,
|
||||
custom_headers,
|
||||
)
|
||||
.unwrap(),
|
||||
),
|
||||
upload_concurrency_controller: AdaptiveConcurrencyController::new_upload("upload"),
|
||||
download_concurrency_controller: AdaptiveConcurrencyController::new_download("download"),
|
||||
upload_concurrency_controller: AdaptiveConcurrencyController::new_upload(ctx.clone(), "upload"),
|
||||
download_concurrency_controller: AdaptiveConcurrencyController::new_download(ctx.clone(), "download"),
|
||||
detected_reconstruction_api_version: AtomicU32::new(0),
|
||||
})
|
||||
}
|
||||
@@ -104,13 +113,14 @@ impl RemoteClient {
|
||||
/// * `dry_run` - Whether to run in dry-run mode
|
||||
/// * `custom_headers` - Optional custom headers to include in HTTP requests (should include User-Agent)
|
||||
pub fn new(
|
||||
ctx: XetContext,
|
||||
endpoint: &str,
|
||||
auth: &Option<AuthConfig>,
|
||||
session_id: &str,
|
||||
dry_run: bool,
|
||||
custom_headers: Option<Arc<HeaderMap>>,
|
||||
) -> Arc<Self> {
|
||||
Self::new_with_socket(endpoint, auth, session_id, dry_run, None, custom_headers)
|
||||
Self::new_with_socket(ctx, endpoint, auth, session_id, dry_run, None, custom_headers)
|
||||
}
|
||||
|
||||
/// Get the endpoint URL.
|
||||
@@ -143,7 +153,7 @@ impl RemoteClient {
|
||||
let client = self.authenticated_http_client.clone();
|
||||
let api_tag = "cas::query_dedup";
|
||||
|
||||
let result = RetryWrapper::new(api_tag)
|
||||
let result = RetryWrapper::new(self.ctx.clone(), api_tag)
|
||||
.with_429_no_retry()
|
||||
.log_errors_as_info()
|
||||
.run(move || client.get(url.clone()).with_extension(Api(api_tag)).send())
|
||||
@@ -206,7 +216,7 @@ impl RemoteClient {
|
||||
|
||||
let client = self.authenticated_http_client.clone();
|
||||
|
||||
let result: Result<T> = RetryWrapper::new(api_tag)
|
||||
let result: Result<T> = RetryWrapper::new(self.ctx.clone(), api_tag)
|
||||
.with_expected_416()
|
||||
.run_and_extract_json(move || {
|
||||
let mut request = client.get(url.clone()).with_extension(Api(api_tag));
|
||||
@@ -304,7 +314,7 @@ impl Client for RemoteClient {
|
||||
file_id: &MerkleHash,
|
||||
bytes_range: Option<FileRange>,
|
||||
) -> Result<Option<QueryReconstructionResponseV2>> {
|
||||
let forced_version = xet_config().client.reconstruction_api_version;
|
||||
let forced_version = self.ctx.config.client.reconstruction_api_version;
|
||||
self.get_reconstruction_with_version_override(file_id, bytes_range, forced_version)
|
||||
.await
|
||||
}
|
||||
@@ -331,7 +341,7 @@ impl Client for RemoteClient {
|
||||
let api_tag = "cas::batch_get_reconstruction";
|
||||
let client = self.authenticated_http_client.clone();
|
||||
|
||||
let response: BatchQueryReconstructionResponse = RetryWrapper::new(api_tag)
|
||||
let response: BatchQueryReconstructionResponse = RetryWrapper::new(self.ctx.clone(), api_tag)
|
||||
.run_and_extract_json(move || client.get(url.clone()).with_extension(Api(api_tag)).send())
|
||||
.await?;
|
||||
|
||||
@@ -368,7 +378,7 @@ impl Client for RemoteClient {
|
||||
transfer_reporter = transfer_reporter.with_progress_callback(cb);
|
||||
}
|
||||
|
||||
let result = RetryWrapper::new(api_tag)
|
||||
let result = RetryWrapper::new(self.ctx.clone(), api_tag)
|
||||
.with_retry_on_403()
|
||||
.with_connection_permit(download_permit, None)
|
||||
.run_and_extract_custom(
|
||||
@@ -518,7 +528,7 @@ impl Client for RemoteClient {
|
||||
let api_tag = "cas::get_reconstruction_info";
|
||||
let client = self.authenticated_http_client.clone();
|
||||
|
||||
let response: QueryReconstructionResponse = RetryWrapper::new(api_tag)
|
||||
let response: QueryReconstructionResponse = RetryWrapper::new(self.ctx.clone(), api_tag)
|
||||
.run_and_extract_json(move || client.get(url.clone()).with_extension(Api(api_tag)).send())
|
||||
.await?;
|
||||
|
||||
@@ -579,7 +589,7 @@ impl Client for RemoteClient {
|
||||
#[cfg(target_family = "wasm")]
|
||||
let client = self.authenticated_http_client.clone();
|
||||
|
||||
let response: UploadShardResponse = RetryWrapper::new(api_tag)
|
||||
let response: UploadShardResponse = RetryWrapper::new(self.ctx.clone(), api_tag)
|
||||
.with_connection_permit(upload_permit, Some(shard_data.len() as u64))
|
||||
.run_and_extract_json(move || {
|
||||
client
|
||||
@@ -647,7 +657,7 @@ impl Client for RemoteClient {
|
||||
let n_transfer_bytes = serialized_xorb_object.serialized_data.len() as u64;
|
||||
|
||||
let serialized_data = serialized_xorb_object.serialized_data.clone();
|
||||
let block_size = xet_config().client.upload_reporting_block_size;
|
||||
let block_size = self.ctx.config.client.upload_reporting_block_size;
|
||||
|
||||
let mut upload_reporter = StreamProgressReporter::new(n_transfer_bytes)
|
||||
.with_adaptive_concurrency_reporter(upload_permit.get_partial_completion_reporting_function());
|
||||
@@ -661,7 +671,7 @@ impl Client for RemoteClient {
|
||||
|
||||
let api_tag = "cas::upload_xorb";
|
||||
|
||||
let response: UploadXorbResponse = RetryWrapper::new(api_tag)
|
||||
let response: UploadXorbResponse = RetryWrapper::new(self.ctx.clone(), api_tag)
|
||||
.with_connection_permit(upload_permit, Some(n_transfer_bytes))
|
||||
.run_and_extract_json(move || {
|
||||
let upload_stream = UploadProgressStream::wrap_bytes_as_stream(
|
||||
@@ -747,7 +757,6 @@ mod tests {
|
||||
use xet_core_structures::xorb_object::xorb_format_test_utils::{
|
||||
ChunkSize, build_and_verify_xorb_object, build_raw_xorb,
|
||||
};
|
||||
use xet_runtime::core::XetRuntime;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -759,13 +768,14 @@ mod tests {
|
||||
let prefix = PREFIX_DEFAULT;
|
||||
let raw_xorb = build_raw_xorb(3, ChunkSize::Random(512, 10248));
|
||||
|
||||
let threadpool = XetRuntime::new().unwrap();
|
||||
let client = RemoteClient::new(CAS_ENDPOINT, &None, "", false, None);
|
||||
let ctx = XetContext::default().unwrap();
|
||||
let client = RemoteClient::new(ctx.clone(), CAS_ENDPOINT, &None, "", false, None);
|
||||
|
||||
let xorb_obj = build_and_verify_xorb_object(raw_xorb, CompressionScheme::LZ4);
|
||||
|
||||
// Act
|
||||
let result = threadpool
|
||||
let result = ctx
|
||||
.runtime
|
||||
.bridge_sync(async move {
|
||||
let permit = client.acquire_upload_permit().await.unwrap();
|
||||
client.upload_xorb(prefix, xorb_obj, None, permit).await
|
||||
|
||||
@@ -8,7 +8,7 @@ use tokio::sync::Mutex;
|
||||
use tokio_retry::RetryIf;
|
||||
use tokio_retry::strategy::{ExponentialBackoff, jitter};
|
||||
use tracing::{error, info};
|
||||
use xet_runtime::core::xet_config;
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
use super::adaptive_concurrency::ConnectionPermit;
|
||||
use crate::common::http_client::request_id_from_response;
|
||||
@@ -37,10 +37,12 @@ pub struct RetryWrapper {
|
||||
}
|
||||
|
||||
impl RetryWrapper {
|
||||
pub fn new(api_tag: &'static str) -> Self {
|
||||
pub fn new(ctx: XetContext, api_tag: &'static str) -> Self {
|
||||
let max_attempts = ctx.config.client.retry_max_attempts;
|
||||
let base_delay = ctx.config.client.retry_base_delay;
|
||||
Self {
|
||||
max_attempts: xet_config().client.retry_max_attempts,
|
||||
base_delay: xet_config().client.retry_base_delay,
|
||||
max_attempts,
|
||||
base_delay,
|
||||
no_retry_on_429: false,
|
||||
retry_on_403: false,
|
||||
expected_416: false,
|
||||
@@ -569,11 +571,18 @@ mod tests {
|
||||
use serde::{Deserialize, Serialize};
|
||||
use wiremock::matchers::{method, path};
|
||||
use wiremock::{Mock, MockServer, ResponseTemplate};
|
||||
use xet_runtime::config::XetConfig;
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn test_context() -> XetContext {
|
||||
let config = XetConfig::new();
|
||||
XetContext::from_external(tokio::runtime::Handle::current(), config)
|
||||
}
|
||||
|
||||
fn connection_wrapper(api: &'static str) -> RetryWrapper {
|
||||
RetryWrapper::new(api)
|
||||
RetryWrapper::new(test_context(), api)
|
||||
.with_base_delay(Duration::from_millis(5))
|
||||
.with_max_attempts(3)
|
||||
}
|
||||
|
||||
@@ -133,7 +133,13 @@ pub trait ClientTestingUtils: Client + Send + Sync {
|
||||
|
||||
shard.add_xorb_block(raw_xorb.xorb_info.clone())?;
|
||||
|
||||
let serialized_xorb = SerializedXorbObject::from_xorb(raw_xorb.clone(), true)?;
|
||||
let cfg = xet_runtime::config::XetConfig::new();
|
||||
let serialized_xorb = SerializedXorbObject::from_xorb(
|
||||
raw_xorb.clone(),
|
||||
true,
|
||||
cfg.xorb.compression_policy.as_str(),
|
||||
cfg.xorb.compression_scheme_retest_interval,
|
||||
)?;
|
||||
|
||||
let upload_permit = self.acquire_upload_permit().await?;
|
||||
self.upload_xorb("default", serialized_xorb, None, upload_permit).await?;
|
||||
|
||||
@@ -662,7 +662,8 @@ pub async fn test_global_dedup(client: Arc<dyn DirectAccessClient>) {
|
||||
.unwrap();
|
||||
|
||||
// Verify the returned shard can be loaded and contains the expected data
|
||||
let sf = MDBShardFile::write_out_from_reader(shard_dir_2.clone(), &mut Cursor::new(&new_shard)).unwrap();
|
||||
let sfc = xet_core_structures::metadata_shard::new_shard_file_cache();
|
||||
let sf = MDBShardFile::write_out_from_reader(shard_dir_2.clone(), &mut Cursor::new(&new_shard), &sfc).unwrap();
|
||||
|
||||
// Verify the shard has the same dedup hashes (the content matches semantically)
|
||||
let returned_dedup_hashes = MDBShardInfo::filter_cas_chunks_for_global_dedup(&mut Cursor::new(&new_shard)).unwrap();
|
||||
|
||||
@@ -5,7 +5,7 @@ use std::mem::size_of;
|
||||
use std::ops::Range;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::atomic::{AtomicU16, AtomicU64, AtomicUsize, Ordering};
|
||||
use std::sync::{Arc, LazyLock, Mutex, Weak};
|
||||
use std::sync::{Arc, Mutex, Weak};
|
||||
|
||||
use anyhow::anyhow;
|
||||
use async_trait::async_trait;
|
||||
@@ -25,6 +25,7 @@ use xet_core_structures::metadata_shard::xorb_structs::MDBXorbInfo;
|
||||
use xet_core_structures::metadata_shard::{MDBShardFile, MDBShardFileHeader, ShardFileManager};
|
||||
use xet_core_structures::serialization_utils::read_u32;
|
||||
use xet_core_structures::xorb_object::{SerializedXorbObject, XorbObject};
|
||||
use xet_runtime::core::XetContext;
|
||||
#[cfg(feature = "fd-track")]
|
||||
use xet_runtime::fd_diagnostics::{report_fd_count, track_fd_scope};
|
||||
use xet_runtime::file_utils::SafeFileCreator;
|
||||
@@ -162,17 +163,22 @@ impl redb::Value for FileShardRef {
|
||||
}
|
||||
}
|
||||
|
||||
/// Weak handle so the cache never keeps a [`redb::Database`] alive; only [`Arc`]s held by
|
||||
/// [`LocalClient`] (and clones) do. When the last strong ref drops, the entry can be purged.
|
||||
type CachedDbWeak = Weak<redb::Database>;
|
||||
/// Process-global cache of open redb databases, keyed by canonicalized path.
|
||||
/// Stores [`Weak`] pointers so the cache never keeps a database alive on its
|
||||
/// own; only [`Arc`]s held by [`LocalClient`] instances do.
|
||||
///
|
||||
/// Both [`get_or_open_db`] and [`LocalClient::drop`] hold this mutex while
|
||||
/// creating/destroying the inner [`redb::Database`], which eliminates the race
|
||||
/// window where a database file lock could still be held between a cache miss
|
||||
/// and a `Database::create` call.
|
||||
static DB_CACHE: std::sync::LazyLock<Mutex<HashMap<PathBuf, Weak<redb::Database>>>> =
|
||||
std::sync::LazyLock::new(|| Mutex::new(HashMap::new()));
|
||||
|
||||
/// Global cache of redb databases keyed by canonical DB file path.
|
||||
/// redb enforces exclusive file locking, so multiple LocalClient instances
|
||||
/// pointing at the same directory must share a single Database handle.
|
||||
static DB_CACHE: LazyLock<Mutex<HashMap<PathBuf, CachedDbWeak>>> = LazyLock::new(|| Mutex::new(HashMap::new()));
|
||||
|
||||
/// Opens or returns a shared [`Arc<redb::Database>`] for `db_path`. The map stores only
|
||||
/// [`Weak`] pointers ([`Arc::downgrade`]); on a hit, [`Weak::upgrade`] yields a new strong ref.
|
||||
/// Opens or returns a shared [`Arc<redb::Database>`] for `db_path`.
|
||||
///
|
||||
/// Must be called (and the returned `Arc` kept alive) while no other thread is
|
||||
/// dropping the last `Arc` for the same path -- which is guaranteed because
|
||||
/// both this function and `LocalClient::drop` serialize on `DB_CACHE`.
|
||||
fn get_or_open_db(db_path: &Path) -> std::result::Result<Arc<redb::Database>, redb::DatabaseError> {
|
||||
#[cfg(feature = "fd-track")]
|
||||
let _fd_scope = track_fd_scope(format!("LocalClient::get_or_open_db({})", db_path.display()));
|
||||
@@ -180,7 +186,7 @@ fn get_or_open_db(db_path: &Path) -> std::result::Result<Arc<redb::Database>, re
|
||||
let mut map = DB_CACHE.lock().unwrap();
|
||||
|
||||
if let Some(weak) = map.get(db_path)
|
||||
&& let Some(db) = Weak::upgrade(weak)
|
||||
&& let Some(db) = weak.upgrade()
|
||||
{
|
||||
tracing::trace!(target: "xet_client::local_cas_redb", path = %db_path.display(), "DB_CACHE hit");
|
||||
#[cfg(feature = "fd-track")]
|
||||
@@ -188,7 +194,7 @@ fn get_or_open_db(db_path: &Path) -> std::result::Result<Arc<redb::Database>, re
|
||||
return Ok(db);
|
||||
}
|
||||
|
||||
// Purge dead entries to avoid unbounded cache growth.
|
||||
// Purge dead entries while we hold the lock.
|
||||
map.retain(|_, weak| weak.strong_count() > 0);
|
||||
|
||||
tracing::trace!(target: "xet_client::local_cas_redb", path = %db_path.display(), "DB_CACHE miss");
|
||||
@@ -232,7 +238,11 @@ fn file_entry_byte_ranges(shard_bytes: &[u8]) -> std::result::Result<Vec<(Merkle
|
||||
}
|
||||
|
||||
pub struct LocalClient {
|
||||
db: Arc<redb::Database>,
|
||||
/// Wrapped in `Option` so `Drop` can `take()` it under the `DB_CACHE` lock,
|
||||
/// ensuring the redb file lock is released before another caller can open
|
||||
/// the same path. Always `Some` during normal operation.
|
||||
db: Option<Arc<redb::Database>>,
|
||||
db_path: PathBuf,
|
||||
shard_manager: Arc<ShardFileManager>,
|
||||
xorb_dir: PathBuf,
|
||||
shard_dir: PathBuf,
|
||||
@@ -252,21 +262,21 @@ pub struct LocalClient {
|
||||
impl LocalClient {
|
||||
/// Create a local client hosted in a temporary directory for testing.
|
||||
/// This is an async function to allow use with current-thread tokio runtime.
|
||||
pub async fn temporary() -> Result<Arc<Self>> {
|
||||
pub async fn temporary(ctx: XetContext) -> Result<Arc<Self>> {
|
||||
let tmp_dir = TempDir::new().unwrap();
|
||||
let path = tmp_dir.path().to_owned();
|
||||
let s = Self::new_internal(path, Some(tmp_dir)).await?;
|
||||
let s = Self::new_internal(ctx, path, Some(tmp_dir)).await?;
|
||||
Ok(Arc::new(s))
|
||||
}
|
||||
|
||||
/// Create a local client hosted in a directory. Effectively, this directory
|
||||
/// is the CAS endpoint and persists across instances of LocalClient.
|
||||
pub async fn new(path: impl AsRef<Path>) -> Result<Arc<Self>> {
|
||||
pub async fn new(ctx: XetContext, path: impl AsRef<Path>) -> Result<Arc<Self>> {
|
||||
let path = path.as_ref().to_owned();
|
||||
Ok(Arc::new(Self::new_internal(path, None).await?))
|
||||
Ok(Arc::new(Self::new_internal(ctx, path, None).await?))
|
||||
}
|
||||
|
||||
async fn new_internal(path: impl AsRef<Path>, tmp_dir: Option<TempDir>) -> Result<Self> {
|
||||
async fn new_internal(ctx: XetContext, path: impl AsRef<Path>, tmp_dir: Option<TempDir>) -> Result<Self> {
|
||||
let base_dir = std::path::absolute(path)?;
|
||||
if !base_dir.exists() {
|
||||
std::fs::create_dir_all(&base_dir)?;
|
||||
@@ -305,16 +315,17 @@ impl LocalClient {
|
||||
}
|
||||
|
||||
// Open / set up the shard lookup
|
||||
let shard_manager = ShardFileManager::new_in_session_directory(shard_dir.clone(), true).await?;
|
||||
let shard_manager = ShardFileManager::new_in_session_directory(&ctx, shard_dir.clone(), true).await?;
|
||||
#[cfg(feature = "fd-track")]
|
||||
report_fd_count("LocalClient::new_internal after shard manager init");
|
||||
|
||||
Ok(Self {
|
||||
db,
|
||||
db: Some(db),
|
||||
db_path,
|
||||
shard_manager,
|
||||
xorb_dir,
|
||||
shard_dir,
|
||||
upload_concurrency_controller: AdaptiveConcurrencyController::new_upload("local_uploads"),
|
||||
upload_concurrency_controller: AdaptiveConcurrencyController::new_upload(ctx, "local_uploads"),
|
||||
url_expiration_ms: AtomicU64::new(u64::MAX),
|
||||
global_dedup_expiration_secs: AtomicU64::new(0),
|
||||
random_ms_delay_window: (AtomicU64::new(0), AtomicU64::new(0)),
|
||||
@@ -324,6 +335,10 @@ impl LocalClient {
|
||||
})
|
||||
}
|
||||
|
||||
fn db(&self) -> &redb::Database {
|
||||
self.db.as_deref().expect("db used after close")
|
||||
}
|
||||
|
||||
/// Internal function to get the path for a given hash entry
|
||||
fn get_path_for_entry(&self, hash: &MerkleHash) -> PathBuf {
|
||||
self.xorb_dir.join(format!("default.{hash:?}"))
|
||||
@@ -331,7 +346,7 @@ impl LocalClient {
|
||||
|
||||
#[cfg(test)]
|
||||
fn is_file_deleted(&self, file_hash: &MerkleHash) -> bool {
|
||||
let Ok(read_txn) = self.db.begin_read() else {
|
||||
let Ok(read_txn) = self.db().begin_read() else {
|
||||
return true;
|
||||
};
|
||||
let Ok(table) = read_txn.open_table(FILE_TO_SHARD_TABLE) else {
|
||||
@@ -444,14 +459,14 @@ impl LocalClient {
|
||||
|
||||
if !in_memory.is_empty() {
|
||||
let shard_path = in_memory.write_to_directory(&self.shard_dir, None)?;
|
||||
let shard = MDBShardFile::load_from_file(&shard_path)?;
|
||||
let shard = MDBShardFile::load_from_file(&shard_path, self.shard_manager.shard_file_cache())?;
|
||||
let shard_hash = shard.shard_hash;
|
||||
self.shard_manager.register_shards(&[shard]).await?;
|
||||
|
||||
// Update FILE_TO_SHARD_TABLE with byte-accurate offsets.
|
||||
let shard_bytes = std::fs::read(&shard_path)?;
|
||||
let file_ranges = file_entry_byte_ranges(&shard_bytes)?;
|
||||
let write_txn = self.db.begin_write().map_err(map_redb_db_error)?;
|
||||
let write_txn = self.db().begin_write().map_err(map_redb_db_error)?;
|
||||
{
|
||||
let mut file_table = write_txn.open_table(FILE_TO_SHARD_TABLE).map_err(map_redb_db_error)?;
|
||||
for (file_hash, offset, length) in &file_ranges {
|
||||
@@ -478,15 +493,23 @@ impl Drop for LocalClient {
|
||||
fn drop(&mut self) {
|
||||
#[cfg(feature = "fd-track")]
|
||||
let _fd_scope = track_fd_scope(format!("LocalClient::drop({})", self.xorb_dir.display()));
|
||||
#[cfg(feature = "fd-track")]
|
||||
report_fd_count("LocalClient::drop start");
|
||||
|
||||
// Drop the database handle while holding the cache lock. This
|
||||
// serializes with `get_or_open_db`, ensuring the redb file lock is
|
||||
// fully released before any other caller can attempt to reopen the
|
||||
// same path.
|
||||
if let Ok(mut map) = DB_CACHE.lock() {
|
||||
let db = self.db.take();
|
||||
if db.as_ref().is_some_and(|d| Arc::strong_count(d) == 1) {
|
||||
map.remove(&self.db_path);
|
||||
}
|
||||
drop(db);
|
||||
}
|
||||
|
||||
#[cfg(feature = "fd-track")]
|
||||
{
|
||||
report_fd_count("LocalClient::drop start");
|
||||
if let Ok(mut map) = DB_CACHE.lock() {
|
||||
map.retain(|_, weak| weak.strong_count() > 0);
|
||||
}
|
||||
report_fd_count("LocalClient::drop end");
|
||||
}
|
||||
report_fd_count("LocalClient::drop end");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -792,7 +815,7 @@ impl LocalClient {
|
||||
/// Removes all FILE_TO_SHARD_TABLE entries whose shard_hash equals `shard_hash`.
|
||||
fn remove_file_entries_for_shard(&self, shard_hash: &MerkleHash) -> Result<()> {
|
||||
let to_remove: Vec<RedbHash> = {
|
||||
let read_txn = self.db.begin_read().map_err(map_redb_db_error)?;
|
||||
let read_txn = self.db().begin_read().map_err(map_redb_db_error)?;
|
||||
let table = read_txn.open_table(FILE_TO_SHARD_TABLE).map_err(map_redb_db_error)?;
|
||||
table
|
||||
.iter()
|
||||
@@ -803,7 +826,7 @@ impl LocalClient {
|
||||
.collect()
|
||||
};
|
||||
if !to_remove.is_empty() {
|
||||
let write_txn = self.db.begin_write().map_err(map_redb_db_error)?;
|
||||
let write_txn = self.db().begin_write().map_err(map_redb_db_error)?;
|
||||
{
|
||||
let mut table = write_txn.open_table(FILE_TO_SHARD_TABLE).map_err(map_redb_db_error)?;
|
||||
for key in &to_remove {
|
||||
@@ -836,7 +859,7 @@ impl super::DeletionControlableClient for LocalClient {
|
||||
}
|
||||
|
||||
async fn list_file_shard_entries(&self) -> Result<Vec<(MerkleHash, MerkleHash)>> {
|
||||
let read_txn = self.db.begin_read().map_err(map_redb_db_error)?;
|
||||
let read_txn = self.db().begin_read().map_err(map_redb_db_error)?;
|
||||
let table = read_txn.open_table(FILE_TO_SHARD_TABLE).map_err(map_redb_db_error)?;
|
||||
let mut entries = Vec::new();
|
||||
for entry in table.iter().map_err(map_redb_db_error)? {
|
||||
@@ -849,7 +872,7 @@ impl super::DeletionControlableClient for LocalClient {
|
||||
}
|
||||
|
||||
async fn delete_file_entry(&self, file_hash: &MerkleHash) -> Result<()> {
|
||||
let write_txn = self.db.begin_write().map_err(map_redb_db_error)?;
|
||||
let write_txn = self.db().begin_write().map_err(map_redb_db_error)?;
|
||||
{
|
||||
let mut table = write_txn.open_table(FILE_TO_SHARD_TABLE).map_err(map_redb_db_error)?;
|
||||
table.remove(&RedbHash::from(*file_hash)).map_err(map_redb_db_error)?;
|
||||
@@ -862,7 +885,7 @@ impl super::DeletionControlableClient for LocalClient {
|
||||
let shard_redb = RedbHash::from(*shard_hash);
|
||||
for _ in 0..4 {
|
||||
let to_delete: Vec<RedbHash> = {
|
||||
let read_txn = self.db.begin_read().map_err(map_redb_db_error)?;
|
||||
let read_txn = self.db().begin_read().map_err(map_redb_db_error)?;
|
||||
let table = read_txn.open_table(GLOBAL_DEDUP_TABLE).map_err(map_redb_db_error)?;
|
||||
table
|
||||
.iter()
|
||||
@@ -877,7 +900,7 @@ impl super::DeletionControlableClient for LocalClient {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let write_txn = self.db.begin_write().map_err(map_redb_db_error)?;
|
||||
let write_txn = self.db().begin_write().map_err(map_redb_db_error)?;
|
||||
{
|
||||
let mut table = write_txn.open_table(GLOBAL_DEDUP_TABLE).map_err(map_redb_db_error)?;
|
||||
for chunk_hash in &to_delete {
|
||||
@@ -888,7 +911,7 @@ impl super::DeletionControlableClient for LocalClient {
|
||||
}
|
||||
|
||||
let still_present = {
|
||||
let read_txn = self.db.begin_read().map_err(map_redb_db_error)?;
|
||||
let read_txn = self.db().begin_read().map_err(map_redb_db_error)?;
|
||||
let table = read_txn.open_table(GLOBAL_DEDUP_TABLE).map_err(map_redb_db_error)?;
|
||||
table
|
||||
.iter()
|
||||
@@ -1016,7 +1039,7 @@ impl super::DeletionControlableClient for LocalClient {
|
||||
// A file entry in a shard is only considered active if the table maps that
|
||||
// file hash to that specific shard, preventing stale entries from resurrecting.
|
||||
let file_to_shard: HashMap<MerkleHash, MerkleHash> = {
|
||||
let read_txn = self.db.begin_read().map_err(map_redb_db_error)?;
|
||||
let read_txn = self.db().begin_read().map_err(map_redb_db_error)?;
|
||||
let table = read_txn.open_table(FILE_TO_SHARD_TABLE).map_err(map_redb_db_error)?;
|
||||
let mut map = HashMap::new();
|
||||
for entry in table.iter().map_err(map_redb_db_error)? {
|
||||
@@ -1108,7 +1131,7 @@ impl super::DeletionControlableClient for LocalClient {
|
||||
// a TOCTOU race: upload_shard writes the file then commits the dedup
|
||||
// entry, so any entry visible in this MVCC snapshot is guaranteed to
|
||||
// have its shard file already on disk when we read the directory below.
|
||||
let read_txn = self.db.begin_read().map_err(map_redb_db_error)?;
|
||||
let read_txn = self.db().begin_read().map_err(map_redb_db_error)?;
|
||||
|
||||
let shard_files = self.shard_file_paths()?;
|
||||
|
||||
@@ -1239,7 +1262,7 @@ impl LocalClient {
|
||||
/// via a direct-seek into the canonical shard on disk. Returns `None` if the file
|
||||
/// is not registered (i.e. deleted or never uploaded).
|
||||
fn get_file_info_from_table(&self, file_hash: &MerkleHash) -> Result<Option<(MDBFileInfo, MerkleHash)>> {
|
||||
let read_txn = self.db.begin_read().map_err(map_redb_db_error)?;
|
||||
let read_txn = self.db().begin_read().map_err(map_redb_db_error)?;
|
||||
let table = read_txn.open_table(FILE_TO_SHARD_TABLE).map_err(map_redb_db_error)?;
|
||||
let Some(entry) = table.get(&RedbHash::from(*file_hash)).map_err(map_redb_db_error)? else {
|
||||
return Ok(None);
|
||||
@@ -1387,7 +1410,7 @@ impl Client for LocalClient {
|
||||
|
||||
async fn query_for_global_dedup_shard(&self, _prefix: &str, chunk_hash: &MerkleHash) -> Result<Option<Bytes>> {
|
||||
self.apply_api_delay().await;
|
||||
let read_txn = self.db.begin_read().map_err(map_redb_db_error)?;
|
||||
let read_txn = self.db().begin_read().map_err(map_redb_db_error)?;
|
||||
let table = read_txn.open_table(GLOBAL_DEDUP_TABLE).map_err(map_redb_db_error)?;
|
||||
|
||||
if let Some(shard) = table.get(&RedbHash::from(*chunk_hash)).map_err(map_redb_db_error)? {
|
||||
@@ -1446,7 +1469,7 @@ impl Client for LocalClient {
|
||||
|
||||
// Write the rebuilt shard to disk (creates proper lookup tables)
|
||||
let shard_path = in_memory_shard.write_to_directory(&self.shard_dir, None)?;
|
||||
let shard = MDBShardFile::load_from_file(&shard_path)?;
|
||||
let shard = MDBShardFile::load_from_file(&shard_path, self.shard_manager.shard_file_cache())?;
|
||||
let shard_hash = shard.shard_hash;
|
||||
|
||||
self.shard_manager.register_shards(&[shard]).await?;
|
||||
@@ -1459,7 +1482,7 @@ impl Client for LocalClient {
|
||||
let file_ranges = file_entry_byte_ranges(&written_shard_bytes)?;
|
||||
|
||||
let shard_hash_redb = RedbHash::from(shard_hash);
|
||||
let write_txn = self.db.begin_write().map_err(map_redb_db_error)?;
|
||||
let write_txn = self.db().begin_write().map_err(map_redb_db_error)?;
|
||||
{
|
||||
let mut dedup_table = write_txn.open_table(GLOBAL_DEDUP_TABLE).map_err(map_redb_db_error)?;
|
||||
for chunk in chunk_hashes {
|
||||
@@ -1711,8 +1734,15 @@ mod tests {
|
||||
use xet_core_structures::xorb_object::xorb_format_test_utils::{
|
||||
ChunkSize, build_and_verify_xorb_object, build_raw_xorb,
|
||||
};
|
||||
use xet_runtime::config::XetConfig;
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn test_context() -> XetContext {
|
||||
let config = XetConfig::new();
|
||||
XetContext::from_external(tokio::runtime::Handle::current(), config)
|
||||
}
|
||||
use crate::cas_client::simulation::DeletionControlableClient;
|
||||
use crate::cas_client::simulation::client_testing_utils::ClientTestingUtils;
|
||||
use crate::cas_types::{ChunkRange, XorbReconstructionFetchInfo};
|
||||
@@ -1721,7 +1751,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_common_client_suite() {
|
||||
crate::cas_client::simulation::client_unit_testing::test_client_functionality(|| async {
|
||||
LocalClient::temporary().await.unwrap()
|
||||
LocalClient::temporary(test_context()).await.unwrap()
|
||||
as std::sync::Arc<dyn crate::cas_client::simulation::DirectAccessClient>
|
||||
})
|
||||
.await;
|
||||
@@ -1740,9 +1770,10 @@ mod tests {
|
||||
let link = tmp.path().join("link");
|
||||
std::os::unix::fs::symlink(&real, &link).unwrap();
|
||||
|
||||
let c1 = LocalClient::new(&link).await.unwrap();
|
||||
let c2 = LocalClient::new(&real).await.unwrap();
|
||||
assert!(Arc::ptr_eq(&c1.db, &c2.db));
|
||||
let ctx = test_context();
|
||||
let c1 = LocalClient::new(ctx.clone(), &link).await.unwrap();
|
||||
let c2 = LocalClient::new(ctx, &real).await.unwrap();
|
||||
assert!(Arc::ptr_eq(c1.db.as_ref().unwrap(), c2.db.as_ref().unwrap()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -1752,7 +1783,7 @@ mod tests {
|
||||
let xorb_obj = build_and_verify_xorb_object(xorb, CompressionScheme::Auto);
|
||||
let hash = xorb_obj.hash;
|
||||
|
||||
let client = LocalClient::temporary().await.unwrap();
|
||||
let client = LocalClient::temporary(test_context()).await.unwrap();
|
||||
let permit = client.acquire_upload_permit().await.unwrap();
|
||||
client.upload_xorb("default", xorb_obj, None, permit).await.unwrap();
|
||||
|
||||
@@ -1873,7 +1904,7 @@ mod tests {
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn test_url_expiration() {
|
||||
super::super::client_unit_testing::test_url_expiration_functionality(|| async {
|
||||
LocalClient::temporary().await.unwrap()
|
||||
LocalClient::temporary(test_context()).await.unwrap()
|
||||
as std::sync::Arc<dyn crate::cas_client::simulation::DirectAccessClient>
|
||||
})
|
||||
.await;
|
||||
@@ -1882,7 +1913,7 @@ mod tests {
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn test_api_delay() {
|
||||
super::super::client_unit_testing::test_api_delay_functionality(|| async {
|
||||
LocalClient::temporary().await.unwrap()
|
||||
LocalClient::temporary(test_context()).await.unwrap()
|
||||
as std::sync::Arc<dyn crate::cas_client::simulation::DirectAccessClient>
|
||||
})
|
||||
.await;
|
||||
@@ -1891,7 +1922,7 @@ mod tests {
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn test_global_dedup_shard_expiration() {
|
||||
super::super::client_unit_testing::test_global_dedup_shard_expiration_functionality(|| async {
|
||||
LocalClient::temporary().await.unwrap()
|
||||
LocalClient::temporary(test_context()).await.unwrap()
|
||||
as std::sync::Arc<dyn crate::cas_client::simulation::DirectAccessClient>
|
||||
})
|
||||
.await;
|
||||
@@ -1901,7 +1932,7 @@ mod tests {
|
||||
#[cfg_attr(feature = "smoke-test", ignore)]
|
||||
async fn test_global_dedup_shard_expiration_stress() {
|
||||
super::super::client_unit_testing::test_global_dedup_shard_expiration_stress(|| async {
|
||||
LocalClient::temporary().await.unwrap()
|
||||
LocalClient::temporary(test_context()).await.unwrap()
|
||||
as std::sync::Arc<dyn crate::cas_client::simulation::DirectAccessClient>
|
||||
})
|
||||
.await;
|
||||
@@ -1910,14 +1941,14 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_deletion_suite() {
|
||||
super::super::deletion_unit_testing::test_deletion_functionality(|| async {
|
||||
LocalClient::temporary().await.unwrap()
|
||||
LocalClient::temporary(test_context()).await.unwrap()
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_verify_integrity_detects_missing_cas_block_reference() {
|
||||
let client = LocalClient::temporary().await.unwrap();
|
||||
let client = LocalClient::temporary(test_context()).await.unwrap();
|
||||
client.upload_random_file(&[(3, (0, 3)), (4, (0, 2))], 2048).await.unwrap();
|
||||
client.verify_integrity().await.unwrap();
|
||||
|
||||
@@ -1932,7 +1963,7 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_verify_integrity_detects_invalid_chunk_range() {
|
||||
let client = LocalClient::temporary().await.unwrap();
|
||||
let client = LocalClient::temporary(test_context()).await.unwrap();
|
||||
client.upload_random_file(&[(5, (0, 3))], 2048).await.unwrap();
|
||||
client.verify_integrity().await.unwrap();
|
||||
|
||||
@@ -1949,7 +1980,7 @@ mod tests {
|
||||
/// Verifies that delete_file_entry does not rewrite shard files (shard hashes remain stable).
|
||||
#[tokio::test]
|
||||
async fn test_delete_file_entry_does_not_rewrite_shards() {
|
||||
let client = LocalClient::temporary().await.unwrap();
|
||||
let client = LocalClient::temporary(test_context()).await.unwrap();
|
||||
client.upload_random_file(&[(1, (0, 3))], 2048).await.unwrap();
|
||||
|
||||
let shard_hashes_before: Vec<_> = client.shard_file_paths().unwrap().into_iter().map(|(h, _)| h).collect();
|
||||
@@ -1972,7 +2003,7 @@ mod tests {
|
||||
|
||||
let file_hash;
|
||||
{
|
||||
let client = LocalClient::new(&path).await.unwrap();
|
||||
let client = LocalClient::new(test_context(), &path).await.unwrap();
|
||||
let file = client.upload_random_file(&[(1, (0, 3)), (2, (0, 2))], 2048).await.unwrap();
|
||||
file_hash = file.file_hash;
|
||||
assert!(!client.list_file_shard_entries().await.unwrap().is_empty());
|
||||
@@ -1982,7 +2013,7 @@ mod tests {
|
||||
}
|
||||
|
||||
{
|
||||
let client = LocalClient::new(&path).await.unwrap();
|
||||
let client = LocalClient::new(test_context(), &path).await.unwrap();
|
||||
assert!(
|
||||
client.is_file_deleted(&file_hash),
|
||||
"Entry should be absent from FILE_TO_SHARD_TABLE after restart"
|
||||
@@ -1998,7 +2029,7 @@ mod tests {
|
||||
/// in any shard but exists on disk should pass verify_integrity (dedup case).
|
||||
#[tokio::test]
|
||||
async fn test_verify_integrity_cross_shard_dedup_ok() {
|
||||
let client = LocalClient::temporary().await.unwrap();
|
||||
let client = LocalClient::temporary(test_context()).await.unwrap();
|
||||
client.upload_random_file(&[(1, (0, 3))], 2048).await.unwrap();
|
||||
client.verify_integrity().await.unwrap();
|
||||
|
||||
@@ -2022,7 +2053,7 @@ mod tests {
|
||||
/// so missing XORBs for deleted files do not cause false integrity failures.
|
||||
#[tokio::test]
|
||||
async fn test_verify_integrity_skips_deleted_files() {
|
||||
let client = LocalClient::temporary().await.unwrap();
|
||||
let client = LocalClient::temporary(test_context()).await.unwrap();
|
||||
let deleted_file = client.upload_random_file(&[(1, (0, 3))], 2048).await.unwrap();
|
||||
let live_file = client.upload_random_file(&[(2, (0, 2))], 2048).await.unwrap();
|
||||
client.verify_integrity().await.unwrap();
|
||||
@@ -2062,7 +2093,7 @@ mod tests {
|
||||
/// to shard files that have been removed.
|
||||
#[tokio::test]
|
||||
async fn test_verify_integrity_detects_stale_dedup_shard_reference() {
|
||||
let client = LocalClient::temporary().await.unwrap();
|
||||
let client = LocalClient::temporary(test_context()).await.unwrap();
|
||||
let file = client.upload_random_file(&[(10, (0, 3))], 2048).await.unwrap();
|
||||
client.verify_integrity().await.unwrap();
|
||||
|
||||
@@ -2095,7 +2126,7 @@ mod tests {
|
||||
/// after deleting the original file and its xorbs must not resurrect stale entries.
|
||||
#[tokio::test]
|
||||
async fn test_reupload_same_file_hash_does_not_resurrect_stale_entries() {
|
||||
let client = LocalClient::temporary().await.unwrap();
|
||||
let client = LocalClient::temporary(test_context()).await.unwrap();
|
||||
|
||||
// 1. Upload file F in shard S1 referencing xorb X.
|
||||
let file = client.upload_random_file(&[(1, (0, 3))], 2048).await.unwrap();
|
||||
@@ -2139,7 +2170,7 @@ mod tests {
|
||||
/// Tests that list_xorbs_and_tags tags change after file re-creation with a timestamp delay.
|
||||
#[tokio::test]
|
||||
async fn test_list_xorbs_and_tags_timestamp_changes() {
|
||||
let client = LocalClient::temporary().await.unwrap();
|
||||
let client = LocalClient::temporary(test_context()).await.unwrap();
|
||||
|
||||
let file1 = client.upload_random_file(&[(1, (0, 2))], 2048).await.unwrap();
|
||||
let xorb_hash = file1.terms[0].xorb_hash;
|
||||
|
||||
@@ -47,6 +47,7 @@ use anyhow::Result;
|
||||
use clap::Parser;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use xet_client::cas_client::{LocalServer, LocalServerConfig};
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
/// A local HTTP server that wraps a DirectAccessClient for testing and development.
|
||||
///
|
||||
@@ -116,7 +117,8 @@ async fn main() -> Result<()> {
|
||||
}
|
||||
tracing::info!("Listening on: {}:{}", config.host, config.port);
|
||||
|
||||
let server: xet_client::cas_client::LocalServer = LocalServer::new(config).await?;
|
||||
let ctx = XetContext::default().map_err(|e| anyhow::anyhow!("{e}"))?;
|
||||
let server: xet_client::cas_client::LocalServer = LocalServer::new(ctx, config).await?;
|
||||
server.run().await?;
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -14,16 +14,18 @@
|
||||
//! ```no_run
|
||||
//! use anyhow::Result;
|
||||
//! use xet_client::cas_client::{LocalServer, LocalServerConfig};
|
||||
//! use xet_runtime::core::XetContext;
|
||||
//!
|
||||
//! #[tokio::main]
|
||||
//! async fn main() -> Result<()> {
|
||||
//! let ctx = XetContext::default().unwrap();
|
||||
//! let config = LocalServerConfig {
|
||||
//! data_directory: "./data".into(),
|
||||
//! host: "127.0.0.1".to_string(),
|
||||
//! port: 8080,
|
||||
//! in_memory: false,
|
||||
//! };
|
||||
//! let server: LocalServer = LocalServer::new(config).await?;
|
||||
//! let server: LocalServer = LocalServer::new(ctx, config).await?;
|
||||
//! server.run().await?;
|
||||
//! Ok(())
|
||||
//! }
|
||||
@@ -47,6 +49,7 @@ use tokio::net::TcpListener;
|
||||
#[cfg(test)]
|
||||
use tokio::sync::oneshot;
|
||||
use tower_http::cors::CorsLayer;
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
#[cfg(test)]
|
||||
use super::super::super::RemoteClient;
|
||||
@@ -105,14 +108,14 @@ impl LocalServer {
|
||||
///
|
||||
/// If `in_memory` is false, creates a new `LocalClient` pointing to the configured data directory.
|
||||
/// If `in_memory` is true, creates a new `MemoryClient` (data directory is ignored).
|
||||
pub async fn new(config: LocalServerConfig) -> Result<Self> {
|
||||
pub async fn new(ctx: XetContext, config: LocalServerConfig) -> Result<Self> {
|
||||
let (client, deletion_client): (Arc<dyn DirectAccessClient>, Option<Arc<dyn DeletionControlableClient>>) =
|
||||
if config.in_memory {
|
||||
let client = MemoryClient::new();
|
||||
let client = MemoryClient::new(ctx.clone());
|
||||
let deletion_client = client.clone() as Arc<dyn DeletionControlableClient>;
|
||||
(client, Some(deletion_client))
|
||||
} else {
|
||||
let client = LocalClient::new(&config.data_directory).await?;
|
||||
let client = LocalClient::new(ctx, &config.data_directory).await?;
|
||||
let deletion_client = client.clone() as Arc<dyn DeletionControlableClient>;
|
||||
(client, Some(deletion_client))
|
||||
};
|
||||
@@ -287,35 +290,38 @@ impl LocalTestServer {
|
||||
///
|
||||
/// The server listens on a randomly assigned available port on localhost.
|
||||
pub async fn start_with_socket_proxy(in_memory: bool, socket_path: Option<PathBuf>) -> Self {
|
||||
let ctx = XetContext::default().expect("XetContext::new");
|
||||
if in_memory {
|
||||
let client = MemoryClient::new();
|
||||
let client = MemoryClient::new(ctx.clone());
|
||||
let deletion_client: Arc<dyn DeletionControlableClient> = client.clone();
|
||||
Self::start_with_client_and_socket(client, Some(deletion_client), socket_path).await
|
||||
Self::start_with_client_and_socket(ctx, client, Some(deletion_client), socket_path).await
|
||||
} else {
|
||||
let client = LocalClient::temporary().await.unwrap();
|
||||
let client = LocalClient::temporary(ctx.clone()).await.unwrap();
|
||||
let deletion_client: Arc<dyn DeletionControlableClient> = client.clone();
|
||||
Self::start_with_client_and_socket(client, Some(deletion_client), socket_path).await
|
||||
Self::start_with_client_and_socket(ctx, client, Some(deletion_client), socket_path).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Starts a new test server using an existing `DirectAccessClient`.
|
||||
///
|
||||
/// Useful when you need to pre-populate the client with data before starting the server.
|
||||
pub async fn start_with_client(client: Arc<dyn DirectAccessClient>) -> Self {
|
||||
Self::start_with_client_and_socket(client, None, None).await
|
||||
pub async fn start_with_client(ctx: XetContext, client: Arc<dyn DirectAccessClient>) -> Self {
|
||||
Self::start_with_client_and_socket(ctx, client, None, None).await
|
||||
}
|
||||
|
||||
/// Starts a new test server using an existing `DirectAccessClient` and optional
|
||||
/// deletion-capable client, with an optional socket proxy.
|
||||
/// deletion-capable client.
|
||||
pub async fn start_with_client_and_deletion(
|
||||
ctx: XetContext,
|
||||
client: Arc<dyn DirectAccessClient>,
|
||||
deletion_client: Option<Arc<dyn DeletionControlableClient>>,
|
||||
) -> Self {
|
||||
Self::start_with_client_and_socket(client, deletion_client, None).await
|
||||
Self::start_with_client_and_socket(ctx, client, deletion_client, None).await
|
||||
}
|
||||
|
||||
/// Starts a new test server using an existing `DirectAccessClient` with an optional socket proxy.
|
||||
async fn start_with_client_and_socket(
|
||||
ctx: XetContext,
|
||||
client: Arc<dyn DirectAccessClient>,
|
||||
deletion_client: Option<Arc<dyn DeletionControlableClient>>,
|
||||
_socket_path: Option<PathBuf>,
|
||||
@@ -339,7 +345,6 @@ impl LocalTestServer {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
if let Some(socket_path) = _socket_path {
|
||||
// Extract host:port from http://host:port
|
||||
let tcp_addr = tcp_endpoint.strip_prefix("http://").unwrap_or(&tcp_endpoint).to_string();
|
||||
|
||||
let proxy = UnixSocketProxy::new(socket_path.clone(), tcp_addr)
|
||||
@@ -348,9 +353,9 @@ impl LocalTestServer {
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
|
||||
// Create RemoteClient with socket path
|
||||
let socket_path_str = socket_path.to_string_lossy().to_string();
|
||||
let client = RemoteClient::new_with_socket(
|
||||
ctx.clone(),
|
||||
&tcp_endpoint,
|
||||
&None,
|
||||
"test-session",
|
||||
@@ -362,14 +367,14 @@ impl LocalTestServer {
|
||||
(client, Some(proxy))
|
||||
} else {
|
||||
let client =
|
||||
RemoteClient::new(&tcp_endpoint, &None, "test-session", false, Some(Arc::new(headers)));
|
||||
RemoteClient::new(ctx, &tcp_endpoint, &None, "test-session", false, Some(Arc::new(headers)));
|
||||
(client, None)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
let client = RemoteClient::new(&tcp_endpoint, &None, "test-session", false, None);
|
||||
let client = RemoteClient::new(ctx, &tcp_endpoint, &None, "test-session", false, None);
|
||||
(client, Option::<()>::None)
|
||||
}
|
||||
};
|
||||
@@ -1231,8 +1236,9 @@ mod tests {
|
||||
#[tokio::test]
|
||||
#[cfg_attr(feature = "smoke-test", ignore)]
|
||||
async fn test_deletion_lifecycle_via_server() {
|
||||
let lc = LocalClient::temporary().await.unwrap();
|
||||
let server = LocalTestServer::start_with_client(lc.clone()).await;
|
||||
let ctx = XetContext::default().expect("XetContext::new");
|
||||
let lc = LocalClient::temporary(ctx.clone()).await.unwrap();
|
||||
let server = LocalTestServer::start_with_client(ctx, lc.clone()).await;
|
||||
|
||||
// Upload files via remote client (goes through HTTP server)
|
||||
let file1 = server
|
||||
@@ -1290,27 +1296,18 @@ mod tests {
|
||||
assert!(lc.list_xorbs().await.unwrap().is_empty());
|
||||
}
|
||||
|
||||
/// Keeps a LocalTestServer alive for the duration of the tokio runtime by
|
||||
/// moving it into a spawned task. Returns the endpoint URL.
|
||||
fn detach_server(server: LocalTestServer) -> String {
|
||||
let endpoint = server.endpoint().to_string();
|
||||
tokio::spawn(async move {
|
||||
let _server = server;
|
||||
futures::future::pending::<()>().await;
|
||||
});
|
||||
endpoint
|
||||
}
|
||||
|
||||
/// Runs the common DirectAccessClient test suite via SimulationControlClient.
|
||||
#[tokio::test]
|
||||
#[cfg_attr(feature = "smoke-test", ignore)]
|
||||
async fn test_simulation_control_client_common_suite() {
|
||||
crate::cas_client::simulation::client_unit_testing::test_client_functionality(|| async {
|
||||
let lc = LocalClient::temporary().await.unwrap();
|
||||
let ctx = XetContext::default().expect("XetContext::new");
|
||||
let lc = LocalClient::temporary(ctx.clone()).await.unwrap();
|
||||
let dc: Arc<dyn DeletionControlableClient> = lc.clone();
|
||||
let server = LocalTestServer::start_with_client_and_deletion(lc, Some(dc)).await;
|
||||
let endpoint = detach_server(server);
|
||||
Arc::new(SimulationControlClient::new(&endpoint)) as Arc<dyn DirectAccessClient>
|
||||
let server = LocalTestServer::start_with_client_and_deletion(ctx.clone(), lc, Some(dc)).await;
|
||||
let endpoint = server.endpoint().to_string();
|
||||
Arc::new(SimulationControlClient::new(ctx, &endpoint).with_keep_alive(server))
|
||||
as Arc<dyn DirectAccessClient>
|
||||
})
|
||||
.await;
|
||||
}
|
||||
@@ -1322,7 +1319,8 @@ mod tests {
|
||||
.with_ephemeral_disk()
|
||||
.start()
|
||||
.await;
|
||||
let sc = SimulationControlClient::new(server.http_endpoint());
|
||||
let ctx = XetContext::default().expect("XetContext::new");
|
||||
let sc = SimulationControlClient::new(ctx, server.http_endpoint());
|
||||
|
||||
let file = sc.upload_random_file(&[(1, (0, 4))], CHUNK_SIZE).await.unwrap();
|
||||
let first_chunk = file.terms[0].chunk_hashes[0];
|
||||
@@ -1362,11 +1360,12 @@ mod tests {
|
||||
#[cfg_attr(feature = "smoke-test", ignore)]
|
||||
async fn test_simulation_control_client_deletion_suite() {
|
||||
crate::cas_client::simulation::deletion_unit_testing::test_deletion_functionality(|| async {
|
||||
let lc = LocalClient::temporary().await.unwrap();
|
||||
let ctx = XetContext::default().expect("XetContext::new");
|
||||
let lc = LocalClient::temporary(ctx.clone()).await.unwrap();
|
||||
let dc: Arc<dyn DeletionControlableClient> = lc.clone();
|
||||
let server = LocalTestServer::start_with_client_and_deletion(lc, Some(dc)).await;
|
||||
let endpoint = detach_server(server);
|
||||
Arc::new(SimulationControlClient::new(&endpoint))
|
||||
let server = LocalTestServer::start_with_client_and_deletion(ctx.clone(), lc, Some(dc)).await;
|
||||
let endpoint = server.endpoint().to_string();
|
||||
Arc::new(SimulationControlClient::new(ctx, &endpoint).with_keep_alive(server))
|
||||
})
|
||||
.await;
|
||||
}
|
||||
@@ -1375,7 +1374,8 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_simulation_control_client_deletion_on_memory_backend() {
|
||||
let server = LocalTestServer::start(true).await;
|
||||
let sc = SimulationControlClient::new(server.endpoint());
|
||||
let ctx = XetContext::default().expect("XetContext::new");
|
||||
let sc = SimulationControlClient::new(ctx, server.endpoint());
|
||||
|
||||
// DirectAccessClient methods should work.
|
||||
let xorbs = DirectAccessClient::list_xorbs(&sc).await.unwrap();
|
||||
@@ -1400,7 +1400,8 @@ mod tests {
|
||||
use crate::cas_client::simulation::LocalTestServerBuilder;
|
||||
|
||||
let server = LocalTestServerBuilder::new().with_ephemeral_disk().start().await;
|
||||
let sc = SimulationControlClient::new(server.http_endpoint());
|
||||
let ctx = XetContext::default().expect("XetContext::new");
|
||||
let sc = SimulationControlClient::new(ctx, server.http_endpoint());
|
||||
|
||||
let file = sc.upload_random_file(&[(1, (0, 3))], 2048).await.unwrap();
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ use bytes::Bytes;
|
||||
use http::header::HeaderMap;
|
||||
use xet_core_structures::merklehash::MerkleHash;
|
||||
use xet_core_structures::xorb_object::XorbObject;
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
use super::simulation_types::{
|
||||
FetchTermDataRequest, FetchTermDataResponse, FileShardsEntry, FileSizeResponse, HashWithTag, TagDeleteRequest,
|
||||
@@ -34,22 +35,35 @@ pub struct SimulationControlClient {
|
||||
endpoint: String,
|
||||
http_client: reqwest::Client,
|
||||
remote_client: Arc<RemoteClient>,
|
||||
_keep_alive: Option<Box<dyn std::any::Any + Send + Sync>>,
|
||||
}
|
||||
|
||||
impl SimulationControlClient {
|
||||
/// Creates a new client connected to the given server endpoint URL.
|
||||
pub fn new(endpoint: &str) -> Self {
|
||||
pub fn new(ctx: XetContext, endpoint: &str) -> Self {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(http::header::USER_AGENT, http::header::HeaderValue::from_static("simulation-control-client"));
|
||||
let remote_client = RemoteClient::new(endpoint, &None, "simulation-session", false, Some(Arc::new(headers)));
|
||||
let remote_client =
|
||||
RemoteClient::new(ctx, endpoint, &None, "simulation-session", false, Some(Arc::new(headers)));
|
||||
|
||||
Self {
|
||||
endpoint: endpoint.to_string(),
|
||||
http_client: reqwest::Client::new(),
|
||||
remote_client,
|
||||
_keep_alive: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Attaches a resource that will be kept alive as long as this client exists.
|
||||
///
|
||||
/// Primarily used in tests to tie the lifetime of a [`LocalTestServer`] to
|
||||
/// the client so that the server is shut down when the client is dropped,
|
||||
/// preventing file-descriptor leaks.
|
||||
pub fn with_keep_alive(mut self, resource: impl std::any::Any + Send + Sync + 'static) -> Self {
|
||||
self._keep_alive = Some(Box::new(resource));
|
||||
self
|
||||
}
|
||||
|
||||
/// Constructs a full URL for a `/simulation/` endpoint path.
|
||||
fn sim_url(&self, path: &str) -> String {
|
||||
format!("{}/simulation{}", self.endpoint, path)
|
||||
|
||||
@@ -19,6 +19,7 @@ use xet_core_structures::metadata_shard::shard_in_memory::MDBInMemoryShard;
|
||||
use xet_core_structures::metadata_shard::streaming_shard::MDBMinimalShard;
|
||||
use xet_core_structures::metadata_shard::xorb_structs::MDBXorbInfo;
|
||||
use xet_core_structures::xorb_object::{SerializedXorbObject, XorbObject};
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
use super::super::Client;
|
||||
use super::super::adaptive_concurrency::AdaptiveConcurrencyController;
|
||||
@@ -76,12 +77,12 @@ pub struct MemoryClient {
|
||||
|
||||
impl MemoryClient {
|
||||
/// Create a new in-memory client.
|
||||
pub fn new() -> Arc<Self> {
|
||||
pub fn new(ctx: XetContext) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
xorbs: RwLock::new(MerkleHashMap::new()),
|
||||
shard: RwLock::new(MDBInMemoryShard::default()),
|
||||
global_dedup: RwLock::new(MerkleHashMap::new()),
|
||||
upload_concurrency_controller: AdaptiveConcurrencyController::new_upload("memory_uploads"),
|
||||
upload_concurrency_controller: AdaptiveConcurrencyController::new_upload(ctx, "memory_uploads"),
|
||||
xorb_generation: AtomicU64::new(0),
|
||||
url_expiration_ms: AtomicU64::new(u64::MAX),
|
||||
global_dedup_expiration_secs: AtomicU64::new(0),
|
||||
@@ -262,23 +263,6 @@ impl MemoryClient {
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MemoryClient {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
xorbs: RwLock::new(MerkleHashMap::new()),
|
||||
shard: RwLock::new(MDBInMemoryShard::default()),
|
||||
global_dedup: RwLock::new(MerkleHashMap::new()),
|
||||
upload_concurrency_controller: AdaptiveConcurrencyController::new_upload("memory_uploads"),
|
||||
xorb_generation: AtomicU64::new(0),
|
||||
url_expiration_ms: AtomicU64::new(u64::MAX),
|
||||
global_dedup_expiration_secs: AtomicU64::new(0),
|
||||
random_ms_delay_window: (AtomicU64::new(0), AtomicU64::new(0)),
|
||||
max_ranges_per_fetch: AtomicUsize::new(usize::MAX),
|
||||
v2_disabled_status: AtomicU16::new(0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(not(target_family = "wasm"), async_trait)]
|
||||
#[cfg_attr(target_family = "wasm", async_trait(?Send))]
|
||||
impl DirectAccessClient for MemoryClient {
|
||||
@@ -1142,16 +1126,24 @@ fn parse_any_fetch_url(url: &str) -> Result<(MerkleHash, Instant)> {
|
||||
|
||||
#[cfg(all(test, not(target_family = "wasm")))]
|
||||
mod tests {
|
||||
use xet_runtime::config::XetConfig;
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
use super::super::client_testing_utils::ClientTestingUtils;
|
||||
use super::super::deletion_controls::DeletionControlableClient;
|
||||
use super::*;
|
||||
|
||||
fn test_ctx() -> XetContext {
|
||||
let config = XetConfig::new();
|
||||
XetContext::from_external(tokio::runtime::Handle::current(), config)
|
||||
}
|
||||
|
||||
fn new_client() -> Arc<dyn super::super::DirectAccessClient> {
|
||||
MemoryClient::new()
|
||||
MemoryClient::new(test_ctx())
|
||||
}
|
||||
|
||||
fn new_deletion_client() -> Arc<MemoryClient> {
|
||||
MemoryClient::new()
|
||||
MemoryClient::new(test_ctx())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -1213,7 +1205,7 @@ mod tests {
|
||||
/// Comprehensive test for RandomXorb insertion and data access.
|
||||
#[tokio::test]
|
||||
async fn test_random_xorb() {
|
||||
let client = MemoryClient::new();
|
||||
let client = MemoryClient::new(test_ctx());
|
||||
|
||||
// Basic insertion and existence
|
||||
let xorb = RandomXorb::from_seed(42, 5, 1024);
|
||||
@@ -1252,7 +1244,7 @@ mod tests {
|
||||
/// Test RandomXorb with large chunk count and scattered range access.
|
||||
#[tokio::test]
|
||||
async fn test_random_xorb_large() {
|
||||
let client = MemoryClient::new();
|
||||
let client = MemoryClient::new(test_ctx());
|
||||
let xorb = RandomXorb::from_seed(12345, 100, 4096);
|
||||
let xorb_hash = client.insert_random_xorb(xorb.clone()).await.unwrap();
|
||||
|
||||
@@ -1267,7 +1259,7 @@ mod tests {
|
||||
/// Comprehensive test for lazy file insertion with on-the-fly xorb generation.
|
||||
#[tokio::test]
|
||||
async fn test_lazy_file() {
|
||||
let client = MemoryClient::new();
|
||||
let client = MemoryClient::new(test_ctx());
|
||||
|
||||
// Single-term file
|
||||
let file = client.insert_random_lazy_file(&[(1, (0, 3))], 256).await.unwrap();
|
||||
@@ -1312,8 +1304,14 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_lazy_file_deterministic() {
|
||||
let term_spec = &[(999, (0, 4))];
|
||||
let file1 = MemoryClient::new().insert_random_lazy_file(term_spec, 512).await.unwrap();
|
||||
let file2 = MemoryClient::new().insert_random_lazy_file(term_spec, 512).await.unwrap();
|
||||
let file1 = MemoryClient::new(test_ctx())
|
||||
.insert_random_lazy_file(term_spec, 512)
|
||||
.await
|
||||
.unwrap();
|
||||
let file2 = MemoryClient::new(test_ctx())
|
||||
.insert_random_lazy_file(term_spec, 512)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(file1.file_hash, file2.file_hash);
|
||||
assert_eq!(file1.data, file2.data);
|
||||
}
|
||||
@@ -1321,7 +1319,7 @@ mod tests {
|
||||
/// Verify materialized and random xorbs coexist correctly.
|
||||
#[tokio::test]
|
||||
async fn test_mixed_xorb_types() {
|
||||
let client = MemoryClient::new();
|
||||
let client = MemoryClient::new(test_ctx());
|
||||
|
||||
let random_xorb = RandomXorb::from_seed(111, 3, 256);
|
||||
let random_hash = client.insert_random_xorb(random_xorb).await.unwrap();
|
||||
|
||||
@@ -80,11 +80,11 @@ impl RemoteSimulationClient {
|
||||
let random_bytes = Bytes::from(random_data);
|
||||
|
||||
let n_upload_bytes = random_bytes.len() as u64;
|
||||
let block_size = xet_runtime::core::xet_config().client.upload_reporting_block_size;
|
||||
let block_size = self.inner.ctx.config.client.upload_reporting_block_size;
|
||||
|
||||
let api_tag = "simulation::dummy_upload";
|
||||
|
||||
RetryWrapper::new(api_tag)
|
||||
RetryWrapper::new(self.inner.ctx.clone(), api_tag)
|
||||
.with_connection_permit(upload_permit, Some(n_upload_bytes))
|
||||
.run(move || {
|
||||
let upload_stream = UploadProgressStream::new(random_bytes.clone(), block_size);
|
||||
|
||||
@@ -14,6 +14,7 @@ use http::header::{self, HeaderMap, HeaderValue};
|
||||
#[cfg(unix)]
|
||||
use tempfile::TempDir;
|
||||
use tokio::sync::oneshot;
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
use super::super::RemoteClient;
|
||||
use super::super::interface::Client;
|
||||
@@ -66,7 +67,7 @@ use crate::error::Result;
|
||||
/// .await;
|
||||
///
|
||||
/// // Use existing client
|
||||
/// let client = MemoryClient::new();
|
||||
/// let client = MemoryClient::new(ctx);
|
||||
/// let server = LocalTestServerBuilder::new()
|
||||
/// .with_client(client)
|
||||
/// .with_ephemeral_socket()
|
||||
@@ -171,6 +172,7 @@ impl LocalTestServerBuilder {
|
||||
|
||||
/// Builds and starts the test server.
|
||||
pub async fn start(self) -> LocalTestServer {
|
||||
let ctx = XetContext::default().expect("XetContext::new");
|
||||
#[cfg(unix)]
|
||||
let (socket_path, ephemeral_tempdir) = if self.ephemeral_socket {
|
||||
let tempdir = TempDir::new().expect("Failed to create temporary directory for ephemeral socket");
|
||||
@@ -189,11 +191,11 @@ impl LocalTestServerBuilder {
|
||||
if let Some(client) = self.client {
|
||||
(client, None)
|
||||
} else if self.in_memory {
|
||||
let mc = MemoryClient::new();
|
||||
let mc = MemoryClient::new(ctx.clone());
|
||||
let dc: Arc<dyn DeletionControlableClient> = mc.clone();
|
||||
(mc, Some(dc))
|
||||
} else if self.ephemeral_disk {
|
||||
let lc = LocalClient::temporary()
|
||||
let lc = LocalClient::temporary(ctx.clone())
|
||||
.await
|
||||
.expect("Failed to create LocalClient with temporary directory");
|
||||
let dc: Arc<dyn DeletionControlableClient> = lc.clone();
|
||||
@@ -202,7 +204,9 @@ impl LocalTestServerBuilder {
|
||||
let disk_path = self.disk_location.unwrap_or_else(|| {
|
||||
panic!("with_disk_location must be called when in_memory is false and no client is provided")
|
||||
});
|
||||
let lc = LocalClient::new(&disk_path).await.expect("Failed to create LocalClient");
|
||||
let lc = LocalClient::new(ctx.clone(), &disk_path)
|
||||
.await
|
||||
.expect("Failed to create LocalClient");
|
||||
let dc: Arc<dyn DeletionControlableClient> = lc.clone();
|
||||
(lc, Some(dc))
|
||||
};
|
||||
@@ -256,6 +260,7 @@ impl LocalTestServerBuilder {
|
||||
|
||||
let socket_path_str = socket_path.to_string_lossy().to_string();
|
||||
let client = RemoteClient::new_with_socket(
|
||||
ctx.clone(),
|
||||
&client_endpoint,
|
||||
&None,
|
||||
"test-session",
|
||||
@@ -266,13 +271,21 @@ impl LocalTestServerBuilder {
|
||||
|
||||
(client, Some(proxy))
|
||||
} else {
|
||||
let client = RemoteClient::new(&client_endpoint, &None, "test-session", false, custom_headers.clone());
|
||||
let client = RemoteClient::new(
|
||||
ctx.clone(),
|
||||
&client_endpoint,
|
||||
&None,
|
||||
"test-session",
|
||||
false,
|
||||
custom_headers.clone(),
|
||||
);
|
||||
(client, None)
|
||||
}
|
||||
};
|
||||
|
||||
#[cfg(not(unix))]
|
||||
let remote_client = RemoteClient::new(&client_endpoint, &None, "test-session", false, custom_headers.clone());
|
||||
let remote_client =
|
||||
RemoteClient::new(ctx.clone(), &client_endpoint, &None, "test-session", false, custom_headers.clone());
|
||||
|
||||
let remote_simulation_client = Arc::new(RemoteSimulationClient::new(remote_client));
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ use std::path::PathBuf;
|
||||
use std::u64;
|
||||
|
||||
use clap::Parser;
|
||||
use xet_runtime::config::XetConfig;
|
||||
|
||||
use crate::chunk_cache::{CacheConfig, DiskCache};
|
||||
|
||||
@@ -19,10 +20,14 @@ fn main() {
|
||||
}
|
||||
|
||||
fn print_main(root: PathBuf) {
|
||||
let cache = DiskCache::initialize(&CacheConfig {
|
||||
cache_directory: root,
|
||||
cache_size: u64::MAX,
|
||||
})
|
||||
let xet_config = XetConfig::new();
|
||||
let cache = DiskCache::initialize(
|
||||
&xet_config,
|
||||
&CacheConfig {
|
||||
cache_directory: root,
|
||||
cache_size: u64::MAX,
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
cache.print();
|
||||
}
|
||||
|
||||
@@ -3,6 +3,8 @@ use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Arc, LazyLock, Mutex, Weak};
|
||||
|
||||
use xet_runtime::config::XetConfig;
|
||||
|
||||
use super::error::ChunkCacheError;
|
||||
use super::{CacheConfig, ChunkCache, DiskCache};
|
||||
|
||||
@@ -11,8 +13,8 @@ use super::{CacheConfig, ChunkCache, DiskCache};
|
||||
static CACHE_MANAGER: LazyLock<CacheManager> = LazyLock::new(CacheManager::new);
|
||||
|
||||
/// get_cache attempts to return a cache given the provided config parameter
|
||||
pub fn get_cache(config: &CacheConfig) -> Result<Arc<dyn ChunkCache>, ChunkCacheError> {
|
||||
CACHE_MANAGER.get(config)
|
||||
pub fn get_cache(xet_config: &XetConfig, config: &CacheConfig) -> Result<Arc<dyn ChunkCache>, ChunkCacheError> {
|
||||
CACHE_MANAGER.get(xet_config, config)
|
||||
}
|
||||
|
||||
struct CacheManager {
|
||||
@@ -30,7 +32,7 @@ impl CacheManager {
|
||||
/// cache_directory then it will return an Arc to that `DiskCache` instance. If it doesn't exist
|
||||
/// or the `DiskCache` instance has been deallocated (CacheManager only holds a weak pointer)
|
||||
/// then it creates a new instance based on the provided config.
|
||||
fn get(&self, config: &CacheConfig) -> Result<Arc<dyn ChunkCache>, ChunkCacheError> {
|
||||
fn get(&self, xet_config: &XetConfig, config: &CacheConfig) -> Result<Arc<dyn ChunkCache>, ChunkCacheError> {
|
||||
let mut vals = self.vals.lock()?;
|
||||
if let Some(v) = vals.get_mut(&config.cache_directory) {
|
||||
let weak = v.borrow().clone();
|
||||
@@ -40,12 +42,12 @@ impl CacheManager {
|
||||
}
|
||||
// since upgrading failed, creates a new DiskCache, replaces the weak pointer with a
|
||||
// weak pointer to the new instance and then returns the Arc to the new cache instance
|
||||
let result: Arc<dyn ChunkCache> = Arc::new(DiskCache::initialize(config)?);
|
||||
let result: Arc<dyn ChunkCache> = Arc::new(DiskCache::initialize(xet_config, config)?);
|
||||
v.replace(Arc::downgrade(&result));
|
||||
Ok(result)
|
||||
} else {
|
||||
// create a new Cache and insert weak pointer to managed map
|
||||
let result: Arc<dyn ChunkCache> = Arc::new(DiskCache::initialize(config)?);
|
||||
let result: Arc<dyn ChunkCache> = Arc::new(DiskCache::initialize(xet_config, config)?);
|
||||
vals.insert(config.cache_directory.clone(), RefCell::new(Arc::downgrade(&result)));
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ use base64::engine::general_purpose::URL_SAFE;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, error};
|
||||
use xet_core_structures::merklehash::MerkleHash;
|
||||
use xet_runtime::core::xet_config;
|
||||
use xet_runtime::config::XetConfig;
|
||||
use xet_runtime::error_printer::ErrorPrinter;
|
||||
use xet_runtime::file_utils::SafeFileCreator;
|
||||
use xet_runtime::utils::output_bytes;
|
||||
@@ -205,7 +205,7 @@ impl DiskCache {
|
||||
/// │ ├── [range 400-402, file_len, file_hash]
|
||||
/// │ ├── [range 404-405, file_len, file_hash]
|
||||
/// │ └── [range 679-700, file_len, file_hash]
|
||||
pub fn initialize(config: &CacheConfig) -> Result<Self, ChunkCacheError> {
|
||||
pub fn initialize(xet_config: &XetConfig, config: &CacheConfig) -> Result<Self, ChunkCacheError> {
|
||||
if config.cache_size == 0 {
|
||||
return Err(ChunkCacheError::InvalidArguments);
|
||||
}
|
||||
@@ -213,7 +213,7 @@ impl DiskCache {
|
||||
let cache_root = config.cache_directory.clone();
|
||||
|
||||
// May take a while; don't block the runtime for this.
|
||||
let state = Self::initialize_state(&cache_root, capacity)?;
|
||||
let state = Self::initialize_state(&cache_root, capacity, xet_config)?;
|
||||
|
||||
Ok(Self {
|
||||
state: Arc::new(RwLock::new(state)),
|
||||
@@ -222,7 +222,11 @@ impl DiskCache {
|
||||
})
|
||||
}
|
||||
|
||||
fn initialize_state(cache_root: &PathBuf, capacity: u64) -> Result<CacheState, ChunkCacheError> {
|
||||
fn initialize_state(
|
||||
cache_root: &PathBuf,
|
||||
capacity: u64,
|
||||
xet_config: &XetConfig,
|
||||
) -> Result<CacheState, ChunkCacheError> {
|
||||
let mut state = HashMap::new();
|
||||
let mut total_bytes = 0;
|
||||
let mut num_items = 0;
|
||||
@@ -284,7 +288,7 @@ impl DiskCache {
|
||||
|
||||
// loop through cache items inside key directory
|
||||
for item in key_readdir {
|
||||
let cache_item = match try_parse_cache_file(item, capacity) {
|
||||
let cache_item = match try_parse_cache_file(item, capacity, xet_config) {
|
||||
Ok(Some(ci)) => ci,
|
||||
Ok(None) => continue,
|
||||
Err(e) => return Err(e),
|
||||
@@ -664,7 +668,11 @@ fn is_ok_dir(dir_result: Result<DirEntry, io::Error>) -> OptionResult<DirEntry,
|
||||
// given a result from readdir attempts to parse it as a cache file handle
|
||||
// i.e. validate its file name against the contents (excluding file-hash-validation)
|
||||
// validate that it is a file, correct len, and is not too large.
|
||||
fn try_parse_cache_file(file_result: io::Result<DirEntry>, capacity: u64) -> OptionResult<CacheItem, ChunkCacheError> {
|
||||
fn try_parse_cache_file(
|
||||
file_result: io::Result<DirEntry>,
|
||||
capacity: u64,
|
||||
config: &XetConfig,
|
||||
) -> OptionResult<CacheItem, ChunkCacheError> {
|
||||
let item = match file_result {
|
||||
Ok(item) => item,
|
||||
Err(e) => {
|
||||
@@ -687,10 +695,10 @@ fn try_parse_cache_file(file_result: io::Result<DirEntry>, capacity: u64) -> Opt
|
||||
if !md.is_file() {
|
||||
return Ok(None);
|
||||
}
|
||||
if md.len() > xet_config().chunk_cache.size_bytes {
|
||||
if md.len() > config.chunk_cache.size_bytes {
|
||||
return Err(ChunkCacheError::general(format!(
|
||||
"Cache directory contains a file larger than {} GB, cache directory state is invalid",
|
||||
(xet_config().chunk_cache.size_bytes as f64 / (1 << 30) as f64)
|
||||
(config.chunk_cache.size_bytes as f64 / (1 << 30) as f64)
|
||||
)));
|
||||
}
|
||||
|
||||
@@ -817,6 +825,7 @@ mod tests {
|
||||
use rand::SeedableRng;
|
||||
use rand::rngs::StdRng;
|
||||
use tempfile::TempDir;
|
||||
use xet_runtime::config::XetConfig;
|
||||
use xet_runtime::utils::output_bytes;
|
||||
|
||||
use super::super::{CacheConfig, ChunkCache};
|
||||
@@ -835,9 +844,9 @@ mod tests {
|
||||
let config = CacheConfig {
|
||||
cache_directory: cache_root.path().to_path_buf(),
|
||||
cache_size: DEFAULT_CHUNK_CACHE_CAPACITY,
|
||||
..Default::default()
|
||||
};
|
||||
let cache = DiskCache::initialize(&config).unwrap();
|
||||
let xet_config = XetConfig::new();
|
||||
let cache = DiskCache::initialize(&xet_config, &config).unwrap();
|
||||
assert!(
|
||||
cache
|
||||
.get(&random_key(&mut rng), &random_range(&mut rng))
|
||||
@@ -854,9 +863,9 @@ mod tests {
|
||||
let config = CacheConfig {
|
||||
cache_directory: cache_root.path().to_path_buf(),
|
||||
cache_size: DEFAULT_CHUNK_CACHE_CAPACITY,
|
||||
..Default::default()
|
||||
};
|
||||
let cache = DiskCache::initialize(&config).unwrap();
|
||||
let xet_config = XetConfig::new();
|
||||
let cache = DiskCache::initialize(&xet_config, &config).unwrap();
|
||||
|
||||
let key = random_key(&mut rng);
|
||||
let range = ChunkRange::new(0, 4);
|
||||
@@ -886,9 +895,9 @@ mod tests {
|
||||
let config = CacheConfig {
|
||||
cache_directory: cache_root.path().to_path_buf(),
|
||||
cache_size: DEFAULT_CHUNK_CACHE_CAPACITY,
|
||||
..Default::default()
|
||||
};
|
||||
let cache = DiskCache::initialize(&config).unwrap();
|
||||
let xet_config = XetConfig::new();
|
||||
let cache = DiskCache::initialize(&xet_config, &config).unwrap();
|
||||
|
||||
let key = random_key(&mut rng);
|
||||
// following parts of test assume overall inserted range includes chunk 0
|
||||
@@ -933,9 +942,9 @@ mod tests {
|
||||
let config = CacheConfig {
|
||||
cache_directory: cache_root.path().to_path_buf(),
|
||||
cache_size: CAP,
|
||||
..Default::default()
|
||||
};
|
||||
let cache = DiskCache::initialize(&config).unwrap();
|
||||
let xet_config = XetConfig::new();
|
||||
let cache = DiskCache::initialize(&xet_config, &config).unwrap();
|
||||
let mut it = RandomEntryIterator::std_from_seed(RANDOM_SEED);
|
||||
|
||||
// fill the cache to almost capacity
|
||||
@@ -957,9 +966,9 @@ mod tests {
|
||||
let config = CacheConfig {
|
||||
cache_directory: cache_root.path().to_path_buf(),
|
||||
cache_size: DEFAULT_CHUNK_CACHE_CAPACITY,
|
||||
..Default::default()
|
||||
};
|
||||
let cache = DiskCache::initialize(&config).unwrap();
|
||||
let xet_config = XetConfig::new();
|
||||
let cache = DiskCache::initialize(&xet_config, &config).unwrap();
|
||||
let mut it = RandomEntryIterator::std_from_seed(RANDOM_SEED).with_range_len(1000);
|
||||
let (key, range, offsets, data) = it.next().unwrap();
|
||||
assert!(cache.put(&key, &range, &offsets, &data).await.is_ok());
|
||||
@@ -974,9 +983,9 @@ mod tests {
|
||||
let config = CacheConfig {
|
||||
cache_directory: cache_root.path().to_path_buf(),
|
||||
cache_size: DEFAULT_CHUNK_CACHE_CAPACITY,
|
||||
..Default::default()
|
||||
};
|
||||
let cache = DiskCache::initialize(&config).unwrap();
|
||||
let xet_config = XetConfig::new();
|
||||
let cache = DiskCache::initialize(&xet_config, &config).unwrap();
|
||||
let (key, range, offsets, data) = it.next().unwrap();
|
||||
assert!(cache.put(&key, &range, &offsets, &data).await.is_ok());
|
||||
(cache_root, cache, key, range, offsets, data)
|
||||
@@ -1025,9 +1034,9 @@ mod tests {
|
||||
let config = CacheConfig {
|
||||
cache_directory: cache_root.path().to_path_buf(),
|
||||
cache_size: DEFAULT_CHUNK_CACHE_CAPACITY,
|
||||
..Default::default()
|
||||
};
|
||||
let cache = DiskCache::initialize(&config).unwrap();
|
||||
let xet_config = XetConfig::new();
|
||||
let cache = DiskCache::initialize(&xet_config, &config).unwrap();
|
||||
|
||||
let mut it = RandomEntryIterator::std_from_seed(RANDOM_SEED);
|
||||
|
||||
@@ -1039,7 +1048,7 @@ mod tests {
|
||||
keys_and_ranges.push((key, range));
|
||||
}
|
||||
|
||||
let cache2 = DiskCache::initialize(&config).unwrap();
|
||||
let cache2 = DiskCache::initialize(&xet_config, &config).unwrap();
|
||||
for (i, (key, range)) in keys_and_ranges.iter().enumerate() {
|
||||
let get_result = cache2.get(&key, &range).await;
|
||||
assert!(get_result.is_ok(), "{i} {get_result:?}");
|
||||
@@ -1058,9 +1067,9 @@ mod tests {
|
||||
let config = CacheConfig {
|
||||
cache_directory: cache_root.path().to_path_buf(),
|
||||
cache_size: DEFAULT_CHUNK_CACHE_CAPACITY,
|
||||
..Default::default()
|
||||
};
|
||||
let cache = DiskCache::initialize(&config).unwrap();
|
||||
let xet_config = XetConfig::new();
|
||||
let cache = DiskCache::initialize(&xet_config, &config).unwrap();
|
||||
let mut it = RandomEntryIterator::std_from_seed(RANDOM_SEED).with_range_len(LARGE_FILE as u32);
|
||||
|
||||
let (key, range, offsets, data) = it.next().unwrap();
|
||||
@@ -1068,9 +1077,8 @@ mod tests {
|
||||
let config = CacheConfig {
|
||||
cache_directory: cache_root.path().to_path_buf(),
|
||||
cache_size: LARGE_FILE - 1,
|
||||
..Default::default()
|
||||
};
|
||||
let cache2 = DiskCache::initialize(&config).unwrap();
|
||||
let cache2 = DiskCache::initialize(&xet_config, &config).unwrap();
|
||||
|
||||
assert_eq!(cache2.total_bytes().await, 0);
|
||||
}
|
||||
@@ -1079,12 +1087,12 @@ mod tests {
|
||||
async fn test_initialize_stops_loading_early_with_too_many_files() {
|
||||
const LARGE_FILE: u64 = 1000;
|
||||
let cache_root = TempDir::new().unwrap();
|
||||
let xet_config = XetConfig::new();
|
||||
let config = CacheConfig {
|
||||
cache_directory: cache_root.path().to_path_buf(),
|
||||
cache_size: LARGE_FILE * 10,
|
||||
..Default::default()
|
||||
};
|
||||
let cache = DiskCache::initialize(&config).unwrap();
|
||||
let cache = DiskCache::initialize(&xet_config, &config).unwrap();
|
||||
let mut it = RandomEntryIterator::std_from_seed(RANDOM_SEED).with_range_len(LARGE_FILE as u32);
|
||||
for _ in 0..10 {
|
||||
let (key, range, offsets, data) = it.next().unwrap();
|
||||
@@ -1095,9 +1103,8 @@ mod tests {
|
||||
let config = CacheConfig {
|
||||
cache_directory: cache_root.path().to_path_buf(),
|
||||
cache_size: cap2,
|
||||
..Default::default()
|
||||
};
|
||||
let cache2 = DiskCache::initialize(&config).unwrap();
|
||||
let cache2 = DiskCache::initialize(&xet_config, &config).unwrap();
|
||||
|
||||
assert!(cache2.total_bytes().await < cap2 * 3, "{} < {}", cache2.total_bytes().await, cap2 * 3);
|
||||
}
|
||||
@@ -1113,17 +1120,17 @@ mod tests {
|
||||
async fn test_unknown_eviction() {
|
||||
let cache_root = TempDir::new().unwrap();
|
||||
let capacity = 12 * RANGE_LEN as u64;
|
||||
let xet_config = XetConfig::new();
|
||||
let config = CacheConfig {
|
||||
cache_directory: cache_root.path().to_path_buf(),
|
||||
cache_size: capacity,
|
||||
..Default::default()
|
||||
};
|
||||
let cache = DiskCache::initialize(&config).unwrap();
|
||||
let cache = DiskCache::initialize(&xet_config, &config).unwrap();
|
||||
let mut it = RandomEntryIterator::std_from_seed(RANDOM_SEED);
|
||||
let (key, range, chunk_byte_indices, data) = it.next().unwrap();
|
||||
cache.put(&key, &range, &chunk_byte_indices, &data).await.unwrap();
|
||||
|
||||
let cache2 = DiskCache::initialize(&config).unwrap();
|
||||
let cache2 = DiskCache::initialize(&xet_config, &config).unwrap();
|
||||
let get_result = cache2.get(&key, &range).await;
|
||||
assert!(get_result.is_ok());
|
||||
assert!(get_result.unwrap().is_some());
|
||||
@@ -1156,9 +1163,9 @@ mod tests {
|
||||
let config = CacheConfig {
|
||||
cache_directory: cache_root.path().to_path_buf(),
|
||||
cache_size: DEFAULT_CHUNK_CACHE_CAPACITY,
|
||||
..Default::default()
|
||||
};
|
||||
let cache = DiskCache::initialize(&config).unwrap();
|
||||
let xet_config = XetConfig::new();
|
||||
let cache = DiskCache::initialize(&xet_config, &config).unwrap();
|
||||
|
||||
let (key, range, chunk_byte_indices, data) = RandomEntryIterator::std_from_seed(RANDOM_SEED).next().unwrap();
|
||||
cache.put(&key, &range, &chunk_byte_indices, &data).await.unwrap();
|
||||
@@ -1207,12 +1214,12 @@ mod tests {
|
||||
const NUM: u32 = 12;
|
||||
let cache_root = TempDir::new().unwrap();
|
||||
let capacity = (NUM * RANGE_LEN) as u64;
|
||||
let xet_config = XetConfig::new();
|
||||
let config = CacheConfig {
|
||||
cache_directory: cache_root.path().to_path_buf(),
|
||||
cache_size: capacity,
|
||||
..Default::default()
|
||||
};
|
||||
let cache = DiskCache::initialize(&config).unwrap();
|
||||
let cache = DiskCache::initialize(&xet_config, &config).unwrap();
|
||||
let mut it = RandomEntryIterator::std_from_seed(RANDOM_SEED).with_one_chunk_ranges(true);
|
||||
let (key, _, _, _) = it.next().unwrap();
|
||||
let mut previously_put: Vec<(Key, ChunkRange)> = Vec::new();
|
||||
@@ -1246,11 +1253,15 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_initialize_with_cache_size_0() {
|
||||
let xet_config = XetConfig::new();
|
||||
assert!(
|
||||
DiskCache::initialize(&CacheConfig {
|
||||
cache_directory: "/tmp".into(),
|
||||
cache_size: 0,
|
||||
})
|
||||
DiskCache::initialize(
|
||||
&xet_config,
|
||||
&CacheConfig {
|
||||
cache_directory: "/tmp".into(),
|
||||
cache_size: 0,
|
||||
},
|
||||
)
|
||||
.is_err()
|
||||
);
|
||||
}
|
||||
@@ -1259,6 +1270,7 @@ mod tests {
|
||||
#[cfg(test)]
|
||||
mod concurrency_tests {
|
||||
use tempfile::TempDir;
|
||||
use xet_runtime::config::XetConfig;
|
||||
|
||||
use super::super::{CacheConfig, ChunkCache};
|
||||
use super::DiskCache;
|
||||
@@ -1276,9 +1288,9 @@ mod concurrency_tests {
|
||||
let config = CacheConfig {
|
||||
cache_directory: cache_root.path().to_path_buf(),
|
||||
cache_size: DEFAULT_CHUNK_CACHE_CAPACITY,
|
||||
..Default::default()
|
||||
};
|
||||
let cache = DiskCache::initialize(&config).unwrap();
|
||||
let xet_config = XetConfig::new();
|
||||
let cache = DiskCache::initialize(&xet_config, &config).unwrap();
|
||||
|
||||
let num_tasks = 2 + rand::random::<u8>() % 14;
|
||||
|
||||
@@ -1308,12 +1320,12 @@ mod concurrency_tests {
|
||||
#[cfg_attr(feature = "smoke-test", ignore)]
|
||||
async fn test_run_concurrently_with_evictions() {
|
||||
let cache_root = TempDir::new().unwrap();
|
||||
let xet_config = XetConfig::new();
|
||||
let config = CacheConfig {
|
||||
cache_directory: cache_root.path().to_path_buf(),
|
||||
cache_size: RANGE_LEN as u64 * NUM_ITEMS_PER_TASK as u64,
|
||||
..Default::default()
|
||||
};
|
||||
let cache = DiskCache::initialize(&config).unwrap();
|
||||
let cache = DiskCache::initialize(&xet_config, &config).unwrap();
|
||||
|
||||
let num_tasks = 2 + rand::random::<u8>() % 14;
|
||||
|
||||
@@ -1342,11 +1354,12 @@ mod concurrency_tests {
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
async fn test_run_concurrently_thundering_herd() {
|
||||
let cache_root = TempDir::new().unwrap();
|
||||
let xet_config = XetConfig::new();
|
||||
let config = CacheConfig {
|
||||
cache_directory: cache_root.path().to_path_buf(),
|
||||
cache_size: RANGE_LEN as u64 * NUM_ITEMS_PER_TASK as u64,
|
||||
};
|
||||
let cache = DiskCache::initialize(&config).unwrap();
|
||||
let cache = DiskCache::initialize(&xet_config, &config).unwrap();
|
||||
|
||||
// data inserted is the same
|
||||
let mut it = RandomEntryIterator::std_from_seed(RANDOM_SEED);
|
||||
|
||||
@@ -11,7 +11,7 @@ pub use disk::test_utils::*;
|
||||
use error::ChunkCacheError;
|
||||
#[cfg(test)]
|
||||
use mockall::automock;
|
||||
use xet_runtime::core::xet_config;
|
||||
use xet_runtime::config::XetConfig;
|
||||
|
||||
use crate::cas_types::{ChunkRange, Key};
|
||||
|
||||
@@ -79,11 +79,11 @@ pub struct CacheConfig {
|
||||
pub cache_size: u64,
|
||||
}
|
||||
|
||||
impl Default for CacheConfig {
|
||||
fn default() -> Self {
|
||||
impl CacheConfig {
|
||||
pub fn from_config(config: &XetConfig) -> Self {
|
||||
CacheConfig {
|
||||
cache_directory: PathBuf::from("/tmp"),
|
||||
cache_size: xet_config().chunk_cache.size_bytes,
|
||||
cache_size: config.chunk_cache.size_bytes,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,8 @@ use reqwest::{Request, Response};
|
||||
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, Middleware, Next};
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{Instrument, info, info_span, warn};
|
||||
use xet_runtime::core::{XetRuntime, xet_config};
|
||||
use xet_runtime::config::XetConfig;
|
||||
use xet_runtime::core::XetContext;
|
||||
use xet_runtime::error_printer::{ErrorPrinter, OptionPrinter};
|
||||
|
||||
use crate::cas_client::auth::{AuthConfig, TokenProvider};
|
||||
@@ -69,57 +70,35 @@ fn headers_tag(headers: Option<&HeaderMap>) -> String {
|
||||
|
||||
#[allow(unused_variables)]
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
fn reqwest_client(unix_socket_path: Option<&str>, custom_headers: Option<Arc<HeaderMap>>) -> Result<reqwest::Client> {
|
||||
// Check config if explicit socket path is not provided
|
||||
fn reqwest_client_raw(
|
||||
config: &XetConfig,
|
||||
unix_socket_path: Option<&str>,
|
||||
custom_headers: Option<Arc<HeaderMap>>,
|
||||
) -> std::result::Result<reqwest::Client, reqwest::Error> {
|
||||
let socket_path = unix_socket_path
|
||||
.map(|s| s.to_string())
|
||||
.or_else(|| xet_config().client.unix_socket_path.clone());
|
||||
.or_else(|| config.client.unix_socket_path.clone());
|
||||
|
||||
// Build a cache tag that captures both the transport (socket path / TCP) and the
|
||||
// set of default headers so that clients with different headers get separate pools.
|
||||
let tag = format!("{}|{}", socket_path.as_deref().unwrap_or("tcp"), headers_tag(custom_headers.as_deref()));
|
||||
|
||||
// Create client function
|
||||
let socket_path_clone = socket_path.clone();
|
||||
let socket_path_for_builder = socket_path.clone();
|
||||
let custom_headers_for_client = custom_headers.clone();
|
||||
let create_client = move || {
|
||||
let config = &xet_config().client;
|
||||
let mut builder = reqwest::Client::builder()
|
||||
.pool_idle_timeout(config.idle_connection_timeout)
|
||||
.pool_max_idle_per_host(config.max_idle_connections)
|
||||
.connect_timeout(config.connect_timeout)
|
||||
.read_timeout(config.read_timeout)
|
||||
.http1_only();
|
||||
let client_cfg = &config.client;
|
||||
let mut builder = reqwest::Client::builder()
|
||||
.pool_idle_timeout(client_cfg.idle_connection_timeout)
|
||||
.pool_max_idle_per_host(client_cfg.max_idle_connections)
|
||||
.connect_timeout(client_cfg.connect_timeout)
|
||||
.read_timeout(client_cfg.read_timeout)
|
||||
.http1_only();
|
||||
|
||||
#[cfg(unix)]
|
||||
if let Some(ref path) = socket_path_clone {
|
||||
builder = builder.unix_socket(path.clone());
|
||||
}
|
||||
|
||||
if let Some(headers) = custom_headers_for_client {
|
||||
builder = builder.default_headers((*headers).clone());
|
||||
}
|
||||
|
||||
builder.build()
|
||||
};
|
||||
|
||||
// Try to use cached client if in a runtime, otherwise create directly
|
||||
let client = XetRuntime::get_or_create_reqwest_client(tag, create_client)?;
|
||||
|
||||
if socket_path.is_some() {
|
||||
info!(socket_path=?socket_path, "HTTP client configured with Unix socket");
|
||||
} else {
|
||||
let config = &xet_config().client;
|
||||
let custom_headers = custom_headers.as_deref().map(redact_headers);
|
||||
info!(
|
||||
idle_timeout=?config.idle_connection_timeout,
|
||||
max_idle_connections=config.max_idle_connections,
|
||||
custom_headers=?custom_headers,
|
||||
"HTTP client configured"
|
||||
);
|
||||
#[cfg(unix)]
|
||||
if let Some(ref path) = socket_path_for_builder {
|
||||
builder = builder.unix_socket(path.clone());
|
||||
}
|
||||
|
||||
Ok(client)
|
||||
if let Some(headers) = custom_headers_for_client {
|
||||
builder = builder.default_headers((*headers).clone());
|
||||
}
|
||||
|
||||
builder.build()
|
||||
}
|
||||
|
||||
/// Creates a reqwest client with no read_timeout. Used for shard uploads where server-side
|
||||
@@ -128,19 +107,19 @@ fn reqwest_client(unix_socket_path: Option<&str>, custom_headers: Option<Arc<Hea
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
#[allow(unused_variables)]
|
||||
fn reqwest_client_no_read_timeout(
|
||||
config: &XetConfig,
|
||||
unix_socket_path: Option<&str>,
|
||||
custom_headers: Option<Arc<HeaderMap>>,
|
||||
) -> Result<reqwest::Client> {
|
||||
let socket_path = unix_socket_path
|
||||
.map(|s| s.to_string())
|
||||
.or_else(|| xet_config().client.unix_socket_path.clone());
|
||||
.or_else(|| config.client.unix_socket_path.clone());
|
||||
|
||||
let config = &xet_config().client;
|
||||
let client_cfg = &config.client;
|
||||
let mut builder = reqwest::Client::builder()
|
||||
.pool_idle_timeout(config.idle_connection_timeout)
|
||||
.pool_max_idle_per_host(config.max_idle_connections)
|
||||
.connect_timeout(config.connect_timeout)
|
||||
// No read_timeout — shard processing time scales with entry count and is unbounded
|
||||
.pool_idle_timeout(client_cfg.idle_connection_timeout)
|
||||
.pool_max_idle_per_host(client_cfg.max_idle_connections)
|
||||
.connect_timeout(client_cfg.connect_timeout)
|
||||
.http1_only();
|
||||
|
||||
#[cfg(unix)]
|
||||
@@ -155,7 +134,7 @@ fn reqwest_client_no_read_timeout(
|
||||
let client = builder.build()?;
|
||||
|
||||
info!(
|
||||
connect_timeout=?config.connect_timeout,
|
||||
connect_timeout=?client_cfg.connect_timeout,
|
||||
"No-read-timeout HTTP client configured (for shard uploads)"
|
||||
);
|
||||
|
||||
@@ -163,20 +142,22 @@ fn reqwest_client_no_read_timeout(
|
||||
}
|
||||
|
||||
#[cfg(target_family = "wasm")]
|
||||
fn reqwest_client(_unix_socket_path: Option<&str>, custom_headers: Option<Arc<HeaderMap>>) -> Result<reqwest::Client> {
|
||||
// For WASM, create a new client with the specified headers, including the user-agent.
|
||||
// Note: we could cache this, but user_agent can vary, so we create per-call
|
||||
// Unix socket path is ignored on WASM
|
||||
fn reqwest_client_raw(
|
||||
_config: &XetConfig,
|
||||
_unix_socket_path: Option<&str>,
|
||||
custom_headers: Option<Arc<HeaderMap>>,
|
||||
) -> std::result::Result<reqwest::Client, reqwest::Error> {
|
||||
let mut builder = reqwest::Client::builder();
|
||||
if let Some(custom_headers) = custom_headers {
|
||||
builder = builder.default_headers((*custom_headers).clone());
|
||||
}
|
||||
Ok(builder.build()?)
|
||||
builder.build()
|
||||
}
|
||||
|
||||
/// Builds authenticated HTTP Client to talk to CAS.
|
||||
#[allow(unused_mut)]
|
||||
pub fn build_auth_http_client(
|
||||
ctx: &XetContext,
|
||||
auth_config: &Option<AuthConfig>,
|
||||
session_id: &str,
|
||||
unix_socket_path: Option<&str>,
|
||||
@@ -186,7 +167,32 @@ pub fn build_auth_http_client(
|
||||
let logging_middleware = Some(LoggingMiddleware);
|
||||
let session_middleware = (!session_id.is_empty()).then(|| SessionMiddleware(session_id.to_owned()));
|
||||
|
||||
let mut builder = ClientBuilder::new(reqwest_client(unix_socket_path, custom_headers)?);
|
||||
let config_arc = ctx.config.clone();
|
||||
let unix_owned = unix_socket_path.map(|s| s.to_string());
|
||||
let custom_for_client = custom_headers.clone();
|
||||
let socket_path = unix_socket_path
|
||||
.map(|s| s.to_string())
|
||||
.or_else(|| config_arc.client.unix_socket_path.clone());
|
||||
let tag = format!("{}|{}", socket_path.as_deref().unwrap_or("tcp"), headers_tag(custom_headers.as_deref()));
|
||||
|
||||
let raw_client = ctx.common.get_or_create_reqwest_client(tag, move || {
|
||||
reqwest_client_raw(config_arc.as_ref(), unix_owned.as_deref(), custom_for_client)
|
||||
})?;
|
||||
|
||||
if socket_path.is_some() {
|
||||
info!(socket_path=?socket_path, "HTTP client configured with Unix socket");
|
||||
} else {
|
||||
let client_cfg = &ctx.config.client;
|
||||
let custom_headers_log = custom_headers.as_deref().map(redact_headers);
|
||||
info!(
|
||||
idle_timeout=?client_cfg.idle_connection_timeout,
|
||||
max_idle_connections=client_cfg.max_idle_connections,
|
||||
custom_headers=?custom_headers_log,
|
||||
"HTTP client configured"
|
||||
);
|
||||
}
|
||||
|
||||
let mut builder = ClientBuilder::new(raw_client);
|
||||
|
||||
#[cfg(unix)]
|
||||
if unix_socket_path.is_some() {
|
||||
@@ -208,6 +214,7 @@ pub fn build_auth_http_client(
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
#[allow(unused_mut)]
|
||||
pub fn build_auth_http_client_no_read_timeout(
|
||||
ctx: &XetContext,
|
||||
auth_config: &Option<AuthConfig>,
|
||||
session_id: &str,
|
||||
unix_socket_path: Option<&str>,
|
||||
@@ -217,7 +224,7 @@ pub fn build_auth_http_client_no_read_timeout(
|
||||
let logging_middleware = Some(LoggingMiddleware);
|
||||
let session_middleware = (!session_id.is_empty()).then(|| SessionMiddleware(session_id.to_owned()));
|
||||
|
||||
let raw_client = reqwest_client_no_read_timeout(unix_socket_path, custom_headers)?;
|
||||
let raw_client = reqwest_client_no_read_timeout(ctx.config.as_ref(), unix_socket_path, custom_headers)?;
|
||||
let mut builder = ClientBuilder::new(raw_client);
|
||||
|
||||
#[cfg(unix)]
|
||||
@@ -234,11 +241,12 @@ pub fn build_auth_http_client_no_read_timeout(
|
||||
|
||||
/// Builds HTTP Client to talk to CAS.
|
||||
pub fn build_http_client(
|
||||
ctx: &XetContext,
|
||||
session_id: &str,
|
||||
unix_socket_path: Option<&str>,
|
||||
custom_headers: Option<Arc<HeaderMap>>,
|
||||
) -> Result<ClientWithMiddleware> {
|
||||
build_auth_http_client(&None, session_id, unix_socket_path, custom_headers)
|
||||
build_auth_http_client(ctx, &None, session_id, unix_socket_path, custom_headers)
|
||||
}
|
||||
|
||||
/// Helper trait to allow the reqwest_middleware client to optionally add a middleware.
|
||||
@@ -460,7 +468,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_build_http_client_without_uds() {
|
||||
let result = build_http_client("test-session", None, None);
|
||||
let ctx = XetContext::default().expect("xet context");
|
||||
let result = build_http_client(&ctx, "test-session", None, None);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
@@ -469,30 +478,35 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_build_no_read_timeout_succeeds() {
|
||||
let result = build_auth_http_client_no_read_timeout(&None, "test-session", None, None);
|
||||
let ctx = XetContext::default().expect("xet context");
|
||||
let result = build_auth_http_client_no_read_timeout(&ctx, &None, "test-session", None, None);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_no_read_timeout_with_empty_session_id() {
|
||||
let result = build_auth_http_client_no_read_timeout(&None, "", None, None);
|
||||
let ctx = XetContext::default().expect("xet context");
|
||||
let result = build_auth_http_client_no_read_timeout(&ctx, &None, "", None, None);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_no_read_timeout_with_custom_headers() {
|
||||
let ctx = XetContext::default().expect("xet context");
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("X-Custom-Header", HeaderValue::from_static("test-value"));
|
||||
headers.insert(reqwest::header::USER_AGENT, HeaderValue::from_static("test-agent/1.0"));
|
||||
|
||||
let result = build_auth_http_client_no_read_timeout(&None, "test-session", None, Some(Arc::new(headers)));
|
||||
let result =
|
||||
build_auth_http_client_no_read_timeout(&ctx, &None, "test-session", None, Some(Arc::new(headers)));
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_read_timeout_client_is_distinct_from_standard_client() {
|
||||
let standard = build_auth_http_client(&None, "test-session", None, None).unwrap();
|
||||
let no_timeout = build_auth_http_client_no_read_timeout(&None, "test-session", None, None).unwrap();
|
||||
let ctx = XetContext::default().expect("xet context");
|
||||
let standard = build_auth_http_client(&ctx, &None, "test-session", None, None).unwrap();
|
||||
let no_timeout = build_auth_http_client_no_read_timeout(&ctx, &None, "test-session", None, None).unwrap();
|
||||
|
||||
assert_ne!(
|
||||
format!("{:p}", &standard),
|
||||
|
||||
@@ -2,6 +2,7 @@ use std::sync::Arc;
|
||||
|
||||
use http::header::HeaderMap;
|
||||
use urlencoding::encode;
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
use super::types::{CasJWTInfo, RepoInfo};
|
||||
use crate::cas_client::exports::ClientWithMiddleware;
|
||||
@@ -35,6 +36,7 @@ impl Operation {
|
||||
}
|
||||
|
||||
pub struct HubClient {
|
||||
ctx: XetContext,
|
||||
endpoint: String,
|
||||
repo_info: RepoInfo,
|
||||
reference: Option<String>,
|
||||
@@ -44,6 +46,7 @@ pub struct HubClient {
|
||||
|
||||
impl HubClient {
|
||||
pub fn new(
|
||||
ctx: XetContext,
|
||||
endpoint: &str,
|
||||
repo_info: RepoInfo,
|
||||
reference: Option<String>,
|
||||
@@ -52,10 +55,11 @@ impl HubClient {
|
||||
custom_headers: Option<HeaderMap>,
|
||||
) -> Result<Self> {
|
||||
Ok(HubClient {
|
||||
ctx: ctx.clone(),
|
||||
endpoint: endpoint.to_owned(),
|
||||
repo_info,
|
||||
reference,
|
||||
client: build_http_client(session_id, None, custom_headers.map(|ch| ch.into()))?,
|
||||
client: build_http_client(&ctx, session_id, None, custom_headers.map(|ch| ch.into()))?,
|
||||
cred_helper,
|
||||
})
|
||||
}
|
||||
@@ -88,7 +92,7 @@ impl HubClient {
|
||||
let client = self.client.clone();
|
||||
let cred_helper = self.cred_helper.clone();
|
||||
|
||||
let info: CasJWTInfo = RetryWrapper::new("xet-token")
|
||||
let info: CasJWTInfo = RetryWrapper::new(self.ctx.clone(), "xet-token")
|
||||
.run_and_extract_json(move || {
|
||||
let url = url.clone();
|
||||
let client = client.clone();
|
||||
@@ -110,6 +114,7 @@ impl HubClient {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use http::header::{self, HeaderMap, HeaderValue};
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
use super::super::{BearerCredentialHelper, HFRepoType, Operation, RepoInfo};
|
||||
use super::HubClient;
|
||||
@@ -122,6 +127,7 @@ mod tests {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(header::USER_AGENT, HeaderValue::from_static("xtool"));
|
||||
let hub_client = HubClient::new(
|
||||
XetContext::default().expect("runtime"),
|
||||
"https://huggingface.co",
|
||||
RepoInfo {
|
||||
repo_type: HFRepoType::Model,
|
||||
@@ -149,6 +155,7 @@ mod tests {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(header::USER_AGENT, HeaderValue::from_static("xtool"));
|
||||
let hub_client = HubClient::new(
|
||||
XetContext::default().expect("runtime"),
|
||||
"https://huggingface.co",
|
||||
RepoInfo {
|
||||
repo_type: HFRepoType::Model,
|
||||
@@ -176,6 +183,7 @@ mod tests {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert(header::USER_AGENT, HeaderValue::from_static("xtool"));
|
||||
let hub_client = HubClient::new(
|
||||
XetContext::default().expect("runtime"),
|
||||
"https://huggingface.co",
|
||||
RepoInfo {
|
||||
repo_type: HFRepoType::Model,
|
||||
|
||||
@@ -17,8 +17,8 @@ pub use constants::{
|
||||
hash_is_global_dedup_eligible,
|
||||
};
|
||||
pub use file_structs::Sha256;
|
||||
pub use shard_file_handle::MDBShardFile;
|
||||
pub use shard_file_manager::ShardFileManager;
|
||||
pub use shard_file_handle::{MDBShardFile, ShardFileCache, new_shard_file_cache};
|
||||
pub use shard_file_manager::{ShardFileManager, get_shard_file_cache};
|
||||
pub use shard_format::{MDBShardFileFooter, MDBShardFileHeader, MDBShardInfo};
|
||||
|
||||
// Temporary to transition dependent code to new location
|
||||
|
||||
@@ -6,10 +6,11 @@ use std::sync::Arc;
|
||||
|
||||
use tokio::task::JoinHandle;
|
||||
use tracing::{error, info};
|
||||
use xet_runtime::core::{XetRuntime, check_sigint_shutdown};
|
||||
use xet_runtime::RuntimeError;
|
||||
use xet_runtime::core::XetRuntime;
|
||||
|
||||
use super::set_operations::shard_set_union;
|
||||
use super::shard_file_handle::MDBShardFile;
|
||||
use super::shard_file_handle::{MDBShardFile, ShardFileCache};
|
||||
use super::{MDBShardFileFooter, MDBShardFileHeader, MDBShardInfo};
|
||||
use crate::error::Result;
|
||||
|
||||
@@ -19,13 +20,16 @@ use crate::error::Result;
|
||||
///
|
||||
/// Ordering of staged shards is preserved.
|
||||
pub fn consolidate_shards_in_directory(
|
||||
runtime: &Arc<XetRuntime>,
|
||||
session_directory: impl AsRef<Path>,
|
||||
target_max_size: u64,
|
||||
skip_on_error: bool,
|
||||
cache: &ShardFileCache,
|
||||
) -> Result<Vec<Arc<MDBShardFile>>> {
|
||||
let session_directory = session_directory.as_ref();
|
||||
// Get the new shards and the shards in the original list to remove.
|
||||
let shard_merge_result = merge_shards(session_directory, session_directory, target_max_size, skip_on_error)?;
|
||||
let shard_merge_result =
|
||||
merge_shards(runtime, session_directory, session_directory, target_max_size, skip_on_error, cache)?;
|
||||
|
||||
// Now, go through and remove all the shards in the delete list.
|
||||
for sfi in shard_merge_result.obsolete_shards {
|
||||
@@ -54,12 +58,14 @@ pub struct ShardMergeResult {
|
||||
/// Ordering of staged shards is preserved.
|
||||
#[allow(clippy::needless_range_loop)] // The alternative is less readable IMO
|
||||
pub fn merge_shards(
|
||||
runtime: &Arc<XetRuntime>,
|
||||
source_directory: impl AsRef<Path>,
|
||||
target_directory: impl AsRef<Path>,
|
||||
target_max_size: u64,
|
||||
skip_on_error: bool,
|
||||
cache: &ShardFileCache,
|
||||
) -> Result<ShardMergeResult> {
|
||||
let mut shards: Vec<_> = MDBShardFile::load_all_valid(source_directory.as_ref())?;
|
||||
let mut shards: Vec<_> = MDBShardFile::load_all_valid(source_directory.as_ref(), cache)?;
|
||||
|
||||
shards.sort_unstable_by_key(|si| si.last_modified_time);
|
||||
|
||||
@@ -80,7 +86,9 @@ pub fn merge_shards(
|
||||
let mut cur_si = MDBShardInfo::default();
|
||||
|
||||
for sfi in shards {
|
||||
check_sigint_shutdown()?;
|
||||
if runtime.in_sigint_shutdown() {
|
||||
return Err(RuntimeError::KeyboardInterrupt.into());
|
||||
}
|
||||
|
||||
// Now, load the new shard data in. To be resiliant to the possibility of shards
|
||||
// being deleted under us (as can happen in shard session resume with multiple
|
||||
@@ -120,7 +128,7 @@ pub fn merge_shards(
|
||||
swap(&mut out_data, &mut cur_data);
|
||||
} else {
|
||||
// Flush everything out and replace the new.
|
||||
let out_sfi = MDBShardFile::write_out_from_reader(&target_directory, &mut Cursor::new(&cur_data))?;
|
||||
let out_sfi = MDBShardFile::write_out_from_reader(&target_directory, &mut Cursor::new(&cur_data), cache)?;
|
||||
dest_shards.push(out_sfi);
|
||||
|
||||
// Move the loaded data into the current buffer.
|
||||
@@ -131,7 +139,7 @@ pub fn merge_shards(
|
||||
|
||||
// If there is any left over at the end, flush that as well.
|
||||
if !cur_data.is_empty() {
|
||||
let out_sfi = MDBShardFile::write_out_from_reader(&target_directory, &mut Cursor::new(&cur_data))?;
|
||||
let out_sfi = MDBShardFile::write_out_from_reader(&target_directory, &mut Cursor::new(&cur_data), cache)?;
|
||||
dest_shards.push(out_sfi);
|
||||
}
|
||||
|
||||
@@ -152,14 +160,17 @@ pub fn merge_shards(
|
||||
|
||||
/// Same as above, but performs it in the background and on a io focused thread.
|
||||
pub fn merge_shards_background(
|
||||
runtime: Arc<XetRuntime>,
|
||||
source_directory: impl AsRef<Path>,
|
||||
target_directory: impl AsRef<Path>,
|
||||
target_max_size: u64,
|
||||
skip_on_error: bool,
|
||||
cache: ShardFileCache,
|
||||
) -> JoinHandle<Result<ShardMergeResult>> {
|
||||
let source_directory = source_directory.as_ref().to_owned();
|
||||
let target_directory = target_directory.as_ref().to_owned();
|
||||
|
||||
XetRuntime::current()
|
||||
.spawn_blocking(move || merge_shards(source_directory, target_directory, target_max_size, skip_on_error))
|
||||
let rt = runtime.clone();
|
||||
runtime.spawn_blocking(move || {
|
||||
merge_shards(&rt, &source_directory, &target_directory, target_max_size, skip_on_error, &cache)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ use clap::Parser;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{RngExt, SeedableRng};
|
||||
use tempfile::TempDir;
|
||||
use tokio::runtime::Handle;
|
||||
use tokio::time;
|
||||
use xet_core_structures::merklehash::MerkleHash;
|
||||
use xet_core_structures::metadata_shard::shard_file_manager::ShardFileManager;
|
||||
@@ -16,6 +17,8 @@ use xet_core_structures::metadata_shard::shard_format::MDBShardInfo;
|
||||
use xet_core_structures::metadata_shard::shard_format::test_routines::rng_hash;
|
||||
use xet_core_structures::metadata_shard::shard_in_memory::MDBInMemoryShard;
|
||||
use xet_core_structures::metadata_shard::xorb_structs::{MDBXorbInfo, XorbChunkSequenceEntry, XorbChunkSequenceHeader};
|
||||
use xet_runtime::config::XetConfig;
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
const XORB_BLOCK_SIZE: usize = 512;
|
||||
const PAR_TASK: usize = 1;
|
||||
@@ -48,6 +51,7 @@ fn make_shard(size: u64, seed: &mut u64) -> MDBInMemoryShard {
|
||||
}
|
||||
|
||||
async fn run_shard_benchmark(
|
||||
ctx: &XetContext,
|
||||
shard_sizes: Vec<(u64, u64)>,
|
||||
file_contiguity: usize,
|
||||
contiguity: usize,
|
||||
@@ -76,7 +80,7 @@ async fn run_shard_benchmark(
|
||||
|
||||
// Now, spawn tasks to
|
||||
let counter = Arc::new(AtomicUsize::new(0));
|
||||
let mdb = ShardFileManager::new_in_session_directory(dir, false).await?;
|
||||
let mdb = ShardFileManager::new_in_session_directory(ctx, dir, false).await?;
|
||||
|
||||
let start_time = Instant::now();
|
||||
|
||||
@@ -184,6 +188,8 @@ struct ShardBenchmarkArgs {
|
||||
async fn main() {
|
||||
let args = ShardBenchmarkArgs::parse();
|
||||
|
||||
let ctx = XetContext::from_external(Handle::current(), XetConfig::new());
|
||||
|
||||
let temp_dir = TempDir::with_prefix("git-xet-shard").expect("Failed to create temp dir");
|
||||
let dir = args.dir.unwrap_or_else(|| temp_dir.path().into());
|
||||
eprintln!("Using dir {dir:?}");
|
||||
@@ -194,6 +200,7 @@ async fn main() {
|
||||
assert!(dir.exists());
|
||||
|
||||
run_shard_benchmark(
|
||||
&ctx,
|
||||
args.shard_sizes,
|
||||
args.contiguity,
|
||||
args.file_contiguity,
|
||||
|
||||
@@ -45,8 +45,10 @@ impl Default for MDBShardFile {
|
||||
}
|
||||
}
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
static ref MDB_SHARD_FILE_CACHE: RwLock<HashMap<PathBuf, Arc<MDBShardFile>>> = RwLock::new(HashMap::default());
|
||||
pub type ShardFileCache = Arc<RwLock<HashMap<PathBuf, Arc<MDBShardFile>>>>;
|
||||
|
||||
pub fn new_shard_file_cache() -> ShardFileCache {
|
||||
Arc::new(RwLock::new(HashMap::new()))
|
||||
}
|
||||
|
||||
impl MDBShardFile {
|
||||
@@ -64,11 +66,19 @@ impl MDBShardFile {
|
||||
Ok(s)
|
||||
}
|
||||
|
||||
pub fn copy_into_target_directory(&self, target_directory: impl AsRef<Path>) -> Result<Arc<Self>> {
|
||||
Self::write_out_from_reader(target_directory, &mut self.get_reader()?)
|
||||
pub fn copy_into_target_directory(
|
||||
&self,
|
||||
target_directory: impl AsRef<Path>,
|
||||
cache: &ShardFileCache,
|
||||
) -> Result<Arc<Self>> {
|
||||
Self::write_out_from_reader(target_directory, &mut self.get_reader()?, cache)
|
||||
}
|
||||
|
||||
pub fn write_out_from_reader<R: Read>(target_directory: impl AsRef<Path>, reader: &mut R) -> Result<Arc<Self>> {
|
||||
pub fn write_out_from_reader<R: Read>(
|
||||
target_directory: impl AsRef<Path>,
|
||||
reader: &mut R,
|
||||
cache: &ShardFileCache,
|
||||
) -> Result<Arc<Self>> {
|
||||
let target_directory = target_directory.as_ref();
|
||||
|
||||
let mut hashed_write; // Need to access after file is closed.
|
||||
@@ -97,16 +107,17 @@ impl MDBShardFile {
|
||||
|
||||
std::fs::rename(&temp_file_name, &full_file_name)?;
|
||||
|
||||
Self::load_from_hash_and_path(shard_hash, &full_file_name)
|
||||
Self::load_from_hash_and_path(shard_hash, &full_file_name, cache)
|
||||
}
|
||||
|
||||
pub fn export_with_expiration(
|
||||
&self,
|
||||
target_directory: impl AsRef<Path>,
|
||||
shard_valid_for: Duration,
|
||||
cache: &ShardFileCache,
|
||||
) -> Result<Arc<Self>> {
|
||||
let now = SystemTime::now();
|
||||
self.export_with_specific_expiration(target_directory, now.add(shard_valid_for), now)
|
||||
self.export_with_specific_expiration(target_directory, now.add(shard_valid_for), now, cache)
|
||||
}
|
||||
|
||||
pub fn export_with_specific_expiration(
|
||||
@@ -114,6 +125,7 @@ impl MDBShardFile {
|
||||
target_directory: impl AsRef<Path>,
|
||||
expiration: SystemTime,
|
||||
creation_time: SystemTime,
|
||||
cache: &ShardFileCache,
|
||||
) -> Result<Arc<Self>> {
|
||||
// New footer with the proper expiration added.
|
||||
let mut out_footer = self.shard.metadata.clone();
|
||||
@@ -132,15 +144,15 @@ impl MDBShardFile {
|
||||
Self::write_out_from_reader(
|
||||
target_directory,
|
||||
&mut reader.take(out_footer.footer_offset).chain(Cursor::new(out_footer_bytes)),
|
||||
cache,
|
||||
)
|
||||
}
|
||||
|
||||
fn load_from_hash_and_path(shard_hash: MerkleHash, path: &Path) -> Result<Arc<Self>> {
|
||||
fn load_from_hash_and_path(shard_hash: MerkleHash, path: &Path, cache: &ShardFileCache) -> Result<Arc<Self>> {
|
||||
let path = std::path::absolute(path)?;
|
||||
|
||||
// First see if it's in the shard file cache.
|
||||
{
|
||||
let lg = MDB_SHARD_FILE_CACHE.read().unwrap();
|
||||
let lg = cache.read().unwrap();
|
||||
if let Some(sf) = lg.get(&path) {
|
||||
return Ok(sf.clone());
|
||||
}
|
||||
@@ -157,34 +169,32 @@ impl MDBShardFile {
|
||||
disable_verifications: false.into(),
|
||||
});
|
||||
|
||||
MDB_SHARD_FILE_CACHE.write().unwrap().insert(path, sf.clone());
|
||||
cache.write().unwrap().insert(path, sf.clone());
|
||||
|
||||
Ok(sf)
|
||||
}
|
||||
|
||||
fn drop_from_cache(self: Arc<Self>) {
|
||||
MDB_SHARD_FILE_CACHE.write().unwrap().remove_entry(&self.path);
|
||||
fn drop_from_cache(self: Arc<Self>, cache: &ShardFileCache) {
|
||||
cache.write().unwrap().remove_entry(&self.path);
|
||||
}
|
||||
|
||||
pub fn purge_if_needed(self: Arc<Self>) {
|
||||
// If the file no longer exists or isn't the correct length, then purge it from the cache.
|
||||
pub fn purge_if_needed(self: Arc<Self>, cache: &ShardFileCache) {
|
||||
if !self.path.exists() || self.shard.num_bytes() != self.path.metadata().map(|m| m.len()).unwrap_or(0) {
|
||||
info!("Purging shard file from cache: {:?}", self.path);
|
||||
self.drop_from_cache();
|
||||
self.drop_from_cache(cache);
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads the MDBShardFile struct from a file path
|
||||
pub fn load_from_file(path: &Path) -> Result<Arc<Self>> {
|
||||
pub fn load_from_file(path: &Path, cache: &ShardFileCache) -> Result<Arc<Self>> {
|
||||
if let Some(shard_hash) = parse_shard_filename(path.to_str().unwrap()) {
|
||||
Self::load_from_hash_and_path(shard_hash, path)
|
||||
Self::load_from_hash_and_path(shard_hash, path, cache)
|
||||
} else {
|
||||
Err(CoreError::BadFilename(format!("{path:?} not a valid MerkleDB filename.")))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_all_valid(path: impl AsRef<Path>) -> Result<Vec<Arc<Self>>> {
|
||||
Self::load_managed_directory(path, true, false, false, 0)
|
||||
pub fn load_all_valid(path: impl AsRef<Path>, cache: &ShardFileCache) -> Result<Vec<Arc<Self>>> {
|
||||
Self::load_managed_directory(path, true, false, false, 0, cache)
|
||||
}
|
||||
|
||||
pub fn load_managed_directory(
|
||||
@@ -193,6 +203,7 @@ impl MDBShardFile {
|
||||
load_expired: bool,
|
||||
prune_expired: bool,
|
||||
prune_dir_storage_to_size: u64,
|
||||
cache: &ShardFileCache,
|
||||
) -> Result<Vec<Arc<Self>>> {
|
||||
let current_time = current_timestamp();
|
||||
let expiration_buffer = MDB_SHARD_EXPIRATION_BUFFER.as_secs();
|
||||
@@ -201,7 +212,7 @@ impl MDBShardFile {
|
||||
|
||||
let mut total_size = 0;
|
||||
|
||||
Self::scan_impl(path, skip_on_error, |s| {
|
||||
Self::scan_impl(path, skip_on_error, cache, |s| {
|
||||
if load_expired || current_time <= s.shard.metadata.shard_key_expiry {
|
||||
total_size += s.shard.num_bytes();
|
||||
ret.push(s);
|
||||
@@ -210,25 +221,20 @@ impl MDBShardFile {
|
||||
{
|
||||
info!("Deleting expired shard {:?}", &s.path);
|
||||
let _ = std::fs::remove_file(&s.path);
|
||||
Self::drop_from_cache(s);
|
||||
Self::drop_from_cache(s, cache);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})?;
|
||||
|
||||
// Do we need to prune the directory to keep things down to size?
|
||||
if prune_dir_storage_to_size != 0 && total_size > prune_dir_storage_to_size {
|
||||
// Flush out the oldest ones first using a heap.
|
||||
|
||||
let heap_predicate = |s1: &Arc<MDBShardFile>, s2: &Arc<MDBShardFile>| {
|
||||
// Compare in reverse so pop is done from earliest shard
|
||||
s2.shard
|
||||
.metadata
|
||||
.shard_creation_timestamp
|
||||
.partial_cmp(&s1.shard.metadata.shard_creation_timestamp)
|
||||
};
|
||||
|
||||
// Turn the return shards into a heap around the shard creation timestamp
|
||||
make_heap_with(&mut ret, heap_predicate);
|
||||
|
||||
while total_size > prune_dir_storage_to_size {
|
||||
@@ -240,17 +246,21 @@ impl MDBShardFile {
|
||||
info!("Pruning shard to maintain cache size: {:?}", &s.path);
|
||||
total_size -= s.shard.num_bytes();
|
||||
let _ = std::fs::remove_file(&s.path);
|
||||
Self::drop_from_cache(s);
|
||||
Self::drop_from_cache(s, cache);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ret)
|
||||
}
|
||||
|
||||
pub fn clean_shard_cache(path: impl AsRef<Path>, expiration_buffer_secs: u64) -> Result<()> {
|
||||
pub fn clean_shard_cache(
|
||||
path: impl AsRef<Path>,
|
||||
expiration_buffer_secs: u64,
|
||||
cache: &ShardFileCache,
|
||||
) -> Result<()> {
|
||||
let current_time = current_timestamp();
|
||||
|
||||
Self::scan_impl(path, true, |s| {
|
||||
Self::scan_impl(path, true, cache, |s| {
|
||||
if s.shard.metadata.shard_key_expiry.saturating_add(expiration_buffer_secs) <= current_time {
|
||||
info!("Deleting expired shard {:?}", &s.path);
|
||||
let _ = std::fs::remove_file(&s.path);
|
||||
@@ -274,12 +284,13 @@ impl MDBShardFile {
|
||||
fn scan_impl(
|
||||
path: impl AsRef<Path>,
|
||||
skip_on_error: bool,
|
||||
cache: &ShardFileCache,
|
||||
mut callback: impl FnMut(Arc<Self>) -> Result<()>,
|
||||
) -> Result<()> {
|
||||
let path = path.as_ref();
|
||||
|
||||
let mut load_file = |h: MerkleHash, file_name: &Path| -> Result<()> {
|
||||
let s_res = Self::load_from_hash_and_path(h, file_name);
|
||||
let s_res = Self::load_from_hash_and_path(h, file_name, cache);
|
||||
|
||||
let s = match s_res {
|
||||
Ok(s) => s,
|
||||
@@ -323,6 +334,7 @@ impl MDBShardFile {
|
||||
|
||||
/// Write out the current shard, re-keyed with an hmac key, to the output directory in question, returning
|
||||
/// the full path to the new shard.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn export_as_keyed_shard(
|
||||
&self,
|
||||
target_directory: impl AsRef<Path>,
|
||||
@@ -331,6 +343,7 @@ impl MDBShardFile {
|
||||
include_file_info: bool,
|
||||
include_xorb_lookup_table: bool,
|
||||
include_chunk_lookup_table: bool,
|
||||
cache: &ShardFileCache,
|
||||
) -> Result<Arc<Self>> {
|
||||
let mut output_bytes = Vec::<u8>::new();
|
||||
|
||||
@@ -344,7 +357,7 @@ impl MDBShardFile {
|
||||
include_chunk_lookup_table,
|
||||
)?;
|
||||
|
||||
let written_out = Self::write_out_from_reader(target_directory, &mut Cursor::new(output_bytes))?;
|
||||
let written_out = Self::write_out_from_reader(target_directory, &mut Cursor::new(output_bytes), cache)?;
|
||||
written_out.verify_shard_integrity_debug_only();
|
||||
|
||||
Ok(written_out)
|
||||
|
||||
@@ -6,12 +6,12 @@ use std::sync::atomic::AtomicBool;
|
||||
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, info, instrument, trace, warn};
|
||||
use xet_runtime::core::{XetRuntime, xet_config};
|
||||
use xet_runtime::core::{XetCommon, XetContext};
|
||||
use xet_runtime::utils::RwTaskLock;
|
||||
|
||||
use super::constants::MDB_SHARD_EXPIRATION_BUFFER;
|
||||
use super::file_structs::*;
|
||||
use super::shard_file_handle::MDBShardFile;
|
||||
use super::shard_file_handle::{MDBShardFile, ShardFileCache};
|
||||
use super::shard_file_reconstructor::FileReconstructor;
|
||||
use super::shard_in_memory::MDBInMemoryShard;
|
||||
use super::utils::truncate_hash;
|
||||
@@ -20,9 +20,14 @@ use crate::error::{CoreError, Result};
|
||||
use crate::merklehash::{HMACKey, MerkleHash};
|
||||
use crate::{MerkleHashMap, TruncatedMerkleHashMap};
|
||||
|
||||
// The shard manager cache
|
||||
lazy_static::lazy_static! {
|
||||
static ref MDB_SHARD_FILE_MANAGER_CACHE: RwLock<HashMap<PathBuf, Arc<ShardFileManager>>> = RwLock::new(HashMap::default());
|
||||
type ShardFileManagerCache = Arc<RwLock<HashMap<PathBuf, Arc<ShardFileManager>>>>;
|
||||
|
||||
fn get_sfm_cache(common: &XetCommon) -> ShardFileManagerCache {
|
||||
common.cache_get_or_create("mdb_shard_file_manager_cache", || Arc::new(RwLock::new(HashMap::new())))
|
||||
}
|
||||
|
||||
pub fn get_shard_file_cache(common: &XetCommon) -> ShardFileCache {
|
||||
common.cache_get_or_create("mdb_shard_file_cache", super::shard_file_handle::new_shard_file_cache)
|
||||
}
|
||||
|
||||
// The structure used as the target for the dedup lookup
|
||||
@@ -74,6 +79,8 @@ impl ShardBookkeeper {
|
||||
}
|
||||
|
||||
pub struct ShardFileManager {
|
||||
chunk_index_table_max_size: usize,
|
||||
shard_file_cache: ShardFileCache,
|
||||
shard_bookkeeper: RwTaskLock<ShardBookkeeper, CoreError>,
|
||||
current_state: RwLock<MDBInMemoryShard>,
|
||||
shard_directory: PathBuf,
|
||||
@@ -101,25 +108,28 @@ pub struct ShardFileManager {
|
||||
impl ShardFileManager {
|
||||
// Construct in a session directory.
|
||||
pub async fn new_in_session_directory(
|
||||
ctx: &XetContext,
|
||||
session_directory: impl AsRef<Path>,
|
||||
scan_directory: bool,
|
||||
) -> Result<Arc<Self>> {
|
||||
Self::new_impl(session_directory, false, xet_config().shard.max_target_size, scan_directory, 0).await
|
||||
Self::new_impl(ctx, session_directory, false, ctx.config.shard.max_target_size, scan_directory, 0).await
|
||||
}
|
||||
|
||||
// Construction functions
|
||||
pub async fn new_in_cache_directory(cache_directory: impl AsRef<Path>) -> Result<Arc<Self>> {
|
||||
pub async fn new_in_cache_directory(ctx: &XetContext, cache_directory: impl AsRef<Path>) -> Result<Arc<Self>> {
|
||||
Self::new_impl(
|
||||
ctx,
|
||||
cache_directory,
|
||||
true,
|
||||
xet_config().shard.max_target_size,
|
||||
ctx.config.shard.max_target_size,
|
||||
true,
|
||||
xet_config().shard.cache_size_limit.as_u64(),
|
||||
ctx.config.shard.cache_size_limit.as_u64(),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn new_impl(
|
||||
ctx: &XetContext,
|
||||
directory: impl AsRef<Path>,
|
||||
is_cachable: bool,
|
||||
target_shard_max_size: u64,
|
||||
@@ -133,8 +143,12 @@ impl ShardFileManager {
|
||||
std::fs::create_dir_all(&shard_directory)?;
|
||||
}
|
||||
|
||||
let chunk_index_table_max_size = ctx.config.shard.chunk_index_table_max_size;
|
||||
let shard_file_cache = get_shard_file_cache(&ctx.common);
|
||||
let create_new_sfm = || {
|
||||
Arc::new(Self {
|
||||
chunk_index_table_max_size,
|
||||
shard_file_cache: shard_file_cache.clone(),
|
||||
shard_bookkeeper: RwTaskLock::from_value(ShardBookkeeper::new()),
|
||||
current_state: RwLock::new(MDBInMemoryShard::default()),
|
||||
shard_directory: shard_directory.clone(),
|
||||
@@ -148,8 +162,10 @@ impl ShardFileManager {
|
||||
break 'load_sfm create_new_sfm();
|
||||
}
|
||||
|
||||
let sfm_cache = get_sfm_cache(&ctx.common);
|
||||
|
||||
{
|
||||
let ro_lg = MDB_SHARD_FILE_MANAGER_CACHE.read().await;
|
||||
let ro_lg = sfm_cache.read().await;
|
||||
|
||||
if let Some(sfm) = ro_lg.get(&shard_directory) {
|
||||
sfm.refresh_shard_dir(false, 0).await?;
|
||||
@@ -158,7 +174,7 @@ impl ShardFileManager {
|
||||
}
|
||||
|
||||
// Now, create and insert it.
|
||||
let mut rw_lg = MDB_SHARD_FILE_MANAGER_CACHE.write().await;
|
||||
let mut rw_lg = sfm_cache.write().await;
|
||||
rw_lg.entry(shard_directory.clone()).or_insert_with(create_new_sfm).clone()
|
||||
};
|
||||
|
||||
@@ -176,6 +192,7 @@ impl ShardFileManager {
|
||||
false,
|
||||
prune_expired,
|
||||
prune_cache_to_size,
|
||||
&self.shard_file_cache,
|
||||
)?;
|
||||
|
||||
{
|
||||
@@ -189,7 +206,11 @@ impl ShardFileManager {
|
||||
}
|
||||
|
||||
pub async fn import_shard_from_bytes(&self, shard: &[u8]) -> Result<()> {
|
||||
let new_shard_file = MDBShardFile::write_out_from_reader(&self.shard_directory, &mut Cursor::new(shard))?;
|
||||
let new_shard_file = MDBShardFile::write_out_from_reader(
|
||||
&self.shard_directory,
|
||||
&mut Cursor::new(shard),
|
||||
&self.shard_file_cache,
|
||||
)?;
|
||||
|
||||
self.register_shards(&[new_shard_file]).await
|
||||
}
|
||||
@@ -203,7 +224,11 @@ impl ShardFileManager {
|
||||
let needs_clean = self.shard_directory_cleaned.swap(true, std::sync::atomic::Ordering::Relaxed);
|
||||
|
||||
if needs_clean {
|
||||
MDBShardFile::clean_shard_cache(&self.shard_directory, MDB_SHARD_EXPIRATION_BUFFER.as_secs())?;
|
||||
MDBShardFile::clean_shard_cache(
|
||||
&self.shard_directory,
|
||||
MDB_SHARD_EXPIRATION_BUFFER.as_secs(),
|
||||
&self.shard_file_cache,
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -211,7 +236,7 @@ impl ShardFileManager {
|
||||
|
||||
pub async fn register_shards_by_path<P: AsRef<Path>>(&self, new_shards: &[P]) -> Result<()> {
|
||||
let new_shards: Vec<Arc<_>> = new_shards.iter().try_fold(Vec::new(), |mut acc, p| {
|
||||
acc.extend(MDBShardFile::load_all_valid(p)?);
|
||||
acc.extend(MDBShardFile::load_all_valid(p, &self.shard_file_cache)?);
|
||||
|
||||
Result::Ok(acc)
|
||||
})?;
|
||||
@@ -251,8 +276,9 @@ impl ShardFileManager {
|
||||
// Begin loading the truncated hashes in the background for this shard so they're ready
|
||||
// when we have to insert them.
|
||||
let s_rth = s.clone();
|
||||
let s_truncated_hashes_jh = XetRuntime::current().spawn_blocking(move || s_rth.read_all_truncated_hashes());
|
||||
let s_truncated_hashes_jh = tokio::task::spawn_blocking(move || s_rth.read_all_truncated_hashes());
|
||||
|
||||
let chunk_index_table_max_size = self.chunk_index_table_max_size;
|
||||
// Update the bookkeeper with the task of
|
||||
self.shard_bookkeeper
|
||||
.update(move |mut sbkp_lg| async move {
|
||||
@@ -269,8 +295,7 @@ impl ShardFileManager {
|
||||
sbkp_lg.shard_collections.push(KeyedShardCollection::new(shard_hmac_key));
|
||||
}
|
||||
|
||||
let update_chunk_lookup =
|
||||
sbkp_lg.total_indexed_chunks < xet_config().shard.chunk_index_table_max_size;
|
||||
let update_chunk_lookup = sbkp_lg.total_indexed_chunks < chunk_index_table_max_size;
|
||||
|
||||
let shard_hash = s.shard_hash;
|
||||
|
||||
@@ -349,7 +374,7 @@ impl ShardFileManager {
|
||||
let mut all_file_info: Vec<MDBFileInfo> =
|
||||
self.current_state.read().await.file_content.values().cloned().collect();
|
||||
|
||||
let shard_files = MDBShardFile::load_all_valid(&self.shard_directory)?;
|
||||
let shard_files = MDBShardFile::load_all_valid(&self.shard_directory, &self.shard_file_cache)?;
|
||||
|
||||
for shard in shard_files {
|
||||
all_file_info.append(&mut shard.read_all_file_info_sections()?);
|
||||
@@ -466,9 +491,9 @@ impl ShardFileManager {
|
||||
lg.add_xorb_block(xorb_block_contents)?;
|
||||
drop(lg);
|
||||
|
||||
// if we cut a new shard, register it after dropping the lock guard
|
||||
if let Some(new_shard_path) = new_shard_path {
|
||||
self.register_shards(&[MDBShardFile::load_from_file(&new_shard_path)?]).await?;
|
||||
self.register_shards(&[MDBShardFile::load_from_file(&new_shard_path, &self.shard_file_cache)?])
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -478,7 +503,6 @@ impl ShardFileManager {
|
||||
pub async fn add_file_reconstruction_info(&self, file_info: MDBFileInfo) -> Result<()> {
|
||||
let mut lg = self.current_state.write().await;
|
||||
|
||||
// cut a new shard if adding this file will take us over the max shard file size
|
||||
let new_shard_path = if lg.shard_file_size() + file_info.num_bytes() > self.target_shard_max_size {
|
||||
let path = Self::cut_shard(&mut lg, &self.shard_directory)?;
|
||||
Some(path)
|
||||
@@ -489,9 +513,9 @@ impl ShardFileManager {
|
||||
lg.add_file_reconstruction_info(file_info)?;
|
||||
drop(lg);
|
||||
|
||||
// if we cut a new shard, register it after dropping the lock guard
|
||||
if let Some(new_shard_path) = new_shard_path {
|
||||
self.register_shards(&[MDBShardFile::load_from_file(&new_shard_path)?]).await?;
|
||||
self.register_shards(&[MDBShardFile::load_from_file(&new_shard_path, &self.shard_file_cache)?])
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -503,13 +527,10 @@ impl ShardFileManager {
|
||||
Ok(new_shard_path)
|
||||
}
|
||||
|
||||
/// Flush the current state of the in-memory lookups to a shard in the session directory,
|
||||
/// returning the hash of the shard and the file written, or None if no file was written.
|
||||
#[instrument(skip_all, name = "ShardFileManager::flush")]
|
||||
pub async fn flush(&self) -> Result<Option<PathBuf>> {
|
||||
let new_shard_path;
|
||||
|
||||
// The locked section here.
|
||||
{
|
||||
let mut lg = self.current_state.write().await;
|
||||
|
||||
@@ -523,11 +544,15 @@ impl ShardFileManager {
|
||||
info!("Shard manager flushed new shard to {new_shard_path:?}.");
|
||||
}
|
||||
|
||||
// Load this one into our local shard catalog
|
||||
self.register_shards(&[MDBShardFile::load_from_file(&new_shard_path)?]).await?;
|
||||
self.register_shards(&[MDBShardFile::load_from_file(&new_shard_path, &self.shard_file_cache)?])
|
||||
.await?;
|
||||
|
||||
Ok(Some(new_shard_path))
|
||||
}
|
||||
|
||||
pub fn shard_file_cache(&self) -> &ShardFileCache {
|
||||
&self.shard_file_cache
|
||||
}
|
||||
}
|
||||
|
||||
impl ShardFileManager {
|
||||
@@ -575,15 +600,22 @@ mod tests {
|
||||
|
||||
use rand::prelude::*;
|
||||
use tempfile::TempDir;
|
||||
use tokio::runtime::Handle;
|
||||
use xet_runtime::config::XetConfig;
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
use super::super::file_structs::FileDataSequenceHeader;
|
||||
use super::super::session_directory::{consolidate_shards_in_directory, merge_shards};
|
||||
use super::super::shard_format::test_routines::{gen_random_file_info, rng_hash, simple_hash};
|
||||
use super::super::utils::parse_shard_filename;
|
||||
use super::super::xorb_structs::{XorbChunkSequenceEntry, XorbChunkSequenceHeader};
|
||||
use super::*;
|
||||
use super::{get_shard_file_cache, *};
|
||||
use crate::error::Result;
|
||||
|
||||
fn test_context() -> XetContext {
|
||||
XetContext::from_external(Handle::current(), XetConfig::new())
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub async fn fill_with_specific_shard(
|
||||
shard: &ShardFileManager,
|
||||
@@ -631,6 +663,7 @@ mod tests {
|
||||
|
||||
// Create n_shards new random shards in the directory pointed
|
||||
pub async fn create_random_shard_collection(
|
||||
ctx: &XetContext,
|
||||
seed: u64,
|
||||
shard_dir: impl AsRef<Path>,
|
||||
n_shards: usize,
|
||||
@@ -641,7 +674,7 @@ mod tests {
|
||||
let mut rng = StdRng::seed_from_u64(seed);
|
||||
|
||||
let shard_dir = shard_dir.as_ref();
|
||||
let sfm = ShardFileManager::new_in_session_directory(shard_dir, false).await?;
|
||||
let sfm = ShardFileManager::new_in_session_directory(ctx, shard_dir, false).await?;
|
||||
let mut reference_shard = MDBInMemoryShard::default();
|
||||
|
||||
for _ in 0..n_shards {
|
||||
@@ -756,9 +789,9 @@ mod tests {
|
||||
assert_eq!(result_m.is_some(), result_f.is_some());
|
||||
|
||||
// Make sure retriving the expected file.
|
||||
if result_m.is_some() {
|
||||
assert_eq!(result_m.unwrap().metadata.file_hash, *k);
|
||||
assert_eq!(result_f.unwrap().0.metadata.file_hash, *k);
|
||||
if let (Some(result_m), Some(result_f)) = (result_m, result_f) {
|
||||
assert_eq!(result_m.metadata.file_hash, *k);
|
||||
assert_eq!(result_f.0.metadata.file_hash, *k);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -771,17 +804,23 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn sfm_with_target_shard_size(path: impl AsRef<Path>, target_size: u64) -> Result<Arc<ShardFileManager>> {
|
||||
ShardFileManager::new_impl(path, false, target_size, true, 0).await
|
||||
async fn sfm_with_target_shard_size(
|
||||
ctx: &XetContext,
|
||||
path: impl AsRef<Path>,
|
||||
target_size: u64,
|
||||
) -> Result<Arc<ShardFileManager>> {
|
||||
ShardFileManager::new_impl(ctx, path, false, target_size, true, 0).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_basic_retrieval() -> Result<()> {
|
||||
let ctx = test_context();
|
||||
let sfc = get_shard_file_cache(&ctx.common);
|
||||
let tmp_dir = TempDir::with_prefix("gitxet_shard_test_1")?;
|
||||
let mut mdb_in_mem = MDBInMemoryShard::default();
|
||||
|
||||
{
|
||||
let mdb = ShardFileManager::new_in_session_directory(tmp_dir.path(), true).await?;
|
||||
let mdb = ShardFileManager::new_in_session_directory(&ctx, tmp_dir.path(), true).await?;
|
||||
|
||||
fill_with_specific_shard(&mdb, &mut mdb_in_mem, &[(0, &[(11, 5)])], &[(100, &[(200, (0, 5))])]).await?;
|
||||
|
||||
@@ -793,11 +832,11 @@ mod tests {
|
||||
verify_metadata_shards_match(&mdb, &mdb_in_mem, true).await?;
|
||||
|
||||
// Verify that the file is correct
|
||||
MDBShardFile::load_from_file(&out_file)?.verify_shard_integrity();
|
||||
MDBShardFile::load_from_file(&out_file, &sfc)?.verify_shard_integrity();
|
||||
}
|
||||
{
|
||||
// Now, make sure that this happens if this directory is opened up
|
||||
let mdb2 = ShardFileManager::new_in_session_directory(tmp_dir.path(), true).await?;
|
||||
let mdb2 = ShardFileManager::new_in_session_directory(&ctx, tmp_dir.path(), true).await?;
|
||||
|
||||
// Make sure it's all in there this round.
|
||||
verify_metadata_shards_match(&mdb2, &mdb_in_mem, true).await?;
|
||||
@@ -808,8 +847,13 @@ mod tests {
|
||||
verify_metadata_shards_match(&mdb2, &mdb_in_mem, true).await?;
|
||||
|
||||
// Now, merge shards in the background.
|
||||
let merged_shards =
|
||||
consolidate_shards_in_directory(tmp_dir.path(), xet_config().shard.max_target_size, false)?;
|
||||
let merged_shards = consolidate_shards_in_directory(
|
||||
&ctx.runtime,
|
||||
tmp_dir.path(),
|
||||
ctx.config.shard.max_target_size,
|
||||
false,
|
||||
&sfc,
|
||||
)?;
|
||||
|
||||
assert_eq!(merged_shards.len(), 1);
|
||||
for si in merged_shards {
|
||||
@@ -825,9 +869,11 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_larger_simulated() -> Result<()> {
|
||||
let ctx = test_context();
|
||||
let sfc = get_shard_file_cache(&ctx.common);
|
||||
let tmp_dir = TempDir::with_prefix("gitxet_shard_test_2")?;
|
||||
let mut mdb_in_mem = MDBInMemoryShard::default();
|
||||
let mdb = ShardFileManager::new_in_session_directory(tmp_dir.path(), true).await?;
|
||||
let mdb = ShardFileManager::new_in_session_directory(&ctx, tmp_dir.path(), true).await?;
|
||||
|
||||
for i in 0..10 {
|
||||
fill_with_random_shard(&mdb, &mut mdb_in_mem, i, &[1, 5, 10, 8], &[4, 3, 5, 9, 4, 6]).await?;
|
||||
@@ -840,13 +886,13 @@ mod tests {
|
||||
verify_metadata_shards_match(&mdb, &mdb_in_mem, true).await?;
|
||||
|
||||
// Verify that the file is correct
|
||||
MDBShardFile::load_from_file(&out_file)?.verify_shard_integrity();
|
||||
MDBShardFile::load_from_file(&out_file, &sfc)?.verify_shard_integrity();
|
||||
|
||||
// Make sure an empty flush doesn't bother anything.
|
||||
mdb.flush().await?;
|
||||
|
||||
// Now, make sure that this happens if this directory is opened up
|
||||
let mdb2 = ShardFileManager::new_in_session_directory(tmp_dir.path(), true).await?;
|
||||
let mdb2 = ShardFileManager::new_in_session_directory(&ctx, tmp_dir.path(), true).await?;
|
||||
|
||||
// Make sure it's all in there this round.
|
||||
verify_metadata_shards_match(&mdb2, &mdb_in_mem, true).await?;
|
||||
@@ -856,13 +902,15 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_process_session_management() -> Result<()> {
|
||||
let ctx = test_context();
|
||||
let sfc = get_shard_file_cache(&ctx.common);
|
||||
let tmp_dir = TempDir::with_prefix("gitxet_shard_test_3").unwrap();
|
||||
let mut mdb_in_mem = MDBInMemoryShard::default();
|
||||
|
||||
for sesh in 0..3 {
|
||||
for i in 0..10 {
|
||||
{
|
||||
let mdb = ShardFileManager::new_in_session_directory(tmp_dir.path(), true).await?;
|
||||
let mdb = ShardFileManager::new_in_session_directory(&ctx, tmp_dir.path(), true).await?;
|
||||
fill_with_random_shard(&mdb, &mut mdb_in_mem, 100 * sesh + i, &[1, 5, 10, 8], &[4, 3, 5, 9, 4, 6])
|
||||
.await
|
||||
.unwrap();
|
||||
@@ -875,7 +923,7 @@ mod tests {
|
||||
verify_metadata_shards_match(&mdb, &mdb_in_mem, true).await.unwrap();
|
||||
|
||||
// Verify that the file is correct
|
||||
MDBShardFile::load_from_file(&out_file)?.verify_shard_integrity();
|
||||
MDBShardFile::load_from_file(&out_file, &sfc)?.verify_shard_integrity();
|
||||
|
||||
mdb.flush().await.unwrap();
|
||||
|
||||
@@ -884,8 +932,14 @@ mod tests {
|
||||
}
|
||||
|
||||
{
|
||||
let merged_shards =
|
||||
consolidate_shards_in_directory(tmp_dir.path(), xet_config().shard.max_target_size, false).unwrap();
|
||||
let merged_shards = consolidate_shards_in_directory(
|
||||
&ctx.runtime,
|
||||
tmp_dir.path(),
|
||||
ctx.config.shard.max_target_size,
|
||||
false,
|
||||
&sfc,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(merged_shards.len(), 1);
|
||||
|
||||
@@ -897,7 +951,7 @@ mod tests {
|
||||
|
||||
{
|
||||
// Now, make sure that this happens if this directory is opened up
|
||||
let mdb2 = ShardFileManager::new_in_session_directory(tmp_dir.path(), true).await?;
|
||||
let mdb2 = ShardFileManager::new_in_session_directory(&ctx, tmp_dir.path(), true).await?;
|
||||
|
||||
verify_metadata_shards_match(&mdb2, &mdb_in_mem, true).await.unwrap();
|
||||
}
|
||||
@@ -907,18 +961,20 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_flush_and_consolidation() -> Result<()> {
|
||||
let ctx = test_context();
|
||||
let sfc = get_shard_file_cache(&ctx.common);
|
||||
let tmp_dir = TempDir::with_prefix("gitxet_shard_test_4b")?;
|
||||
let mut mdb_in_mem = MDBInMemoryShard::default();
|
||||
|
||||
const T: u64 = 10000;
|
||||
|
||||
{
|
||||
let mdb = sfm_with_target_shard_size(tmp_dir.path(), T).await?;
|
||||
let mdb = sfm_with_target_shard_size(&ctx, tmp_dir.path(), T).await?;
|
||||
fill_with_random_shard(&mdb, &mut mdb_in_mem, 0, &[16; 16], &[16; 16]).await?;
|
||||
mdb.flush().await?;
|
||||
}
|
||||
{
|
||||
let mdb = sfm_with_target_shard_size(tmp_dir.path(), 2 * T).await?;
|
||||
let mdb = sfm_with_target_shard_size(&ctx, tmp_dir.path(), 2 * T).await?;
|
||||
|
||||
verify_metadata_shards_match(&mdb, &mdb_in_mem, true).await?;
|
||||
|
||||
@@ -931,7 +987,7 @@ mod tests {
|
||||
|
||||
// Reload and verify
|
||||
{
|
||||
let mdb = ShardFileManager::new_in_session_directory(tmp_dir.path(), true).await?;
|
||||
let mdb = ShardFileManager::new_in_session_directory(&ctx, tmp_dir.path(), true).await?;
|
||||
verify_metadata_shards_match(&mdb, &mdb_in_mem, true).await?;
|
||||
}
|
||||
|
||||
@@ -939,7 +995,8 @@ mod tests {
|
||||
{
|
||||
let tmp_merge_dir = TempDir::new()?;
|
||||
|
||||
let shard_merge_result = merge_shards(tmp_dir.path(), tmp_merge_dir.path(), 8 * T, false)?;
|
||||
let shard_merge_result =
|
||||
merge_shards(&ctx.runtime, tmp_dir.path(), tmp_merge_dir.path(), 8 * T, false, &sfc)?;
|
||||
let mut merged_shards = shard_merge_result.merged_shards;
|
||||
let m_del_shards = shard_merge_result.obsolete_shards;
|
||||
|
||||
@@ -951,7 +1008,7 @@ mod tests {
|
||||
assert_eq!(paths.count(), m_del_shards.len());
|
||||
|
||||
// This call should be the same, but
|
||||
let mut rv = consolidate_shards_in_directory(tmp_dir.path(), 8 * T, false)?;
|
||||
let mut rv = consolidate_shards_in_directory(&ctx.runtime, tmp_dir.path(), 8 * T, false, &sfc)?;
|
||||
|
||||
let paths = std::fs::read_dir(tmp_dir.path()).unwrap();
|
||||
let n_paths = paths.count();
|
||||
@@ -973,7 +1030,7 @@ mod tests {
|
||||
|
||||
// Reload and verify
|
||||
{
|
||||
let mdb = ShardFileManager::new_in_session_directory(tmp_dir.path(), true).await?;
|
||||
let mdb = ShardFileManager::new_in_session_directory(&ctx, tmp_dir.path(), true).await?;
|
||||
verify_metadata_shards_match(&mdb, &mdb_in_mem, true).await?;
|
||||
}
|
||||
|
||||
@@ -982,13 +1039,15 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_size_threshholds() -> Result<()> {
|
||||
let ctx = test_context();
|
||||
let sfc = get_shard_file_cache(&ctx.common);
|
||||
let tmp_dir = TempDir::with_prefix("gitxet_shard_test_4")?;
|
||||
let mut mdb_in_mem = MDBInMemoryShard::default();
|
||||
|
||||
const T: u64 = 4096;
|
||||
|
||||
for i in 0..5 {
|
||||
let mdb = sfm_with_target_shard_size(tmp_dir.path(), T).await?;
|
||||
let mdb = sfm_with_target_shard_size(&ctx, tmp_dir.path(), T).await?;
|
||||
fill_with_random_shard(&mdb, &mut mdb_in_mem, i, &[5; 25], &[5; 25]).await?;
|
||||
|
||||
verify_metadata_shards_match(&mdb, &mdb_in_mem, true).await?;
|
||||
@@ -996,7 +1055,7 @@ mod tests {
|
||||
let out_file = mdb.flush().await?.unwrap();
|
||||
|
||||
// Verify that the file is correct
|
||||
MDBShardFile::load_from_file(&out_file).unwrap().verify_shard_integrity();
|
||||
MDBShardFile::load_from_file(&out_file, &sfc).unwrap().verify_shard_integrity();
|
||||
|
||||
// Make sure it still stays together
|
||||
verify_metadata_shards_match(&mdb, &mdb_in_mem, true).await?;
|
||||
@@ -1009,12 +1068,13 @@ mod tests {
|
||||
let mut target_size = T;
|
||||
|
||||
loop {
|
||||
let mdb2 = sfm_with_target_shard_size(tmp_dir.path(), 2 * T).await?;
|
||||
let mdb2 = sfm_with_target_shard_size(&ctx, tmp_dir.path(), 2 * T).await?;
|
||||
|
||||
// Make sure it's all in there this round.
|
||||
verify_metadata_shards_match(&mdb2, &mdb_in_mem, true).await?;
|
||||
|
||||
let merged_shards = consolidate_shards_in_directory(tmp_dir.path(), target_size, false)?;
|
||||
let merged_shards =
|
||||
consolidate_shards_in_directory(&ctx.runtime, tmp_dir.path(), target_size, false, &sfc)?;
|
||||
|
||||
for si in merged_shards.iter() {
|
||||
assert!(si.path.exists());
|
||||
@@ -1041,14 +1101,17 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_keyed_shard_tooling() -> Result<()> {
|
||||
let ctx = test_context();
|
||||
let sfc = get_shard_file_cache(&ctx.common);
|
||||
let tmp_dir = TempDir::with_prefix("shard_test_unkeyed")?;
|
||||
let tmp_dir_path = tmp_dir.path();
|
||||
|
||||
let ref_shard = create_random_shard_collection(0, tmp_dir_path, 2, &[1, 5, 10, 8], &[4, 3, 5, 9, 4, 6]).await?;
|
||||
let ref_shard =
|
||||
create_random_shard_collection(&ctx, 0, tmp_dir_path, 2, &[1, 5, 10, 8], &[4, 3, 5, 9, 4, 6]).await?;
|
||||
|
||||
// First, load all of these with a shard file manager and check them.
|
||||
{
|
||||
let shard_file_manager = ShardFileManager::new_in_session_directory(tmp_dir_path, true).await?;
|
||||
let shard_file_manager = ShardFileManager::new_in_session_directory(&ctx, tmp_dir_path, true).await?;
|
||||
verify_metadata_shards_match(&shard_file_manager, &ref_shard, true).await?;
|
||||
}
|
||||
|
||||
@@ -1077,7 +1140,7 @@ mod tests {
|
||||
}
|
||||
};
|
||||
|
||||
let shard = MDBShardFile::load_from_file(p)?;
|
||||
let shard = MDBShardFile::load_from_file(p, &sfc)?;
|
||||
|
||||
// Reexport all these shards as keyed shards.
|
||||
let out = shard
|
||||
@@ -1088,6 +1151,7 @@ mod tests {
|
||||
include_info,
|
||||
include_info,
|
||||
include_info,
|
||||
&sfc,
|
||||
)
|
||||
.unwrap();
|
||||
if key != HMACKey::default() {
|
||||
@@ -1098,7 +1162,7 @@ mod tests {
|
||||
}
|
||||
|
||||
// Now, verify that everything still works great.
|
||||
let shard_file_manager = ShardFileManager::new_in_session_directory(tmp_dir_path_keyed, true).await?;
|
||||
let shard_file_manager = ShardFileManager::new_in_session_directory(&ctx, tmp_dir_path_keyed, true).await?;
|
||||
|
||||
verify_metadata_shards_match(&shard_file_manager, &ref_shard, include_info).await?;
|
||||
}
|
||||
@@ -1106,39 +1170,41 @@ mod tests {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn shard_list_with_timestamp_filtering(path: &Path) -> Result<Vec<Arc<MDBShardFile>>> {
|
||||
Ok(ShardFileManager::new_impl(path, false, xet_config().shard.max_target_size, true, 0)
|
||||
async fn shard_list_with_timestamp_filtering(ctx: &XetContext, path: &Path) -> Result<Vec<Arc<MDBShardFile>>> {
|
||||
ShardFileManager::new_impl(ctx, path, false, ctx.config.shard.max_target_size, true, 0)
|
||||
.await?
|
||||
.registered_shard_list()
|
||||
.await?)
|
||||
.await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[cfg_attr(feature = "smoke-test", ignore)]
|
||||
async fn test_timestamp_filtering() -> Result<()> {
|
||||
let ctx = test_context();
|
||||
let sfc = get_shard_file_cache(&ctx.common);
|
||||
let tmp_dir = TempDir::with_prefix("shard_test_timestamp")?;
|
||||
let tmp_dir_path = tmp_dir.path();
|
||||
|
||||
// Just create a single shard; we'll key it with other keys and timestamps and then test loading.
|
||||
create_random_shard_collection(0, tmp_dir_path, 1, &[1, 5, 10, 8], &[4, 3, 5, 9, 4, 6]).await?;
|
||||
create_random_shard_collection(&ctx, 0, tmp_dir_path, 1, &[1, 5, 10, 8], &[4, 3, 5, 9, 4, 6]).await?;
|
||||
|
||||
let path = std::fs::read_dir(tmp_dir_path)?.map(|p| p.unwrap().path()).next().unwrap();
|
||||
|
||||
// Create another that has an expiration date of one second from now.
|
||||
let key: HMACKey = rng_hash(0);
|
||||
|
||||
let shard = MDBShardFile::load_from_file(&path)?;
|
||||
let shard = MDBShardFile::load_from_file(&path, &sfc)?;
|
||||
|
||||
let _tmp_dir_keyed = TempDir::with_prefix("shard_test_keyed_1")?;
|
||||
let tmp_dir_path_keyed = _tmp_dir_keyed.path();
|
||||
|
||||
// Reexport this shard as a keyed shards.
|
||||
let out = shard
|
||||
.export_as_keyed_shard(tmp_dir_path_keyed, key, Duration::new(1, 0), false, false, false)
|
||||
.export_as_keyed_shard(tmp_dir_path_keyed, key, Duration::new(1, 0), false, false, false, &sfc)
|
||||
.unwrap();
|
||||
|
||||
{
|
||||
let loaded_shards = shard_list_with_timestamp_filtering(tmp_dir_path_keyed).await?;
|
||||
let loaded_shards = shard_list_with_timestamp_filtering(&ctx, tmp_dir_path_keyed).await?;
|
||||
|
||||
assert_eq!(loaded_shards.len(), 1);
|
||||
assert_eq!(loaded_shards[0].shard_hash, out.shard_hash)
|
||||
@@ -1149,27 +1215,27 @@ mod tests {
|
||||
std::thread::sleep(Duration::new(2, 10_000_000));
|
||||
|
||||
{
|
||||
let loaded_shards = shard_list_with_timestamp_filtering(tmp_dir_path_keyed).await?;
|
||||
let loaded_shards = shard_list_with_timestamp_filtering(&ctx, tmp_dir_path_keyed).await?;
|
||||
|
||||
// No shards loaded
|
||||
assert!(loaded_shards.is_empty());
|
||||
|
||||
// shard file still there.
|
||||
let n_files = std::fs::read_dir(tmp_dir_path_keyed)?.map(|p| p.unwrap().path()).count();
|
||||
let n_files = std::fs::read_dir(tmp_dir_path_keyed)?.count();
|
||||
assert_eq!(n_files, 1);
|
||||
|
||||
// Now try deletion with a large window; shouldn't touch the shard.
|
||||
MDBShardFile::clean_shard_cache(tmp_dir_path_keyed, 100)?;
|
||||
MDBShardFile::clean_shard_cache(tmp_dir_path_keyed, 100, &sfc)?;
|
||||
|
||||
// shard file still there.
|
||||
let n_files = std::fs::read_dir(tmp_dir_path_keyed)?.map(|p| p.unwrap().path()).count();
|
||||
let n_files = std::fs::read_dir(tmp_dir_path_keyed)?.count();
|
||||
assert_eq!(n_files, 1);
|
||||
|
||||
// Now try deletion with 0 expiration
|
||||
MDBShardFile::clean_shard_cache(tmp_dir_path_keyed, 0)?;
|
||||
MDBShardFile::clean_shard_cache(tmp_dir_path_keyed, 0, &sfc)?;
|
||||
|
||||
// File should be gone.
|
||||
let n_files = std::fs::read_dir(tmp_dir_path_keyed)?.map(|p| p.unwrap().path()).count();
|
||||
let n_files = std::fs::read_dir(tmp_dir_path_keyed)?.count();
|
||||
assert_eq!(n_files, 0);
|
||||
}
|
||||
|
||||
@@ -1179,24 +1245,26 @@ mod tests {
|
||||
#[tokio::test]
|
||||
#[cfg_attr(feature = "smoke-test", ignore)]
|
||||
async fn test_export_expiration() -> Result<()> {
|
||||
let ctx = test_context();
|
||||
let sfc = get_shard_file_cache(&ctx.common);
|
||||
let tmp_dir = TempDir::with_prefix("shard_test_timestamp_2")?;
|
||||
let tmp_dir_path = tmp_dir.path();
|
||||
|
||||
// Just create a single shard; we'll key it with other keys and timestamps and then test loading.
|
||||
create_random_shard_collection(0, tmp_dir_path, 1, &[1, 5, 10, 8], &[4, 3, 5, 9, 4, 6]).await?;
|
||||
create_random_shard_collection(&ctx, 0, tmp_dir_path, 1, &[1, 5, 10, 8], &[4, 3, 5, 9, 4, 6]).await?;
|
||||
|
||||
let path = std::fs::read_dir(tmp_dir_path)?.map(|p| p.unwrap().path()).next().unwrap();
|
||||
|
||||
let shard = MDBShardFile::load_from_file(&path)?;
|
||||
let shard = MDBShardFile::load_from_file(&path, &sfc)?;
|
||||
|
||||
let _tmp_dir_expiry = TempDir::with_prefix("shard_test_expiry_2")?;
|
||||
let tmp_dir_path_expiry = _tmp_dir_expiry.path();
|
||||
|
||||
// Create another that has an expiration date of one second from now.
|
||||
let out = shard.export_with_expiration(tmp_dir_path_expiry, Duration::new(1, 0))?;
|
||||
let out = shard.export_with_expiration(tmp_dir_path_expiry, Duration::new(1, 0), &sfc)?;
|
||||
|
||||
{
|
||||
let loaded_shards = shard_list_with_timestamp_filtering(tmp_dir_path_expiry).await?;
|
||||
let loaded_shards = shard_list_with_timestamp_filtering(&ctx, tmp_dir_path_expiry).await?;
|
||||
|
||||
assert_eq!(loaded_shards.len(), 1);
|
||||
assert_eq!(loaded_shards[0].shard_hash, out.shard_hash)
|
||||
@@ -1207,26 +1275,26 @@ mod tests {
|
||||
std::thread::sleep(Duration::new(2, 10_000_000));
|
||||
|
||||
{
|
||||
let loaded_shards = shard_list_with_timestamp_filtering(tmp_dir_path_expiry).await?;
|
||||
let loaded_shards = shard_list_with_timestamp_filtering(&ctx, tmp_dir_path_expiry).await?;
|
||||
|
||||
assert!(loaded_shards.is_empty());
|
||||
|
||||
// Make sure it leaves the shard there.
|
||||
let n_files = std::fs::read_dir(tmp_dir_path_expiry)?.map(|p| p.unwrap().path()).count();
|
||||
let n_files = std::fs::read_dir(tmp_dir_path_expiry)?.count();
|
||||
assert_eq!(n_files, 1);
|
||||
|
||||
// Now try deletion with a large window; shouldn't touch the shard.
|
||||
MDBShardFile::clean_shard_cache(tmp_dir_path_expiry, 100)?;
|
||||
MDBShardFile::clean_shard_cache(tmp_dir_path_expiry, 100, &sfc)?;
|
||||
|
||||
// shard file still there.
|
||||
let n_files = std::fs::read_dir(tmp_dir_path_expiry)?.map(|p| p.unwrap().path()).count();
|
||||
let n_files = std::fs::read_dir(tmp_dir_path_expiry)?.count();
|
||||
assert_eq!(n_files, 1);
|
||||
|
||||
// Now try deletion with 0 expiration
|
||||
MDBShardFile::clean_shard_cache(tmp_dir_path_expiry, 0)?;
|
||||
MDBShardFile::clean_shard_cache(tmp_dir_path_expiry, 0, &sfc)?;
|
||||
|
||||
// File should be gone.
|
||||
let n_files = std::fs::read_dir(tmp_dir_path_expiry)?.map(|p| p.unwrap().path()).count();
|
||||
let n_files = std::fs::read_dir(tmp_dir_path_expiry)?.count();
|
||||
assert_eq!(n_files, 0);
|
||||
}
|
||||
|
||||
@@ -1235,13 +1303,15 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_size_pruning() {
|
||||
let ctx = test_context();
|
||||
let sfc = get_shard_file_cache(&ctx.common);
|
||||
let tmp_dir = TempDir::with_prefix("shard_test_cache_size_pruning").unwrap();
|
||||
|
||||
let tmp_dir_1 = tmp_dir.path().join("src");
|
||||
|
||||
let n_shards = 4;
|
||||
|
||||
create_random_shard_collection(0, &tmp_dir_1, n_shards, &[16; 16], &[16; 16])
|
||||
create_random_shard_collection(&ctx, 0, &tmp_dir_1, n_shards, &[16; 16], &[16; 16])
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -1257,20 +1327,23 @@ mod tests {
|
||||
|
||||
for (i, p) in std::fs::read_dir(&tmp_dir_1).unwrap().enumerate() {
|
||||
let p = p.unwrap();
|
||||
let shard = MDBShardFile::load_from_file(&p.path()).unwrap();
|
||||
let shard = MDBShardFile::load_from_file(&p.path(), &sfc).unwrap();
|
||||
|
||||
let s = shard
|
||||
.export_with_specific_expiration(tmp_dir_2, expiration, base_time + Duration::from_secs(i as u64))
|
||||
.export_with_specific_expiration(tmp_dir_2, expiration, base_time + Duration::from_secs(i as u64), &sfc)
|
||||
.unwrap();
|
||||
|
||||
shard_list.push(s);
|
||||
}
|
||||
|
||||
let get_shards = |cache_size: u64| async move {
|
||||
let sfm = ShardFileManager::new_impl(tmp_dir_2, false, 64 * 1024, true, cache_size)
|
||||
.await
|
||||
.unwrap();
|
||||
sfm.registered_shard_list().await.unwrap()
|
||||
let get_shards = |cache_size: u64| {
|
||||
let ctx = ctx.clone();
|
||||
async move {
|
||||
let sfm = ShardFileManager::new_impl(&ctx, tmp_dir_2, false, 64 * 1024, true, cache_size)
|
||||
.await
|
||||
.unwrap();
|
||||
sfm.registered_shard_list().await.unwrap()
|
||||
}
|
||||
};
|
||||
|
||||
for i in 0..n_shards {
|
||||
@@ -1286,7 +1359,7 @@ mod tests {
|
||||
|
||||
let existing_files: HashSet<_> = std::fs::read_dir(tmp_dir_2)
|
||||
.unwrap()
|
||||
.map(|p| MDBShardFile::load_from_file(&p.unwrap().path()).unwrap().shard_hash)
|
||||
.map(|p| MDBShardFile::load_from_file(&p.unwrap().path(), &sfc).unwrap().shard_hash)
|
||||
.collect();
|
||||
|
||||
assert_eq!(existing_files, reference_hashes);
|
||||
@@ -1303,11 +1376,13 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_shard_deletion_ok() {
|
||||
let ctx = test_context();
|
||||
let sfc = get_shard_file_cache(&ctx.common);
|
||||
let tmp_dir = TempDir::with_prefix("shard_test_deletion").unwrap();
|
||||
|
||||
let tmp_dir_1 = tmp_dir.path().join("src");
|
||||
|
||||
create_random_shard_collection(0, &tmp_dir_1, 4, &[4; 4], &[4; 4])
|
||||
create_random_shard_collection(&ctx, 0, &tmp_dir_1, 4, &[4; 4], &[4; 4])
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -1331,7 +1406,7 @@ mod tests {
|
||||
|
||||
if i == corrupt_file_index {
|
||||
// Load this so the metadata gets cached; that will get loaded from cache.
|
||||
let sfi = MDBShardFile::load_from_file(&dest_file_name).unwrap();
|
||||
let sfi = MDBShardFile::load_from_file(&dest_file_name, &sfc).unwrap();
|
||||
|
||||
// Turn off the typical verification checks on this shard, so it isn't verified
|
||||
// on loading from the cache.
|
||||
@@ -1348,13 +1423,14 @@ mod tests {
|
||||
// Now attempt a merge; this should cause an error.
|
||||
let out_dir_1 = work_dir.join("out_err");
|
||||
std::fs::create_dir_all(&out_dir_1).unwrap();
|
||||
let res = merge_shards(&tmp_src_dir, &out_dir_1, base_size * merge_size, false);
|
||||
let res = merge_shards(&ctx.runtime, &tmp_src_dir, &out_dir_1, base_size * merge_size, false, &sfc);
|
||||
assert!(res.is_err());
|
||||
|
||||
// Now attempt a merge with error skipping; which should not cause an error.
|
||||
let out_dir_2 = work_dir.join("out_skips");
|
||||
std::fs::create_dir_all(&out_dir_2).unwrap();
|
||||
let res = merge_shards(&tmp_src_dir, &out_dir_2, base_size * merge_size, true).unwrap();
|
||||
let res =
|
||||
merge_shards(&ctx.runtime, &tmp_src_dir, &out_dir_2, base_size * merge_size, true, &sfc).unwrap();
|
||||
|
||||
assert_eq!(res.merged_shards.len(), n_merged);
|
||||
|
||||
@@ -1372,4 +1448,31 @@ mod tests {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
|
||||
async fn test_new_in_cache_directory_cache_identity_under_concurrency() -> Result<()> {
|
||||
let ctx = test_context();
|
||||
let tmp_dir = TempDir::with_prefix("shard_test_cache_identity")?;
|
||||
let shard_dir = tmp_dir.path().to_path_buf();
|
||||
|
||||
let mut join_handles = Vec::new();
|
||||
for _ in 0..32 {
|
||||
let ctx = ctx.clone();
|
||||
let shard_dir = shard_dir.clone();
|
||||
join_handles
|
||||
.push(tokio::spawn(async move { ShardFileManager::new_in_cache_directory(&ctx, &shard_dir).await }));
|
||||
}
|
||||
|
||||
let mut managers = Vec::new();
|
||||
for jh in join_handles {
|
||||
managers.push(jh.await.unwrap()?);
|
||||
}
|
||||
|
||||
let first = managers.first().unwrap().clone();
|
||||
for manager in managers.iter().skip(1) {
|
||||
assert!(Arc::ptr_eq(&first, manager));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,8 @@ use std::time::Duration;
|
||||
use tracing::debug;
|
||||
|
||||
use super::file_structs::*;
|
||||
#[cfg(debug_assertions)]
|
||||
use super::shard_file_handle::{MDBShardFile, new_shard_file_cache};
|
||||
use super::shard_format::MDBShardInfo;
|
||||
use super::utils::{shard_file_name, temp_shard_file_name};
|
||||
use super::xorb_structs::*;
|
||||
@@ -258,8 +260,8 @@ impl MDBInMemoryShard {
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
{
|
||||
use super::MDBShardFile;
|
||||
let shard_file = MDBShardFile::load_from_file(&full_file_name)?;
|
||||
let cache = new_shard_file_cache();
|
||||
let shard_file = MDBShardFile::load_from_file(&full_file_name, &cache)?;
|
||||
shard_file.verify_shard_integrity();
|
||||
}
|
||||
|
||||
|
||||
@@ -26,38 +26,4 @@ xet_runtime::test_configurable_constants! {
|
||||
lazy_static::lazy_static! {
|
||||
/// The maximum chunk size, calculated from the configurable constants above
|
||||
pub static ref MAX_CHUNK_SIZE: usize = (*TARGET_CHUNK_SIZE) * (*MAXIMUM_CHUNK_MULTIPLIER);
|
||||
|
||||
/// The byte threshold at which to cut a new xorb during building.
|
||||
/// Defaults to MAX_XORB_BYTES, but in simulation builds can be lowered
|
||||
/// via the `simulation_max_bytes` xorb config value to produce
|
||||
/// smaller (but still valid) xorbs.
|
||||
pub static ref XORB_CUT_THRESHOLD_BYTES: usize = {
|
||||
#[cfg(feature = "simulation")]
|
||||
{
|
||||
xet_runtime::core::xet_config()
|
||||
.xorb
|
||||
.simulation_max_bytes
|
||||
.map(|bs| (bs.as_u64() as usize).min(*MAX_XORB_BYTES))
|
||||
.unwrap_or(*MAX_XORB_BYTES)
|
||||
}
|
||||
#[cfg(not(feature = "simulation"))]
|
||||
{ *MAX_XORB_BYTES }
|
||||
};
|
||||
|
||||
/// The chunk-count threshold at which to cut a new xorb during building.
|
||||
/// Defaults to MAX_XORB_CHUNKS, but in simulation builds can be lowered
|
||||
/// via the `simulation_max_xorb_chunks` xorb config value to produce
|
||||
/// smaller (but still valid) xorbs.
|
||||
pub static ref XORB_CUT_THRESHOLD_CHUNKS: usize = {
|
||||
#[cfg(feature = "simulation")]
|
||||
{
|
||||
xet_runtime::core::xet_config()
|
||||
.xorb
|
||||
.simulation_max_chunks
|
||||
.unwrap_or(*MAX_XORB_CHUNKS)
|
||||
.min(*MAX_XORB_CHUNKS)
|
||||
}
|
||||
#[cfg(not(feature = "simulation"))]
|
||||
{ *MAX_XORB_CHUNKS }
|
||||
};
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ use futures::AsyncReadExt;
|
||||
use more_asserts::*;
|
||||
use serde::Serialize;
|
||||
use tracing::warn;
|
||||
use xet_runtime::core::xet_config;
|
||||
|
||||
use super::constants::{TARGET_CHUNK_SIZE, XORB_BLOCK_SIZE};
|
||||
use super::xorb_chunk_format::{deserialize_chunk, deserialize_chunk_header, serialize_chunk, write_chunk_header};
|
||||
@@ -1285,11 +1284,16 @@ pub struct SerializedXorbObject {
|
||||
impl SerializedXorbObject {
|
||||
/// Builds the xorb from raw xorb data.
|
||||
///
|
||||
/// The compression scheme is determined by `HF_XET_XORB_COMPRESSION_POLICY`:
|
||||
/// The compression scheme is determined by `compression_policy`:
|
||||
/// auto-detect (default) or an explicit scheme (none, lz4, bg4-lz4).
|
||||
pub fn from_xorb(xorb: RawXorbData, serialize_footer: bool) -> Result<Self, CoreError> {
|
||||
let compression_scheme: CompressionScheme = xet_config().xorb.compression_policy.parse()?;
|
||||
Self::from_xorb_with_compression(xorb, compression_scheme, serialize_footer)
|
||||
pub fn from_xorb(
|
||||
xorb: RawXorbData,
|
||||
serialize_footer: bool,
|
||||
compression_policy: &str,
|
||||
compression_scheme_retest_interval: usize,
|
||||
) -> Result<Self, CoreError> {
|
||||
let compression_scheme: CompressionScheme = compression_policy.parse()?;
|
||||
Self::from_xorb_with_compression(xorb, compression_scheme, serialize_footer, compression_scheme_retest_interval)
|
||||
}
|
||||
|
||||
/// Builds the xorb from raw xorb data with an explicit compression scheme override.
|
||||
@@ -1297,6 +1301,7 @@ impl SerializedXorbObject {
|
||||
xorb: RawXorbData,
|
||||
compression_scheme: CompressionScheme,
|
||||
serialize_footer: bool,
|
||||
compression_scheme_retest_interval: usize,
|
||||
) -> Result<Self, CoreError> {
|
||||
let mut xorb_object_info = XorbObjectInfoV1::default();
|
||||
|
||||
@@ -1325,8 +1330,8 @@ impl SerializedXorbObject {
|
||||
let mut serialized_data = Vec::with_capacity(size_upper_bound);
|
||||
|
||||
// Set the periodic retesting interval
|
||||
let retest_interval = if xet_config().xorb.compression_scheme_retest_interval > 0 {
|
||||
xet_config().xorb.compression_scheme_retest_interval
|
||||
let retest_interval = if compression_scheme_retest_interval > 0 {
|
||||
compression_scheme_retest_interval
|
||||
} else {
|
||||
num_chunks
|
||||
};
|
||||
@@ -1383,6 +1388,7 @@ impl SerializedXorbObject {
|
||||
|
||||
pub mod test_utils {
|
||||
use rand::RngExt;
|
||||
use xet_runtime::config::XetConfig;
|
||||
|
||||
use super::super::raw_xorb_data::test_utils::raw_xorb_to_vec;
|
||||
use super::super::xorb_chunk_format::serialize_chunk;
|
||||
@@ -1556,8 +1562,9 @@ pub mod test_utils {
|
||||
xorb: RawXorbData,
|
||||
compression_scheme: CompressionScheme,
|
||||
) -> SerializedXorbObject {
|
||||
let retest = XetConfig::new().xorb.compression_scheme_retest_interval;
|
||||
let xorb_obj =
|
||||
SerializedXorbObject::from_xorb_with_compression(xorb.clone(), compression_scheme, true).unwrap();
|
||||
SerializedXorbObject::from_xorb_with_compression(xorb.clone(), compression_scheme, true, retest).unwrap();
|
||||
|
||||
verify_serialized_xorb_object(&xorb, compression_scheme, &xorb_obj);
|
||||
|
||||
@@ -2276,11 +2283,13 @@ mod tests {
|
||||
);
|
||||
|
||||
// Switch V1 footer to V0
|
||||
let mut xorb_info_v0 = XorbObjectInfoV0::default();
|
||||
xorb_info_v0.xorb_hash = c.info.xorb_hash;
|
||||
xorb_info_v0.num_chunks = c.info.num_chunks;
|
||||
xorb_info_v0.chunk_boundary_offsets = c.info.chunk_boundary_offsets.clone();
|
||||
xorb_info_v0.chunk_hashes = c.info.chunk_hashes.clone();
|
||||
let xorb_info_v0 = XorbObjectInfoV0 {
|
||||
xorb_hash: c.info.xorb_hash,
|
||||
num_chunks: c.info.num_chunks,
|
||||
chunk_boundary_offsets: c.info.chunk_boundary_offsets.clone(),
|
||||
chunk_hashes: c.info.chunk_hashes.clone(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let mut buf = buf.into_inner();
|
||||
let serialized_chunks_length = c.get_contents_length().unwrap();
|
||||
@@ -2551,8 +2560,15 @@ mod tests {
|
||||
#[test]
|
||||
fn test_from_xorb_uses_default_config() {
|
||||
let raw = build_raw_xorb(4, ChunkSize::Fixed(1024));
|
||||
let serialized = SerializedXorbObject::from_xorb(raw, false).unwrap();
|
||||
let cfg = xet_runtime::config::XetConfig::new();
|
||||
let serialized = SerializedXorbObject::from_xorb(
|
||||
raw,
|
||||
false,
|
||||
cfg.xorb.compression_policy.as_str(),
|
||||
cfg.xorb.compression_scheme_retest_interval,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(serialized.num_chunks, 4);
|
||||
assert!(serialized.serialized_data.len() > 0);
|
||||
assert!(!serialized.serialized_data.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,15 +6,18 @@ use tokio::runtime::Runtime;
|
||||
use xet_client::cas_client::{Client, MemoryClient};
|
||||
use xet_data::file_reconstruction::FileReconstructor;
|
||||
use xet_runtime::config::ReconstructionConfig;
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
struct BenchFixture {
|
||||
ctx: XetContext,
|
||||
client: Arc<dyn Client>,
|
||||
file_hash: xet_core_structures::merklehash::MerkleHash,
|
||||
_file_size: u64,
|
||||
}
|
||||
|
||||
async fn create_fixture(num_xorbs: usize, chunks_per_xorb: u64, chunk_size: usize) -> BenchFixture {
|
||||
let client = MemoryClient::new();
|
||||
let ctx = XetContext::default().expect("xet context");
|
||||
let client = MemoryClient::new(ctx.clone());
|
||||
|
||||
let term_spec: Vec<(u64, (u64, u64))> = (0..num_xorbs).map(|i| ((i + 1) as u64, (0, chunks_per_xorb))).collect();
|
||||
|
||||
@@ -22,6 +25,7 @@ async fn create_fixture(num_xorbs: usize, chunks_per_xorb: u64, chunk_size: usiz
|
||||
let file_size = file_contents.data.len() as u64;
|
||||
|
||||
BenchFixture {
|
||||
ctx,
|
||||
client: client as Arc<dyn Client>,
|
||||
file_hash: file_contents.file_hash,
|
||||
_file_size: file_size,
|
||||
@@ -40,13 +44,14 @@ fn bench_sequential_non_vectored(c: &mut Criterion) {
|
||||
&fixture,
|
||||
|b, fix| {
|
||||
b.to_async(&rt).iter(|| {
|
||||
let ctx = fix.ctx.clone();
|
||||
let client = fix.client.clone();
|
||||
let hash = fix.file_hash;
|
||||
let cfg = config.clone();
|
||||
async move {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let path = dir.path().join("out.bin");
|
||||
FileReconstructor::new(&client, hash)
|
||||
FileReconstructor::new(&ctx, &client, hash)
|
||||
.with_config(cfg)
|
||||
.reconstruct_to_file(&path, None, false)
|
||||
.await
|
||||
@@ -69,13 +74,14 @@ fn bench_sequential_vectored(c: &mut Criterion) {
|
||||
&fixture,
|
||||
|b, fix| {
|
||||
b.to_async(&rt).iter(|| {
|
||||
let ctx = fix.ctx.clone();
|
||||
let client = fix.client.clone();
|
||||
let hash = fix.file_hash;
|
||||
let cfg = config.clone();
|
||||
async move {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let path = dir.path().join("out.bin");
|
||||
FileReconstructor::new(&client, hash)
|
||||
FileReconstructor::new(&ctx, &client, hash)
|
||||
.with_config(cfg)
|
||||
.reconstruct_to_file(&path, None, false)
|
||||
.await
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
use std::collections::VecDeque;
|
||||
|
||||
use xet_runtime::core::xet_config;
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
pub(crate) struct DefragPrevention {
|
||||
nranges_in_streaming_fragmentation_estimator: usize,
|
||||
|
||||
/// This tracks the number of chunks in each of the last N ranges
|
||||
rolling_last_nranges: VecDeque<usize>,
|
||||
|
||||
@@ -22,6 +24,18 @@ pub(crate) struct DefragPrevention {
|
||||
}
|
||||
|
||||
impl DefragPrevention {
|
||||
pub(crate) fn new(ctx: &XetContext) -> Self {
|
||||
let d = &ctx.config.deduplication;
|
||||
Self {
|
||||
nranges_in_streaming_fragmentation_estimator: d.nranges_in_streaming_fragmentation_estimator,
|
||||
rolling_last_nranges: VecDeque::with_capacity(d.nranges_in_streaming_fragmentation_estimator),
|
||||
rolling_nranges_chunks: 0,
|
||||
defrag_at_low_threshold: true,
|
||||
min_chunks_per_range: d.min_n_chunks_per_range,
|
||||
min_chunks_per_range_historesis_factor: d.min_n_chunks_per_range_hysteresis_factor,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn increment_last_range_in_fragmentation_estimate(&mut self, nchunks: usize) {
|
||||
if let Some(back) = self.rolling_last_nranges.back_mut() {
|
||||
*back += nchunks;
|
||||
@@ -31,14 +45,14 @@ impl DefragPrevention {
|
||||
pub(crate) fn add_range_to_fragmentation_estimate(&mut self, nchunks: usize) {
|
||||
self.rolling_last_nranges.push_back(nchunks);
|
||||
self.rolling_nranges_chunks += nchunks;
|
||||
if self.rolling_last_nranges.len() > xet_config().deduplication.nranges_in_streaming_fragmentation_estimator {
|
||||
if self.rolling_last_nranges.len() > self.nranges_in_streaming_fragmentation_estimator {
|
||||
self.rolling_nranges_chunks -= self.rolling_last_nranges.pop_front().unwrap();
|
||||
}
|
||||
}
|
||||
/// Returns the average number of chunks per range
|
||||
/// None if there is is not enough data for an estimate
|
||||
pub(crate) fn rolling_chunks_per_range(&self) -> Option<f32> {
|
||||
if self.rolling_last_nranges.len() < xet_config().deduplication.nranges_in_streaming_fragmentation_estimator {
|
||||
if self.rolling_last_nranges.len() < self.nranges_in_streaming_fragmentation_estimator {
|
||||
None
|
||||
} else {
|
||||
Some(self.rolling_nranges_chunks as f32 / self.rolling_last_nranges.len() as f32)
|
||||
@@ -79,17 +93,3 @@ impl DefragPrevention {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DefragPrevention {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
rolling_last_nranges: VecDeque::with_capacity(
|
||||
xet_config().deduplication.nranges_in_streaming_fragmentation_estimator,
|
||||
),
|
||||
rolling_nranges_chunks: 0,
|
||||
defrag_at_low_threshold: true,
|
||||
min_chunks_per_range: xet_config().deduplication.min_n_chunks_per_range,
|
||||
min_chunks_per_range_historesis_factor: xet_config().deduplication.min_n_chunks_per_range_hysteresis_factor,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,8 +7,9 @@ use xet_core_structures::metadata_shard::file_structs::{
|
||||
FileDataSequenceEntry, FileDataSequenceHeader, FileMetadataExt, FileVerificationEntry, MDBFileInfo,
|
||||
};
|
||||
use xet_core_structures::metadata_shard::hash_is_global_dedup_eligible;
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
use super::constants::{XORB_CUT_THRESHOLD_BYTES, XORB_CUT_THRESHOLD_CHUNKS};
|
||||
use super::constants::{MAX_XORB_BYTES, MAX_XORB_CHUNKS};
|
||||
use super::data_aggregator::DataAggregator;
|
||||
use super::dedup_metrics::DeduplicationMetrics;
|
||||
use super::defrag_prevention::DefragPrevention;
|
||||
@@ -17,6 +18,9 @@ use super::{Chunk, RawXorbData};
|
||||
use crate::progress_tracking::upload_tracking::FileXorbDependency;
|
||||
|
||||
pub struct FileDeduper<DataInterfaceType: DeduplicationDataInterface> {
|
||||
#[cfg_attr(not(feature = "simulation"), allow(dead_code))]
|
||||
ctx: XetContext,
|
||||
|
||||
data_mng: DataInterfaceType,
|
||||
|
||||
/// A tag for tracking the file externally
|
||||
@@ -55,8 +59,9 @@ pub struct FileDeduper<DataInterfaceType: DeduplicationDataInterface> {
|
||||
}
|
||||
|
||||
impl<DataInterfaceType: DeduplicationDataInterface> FileDeduper<DataInterfaceType> {
|
||||
pub fn new(data_manager: DataInterfaceType, file_id: u64) -> Self {
|
||||
pub fn new(data_manager: DataInterfaceType, file_id: u64, ctx: XetContext) -> Self {
|
||||
Self {
|
||||
ctx: ctx.clone(),
|
||||
data_mng: data_manager,
|
||||
file_id,
|
||||
new_data: Vec::new(),
|
||||
@@ -65,7 +70,7 @@ impl<DataInterfaceType: DeduplicationDataInterface> FileDeduper<DataInterfaceTyp
|
||||
chunk_hashes: Vec::new(),
|
||||
file_info: Vec::new(),
|
||||
internally_referencing_entries: Vec::new(),
|
||||
defrag_tracker: DefragPrevention::default(),
|
||||
defrag_tracker: DefragPrevention::new(&ctx),
|
||||
min_spacing_between_global_dedup_queries: 0,
|
||||
next_chunk_index_eligible_for_global_dedup_query: 0,
|
||||
deduplication_metrics: DeduplicationMetrics::default(),
|
||||
@@ -85,6 +90,27 @@ impl<DataInterfaceType: DeduplicationDataInterface> FileDeduper<DataInterfaceTyp
|
||||
// All the previous chunk are stored here, use it as the global chunk index start.
|
||||
let global_chunk_index_start = self.chunk_hashes.len();
|
||||
|
||||
#[cfg(feature = "simulation")]
|
||||
let xorb_cut_bytes = self
|
||||
.ctx
|
||||
.config
|
||||
.xorb
|
||||
.simulation_max_bytes
|
||||
.map(|bs| (bs.as_u64() as usize).min(*MAX_XORB_BYTES))
|
||||
.unwrap_or(*MAX_XORB_BYTES);
|
||||
#[cfg(not(feature = "simulation"))]
|
||||
let xorb_cut_bytes = *MAX_XORB_BYTES;
|
||||
#[cfg(feature = "simulation")]
|
||||
let xorb_cut_chunks = self
|
||||
.ctx
|
||||
.config
|
||||
.xorb
|
||||
.simulation_max_chunks
|
||||
.unwrap_or(*MAX_XORB_CHUNKS)
|
||||
.min(*MAX_XORB_CHUNKS);
|
||||
#[cfg(not(feature = "simulation"))]
|
||||
let xorb_cut_chunks = *MAX_XORB_CHUNKS;
|
||||
|
||||
let chunk_hashes = Vec::from_iter(chunks.iter().map(|c| c.hash));
|
||||
|
||||
// Now, parallelize the querying of potential new shards on the server end with
|
||||
@@ -213,9 +239,7 @@ impl<DataInterfaceType: DeduplicationDataInterface> FileDeduper<DataInterfaceTyp
|
||||
dedup_metrics.new_chunks += 1;
|
||||
|
||||
// Do we need to cut a new xorb first?
|
||||
if self.new_data_size + n_bytes > *XORB_CUT_THRESHOLD_BYTES
|
||||
|| self.new_data.len() + 1 > *XORB_CUT_THRESHOLD_CHUNKS
|
||||
{
|
||||
if self.new_data_size + n_bytes > xorb_cut_bytes || self.new_data.len() + 1 > xorb_cut_chunks {
|
||||
let new_xorb = self.cut_new_xorb();
|
||||
xorb_dependencies.push(FileXorbDependency {
|
||||
file_id: self.file_id,
|
||||
|
||||
@@ -8,7 +8,7 @@ use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::task::{JoinHandle, JoinSet};
|
||||
use xet_client::cas_types::FileRange;
|
||||
use xet_runtime::core::{XetRuntime, check_sigint_shutdown};
|
||||
use xet_runtime::core::XetContext;
|
||||
use xet_runtime::utils::adjustable_semaphore::AdjustableSemaphorePermit;
|
||||
|
||||
use super::super::data_writer::{DataFuture, DataWriter};
|
||||
@@ -44,6 +44,7 @@ type PendingWrite = (Bytes, Option<AdjustableSemaphorePermit>);
|
||||
/// Background writer thread that processes queue items and dispatches data
|
||||
/// to an output sink (a `Write` impl or a stream function).
|
||||
struct SyncWriterThread {
|
||||
ctx: XetContext,
|
||||
rx: UnboundedReceiver<SequentialRetrievalItem>,
|
||||
bytes_written: Arc<AtomicU64>,
|
||||
progress_updater: Option<Arc<ItemProgressUpdater>>,
|
||||
@@ -54,12 +55,14 @@ struct SyncWriterThread {
|
||||
|
||||
impl SyncWriterThread {
|
||||
fn new(
|
||||
ctx: XetContext,
|
||||
rx: UnboundedReceiver<SequentialRetrievalItem>,
|
||||
bytes_written: Arc<AtomicU64>,
|
||||
progress_updater: Option<Arc<ItemProgressUpdater>>,
|
||||
run_state: Arc<RunState>,
|
||||
) -> Self {
|
||||
Self {
|
||||
ctx,
|
||||
rx,
|
||||
bytes_written,
|
||||
progress_updater,
|
||||
@@ -142,7 +145,7 @@ impl SyncWriterThread {
|
||||
break;
|
||||
}
|
||||
|
||||
check_sigint_shutdown()?;
|
||||
self.ctx.check_sigint_shutdown()?;
|
||||
}
|
||||
|
||||
debug_assert!(self.finished);
|
||||
@@ -156,7 +159,7 @@ impl SyncWriterThread {
|
||||
let mut pending_writes: VecDeque<PendingWrite> = VecDeque::new();
|
||||
|
||||
while !self.finished || !pending_writes.is_empty() {
|
||||
check_sigint_shutdown()?;
|
||||
self.ctx.check_sigint_shutdown()?;
|
||||
|
||||
// If no pending writes, block to get at least one.
|
||||
if pending_writes.is_empty() {
|
||||
@@ -393,6 +396,7 @@ impl SequentialWriter {
|
||||
/// moved to a background thread for blocking I/O operations.
|
||||
#[allow(clippy::new_ret_no_self)]
|
||||
pub(crate) fn new<W: Write + Send + 'static>(
|
||||
ctx: &XetContext,
|
||||
writer: W,
|
||||
use_vectorized: bool,
|
||||
run_state: Arc<RunState>,
|
||||
@@ -404,9 +408,11 @@ impl SequentialWriter {
|
||||
let run_state_thread = run_state.clone();
|
||||
let bytes_written_clone = bytes_written.clone();
|
||||
let progress_updater = run_state.progress_updater().cloned();
|
||||
let ctx_thread = ctx.clone();
|
||||
|
||||
let handle = XetRuntime::current().spawn_blocking(move || {
|
||||
let writer_thread = SyncWriterThread::new(rx, bytes_written_clone, progress_updater, run_state_thread);
|
||||
let handle = ctx.runtime.spawn_blocking(move || {
|
||||
let writer_thread =
|
||||
SyncWriterThread::new(ctx_thread, rx, bytes_written_clone, progress_updater, run_state_thread);
|
||||
let result = if use_vectorized {
|
||||
writer_thread.run_vectorized(writer)
|
||||
} else {
|
||||
@@ -434,10 +440,15 @@ mod tests {
|
||||
use std::io;
|
||||
use std::time::Duration;
|
||||
|
||||
use xet_runtime::core::XetContext;
|
||||
use xet_runtime::utils::adjustable_semaphore::AdjustableSemaphore;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn test_context() -> XetContext {
|
||||
XetContext::default().unwrap()
|
||||
}
|
||||
|
||||
struct SharedBuffer(Arc<std::sync::Mutex<Vec<u8>>>);
|
||||
|
||||
impl Write for SharedBuffer {
|
||||
@@ -597,7 +608,12 @@ mod tests {
|
||||
let buffer = Arc::new(std::sync::Mutex::new(Vec::new()));
|
||||
let buffer_clone = buffer.clone();
|
||||
|
||||
let mut writer = SequentialWriter::new(Box::new(SharedBuffer(buffer_clone)), false, RunState::new_for_test());
|
||||
let mut writer = SequentialWriter::new(
|
||||
&test_context(),
|
||||
Box::new(SharedBuffer(buffer_clone)),
|
||||
false,
|
||||
RunState::new_for_test(),
|
||||
);
|
||||
|
||||
writer
|
||||
.set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
|
||||
@@ -623,7 +639,12 @@ mod tests {
|
||||
let buffer = Arc::new(std::sync::Mutex::new(Vec::new()));
|
||||
let buffer_clone = buffer.clone();
|
||||
|
||||
let mut writer = SequentialWriter::new(Box::new(SharedBuffer(buffer_clone)), false, RunState::new_for_test());
|
||||
let mut writer = SequentialWriter::new(
|
||||
&test_context(),
|
||||
Box::new(SharedBuffer(buffer_clone)),
|
||||
false,
|
||||
RunState::new_for_test(),
|
||||
);
|
||||
|
||||
// Create futures that resolve with delays
|
||||
let f0: DataFuture = Box::pin(async {
|
||||
@@ -649,7 +670,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_size_mismatch_error() {
|
||||
let buffer = std::io::Cursor::new(Vec::new());
|
||||
let mut writer = SequentialWriter::new(Box::new(buffer), false, RunState::new_for_test());
|
||||
let mut writer = SequentialWriter::new(&test_context(), Box::new(buffer), false, RunState::new_for_test());
|
||||
|
||||
writer
|
||||
.set_next_term_data_source(FileRange::new(0, 10), None, immediate_future(Bytes::from("Hello")))
|
||||
@@ -672,7 +693,8 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
let mut writer = SequentialWriter::new(Box::new(FailingWriter), false, RunState::new_for_test());
|
||||
let mut writer =
|
||||
SequentialWriter::new(&test_context(), Box::new(FailingWriter), false, RunState::new_for_test());
|
||||
|
||||
writer
|
||||
.set_next_term_data_source(FileRange::new(0, 4), None, immediate_future(Bytes::from("Test")))
|
||||
@@ -701,7 +723,8 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
let writer = SequentialWriter::new(Box::new(FlushFailingWriter), false, RunState::new_for_test());
|
||||
let writer =
|
||||
SequentialWriter::new(&test_context(), Box::new(FlushFailingWriter), false, RunState::new_for_test());
|
||||
let result = writer.finish().await;
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result, Err(FileReconstructionError::IoError(_))));
|
||||
@@ -712,7 +735,12 @@ mod tests {
|
||||
let buffer = Arc::new(std::sync::Mutex::new(Vec::new()));
|
||||
let buffer_clone = buffer.clone();
|
||||
|
||||
let mut writer = SequentialWriter::new(Box::new(SharedBuffer(buffer_clone)), false, RunState::new_for_test());
|
||||
let mut writer = SequentialWriter::new(
|
||||
&test_context(),
|
||||
Box::new(SharedBuffer(buffer_clone)),
|
||||
false,
|
||||
RunState::new_for_test(),
|
||||
);
|
||||
|
||||
let failing_future: DataFuture =
|
||||
Box::pin(async { Err(FileReconstructionError::InternalError("Simulated future error".to_string())) });
|
||||
@@ -729,7 +757,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_size_mismatch_too_small() {
|
||||
let buffer = std::io::Cursor::new(Vec::new());
|
||||
let mut writer = SequentialWriter::new(Box::new(buffer), false, RunState::new_for_test());
|
||||
let mut writer = SequentialWriter::new(&test_context(), Box::new(buffer), false, RunState::new_for_test());
|
||||
|
||||
writer
|
||||
.set_next_term_data_source(FileRange::new(0, 10), None, immediate_future(Bytes::from("Hi")))
|
||||
@@ -743,7 +771,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_size_mismatch_too_large() {
|
||||
let buffer = std::io::Cursor::new(Vec::new());
|
||||
let mut writer = SequentialWriter::new(Box::new(buffer), false, RunState::new_for_test());
|
||||
let mut writer = SequentialWriter::new(&test_context(), Box::new(buffer), false, RunState::new_for_test());
|
||||
|
||||
writer
|
||||
.set_next_term_data_source(FileRange::new(0, 2), None, immediate_future(Bytes::from("Hello World")))
|
||||
@@ -759,7 +787,12 @@ mod tests {
|
||||
let buffer = Arc::new(std::sync::Mutex::new(Vec::new()));
|
||||
let buffer_clone = buffer.clone();
|
||||
|
||||
let mut writer = SequentialWriter::new(Box::new(SharedBuffer(buffer_clone)), false, RunState::new_for_test());
|
||||
let mut writer = SequentialWriter::new(
|
||||
&test_context(),
|
||||
Box::new(SharedBuffer(buffer_clone)),
|
||||
false,
|
||||
RunState::new_for_test(),
|
||||
);
|
||||
|
||||
writer
|
||||
.set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
|
||||
@@ -784,7 +817,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_non_sequential_range_returns_error() {
|
||||
let buffer = std::io::Cursor::new(Vec::new());
|
||||
let mut writer = SequentialWriter::new(Box::new(buffer), false, RunState::new_for_test());
|
||||
let mut writer = SequentialWriter::new(&test_context(), Box::new(buffer), false, RunState::new_for_test());
|
||||
|
||||
writer
|
||||
.set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
|
||||
@@ -801,7 +834,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_first_range_must_start_at_zero() {
|
||||
let buffer = std::io::Cursor::new(Vec::new());
|
||||
let mut writer = SequentialWriter::new(Box::new(buffer), false, RunState::new_for_test());
|
||||
let mut writer = SequentialWriter::new(&test_context(), Box::new(buffer), false, RunState::new_for_test());
|
||||
|
||||
let result = writer
|
||||
.set_next_term_data_source(FileRange::new(5, 10), None, immediate_future(Bytes::from("Hello")))
|
||||
@@ -816,7 +849,12 @@ mod tests {
|
||||
let buffer_clone = buffer.clone();
|
||||
let semaphore = AdjustableSemaphore::new(2, (0, 2));
|
||||
|
||||
let mut writer = SequentialWriter::new(Box::new(SharedBuffer(buffer_clone)), false, RunState::new_for_test());
|
||||
let mut writer = SequentialWriter::new(
|
||||
&test_context(),
|
||||
Box::new(SharedBuffer(buffer_clone)),
|
||||
false,
|
||||
RunState::new_for_test(),
|
||||
);
|
||||
|
||||
let permit1 = semaphore.acquire().await.unwrap();
|
||||
let permit2 = semaphore.acquire().await.unwrap();
|
||||
@@ -853,7 +891,7 @@ mod tests {
|
||||
let buffer = test_writer.buffer.clone();
|
||||
let vectored_count = test_writer.vectored_write_count.clone();
|
||||
|
||||
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
|
||||
let mut writer = SequentialWriter::new(&test_context(), Box::new(test_writer), true, RunState::new_for_test());
|
||||
|
||||
writer
|
||||
.set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
|
||||
@@ -880,7 +918,7 @@ mod tests {
|
||||
let test_writer = TestWriter::new(TestWriterConfig::vectorized_partial(3));
|
||||
let buffer = test_writer.buffer.clone();
|
||||
|
||||
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
|
||||
let mut writer = SequentialWriter::new(&test_context(), Box::new(test_writer), true, RunState::new_for_test());
|
||||
|
||||
writer
|
||||
.set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
|
||||
@@ -910,7 +948,7 @@ mod tests {
|
||||
let test_writer = TestWriter::new(TestWriterConfig::vectorized());
|
||||
let buffer = test_writer.buffer.clone();
|
||||
|
||||
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
|
||||
let mut writer = SequentialWriter::new(&test_context(), Box::new(test_writer), true, RunState::new_for_test());
|
||||
|
||||
// Create futures that resolve with different delays
|
||||
let f0: DataFuture = Box::pin(async {
|
||||
@@ -940,7 +978,7 @@ mod tests {
|
||||
let buffer = test_writer.buffer.clone();
|
||||
let vectored_count = test_writer.vectored_write_count.clone();
|
||||
|
||||
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
|
||||
let mut writer = SequentialWriter::new(&test_context(), Box::new(test_writer), true, RunState::new_for_test());
|
||||
|
||||
// Write 100 single-byte chunks
|
||||
for i in 0..100u8 {
|
||||
@@ -969,7 +1007,7 @@ mod tests {
|
||||
let test_writer = TestWriter::new(TestWriterConfig::vectorized_with_interrupts());
|
||||
let buffer = test_writer.buffer.clone();
|
||||
|
||||
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
|
||||
let mut writer = SequentialWriter::new(&test_context(), Box::new(test_writer), true, RunState::new_for_test());
|
||||
|
||||
writer
|
||||
.set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
|
||||
@@ -996,7 +1034,7 @@ mod tests {
|
||||
let buffer = test_writer.buffer.clone();
|
||||
let semaphore = AdjustableSemaphore::new(2, (0, 2));
|
||||
|
||||
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
|
||||
let mut writer = SequentialWriter::new(&test_context(), Box::new(test_writer), true, RunState::new_for_test());
|
||||
|
||||
let permit1 = semaphore.acquire().await.unwrap();
|
||||
let permit2 = semaphore.acquire().await.unwrap();
|
||||
@@ -1031,7 +1069,7 @@ mod tests {
|
||||
let buffer = test_writer.buffer.clone();
|
||||
let semaphore = AdjustableSemaphore::new(3, (0, 3));
|
||||
|
||||
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
|
||||
let mut writer = SequentialWriter::new(&test_context(), Box::new(test_writer), true, RunState::new_for_test());
|
||||
|
||||
let permit1 = semaphore.acquire().await.unwrap();
|
||||
let permit2 = semaphore.acquire().await.unwrap();
|
||||
@@ -1067,7 +1105,7 @@ mod tests {
|
||||
let write_count = test_writer.write_count.clone();
|
||||
let vectored_count = test_writer.vectored_write_count.clone();
|
||||
|
||||
let mut writer = SequentialWriter::new(Box::new(test_writer), false, RunState::new_for_test());
|
||||
let mut writer = SequentialWriter::new(&test_context(), Box::new(test_writer), false, RunState::new_for_test());
|
||||
|
||||
writer
|
||||
.set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
|
||||
@@ -1095,7 +1133,7 @@ mod tests {
|
||||
let test_writer = TestWriter::new(TestWriterConfig::partial(3));
|
||||
let buffer = test_writer.buffer.clone();
|
||||
|
||||
let mut writer = SequentialWriter::new(Box::new(test_writer), false, RunState::new_for_test());
|
||||
let mut writer = SequentialWriter::new(&test_context(), Box::new(test_writer), false, RunState::new_for_test());
|
||||
|
||||
writer
|
||||
.set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("Hello")))
|
||||
@@ -1125,7 +1163,7 @@ mod tests {
|
||||
let test_writer = TestWriter::new(TestWriterConfig::vectorized_partial(1));
|
||||
let buffer = test_writer.buffer.clone();
|
||||
|
||||
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
|
||||
let mut writer = SequentialWriter::new(&test_context(), Box::new(test_writer), true, RunState::new_for_test());
|
||||
|
||||
writer
|
||||
.set_next_term_data_source(FileRange::new(0, 5), None, immediate_future(Bytes::from("ABCDE")))
|
||||
@@ -1148,7 +1186,7 @@ mod tests {
|
||||
let test_writer = TestWriter::new(TestWriterConfig::vectorized());
|
||||
let buffer = test_writer.buffer.clone();
|
||||
|
||||
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
|
||||
let mut writer = SequentialWriter::new(&test_context(), Box::new(test_writer), true, RunState::new_for_test());
|
||||
|
||||
// Write in chunks of 1000 bytes
|
||||
for i in 0..10 {
|
||||
@@ -1177,7 +1215,7 @@ mod tests {
|
||||
let test_writer = TestWriter::new(TestWriterConfig::vectorized_partial(100));
|
||||
let buffer = test_writer.buffer.clone();
|
||||
|
||||
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test());
|
||||
let mut writer = SequentialWriter::new(&test_context(), Box::new(test_writer), true, RunState::new_for_test());
|
||||
|
||||
// Write in chunks of 500 bytes
|
||||
for i in 0..10 {
|
||||
@@ -1204,7 +1242,7 @@ mod tests {
|
||||
async fn test_vectorized_exceeded_max_slice() {
|
||||
let test_writer = TestWriter::new(TestWriterConfig::vectorized_hard_limit(2)); // hard limit set to 2 slices at a time
|
||||
|
||||
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test()); // controlled writev at max 24 slices at a time
|
||||
let mut writer = SequentialWriter::new(&test_context(), Box::new(test_writer), true, RunState::new_for_test()); // controlled writev at max 24 slices at a time
|
||||
|
||||
// Write in slices of 10 bytes, creating in total 1000 slices
|
||||
for i in 0..1000 {
|
||||
@@ -1237,7 +1275,7 @@ mod tests {
|
||||
let test_writer = TestWriter::new(TestWriterConfig::vectorized_hard_limit(40)); // hard limit set to 40 slices at a time
|
||||
let buffer = test_writer.buffer.clone();
|
||||
|
||||
let mut writer = SequentialWriter::new(Box::new(test_writer), true, RunState::new_for_test()); // controlled writev at max 24 slices at a time
|
||||
let mut writer = SequentialWriter::new(&test_context(), Box::new(test_writer), true, RunState::new_for_test()); // controlled writev at max 24 slices at a time
|
||||
|
||||
// Write in slices of 10 bytes, creating in total 1000 slices
|
||||
for i in 0..1000 {
|
||||
|
||||
@@ -11,7 +11,7 @@ use xet_client::cas_types::FileRange;
|
||||
use xet_client::chunk_cache::ChunkCache;
|
||||
use xet_core_structures::merklehash::MerkleHash;
|
||||
use xet_runtime::config::ReconstructionConfig;
|
||||
use xet_runtime::core::{XetRuntime, xet_config};
|
||||
use xet_runtime::core::XetContext;
|
||||
use xet_runtime::utils::ClosureGuard;
|
||||
use xet_runtime::utils::adjustable_semaphore::AdjustableSemaphore;
|
||||
|
||||
@@ -25,6 +25,7 @@ use crate::progress_tracking::ItemProgressUpdater;
|
||||
/// and writing the reassembled data to an output. Supports byte range requests and
|
||||
/// uses memory-limited buffering with adaptive prefetching.
|
||||
pub struct FileReconstructor {
|
||||
ctx: XetContext,
|
||||
client: Arc<dyn Client>,
|
||||
file_hash: MerkleHash,
|
||||
byte_range: Option<FileRange>,
|
||||
@@ -44,13 +45,14 @@ pub struct FileReconstructor {
|
||||
}
|
||||
|
||||
impl FileReconstructor {
|
||||
pub fn new(client: &Arc<dyn Client>, file_hash: MerkleHash) -> Self {
|
||||
pub fn new(ctx: &XetContext, client: &Arc<dyn Client>, file_hash: MerkleHash) -> Self {
|
||||
Self {
|
||||
ctx: ctx.clone(),
|
||||
client: client.clone(),
|
||||
file_hash,
|
||||
byte_range: None,
|
||||
progress_updater: default_progress_updater(),
|
||||
config: Arc::new(xet_config().reconstruction.clone()),
|
||||
config: Arc::new(ctx.config.reconstruction.clone()),
|
||||
chunk_cache: None,
|
||||
custom_buffer_semaphore: None,
|
||||
cancellation_token: CancellationToken::new(),
|
||||
@@ -138,7 +140,7 @@ impl FileReconstructor {
|
||||
|
||||
let run_state = RunState::new(self.cancellation_token.clone(), self.file_hash, self.progress_updater.clone());
|
||||
|
||||
let data_writer = SequentialWriter::new(file, self.config.use_vectored_write, run_state.clone());
|
||||
let data_writer = SequentialWriter::new(&self.ctx, file, self.config.use_vectored_write, run_state.clone());
|
||||
|
||||
self.run(data_writer, run_state, false).await
|
||||
}
|
||||
@@ -155,7 +157,7 @@ impl FileReconstructor {
|
||||
);
|
||||
|
||||
let run_state = RunState::new(self.cancellation_token.clone(), self.file_hash, self.progress_updater.clone());
|
||||
let data_writer = SequentialWriter::new(writer, self.config.use_vectored_write, run_state.clone());
|
||||
let data_writer = SequentialWriter::new(&self.ctx, writer, self.config.use_vectored_write, run_state.clone());
|
||||
self.run(data_writer, run_state, false).await
|
||||
}
|
||||
|
||||
@@ -229,6 +231,7 @@ impl FileReconstructor {
|
||||
_is_streaming: bool,
|
||||
) -> std::result::Result<u64, RunError> {
|
||||
let Self {
|
||||
ctx,
|
||||
client,
|
||||
byte_range,
|
||||
config,
|
||||
@@ -243,6 +246,7 @@ impl FileReconstructor {
|
||||
let requested_range = byte_range.unwrap_or_else(FileRange::full);
|
||||
|
||||
let mut term_manager = ReconstructionTermManager::new(
|
||||
ctx.clone(),
|
||||
config.clone(),
|
||||
client.clone(),
|
||||
file_hash,
|
||||
@@ -252,8 +256,8 @@ impl FileReconstructor {
|
||||
.await?;
|
||||
|
||||
let using_global_memory_limit = custom_buffer_semaphore.is_none();
|
||||
let download_buffer_semaphore = custom_buffer_semaphore
|
||||
.unwrap_or_else(|| XetRuntime::current().common().reconstruction_download_buffer.clone());
|
||||
let download_buffer_semaphore =
|
||||
custom_buffer_semaphore.unwrap_or_else(|| ctx.common.reconstruction_download_buffer.clone());
|
||||
|
||||
// Dynamic buffer scaling: the target buffer size grows with the number of active
|
||||
// downloads: target = (base + n * perfile).min(limit). On start we increment to
|
||||
@@ -264,7 +268,7 @@ impl FileReconstructor {
|
||||
let _download_count_decrement_guard;
|
||||
|
||||
if using_global_memory_limit {
|
||||
let active_downloads = XetRuntime::current().common().active_downloads.clone();
|
||||
let active_downloads = ctx.common.active_downloads.clone();
|
||||
let n = active_downloads.fetch_add(1, Ordering::Relaxed) + 1;
|
||||
|
||||
let base = config.download_buffer_size.as_u64();
|
||||
@@ -348,7 +352,12 @@ impl FileReconstructor {
|
||||
};
|
||||
|
||||
let data_future = file_term
|
||||
.get_data_task(client.clone(), run_state.progress_updater().cloned(), chunk_cache.clone())
|
||||
.get_data_task(
|
||||
ctx.clone(),
|
||||
client.clone(),
|
||||
run_state.progress_updater().cloned(),
|
||||
chunk_cache.clone(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
@@ -415,9 +424,11 @@ mod tests {
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::time::Duration;
|
||||
|
||||
use tokio::runtime::Handle;
|
||||
use xet_client::cas_client::{ClientTestingUtils, DirectAccessClient, LocalClient, RandomFileContents};
|
||||
use xet_client::cas_types::FileRange;
|
||||
use xet_runtime::core::XetRuntime;
|
||||
use xet_runtime::config::XetConfig;
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
use super::*;
|
||||
use crate::progress_tracking::ItemProgressUpdater;
|
||||
@@ -436,7 +447,7 @@ mod tests {
|
||||
|
||||
/// Creates a test client and uploads a random file with the given term specification.
|
||||
async fn setup_test_file(term_spec: &[(u64, (u64, u64))]) -> (Arc<LocalClient>, RandomFileContents) {
|
||||
let client = LocalClient::temporary().await.unwrap();
|
||||
let client = LocalClient::temporary(XetContext::default().unwrap()).await.unwrap();
|
||||
let file_contents = client.upload_random_file(term_spec, TEST_CHUNK_SIZE).await.unwrap();
|
||||
(client, file_contents)
|
||||
}
|
||||
@@ -453,7 +464,8 @@ mod tests {
|
||||
let writer = StaticCursorWriter(buffer.clone());
|
||||
|
||||
let mut reconstructor =
|
||||
FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_hash).with_config(config);
|
||||
FileReconstructor::new(&XetContext::default().unwrap(), &(client.clone() as Arc<dyn Client>), file_hash)
|
||||
.with_config(config);
|
||||
|
||||
if let Some(range) = byte_range {
|
||||
reconstructor = reconstructor.with_byte_range(range);
|
||||
@@ -480,7 +492,8 @@ mod tests {
|
||||
let file_path = temp_dir.path().join("output.bin");
|
||||
|
||||
let mut reconstructor =
|
||||
FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_hash).with_config(config);
|
||||
FileReconstructor::new(&XetContext::default().unwrap(), &(client.clone() as Arc<dyn Client>), file_hash)
|
||||
.with_config(config);
|
||||
|
||||
if let Some(range) = byte_range {
|
||||
reconstructor = reconstructor.with_byte_range(range);
|
||||
@@ -507,7 +520,8 @@ mod tests {
|
||||
let file_path = temp_dir.path().join("output.bin");
|
||||
|
||||
let mut reconstructor =
|
||||
FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_hash).with_config(config);
|
||||
FileReconstructor::new(&XetContext::default().unwrap(), &(client.clone() as Arc<dyn Client>), file_hash)
|
||||
.with_config(config);
|
||||
|
||||
if let Some(range) = byte_range {
|
||||
reconstructor = reconstructor.with_byte_range(range);
|
||||
@@ -532,7 +546,8 @@ mod tests {
|
||||
let file_path = temp_dir.path().join("output.bin");
|
||||
|
||||
let mut reconstructor =
|
||||
FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_hash).with_config(config);
|
||||
FileReconstructor::new(&XetContext::default().unwrap(), &(client.clone() as Arc<dyn Client>), file_hash)
|
||||
.with_config(config);
|
||||
|
||||
if let Some(range) = byte_range {
|
||||
reconstructor = reconstructor.with_byte_range(range);
|
||||
@@ -574,7 +589,7 @@ mod tests {
|
||||
config.use_vectored_write = use_vectored;
|
||||
|
||||
// Test 1: reconstruct_to_writer
|
||||
let vec_result = reconstruct_to_vec(client, h, None, &config, None).await.unwrap();
|
||||
let vec_result = reconstruct_to_vec(&client, h, None, &config, None).await.unwrap();
|
||||
assert_eq!(vec_result, *expected, "vec failed (vectored={use_vectored})");
|
||||
|
||||
// Test 2: reconstruct_to_file
|
||||
@@ -606,7 +621,7 @@ mod tests {
|
||||
config.use_vectored_write = use_vectored;
|
||||
|
||||
// Test 1: reconstruct_to_writer
|
||||
let vec_result = reconstruct_to_vec(client, file_contents.file_hash, Some(range), &config, None)
|
||||
let vec_result = reconstruct_to_vec(&client, file_contents.file_hash, Some(range), &config, None)
|
||||
.await
|
||||
.expect("reconstruct_to_vec should succeed");
|
||||
assert_eq!(vec_result, expected, "vec failed (vectored={use_vectored})");
|
||||
@@ -683,12 +698,16 @@ mod tests {
|
||||
let writer = StaticCursorWriter(buffer.clone());
|
||||
|
||||
let progress_updater = ItemProgressUpdater::new_standalone("file");
|
||||
let bytes_written = FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_contents.file_hash)
|
||||
.with_config(&config)
|
||||
.with_progress_updater(progress_updater.clone())
|
||||
.reconstruct_to_writer(writer)
|
||||
.await
|
||||
.unwrap();
|
||||
let bytes_written = FileReconstructor::new(
|
||||
&XetContext::default().unwrap(),
|
||||
&(client.clone() as Arc<dyn Client>),
|
||||
file_contents.file_hash,
|
||||
)
|
||||
.with_config(&config)
|
||||
.with_progress_updater(progress_updater.clone())
|
||||
.reconstruct_to_writer(writer)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(bytes_written, file_contents.data.len() as u64);
|
||||
}
|
||||
@@ -705,13 +724,17 @@ mod tests {
|
||||
let writer = StaticCursorWriter(buffer.clone());
|
||||
|
||||
let progress_updater = ItemProgressUpdater::new_standalone("file");
|
||||
let bytes_written = FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_contents.file_hash)
|
||||
.with_config(&config)
|
||||
.with_byte_range(range)
|
||||
.with_progress_updater(progress_updater.clone())
|
||||
.reconstruct_to_writer(writer)
|
||||
.await
|
||||
.unwrap();
|
||||
let bytes_written = FileReconstructor::new(
|
||||
&XetContext::default().unwrap(),
|
||||
&(client.clone() as Arc<dyn Client>),
|
||||
file_contents.file_hash,
|
||||
)
|
||||
.with_config(&config)
|
||||
.with_byte_range(range)
|
||||
.with_progress_updater(progress_updater.clone())
|
||||
.reconstruct_to_writer(writer)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(bytes_written, expected_bytes);
|
||||
}
|
||||
@@ -729,12 +752,16 @@ mod tests {
|
||||
let buffer = Arc::new(std::sync::Mutex::new(Cursor::new(Vec::new())));
|
||||
let writer = StaticCursorWriter(buffer.clone());
|
||||
|
||||
let bytes_written = FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_contents.file_hash)
|
||||
.with_config(&config)
|
||||
.with_progress_updater(task.clone())
|
||||
.reconstruct_to_writer(writer)
|
||||
.await
|
||||
.unwrap();
|
||||
let bytes_written = FileReconstructor::new(
|
||||
&XetContext::default().unwrap(),
|
||||
&(client.clone() as Arc<dyn Client>),
|
||||
file_contents.file_hash,
|
||||
)
|
||||
.with_config(&config)
|
||||
.with_progress_updater(task.clone())
|
||||
.reconstruct_to_writer(writer)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(bytes_written, file_contents.data.len() as u64);
|
||||
|
||||
@@ -759,12 +786,16 @@ mod tests {
|
||||
let buffer = Arc::new(std::sync::Mutex::new(Cursor::new(Vec::new())));
|
||||
let writer = StaticCursorWriter(buffer.clone());
|
||||
|
||||
let bytes_written = FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_contents.file_hash)
|
||||
.with_config(&config)
|
||||
.with_progress_updater(task.clone())
|
||||
.reconstruct_to_writer(writer)
|
||||
.await
|
||||
.unwrap();
|
||||
let bytes_written = FileReconstructor::new(
|
||||
&XetContext::default().unwrap(),
|
||||
&(client.clone() as Arc<dyn Client>),
|
||||
file_contents.file_hash,
|
||||
)
|
||||
.with_config(&config)
|
||||
.with_progress_updater(task.clone())
|
||||
.reconstruct_to_writer(writer)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(bytes_written, file_size);
|
||||
|
||||
@@ -1050,8 +1081,9 @@ mod tests {
|
||||
// Create a tiny semaphore (1 permit) to force sequential processing
|
||||
// This ensures each term is fully written before the next is fetched
|
||||
let tiny_semaphore = AdjustableSemaphore::new(1, (1, 1));
|
||||
let ctx = XetContext::from_external(Handle::current(), XetConfig::new());
|
||||
|
||||
FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_contents.file_hash)
|
||||
FileReconstructor::new(&ctx, &(client.clone() as Arc<dyn Client>), file_contents.file_hash)
|
||||
.with_config(url_refresh_test_config())
|
||||
.with_buffer_semaphore(tiny_semaphore)
|
||||
.reconstruct_to_writer(writer)
|
||||
@@ -1084,8 +1116,9 @@ mod tests {
|
||||
let writer_buffer = writer.buffer.clone();
|
||||
|
||||
let tiny_semaphore = AdjustableSemaphore::new(1, (1, 1));
|
||||
let ctx = XetContext::from_external(Handle::current(), XetConfig::new());
|
||||
|
||||
FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_contents.file_hash)
|
||||
FileReconstructor::new(&ctx, &(client.clone() as Arc<dyn Client>), file_contents.file_hash)
|
||||
.with_config(url_refresh_test_config())
|
||||
.with_buffer_semaphore(tiny_semaphore)
|
||||
.reconstruct_to_writer(writer)
|
||||
@@ -1110,8 +1143,9 @@ mod tests {
|
||||
let writer_buffer = writer.buffer.clone();
|
||||
|
||||
let tiny_semaphore = AdjustableSemaphore::new(1, (1, 1));
|
||||
let ctx = XetContext::from_external(Handle::current(), XetConfig::new());
|
||||
|
||||
FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_contents.file_hash)
|
||||
FileReconstructor::new(&ctx, &(client.clone() as Arc<dyn Client>), file_contents.file_hash)
|
||||
.with_config(url_refresh_test_config())
|
||||
.with_buffer_semaphore(tiny_semaphore)
|
||||
.reconstruct_to_writer(writer)
|
||||
@@ -1136,8 +1170,9 @@ mod tests {
|
||||
let writer_buffer = writer.buffer.clone();
|
||||
|
||||
let tiny_semaphore = AdjustableSemaphore::new(1, (1, 1));
|
||||
let ctx = XetContext::from_external(Handle::current(), XetConfig::new());
|
||||
|
||||
FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_contents.file_hash)
|
||||
FileReconstructor::new(&ctx, &(client.clone() as Arc<dyn Client>), file_contents.file_hash)
|
||||
.with_config(url_refresh_test_config())
|
||||
.with_buffer_semaphore(tiny_semaphore)
|
||||
.reconstruct_to_writer(writer)
|
||||
@@ -1161,10 +1196,11 @@ mod tests {
|
||||
let writer_buffer = writer.buffer.clone();
|
||||
|
||||
let tiny_semaphore = AdjustableSemaphore::new(1, (0, 1));
|
||||
let ctx = XetContext::from_external(Handle::current(), XetConfig::new());
|
||||
|
||||
let range = FileRange::new(file_len / 4, file_len * 3 / 4);
|
||||
|
||||
FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_contents.file_hash)
|
||||
FileReconstructor::new(&ctx, &(client.clone() as Arc<dyn Client>), file_contents.file_hash)
|
||||
.with_byte_range(range)
|
||||
.with_config(url_refresh_test_config())
|
||||
.with_buffer_semaphore(tiny_semaphore)
|
||||
@@ -1184,29 +1220,32 @@ mod tests {
|
||||
runtime_config.reconstruction.download_buffer_limit = xet_runtime::utils::ByteSize::from("4kb");
|
||||
let expected_total = runtime_config.reconstruction.download_buffer_limit.as_u64();
|
||||
|
||||
let rt = XetRuntime::new_with_config(runtime_config).unwrap();
|
||||
let ctx = XetContext::with_config(runtime_config).unwrap();
|
||||
let runtime = ctx.runtime.clone();
|
||||
|
||||
rt.bridge_sync(async move {
|
||||
let (client, file_contents) = setup_test_file(&[(1, (0, 2)), (2, (0, 2)), (3, (0, 2))]).await;
|
||||
let sem = XetRuntime::current().common().reconstruction_download_buffer.clone();
|
||||
runtime
|
||||
.bridge_sync(async move {
|
||||
let ctx = ctx.clone();
|
||||
let (client, file_contents) = setup_test_file(&[(1, (0, 2)), (2, (0, 2)), (3, (0, 2))]).await;
|
||||
let sem = ctx.common.reconstruction_download_buffer.clone();
|
||||
|
||||
// Pre-grow to max so the run's increment request is a no-op.
|
||||
let p = sem.increment_total_permits(u64::MAX).unwrap();
|
||||
drop(p);
|
||||
assert_eq!(sem.total_permits(), expected_total);
|
||||
// Pre-grow to max so the run's increment request is a no-op.
|
||||
let p = sem.increment_total_permits(u64::MAX).unwrap();
|
||||
drop(p);
|
||||
assert_eq!(sem.total_permits(), expected_total);
|
||||
|
||||
let mut config = test_config();
|
||||
config.download_buffer_perfile_size = xet_runtime::utils::ByteSize::from("8kb");
|
||||
let mut config = test_config();
|
||||
config.download_buffer_perfile_size = xet_runtime::utils::ByteSize::from("8kb");
|
||||
|
||||
let reconstructed = reconstruct_to_vec(&client, file_contents.file_hash, None, &config, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(reconstructed, file_contents.data);
|
||||
let reconstructed = reconstruct_to_vec(&client, file_contents.file_hash, None, &config, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(reconstructed, file_contents.data);
|
||||
|
||||
assert_eq!(sem.total_permits(), expected_total);
|
||||
assert_eq!(XetRuntime::current().common().active_downloads.load(Ordering::Relaxed), 0);
|
||||
})
|
||||
.unwrap();
|
||||
assert_eq!(sem.total_permits(), expected_total);
|
||||
assert_eq!(ctx.common.active_downloads.load(Ordering::Relaxed), 0);
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// ==================== File Output Specific Tests ====================
|
||||
@@ -1221,7 +1260,7 @@ mod tests {
|
||||
range: FileRange,
|
||||
config: ReconstructionConfig,
|
||||
) -> Result<u64> {
|
||||
FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_hash)
|
||||
FileReconstructor::new(&XetContext::default().unwrap(), &(client.clone() as Arc<dyn Client>), file_hash)
|
||||
.with_byte_range(range)
|
||||
.with_config(config)
|
||||
.reconstruct_to_file(file_path, None, false)
|
||||
@@ -1238,7 +1277,7 @@ mod tests {
|
||||
// Each xorb has ~64KB of data (16 chunks * 4KB), giving us ~1MB total with 16 xorbs.
|
||||
let term_spec: Vec<(u64, (u64, u64))> = (1..=16).map(|i| (i, (0, 16))).collect();
|
||||
|
||||
let client = LocalClient::temporary().await.unwrap();
|
||||
let client = LocalClient::temporary(XetContext::default().unwrap()).await.unwrap();
|
||||
let file_contents = client.upload_random_file(&term_spec, LARGE_CHUNK_SIZE).await.unwrap();
|
||||
let file_len = file_contents.data.len() as u64;
|
||||
|
||||
@@ -1277,7 +1316,7 @@ mod tests {
|
||||
let config = config.clone();
|
||||
|
||||
join_set.spawn(async move {
|
||||
FileReconstructor::new(&(client as Arc<dyn Client>), file_hash)
|
||||
FileReconstructor::new(&XetContext::default().unwrap(), &(client as Arc<dyn Client>), file_hash)
|
||||
.with_byte_range(range)
|
||||
.with_config(config)
|
||||
.reconstruct_to_file(&file_path, None, false)
|
||||
@@ -1420,7 +1459,7 @@ mod tests {
|
||||
/// LocalClient with max_ranges_per_fetch=2 (tests V2 response splitting without HTTP).
|
||||
#[tokio::test]
|
||||
async fn test_local_client_max_ranges_2_disjoint() {
|
||||
let client = LocalClient::temporary().await.unwrap();
|
||||
let client = LocalClient::temporary(XetContext::default().unwrap()).await.unwrap();
|
||||
client.set_max_ranges_per_fetch(2);
|
||||
|
||||
let term_spec = &[(1, (0, 2)), (1, (4, 6)), (1, (8, 10)), (1, (12, 14))];
|
||||
@@ -1436,7 +1475,7 @@ mod tests {
|
||||
/// LocalClient with max_ranges_per_fetch=1 (every range gets its own fetch entry).
|
||||
#[tokio::test]
|
||||
async fn test_local_client_max_ranges_1_multi_xorb() {
|
||||
let client = LocalClient::temporary().await.unwrap();
|
||||
let client = LocalClient::temporary(XetContext::default().unwrap()).await.unwrap();
|
||||
client.set_max_ranges_per_fetch(1);
|
||||
|
||||
let term_spec = &[(1, (0, 2)), (2, (0, 2)), (1, (4, 6)), (2, (4, 6))];
|
||||
@@ -1466,12 +1505,16 @@ mod tests {
|
||||
let buffer = Arc::new(std::sync::Mutex::new(Cursor::new(Vec::new())));
|
||||
let writer = StaticCursorWriter(buffer.clone());
|
||||
|
||||
let bytes_written = FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_contents.file_hash)
|
||||
.with_config(&config)
|
||||
.with_cancellation_token(token)
|
||||
.reconstruct_to_writer(writer)
|
||||
.await
|
||||
.unwrap();
|
||||
let bytes_written = FileReconstructor::new(
|
||||
&XetContext::default().unwrap(),
|
||||
&(client.clone() as Arc<dyn Client>),
|
||||
file_contents.file_hash,
|
||||
)
|
||||
.with_config(&config)
|
||||
.with_cancellation_token(token)
|
||||
.reconstruct_to_writer(writer)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(bytes_written, 0);
|
||||
}
|
||||
@@ -1519,13 +1562,17 @@ mod tests {
|
||||
// Use a tiny semaphore to force sequential term processing.
|
||||
let tiny_semaphore = AdjustableSemaphore::new(1, (1, 1));
|
||||
|
||||
let bytes_written = FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_contents.file_hash)
|
||||
.with_config(&config)
|
||||
.with_cancellation_token(token)
|
||||
.with_buffer_semaphore(tiny_semaphore)
|
||||
.reconstruct_to_writer(writer)
|
||||
.await
|
||||
.unwrap();
|
||||
let bytes_written = FileReconstructor::new(
|
||||
&XetContext::default().unwrap(),
|
||||
&(client.clone() as Arc<dyn Client>),
|
||||
file_contents.file_hash,
|
||||
)
|
||||
.with_config(&config)
|
||||
.with_cancellation_token(token)
|
||||
.with_buffer_semaphore(tiny_semaphore)
|
||||
.reconstruct_to_writer(writer)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Verify cancellation returned Ok(0) and only partial data was written.
|
||||
assert_eq!(bytes_written, 0);
|
||||
@@ -1542,12 +1589,16 @@ mod tests {
|
||||
let buffer = Arc::new(std::sync::Mutex::new(Cursor::new(Vec::new())));
|
||||
let writer = StaticCursorWriter(buffer.clone());
|
||||
|
||||
let bytes_written = FileReconstructor::new(&(client.clone() as Arc<dyn Client>), file_contents.file_hash)
|
||||
.with_config(&config)
|
||||
.with_cancellation_token(token)
|
||||
.reconstruct_to_writer(writer)
|
||||
.await
|
||||
.unwrap();
|
||||
let bytes_written = FileReconstructor::new(
|
||||
&XetContext::default().unwrap(),
|
||||
&(client.clone() as Arc<dyn Client>),
|
||||
file_contents.file_hash,
|
||||
)
|
||||
.with_config(&config)
|
||||
.with_cancellation_token(token)
|
||||
.reconstruct_to_writer(writer)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(bytes_written, file_contents.data.len() as u64);
|
||||
assert_eq!(buffer.lock().unwrap().get_ref().clone(), file_contents.data);
|
||||
@@ -1559,10 +1610,10 @@ mod tests {
|
||||
mod multirange_tests {
|
||||
use super::*;
|
||||
|
||||
fn with_multirange_config(enable: bool) -> Arc<XetRuntime> {
|
||||
fn with_multirange_config(enable: bool) -> XetContext {
|
||||
let mut config = xet_runtime::config::XetConfig::new();
|
||||
config.client.enable_multirange_fetching = enable;
|
||||
XetRuntime::new_with_config(config).unwrap()
|
||||
XetContext::with_config(config).unwrap()
|
||||
}
|
||||
|
||||
/// Exercises multiple disjoint-range scenarios through LocalClient with both
|
||||
@@ -1570,38 +1621,39 @@ mod tests {
|
||||
#[test]
|
||||
fn test_multirange_local_client() {
|
||||
for enable in [false, true] {
|
||||
let rt = with_multirange_config(enable);
|
||||
rt.bridge_sync(async move {
|
||||
let scenarios: Vec<Vec<(u64, (u64, u64))>> = vec![
|
||||
vec![(1, (0, 2)), (1, (4, 6)), (1, (8, 10))],
|
||||
vec![
|
||||
(1, (0, 2)),
|
||||
(2, (0, 2)),
|
||||
(1, (4, 6)),
|
||||
(2, (4, 6)),
|
||||
(1, (8, 10)),
|
||||
(2, (8, 10)),
|
||||
],
|
||||
vec![
|
||||
(1, (0, 2)),
|
||||
(2, (0, 3)),
|
||||
(3, (2, 5)),
|
||||
(1, (5, 8)),
|
||||
(2, (6, 8)),
|
||||
(3, (0, 2)),
|
||||
],
|
||||
];
|
||||
let config = test_config();
|
||||
for term_spec in &scenarios {
|
||||
let (client, fc) = setup_test_file(term_spec).await;
|
||||
reconstruct_and_verify_full(&client, &fc, config.clone()).await;
|
||||
let ctx = with_multirange_config(enable);
|
||||
ctx.runtime
|
||||
.bridge_sync(async move {
|
||||
let scenarios: Vec<Vec<(u64, (u64, u64))>> = vec![
|
||||
vec![(1, (0, 2)), (1, (4, 6)), (1, (8, 10))],
|
||||
vec![
|
||||
(1, (0, 2)),
|
||||
(2, (0, 2)),
|
||||
(1, (4, 6)),
|
||||
(2, (4, 6)),
|
||||
(1, (8, 10)),
|
||||
(2, (8, 10)),
|
||||
],
|
||||
vec![
|
||||
(1, (0, 2)),
|
||||
(2, (0, 3)),
|
||||
(3, (2, 5)),
|
||||
(1, (5, 8)),
|
||||
(2, (6, 8)),
|
||||
(3, (0, 2)),
|
||||
],
|
||||
];
|
||||
let config = test_config();
|
||||
for term_spec in &scenarios {
|
||||
let (client, fc) = setup_test_file(term_spec).await;
|
||||
reconstruct_and_verify_full(&client, &fc, config.clone()).await;
|
||||
|
||||
let file_len = fc.data.len() as u64;
|
||||
let range = FileRange::new(file_len / 4, file_len * 3 / 4);
|
||||
reconstruct_and_verify_range(&client, &fc, range, config.clone()).await;
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
let file_len = fc.data.len() as u64;
|
||||
let range = FileRange::new(file_len / 4, file_len * 3 / 4);
|
||||
reconstruct_and_verify_range(&client, &fc, range, config.clone()).await;
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1609,19 +1661,20 @@ mod tests {
|
||||
#[test]
|
||||
fn test_multirange_max_ranges() {
|
||||
for enable in [false, true] {
|
||||
let rt = with_multirange_config(enable);
|
||||
rt.bridge_sync(async {
|
||||
let client = LocalClient::temporary().await.unwrap();
|
||||
client.set_max_ranges_per_fetch(2);
|
||||
let ctx = with_multirange_config(enable);
|
||||
ctx.runtime
|
||||
.bridge_sync(async {
|
||||
let client = LocalClient::temporary(XetContext::default().unwrap()).await.unwrap();
|
||||
client.set_max_ranges_per_fetch(2);
|
||||
|
||||
let term_spec = &[(1, (0, 2)), (1, (4, 6)), (1, (8, 10)), (1, (12, 14))];
|
||||
let fc = client.upload_random_file(term_spec, TEST_CHUNK_SIZE).await.unwrap();
|
||||
let term_spec = &[(1, (0, 2)), (1, (4, 6)), (1, (8, 10)), (1, (12, 14))];
|
||||
let fc = client.upload_random_file(term_spec, TEST_CHUNK_SIZE).await.unwrap();
|
||||
|
||||
let config = test_config();
|
||||
let result = reconstruct_to_vec(&client, fc.file_hash, None, &config, None).await.unwrap();
|
||||
assert_eq!(result, fc.data.as_ref());
|
||||
})
|
||||
.unwrap();
|
||||
let config = test_config();
|
||||
let result = reconstruct_to_vec(&client, fc.file_hash, None, &config, None).await.unwrap();
|
||||
assert_eq!(result, fc.data.as_ref());
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1647,7 +1700,8 @@ mod tests {
|
||||
let writer = StaticCursorWriter(buffer.clone());
|
||||
|
||||
let client: Arc<dyn Client> = server.remote_client().clone();
|
||||
let mut reconstructor = FileReconstructor::new(&client, file_hash).with_config(config);
|
||||
let mut reconstructor =
|
||||
FileReconstructor::new(&XetContext::default().unwrap(), &client, file_hash).with_config(config);
|
||||
|
||||
if let Some(range) = byte_range {
|
||||
reconstructor = reconstructor.with_byte_range(range);
|
||||
@@ -1843,10 +1897,10 @@ mod tests {
|
||||
|
||||
// ==================== Multirange via Server ====================
|
||||
|
||||
fn with_multirange_config(enable: bool) -> Arc<XetRuntime> {
|
||||
fn with_multirange_config(enable: bool) -> XetContext {
|
||||
let mut config = xet_runtime::config::XetConfig::new();
|
||||
config.client.enable_multirange_fetching = enable;
|
||||
XetRuntime::new_with_config(config).unwrap()
|
||||
XetContext::with_config(config).unwrap()
|
||||
}
|
||||
|
||||
/// Exercises HTTP server path with full, max-ranges-split, and partial-range
|
||||
@@ -1854,49 +1908,50 @@ mod tests {
|
||||
#[test]
|
||||
fn test_multirange_via_server() {
|
||||
for enable in [false, true] {
|
||||
let rt = with_multirange_config(enable);
|
||||
rt.bridge_sync(async {
|
||||
let config = test_config();
|
||||
let ctx = with_multirange_config(enable);
|
||||
ctx.runtime
|
||||
.bridge_sync(async {
|
||||
let config = test_config();
|
||||
|
||||
// Full reconstruction with disjoint ranges
|
||||
let server = xet_client::cas_client::LocalTestServerBuilder::new().start().await;
|
||||
let fc = server
|
||||
.remote_client()
|
||||
.upload_random_file(&[(1, (0, 2)), (1, (4, 6)), (1, (8, 10))], TEST_CHUNK_SIZE)
|
||||
.await
|
||||
.unwrap();
|
||||
let result = reconstruct_via_server(&server, fc.file_hash, None, &config).await.unwrap();
|
||||
assert_eq!(result, fc.data.as_ref());
|
||||
// Full reconstruction with disjoint ranges
|
||||
let server = xet_client::cas_client::LocalTestServerBuilder::new().start().await;
|
||||
let fc = server
|
||||
.remote_client()
|
||||
.upload_random_file(&[(1, (0, 2)), (1, (4, 6)), (1, (8, 10))], TEST_CHUNK_SIZE)
|
||||
.await
|
||||
.unwrap();
|
||||
let result = reconstruct_via_server(&server, fc.file_hash, None, &config).await.unwrap();
|
||||
assert_eq!(result, fc.data.as_ref());
|
||||
|
||||
// Multi-xorb with max_ranges_per_fetch=2
|
||||
let server = xet_client::cas_client::LocalTestServerBuilder::new().start().await;
|
||||
let fc = server
|
||||
.remote_client()
|
||||
.upload_random_file(
|
||||
&[(1, (0, 2)), (2, (0, 2)), (1, (4, 6)), (2, (4, 6)), (1, (8, 10))],
|
||||
TEST_CHUNK_SIZE,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
server.set_max_ranges_per_fetch(2);
|
||||
let result = reconstruct_via_server(&server, fc.file_hash, None, &config).await.unwrap();
|
||||
assert_eq!(result, fc.data.as_ref());
|
||||
// Multi-xorb with max_ranges_per_fetch=2
|
||||
let server = xet_client::cas_client::LocalTestServerBuilder::new().start().await;
|
||||
let fc = server
|
||||
.remote_client()
|
||||
.upload_random_file(
|
||||
&[(1, (0, 2)), (2, (0, 2)), (1, (4, 6)), (2, (4, 6)), (1, (8, 10))],
|
||||
TEST_CHUNK_SIZE,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
server.set_max_ranges_per_fetch(2);
|
||||
let result = reconstruct_via_server(&server, fc.file_hash, None, &config).await.unwrap();
|
||||
assert_eq!(result, fc.data.as_ref());
|
||||
|
||||
// Partial byte range
|
||||
let server = xet_client::cas_client::LocalTestServerBuilder::new().start().await;
|
||||
let fc = server
|
||||
.remote_client()
|
||||
.upload_random_file(&[(1, (0, 3)), (2, (0, 2)), (1, (3, 5)), (2, (4, 6))], TEST_CHUNK_SIZE)
|
||||
.await
|
||||
.unwrap();
|
||||
let file_len = fc.data.len() as u64;
|
||||
let range = FileRange::new(file_len / 4, file_len * 3 / 4);
|
||||
let result = reconstruct_via_server(&server, fc.file_hash, Some(range), &config)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(result, &fc.data[range.start as usize..range.end as usize]);
|
||||
})
|
||||
.unwrap();
|
||||
// Partial byte range
|
||||
let server = xet_client::cas_client::LocalTestServerBuilder::new().start().await;
|
||||
let fc = server
|
||||
.remote_client()
|
||||
.upload_random_file(&[(1, (0, 3)), (2, (0, 2)), (1, (3, 5)), (2, (4, 6))], TEST_CHUNK_SIZE)
|
||||
.await
|
||||
.unwrap();
|
||||
let file_len = fc.data.len() as u64;
|
||||
let range = FileRange::new(file_len / 4, file_len * 3 / 4);
|
||||
let result = reconstruct_via_server(&server, fc.file_hash, Some(range), &config)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(result, &fc.data[range.start as usize..range.end as usize]);
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
} // mod server_tests
|
||||
|
||||
@@ -8,7 +8,7 @@ use xet_client::cas_client::Client;
|
||||
use xet_client::cas_types::{ChunkRange, FileRange, HttpRange};
|
||||
use xet_client::chunk_cache::ChunkCache;
|
||||
use xet_core_structures::merklehash::MerkleHash;
|
||||
use xet_runtime::core::xet_config;
|
||||
use xet_runtime::core::XetContext;
|
||||
use xet_runtime::utils::UniqueId;
|
||||
|
||||
use super::super::FileReconstructionError;
|
||||
@@ -58,6 +58,7 @@ impl FileTerm {
|
||||
/// only one download per xorb block (other callers wait without acquiring CAS permits).
|
||||
pub async fn get_data_task(
|
||||
&self,
|
||||
ctx: XetContext,
|
||||
client: Arc<dyn Client>,
|
||||
progress_updater: Option<Arc<ItemProgressUpdater>>,
|
||||
chunk_cache: Option<Arc<dyn ChunkCache>>,
|
||||
@@ -74,7 +75,7 @@ impl FileTerm {
|
||||
|
||||
let task = tokio::task::spawn(async move {
|
||||
let xorb_block_data = xorb_block
|
||||
.retrieve_data(client, url_info, progress_updater, chunk_cache)
|
||||
.retrieve_data(ctx, client, url_info, progress_updater, chunk_cache)
|
||||
.await?;
|
||||
Ok(file_term.extract_bytes(&xorb_block_data))
|
||||
});
|
||||
@@ -108,6 +109,7 @@ struct FileTermEntry {
|
||||
/// download (with dedup and compression enabled)
|
||||
/// along with the Vec<FileTerm>.
|
||||
pub async fn retrieve_file_term_block(
|
||||
ctx: &XetContext,
|
||||
client: Arc<dyn Client>,
|
||||
file_hash: MerkleHash,
|
||||
query_file_byte_range: FileRange,
|
||||
@@ -141,7 +143,7 @@ pub async fn retrieve_file_term_block(
|
||||
// Track the current byte offset in the output file as we process terms sequentially.
|
||||
let mut cur_file_byte_offset = query_file_byte_range.start;
|
||||
|
||||
let enable_multirange = xet_config().client.enable_multirange_fetching;
|
||||
let enable_multirange = ctx.config.client.enable_multirange_fetching;
|
||||
|
||||
for (local_term_index, term) in raw_reconstruction.terms.iter().enumerate() {
|
||||
let xorb_hash: MerkleHash = term.hash.into();
|
||||
@@ -353,6 +355,7 @@ mod tests {
|
||||
use more_asserts::assert_le;
|
||||
use xet_client::cas_client::{ClientTestingUtils, LocalClient, RandomFileContents};
|
||||
use xet_client::cas_types::{ChunkRange, FileRange};
|
||||
use xet_runtime::core::XetContext;
|
||||
use xet_runtime::utils::UniqueId;
|
||||
|
||||
use super::*;
|
||||
@@ -382,10 +385,11 @@ mod tests {
|
||||
|
||||
/// Creates a test client and uploads a random file with the given term specification.
|
||||
/// Returns the client and file contents for verification.
|
||||
async fn setup_test_file(term_spec: &[(u64, (u64, u64))]) -> (Arc<LocalClient>, RandomFileContents) {
|
||||
let client = LocalClient::temporary().await.unwrap();
|
||||
async fn setup_test_file(term_spec: &[(u64, (u64, u64))]) -> (XetContext, Arc<LocalClient>, RandomFileContents) {
|
||||
let ctx = XetContext::default().unwrap();
|
||||
let client = LocalClient::temporary(ctx.clone()).await.unwrap();
|
||||
let file_contents = client.upload_random_file(term_spec, TEST_CHUNK_SIZE).await.unwrap();
|
||||
(client, file_contents)
|
||||
(ctx, client, file_contents)
|
||||
}
|
||||
|
||||
/// Retrieves file terms and thoroughly verifies their correctness.
|
||||
@@ -400,6 +404,7 @@ mod tests {
|
||||
/// - Cross-references with the known file contents for correctness
|
||||
/// - Verifies number of file terms matches expected from term_spec
|
||||
async fn retrieve_and_verify(
|
||||
ctx: &XetContext,
|
||||
client: &Arc<LocalClient>,
|
||||
file_contents: &RandomFileContents,
|
||||
requested_range: Option<FileRange>,
|
||||
@@ -408,7 +413,7 @@ mod tests {
|
||||
let dyn_client: Arc<dyn Client> = client.clone();
|
||||
|
||||
let (returned_range, _, file_terms) =
|
||||
retrieve_file_term_block(dyn_client.clone(), file_contents.file_hash, requested_range)
|
||||
retrieve_file_term_block(ctx, dyn_client.clone(), file_contents.file_hash, requested_range)
|
||||
.await
|
||||
.expect("retrieve_file_term_block should succeed")
|
||||
.expect("file_terms should not be None for valid range");
|
||||
@@ -479,7 +484,10 @@ mod tests {
|
||||
assert!(file_contents.xorbs.contains_key(&file_term.xorb_block.xorb_hash));
|
||||
|
||||
// Get the data task and await it.
|
||||
let data_future = file_term.get_data_task(dyn_client.clone(), None, None).await.unwrap();
|
||||
let data_future = file_term
|
||||
.get_data_task(ctx.clone(), dyn_client.clone(), None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let data = data_future.await.unwrap();
|
||||
|
||||
// Verify the data size matches the byte range.
|
||||
@@ -514,10 +522,10 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_xorb_block_references_exact() {
|
||||
let (client, file_contents) = setup_test_file(&[(1, (0, 2)), (1, (2, 4)), (1, (4, 6))]).await;
|
||||
let (runtime, client, file_contents) = setup_test_file(&[(1, (0, 2)), (1, (2, 4)), (1, (4, 6))]).await;
|
||||
let file_range = FileRange::new(0, file_contents.data.len() as u64);
|
||||
let dyn_client: Arc<dyn Client> = client.clone();
|
||||
let (_, _, file_terms) = retrieve_file_term_block(dyn_client, file_contents.file_hash, file_range)
|
||||
let (_, _, file_terms) = retrieve_file_term_block(&runtime, dyn_client, file_contents.file_hash, file_range)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
@@ -528,13 +536,14 @@ mod tests {
|
||||
let expected = vec![ChunkRange::new(0, 2), ChunkRange::new(2, 4), ChunkRange::new(4, 6)];
|
||||
assert_eq!(ref_ranges, expected);
|
||||
|
||||
let (client2, file_contents2) = setup_test_file(&[(1, (0, 5)), (1, (0, 5))]).await;
|
||||
let (runtime2, client2, file_contents2) = setup_test_file(&[(1, (0, 5)), (1, (0, 5))]).await;
|
||||
let file_range2 = FileRange::new(0, file_contents2.data.len() as u64);
|
||||
let dyn_client2: Arc<dyn Client> = client2.clone();
|
||||
let (_, _, file_terms2) = retrieve_file_term_block(dyn_client2, file_contents2.file_hash, file_range2)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let (_, _, file_terms2) =
|
||||
retrieve_file_term_block(&runtime2, dyn_client2, file_contents2.file_hash, file_range2)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
verify_xorb_block_references(&file_terms2);
|
||||
let block2 = &file_terms2[0].xorb_block;
|
||||
let ref_ranges2: Vec<ChunkRange> = block2.references.iter().map(|r| r.term_chunks).collect();
|
||||
@@ -544,57 +553,58 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_single_xorb_full_range() {
|
||||
let (client, file_contents) = setup_test_file(&[(1, (0, 5))]).await;
|
||||
retrieve_and_verify(&client, &file_contents, None).await;
|
||||
let (runtime, client, file_contents) = setup_test_file(&[(1, (0, 5))]).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, None).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_terms_same_xorb() {
|
||||
let (client, file_contents) = setup_test_file(&[(1, (0, 2)), (1, (2, 4)), (1, (4, 6))]).await;
|
||||
retrieve_and_verify(&client, &file_contents, None).await;
|
||||
let (runtime, client, file_contents) = setup_test_file(&[(1, (0, 2)), (1, (2, 4)), (1, (4, 6))]).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, None).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_xorbs() {
|
||||
let (client, file_contents) = setup_test_file(&[(1, (0, 3)), (2, (0, 2)), (3, (0, 4))]).await;
|
||||
retrieve_and_verify(&client, &file_contents, None).await;
|
||||
let (runtime, client, file_contents) = setup_test_file(&[(1, (0, 3)), (2, (0, 2)), (3, (0, 4))]).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, None).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_overlapping_chunk_ranges() {
|
||||
let (client, file_contents) = setup_test_file(&[(1, (0, 5)), (1, (1, 3)), (1, (2, 4))]).await;
|
||||
retrieve_and_verify(&client, &file_contents, None).await;
|
||||
let (runtime, client, file_contents) = setup_test_file(&[(1, (0, 5)), (1, (1, 3)), (1, (2, 4))]).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, None).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_partial_range_middle() {
|
||||
let (client, file_contents) = setup_test_file(&[(1, (0, 10))]).await;
|
||||
let (runtime, client, file_contents) = setup_test_file(&[(1, (0, 10))]).await;
|
||||
let file_len = file_contents.data.len() as u64;
|
||||
retrieve_and_verify(&client, &file_contents, Some(FileRange::new(file_len / 4, file_len * 3 / 4))).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, Some(FileRange::new(file_len / 4, file_len * 3 / 4)))
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_partial_range_start() {
|
||||
let (client, file_contents) = setup_test_file(&[(1, (0, 10))]).await;
|
||||
let (runtime, client, file_contents) = setup_test_file(&[(1, (0, 10))]).await;
|
||||
let file_len = file_contents.data.len() as u64;
|
||||
retrieve_and_verify(&client, &file_contents, Some(FileRange::new(0, file_len / 2))).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, Some(FileRange::new(0, file_len / 2))).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_partial_range_end() {
|
||||
let (client, file_contents) = setup_test_file(&[(1, (0, 10))]).await;
|
||||
let (runtime, client, file_contents) = setup_test_file(&[(1, (0, 10))]).await;
|
||||
let file_len = file_contents.data.len() as u64;
|
||||
retrieve_and_verify(&client, &file_contents, Some(FileRange::new(file_len / 2, file_len))).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, Some(FileRange::new(file_len / 2, file_len))).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_beyond_file_end() {
|
||||
let (client, file_contents) = setup_test_file(&[(1, (0, 3))]).await;
|
||||
let (runtime, client, file_contents) = setup_test_file(&[(1, (0, 3))]).await;
|
||||
let file_len = file_contents.data.len() as u64;
|
||||
let beyond_range = FileRange::new(file_len + 1000, file_len + 2000);
|
||||
|
||||
let dyn_client: Arc<dyn Client> = client.clone();
|
||||
let result = retrieve_file_term_block(dyn_client, file_contents.file_hash, beyond_range).await;
|
||||
let result = retrieve_file_term_block(&runtime, dyn_client, file_contents.file_hash, beyond_range).await;
|
||||
|
||||
match result {
|
||||
Ok(None) => {},
|
||||
@@ -605,49 +615,50 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_interleaved_xorbs() {
|
||||
let (client, file_contents) = setup_test_file(&[(1, (0, 2)), (2, (0, 2)), (1, (2, 4)), (2, (2, 4))]).await;
|
||||
retrieve_and_verify(&client, &file_contents, None).await;
|
||||
let (runtime, client, file_contents) =
|
||||
setup_test_file(&[(1, (0, 2)), (2, (0, 2)), (1, (2, 4)), (2, (2, 4))]).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, None).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_non_contiguous_chunks() {
|
||||
let (client, file_contents) = setup_test_file(&[(1, (0, 2)), (1, (4, 6))]).await;
|
||||
retrieve_and_verify(&client, &file_contents, None).await;
|
||||
let (runtime, client, file_contents) = setup_test_file(&[(1, (0, 2)), (1, (4, 6))]).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, None).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_adjacent_chunks() {
|
||||
let (client, file_contents) = setup_test_file(&[(1, (0, 3)), (1, (3, 5))]).await;
|
||||
retrieve_and_verify(&client, &file_contents, None).await;
|
||||
let (runtime, client, file_contents) = setup_test_file(&[(1, (0, 3)), (1, (3, 5))]).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, None).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_single_chunk_terms() {
|
||||
let (client, file_contents) =
|
||||
let (runtime, client, file_contents) =
|
||||
setup_test_file(&[(1, (0, 1)), (1, (1, 2)), (1, (2, 3)), (2, (0, 1)), (2, (1, 2))]).await;
|
||||
retrieve_and_verify(&client, &file_contents, None).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, None).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_large_file_many_xorbs() {
|
||||
let term_spec: Vec<(u64, (u64, u64))> = (1..=10).map(|i| (i, (0, 3))).collect();
|
||||
let (client, file_contents) = setup_test_file(&term_spec).await;
|
||||
retrieve_and_verify(&client, &file_contents, None).await;
|
||||
let (runtime, client, file_contents) = setup_test_file(&term_spec).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, None).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_xorb_block_deduplication() {
|
||||
let (client, file_contents) = setup_test_file(&[(1, (0, 5)), (1, (0, 5))]).await;
|
||||
retrieve_and_verify(&client, &file_contents, None).await;
|
||||
let (runtime, client, file_contents) = setup_test_file(&[(1, (0, 5)), (1, (0, 5))]).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, None).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_retrieval_url_acquisition() {
|
||||
let (client, file_contents) = setup_test_file(&[(1, (0, 5))]).await;
|
||||
let (runtime, client, file_contents) = setup_test_file(&[(1, (0, 5))]).await;
|
||||
let file_range = FileRange::new(0, file_contents.data.len() as u64);
|
||||
let dyn_client: Arc<dyn Client> = client.clone();
|
||||
|
||||
let (_, _, file_terms) = retrieve_file_term_block(dyn_client, file_contents.file_hash, file_range)
|
||||
let (_, _, file_terms) = retrieve_file_term_block(&runtime, dyn_client, file_contents.file_hash, file_range)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
@@ -673,94 +684,95 @@ mod tests {
|
||||
(2, (4, 6)),
|
||||
(1, (0, 2)),
|
||||
];
|
||||
let (client, file_contents) = setup_test_file(term_spec).await;
|
||||
retrieve_and_verify(&client, &file_contents, None).await;
|
||||
let (runtime, client, file_contents) = setup_test_file(term_spec).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, None).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_repeated_xorb_different_ranges() {
|
||||
let (client, file_contents) = setup_test_file(&[(1, (0, 2)), (1, (3, 5)), (1, (1, 3)), (1, (4, 6))]).await;
|
||||
retrieve_and_verify(&client, &file_contents, None).await;
|
||||
let (runtime, client, file_contents) =
|
||||
setup_test_file(&[(1, (0, 2)), (1, (3, 5)), (1, (1, 3)), (1, (4, 6))]).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, None).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_single_chunk_file() {
|
||||
let (client, file_contents) = setup_test_file(&[(1, (0, 1))]).await;
|
||||
retrieve_and_verify(&client, &file_contents, None).await;
|
||||
let (runtime, client, file_contents) = setup_test_file(&[(1, (0, 1))]).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, None).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_many_small_terms_from_different_xorbs() {
|
||||
let term_spec: Vec<(u64, (u64, u64))> = (1..=20).map(|i| (i, (0, 1))).collect();
|
||||
let (client, file_contents) = setup_test_file(&term_spec).await;
|
||||
retrieve_and_verify(&client, &file_contents, None).await;
|
||||
let (runtime, client, file_contents) = setup_test_file(&term_spec).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, None).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_range_few_bytes_before_end() {
|
||||
let (client, file_contents) = setup_test_file(&[(1, (0, 5))]).await;
|
||||
let (runtime, client, file_contents) = setup_test_file(&[(1, (0, 5))]).await;
|
||||
let file_len = file_contents.data.len() as u64;
|
||||
|
||||
let range = FileRange::new(0, file_len - 3);
|
||||
retrieve_and_verify(&client, &file_contents, Some(range)).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, Some(range)).await;
|
||||
|
||||
let range = FileRange::new(0, file_len - 1);
|
||||
retrieve_and_verify(&client, &file_contents, Some(range)).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, Some(range)).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_range_few_bytes_after_start() {
|
||||
let (client, file_contents) = setup_test_file(&[(1, (0, 5))]).await;
|
||||
let (runtime, client, file_contents) = setup_test_file(&[(1, (0, 5))]).await;
|
||||
let file_len = file_contents.data.len() as u64;
|
||||
|
||||
let range = FileRange::new(3, file_len);
|
||||
retrieve_and_verify(&client, &file_contents, Some(range)).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, Some(range)).await;
|
||||
|
||||
let range = FileRange::new(1, file_len);
|
||||
retrieve_and_verify(&client, &file_contents, Some(range)).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, Some(range)).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_range_few_bytes_offset_both_ends() {
|
||||
let (client, file_contents) = setup_test_file(&[(1, (0, 5))]).await;
|
||||
let (runtime, client, file_contents) = setup_test_file(&[(1, (0, 5))]).await;
|
||||
let file_len = file_contents.data.len() as u64;
|
||||
|
||||
let range = FileRange::new(2, file_len - 2);
|
||||
retrieve_and_verify(&client, &file_contents, Some(range)).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, Some(range)).await;
|
||||
|
||||
let range = FileRange::new(file_len / 2 - 1, file_len / 2 + 1);
|
||||
retrieve_and_verify(&client, &file_contents, Some(range)).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, Some(range)).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_range_single_byte_at_various_positions() {
|
||||
let (client, file_contents) = setup_test_file(&[(1, (0, 5))]).await;
|
||||
let (runtime, client, file_contents) = setup_test_file(&[(1, (0, 5))]).await;
|
||||
let file_len = file_contents.data.len() as u64;
|
||||
|
||||
retrieve_and_verify(&client, &file_contents, Some(FileRange::new(0, 1))).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, Some(FileRange::new(0, 1))).await;
|
||||
|
||||
retrieve_and_verify(&client, &file_contents, Some(FileRange::new(file_len - 1, file_len))).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, Some(FileRange::new(file_len - 1, file_len))).await;
|
||||
|
||||
let mid = file_len / 2;
|
||||
retrieve_and_verify(&client, &file_contents, Some(FileRange::new(mid, mid + 1))).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, Some(FileRange::new(mid, mid + 1))).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multi_term_range_ends_mid_chunk() {
|
||||
let (client, file_contents) = setup_test_file(&[(1, (0, 3)), (2, (0, 3)), (3, (0, 3))]).await;
|
||||
let (runtime, client, file_contents) = setup_test_file(&[(1, (0, 3)), (2, (0, 3)), (3, (0, 3))]).await;
|
||||
let file_len = file_contents.data.len() as u64;
|
||||
|
||||
let range = FileRange::new(0, file_len - 5);
|
||||
retrieve_and_verify(&client, &file_contents, Some(range)).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, Some(range)).await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multi_term_range_starts_mid_chunk() {
|
||||
let (client, file_contents) = setup_test_file(&[(1, (0, 3)), (2, (0, 3)), (3, (0, 3))]).await;
|
||||
let (runtime, client, file_contents) = setup_test_file(&[(1, (0, 3)), (2, (0, 3)), (3, (0, 3))]).await;
|
||||
let file_len = file_contents.data.len() as u64;
|
||||
|
||||
let range = FileRange::new(5, file_len);
|
||||
retrieve_and_verify(&client, &file_contents, Some(range)).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, Some(range)).await;
|
||||
}
|
||||
|
||||
// ==================== Multi-Disjoint Range Edge Cases ====================
|
||||
@@ -769,42 +781,43 @@ mod tests {
|
||||
/// This creates one XorbBlock with chunk_ranges = [(0,2), (4,6), (8,10)].
|
||||
#[tokio::test]
|
||||
async fn test_triple_disjoint_same_xorb() {
|
||||
let (client, file_contents) = setup_test_file(&[(1, (0, 2)), (1, (4, 6)), (1, (8, 10))]).await;
|
||||
retrieve_and_verify(&client, &file_contents, None).await;
|
||||
let (runtime, client, file_contents) = setup_test_file(&[(1, (0, 2)), (1, (4, 6)), (1, (8, 10))]).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, None).await;
|
||||
}
|
||||
|
||||
/// Triple disjoint ranges with a partial byte range spanning the gap.
|
||||
#[tokio::test]
|
||||
async fn test_triple_disjoint_partial_range_across_gap() {
|
||||
let (client, file_contents) = setup_test_file(&[(1, (0, 2)), (1, (4, 6)), (1, (8, 10))]).await;
|
||||
let (runtime, client, file_contents) = setup_test_file(&[(1, (0, 2)), (1, (4, 6)), (1, (8, 10))]).await;
|
||||
let file_len = file_contents.data.len() as u64;
|
||||
let range = FileRange::new(file_len / 4, file_len * 3 / 4);
|
||||
retrieve_and_verify(&client, &file_contents, Some(range)).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, Some(range)).await;
|
||||
}
|
||||
|
||||
/// Two xorbs, each with two disjoint ranges, interleaved in file order.
|
||||
#[tokio::test]
|
||||
async fn test_two_xorbs_interleaved_disjoint() {
|
||||
let term_spec = &[(1, (0, 2)), (2, (0, 2)), (1, (4, 6)), (2, (4, 6))];
|
||||
let (client, file_contents) = setup_test_file(term_spec).await;
|
||||
retrieve_and_verify(&client, &file_contents, None).await;
|
||||
let (runtime, client, file_contents) = setup_test_file(term_spec).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, None).await;
|
||||
}
|
||||
|
||||
/// Two xorbs interleaved with disjoint ranges, partial byte range.
|
||||
#[tokio::test]
|
||||
async fn test_two_xorbs_interleaved_disjoint_partial() {
|
||||
let term_spec = &[(1, (0, 2)), (2, (0, 2)), (1, (4, 6)), (2, (4, 6))];
|
||||
let (client, file_contents) = setup_test_file(term_spec).await;
|
||||
let (runtime, client, file_contents) = setup_test_file(term_spec).await;
|
||||
let file_len = file_contents.data.len() as u64;
|
||||
retrieve_and_verify(&client, &file_contents, Some(FileRange::new(file_len / 3, file_len * 2 / 3))).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, Some(FileRange::new(file_len / 3, file_len * 2 / 3)))
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Single xorb with four disjoint ranges, each a single chunk wide.
|
||||
#[tokio::test]
|
||||
async fn test_four_single_chunk_disjoint() {
|
||||
let term_spec = &[(1, (0, 1)), (1, (3, 4)), (1, (6, 7)), (1, (9, 10))];
|
||||
let (client, file_contents) = setup_test_file(term_spec).await;
|
||||
retrieve_and_verify(&client, &file_contents, None).await;
|
||||
let (runtime, client, file_contents) = setup_test_file(term_spec).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, None).await;
|
||||
}
|
||||
|
||||
/// Mix of contiguous and disjoint ranges from the same xorb.
|
||||
@@ -812,8 +825,8 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_contiguous_then_disjoint() {
|
||||
let term_spec = &[(1, (0, 2)), (1, (2, 4)), (1, (8, 10))];
|
||||
let (client, file_contents) = setup_test_file(term_spec).await;
|
||||
retrieve_and_verify(&client, &file_contents, None).await;
|
||||
let (runtime, client, file_contents) = setup_test_file(term_spec).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, None).await;
|
||||
}
|
||||
|
||||
/// Three xorbs with complex disjoint access patterns.
|
||||
@@ -827,7 +840,7 @@ mod tests {
|
||||
(2, (6, 8)),
|
||||
(3, (0, 2)),
|
||||
];
|
||||
let (client, file_contents) = setup_test_file(term_spec).await;
|
||||
retrieve_and_verify(&client, &file_contents, None).await;
|
||||
let (runtime, client, file_contents) = setup_test_file(term_spec).await;
|
||||
retrieve_and_verify(&runtime, &client, &file_contents, None).await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,6 +11,7 @@ use xet_client::cas_types::FileRange;
|
||||
use xet_core_structures::ExpWeightedMovingAvg;
|
||||
use xet_core_structures::merklehash::MerkleHash;
|
||||
use xet_runtime::config::ReconstructionConfig;
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
use super::super::FileReconstructionError;
|
||||
use super::super::error::Result;
|
||||
@@ -23,6 +24,7 @@ type RawFetchedFileTerms = Result<Option<(Vec<FileTerm>, u64, u64)>>;
|
||||
/// Prefetches reconstruction blocks ahead of consumption based on observed completion rates
|
||||
/// to minimize download latency while controlling memory usage.
|
||||
pub struct ReconstructionTermManager {
|
||||
ctx: XetContext,
|
||||
config: Arc<ReconstructionConfig>,
|
||||
client: Arc<dyn Client>,
|
||||
file_hash: MerkleHash,
|
||||
@@ -40,6 +42,7 @@ pub struct ReconstructionTermManager {
|
||||
|
||||
impl ReconstructionTermManager {
|
||||
pub async fn new(
|
||||
ctx: XetContext,
|
||||
config: Arc<ReconstructionConfig>,
|
||||
client: Arc<dyn Client>,
|
||||
file_hash: MerkleHash,
|
||||
@@ -52,6 +55,7 @@ impl ReconstructionTermManager {
|
||||
let requested_byte_range = file_byte_range;
|
||||
|
||||
let mut s = Self {
|
||||
ctx,
|
||||
config,
|
||||
client,
|
||||
file_hash,
|
||||
@@ -281,9 +285,10 @@ impl ReconstructionTermManager {
|
||||
let known_final_byte_position = self.known_final_byte_position.clone();
|
||||
let client = self.client.clone();
|
||||
let file_hash = self.file_hash;
|
||||
let runtime = self.ctx.clone();
|
||||
|
||||
let jh = tokio::task::spawn(async move {
|
||||
let result = retrieve_file_term_block(client, file_hash, prefetch_block_range).await;
|
||||
let result = retrieve_file_term_block(&runtime, client, file_hash, prefetch_block_range).await;
|
||||
|
||||
// See if we're done with the file.
|
||||
if let Ok(Some((ref returned_range, transfer_bytes, ref file_terms))) = result {
|
||||
|
||||
@@ -5,6 +5,7 @@ use tracing::{debug, info};
|
||||
use xet_client::cas_client::{Client, URLProvider};
|
||||
use xet_client::cas_types::{FileRange, HttpRange};
|
||||
use xet_core_structures::merklehash::MerkleHash;
|
||||
use xet_runtime::core::XetContext;
|
||||
use xet_runtime::utils::UniqueId;
|
||||
|
||||
use super::super::FileReconstructionError;
|
||||
@@ -59,7 +60,12 @@ impl TermBlockRetrievalURLs {
|
||||
/// ignored if the acquisition ID of the current URLs is different from the one passed in
|
||||
/// as reference in the request; this indicates that the caller has a stale URL already and
|
||||
/// the new request will get a new URL.
|
||||
pub async fn refresh_retrieval_urls(&self, client: Arc<dyn Client>, acquisition_id: UniqueId) -> Result<()> {
|
||||
pub async fn refresh_retrieval_urls(
|
||||
&self,
|
||||
ctx: &XetContext,
|
||||
client: Arc<dyn Client>,
|
||||
acquisition_id: UniqueId,
|
||||
) -> Result<()> {
|
||||
if self.xorb_block_retrieval_urls.read().await.0 != acquisition_id {
|
||||
// Another task already refreshed while we were waiting for the read lock.
|
||||
debug!(
|
||||
@@ -91,7 +97,7 @@ impl TermBlockRetrievalURLs {
|
||||
|
||||
// Re-fetch the entire block to get fresh URLs, then verify the structure matches.
|
||||
let Some((returned_range, _transfer_bytes, file_terms)) =
|
||||
retrieve_file_term_block(client, self.file_hash, self.byte_range).await?
|
||||
retrieve_file_term_block(ctx, client, self.file_hash, self.byte_range).await?
|
||||
else {
|
||||
return Err(FileReconstructionError::CorruptedReconstruction(
|
||||
"On URL refresh, the returned reconstruction was None.".to_owned(),
|
||||
@@ -131,6 +137,7 @@ impl TermBlockRetrievalURLs {
|
||||
|
||||
/// Provides download URLs for a xorb block, handling URL refresh on expiration.
|
||||
pub struct XorbURLProvider {
|
||||
pub ctx: XetContext,
|
||||
pub client: Arc<dyn Client>,
|
||||
pub url_info: Arc<TermBlockRetrievalURLs>,
|
||||
pub xorb_block_index: usize,
|
||||
@@ -148,7 +155,7 @@ impl URLProvider for XorbURLProvider {
|
||||
|
||||
async fn refresh_url(&self) -> std::result::Result<(), xet_client::ClientError> {
|
||||
self.url_info
|
||||
.refresh_retrieval_urls(self.client.clone(), *self.last_acquisition_id.lock().await)
|
||||
.refresh_retrieval_urls(&self.ctx, self.client.clone(), *self.last_acquisition_id.lock().await)
|
||||
.await
|
||||
.map_err(|e| xet_client::ClientError::Other(e.to_string()))
|
||||
}
|
||||
@@ -162,6 +169,7 @@ mod tests {
|
||||
use xet_client::cas_client::{ClientTestingUtils, LocalClient, URLProvider};
|
||||
use xet_client::cas_types::{FileRange, HttpRange};
|
||||
use xet_core_structures::merklehash::MerkleHash;
|
||||
use xet_runtime::core::XetContext;
|
||||
use xet_runtime::utils::UniqueId;
|
||||
|
||||
use super::{TermBlockRetrievalURLs, XorbURLProvider};
|
||||
@@ -188,8 +196,9 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_refresh_skipped_when_already_refreshed() {
|
||||
let ctx = XetContext::default().unwrap();
|
||||
let (client, file_contents) = {
|
||||
let c = LocalClient::temporary().await.unwrap();
|
||||
let c = LocalClient::temporary(ctx.clone()).await.unwrap();
|
||||
let fc = c.upload_random_file(&[(1, (0, 3))], 64).await.unwrap();
|
||||
(c, fc)
|
||||
};
|
||||
@@ -198,7 +207,7 @@ mod tests {
|
||||
let dyn_client: Arc<dyn xet_client::cas_client::Client> = client.clone();
|
||||
|
||||
let (_, _, file_terms) =
|
||||
super::retrieve_file_term_block(dyn_client.clone(), file_contents.file_hash, file_range)
|
||||
super::retrieve_file_term_block(&ctx, dyn_client.clone(), file_contents.file_hash, file_range)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
@@ -210,20 +219,27 @@ mod tests {
|
||||
|
||||
// Refresh with a stale (different) ID should be a no-op.
|
||||
let stale_id = UniqueId::new();
|
||||
url_info.refresh_retrieval_urls(dyn_client.clone(), stale_id).await.unwrap();
|
||||
url_info
|
||||
.refresh_retrieval_urls(&ctx, dyn_client.clone(), stale_id)
|
||||
.await
|
||||
.unwrap();
|
||||
let (id_after, _, _) = url_info.get_retrieval_url(0).await;
|
||||
assert!(id_after == original_id, "refresh with stale ID should not change acquisition ID");
|
||||
|
||||
// Refresh with the correct ID should update URLs.
|
||||
url_info.refresh_retrieval_urls(dyn_client.clone(), original_id).await.unwrap();
|
||||
url_info
|
||||
.refresh_retrieval_urls(&ctx, dyn_client.clone(), original_id)
|
||||
.await
|
||||
.unwrap();
|
||||
let (refreshed_id, _, _) = url_info.get_retrieval_url(0).await;
|
||||
assert!(refreshed_id != original_id, "refresh with correct ID should change acquisition ID");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_xorb_url_provider_retrieve_and_refresh() {
|
||||
let ctx = XetContext::default().unwrap();
|
||||
let (client, file_contents) = {
|
||||
let c = LocalClient::temporary().await.unwrap();
|
||||
let c = LocalClient::temporary(ctx.clone()).await.unwrap();
|
||||
let fc = c.upload_random_file(&[(1, (0, 3))], 64).await.unwrap();
|
||||
(c, fc)
|
||||
};
|
||||
@@ -232,7 +248,7 @@ mod tests {
|
||||
let dyn_client: Arc<dyn xet_client::cas_client::Client> = client.clone();
|
||||
|
||||
let (_, _, file_terms) =
|
||||
super::retrieve_file_term_block(dyn_client.clone(), file_contents.file_hash, file_range)
|
||||
super::retrieve_file_term_block(&ctx, dyn_client.clone(), file_contents.file_hash, file_range)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
@@ -240,6 +256,7 @@ mod tests {
|
||||
let url_info = file_terms[0].url_info.clone();
|
||||
|
||||
let provider = XorbURLProvider {
|
||||
ctx: ctx.clone(),
|
||||
client: dyn_client.clone(),
|
||||
url_info,
|
||||
xorb_block_index: 0,
|
||||
|
||||
@@ -7,7 +7,7 @@ use xet_client::cas_client::{Client, ProgressCallback};
|
||||
use xet_client::cas_types::{ChunkRange, Key};
|
||||
use xet_client::chunk_cache::ChunkCache;
|
||||
use xet_core_structures::merklehash::MerkleHash;
|
||||
use xet_runtime::core::xet_config;
|
||||
use xet_runtime::core::XetContext;
|
||||
use xet_runtime::utils::UniqueId;
|
||||
|
||||
use super::super::error::Result;
|
||||
@@ -94,6 +94,7 @@ impl XorbBlock {
|
||||
/// can retry.
|
||||
pub async fn retrieve_data(
|
||||
self: Arc<Self>,
|
||||
ctx: XetContext,
|
||||
client: Arc<dyn Client>,
|
||||
url_info: Arc<TermBlockRetrievalURLs>,
|
||||
progress_updater: Option<Arc<ItemProgressUpdater>>,
|
||||
@@ -111,7 +112,7 @@ impl XorbBlock {
|
||||
// blocks (multiple disjoint chunk ranges per block) are cached.
|
||||
if let Some(ref cache) = chunk_cache {
|
||||
let cache_key = Key {
|
||||
prefix: xet_config().data.default_prefix.clone(),
|
||||
prefix: ctx.config.data.default_prefix.clone(),
|
||||
hash: self.xorb_hash,
|
||||
};
|
||||
let chunk_range = chunk_ranges.first().copied().unwrap_or_default();
|
||||
@@ -133,6 +134,7 @@ impl XorbBlock {
|
||||
let permit = client.acquire_download_permit().await?;
|
||||
|
||||
let url_provider = XorbURLProvider {
|
||||
ctx: ctx.clone(),
|
||||
client: client.clone(),
|
||||
url_info,
|
||||
xorb_block_index,
|
||||
@@ -155,7 +157,7 @@ impl XorbBlock {
|
||||
// Store in chunk cache (best-effort, non-blocking).
|
||||
if let Some(cache) = chunk_cache {
|
||||
let cache_key = Key {
|
||||
prefix: xet_config().data.default_prefix.clone(),
|
||||
prefix: ctx.config.data.default_prefix.clone(),
|
||||
hash: self.xorb_hash,
|
||||
};
|
||||
let chunk_range = chunk_ranges.first().copied().unwrap_or_default();
|
||||
|
||||
@@ -7,7 +7,7 @@ use anyhow::Result;
|
||||
use clap::{Args, Parser, Subcommand};
|
||||
use xet_data::processing::configurations::TranslatorConfig;
|
||||
use xet_data::processing::{FileUploadSession, Sha256Policy, XetFileInfo};
|
||||
use xet_runtime::core::XetRuntime;
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
#[derive(Parser)]
|
||||
struct XCommand {
|
||||
@@ -56,16 +56,15 @@ impl Command {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_threadpool() -> Arc<XetRuntime> {
|
||||
static THREADPOOL: OnceLock<Arc<XetRuntime>> = OnceLock::new();
|
||||
THREADPOOL
|
||||
.get_or_init(|| XetRuntime::new().expect("Error starting multithreaded runtime."))
|
||||
fn get_xet_context() -> XetContext {
|
||||
static CTX: OnceLock<XetContext> = OnceLock::new();
|
||||
CTX.get_or_init(|| XetContext::default().expect("Error starting multithreaded runtime."))
|
||||
.clone()
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let cli = XCommand::parse();
|
||||
let _ = get_threadpool().bridge_sync(async move { cli.run().await }).unwrap();
|
||||
let _ = get_xet_context().runtime.bridge_sync(async move { cli.run().await }).unwrap();
|
||||
}
|
||||
|
||||
async fn clean_file(arg: &CleanArg) -> Result<()> {
|
||||
@@ -85,7 +84,9 @@ async fn clean(mut reader: impl Read, mut writer: impl Write, size: u64) -> Resu
|
||||
|
||||
let mut read_buf = vec![0u8; READ_BLOCK_SIZE];
|
||||
|
||||
let translator = FileUploadSession::new(TranslatorConfig::local_config(std::env::current_dir()?)?.into()).await?;
|
||||
let runtime = get_xet_context();
|
||||
let translator =
|
||||
FileUploadSession::new(Arc::new(TranslatorConfig::local_config(&runtime, std::env::current_dir()?)?)).await?;
|
||||
|
||||
let mut size_read = 0;
|
||||
let (_id, mut handle) = translator.start_clean(None, Some(size), Sha256Policy::Compute)?;
|
||||
@@ -132,7 +133,8 @@ async fn smudge(_name: Arc<str>, mut reader: impl Read, output_path: PathBuf) ->
|
||||
|
||||
// Use local config pointing to current directory
|
||||
let cas_path = std::env::current_dir()?;
|
||||
let config = TranslatorConfig::local_config(cas_path)?;
|
||||
let runtime = get_xet_context();
|
||||
let config = TranslatorConfig::local_config(&runtime, cas_path)?;
|
||||
let session = xet_data::processing::FileDownloadSession::new(config.into(), None).await?;
|
||||
|
||||
let (_id, _n_bytes) = session.download_file(&xet_file, &output_path).await?;
|
||||
|
||||
@@ -10,14 +10,14 @@ use walkdir::WalkDir;
|
||||
use xet_client::cas_client::RemoteClient;
|
||||
use xet_client::cas_client::auth::TokenRefresher;
|
||||
use xet_client::cas_types::{FileRange, QueryReconstructionResponse};
|
||||
use xet_client::hub_client::{BearerCredentialHelper, HubClient, Operation, RepoInfo};
|
||||
use xet_client::hub_client::{BearerCredentialHelper, CredentialHelper, HubClient, Operation, RepoInfo};
|
||||
use xet_core_structures::merklehash::MerkleHash;
|
||||
use xet_core_structures::xorb_object::CompressionScheme;
|
||||
use xet_data::processing::data_client::default_config;
|
||||
use xet_data::processing::migration_tool::hub_client_token_refresher::HubClientTokenRefresher;
|
||||
use xet_data::processing::migration_tool::migrate::migrate_files_impl;
|
||||
use xet_runtime::config::XetConfig;
|
||||
use xet_runtime::core::XetRuntime;
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
const DEFAULT_HF_ENDPOINT: &str = "https://huggingface.co";
|
||||
const USER_AGENT: &str = concat!("xtool", "/", env!("CARGO_PKG_VERSION"));
|
||||
@@ -48,7 +48,7 @@ struct CliOverrides {
|
||||
}
|
||||
|
||||
impl XCommand {
|
||||
async fn run(self) -> Result<()> {
|
||||
async fn run(self, ctx: &XetContext) -> Result<()> {
|
||||
let endpoint = self
|
||||
.overrides
|
||||
.endpoint
|
||||
@@ -63,15 +63,16 @@ impl XCommand {
|
||||
|
||||
let cred_helper = BearerCredentialHelper::new(token, "");
|
||||
let hub_client = HubClient::new(
|
||||
ctx.clone(),
|
||||
&endpoint,
|
||||
RepoInfo::try_from(&self.overrides.repo_type, &self.overrides.repo_id)?,
|
||||
Some("main".to_owned()),
|
||||
"",
|
||||
Some(cred_helper),
|
||||
Some(cred_helper as Arc<dyn CredentialHelper>),
|
||||
Some(headers),
|
||||
)?;
|
||||
|
||||
self.command.run(hub_client).await
|
||||
self.command.run(ctx, hub_client).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -122,14 +123,14 @@ struct QueryArg {
|
||||
}
|
||||
|
||||
impl Command {
|
||||
async fn run(self, hub_client: HubClient) -> Result<()> {
|
||||
async fn run(self, ctx: &XetContext, hub_client: HubClient) -> Result<()> {
|
||||
match self {
|
||||
Command::Dedup(arg) => {
|
||||
let file_paths = walk_files(arg.files, arg.recursive);
|
||||
eprintln!("Dedupping {} files...", file_paths.len());
|
||||
|
||||
let (all_file_info, clean_ret, total_bytes_trans) =
|
||||
migrate_files_impl(file_paths, None, arg.sequential, hub_client, None, !arg.migrate).await?;
|
||||
migrate_files_impl(ctx, file_paths, None, arg.sequential, hub_client, None, !arg.migrate).await?;
|
||||
|
||||
// Print file info for analysis
|
||||
if !arg.migrate {
|
||||
@@ -158,7 +159,7 @@ impl Command {
|
||||
},
|
||||
Command::Query(arg) => {
|
||||
let file_hash = MerkleHash::from_hex(&arg.hash)?;
|
||||
let ret = query_reconstruction(file_hash, arg.bytes_range, hub_client).await?;
|
||||
let ret = query_reconstruction(ctx, file_hash, arg.bytes_range, hub_client).await?;
|
||||
|
||||
eprintln!("{ret:?}");
|
||||
|
||||
@@ -197,6 +198,7 @@ fn is_git_special_files(path: &str) -> bool {
|
||||
}
|
||||
|
||||
async fn query_reconstruction(
|
||||
ctx: &XetContext,
|
||||
file_hash: MerkleHash,
|
||||
bytes_range: Option<FileRange>,
|
||||
hub_client: HubClient,
|
||||
@@ -213,13 +215,20 @@ async fn query_reconstruction(
|
||||
headers.insert(http::header::USER_AGENT, http::HeaderValue::from_static(USER_AGENT));
|
||||
|
||||
let config = default_config(
|
||||
ctx,
|
||||
jwt_info.cas_url.clone(),
|
||||
Some((jwt_info.access_token, jwt_info.exp)),
|
||||
Some(token_refresher),
|
||||
Some(Arc::new(headers)),
|
||||
)?;
|
||||
let remote_client =
|
||||
RemoteClient::new(&jwt_info.cas_url, &config.session.auth, "", true, config.session.custom_headers.clone());
|
||||
let remote_client = RemoteClient::new(
|
||||
ctx.clone(),
|
||||
&jwt_info.cas_url,
|
||||
&config.session.auth,
|
||||
"",
|
||||
true,
|
||||
config.session.custom_headers.clone(),
|
||||
);
|
||||
|
||||
// Use V1 directly so the query tool returns the raw QueryReconstructionResponse for inspection.
|
||||
remote_client
|
||||
@@ -248,8 +257,9 @@ fn main() -> Result<()> {
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e.to_string()))?;
|
||||
}
|
||||
|
||||
let threadpool = XetRuntime::new_with_config(config)?;
|
||||
threadpool.bridge_sync(async move { cli.run().await })??;
|
||||
let ctx = XetContext::with_config(config)?;
|
||||
let ctx_ref = ctx.clone();
|
||||
ctx.runtime.bridge_sync(async move { cli.run(&ctx_ref).await })??;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -4,13 +4,13 @@ use std::sync::Arc;
|
||||
use http::HeaderMap;
|
||||
use tracing::info;
|
||||
use xet_client::cas_client::auth::AuthConfig;
|
||||
use xet_runtime::core::{xet_cache_root, xet_config};
|
||||
use xet_runtime::core::{XetContext, xet_cache_root};
|
||||
|
||||
use crate::error::Result;
|
||||
|
||||
/// Session-specific configuration that varies per upload/download session.
|
||||
/// These are runtime values that cannot be configured via environment variables.
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SessionContext {
|
||||
/// The endpoint URL. Use the `local://` prefix (configurable via `HF_XET_DATA_LOCAL_CAS_SCHEME`)
|
||||
/// to specify a local filesystem path, or `memory://` for in-memory storage.
|
||||
@@ -23,13 +23,13 @@ pub struct SessionContext {
|
||||
|
||||
impl SessionContext {
|
||||
/// Returns true if this endpoint points to a local filesystem path.
|
||||
pub fn is_local(&self) -> bool {
|
||||
self.endpoint.starts_with(&xet_config().data.local_cas_scheme)
|
||||
pub fn is_local(&self, ctx: &XetContext) -> bool {
|
||||
self.endpoint.starts_with(ctx.config.data.local_cas_scheme.as_str())
|
||||
}
|
||||
|
||||
/// Returns the local filesystem path if this is a local endpoint.
|
||||
pub fn local_path(&self) -> Option<PathBuf> {
|
||||
let path = self.endpoint.strip_prefix(&xet_config().data.local_cas_scheme)?;
|
||||
pub fn local_path(&self, ctx: &XetContext) -> Option<PathBuf> {
|
||||
let path = self.endpoint.strip_prefix(ctx.config.data.local_cas_scheme.as_str())?;
|
||||
Some(PathBuf::from(path))
|
||||
}
|
||||
|
||||
@@ -39,9 +39,9 @@ impl SessionContext {
|
||||
}
|
||||
|
||||
/// Creates a SessionContext for local filesystem-based operations.
|
||||
pub fn for_local_path(base_dir: impl AsRef<Path>) -> Self {
|
||||
pub fn for_local_path(ctx: &XetContext, base_dir: impl AsRef<Path>) -> Self {
|
||||
let path = base_dir.as_ref().to_path_buf();
|
||||
let endpoint = format!("{}{}", xet_config().data.local_cas_scheme, path.display());
|
||||
let endpoint = format!("{}{}", ctx.config.data.local_cas_scheme, path.display());
|
||||
Self {
|
||||
endpoint,
|
||||
auth: None,
|
||||
@@ -65,8 +65,9 @@ impl SessionContext {
|
||||
|
||||
/// Main configuration for file upload/download operations.
|
||||
/// Combines session-specific values with runtime-computed paths derived from the endpoint.
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TranslatorConfig {
|
||||
pub ctx: XetContext,
|
||||
pub session: SessionContext,
|
||||
|
||||
/// Directory for caching shard files.
|
||||
@@ -88,10 +89,10 @@ impl TranslatorConfig {
|
||||
}
|
||||
|
||||
/// Creates a new TranslatorConfig from a SessionContext, computing all derived paths.
|
||||
pub fn new(session: SessionContext) -> Result<Self> {
|
||||
let config = xet_config();
|
||||
pub fn new(ctx: &XetContext, session: SessionContext) -> Result<Self> {
|
||||
let config = ctx.config.as_ref();
|
||||
|
||||
let (shard_cache_directory, shard_session_directory) = if let Some(local_path) = session.local_path() {
|
||||
let (shard_cache_directory, shard_session_directory) = if let Some(local_path) = session.local_path(ctx) {
|
||||
let base_path = local_path.join("xet");
|
||||
std::fs::create_dir_all(&base_path)?;
|
||||
|
||||
@@ -120,6 +121,7 @@ impl TranslatorConfig {
|
||||
);
|
||||
|
||||
Ok(Self {
|
||||
ctx: ctx.clone(),
|
||||
session,
|
||||
shard_cache_directory,
|
||||
shard_session_directory,
|
||||
@@ -128,18 +130,19 @@ impl TranslatorConfig {
|
||||
}
|
||||
|
||||
/// Creates a TranslatorConfig for local filesystem-based storage.
|
||||
pub fn local_config(base_dir: impl AsRef<Path>) -> Result<Self> {
|
||||
Self::new(SessionContext::for_local_path(base_dir))
|
||||
pub fn local_config(ctx: &XetContext, base_dir: impl AsRef<Path>) -> Result<Self> {
|
||||
Self::new(ctx, SessionContext::for_local_path(ctx, base_dir))
|
||||
}
|
||||
|
||||
/// Creates a TranslatorConfig that uses in-memory storage for XORBs.
|
||||
/// Shard data still uses file-based storage in the provided base directory.
|
||||
pub fn memory_config(base_dir: impl AsRef<Path>) -> Result<Self> {
|
||||
pub fn memory_config(ctx: &XetContext, base_dir: impl AsRef<Path>) -> Result<Self> {
|
||||
let session = SessionContext::for_memory();
|
||||
let config = xet_config();
|
||||
let config = ctx.config.as_ref();
|
||||
let base_path = Self::create_base_xet_dir(base_dir)?;
|
||||
|
||||
Ok(Self {
|
||||
ctx: ctx.clone(),
|
||||
session,
|
||||
shard_cache_directory: base_path.join(&config.shard.cache_subdir),
|
||||
shard_session_directory: base_path.join(&config.session.dir_name),
|
||||
@@ -150,7 +153,7 @@ impl TranslatorConfig {
|
||||
/// Creates a TranslatorConfig that connects to a CAS server at the given endpoint.
|
||||
/// Shard cache and session directories are created under the provided base directory.
|
||||
/// Useful for tests that use LocalTestServer.
|
||||
pub fn test_server_config(endpoint: impl AsRef<str>, base_dir: impl AsRef<Path>) -> Result<Self> {
|
||||
pub fn test_server_config(ctx: &XetContext, endpoint: impl AsRef<str>, base_dir: impl AsRef<Path>) -> Result<Self> {
|
||||
let session = SessionContext {
|
||||
endpoint: endpoint.as_ref().to_string(),
|
||||
auth: None,
|
||||
@@ -158,10 +161,11 @@ impl TranslatorConfig {
|
||||
repo_paths: vec!["".into()],
|
||||
session_id: None,
|
||||
};
|
||||
let config = xet_config();
|
||||
let config = ctx.config.as_ref();
|
||||
let base_path = Self::create_base_xet_dir(base_dir)?;
|
||||
|
||||
Ok(Self {
|
||||
ctx: ctx.clone(),
|
||||
session,
|
||||
shard_cache_directory: base_path.join(&config.shard.cache_subdir),
|
||||
shard_session_directory: base_path.join(&config.session.dir_name),
|
||||
@@ -194,21 +198,23 @@ fn compute_cache_path(endpoint: &str) -> PathBuf {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use tempfile::tempdir;
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
use super::{SessionContext, TranslatorConfig};
|
||||
|
||||
#[test]
|
||||
fn test_session_context_mode_detection() {
|
||||
let ctx = XetContext::default().unwrap();
|
||||
let temp_dir = tempdir().unwrap();
|
||||
let local_session = SessionContext::for_local_path(temp_dir.path());
|
||||
assert!(local_session.is_local());
|
||||
let local_session = SessionContext::for_local_path(&ctx, temp_dir.path());
|
||||
assert!(local_session.is_local(&ctx));
|
||||
assert!(!local_session.is_memory());
|
||||
assert_eq!(local_session.local_path().unwrap(), temp_dir.path().to_path_buf());
|
||||
assert_eq!(local_session.local_path(&ctx).unwrap(), temp_dir.path().to_path_buf());
|
||||
|
||||
let memory_session = SessionContext::for_memory();
|
||||
assert!(memory_session.is_memory());
|
||||
assert!(!memory_session.is_local());
|
||||
assert!(memory_session.local_path().is_none());
|
||||
assert!(!memory_session.is_local(&ctx));
|
||||
assert!(memory_session.local_path(&ctx).is_none());
|
||||
|
||||
let remote_session = SessionContext {
|
||||
endpoint: "http://localhost:8080".into(),
|
||||
@@ -217,20 +223,22 @@ mod tests {
|
||||
repo_paths: Vec::new(),
|
||||
session_id: None,
|
||||
};
|
||||
assert!(!remote_session.is_local());
|
||||
assert!(!remote_session.is_local(&ctx));
|
||||
assert!(!remote_session.is_memory());
|
||||
assert!(remote_session.local_path().is_none());
|
||||
assert!(remote_session.local_path(&ctx).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_memory_and_server_configs_use_base_xet_layout() {
|
||||
let ctx = XetContext::default().unwrap();
|
||||
let temp_dir = tempdir().unwrap();
|
||||
|
||||
let memory_config = TranslatorConfig::memory_config(temp_dir.path()).unwrap();
|
||||
let memory_config = TranslatorConfig::memory_config(&ctx, temp_dir.path()).unwrap();
|
||||
assert!(memory_config.shard_cache_directory.starts_with(temp_dir.path().join("xet")));
|
||||
assert!(memory_config.shard_session_directory.starts_with(temp_dir.path().join("xet")));
|
||||
|
||||
let server_config = TranslatorConfig::test_server_config("http://localhost:8080", temp_dir.path()).unwrap();
|
||||
let server_config =
|
||||
TranslatorConfig::test_server_config(&ctx, "http://localhost:8080", temp_dir.path()).unwrap();
|
||||
assert!(server_config.shard_cache_directory.starts_with(temp_dir.path().join("xet")));
|
||||
assert!(server_config.shard_session_directory.starts_with(temp_dir.path().join("xet")));
|
||||
}
|
||||
|
||||
@@ -9,8 +9,8 @@ use tracing::{Instrument, Span, info_span, instrument};
|
||||
use uuid::Uuid;
|
||||
use xet_client::cas_client::auth::{AuthConfig, TokenRefresher};
|
||||
use xet_core_structures::merklehash::MerkleHash;
|
||||
use xet_runtime::core::XetContext;
|
||||
use xet_runtime::core::par_utils::run_constrained_with_semaphore;
|
||||
use xet_runtime::core::{XetRuntime, check_sigint_shutdown, xet_config};
|
||||
|
||||
use super::configurations::{SessionContext, TranslatorConfig};
|
||||
use super::file_cleaner::Sha256Policy;
|
||||
@@ -19,6 +19,7 @@ use crate::deduplication::{Chunker, DeduplicationMetrics};
|
||||
use crate::error::Result;
|
||||
|
||||
pub fn default_config(
|
||||
ctx: &XetContext,
|
||||
endpoint: String,
|
||||
token_info: Option<(String, u64)>,
|
||||
token_refresher: Option<Arc<dyn TokenRefresher>>,
|
||||
@@ -35,7 +36,7 @@ pub fn default_config(
|
||||
session_id: Some(Uuid::now_v7().to_string()),
|
||||
};
|
||||
|
||||
TranslatorConfig::new(session)
|
||||
TranslatorConfig::new(ctx, session)
|
||||
}
|
||||
|
||||
#[instrument(skip_all, name = "clean_bytes", fields(bytes.len = bytes.len()))]
|
||||
@@ -61,7 +62,7 @@ pub async fn clean_file(
|
||||
let span = Span::current();
|
||||
span.record("file.name", filename.as_ref().to_str());
|
||||
span.record("file.len", filesize);
|
||||
let mut buffer = vec![0u8; u64::min(filesize, *xet_config().data.ingestion_block_size) as usize];
|
||||
let mut buffer = vec![0u8; u64::min(filesize, *processor.ctx.config.data.ingestion_block_size) as usize];
|
||||
|
||||
let (_id, mut handle) =
|
||||
processor.start_clean(Some(filename.as_ref().to_string_lossy().into()), Some(filesize), sha256_policy)?;
|
||||
@@ -99,7 +100,7 @@ pub async fn clean_file(
|
||||
/// - Verify that downloaded files are correctly reassembled
|
||||
/// - Check if a file needs to be uploaded (by comparing hashes)
|
||||
/// - Generate cache keys for local file operations
|
||||
fn hash_single_file(filename: String, buffer_size: usize) -> Result<XetFileInfo> {
|
||||
fn hash_single_file(ctx: XetContext, filename: String, buffer_size: usize) -> Result<XetFileInfo> {
|
||||
let mut reader = File::open(&filename)?;
|
||||
let filesize = reader.metadata()?.len();
|
||||
|
||||
@@ -108,7 +109,7 @@ fn hash_single_file(filename: String, buffer_size: usize) -> Result<XetFileInfo>
|
||||
let mut chunk_hashes: Vec<(MerkleHash, u64)> = Vec::new();
|
||||
|
||||
loop {
|
||||
check_sigint_shutdown()?;
|
||||
ctx.check_sigint_shutdown()?;
|
||||
|
||||
let bytes_read = reader.read(&mut buffer)?;
|
||||
if bytes_read == 0 {
|
||||
@@ -158,15 +159,17 @@ fn hash_single_file(filename: String, buffer_size: usize) -> Result<XetFileInfo>
|
||||
/// - No authentication or server connection required
|
||||
/// - Pure local computation
|
||||
#[instrument(skip_all, name = "data_client::hash_files", fields(num_files=file_paths.len()))]
|
||||
pub async fn hash_files_async(file_paths: Vec<String>) -> Result<Vec<XetFileInfo>> {
|
||||
let rt = XetRuntime::current();
|
||||
let semaphore = rt.common().file_ingestion_semaphore.clone();
|
||||
let buffer_size = *xet_config().data.ingestion_block_size as usize;
|
||||
pub async fn hash_files_async(ctx: &XetContext, file_paths: Vec<String>) -> Result<Vec<XetFileInfo>> {
|
||||
let runtime = ctx.runtime.clone();
|
||||
let semaphore = ctx.common.file_ingestion_semaphore.clone();
|
||||
let buffer_size = *ctx.config.data.ingestion_block_size as usize;
|
||||
|
||||
let hash_futures = file_paths.into_iter().map(|file_path| {
|
||||
let rt = rt.clone();
|
||||
let runtime = runtime.clone();
|
||||
let ctx = ctx.clone();
|
||||
async move {
|
||||
rt.spawn_blocking(move || hash_single_file(file_path, buffer_size))
|
||||
runtime
|
||||
.spawn_blocking(move || hash_single_file(ctx, file_path, buffer_size))
|
||||
.await
|
||||
.map_err(|e| std::io::Error::other(e.to_string()))?
|
||||
}
|
||||
@@ -183,6 +186,7 @@ mod tests {
|
||||
use dirs::home_dir;
|
||||
use serial_test::serial;
|
||||
use tempfile::tempdir;
|
||||
use xet_runtime::core::XetContext;
|
||||
use xet_runtime::utils::EnvVarGuard;
|
||||
|
||||
use super::*;
|
||||
@@ -194,7 +198,8 @@ mod tests {
|
||||
let _hf_home_guard = EnvVarGuard::set("HF_HOME", temp_dir.path().to_str().unwrap());
|
||||
|
||||
let endpoint = "http://localhost:8080".to_string();
|
||||
let result = default_config(endpoint, None, None, None);
|
||||
let ctx = XetContext::default().unwrap();
|
||||
let result = default_config(&ctx, endpoint, None, None, None);
|
||||
|
||||
assert!(result.is_ok());
|
||||
let config = result.unwrap();
|
||||
@@ -211,7 +216,8 @@ mod tests {
|
||||
let hf_home_guard = EnvVarGuard::set("HF_HOME", temp_dir_hf_home.path().to_str().unwrap());
|
||||
|
||||
let endpoint = "http://localhost:8080".to_string();
|
||||
let result = default_config(endpoint, None, None, None);
|
||||
let ctx = XetContext::default().unwrap();
|
||||
let result = default_config(&ctx, endpoint, None, None, None);
|
||||
|
||||
assert!(result.is_ok());
|
||||
let config = result.unwrap();
|
||||
@@ -224,7 +230,8 @@ mod tests {
|
||||
let _hf_home_guard = EnvVarGuard::set("HF_HOME", temp_dir.path().to_str().unwrap());
|
||||
|
||||
let endpoint = "http://localhost:8080".to_string();
|
||||
let result = default_config(endpoint, None, None, None);
|
||||
let ctx = XetContext::default().unwrap();
|
||||
let result = default_config(&ctx, endpoint, None, None, None);
|
||||
|
||||
assert!(result.is_ok());
|
||||
let config = result.unwrap();
|
||||
@@ -238,7 +245,8 @@ mod tests {
|
||||
let _hf_xet_cache_guard = EnvVarGuard::set("HF_XET_CACHE", temp_dir.path().to_str().unwrap());
|
||||
|
||||
let endpoint = "http://localhost:8080".to_string();
|
||||
let result = default_config(endpoint, None, None, None);
|
||||
let ctx = XetContext::default().unwrap();
|
||||
let result = default_config(&ctx, endpoint, None, None, None);
|
||||
|
||||
assert!(result.is_ok());
|
||||
let config = result.unwrap();
|
||||
@@ -249,7 +257,8 @@ mod tests {
|
||||
#[serial(default_config_env)]
|
||||
fn test_default_config_without_env_vars() {
|
||||
let endpoint = "http://localhost:8080".to_string();
|
||||
let result = default_config(endpoint, None, None, None);
|
||||
let ctx = XetContext::default().unwrap();
|
||||
let result = default_config(&ctx, endpoint, None, None, None);
|
||||
|
||||
let expected = home_dir().unwrap().join(".cache").join("huggingface").join("xet");
|
||||
|
||||
@@ -269,7 +278,8 @@ mod tests {
|
||||
std::fs::write(&file_path, b"").unwrap();
|
||||
|
||||
let buffer_size = 8 * 1024 * 1024; // 8MB
|
||||
let result = hash_single_file(file_path.to_str().unwrap().to_string(), buffer_size);
|
||||
let ctx = XetContext::default().unwrap();
|
||||
let result = hash_single_file(ctx, file_path.to_str().unwrap().to_string(), buffer_size);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let file_info = result.unwrap();
|
||||
@@ -285,7 +295,8 @@ mod tests {
|
||||
std::fs::write(&file_path, content).unwrap();
|
||||
|
||||
let buffer_size = 8 * 1024 * 1024; // 8MB
|
||||
let result = hash_single_file(file_path.to_str().unwrap().to_string(), buffer_size);
|
||||
let ctx = XetContext::default().unwrap();
|
||||
let result = hash_single_file(ctx, file_path.to_str().unwrap().to_string(), buffer_size);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let file_info = result.unwrap();
|
||||
@@ -306,14 +317,15 @@ mod tests {
|
||||
std::fs::write(&file_path, &content).unwrap();
|
||||
|
||||
let file_path_str = file_path.to_str().unwrap().to_string();
|
||||
let ctx = XetContext::default().unwrap();
|
||||
|
||||
// Hash with 8MB buffer size
|
||||
let result1 = hash_single_file(file_path_str.clone(), 8 * 1024 * 1024);
|
||||
let result1 = hash_single_file(ctx.clone(), file_path_str.clone(), 8 * 1024 * 1024);
|
||||
assert!(result1.is_ok());
|
||||
let file_info1 = result1.unwrap();
|
||||
|
||||
// Hash with 4MB buffer size
|
||||
let result2 = hash_single_file(file_path_str, 4 * 1024 * 1024);
|
||||
let result2 = hash_single_file(ctx.clone(), file_path_str, 4 * 1024 * 1024);
|
||||
assert!(result2.is_ok());
|
||||
let file_info2 = result2.unwrap();
|
||||
|
||||
@@ -326,7 +338,8 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_hash_file_not_found() {
|
||||
let buffer_size = 8 * 1024 * 1024; // 8MB
|
||||
let result = hash_single_file("/nonexistent/file.txt".to_string(), buffer_size);
|
||||
let ctx = XetContext::default().unwrap();
|
||||
let result = hash_single_file(ctx, "/nonexistent/file.txt".to_string(), buffer_size);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
@@ -345,7 +358,8 @@ mod tests {
|
||||
file2_path.to_str().unwrap().to_string(),
|
||||
];
|
||||
|
||||
let result = hash_files_async(file_paths).await;
|
||||
let ctx = XetContext::default().unwrap();
|
||||
let result = hash_files_async(&ctx, file_paths).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
let file_infos = result.unwrap();
|
||||
@@ -370,21 +384,22 @@ mod tests {
|
||||
std::fs::write(&file_path, &content).unwrap();
|
||||
|
||||
let file_path_str = file_path.to_str().unwrap().to_string();
|
||||
let ctx = XetContext::default().unwrap();
|
||||
|
||||
// Hash with 8MB buffer size - file is exactly 2x buffer size
|
||||
let result1 = hash_single_file(file_path_str.clone(), 8 * 1024 * 1024);
|
||||
let result1 = hash_single_file(ctx.clone(), file_path_str.clone(), 8 * 1024 * 1024);
|
||||
assert!(result1.is_ok());
|
||||
let file_info1 = result1.unwrap();
|
||||
assert_eq!(file_info1.file_size(), Some(file_size as u64));
|
||||
assert!(!file_info1.hash().is_empty());
|
||||
|
||||
// Hash with 4MB buffer size - file is exactly 4x buffer size
|
||||
let result2 = hash_single_file(file_path_str.clone(), 4 * 1024 * 1024);
|
||||
let result2 = hash_single_file(ctx.clone(), file_path_str.clone(), 4 * 1024 * 1024);
|
||||
assert!(result2.is_ok());
|
||||
let file_info2 = result2.unwrap();
|
||||
|
||||
// Hash with 2MB buffer size - file is exactly 8x buffer size
|
||||
let result3 = hash_single_file(file_path_str, 2 * 1024 * 1024);
|
||||
let result3 = hash_single_file(ctx, file_path_str, 2 * 1024 * 1024);
|
||||
assert!(result3.is_ok());
|
||||
let file_info3 = result3.unwrap();
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ impl UploadSessionDataManager {
|
||||
}
|
||||
|
||||
fn global_dedup_queries_enabled(&self) -> bool {
|
||||
xet_runtime::core::xet_config().deduplication.global_dedup_query_enabled
|
||||
self.session.ctx.config.deduplication.global_dedup_query_enabled
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ use chrono::{DateTime, Utc};
|
||||
use tracing::{Instrument, debug_span, info, instrument};
|
||||
use xet_core_structures::metadata_shard::Sha256;
|
||||
use xet_core_structures::metadata_shard::file_structs::FileMetadataExt;
|
||||
use xet_runtime::core::{XetRuntime, xet_config};
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
use super::XetFileInfo;
|
||||
use super::deduplication_interface::UploadSessionDataManager;
|
||||
@@ -53,6 +53,8 @@ impl From<Option<Sha256>> for Sha256Policy {
|
||||
|
||||
/// A class that encapsulates the clean and data task around a single file.
|
||||
pub struct SingleFileCleaner {
|
||||
ctx: XetContext,
|
||||
|
||||
// File name, if known.
|
||||
file_name: Option<Arc<str>>,
|
||||
|
||||
@@ -86,15 +88,17 @@ impl SingleFileCleaner {
|
||||
sha256: Sha256Policy,
|
||||
session: Arc<FileUploadSession>,
|
||||
) -> Self {
|
||||
let deduper = FileDeduper::new(UploadSessionDataManager::new(session.clone()), file_id);
|
||||
let ctx = session.ctx.clone();
|
||||
let deduper = FileDeduper::new(UploadSessionDataManager::new(session.clone()), file_id, ctx.clone());
|
||||
|
||||
let (sha_generator, provided_sha256) = match sha256 {
|
||||
Sha256Policy::Compute => (Some(Sha256Generator::default()), None),
|
||||
Sha256Policy::Compute => (Some(Sha256Generator::new(ctx.clone())), None),
|
||||
Sha256Policy::Provided(hash) => (None, Some(hash)),
|
||||
Sha256Policy::Skip => (None, None),
|
||||
};
|
||||
|
||||
Self {
|
||||
ctx,
|
||||
file_name,
|
||||
file_id,
|
||||
dedup_manager_fut: Box::pin(async move { Ok(deduper) }),
|
||||
@@ -133,7 +137,7 @@ impl SingleFileCleaner {
|
||||
}
|
||||
|
||||
pub async fn add_data_from_bytes(&mut self, data: Bytes) -> Result<()> {
|
||||
let block_size = *xet_config().data.ingestion_block_size as usize;
|
||||
let block_size = *self.ctx.config.data.ingestion_block_size as usize;
|
||||
if data.len() > block_size {
|
||||
let mut pos = 0;
|
||||
while pos < data.len() {
|
||||
@@ -160,9 +164,9 @@ impl SingleFileCleaner {
|
||||
let chunk_data_jh = {
|
||||
let mut chunker = std::mem::take(&mut self.chunker);
|
||||
let data = data.clone();
|
||||
let rt = XetRuntime::current();
|
||||
let runtime = self.ctx.runtime.clone();
|
||||
|
||||
rt.spawn_blocking(move || {
|
||||
runtime.spawn_blocking(move || {
|
||||
let chunks: Arc<[Chunk]> = Arc::from(chunker.next_block_bytes(&data, false));
|
||||
(chunks, chunker)
|
||||
})
|
||||
|
||||
@@ -11,7 +11,7 @@ use tracing::instrument;
|
||||
use xet_client::cas_client::Client;
|
||||
use xet_client::cas_types::FileRange;
|
||||
use xet_client::chunk_cache::ChunkCache;
|
||||
use xet_runtime::core::{XetRuntime, xet_config};
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
use super::XetFileInfo;
|
||||
use super::configurations::TranslatorConfig;
|
||||
@@ -25,6 +25,7 @@ use crate::progress_tracking::{GroupProgress, ItemProgressUpdater, UniqueID};
|
||||
/// This struct parallels `FileUploadSession` for the download path. It holds the
|
||||
/// CAS client and a shared progress group for all downloads in the session.
|
||||
pub struct FileDownloadSession {
|
||||
ctx: XetContext,
|
||||
client: Arc<dyn Client>,
|
||||
chunk_cache: Option<Arc<dyn ChunkCache>>,
|
||||
progress: Arc<GroupProgress>,
|
||||
@@ -41,13 +42,15 @@ impl FileDownloadSession {
|
||||
.map(Cow::Borrowed)
|
||||
.unwrap_or_else(|| Cow::Owned(UniqueID::new().to_string()));
|
||||
|
||||
let ctx = config.ctx.clone();
|
||||
let client = create_remote_client(&config, &session_id, false).await?;
|
||||
let progress = GroupProgress::with_speed_config(
|
||||
xet_config().data.progress_update_speed_sampling_window,
|
||||
xet_config().data.progress_update_speed_min_observations,
|
||||
ctx.config.data.progress_update_speed_sampling_window,
|
||||
ctx.config.data.progress_update_speed_min_observations,
|
||||
);
|
||||
|
||||
Ok(Arc::new(Self {
|
||||
ctx,
|
||||
client,
|
||||
chunk_cache,
|
||||
progress,
|
||||
@@ -59,10 +62,16 @@ impl FileDownloadSession {
|
||||
/// Construct a download session from an existing CAS client.
|
||||
///
|
||||
/// This path uses default progress speed settings. Use [`Self::new`] when the
|
||||
/// session should inherit the configured speed parameters from `xet_config`.
|
||||
pub fn from_client(client: Arc<dyn Client>, chunk_cache: Option<Arc<dyn ChunkCache>>) -> Arc<Self> {
|
||||
/// session should inherit the configured speed parameters from the context used
|
||||
/// to build [`TranslatorConfig`].
|
||||
pub fn from_client(
|
||||
ctx: &XetContext,
|
||||
client: Arc<dyn Client>,
|
||||
chunk_cache: Option<Arc<dyn ChunkCache>>,
|
||||
) -> Arc<Self> {
|
||||
let progress = GroupProgress::new();
|
||||
Arc::new(Self {
|
||||
ctx: ctx.clone(),
|
||||
client,
|
||||
chunk_cache,
|
||||
progress,
|
||||
@@ -110,9 +119,9 @@ impl FileDownloadSession {
|
||||
self.check_not_finalized()?;
|
||||
let id = UniqueID::new();
|
||||
let session = self.clone();
|
||||
let rt = XetRuntime::current();
|
||||
let semaphore = rt.common().file_download_semaphore.clone();
|
||||
let handle = rt.spawn(async move {
|
||||
let runtime = self.ctx.runtime.clone();
|
||||
let semaphore = self.ctx.common.file_download_semaphore.clone();
|
||||
let handle = runtime.spawn(async move {
|
||||
let _permit = semaphore.acquire().await?;
|
||||
session.download_file_with_id(&file_info, &write_path, id).await
|
||||
});
|
||||
@@ -292,7 +301,7 @@ impl FileDownloadSession {
|
||||
) -> Result<FileReconstructor> {
|
||||
let file_id = file_info.merkle_hash()?;
|
||||
|
||||
let mut reconstructor = FileReconstructor::new(&self.client, file_id);
|
||||
let mut reconstructor = FileReconstructor::new(&self.ctx, &self.client, file_id);
|
||||
|
||||
match range {
|
||||
Some(range) if range.end < u64::MAX => {
|
||||
@@ -371,22 +380,28 @@ mod tests {
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
use tempfile::tempdir;
|
||||
use xet_runtime::core::XetRuntime;
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
use super::*;
|
||||
use crate::processing::configurations::TranslatorConfig;
|
||||
use crate::processing::file_cleaner::Sha256Policy;
|
||||
use crate::processing::{FileUploadSession, XetFileInfo};
|
||||
|
||||
fn get_threadpool() -> Arc<XetRuntime> {
|
||||
static THREADPOOL: OnceLock<Arc<XetRuntime>> = OnceLock::new();
|
||||
fn get_runtime() -> Arc<xet_runtime::core::XetRuntime> {
|
||||
static THREADPOOL: OnceLock<Arc<xet_runtime::core::XetRuntime>> = OnceLock::new();
|
||||
THREADPOOL
|
||||
.get_or_init(|| XetRuntime::new().expect("Error starting multithreaded runtime."))
|
||||
.get_or_init(|| {
|
||||
XetContext::default()
|
||||
.expect("Error starting multithreaded runtime.")
|
||||
.runtime
|
||||
.clone()
|
||||
})
|
||||
.clone()
|
||||
}
|
||||
|
||||
async fn upload_data(cas_path: &Path, data: &[u8]) -> XetFileInfo {
|
||||
let upload_session = FileUploadSession::new(TranslatorConfig::local_config(cas_path).unwrap().into())
|
||||
let ctx = XetContext::default().unwrap();
|
||||
let upload_session = FileUploadSession::new(TranslatorConfig::local_config(&ctx, cas_path).unwrap().into())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -401,9 +416,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_download_file() {
|
||||
let runtime = get_threadpool();
|
||||
let runtime = get_runtime();
|
||||
runtime
|
||||
.clone()
|
||||
.bridge_sync(async {
|
||||
let temp = tempdir().unwrap();
|
||||
let cas_path = temp.path().join("cas");
|
||||
@@ -411,7 +425,7 @@ mod tests {
|
||||
|
||||
let xfi = upload_data(&cas_path, original_data).await;
|
||||
|
||||
let config = TranslatorConfig::local_config(&cas_path).unwrap();
|
||||
let config = TranslatorConfig::local_config(&XetContext::default().unwrap(), &cas_path).unwrap();
|
||||
let session = FileDownloadSession::new(config.into(), None).await.unwrap();
|
||||
|
||||
let out_path = temp.path().join("output.txt");
|
||||
@@ -425,9 +439,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_download_file_creates_parent_dirs() {
|
||||
let runtime = get_threadpool();
|
||||
let runtime = get_runtime();
|
||||
runtime
|
||||
.clone()
|
||||
.bridge_sync(async {
|
||||
let temp = tempdir().unwrap();
|
||||
let cas_path = temp.path().join("cas");
|
||||
@@ -435,7 +448,7 @@ mod tests {
|
||||
|
||||
let xfi = upload_data(&cas_path, original_data).await;
|
||||
|
||||
let config = TranslatorConfig::local_config(&cas_path).unwrap();
|
||||
let config = TranslatorConfig::local_config(&XetContext::default().unwrap(), &cas_path).unwrap();
|
||||
let session = FileDownloadSession::new(config.into(), None).await.unwrap();
|
||||
|
||||
let out_path = temp.path().join("deep").join("nested").join("dir").join("output.txt");
|
||||
@@ -450,9 +463,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_download_to_writer() {
|
||||
let runtime = get_threadpool();
|
||||
let runtime = get_runtime();
|
||||
runtime
|
||||
.clone()
|
||||
.bridge_sync(async {
|
||||
let temp = tempdir().unwrap();
|
||||
let cas_path = temp.path().join("cas");
|
||||
@@ -460,7 +472,7 @@ mod tests {
|
||||
|
||||
let xfi = upload_data(&cas_path, original_data).await;
|
||||
|
||||
let config = TranslatorConfig::local_config(&cas_path).unwrap();
|
||||
let config = TranslatorConfig::local_config(&XetContext::default().unwrap(), &cas_path).unwrap();
|
||||
let session = FileDownloadSession::new(config.into(), None).await.unwrap();
|
||||
|
||||
let out_path = temp.path().join("partial_writer.txt");
|
||||
@@ -480,16 +492,15 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_download_to_writer_parallel_partitioned_file() {
|
||||
let runtime = get_threadpool();
|
||||
let runtime = get_runtime();
|
||||
runtime
|
||||
.clone()
|
||||
.bridge_sync(async {
|
||||
let temp = tempdir().unwrap();
|
||||
let cas_path = temp.path().join("cas");
|
||||
let original_data = b"abcdefghijklmnopqrstuvwxyz0123456789";
|
||||
|
||||
let xfi = upload_data(&cas_path, original_data).await;
|
||||
let config = TranslatorConfig::local_config(&cas_path).unwrap();
|
||||
let config = TranslatorConfig::local_config(&XetContext::default().unwrap(), &cas_path).unwrap();
|
||||
let session = FileDownloadSession::new(config.into(), None).await.unwrap();
|
||||
|
||||
let out_path = temp.path().join("partitioned.txt");
|
||||
@@ -528,9 +539,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_download_multiple_files_concurrent() {
|
||||
let runtime = get_threadpool();
|
||||
let runtime = get_runtime();
|
||||
runtime
|
||||
.clone()
|
||||
.bridge_sync(async {
|
||||
let temp = tempdir().unwrap();
|
||||
let cas_path = temp.path().join("cas");
|
||||
@@ -541,7 +551,7 @@ mod tests {
|
||||
let xfi_a = upload_data(&cas_path, data_a).await;
|
||||
let xfi_b = upload_data(&cas_path, data_b).await;
|
||||
|
||||
let config = TranslatorConfig::local_config(&cas_path).unwrap();
|
||||
let config = TranslatorConfig::local_config(&XetContext::default().unwrap(), &cas_path).unwrap();
|
||||
let session = FileDownloadSession::new(config.into(), None).await.unwrap();
|
||||
|
||||
let out_a = temp.path().join("out_a.txt");
|
||||
@@ -570,9 +580,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_download_stream_async() {
|
||||
let runtime = get_threadpool();
|
||||
let runtime = get_runtime();
|
||||
runtime
|
||||
.clone()
|
||||
.bridge_sync(async {
|
||||
let temp = tempdir().unwrap();
|
||||
let cas_path = temp.path().join("cas");
|
||||
@@ -580,7 +589,7 @@ mod tests {
|
||||
|
||||
let xfi = upload_data(&cas_path, original_data).await;
|
||||
|
||||
let config = TranslatorConfig::local_config(&cas_path).unwrap();
|
||||
let config = TranslatorConfig::local_config(&XetContext::default().unwrap(), &cas_path).unwrap();
|
||||
let session = FileDownloadSession::new(config.into(), None).await.unwrap();
|
||||
|
||||
let (_id, mut stream) = session.download_stream(&xfi, None).await.unwrap();
|
||||
@@ -597,9 +606,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_download_stream_blocking() {
|
||||
let runtime = get_threadpool();
|
||||
let runtime = get_runtime();
|
||||
runtime
|
||||
.clone()
|
||||
.bridge_sync(async {
|
||||
let temp = tempdir().unwrap();
|
||||
let cas_path = temp.path().join("cas");
|
||||
@@ -607,7 +615,7 @@ mod tests {
|
||||
|
||||
let xfi = upload_data(&cas_path, original_data).await;
|
||||
|
||||
let config = TranslatorConfig::local_config(&cas_path).unwrap();
|
||||
let config = TranslatorConfig::local_config(&XetContext::default().unwrap(), &cas_path).unwrap();
|
||||
let session = FileDownloadSession::new(config.into(), None).await.unwrap();
|
||||
|
||||
let (_id, stream) = session.download_stream(&xfi, None).await.unwrap();
|
||||
@@ -630,9 +638,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_download_stream_returns_none_after_finish() {
|
||||
let runtime = get_threadpool();
|
||||
let runtime = get_runtime();
|
||||
runtime
|
||||
.clone()
|
||||
.bridge_sync(async {
|
||||
let temp = tempdir().unwrap();
|
||||
let cas_path = temp.path().join("cas");
|
||||
@@ -640,7 +647,7 @@ mod tests {
|
||||
|
||||
let xfi = upload_data(&cas_path, original_data).await;
|
||||
|
||||
let config = TranslatorConfig::local_config(&cas_path).unwrap();
|
||||
let config = TranslatorConfig::local_config(&XetContext::default().unwrap(), &cas_path).unwrap();
|
||||
let session = FileDownloadSession::new(config.into(), None).await.unwrap();
|
||||
|
||||
let (_id, mut stream) = session.download_stream(&xfi, None).await.unwrap();
|
||||
@@ -656,9 +663,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_download_stream_multiple_concurrent() {
|
||||
let runtime = get_threadpool();
|
||||
let runtime = get_runtime();
|
||||
runtime
|
||||
.clone()
|
||||
.bridge_sync(async {
|
||||
let temp = tempdir().unwrap();
|
||||
let cas_path = temp.path().join("cas");
|
||||
@@ -669,7 +675,7 @@ mod tests {
|
||||
let xfi_a = upload_data(&cas_path, data_a).await;
|
||||
let xfi_b = upload_data(&cas_path, data_b).await;
|
||||
|
||||
let config = TranslatorConfig::local_config(&cas_path).unwrap();
|
||||
let config = TranslatorConfig::local_config(&XetContext::default().unwrap(), &cas_path).unwrap();
|
||||
let session = FileDownloadSession::new(config.into(), None).await.unwrap();
|
||||
|
||||
let (_id_a, mut stream_a) = session.download_stream(&xfi_a, None).await.unwrap();
|
||||
@@ -702,9 +708,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_drop_stream_without_reading() {
|
||||
let runtime = get_threadpool();
|
||||
let runtime = get_runtime();
|
||||
runtime
|
||||
.clone()
|
||||
.bridge_sync(async {
|
||||
let temp = tempdir().unwrap();
|
||||
let cas_path = temp.path().join("cas");
|
||||
@@ -712,7 +717,7 @@ mod tests {
|
||||
|
||||
let xfi = upload_data(&cas_path, original_data).await;
|
||||
|
||||
let config = TranslatorConfig::local_config(&cas_path).unwrap();
|
||||
let config = TranslatorConfig::local_config(&XetContext::default().unwrap(), &cas_path).unwrap();
|
||||
let session = FileDownloadSession::new(config.into(), None).await.unwrap();
|
||||
|
||||
let (_id, stream) = session.download_stream(&xfi, None).await.unwrap();
|
||||
@@ -728,9 +733,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_drop_stream_multiple_cycles_then_download() {
|
||||
let runtime = get_threadpool();
|
||||
let runtime = get_runtime();
|
||||
runtime
|
||||
.clone()
|
||||
.bridge_sync(async {
|
||||
let temp = tempdir().unwrap();
|
||||
let cas_path = temp.path().join("cas");
|
||||
@@ -738,7 +742,7 @@ mod tests {
|
||||
|
||||
let xfi = upload_data(&cas_path, original_data).await;
|
||||
|
||||
let config = TranslatorConfig::local_config(&cas_path).unwrap();
|
||||
let config = TranslatorConfig::local_config(&XetContext::default().unwrap(), &cas_path).unwrap();
|
||||
let session = FileDownloadSession::new(config.into(), None).await.unwrap();
|
||||
|
||||
for i in 0..5u32 {
|
||||
@@ -759,9 +763,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_drop_stream_blocking_mid_read_then_download() {
|
||||
let runtime = get_threadpool();
|
||||
let runtime = get_runtime();
|
||||
runtime
|
||||
.clone()
|
||||
.bridge_sync(async {
|
||||
let temp = tempdir().unwrap();
|
||||
let cas_path = temp.path().join("cas");
|
||||
@@ -769,7 +772,7 @@ mod tests {
|
||||
|
||||
let xfi = upload_data(&cas_path, original_data).await;
|
||||
|
||||
let config = TranslatorConfig::local_config(&cas_path).unwrap();
|
||||
let config = TranslatorConfig::local_config(&XetContext::default().unwrap(), &cas_path).unwrap();
|
||||
let session = FileDownloadSession::new(config.into(), None).await.unwrap();
|
||||
|
||||
let (_id, stream) = session.download_stream(&xfi, None).await.unwrap();
|
||||
@@ -792,9 +795,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_cancel_stream_before_start_returns_none() {
|
||||
let runtime = get_threadpool();
|
||||
let runtime = get_runtime();
|
||||
runtime
|
||||
.clone()
|
||||
.bridge_sync(async {
|
||||
let temp = tempdir().unwrap();
|
||||
let cas_path = temp.path().join("cas");
|
||||
@@ -802,7 +804,7 @@ mod tests {
|
||||
|
||||
let xfi = upload_data(&cas_path, original_data).await;
|
||||
|
||||
let config = TranslatorConfig::local_config(&cas_path).unwrap();
|
||||
let config = TranslatorConfig::local_config(&XetContext::default().unwrap(), &cas_path).unwrap();
|
||||
let session = FileDownloadSession::new(config.into(), None).await.unwrap();
|
||||
|
||||
let (_id, mut stream) = session.download_stream(&xfi, None).await.unwrap();
|
||||
@@ -815,9 +817,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_cancel_stream_after_first_chunk_returns_none() {
|
||||
let runtime = get_threadpool();
|
||||
let runtime = get_runtime();
|
||||
runtime
|
||||
.clone()
|
||||
.bridge_sync(async {
|
||||
let temp = tempdir().unwrap();
|
||||
let cas_path = temp.path().join("cas");
|
||||
@@ -825,7 +826,7 @@ mod tests {
|
||||
|
||||
let xfi = upload_data(&cas_path, original_data).await;
|
||||
|
||||
let config = TranslatorConfig::local_config(&cas_path).unwrap();
|
||||
let config = TranslatorConfig::local_config(&XetContext::default().unwrap(), &cas_path).unwrap();
|
||||
let session = FileDownloadSession::new(config.into(), None).await.unwrap();
|
||||
|
||||
let (_id, mut stream) = session.download_stream(&xfi, None).await.unwrap();
|
||||
@@ -845,7 +846,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_download_to_writer_range_from() {
|
||||
let runtime = get_threadpool();
|
||||
let runtime = get_runtime();
|
||||
runtime
|
||||
.clone()
|
||||
.external_run_async_task(async {
|
||||
@@ -855,7 +856,7 @@ mod tests {
|
||||
|
||||
let xfi = upload_data(&cas_path, original_data).await;
|
||||
|
||||
let config = TranslatorConfig::local_config(&cas_path).unwrap();
|
||||
let config = TranslatorConfig::local_config(&XetContext::default().unwrap(), &cas_path).unwrap();
|
||||
let session = FileDownloadSession::new(config.into(), None).await.unwrap();
|
||||
|
||||
let out_path = temp.path().join("range_from.bin");
|
||||
@@ -870,7 +871,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_download_to_writer_range_to() {
|
||||
let runtime = get_threadpool();
|
||||
let runtime = get_runtime();
|
||||
runtime
|
||||
.clone()
|
||||
.external_run_async_task(async {
|
||||
@@ -880,7 +881,7 @@ mod tests {
|
||||
|
||||
let xfi = upload_data(&cas_path, original_data).await;
|
||||
|
||||
let config = TranslatorConfig::local_config(&cas_path).unwrap();
|
||||
let config = TranslatorConfig::local_config(&XetContext::default().unwrap(), &cas_path).unwrap();
|
||||
let session = FileDownloadSession::new(config.into(), None).await.unwrap();
|
||||
|
||||
let out_path = temp.path().join("range_to.bin");
|
||||
@@ -895,7 +896,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_download_to_writer_full_range() {
|
||||
let runtime = get_threadpool();
|
||||
let runtime = get_runtime();
|
||||
runtime
|
||||
.clone()
|
||||
.external_run_async_task(async {
|
||||
@@ -905,7 +906,7 @@ mod tests {
|
||||
|
||||
let xfi = upload_data(&cas_path, original_data).await;
|
||||
|
||||
let config = TranslatorConfig::local_config(&cas_path).unwrap();
|
||||
let config = TranslatorConfig::local_config(&XetContext::default().unwrap(), &cas_path).unwrap();
|
||||
let session = FileDownloadSession::new(config.into(), None).await.unwrap();
|
||||
|
||||
let out_path = temp.path().join("full_range.bin");
|
||||
@@ -920,7 +921,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_download_to_writer_range_inclusive() {
|
||||
let runtime = get_threadpool();
|
||||
let runtime = get_runtime();
|
||||
runtime
|
||||
.clone()
|
||||
.external_run_async_task(async {
|
||||
@@ -930,7 +931,7 @@ mod tests {
|
||||
|
||||
let xfi = upload_data(&cas_path, original_data).await;
|
||||
|
||||
let config = TranslatorConfig::local_config(&cas_path).unwrap();
|
||||
let config = TranslatorConfig::local_config(&XetContext::default().unwrap(), &cas_path).unwrap();
|
||||
let session = FileDownloadSession::new(config.into(), None).await.unwrap();
|
||||
|
||||
let out_path = temp.path().join("range_incl.bin");
|
||||
@@ -947,7 +948,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_download_stream_range_bounded() {
|
||||
let runtime = get_threadpool();
|
||||
let runtime = get_runtime();
|
||||
runtime
|
||||
.clone()
|
||||
.external_run_async_task(async {
|
||||
@@ -957,7 +958,7 @@ mod tests {
|
||||
|
||||
let xfi = upload_data(&cas_path, original_data).await;
|
||||
|
||||
let config = TranslatorConfig::local_config(&cas_path).unwrap();
|
||||
let config = TranslatorConfig::local_config(&XetContext::default().unwrap(), &cas_path).unwrap();
|
||||
let session = FileDownloadSession::new(config.into(), None).await.unwrap();
|
||||
|
||||
let (_id, mut stream) = session.download_stream_range(&xfi, 4..12).await.unwrap();
|
||||
@@ -974,7 +975,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_download_stream_range_from() {
|
||||
let runtime = get_threadpool();
|
||||
let runtime = get_runtime();
|
||||
runtime
|
||||
.clone()
|
||||
.external_run_async_task(async {
|
||||
@@ -984,7 +985,7 @@ mod tests {
|
||||
|
||||
let xfi = upload_data(&cas_path, original_data).await;
|
||||
|
||||
let config = TranslatorConfig::local_config(&cas_path).unwrap();
|
||||
let config = TranslatorConfig::local_config(&XetContext::default().unwrap(), &cas_path).unwrap();
|
||||
let session = FileDownloadSession::new(config.into(), None).await.unwrap();
|
||||
|
||||
let (_id, mut stream) = session.download_stream_range(&xfi, 10..).await.unwrap();
|
||||
@@ -1001,7 +1002,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_download_stream_range_to() {
|
||||
let runtime = get_threadpool();
|
||||
let runtime = get_runtime();
|
||||
runtime
|
||||
.clone()
|
||||
.external_run_async_task(async {
|
||||
@@ -1011,7 +1012,7 @@ mod tests {
|
||||
|
||||
let xfi = upload_data(&cas_path, original_data).await;
|
||||
|
||||
let config = TranslatorConfig::local_config(&cas_path).unwrap();
|
||||
let config = TranslatorConfig::local_config(&XetContext::default().unwrap(), &cas_path).unwrap();
|
||||
let session = FileDownloadSession::new(config.into(), None).await.unwrap();
|
||||
|
||||
let (_id, mut stream) = session.download_stream_range(&xfi, ..6).await.unwrap();
|
||||
@@ -1030,7 +1031,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_download_file_unknown_size() {
|
||||
let runtime = get_threadpool();
|
||||
let runtime = get_runtime();
|
||||
runtime
|
||||
.clone()
|
||||
.external_run_async_task(async {
|
||||
@@ -1041,7 +1042,7 @@ mod tests {
|
||||
let xfi = upload_data(&cas_path, original_data).await;
|
||||
let xfi_no_size = XetFileInfo::new_hash_only(xfi.hash().to_string());
|
||||
|
||||
let config = TranslatorConfig::local_config(&cas_path).unwrap();
|
||||
let config = TranslatorConfig::local_config(&XetContext::default().unwrap(), &cas_path).unwrap();
|
||||
let session = FileDownloadSession::new(config.into(), None).await.unwrap();
|
||||
|
||||
let out_path = temp.path().join("output_unknown.txt");
|
||||
@@ -1055,7 +1056,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_download_stream_unknown_size() {
|
||||
let runtime = get_threadpool();
|
||||
let runtime = get_runtime();
|
||||
runtime
|
||||
.clone()
|
||||
.external_run_async_task(async {
|
||||
@@ -1066,7 +1067,7 @@ mod tests {
|
||||
let xfi = upload_data(&cas_path, original_data).await;
|
||||
let xfi_no_size = XetFileInfo::new_hash_only(xfi.hash().to_string());
|
||||
|
||||
let config = TranslatorConfig::local_config(&cas_path).unwrap();
|
||||
let config = TranslatorConfig::local_config(&XetContext::default().unwrap(), &cas_path).unwrap();
|
||||
let session = FileDownloadSession::new(config.into(), None).await.unwrap();
|
||||
|
||||
let (_id, mut stream) = session.download_stream(&xfi_no_size, None).await.unwrap();
|
||||
@@ -1084,7 +1085,7 @@ mod tests {
|
||||
#[cfg(not(debug_assertions))]
|
||||
#[test]
|
||||
fn test_download_file_size_mismatch_error() {
|
||||
let runtime = get_threadpool();
|
||||
let runtime = get_runtime();
|
||||
runtime
|
||||
.clone()
|
||||
.external_run_async_task(async {
|
||||
@@ -1095,7 +1096,7 @@ mod tests {
|
||||
let xfi = upload_data(&cas_path, original_data).await;
|
||||
let wrong_size_xfi = XetFileInfo::new(xfi.hash().to_string(), 999);
|
||||
|
||||
let config = TranslatorConfig::local_config(&cas_path).unwrap();
|
||||
let config = TranslatorConfig::local_config(&XetContext::default().unwrap(), &cas_path).unwrap();
|
||||
let session = FileDownloadSession::new(config.into(), None).await.unwrap();
|
||||
|
||||
let out_path = temp.path().join("output_mismatch.txt");
|
||||
@@ -1132,7 +1133,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_download_to_writer_empty_range() {
|
||||
let runtime = get_threadpool();
|
||||
let runtime = get_runtime();
|
||||
runtime
|
||||
.clone()
|
||||
.external_run_async_task(async {
|
||||
@@ -1142,7 +1143,7 @@ mod tests {
|
||||
|
||||
let xfi = upload_data(&cas_path, original_data).await;
|
||||
|
||||
let config = TranslatorConfig::local_config(&cas_path).unwrap();
|
||||
let config = TranslatorConfig::local_config(&XetContext::default().unwrap(), &cas_path).unwrap();
|
||||
let session = FileDownloadSession::new(config.into(), None).await.unwrap();
|
||||
|
||||
let out_path = temp.path().join("empty_range.bin");
|
||||
@@ -1157,7 +1158,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_download_to_writer_inverted_range_errors() {
|
||||
let runtime = get_threadpool();
|
||||
let runtime = get_runtime();
|
||||
runtime
|
||||
.clone()
|
||||
.external_run_async_task(async {
|
||||
@@ -1167,7 +1168,7 @@ mod tests {
|
||||
|
||||
let xfi = upload_data(&cas_path, original_data).await;
|
||||
|
||||
let config = TranslatorConfig::local_config(&cas_path).unwrap();
|
||||
let config = TranslatorConfig::local_config(&XetContext::default().unwrap(), &cas_path).unwrap();
|
||||
let session = FileDownloadSession::new(config.into(), None).await.unwrap();
|
||||
|
||||
let out_path = temp.path().join("inverted_range.bin");
|
||||
@@ -1182,7 +1183,7 @@ mod tests {
|
||||
#[cfg(not(debug_assertions))]
|
||||
#[test]
|
||||
fn test_download_to_writer_range_start_beyond_file_size_errors() {
|
||||
let runtime = get_threadpool();
|
||||
let runtime = get_runtime();
|
||||
runtime
|
||||
.clone()
|
||||
.external_run_async_task(async {
|
||||
@@ -1192,7 +1193,7 @@ mod tests {
|
||||
|
||||
let xfi = upload_data(&cas_path, original_data).await;
|
||||
|
||||
let config = TranslatorConfig::local_config(&cas_path).unwrap();
|
||||
let config = TranslatorConfig::local_config(&XetContext::default().unwrap(), &cas_path).unwrap();
|
||||
let session = FileDownloadSession::new(config.into(), None).await.unwrap();
|
||||
|
||||
let out_path = temp.path().join("beyond_size.bin");
|
||||
|
||||
@@ -15,16 +15,14 @@ use tracing::{Instrument, Span, info_span, instrument};
|
||||
use xet_client::cas_client::{Client, ProgressCallback};
|
||||
use xet_core_structures::metadata_shard::file_structs::MDBFileInfo;
|
||||
use xet_core_structures::xorb_object::SerializedXorbObject;
|
||||
use xet_runtime::core::{XetRuntime, xet_config};
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
use super::XetFileInfo;
|
||||
use super::configurations::TranslatorConfig;
|
||||
use super::file_cleaner::{Sha256Policy, SingleFileCleaner};
|
||||
use super::remote_client_interface::create_remote_client;
|
||||
use super::shard_interface::SessionShardInterface;
|
||||
use crate::deduplication::constants::{
|
||||
MAX_XORB_BYTES, MAX_XORB_CHUNKS, XORB_CUT_THRESHOLD_BYTES, XORB_CUT_THRESHOLD_CHUNKS,
|
||||
};
|
||||
use crate::deduplication::constants::{MAX_XORB_BYTES, MAX_XORB_CHUNKS};
|
||||
use crate::deduplication::{DataAggregator, DeduplicationMetrics, RawXorbData};
|
||||
use crate::error::{DataError, Result};
|
||||
use crate::progress_tracking::upload_tracking::{CompletionTracker, FileXorbDependency};
|
||||
@@ -37,6 +35,7 @@ use crate::progress_tracking::{GroupProgress, GroupProgressReport, ItemProgressR
|
||||
/// that succeeds or fails as a unit; i.e. all files get uploaded on finalization, and all shards
|
||||
/// and xorbs needed to reconstruct those files are properly uploaded and registered.
|
||||
pub struct FileUploadSession {
|
||||
pub(crate) ctx: XetContext,
|
||||
pub(crate) client: Arc<dyn Client + Send + Sync>,
|
||||
pub(crate) shard_interface: SessionShardInterface,
|
||||
|
||||
@@ -70,6 +69,7 @@ impl FileUploadSession {
|
||||
}
|
||||
|
||||
async fn new_impl(config: Arc<TranslatorConfig>, dry_run: bool) -> Result<Arc<FileUploadSession>> {
|
||||
let ctx = config.ctx.clone();
|
||||
let session_id = config
|
||||
.session
|
||||
.session_id
|
||||
@@ -78,16 +78,17 @@ impl FileUploadSession {
|
||||
.unwrap_or_else(|| Cow::Owned(UniqueID::new().to_string()));
|
||||
|
||||
let progress = GroupProgress::with_speed_config(
|
||||
xet_config().data.progress_update_speed_sampling_window,
|
||||
xet_config().data.progress_update_speed_min_observations,
|
||||
ctx.config.data.progress_update_speed_sampling_window,
|
||||
ctx.config.data.progress_update_speed_min_observations,
|
||||
);
|
||||
let completion_tracker = Arc::new(CompletionTracker::new(progress.clone()));
|
||||
|
||||
let client = create_remote_client(&config, &session_id, dry_run).await?;
|
||||
|
||||
let shard_interface = SessionShardInterface::new(config.clone(), client.clone(), dry_run).await?;
|
||||
let shard_interface = SessionShardInterface::new(&ctx, config.clone(), client.clone(), dry_run).await?;
|
||||
|
||||
Ok(Arc::new(Self {
|
||||
ctx,
|
||||
shard_interface,
|
||||
client,
|
||||
completion_tracker,
|
||||
@@ -115,7 +116,8 @@ impl FileUploadSession {
|
||||
let updater = self.progress.new_item(UniqueID::new(), file_name.clone());
|
||||
let file_id = self.completion_tracker.register_new_file(updater, Some(file_size));
|
||||
|
||||
let ingestion_concurrency_limiter = XetRuntime::current().common().file_ingestion_semaphore.clone();
|
||||
let ingestion_concurrency_limiter = self.ctx.common.file_ingestion_semaphore.clone();
|
||||
let ingestion_block_size = *self.ctx.config.data.ingestion_block_size;
|
||||
let session = self.clone();
|
||||
|
||||
cleaning_tasks.push(tokio::spawn(async move {
|
||||
@@ -144,7 +146,7 @@ impl FileUploadSession {
|
||||
while bytes_read < file_size {
|
||||
// Allocate a block of bytes, read into it.
|
||||
let bytes_left = file_size - bytes_read;
|
||||
let n_bytes_read = (*xet_config().data.ingestion_block_size).min(bytes_left) as usize;
|
||||
let n_bytes_read = ingestion_block_size.min(bytes_left) as usize;
|
||||
|
||||
// Read in the data here; we are assuming the file doesn't change size
|
||||
// on the disk while we are reading it.
|
||||
@@ -244,11 +246,12 @@ impl FileUploadSession {
|
||||
let tracking_name: Arc<str> = Arc::from(file_path.to_string_lossy().as_ref());
|
||||
let (id, cleaner) = self.start_clean(Some(tracking_name), Some(file_size), sha256)?;
|
||||
|
||||
let rt = XetRuntime::current();
|
||||
let semaphore = rt.common().file_ingestion_semaphore.clone();
|
||||
let handle = rt.spawn(async move {
|
||||
let session = self.clone();
|
||||
let runtime = self.ctx.runtime.clone();
|
||||
let semaphore = self.ctx.common.file_ingestion_semaphore.clone();
|
||||
let handle = runtime.spawn(async move {
|
||||
let _permit = semaphore.acquire().await?;
|
||||
Self::feed_file_to_cleaner(cleaner, &file_path).await
|
||||
Self::feed_file_to_cleaner(&session, cleaner, &file_path).await
|
||||
});
|
||||
|
||||
Ok((id, handle))
|
||||
@@ -266,9 +269,9 @@ impl FileUploadSession {
|
||||
self.check_not_finalized()?;
|
||||
let (id, mut cleaner) = self.start_clean(tracking_name, Some(bytes.len() as u64), sha256)?;
|
||||
|
||||
let rt = XetRuntime::current();
|
||||
let semaphore = rt.common().file_ingestion_semaphore.clone();
|
||||
let handle = rt.spawn(async move {
|
||||
let runtime = self.ctx.runtime.clone();
|
||||
let semaphore = self.ctx.common.file_ingestion_semaphore.clone();
|
||||
let handle = runtime.spawn(async move {
|
||||
let _permit = semaphore.acquire().await?;
|
||||
cleaner.add_data(&bytes).await?;
|
||||
cleaner.finish().await
|
||||
@@ -278,12 +281,13 @@ impl FileUploadSession {
|
||||
}
|
||||
|
||||
async fn feed_file_to_cleaner(
|
||||
_session: &Arc<Self>,
|
||||
mut cleaner: SingleFileCleaner,
|
||||
file_path: &Path,
|
||||
) -> Result<(XetFileInfo, DeduplicationMetrics)> {
|
||||
let mut reader = File::open(file_path)?;
|
||||
let filesize = reader.metadata()?.len();
|
||||
let mut buffer = vec![0u8; u64::min(filesize, *xet_config().data.ingestion_block_size) as usize];
|
||||
let mut buffer = vec![0u8; u64::min(filesize, *_session.ctx.config.data.ingestion_block_size) as usize];
|
||||
|
||||
loop {
|
||||
let n = reader.read(&mut buffer)?;
|
||||
@@ -343,13 +347,23 @@ impl FileUploadSession {
|
||||
|
||||
// Serialize the object; this can be relatively expensive, so run it on a compute thread.
|
||||
// XORBs are sent without footer - the server/client reconstructs it from chunk data.
|
||||
let xorb_obj = XetRuntime::current()
|
||||
.spawn_blocking(move || SerializedXorbObject::from_xorb(xorb, false))
|
||||
let runtime = self.ctx.runtime.clone();
|
||||
let compression_policy = self.ctx.config.xorb.compression_policy.clone();
|
||||
let compression_scheme_retest_interval = self.ctx.config.xorb.compression_scheme_retest_interval;
|
||||
let xorb_obj = runtime
|
||||
.spawn_blocking(move || {
|
||||
SerializedXorbObject::from_xorb(
|
||||
xorb,
|
||||
false,
|
||||
compression_policy.as_str(),
|
||||
compression_scheme_retest_interval,
|
||||
)
|
||||
})
|
||||
.await??;
|
||||
|
||||
let session = self.clone();
|
||||
let upload_permit = self.client.acquire_upload_permit().await?;
|
||||
let cas_prefix = xet_config().data.default_prefix.clone();
|
||||
let cas_prefix = self.ctx.config.data.default_prefix.clone();
|
||||
let completion_tracker = self.completion_tracker.clone();
|
||||
let xorb_hash = xorb_obj.hash;
|
||||
let raw_num_bytes = xorb_obj.raw_num_bytes;
|
||||
@@ -397,9 +411,30 @@ impl FileUploadSession {
|
||||
{
|
||||
let mut current_session_data = self.current_session_data.lock().await;
|
||||
|
||||
#[cfg(feature = "simulation")]
|
||||
let xorb_cut_bytes = self
|
||||
.ctx
|
||||
.config
|
||||
.xorb
|
||||
.simulation_max_bytes
|
||||
.map(|bs| (bs.as_u64() as usize).min(*MAX_XORB_BYTES))
|
||||
.unwrap_or(*MAX_XORB_BYTES);
|
||||
#[cfg(not(feature = "simulation"))]
|
||||
let xorb_cut_bytes = *MAX_XORB_BYTES;
|
||||
#[cfg(feature = "simulation")]
|
||||
let xorb_cut_chunks = self
|
||||
.ctx
|
||||
.config
|
||||
.xorb
|
||||
.simulation_max_chunks
|
||||
.unwrap_or(*MAX_XORB_CHUNKS)
|
||||
.min(*MAX_XORB_CHUNKS);
|
||||
#[cfg(not(feature = "simulation"))]
|
||||
let xorb_cut_chunks = *MAX_XORB_CHUNKS;
|
||||
|
||||
// Do we need to cut one of these to a xorb?
|
||||
if current_session_data.num_bytes() + file_data.num_bytes() > *XORB_CUT_THRESHOLD_BYTES
|
||||
|| current_session_data.num_chunks() + file_data.num_chunks() > *XORB_CUT_THRESHOLD_CHUNKS
|
||||
if current_session_data.num_bytes() + file_data.num_bytes() > xorb_cut_bytes
|
||||
|| current_session_data.num_chunks() + file_data.num_chunks() > xorb_cut_chunks
|
||||
{
|
||||
// Cut the larger one as a xorb, uploading it and registering the files.
|
||||
if current_session_data.num_bytes() > file_data.num_bytes() {
|
||||
@@ -577,20 +612,11 @@ mod tests {
|
||||
use std::fs::{File, OpenOptions};
|
||||
use std::io::{Read, Write};
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, OnceLock};
|
||||
|
||||
use xet_runtime::core::XetRuntime;
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
use crate::processing::{FileDownloadSession, FileUploadSession, XetFileInfo};
|
||||
|
||||
/// Return a shared threadpool to be reused as needed.
|
||||
fn get_threadpool() -> Arc<XetRuntime> {
|
||||
static THREADPOOL: OnceLock<Arc<XetRuntime>> = OnceLock::new();
|
||||
THREADPOOL
|
||||
.get_or_init(|| XetRuntime::new().expect("Error starting multithreaded runtime."))
|
||||
.clone()
|
||||
}
|
||||
|
||||
/// Cleans (converts) a regular file into a pointer file.
|
||||
///
|
||||
/// * `input_path`: path to the original file
|
||||
@@ -607,7 +633,8 @@ mod tests {
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let upload_session = FileUploadSession::new(TranslatorConfig::local_config(cas_path).unwrap().into())
|
||||
let ctx = XetContext::default().unwrap();
|
||||
let upload_session = FileUploadSession::new(TranslatorConfig::local_config(&ctx, cas_path).unwrap().into())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -638,7 +665,8 @@ mod tests {
|
||||
|
||||
let xet_file = serde_json::from_str::<XetFileInfo>(&input).unwrap();
|
||||
|
||||
let config = TranslatorConfig::local_config(cas_path).unwrap();
|
||||
let ctx = XetContext::default().unwrap();
|
||||
let config = TranslatorConfig::local_config(&ctx, cas_path).unwrap();
|
||||
let session = FileDownloadSession::new(config.into(), None).await.unwrap();
|
||||
|
||||
let (_id, _n_bytes) = session.download_file(&xet_file, output_path).await.unwrap();
|
||||
@@ -655,10 +683,9 @@ mod tests {
|
||||
let temp = tempdir().unwrap();
|
||||
let original_data = b"Hello, world!";
|
||||
|
||||
let runtime = get_threadpool();
|
||||
let ctx = XetContext::default().unwrap();
|
||||
|
||||
runtime
|
||||
.clone()
|
||||
ctx.runtime
|
||||
.bridge_sync(async move {
|
||||
let cas_path = temp.path().join("cas");
|
||||
|
||||
@@ -686,16 +713,17 @@ mod tests {
|
||||
let temp = tempdir().unwrap();
|
||||
let data = b"Hello, skip sha256!";
|
||||
|
||||
let runtime = get_threadpool();
|
||||
let ctx = XetContext::default().unwrap();
|
||||
|
||||
runtime
|
||||
.clone()
|
||||
ctx.runtime
|
||||
.bridge_sync(async move {
|
||||
let cas_path = temp.path().join("cas");
|
||||
|
||||
let upload_session = FileUploadSession::new(TranslatorConfig::local_config(&cas_path).unwrap().into())
|
||||
.await
|
||||
.unwrap();
|
||||
let session_ctx = XetContext::default().unwrap();
|
||||
let upload_session =
|
||||
FileUploadSession::new(TranslatorConfig::local_config(&session_ctx, &cas_path).unwrap().into())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let (_id, mut cleaner) = upload_session
|
||||
.start_clean(Some("test".into()), Some(data.len() as u64), Sha256Policy::Skip)
|
||||
|
||||
@@ -3,9 +3,9 @@ use std::sync::Arc;
|
||||
use http::header;
|
||||
use tracing::{Instrument, Span, info_span, instrument};
|
||||
use xet_client::cas_client::auth::TokenRefresher;
|
||||
use xet_client::hub_client::{BearerCredentialHelper, HubClient, Operation, RepoInfo};
|
||||
use xet_client::hub_client::{BearerCredentialHelper, CredentialHelper, HubClient, Operation, RepoInfo};
|
||||
use xet_core_structures::metadata_shard::file_structs::MDBFileInfo;
|
||||
use xet_runtime::core::XetRuntime;
|
||||
use xet_runtime::core::XetContext;
|
||||
use xet_runtime::core::par_utils::run_constrained;
|
||||
|
||||
use super::super::data_client::{clean_file, default_config};
|
||||
@@ -37,16 +37,18 @@ pub async fn migrate_with_external_runtime(
|
||||
let cred_helper = BearerCredentialHelper::new(hub_token.to_owned(), "");
|
||||
let mut headers = header::HeaderMap::new();
|
||||
headers.insert(header::USER_AGENT, header::HeaderValue::from_static(USER_AGENT));
|
||||
let ctx = XetContext::default()?;
|
||||
let hub_client = HubClient::new(
|
||||
ctx.clone(),
|
||||
hub_endpoint,
|
||||
RepoInfo::try_from(repo_type, repo_id)?,
|
||||
Some("main".to_owned()),
|
||||
"",
|
||||
Some(cred_helper),
|
||||
Some(cred_helper as Arc<dyn CredentialHelper>),
|
||||
Some(headers),
|
||||
)?;
|
||||
|
||||
migrate_files_impl(file_paths, sha256s, false, hub_client, cas_endpoint, false).await?;
|
||||
migrate_files_impl(&ctx, file_paths, sha256s, false, hub_client, cas_endpoint, false).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -56,6 +58,7 @@ pub type MigrationInfo = (Vec<MDBFileInfo>, Vec<(XetFileInfo, u64)>, u64);
|
||||
|
||||
#[instrument(skip_all, name = "migrate_files", fields(session_id = tracing::field::Empty, num_files = file_paths.len()))]
|
||||
pub async fn migrate_files_impl(
|
||||
ctx: &XetContext,
|
||||
file_paths: Vec<String>,
|
||||
sha256s: Option<Vec<String>>,
|
||||
sequential: bool,
|
||||
@@ -76,6 +79,7 @@ pub async fn migrate_files_impl(
|
||||
headers.insert(http::header::USER_AGENT, http::HeaderValue::from_static(USER_AGENT));
|
||||
|
||||
let config = default_config(
|
||||
ctx,
|
||||
cas,
|
||||
Some((jwt_info.access_token, jwt_info.exp)),
|
||||
Some(token_refresher),
|
||||
@@ -86,7 +90,7 @@ pub async fn migrate_files_impl(
|
||||
let num_workers = if sequential {
|
||||
1
|
||||
} else {
|
||||
XetRuntime::current().num_worker_threads()
|
||||
ctx.runtime.num_worker_threads()
|
||||
};
|
||||
let processor = if dry_run {
|
||||
FileUploadSession::dry_run(config.into()).await?
|
||||
|
||||
@@ -11,12 +11,13 @@ pub(crate) async fn create_remote_client(
|
||||
dry_run: bool,
|
||||
) -> Result<Arc<dyn Client>> {
|
||||
let session = &config.session;
|
||||
let runtime = config.ctx.clone();
|
||||
|
||||
if let Some(local_path) = session.local_path() {
|
||||
if let Some(local_path) = session.local_path(&config.ctx) {
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
{
|
||||
let xorb_path = local_path.join("xet").join("xorbs");
|
||||
Ok(xet_client::cas_client::LocalClient::new(xorb_path).await?)
|
||||
Ok(xet_client::cas_client::LocalClient::new(runtime, xorb_path).await?)
|
||||
}
|
||||
#[cfg(target_family = "wasm")]
|
||||
{
|
||||
@@ -24,9 +25,10 @@ pub(crate) async fn create_remote_client(
|
||||
unimplemented!("Local file system access is not available in WASM")
|
||||
}
|
||||
} else if session.is_memory() {
|
||||
Ok(xet_client::cas_client::MemoryClient::new())
|
||||
Ok(xet_client::cas_client::MemoryClient::new(runtime))
|
||||
} else {
|
||||
Ok(RemoteClient::new(
|
||||
runtime,
|
||||
&session.endpoint,
|
||||
&session.auth,
|
||||
session_id,
|
||||
|
||||
@@ -1,15 +1,20 @@
|
||||
use sha2::{Digest, Sha256 as sha2Sha256};
|
||||
use tokio::task::{JoinError, JoinHandle};
|
||||
use xet_core_structures::metadata_shard::Sha256;
|
||||
use xet_runtime::core::XetRuntime;
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
/// Helper struct to generate a sha256 hash.
|
||||
#[derive(Debug, Default)]
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct Sha256Generator {
|
||||
ctx: XetContext,
|
||||
hasher: Option<JoinHandle<Result<sha2Sha256, JoinError>>>,
|
||||
}
|
||||
|
||||
impl Sha256Generator {
|
||||
pub(crate) fn new(ctx: XetContext) -> Self {
|
||||
Self { ctx, hasher: None }
|
||||
}
|
||||
|
||||
/// Complete the last block, then hand off the new chunks to the new hasher.
|
||||
pub async fn update(&mut self, new_data: impl AsRef<[u8]> + Send + Sync + 'static) -> Result<(), JoinError> {
|
||||
let mut hasher = match self.hasher.take() {
|
||||
@@ -19,8 +24,8 @@ impl Sha256Generator {
|
||||
|
||||
// The previous task returns the hasher; we consume that and pass it on.
|
||||
// Use the compute background thread for this process.
|
||||
let rt = XetRuntime::current();
|
||||
self.hasher = Some(rt.spawn_blocking(move || {
|
||||
let runtime = self.ctx.runtime.clone();
|
||||
self.hasher = Some(runtime.spawn_blocking(move || {
|
||||
hasher.update(&new_data);
|
||||
|
||||
Ok(hasher)
|
||||
@@ -57,7 +62,7 @@ mod sha_tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sha_generation_builder() {
|
||||
let mut sha_generator = Sha256Generator::default();
|
||||
let mut sha_generator = Sha256Generator::new(xet_runtime::core::XetContext::default().unwrap());
|
||||
sha_generator.update(TEST_DATA.as_bytes()).await.unwrap();
|
||||
let hash = sha_generator.finalize().await.unwrap();
|
||||
|
||||
@@ -66,7 +71,7 @@ mod sha_tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sha_generation_build_multiple_chunks() {
|
||||
let mut sha_generator = Sha256Generator::default();
|
||||
let mut sha_generator = Sha256Generator::new(xet_runtime::core::XetContext::default().unwrap());
|
||||
let td = TEST_DATA.as_bytes();
|
||||
sha_generator.update(&td[0..4]).await.unwrap();
|
||||
sha_generator.update(&td[4..td.len()]).await.unwrap();
|
||||
@@ -83,7 +88,7 @@ mod sha_tests {
|
||||
let mut rand_data = [0u8; 4096];
|
||||
rng().fill(&mut rand_data[..]);
|
||||
|
||||
let mut sha_generator = Sha256Generator::default();
|
||||
let mut sha_generator = Sha256Generator::new(xet_runtime::core::XetContext::default().unwrap());
|
||||
|
||||
// Add in random chunks.
|
||||
let mut pos = 0;
|
||||
|
||||
@@ -19,15 +19,16 @@ use xet_core_structures::metadata_shard::session_directory::{
|
||||
use xet_core_structures::metadata_shard::shard_in_memory::MDBInMemoryShard;
|
||||
use xet_core_structures::metadata_shard::xorb_structs::MDBXorbInfo;
|
||||
use xet_core_structures::metadata_shard::{
|
||||
MDB_SHARD_LOCAL_CACHE_EXPIRATION, MDBShardFile, MDBShardFileHeader, ShardFileManager,
|
||||
MDB_SHARD_LOCAL_CACHE_EXPIRATION, MDBShardFile, MDBShardFileHeader, ShardFileManager, get_shard_file_cache,
|
||||
};
|
||||
use xet_runtime::core::xet_config;
|
||||
use xet_runtime::core::XetContext;
|
||||
use xet_runtime::error_printer::ErrorPrinter;
|
||||
|
||||
use super::configurations::TranslatorConfig;
|
||||
use crate::error::Result;
|
||||
|
||||
pub struct SessionShardInterface {
|
||||
ctx: XetContext,
|
||||
session_shard_manager: Arc<ShardFileManager>,
|
||||
cache_shard_manager: Arc<ShardFileManager>,
|
||||
|
||||
@@ -54,6 +55,7 @@ pub struct SessionShardInterface {
|
||||
|
||||
impl SessionShardInterface {
|
||||
pub async fn new(
|
||||
ctx: &XetContext,
|
||||
config: Arc<TranslatorConfig>,
|
||||
client: Arc<dyn Client + Send + Sync>,
|
||||
dry_run: bool,
|
||||
@@ -78,10 +80,12 @@ impl SessionShardInterface {
|
||||
let shard_merge_jh = {
|
||||
if !dry_run {
|
||||
Some(merge_shards_background(
|
||||
ctx.runtime.clone(),
|
||||
&xorb_metadata_staging_dir,
|
||||
&session_dir,
|
||||
xet_config().shard.max_target_size,
|
||||
ctx.config.shard.max_target_size,
|
||||
true,
|
||||
get_shard_file_cache(&ctx.common),
|
||||
))
|
||||
} else {
|
||||
None
|
||||
@@ -89,8 +93,8 @@ impl SessionShardInterface {
|
||||
};
|
||||
|
||||
// Load the cache and session shard managers.
|
||||
let cache_shard_manager = ShardFileManager::new_in_cache_directory(cache_dir).await?;
|
||||
let session_shard_manager = ShardFileManager::new_in_session_directory(&session_dir, false).await?;
|
||||
let cache_shard_manager = ShardFileManager::new_in_cache_directory(ctx, cache_dir).await?;
|
||||
let session_shard_manager = ShardFileManager::new_in_session_directory(ctx, &session_dir, false).await?;
|
||||
|
||||
// Get the new merged shard handles here.
|
||||
let shard_merge_result = {
|
||||
@@ -106,7 +110,7 @@ impl SessionShardInterface {
|
||||
if !shard_merge_result.merged_shards.is_empty() {
|
||||
// Create a new shard manager to just hold the resumed session shards
|
||||
let resumed_session_shard_manager =
|
||||
ShardFileManager::new_in_session_directory(&session_dir, false).await?;
|
||||
ShardFileManager::new_in_session_directory(ctx, &session_dir, false).await?;
|
||||
|
||||
resumed_session_shard_manager
|
||||
.register_shards(&shard_merge_result.merged_shards)
|
||||
@@ -122,6 +126,7 @@ impl SessionShardInterface {
|
||||
shard_merge_result.obsolete_shards.iter().map(|sfi| sfi.path.clone()).collect();
|
||||
|
||||
Ok(Self {
|
||||
ctx: ctx.clone(),
|
||||
session_shard_manager,
|
||||
cache_shard_manager,
|
||||
xorb_metadata_staging_dir,
|
||||
@@ -138,7 +143,7 @@ impl SessionShardInterface {
|
||||
pub async fn query_dedup_shard_by_chunk(&self, chunk_hash: &MerkleHash) -> Result<bool> {
|
||||
let Ok(Some(new_shard)) = self
|
||||
.client
|
||||
.query_for_global_dedup_shard(&xet_config().data.default_prefix, chunk_hash)
|
||||
.query_for_global_dedup_shard(&self.ctx.config.data.default_prefix, chunk_hash)
|
||||
.await
|
||||
.info_error("Error attempting to query global dedup lookup.")
|
||||
else {
|
||||
@@ -201,11 +206,11 @@ impl SessionShardInterface {
|
||||
xorb_shard.add_xorb_block(xorb_block_contents)?;
|
||||
|
||||
let time_now = SystemTime::now();
|
||||
let flush_interval = xet_config().data.session_xorb_metadata_flush_interval;
|
||||
let flush_interval = self.ctx.config.data.session_xorb_metadata_flush_interval;
|
||||
|
||||
// Flush if it's time or we've hit enough new shards that we should do the flush
|
||||
if *last_flush + flush_interval < time_now
|
||||
|| xorb_shard.num_xorb_entries() >= xet_config().data.session_xorb_metadata_flush_max_count
|
||||
|| xorb_shard.num_xorb_entries() >= self.ctx.config.data.session_xorb_metadata_flush_max_count
|
||||
{
|
||||
xorb_shard.write_to_directory(&self.xorb_metadata_staging_dir, Some(*MDB_SHARD_LOCAL_CACHE_EXPIRATION))?;
|
||||
|
||||
@@ -237,11 +242,11 @@ impl SessionShardInterface {
|
||||
|
||||
// First, scan, merge, and fill out any shards in the session directory
|
||||
let shard_list = consolidate_shards_in_directory(
|
||||
&self.ctx.runtime,
|
||||
self.session_shard_manager.shard_directory(),
|
||||
xet_config().shard.max_target_size,
|
||||
// Here, we want to error out if some of the information isn't present or corrupt, so set skip_on_error to
|
||||
// false.
|
||||
self.ctx.config.shard.max_target_size,
|
||||
false,
|
||||
self.session_shard_manager.shard_file_cache(),
|
||||
)?;
|
||||
|
||||
// Upload all the shards and move each to the common directory.
|
||||
@@ -286,6 +291,7 @@ impl SessionShardInterface {
|
||||
let new_shard_path = si.export_with_expiration(
|
||||
cache_shard_manager.shard_directory(),
|
||||
*MDB_SHARD_LOCAL_CACHE_EXPIRATION,
|
||||
cache_shard_manager.shard_file_cache(),
|
||||
)?;
|
||||
|
||||
// Register that new shard in the cache shard manager
|
||||
@@ -347,7 +353,8 @@ mod tests {
|
||||
let mdb_in_mem = MDBInMemoryShard::default();
|
||||
let temp_shard_file_path = mdb_in_mem.write_to_directory(tmp_dir_path, None)?;
|
||||
|
||||
let shard_file = MDBShardFile::load_from_file(&temp_shard_file_path)?;
|
||||
let sfc = xet_core_structures::metadata_shard::new_shard_file_cache();
|
||||
let shard_file = MDBShardFile::load_from_file(&temp_shard_file_path, &sfc)?;
|
||||
assert_eq!(
|
||||
shard_file.shard.header.footer_size,
|
||||
size_of::<xet_core_structures::metadata_shard::MDBShardFileFooter>() as u64
|
||||
|
||||
@@ -9,6 +9,7 @@ use tempfile::TempDir;
|
||||
use xet_client::cas_client::{Client, LocalClient};
|
||||
#[cfg(feature = "simulation")]
|
||||
use xet_client::cas_client::{LocalTestServer, LocalTestServerBuilder};
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
use super::configurations::TranslatorConfig;
|
||||
use super::data_client::clean_file;
|
||||
@@ -187,6 +188,7 @@ pub struct HydrateDehydrateTest {
|
||||
pub src_dir: PathBuf,
|
||||
pub ptr_dir: PathBuf,
|
||||
pub dest_dir: PathBuf,
|
||||
ctx: XetContext,
|
||||
use_test_server: bool,
|
||||
/// Kept alive so the test server stays running for the duration of the test.
|
||||
#[cfg(feature = "simulation")]
|
||||
@@ -224,6 +226,7 @@ impl HydrateDehydrateTest {
|
||||
src_dir,
|
||||
ptr_dir,
|
||||
dest_dir,
|
||||
ctx: XetContext::default().expect("xet context"),
|
||||
_temp_dir,
|
||||
use_test_server,
|
||||
#[cfg(feature = "simulation")]
|
||||
@@ -263,7 +266,9 @@ impl HydrateDehydrateTest {
|
||||
#[cfg(feature = "simulation")]
|
||||
pub async fn ensure_server_created(&mut self) {
|
||||
if self.use_test_server && self.test_server.is_none() {
|
||||
let local_client = LocalClient::new(self.cas_dir.join("xet/xorbs")).await.unwrap();
|
||||
let local_client = LocalClient::new(self.ctx.clone(), self.cas_dir.join("xet/xorbs"))
|
||||
.await
|
||||
.unwrap();
|
||||
self.test_server = Some(LocalTestServerBuilder::new().with_client(local_client).start().await);
|
||||
}
|
||||
}
|
||||
@@ -280,7 +285,9 @@ impl HydrateDehydrateTest {
|
||||
#[cfg(feature = "simulation")]
|
||||
{
|
||||
if self.test_server.is_none() {
|
||||
let local_client = LocalClient::new(self.cas_dir.join("xet/xorbs")).await.unwrap();
|
||||
let local_client = LocalClient::new(self.ctx.clone(), self.cas_dir.join("xet/xorbs"))
|
||||
.await
|
||||
.unwrap();
|
||||
self.test_server = Some(LocalTestServerBuilder::new().with_client(local_client).start().await);
|
||||
}
|
||||
self.test_server.as_ref().unwrap().remote_client().clone() as Arc<dyn Client>
|
||||
@@ -290,12 +297,14 @@ impl HydrateDehydrateTest {
|
||||
panic!("test server requires the 'simulation' feature");
|
||||
}
|
||||
} else {
|
||||
LocalClient::new(self.cas_dir.join("xet/xorbs")).await.unwrap() as Arc<dyn Client>
|
||||
LocalClient::new(self.ctx.clone(), self.cas_dir.join("xet/xorbs"))
|
||||
.await
|
||||
.unwrap() as Arc<dyn Client>
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn new_upload_session(&self) -> Arc<FileUploadSession> {
|
||||
let config = Arc::new(TranslatorConfig::local_config(&self.cas_dir).unwrap());
|
||||
let config = Arc::new(TranslatorConfig::local_config(&self.ctx, &self.cas_dir).unwrap());
|
||||
FileUploadSession::new(config.clone()).await.unwrap()
|
||||
}
|
||||
|
||||
@@ -345,7 +354,7 @@ impl HydrateDehydrateTest {
|
||||
|
||||
pub async fn hydrate(&mut self) {
|
||||
let client = self.get_or_create_client().await;
|
||||
let session = FileDownloadSession::from_client(client, None);
|
||||
let session = FileDownloadSession::from_client(&self.ctx, client, None);
|
||||
|
||||
for entry in read_dir(&self.ptr_dir).unwrap() {
|
||||
let entry = entry.unwrap();
|
||||
@@ -358,7 +367,7 @@ impl HydrateDehydrateTest {
|
||||
|
||||
pub async fn hydrate_partitioned_writers(&mut self, partitions: usize) {
|
||||
let client = self.get_or_create_client().await;
|
||||
let session = FileDownloadSession::from_client(client, None);
|
||||
let session = FileDownloadSession::from_client(&self.ctx, client, None);
|
||||
|
||||
for entry in read_dir(&self.ptr_dir).unwrap() {
|
||||
let entry = entry.unwrap();
|
||||
@@ -402,7 +411,7 @@ impl HydrateDehydrateTest {
|
||||
|
||||
pub async fn hydrate_stream(&mut self) {
|
||||
let client = self.get_or_create_client().await;
|
||||
let session = FileDownloadSession::from_client(client, None);
|
||||
let session = FileDownloadSession::from_client(&self.ctx, client, None);
|
||||
|
||||
for entry in read_dir(&self.ptr_dir).unwrap() {
|
||||
let entry = entry.unwrap();
|
||||
@@ -430,6 +439,7 @@ impl HydrateDehydrateTest {
|
||||
pub struct TestEnvironment {
|
||||
_temp_dir: TempDir,
|
||||
pub base_dir: PathBuf,
|
||||
pub ctx: XetContext,
|
||||
pub config: Arc<super::configurations::TranslatorConfig>,
|
||||
#[cfg(feature = "simulation")]
|
||||
_server: Option<LocalTestServer>,
|
||||
@@ -439,22 +449,25 @@ impl TestEnvironment {
|
||||
pub async fn new() -> Self {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let base_dir = temp_dir.path().to_path_buf();
|
||||
let ctx = XetContext::default().unwrap();
|
||||
|
||||
#[cfg(feature = "simulation")]
|
||||
let (config, server) = {
|
||||
let server = LocalTestServerBuilder::new().start().await;
|
||||
let config = Arc::new(
|
||||
super::configurations::TranslatorConfig::test_server_config(server.http_endpoint(), &base_dir).unwrap(),
|
||||
super::configurations::TranslatorConfig::test_server_config(&ctx, server.http_endpoint(), &base_dir)
|
||||
.unwrap(),
|
||||
);
|
||||
(config, Some(server))
|
||||
};
|
||||
|
||||
#[cfg(not(feature = "simulation"))]
|
||||
let config = Arc::new(super::configurations::TranslatorConfig::local_config(&base_dir).unwrap());
|
||||
let config = Arc::new(super::configurations::TranslatorConfig::local_config(&ctx, &base_dir).unwrap());
|
||||
|
||||
Self {
|
||||
_temp_dir: temp_dir,
|
||||
base_dir,
|
||||
ctx,
|
||||
config,
|
||||
#[cfg(feature = "simulation")]
|
||||
_server: server,
|
||||
|
||||
@@ -8,13 +8,15 @@ pub use xet_data::processing::data_client::hash_files_async;
|
||||
use xet_data::processing::data_client::{clean_bytes, default_config};
|
||||
use xet_data::processing::{FileDownloadSession, FileUploadSession, Sha256Policy, XetFileInfo};
|
||||
use xet_data::{DataError, Result};
|
||||
use xet_runtime::core::XetContext;
|
||||
use xet_runtime::core::par_utils::run_constrained_with_semaphore;
|
||||
use xet_runtime::core::{XetRuntime, xet_config};
|
||||
|
||||
use super::progress_tracking::{GroupProgressCallbackUpdater, ItemProgressCallbackUpdater, TrackingProgressUpdater};
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[instrument(skip_all, name = "data_client::upload_bytes", fields(session_id = tracing::field::Empty, num_files=file_contents.len()))]
|
||||
pub async fn upload_bytes_async(
|
||||
ctx: &XetContext,
|
||||
file_contents: Vec<Vec<u8>>,
|
||||
sha256_policies: Vec<Sha256Policy>,
|
||||
endpoint: Option<String>,
|
||||
@@ -32,7 +34,8 @@ pub async fn upload_bytes_async(
|
||||
}
|
||||
|
||||
let config = default_config(
|
||||
endpoint.unwrap_or_else(|| xet_config().data.default_cas_endpoint.clone()),
|
||||
ctx,
|
||||
endpoint.unwrap_or_else(|| ctx.config.data.default_cas_endpoint.clone()),
|
||||
token_info,
|
||||
token_refresher,
|
||||
custom_headers,
|
||||
@@ -40,7 +43,7 @@ pub async fn upload_bytes_async(
|
||||
|
||||
Span::current().record("session_id", &config.session.session_id);
|
||||
|
||||
let semaphore = XetRuntime::current().common().file_ingestion_semaphore.clone();
|
||||
let semaphore = ctx.common.file_ingestion_semaphore.clone();
|
||||
let upload_session = FileUploadSession::new(config.into()).await?;
|
||||
|
||||
let bridge = progress_updater.map(|updater| GroupProgressCallbackUpdater::start(upload_session.clone(), updater));
|
||||
@@ -61,6 +64,7 @@ pub async fn upload_bytes_async(
|
||||
Ok(files)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[instrument(skip_all, name = "data_client::upload_files",
|
||||
fields(session_id = tracing::field::Empty,
|
||||
num_files=file_paths.len(),
|
||||
@@ -72,6 +76,7 @@ pub async fn upload_bytes_async(
|
||||
defrag_prevented_dedup_chunks = tracing::field::Empty
|
||||
))]
|
||||
pub async fn upload_async(
|
||||
ctx: &XetContext,
|
||||
file_paths: Vec<String>,
|
||||
sha256_policies: Vec<Sha256Policy>,
|
||||
endpoint: Option<String>,
|
||||
@@ -89,7 +94,8 @@ pub async fn upload_async(
|
||||
}
|
||||
|
||||
let config = default_config(
|
||||
endpoint.unwrap_or_else(|| xet_config().data.default_cas_endpoint.clone()),
|
||||
ctx,
|
||||
endpoint.unwrap_or_else(|| ctx.config.data.default_cas_endpoint.clone()),
|
||||
token_info,
|
||||
token_refresher,
|
||||
custom_headers,
|
||||
@@ -125,6 +131,7 @@ pub async fn upload_async(
|
||||
|
||||
#[instrument(skip_all, name = "data_client::download", fields(session_id = tracing::field::Empty, num_files=file_infos.len()))]
|
||||
pub async fn download_async(
|
||||
ctx: &XetContext,
|
||||
file_infos: Vec<(XetFileInfo, String)>,
|
||||
endpoint: Option<String>,
|
||||
token_info: Option<(String, u64)>,
|
||||
@@ -138,7 +145,8 @@ pub async fn download_async(
|
||||
return Err(DataError::ParameterError("updaters are not same length as pointer_files".to_string()));
|
||||
}
|
||||
let config: Arc<_> = default_config(
|
||||
endpoint.unwrap_or_else(|| xet_config().data.default_cas_endpoint.clone()),
|
||||
ctx,
|
||||
endpoint.unwrap_or_else(|| ctx.config.data.default_cas_endpoint.clone()),
|
||||
token_info,
|
||||
token_refresher,
|
||||
custom_headers,
|
||||
|
||||
@@ -11,6 +11,7 @@ pub mod progress_tracking;
|
||||
|
||||
// Re-exports from xet_data so external consumers (hf_xet, git_xet) don't need
|
||||
// a direct xet_data dependency.
|
||||
pub use data_client::hash_files_async;
|
||||
pub use xet_data::processing::configurations::{SessionContext, TranslatorConfig};
|
||||
pub use xet_data::processing::data_client::{clean_bytes, clean_file, default_config, hash_files_async};
|
||||
pub use xet_data::processing::data_client::{clean_bytes, clean_file, default_config};
|
||||
pub use xet_data::processing::{FileDownloadSession, FileUploadSession, Sha256Policy, XetFileInfo};
|
||||
|
||||
@@ -39,8 +39,9 @@ pub(super) async fn create_translator_config(
|
||||
|
||||
// Build token refresher
|
||||
let token_refresher: Option<Arc<dyn TokenRefresher>> = if let Some((url, token_refresh_headers)) = token_refresh {
|
||||
let client = build_http_client(&session_id, None, Some(Arc::new(token_refresh_headers)))?;
|
||||
let direct_route_refresher = DirectRefreshRouteTokenRefresher::new(url, client, None);
|
||||
let client = build_http_client(&session.inner.ctx, &session_id, None, Some(Arc::new(token_refresh_headers)))?;
|
||||
let direct_route_refresher =
|
||||
DirectRefreshRouteTokenRefresher::new(session.inner.ctx.clone(), url, client, None);
|
||||
|
||||
// CAS endpoint is not provided but CAS token refresh endpoint is provided, we
|
||||
// refresh once to get the CAS endpoint, and fill the token info if nothing is provided.
|
||||
@@ -58,9 +59,10 @@ pub(super) async fn create_translator_config(
|
||||
None
|
||||
};
|
||||
|
||||
let endpoint = endpoint.unwrap_or_else(|| session.inner.config.data.default_cas_endpoint.clone());
|
||||
let endpoint = endpoint.unwrap_or_else(|| session.inner.ctx.config.data.default_cas_endpoint.clone());
|
||||
|
||||
let mut config = xet_data::processing::data_client::default_config(
|
||||
&session.inner.ctx,
|
||||
endpoint,
|
||||
token_info,
|
||||
token_refresher,
|
||||
|
||||
@@ -77,9 +77,9 @@ impl AuthGroupBuilder<XetDownloadStreamGroup> {
|
||||
/// API for creating authenticated streaming downloads.
|
||||
///
|
||||
/// Obtain via [`XetSession::new_download_stream_group`] — configure per-group
|
||||
/// auth on the returned [`DownloadStreamGroupBuilder`], then call
|
||||
/// [`build`](DownloadStreamGroupBuilder::build) (async) or
|
||||
/// [`build_blocking`](DownloadStreamGroupBuilder::build_blocking) (sync).
|
||||
/// auth on the returned [`XetDownloadStreamGroupBuilder`], then call
|
||||
/// [`build`](XetDownloadStreamGroupBuilder::build) (async) or
|
||||
/// [`build_blocking`](XetDownloadStreamGroupBuilder::build_blocking) (sync).
|
||||
///
|
||||
/// Create streams with [`download_stream`](Self::download_stream) /
|
||||
/// [`download_stream_blocking`](Self::download_stream_blocking) for ordered
|
||||
|
||||
@@ -400,7 +400,7 @@ mod tests {
|
||||
// finish() must block while download_file_to_path() holds the state lock.
|
||||
fn test_finish_blocked_while_download_registration_holds_state_lock() -> Result<()> {
|
||||
let session = XetSessionBuilder::new().build()?;
|
||||
let runtime = session.inner.runtime.clone();
|
||||
let runtime = session.inner.ctx.runtime.clone();
|
||||
let group = session.new_file_download_group()?.build_blocking()?;
|
||||
let group_for_thread = group.clone();
|
||||
let runtime_for_thread = runtime.clone();
|
||||
@@ -722,8 +722,8 @@ mod tests {
|
||||
tokio::time::sleep(
|
||||
session
|
||||
.inner
|
||||
.runtime
|
||||
.config()
|
||||
.ctx
|
||||
.config
|
||||
.data
|
||||
.progress_update_interval
|
||||
.saturating_add(Duration::from_secs(1)),
|
||||
@@ -842,7 +842,7 @@ mod tests {
|
||||
futures::executor::block_on(async {
|
||||
let session = XetSessionBuilder::new().build().unwrap();
|
||||
let endpoint = format!("local://{}", temp.path().join("cas").display());
|
||||
assert_eq!(session.inner.runtime.mode(), RuntimeMode::Owned);
|
||||
assert_eq!(session.inner.ctx.runtime.mode(), RuntimeMode::Owned);
|
||||
|
||||
let data = b"hello from futures executor";
|
||||
let file_info = {
|
||||
@@ -883,7 +883,7 @@ mod tests {
|
||||
smol::block_on(async {
|
||||
let session = XetSessionBuilder::new().build().unwrap();
|
||||
let endpoint = format!("local://{}", temp.path().join("cas").display());
|
||||
assert_eq!(session.inner.runtime.mode(), RuntimeMode::Owned);
|
||||
assert_eq!(session.inner.ctx.runtime.mode(), RuntimeMode::Owned);
|
||||
|
||||
let data = b"hello from smol executor";
|
||||
let file_info = {
|
||||
@@ -924,7 +924,7 @@ mod tests {
|
||||
async_std::task::block_on(async {
|
||||
let session = XetSessionBuilder::new().build().unwrap();
|
||||
let endpoint = format!("local://{}", temp.path().join("cas").display());
|
||||
assert_eq!(session.inner.runtime.mode(), RuntimeMode::Owned);
|
||||
assert_eq!(session.inner.ctx.runtime.mode(), RuntimeMode::Owned);
|
||||
|
||||
let data = b"hello from async-std executor";
|
||||
let file_info = {
|
||||
@@ -1035,8 +1035,8 @@ mod tests {
|
||||
std::thread::sleep(
|
||||
session
|
||||
.inner
|
||||
.runtime
|
||||
.config()
|
||||
.ctx
|
||||
.config
|
||||
.data
|
||||
.progress_update_interval
|
||||
.saturating_add(Duration::from_secs(1)),
|
||||
@@ -1125,7 +1125,7 @@ mod tests {
|
||||
// download_file_to_path_blocking returns WrongRuntimeMode on an External-mode session.
|
||||
async fn test_download_file_to_path_blocking_errors_in_external_mode() {
|
||||
let session = XetSessionBuilder::new().build().unwrap();
|
||||
assert_eq!(session.inner.runtime.mode(), RuntimeMode::External);
|
||||
assert_eq!(session.inner.ctx.runtime.mode(), RuntimeMode::External);
|
||||
let group = session.new_file_download_group().unwrap().build().await.unwrap();
|
||||
let file_info = XetFileInfo {
|
||||
hash: String::new(),
|
||||
@@ -1147,7 +1147,7 @@ mod tests {
|
||||
// because tokio sets a thread-local runtime context that it detects and rejects.
|
||||
fn test_download_file_to_path_blocking_panics_in_async_context() {
|
||||
let session = XetSessionBuilder::new().build().unwrap();
|
||||
assert_eq!(session.inner.runtime.mode(), RuntimeMode::Owned);
|
||||
assert_eq!(session.inner.ctx.runtime.mode(), RuntimeMode::Owned);
|
||||
let group = session.new_file_download_group().unwrap().build_blocking().unwrap();
|
||||
let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
|
||||
let file_info = XetFileInfo {
|
||||
|
||||
@@ -6,9 +6,8 @@ use std::sync::{Arc, Mutex, Weak};
|
||||
use tracing::info;
|
||||
use uuid::Uuid;
|
||||
use xet_data::progress_tracking::UniqueID;
|
||||
use xet_runtime::RuntimeError;
|
||||
use xet_runtime::config::XetConfig;
|
||||
use xet_runtime::core::XetRuntime;
|
||||
use xet_runtime::core::XetContext;
|
||||
#[cfg(feature = "fd-track")]
|
||||
use xet_runtime::fd_diagnostics::{report_fd_count, track_fd_scope};
|
||||
|
||||
@@ -24,11 +23,7 @@ use super::upload_commit::XetUploadCommitBuilder;
|
||||
/// Lives behind `Arc<XetSessionInner>` — do not use this type directly.
|
||||
#[doc(hidden)]
|
||||
pub struct XetSessionInner {
|
||||
// Independently cloned by background tasks, so needs its own Arc.
|
||||
pub(super) runtime: Arc<XetRuntime>,
|
||||
|
||||
// Only accessed through &self; no independent cloning needed.
|
||||
pub(super) config: XetConfig,
|
||||
pub(super) ctx: XetContext,
|
||||
|
||||
// Root of the cancellation tree. Child commits/groups create child TaskRuntimes via
|
||||
// task_runtime.child(), which links their cancellation tokens to this root. Calling
|
||||
@@ -150,11 +145,11 @@ impl XetSessionBuilder {
|
||||
/// drivers), it is silently ignored and [`build`](Self::build) will fall back to creating
|
||||
/// an owned thread pool instead.
|
||||
///
|
||||
/// If the handle is already in use by another live `XetSession`, [`build`](Self::build) will
|
||||
/// also fall back to creating an owned thread pool — the duplicate is logged at `INFO` level
|
||||
/// and no error is returned.
|
||||
/// Handles can be shared by multiple sessions. Each session gets its own
|
||||
/// [`XetContext`] (`config` + `common`), while the underlying runtime
|
||||
/// may be shared.
|
||||
pub fn with_tokio_handle(self, handle: tokio::runtime::Handle) -> Self {
|
||||
let accept = XetRuntime::handle_meets_requirements(&handle);
|
||||
let accept = XetContext::handle_meets_requirements(&handle);
|
||||
if !accept {
|
||||
info!("supplied tokio handle rejected (missing drivers or wrong flavor); falling back to Owned mode");
|
||||
}
|
||||
@@ -166,48 +161,25 @@ impl XetSessionBuilder {
|
||||
|
||||
/// Consume the builder and create a [`XetSession`].
|
||||
///
|
||||
/// If a tokio runtime handle is available (either from
|
||||
/// [`with_tokio_handle`](Self::with_tokio_handle) or auto-detected via
|
||||
/// `Handle::try_current()`), and it meets requirements, the session wraps
|
||||
/// it — no second thread pool is created. Otherwise, an owned multi-thread
|
||||
/// runtime is created; async methods use an internal bridge and work from
|
||||
/// any executor, and `_blocking` methods are available.
|
||||
/// Threadpool selection order:
|
||||
/// 1. Reuse the current owned runtime from thread-local storage, when present.
|
||||
/// 2. Otherwise, use a provided tokio handle (or auto-detected current handle) if valid.
|
||||
/// 3. Otherwise, create a new owned thread pool.
|
||||
///
|
||||
/// If the detected or provided handle is already registered to another live `XetSession`,
|
||||
/// the duplicate attach is silently rejected and an owned runtime is created instead.
|
||||
/// This prevents two sessions from fighting over the same tokio runtime's task scheduler.
|
||||
/// Each build creates a fresh [`XetContext`] around the selected runtime, so sessions
|
||||
/// can share the same execution backend while keeping independent config and common state.
|
||||
pub fn build(self) -> Result<XetSession, SessionError> {
|
||||
#[cfg(feature = "fd-track")]
|
||||
let _fd_scope = track_fd_scope("XetSessionBuilder::build");
|
||||
|
||||
let handle = self.tokio_handle.or_else(|| {
|
||||
tokio::runtime::Handle::try_current()
|
||||
.ok()
|
||||
.filter(XetRuntime::handle_meets_requirements)
|
||||
});
|
||||
|
||||
let runtime = match handle {
|
||||
Some(h) => {
|
||||
info!("XetSession using External runtime (wrapping caller's tokio handle)");
|
||||
let result = XetRuntime::from_external_with_config(h, self.config.clone());
|
||||
match result {
|
||||
Ok(runtime) => runtime,
|
||||
Err(RuntimeError::ExternalAlreadyAttached(_)) => {
|
||||
info!(
|
||||
"An existing XetSession already wraps caller's tokio handle, switching to creating Owned runtime"
|
||||
);
|
||||
XetRuntime::new_with_config(self.config.clone())?
|
||||
},
|
||||
Err(e) => Err(e)?,
|
||||
}
|
||||
},
|
||||
None => {
|
||||
info!("XetSession creating Owned runtime (new thread pool)");
|
||||
XetRuntime::new_with_config(self.config.clone())?
|
||||
},
|
||||
let ctx = if let Some(h) = self.tokio_handle {
|
||||
info!("XetSession using explicitly provided tokio handle");
|
||||
XetContext::from_external(h, self.config)
|
||||
} else {
|
||||
XetContext::with_config(self.config)?
|
||||
};
|
||||
|
||||
let session = XetSession::new(self.config, runtime);
|
||||
let session = XetSession::new(ctx);
|
||||
info!("Session created, session_id={}", session.inner.id);
|
||||
#[cfg(feature = "fd-track")]
|
||||
report_fd_count("XetSessionBuilder::build complete");
|
||||
@@ -218,8 +190,8 @@ impl XetSessionBuilder {
|
||||
/// Handle for managing file uploads and downloads.
|
||||
///
|
||||
/// `XetSession` is the top-level entry point for the xet-session API. It
|
||||
/// owns a `XetRuntime` (tokio thread pool) and shared HTTP settings (endpoint,
|
||||
/// custom headers). Auth tokens are configured per-operation on the builder
|
||||
/// owns a [`XetContext`] (configuration and tokio thread pool). CAS endpoints,
|
||||
/// custom headers, and auth tokens are configured per-operation on the builder
|
||||
/// types returned by each factory method, not on the session itself, so uploads,
|
||||
/// file downloads, and streaming downloads can each carry a different access-level
|
||||
/// token from the same session.
|
||||
@@ -246,12 +218,11 @@ pub struct XetSession {
|
||||
|
||||
impl XetSession {
|
||||
/// Low-level constructor used by [`XetSessionBuilder::build`].
|
||||
fn new(config: XetConfig, runtime: Arc<XetRuntime>) -> Self {
|
||||
let task_runtime = TaskRuntime::new_root(runtime.clone());
|
||||
fn new(ctx: XetContext) -> Self {
|
||||
let task_runtime = TaskRuntime::new_root(ctx.runtime.clone());
|
||||
Self {
|
||||
inner: Arc::new(XetSessionInner {
|
||||
runtime,
|
||||
config,
|
||||
ctx,
|
||||
task_runtime,
|
||||
active_download_stream_groups: Mutex::new(HashMap::new()),
|
||||
id: Uuid::now_v7(),
|
||||
@@ -367,7 +338,7 @@ impl XetSession {
|
||||
let _fd_scope = track_fd_scope(format!("XetSession::sigint_abort({})", self.inner.id));
|
||||
|
||||
info!("Session SIGINT abort, session_id={}", self.inner.id);
|
||||
self.inner.runtime.perform_sigint_shutdown();
|
||||
self.inner.ctx.runtime.perform_sigint_shutdown();
|
||||
|
||||
let active_download_stream_groups = std::mem::take(&mut *self.inner.active_download_stream_groups.lock()?);
|
||||
for (_id, weak_group) in active_download_stream_groups {
|
||||
@@ -383,7 +354,7 @@ impl XetSession {
|
||||
|
||||
#[cfg(test)]
|
||||
pub(super) fn check_alive(&self) -> Result<(), SessionError> {
|
||||
if self.inner.runtime.in_sigint_shutdown() {
|
||||
if self.inner.ctx.runtime.in_sigint_shutdown() {
|
||||
return Err(SessionError::KeyboardInterrupt);
|
||||
}
|
||||
self.inner.task_runtime.check_state("session")
|
||||
@@ -406,7 +377,7 @@ impl XetSession {
|
||||
mod tests {
|
||||
use tempfile::tempdir;
|
||||
use xet_data::processing::{Sha256Policy, XetFileInfo};
|
||||
use xet_runtime::core::{RuntimeMode, XetRuntime};
|
||||
use xet_runtime::core::{RuntimeMode, XetContext};
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -487,13 +458,13 @@ mod tests {
|
||||
assert!(matches!(group_err, SessionError::UserCancelled(_)));
|
||||
}
|
||||
|
||||
// ── XetRuntime::handle_meets_requirements ────────────────────────────────
|
||||
// ── XetContext::handle_meets_requirements ────────────────────────────────
|
||||
|
||||
#[test]
|
||||
// A multi-thread runtime with enable_all() meets all requirements.
|
||||
fn test_handle_multi_thread_all_features_returns_true() {
|
||||
let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
|
||||
assert!(XetRuntime::handle_meets_requirements(rt.handle()));
|
||||
assert!(XetContext::handle_meets_requirements(rt.handle()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -501,28 +472,28 @@ mod tests {
|
||||
// A current_thread runtime is rejected even when enable_all() is set.
|
||||
fn test_handle_current_thread_returns_false() {
|
||||
let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap();
|
||||
assert!(!XetRuntime::handle_meets_requirements(rt.handle()));
|
||||
assert!(!XetContext::handle_meets_requirements(rt.handle()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
// A multi-thread runtime with no drivers enabled returns false.
|
||||
fn test_handle_without_any_driver_returns_false() {
|
||||
let rt = tokio::runtime::Builder::new_multi_thread().build().unwrap();
|
||||
assert!(!XetRuntime::handle_meets_requirements(rt.handle()));
|
||||
assert!(!XetContext::handle_meets_requirements(rt.handle()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
// A multi-thread runtime with only enable_time() is missing the IO driver.
|
||||
fn test_handle_without_io_driver_returns_false() {
|
||||
let rt = tokio::runtime::Builder::new_multi_thread().enable_time().build().unwrap();
|
||||
assert!(!XetRuntime::handle_meets_requirements(rt.handle()));
|
||||
assert!(!XetContext::handle_meets_requirements(rt.handle()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
// A multi-thread runtime with only enable_io() is missing the time driver.
|
||||
fn test_handle_without_time_driver_returns_false() {
|
||||
let rt = tokio::runtime::Builder::new_multi_thread().enable_io().build().unwrap();
|
||||
assert!(!XetRuntime::handle_meets_requirements(rt.handle()));
|
||||
assert!(!XetContext::handle_meets_requirements(rt.handle()));
|
||||
}
|
||||
|
||||
// ── External-mode _blocking guard ────────────────────────────────────────
|
||||
@@ -531,7 +502,7 @@ mod tests {
|
||||
// build_blocking on an External-mode session returns WrongRuntimeMode.
|
||||
async fn test_new_upload_commit_blocking_errors_in_external_mode() {
|
||||
let session = XetSessionBuilder::new().build().unwrap();
|
||||
assert_eq!(session.inner.runtime.mode(), RuntimeMode::External);
|
||||
assert_eq!(session.inner.ctx.runtime.mode(), RuntimeMode::External);
|
||||
let err = session.new_upload_commit().unwrap().build_blocking().err().unwrap();
|
||||
assert!(matches!(err, SessionError::WrongRuntimeMode(_)));
|
||||
}
|
||||
@@ -540,7 +511,7 @@ mod tests {
|
||||
// build_blocking on an External-mode session returns WrongRuntimeMode.
|
||||
async fn test_new_file_download_group_blocking_errors_in_external_mode() {
|
||||
let session = XetSessionBuilder::new().build().unwrap();
|
||||
assert_eq!(session.inner.runtime.mode(), RuntimeMode::External);
|
||||
assert_eq!(session.inner.ctx.runtime.mode(), RuntimeMode::External);
|
||||
let err = session.new_file_download_group().unwrap().build_blocking().err().unwrap();
|
||||
assert!(matches!(err, SessionError::WrongRuntimeMode(_)));
|
||||
}
|
||||
@@ -551,7 +522,7 @@ mod tests {
|
||||
// build_blocking panics when called from within a tokio runtime on an Owned-mode session.
|
||||
fn test_new_upload_commit_blocking_panics_in_async_context() {
|
||||
let session = XetSessionBuilder::new().build().unwrap();
|
||||
assert_eq!(session.inner.runtime.mode(), RuntimeMode::Owned);
|
||||
assert_eq!(session.inner.ctx.runtime.mode(), RuntimeMode::Owned);
|
||||
let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
|
||||
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
|
||||
rt.block_on(async { session.new_upload_commit().unwrap().build_blocking() })
|
||||
@@ -563,7 +534,7 @@ mod tests {
|
||||
// build_blocking panics when called from within a tokio runtime on an Owned-mode session.
|
||||
fn test_new_file_download_group_blocking_panics_in_async_context() {
|
||||
let session = XetSessionBuilder::new().build().unwrap();
|
||||
assert_eq!(session.inner.runtime.mode(), RuntimeMode::Owned);
|
||||
assert_eq!(session.inner.ctx.runtime.mode(), RuntimeMode::Owned);
|
||||
let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
|
||||
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
|
||||
rt.block_on(async { session.new_file_download_group().unwrap().build_blocking() })
|
||||
@@ -767,31 +738,28 @@ mod tests {
|
||||
assert_eq!(collected_b, data_b);
|
||||
}
|
||||
|
||||
// ── Duplicate tokio handle rejection ─────────────────────────────────────
|
||||
// ── Shared tokio handle behavior ──────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
// Building a second session with the same tokio handle while the first is alive must
|
||||
// fall back to Owned mode rather than returning an error — the duplicate is handled
|
||||
// gracefully so callers do not need to track handle ownership.
|
||||
fn test_build_with_same_handle_falls_back_to_owned() {
|
||||
// Building multiple sessions with the same tokio handle is allowed.
|
||||
fn test_build_with_same_handle_stays_external() {
|
||||
let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
|
||||
let handle = tokio_rt.handle().clone();
|
||||
|
||||
let first = XetSessionBuilder::new().with_tokio_handle(handle.clone()).build().unwrap();
|
||||
assert_eq!(first.inner.runtime.mode(), RuntimeMode::External, "first build must use External runtime");
|
||||
assert_eq!(first.inner.ctx.runtime.mode(), RuntimeMode::External, "first build must use External runtime");
|
||||
|
||||
let second = XetSessionBuilder::new().with_tokio_handle(handle).build();
|
||||
assert!(second.is_ok(), "second build with the same tokio handle must still succeed");
|
||||
assert_eq!(
|
||||
second.unwrap().inner.runtime.mode(),
|
||||
RuntimeMode::Owned,
|
||||
"second build must fall back to Owned runtime when External handle is already in use"
|
||||
second.unwrap().inner.ctx.runtime.mode(),
|
||||
RuntimeMode::External,
|
||||
"second build should remain External when sharing the same tokio handle"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
// After the first session is dropped (deregistering the handle), a new session can
|
||||
// attach to the same tokio handle successfully.
|
||||
// Dropping one session must not affect creating another session with the same handle.
|
||||
fn test_build_with_same_handle_succeeds_after_first_is_dropped() {
|
||||
let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
|
||||
let handle = tokio_rt.handle().clone();
|
||||
|
||||
@@ -559,7 +559,7 @@ mod tests {
|
||||
let cas_path = temp.path().join("cas");
|
||||
let endpoint = format!("local://{}", cas_path.display());
|
||||
let session = XetSessionBuilder::new().build()?;
|
||||
let runtime = session.inner.runtime.clone();
|
||||
let runtime = session.inner.ctx.runtime.clone();
|
||||
let commit = session.new_upload_commit()?.with_endpoint(&endpoint).build_blocking()?;
|
||||
let commit_for_thread = commit.clone();
|
||||
let runtime_for_thread = runtime.clone();
|
||||
@@ -1160,7 +1160,7 @@ mod tests {
|
||||
let endpoint = format!("local://{}", temp.path().join("cas").display());
|
||||
futures::executor::block_on(async {
|
||||
let session = XetSessionBuilder::new().build().unwrap();
|
||||
assert_eq!(session.inner.runtime.mode(), RuntimeMode::Owned);
|
||||
assert_eq!(session.inner.ctx.runtime.mode(), RuntimeMode::Owned);
|
||||
|
||||
let data = b"hello from non-tokio executor";
|
||||
let commit = session
|
||||
@@ -1188,7 +1188,7 @@ mod tests {
|
||||
let endpoint = format!("local://{}", temp.path().join("cas").display());
|
||||
smol::block_on(async {
|
||||
let session = XetSessionBuilder::new().build().unwrap();
|
||||
assert_eq!(session.inner.runtime.mode(), RuntimeMode::Owned);
|
||||
assert_eq!(session.inner.ctx.runtime.mode(), RuntimeMode::Owned);
|
||||
|
||||
let data = b"hello from smol executor";
|
||||
let commit = session
|
||||
@@ -1216,7 +1216,7 @@ mod tests {
|
||||
let endpoint = format!("local://{}", temp.path().join("cas").display());
|
||||
async_std::task::block_on(async {
|
||||
let session = XetSessionBuilder::new().build().unwrap();
|
||||
assert_eq!(session.inner.runtime.mode(), RuntimeMode::Owned);
|
||||
assert_eq!(session.inner.ctx.runtime.mode(), RuntimeMode::Owned);
|
||||
|
||||
let data = b"hello from async-std executor";
|
||||
let commit = session
|
||||
|
||||
@@ -9,6 +9,7 @@ use tokio::sync::Mutex;
|
||||
use xet::legacy::progress_tracking::{ItemProgressUpdate, ProgressUpdate, TrackingProgressUpdater};
|
||||
use xet::legacy::{Sha256Policy, XetFileInfo, data_client};
|
||||
use xet_client::cas_client::LocalTestServerBuilder;
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
/// A test `TrackingProgressUpdater` that records all updates.
|
||||
#[derive(Debug, Default)]
|
||||
@@ -38,8 +39,13 @@ fn make_endpoint(server: &xet_client::cas_client::LocalTestServer) -> Option<Str
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn test_context() -> XetContext {
|
||||
XetContext::default().expect("xet context")
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_upload_bytes_and_download_roundtrip() {
|
||||
let ctx = test_context();
|
||||
let server = LocalTestServerBuilder::new().start().await;
|
||||
let endpoint = make_endpoint(&server);
|
||||
|
||||
@@ -47,7 +53,7 @@ mod tests {
|
||||
let policies = vec![Sha256Policy::Compute; contents.len()];
|
||||
|
||||
let file_infos =
|
||||
data_client::upload_bytes_async(contents.clone(), policies, endpoint.clone(), None, None, None, None)
|
||||
data_client::upload_bytes_async(&ctx, contents.clone(), policies, endpoint.clone(), None, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -67,7 +73,7 @@ mod tests {
|
||||
})
|
||||
.collect();
|
||||
|
||||
let paths = data_client::download_async(download_pairs, endpoint, None, None, None, None)
|
||||
let paths = data_client::download_async(&ctx, download_pairs, endpoint, None, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -80,6 +86,7 @@ mod tests {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_upload_files_and_download_roundtrip() {
|
||||
let ctx = test_context();
|
||||
let server = LocalTestServerBuilder::new().start().await;
|
||||
let endpoint = make_endpoint(&server);
|
||||
|
||||
@@ -99,9 +106,10 @@ mod tests {
|
||||
policies.push(Sha256Policy::Compute);
|
||||
}
|
||||
|
||||
let file_infos = data_client::upload_async(file_paths, policies, endpoint.clone(), None, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let file_infos =
|
||||
data_client::upload_async(&ctx, file_paths, policies, endpoint.clone(), None, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(file_infos.len(), 3);
|
||||
|
||||
@@ -115,7 +123,7 @@ mod tests {
|
||||
})
|
||||
.collect();
|
||||
|
||||
let paths = data_client::download_async(download_pairs, endpoint, None, None, None, None)
|
||||
let paths = data_client::download_async(&ctx, download_pairs, endpoint, None, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -127,6 +135,7 @@ mod tests {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_upload_bytes_with_progress_updater() {
|
||||
let ctx = test_context();
|
||||
let server = LocalTestServerBuilder::new().start().await;
|
||||
let endpoint = make_endpoint(&server);
|
||||
|
||||
@@ -134,10 +143,18 @@ mod tests {
|
||||
let policies = vec![Sha256Policy::Compute; contents.len()];
|
||||
let updater = Arc::new(RecordingUpdater::default());
|
||||
|
||||
let file_infos =
|
||||
data_client::upload_bytes_async(contents, policies, endpoint, None, None, Some(updater.clone()), None)
|
||||
.await
|
||||
.unwrap();
|
||||
let file_infos = data_client::upload_bytes_async(
|
||||
&ctx,
|
||||
contents,
|
||||
policies,
|
||||
endpoint,
|
||||
None,
|
||||
None,
|
||||
Some(updater.clone()),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(file_infos.len(), 2);
|
||||
|
||||
@@ -157,6 +174,7 @@ mod tests {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_download_with_per_file_progress_updaters() {
|
||||
let ctx = test_context();
|
||||
let server = LocalTestServerBuilder::new().start().await;
|
||||
let endpoint = make_endpoint(&server);
|
||||
|
||||
@@ -164,7 +182,7 @@ mod tests {
|
||||
let policies = vec![Sha256Policy::Compute; contents.len()];
|
||||
|
||||
let file_infos =
|
||||
data_client::upload_bytes_async(contents.clone(), policies, endpoint.clone(), None, None, None, None)
|
||||
data_client::upload_bytes_async(&ctx, contents.clone(), policies, endpoint.clone(), None, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -184,7 +202,7 @@ mod tests {
|
||||
let updaters: Vec<Arc<dyn TrackingProgressUpdater>> =
|
||||
vec![updater_a.clone() as Arc<dyn TrackingProgressUpdater>, updater_b.clone()];
|
||||
|
||||
let paths = data_client::download_async(download_pairs, endpoint, None, None, Some(updaters), None)
|
||||
let paths = data_client::download_async(&ctx, download_pairs, endpoint, None, None, Some(updaters), None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -202,6 +220,7 @@ mod tests {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_upload_files_with_progress_updater() {
|
||||
let ctx = test_context();
|
||||
let server = LocalTestServerBuilder::new().start().await;
|
||||
let endpoint = make_endpoint(&server);
|
||||
|
||||
@@ -219,10 +238,18 @@ mod tests {
|
||||
|
||||
let updater = Arc::new(RecordingUpdater::default());
|
||||
|
||||
let file_infos =
|
||||
data_client::upload_async(file_paths, policies, endpoint.clone(), None, None, Some(updater.clone()), None)
|
||||
.await
|
||||
.unwrap();
|
||||
let file_infos = data_client::upload_async(
|
||||
&ctx,
|
||||
file_paths,
|
||||
policies,
|
||||
endpoint.clone(),
|
||||
None,
|
||||
None,
|
||||
Some(updater.clone()),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(file_infos.len(), 2);
|
||||
|
||||
@@ -235,6 +262,7 @@ mod tests {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_upload_download_large_files() {
|
||||
let ctx = test_context();
|
||||
let server = LocalTestServerBuilder::new().start().await;
|
||||
let endpoint = make_endpoint(&server);
|
||||
|
||||
@@ -245,6 +273,7 @@ mod tests {
|
||||
fs::write(&path, &large_data).unwrap();
|
||||
|
||||
let file_infos = data_client::upload_async(
|
||||
&ctx,
|
||||
vec![path.to_string_lossy().to_string()],
|
||||
vec![Sha256Policy::Compute],
|
||||
endpoint.clone(),
|
||||
@@ -263,6 +292,7 @@ mod tests {
|
||||
let out_path = download_dir.path().join("large_out.bin");
|
||||
|
||||
let paths = data_client::download_async(
|
||||
&ctx,
|
||||
vec![(file_infos[0].clone(), out_path.to_string_lossy().to_string())],
|
||||
endpoint,
|
||||
None,
|
||||
@@ -279,6 +309,7 @@ mod tests {
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn test_progress_updates_are_monotonic() {
|
||||
let ctx = test_context();
|
||||
let server = LocalTestServerBuilder::new().start().await;
|
||||
let endpoint = make_endpoint(&server);
|
||||
|
||||
@@ -290,6 +321,7 @@ mod tests {
|
||||
let updater = Arc::new(RecordingUpdater::default());
|
||||
|
||||
data_client::upload_async(
|
||||
&ctx,
|
||||
vec![path.to_string_lossy().to_string()],
|
||||
vec![Sha256Policy::Compute],
|
||||
endpoint,
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
//! - **Tokio async** (External mode): standard `#[tokio::test]` tests.
|
||||
//! - **Blocking** (Owned mode): sync `build()` + `_blocking` methods.
|
||||
//! - **Non-tokio async bridge** (Owned mode): `futures::executor`, `smol`, `async-std` driving async methods via
|
||||
//! `XetRuntime::bridge_async`.
|
||||
//! `XetContext::bridge_async`.
|
||||
//! - **Deficient tokio runtime** (fallback to Owned mode): tokio runtimes missing IO/time drivers or using
|
||||
//! `current_thread` flavor.
|
||||
//! - **Blocking from non-tokio executors**: `_blocking` methods called from within smol/async-std/futures executor
|
||||
@@ -784,7 +784,7 @@ fn blocking_multiple_commits_and_groups() {
|
||||
// ── 3. Non-tokio async bridge tests (Owned mode) ────────────────────────
|
||||
//
|
||||
// build() from a non-tokio executor creates an Owned-mode runtime.
|
||||
// Async methods use XetRuntime::bridge_async: the future runs on the owned tokio
|
||||
// Async methods use XetContext::bridge_async: the future runs on the owned tokio
|
||||
// pool while the caller's executor polls the oneshot receiver.
|
||||
|
||||
#[test]
|
||||
@@ -2089,8 +2089,7 @@ fn fd_leak_single_session_roundtrip() {
|
||||
#[test]
|
||||
#[serial(fd_leak)]
|
||||
fn fd_leak_isolate_components() {
|
||||
use xet_runtime::config::XetConfig;
|
||||
use xet_runtime::core::XetRuntime;
|
||||
use xet_runtime::core::XetContext;
|
||||
|
||||
let report_nonzero_delta = |label: &str, baseline: usize| {
|
||||
let delta = fd_delta_from_baseline(baseline);
|
||||
@@ -2101,14 +2100,14 @@ fn fd_leak_isolate_components() {
|
||||
|
||||
// Warmup: first runtime creation installs signal handlers / global state.
|
||||
{
|
||||
let rt = XetRuntime::new_with_config(XetConfig::new()).unwrap();
|
||||
drop(rt);
|
||||
let ctx = XetContext::default().unwrap();
|
||||
drop(ctx);
|
||||
}
|
||||
|
||||
let before = count_open_fds();
|
||||
{
|
||||
let rt = XetRuntime::new_with_config(XetConfig::new()).unwrap();
|
||||
drop(rt);
|
||||
let ctx = XetContext::default().unwrap();
|
||||
drop(ctx);
|
||||
}
|
||||
assert_fd_delta_eventually_le("runtime create/drop", before, FD_TOLERANCE);
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::any::Any;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::AtomicU64;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
@@ -7,15 +9,10 @@ use tokio::sync::Semaphore;
|
||||
use crate::config::XetConfig;
|
||||
use crate::utils::adjustable_semaphore::AdjustableSemaphore;
|
||||
|
||||
/// Holds global values that are shared across the entire runtime.
|
||||
/// Holds shared state that is common across the entire context.
|
||||
///
|
||||
/// Accessible via `XetRuntime::current().common()`.
|
||||
#[derive(Debug)]
|
||||
/// Accessible via `ctx.common` on a [`super::XetContext`].
|
||||
pub struct XetCommon {
|
||||
// A cached reqwest Client to be shared by all high-level clients.
|
||||
// The String tag identifies the client type (e.g., "tcp" for regular, socket path for UDS).
|
||||
global_reqwest_client: Mutex<Option<(String, Client)>>,
|
||||
|
||||
/// Limits the number of files being ingested (cleaned/uploaded) concurrently.
|
||||
pub file_ingestion_semaphore: Arc<Semaphore>,
|
||||
|
||||
@@ -27,13 +24,17 @@ pub struct XetCommon {
|
||||
|
||||
/// Tracks the number of currently active file downloads for dynamic buffer scaling.
|
||||
pub active_downloads: Arc<AtomicU64>,
|
||||
|
||||
/// Type-erased cache for runtime-scoped resources. Subsystems store their own
|
||||
/// caches here (keyed by a unique string) instead of using process-global statics,
|
||||
/// so everything is cleaned up when the runtime drops.
|
||||
runtime_cache: Mutex<HashMap<String, Box<dyn Any + Send + Sync>>>,
|
||||
}
|
||||
|
||||
impl XetCommon {
|
||||
/// Creates a new `XetCommon` instance with the given configuration.
|
||||
pub fn new(config: &XetConfig) -> Self {
|
||||
Self {
|
||||
global_reqwest_client: Mutex::new(None),
|
||||
file_ingestion_semaphore: Arc::new(Semaphore::new(config.data.max_concurrent_file_ingestion)),
|
||||
file_download_semaphore: Arc::new(Semaphore::new(config.data.max_concurrent_file_downloads)),
|
||||
reconstruction_download_buffer: {
|
||||
@@ -42,9 +43,31 @@ impl XetCommon {
|
||||
AdjustableSemaphore::new(base, (base, limit))
|
||||
},
|
||||
active_downloads: Arc::new(AtomicU64::new(0)),
|
||||
runtime_cache: Mutex::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieves a cached value by key, or creates and stores it using `create`.
|
||||
///
|
||||
/// Values are stored as type-erased `Box<dyn Any>` and recovered via `downcast_ref`.
|
||||
/// Typical stored types are `Arc<Mutex<...>>` or `Arc<RwLock<...>>`, making
|
||||
/// the clone cheap (just an Arc bump).
|
||||
pub fn cache_get_or_create<T, F>(&self, key: &str, create: F) -> T
|
||||
where
|
||||
T: Clone + Send + Sync + 'static,
|
||||
F: FnOnce() -> T,
|
||||
{
|
||||
let mut guard = self.runtime_cache.lock().unwrap();
|
||||
if let Some(existing) = guard.get(key)
|
||||
&& let Some(val) = existing.downcast_ref::<T>()
|
||||
{
|
||||
return val.clone();
|
||||
}
|
||||
let value = create();
|
||||
guard.insert(key.to_string(), Box::new(value.clone()));
|
||||
value
|
||||
}
|
||||
|
||||
/// Gets or creates a reqwest client, using a tag to identify the client type.
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -57,15 +80,13 @@ impl XetCommon {
|
||||
where
|
||||
F: FnOnce() -> std::result::Result<Client, reqwest::Error>,
|
||||
{
|
||||
let mut guard = self.global_reqwest_client.lock()?;
|
||||
let client_cache: Arc<Mutex<Option<(String, Client)>>> =
|
||||
self.cache_get_or_create("global_reqwest_client", || Arc::new(Mutex::new(None)));
|
||||
let mut guard = client_cache.lock()?;
|
||||
|
||||
match guard.as_ref() {
|
||||
Some((cached_tag, cached_client)) if cached_tag == &tag => {
|
||||
// Tag matches, return a clone of the existing client
|
||||
Ok(cached_client.clone())
|
||||
},
|
||||
Some((cached_tag, cached_client)) if cached_tag == &tag => Ok(cached_client.clone()),
|
||||
_ => {
|
||||
// Tag doesn't match or no client exists, create a new one
|
||||
let new_client = create_client_fn()?;
|
||||
*guard = Some((tag, new_client.clone()));
|
||||
Ok(new_client)
|
||||
@@ -124,14 +145,6 @@ mod tests {
|
||||
assert_eq!(call_count.load(Ordering::SeqCst), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_initializes_with_empty_client_cache() {
|
||||
let common = XetCommon::new(&XetConfig::new());
|
||||
|
||||
let guard = common.global_reqwest_client.lock().unwrap();
|
||||
assert!(guard.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_replaces_client_when_tag_changes() {
|
||||
let common = XetCommon::new(&XetConfig::new());
|
||||
@@ -142,23 +155,22 @@ mod tests {
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
{
|
||||
let guard = common.global_reqwest_client.lock().unwrap();
|
||||
let (tag, _) = guard.as_ref().unwrap();
|
||||
assert_eq!(tag, "tcp");
|
||||
}
|
||||
|
||||
let _client2 = common
|
||||
.get_or_create_reqwest_client("/tmp/socket.sock".to_string(), || {
|
||||
reqwest::Client::builder().user_agent("uds-client").build()
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
{
|
||||
let guard = common.global_reqwest_client.lock().unwrap();
|
||||
let (tag, _) = guard.as_ref().unwrap();
|
||||
assert_eq!(tag, "/tmp/socket.sock");
|
||||
}
|
||||
// Second call with a different tag should have triggered creation again;
|
||||
// verify by calling with the new tag and confirming no creation happens.
|
||||
let call_count = AtomicUsize::new(0);
|
||||
let _client3 = common
|
||||
.get_or_create_reqwest_client("/tmp/socket.sock".to_string(), || {
|
||||
call_count.fetch_add(1, Ordering::SeqCst);
|
||||
reqwest::Client::builder().build()
|
||||
})
|
||||
.unwrap();
|
||||
assert_eq!(call_count.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -178,4 +190,46 @@ mod tests {
|
||||
|
||||
assert_eq!(common.active_downloads.load(Ordering::Relaxed), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_get_or_create_returns_cached_value() {
|
||||
let common = XetCommon::new(&XetConfig::new());
|
||||
let call_count = AtomicUsize::new(0);
|
||||
|
||||
let v1: Arc<Mutex<Vec<i32>>> = common.cache_get_or_create("my_cache", || {
|
||||
call_count.fetch_add(1, Ordering::SeqCst);
|
||||
Arc::new(Mutex::new(vec![1, 2, 3]))
|
||||
});
|
||||
|
||||
let v2: Arc<Mutex<Vec<i32>>> = common.cache_get_or_create("my_cache", || {
|
||||
call_count.fetch_add(1, Ordering::SeqCst);
|
||||
Arc::new(Mutex::new(vec![4, 5, 6]))
|
||||
});
|
||||
|
||||
assert_eq!(call_count.load(Ordering::SeqCst), 1);
|
||||
assert!(Arc::ptr_eq(&v1, &v2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_different_keys_are_independent() {
|
||||
let common = XetCommon::new(&XetConfig::new());
|
||||
|
||||
let v1: Arc<String> = common.cache_get_or_create("key_a", || Arc::new("alpha".to_string()));
|
||||
let v2: Arc<String> = common.cache_get_or_create("key_b", || Arc::new("beta".to_string()));
|
||||
|
||||
assert_eq!(v1.as_str(), "alpha");
|
||||
assert_eq!(v2.as_str(), "beta");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_type_mismatch_creates_new_entry() {
|
||||
let common = XetCommon::new(&XetConfig::new());
|
||||
|
||||
let _v1: Arc<String> = common.cache_get_or_create("shared_key", || Arc::new("original".to_string()));
|
||||
|
||||
// Same key but different type -- downcast fails, so create is called and the
|
||||
// entry is replaced with the new type.
|
||||
let v2: Arc<i32> = common.cache_get_or_create("shared_key", || Arc::new(42));
|
||||
assert_eq!(*v2, 42);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
use std::cell::RefCell;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::XetRuntime;
|
||||
use crate::config::XetConfig;
|
||||
|
||||
// Use thread-local references to the config that caches access. This way, xet_config() will
|
||||
// can be called outside of an existing runtime.
|
||||
thread_local! {
|
||||
static THREAD_CONFIG_REF: RefCell<Option<Arc<XetConfig>>> = const { RefCell::new(None) };
|
||||
}
|
||||
|
||||
pub fn xet_config() -> Arc<XetConfig> {
|
||||
if let Some(config) = THREAD_CONFIG_REF.with_borrow(|config| config.clone()) {
|
||||
return config;
|
||||
}
|
||||
|
||||
let config = {
|
||||
if let Some(runtime) = XetRuntime::current_if_exists() {
|
||||
runtime.config().clone()
|
||||
} else {
|
||||
Arc::new(XetConfig::new())
|
||||
}
|
||||
};
|
||||
|
||||
THREAD_CONFIG_REF.set(Some(config.clone()));
|
||||
|
||||
config
|
||||
}
|
||||
94
xet_runtime/src/core/context.rs
Normal file
94
xet_runtime/src/core/context.rs
Normal file
@@ -0,0 +1,94 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use tokio::runtime::Handle as TokioRuntimeHandle;
|
||||
use tracing::info;
|
||||
|
||||
use super::XetCommon;
|
||||
use super::runtime::XetRuntime;
|
||||
use crate::config::XetConfig;
|
||||
use crate::error::RuntimeError;
|
||||
|
||||
/// Bundles the thread pool, configuration, and shared state into a single clonable handle.
|
||||
///
|
||||
/// Every major struct in the codebase should accept `&XetContext` in its constructor
|
||||
/// and store a clone. This replaces the thread-local globals, allowing multiple
|
||||
/// independent runtimes within the same process.
|
||||
#[derive(Clone)]
|
||||
pub struct XetContext {
|
||||
pub runtime: Arc<XetRuntime>,
|
||||
pub config: Arc<XetConfig>,
|
||||
pub common: Arc<XetCommon>,
|
||||
}
|
||||
|
||||
impl XetContext {
|
||||
/// Creates a context from a pre-built thread pool and configuration.
|
||||
pub fn new(config: XetConfig, runtime: Arc<XetRuntime>) -> Self {
|
||||
let config = Arc::new(config);
|
||||
let common = Arc::new(XetCommon::new(&config));
|
||||
Self {
|
||||
runtime,
|
||||
config,
|
||||
common,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a context with default configuration and an auto-detected thread pool.
|
||||
///
|
||||
/// If called from an owned runtime worker thread, reuses that owned [`XetRuntime`].
|
||||
/// Otherwise, if called from within an existing tokio runtime, wraps that runtime.
|
||||
/// If neither is available, spins up a new owned tokio thread pool.
|
||||
#[allow(clippy::should_implement_trait)]
|
||||
pub fn default() -> Result<Self, RuntimeError> {
|
||||
Self::with_config(XetConfig::new())
|
||||
}
|
||||
|
||||
/// Creates a context with the given configuration and an auto-detected thread pool.
|
||||
///
|
||||
/// Follows the same runtime selection as [`default`](Self::default):
|
||||
/// reuse an owned runtime if available, wrap an existing tokio handle, or create a new one.
|
||||
pub fn with_config(config: XetConfig) -> Result<Self, RuntimeError> {
|
||||
let runtime = if let Some(runtime) = XetRuntime::current_if_exists() {
|
||||
runtime
|
||||
} else if let Ok(handle) = TokioRuntimeHandle::try_current()
|
||||
&& Self::handle_meets_requirements(&handle)
|
||||
{
|
||||
info!(
|
||||
"Detected compatible existing Tokio runtime; using external handle instead of creating a new thread pool"
|
||||
);
|
||||
XetRuntime::from_external(handle)
|
||||
} else {
|
||||
XetRuntime::new(&config)?
|
||||
};
|
||||
Ok(Self::new(config, runtime))
|
||||
}
|
||||
|
||||
/// Wraps a caller-provided tokio handle with the given configuration.
|
||||
pub fn from_external(rt_handle: TokioRuntimeHandle, config: XetConfig) -> Self {
|
||||
Self::new(config, XetRuntime::from_external(rt_handle))
|
||||
}
|
||||
|
||||
/// Checks whether a tokio handle meets the requirements for use with xet.
|
||||
pub fn handle_meets_requirements(handle: &TokioRuntimeHandle) -> bool {
|
||||
XetRuntime::handle_meets_requirements(handle)
|
||||
}
|
||||
|
||||
/// Returns an error if the runtime is in the middle of a SIGINT shutdown.
|
||||
#[inline]
|
||||
pub fn check_sigint_shutdown(&self) -> Result<(), RuntimeError> {
|
||||
if self.runtime.in_sigint_shutdown() {
|
||||
Err(RuntimeError::KeyboardInterrupt)
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for XetContext {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("XetContext")
|
||||
.field("runtime", &self.runtime)
|
||||
.field("config", &"...")
|
||||
.field("common", &"...")
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,12 @@
|
||||
pub mod common;
|
||||
pub mod exports;
|
||||
|
||||
pub mod context;
|
||||
pub mod runtime;
|
||||
|
||||
pub use common::XetCommon;
|
||||
pub use runtime::{RuntimeMode, XetRuntime, check_sigint_shutdown};
|
||||
pub use context::XetContext;
|
||||
pub use runtime::{RuntimeMode, XetRuntime};
|
||||
|
||||
pub mod sync_primatives;
|
||||
pub use sync_primatives::{SyncJoinHandle, spawn_os_thread};
|
||||
@@ -18,7 +20,3 @@ mod cache_dir;
|
||||
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
pub use cache_dir::xet_cache_root;
|
||||
|
||||
mod config;
|
||||
|
||||
pub use config::xet_config;
|
||||
|
||||
@@ -11,7 +11,6 @@ use std::sync::{Arc, LazyLock, OnceLock, Weak};
|
||||
use std::task::{Context, Waker};
|
||||
|
||||
use futures::FutureExt;
|
||||
use reqwest::Client;
|
||||
use tokio::runtime::{Builder as TokioRuntimeBuilder, Handle as TokioRuntimeHandle, Runtime as TokioRuntime};
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::task::JoinHandle;
|
||||
@@ -19,7 +18,6 @@ use tracing::debug;
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
use tracing::info;
|
||||
|
||||
use super::XetCommon;
|
||||
use crate::config::XetConfig;
|
||||
use crate::error::RuntimeError;
|
||||
#[cfg(feature = "fd-track")]
|
||||
@@ -73,19 +71,6 @@ fn get_num_tokio_worker_threads() -> usize {
|
||||
n
|
||||
}
|
||||
|
||||
/// Quick function to check for a sigint shutdown.
|
||||
#[inline]
|
||||
pub fn check_sigint_shutdown() -> Result<(), RuntimeError> {
|
||||
if XetRuntime::current_if_exists()
|
||||
.map(|rt| rt.in_sigint_shutdown())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
Err(RuntimeError::KeyboardInterrupt)
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether the runtime owns its tokio thread pool or wraps an external handle.
|
||||
///
|
||||
/// - **`Owned`**: runtime created its own thread pool. Both async bridging ([`XetRuntime::bridge_async`]) and sync
|
||||
@@ -101,6 +86,19 @@ pub enum RuntimeMode {
|
||||
|
||||
type OwnedRuntimeCell = Arc<std::sync::RwLock<Option<Arc<TokioRuntime>>>>;
|
||||
|
||||
// Use thread-local references to the active XetRuntime on each tokio worker thread so code can
|
||||
// resolve the active pool. Weak (not Arc) avoids a cycle: worker TLS -> Arc<XetRuntime> ->
|
||||
// OwnedRuntimeCell -> TokioRuntime -> worker threads.
|
||||
thread_local! {
|
||||
static THREAD_THREADPOOL_REF: RefCell<Option<(u32, Weak<XetRuntime>)>> =
|
||||
const { RefCell::new(None) };
|
||||
}
|
||||
|
||||
// External-mode XetRuntime instances from `from_external_with_config` are registered by tokio runtime ID
|
||||
// so duplicate attachment can be detected. Removed in Drop when the last Arc is released.
|
||||
static EXTERNAL_THREADPOOL_REGISTRY: LazyLock<std::sync::RwLock<HashMap<tokio::runtime::Id, Weak<XetRuntime>>>> =
|
||||
LazyLock::new(|| std::sync::RwLock::new(HashMap::new()));
|
||||
|
||||
#[derive(Debug)]
|
||||
#[cfg_attr(target_family = "wasm", allow(dead_code))]
|
||||
enum RuntimeBackend {
|
||||
@@ -131,19 +129,20 @@ impl<F: FnOnce()> Drop for CallbackGuard<F> {
|
||||
}
|
||||
}
|
||||
|
||||
/// This module provides a simple wrapper around Tokio's runtime to create a thread pool
|
||||
/// with some default settings. It is intended to be used as a singleton thread pool for
|
||||
/// the entire application.
|
||||
/// [`XetRuntime`] is the async execution backend: it either owns a Tokio multi-thread runtime or wraps an
|
||||
/// external [`TokioRuntimeHandle`](tokio::runtime::Handle) (see [`RuntimeMode`]).
|
||||
///
|
||||
/// The `ThreadPool` struct encapsulates a Tokio runtime and provides methods to run
|
||||
/// futures to completion, spawn new tasks, and get a handle to the runtime.
|
||||
/// It exposes [`Self::bridge_async`] and [`Self::bridge_sync`] to run work on the pool, [`Self::spawn`] for
|
||||
/// fire-and-forget tasks, and [`Self::perform_sigint_shutdown`] / [`Self::in_sigint_shutdown`] so callers can
|
||||
/// align with process-wide SIGINT teardown.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use xet_runtime::config::XetConfig;
|
||||
/// use xet_runtime::core::XetRuntime;
|
||||
///
|
||||
/// let pool = XetRuntime::new().expect("Error initializing runtime.");
|
||||
/// let pool = XetRuntime::new(&XetConfig::new()).expect("Error initializing runtime.");
|
||||
///
|
||||
/// let result = pool
|
||||
/// .bridge_sync(async {
|
||||
@@ -154,25 +153,6 @@ impl<F: FnOnce()> Drop for CallbackGuard<F> {
|
||||
///
|
||||
/// assert_eq!(result, 42);
|
||||
/// ```
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// The `new_threadpool` function will intentionally panic if the Tokio runtime cannot be
|
||||
/// created. This is because the application should not continue running without a
|
||||
/// functioning thread pool.
|
||||
///
|
||||
/// # Settings
|
||||
///
|
||||
/// The thread pool is configured with the following settings:
|
||||
/// - 4 worker threads
|
||||
/// - Thread names prefixed with "hf-xet-"
|
||||
/// - 8MB stack size per thread (default is 2MB)
|
||||
/// - Maximum of 100 blocking threads
|
||||
/// - All Tokio features enabled (IO, Timer, Signal, Reactor)
|
||||
///
|
||||
/// # Structs
|
||||
///
|
||||
/// - `ThreadPool`: The main struct that encapsulates the Tokio runtime.
|
||||
#[derive(Debug)]
|
||||
pub struct XetRuntime {
|
||||
// Runtime backend and its owned state (if any).
|
||||
@@ -183,107 +163,37 @@ pub struct XetRuntime {
|
||||
// while holding a reference to the runtime does.
|
||||
handle_ref: OnceLock<TokioRuntimeHandle>,
|
||||
|
||||
// The number of external threads calling into this threadpool.
|
||||
// The number of external threads calling into this runtime.
|
||||
external_executor_count: AtomicUsize,
|
||||
|
||||
// Are we in the middle of a sigint shutdown?
|
||||
sigint_shutdown: AtomicBool,
|
||||
|
||||
// Shared state that is common across the entire runtime.
|
||||
common: XetCommon,
|
||||
|
||||
// Primary configuration struct
|
||||
config: Arc<XetConfig>,
|
||||
|
||||
// System monitor instance if enabled, monitor starts on initiation
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
system_monitor: Option<SystemMonitor>,
|
||||
}
|
||||
|
||||
// Use thread-local references to the runtime that are set on initialization among all
|
||||
// the worker threads in the runtime. This way, XetRuntime::current() will always refer to
|
||||
// the runtime active with that worker thread.
|
||||
//
|
||||
// IMPORTANT: Uses Weak<XetRuntime> instead of Arc to avoid a reference cycle:
|
||||
// worker thread TLS -> Arc<XetRuntime> -> OwnedRuntimeCell -> TokioRuntime -> worker threads
|
||||
// With Weak, the cycle is broken: when the last external Arc<XetRuntime> is dropped,
|
||||
// the runtime can shut down and join its worker threads normally.
|
||||
thread_local! {
|
||||
static THREAD_RUNTIME_REF: RefCell<Option<(u32, Weak<XetRuntime>)>> = const { RefCell::new(None) };
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
fn system_monitor_for_config(config: &XetConfig) -> Option<SystemMonitor> {
|
||||
config
|
||||
.system_monitor
|
||||
.enabled
|
||||
.then(|| {
|
||||
SystemMonitor::follow_process(config.system_monitor.sample_interval, config.system_monitor.log_path.clone())
|
||||
.ok()
|
||||
})
|
||||
.flatten()
|
||||
}
|
||||
|
||||
// Registry for External-mode runtimes created via from_external_with_config.
|
||||
// Keyed by tokio runtime ID so current_if_exists() can find the right XetRuntime
|
||||
// (with the correct XetConfig and XetCommon) when called from the caller's tokio threads,
|
||||
// where THREAD_RUNTIME_REF is never set.
|
||||
//
|
||||
// Uses std::sync (not tokio::sync) because the registry must be accessible from non-async
|
||||
// contexts such as Drop impls and sync builder methods.
|
||||
static EXTERNAL_RUNTIME_REGISTRY: LazyLock<std::sync::RwLock<HashMap<tokio::runtime::Id, Weak<XetRuntime>>>> =
|
||||
LazyLock::new(|| std::sync::RwLock::new(HashMap::new()));
|
||||
|
||||
impl XetRuntime {
|
||||
/// Return the current threadpool that the current worker thread uses. Will fail if
|
||||
/// called from a thread that is not spawned from the current runtime.
|
||||
#[inline]
|
||||
pub fn current() -> Arc<Self> {
|
||||
if let Some(rt) = Self::current_if_exists() {
|
||||
return rt;
|
||||
}
|
||||
|
||||
let Ok(tokio_rt) = TokioRuntimeHandle::try_current() else {
|
||||
panic!("ThreadPool::current() called before ThreadPool::new() or on thread outside of current runtime.");
|
||||
};
|
||||
|
||||
Self::from_external(tokio_rt)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn current_if_exists() -> Option<Arc<Self>> {
|
||||
// 1. Thread-local: set by on_thread_start in new_with_config (Owned mode).
|
||||
let maybe_rt = THREAD_RUNTIME_REF.with_borrow(|rt| {
|
||||
rt.as_ref().and_then(|(pid, weak)| {
|
||||
if *pid == std::process::id() {
|
||||
weak.upgrade()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
});
|
||||
if let Some(rt) = maybe_rt {
|
||||
return Some(rt);
|
||||
}
|
||||
|
||||
// 2. Handle registry: set by from_external_with_config (External mode). Returns the XetRuntime with the correct
|
||||
// XetConfig and XetCommon for this runtime.
|
||||
if let Ok(handle) = TokioRuntimeHandle::try_current() {
|
||||
if let Ok(reg) = EXTERNAL_RUNTIME_REGISTRY.read()
|
||||
&& let Some(weak) = reg.get(&handle.id())
|
||||
&& let Some(rt) = weak.upgrade()
|
||||
{
|
||||
return Some(rt);
|
||||
}
|
||||
// Fallback: no XetSession owns this handle; create a bare default-config wrapper.
|
||||
Some(Self::from_external(handle))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new runtime with the default configuration.
|
||||
pub fn new() -> Result<Arc<Self>, RuntimeError> {
|
||||
Self::new_with_config(XetConfig::new())
|
||||
}
|
||||
|
||||
/// Creates a new runtime with the given configuration.
|
||||
pub fn new_with_config(config: XetConfig) -> Result<Arc<Self>, RuntimeError> {
|
||||
/// Creates a new owned tokio thread pool with the given configuration.
|
||||
pub fn new(config: &XetConfig) -> Result<Arc<Self>, RuntimeError> {
|
||||
#[cfg(feature = "fd-track")]
|
||||
let _fd_scope = track_fd_scope("XetRuntime::new_with_config");
|
||||
let _fd_scope = track_fd_scope("XetRuntime::new");
|
||||
|
||||
let runtime = Arc::new(std::sync::RwLock::new(None));
|
||||
|
||||
// First, get an Arc value holding the runtime that we can initialize the
|
||||
// thread-local THREAD_RUNTIME_REF with
|
||||
let rt = Arc::new(Self {
|
||||
backend: RuntimeBackend::OwnedThreadPool {
|
||||
runtime: runtime.clone(),
|
||||
@@ -291,35 +201,16 @@ impl XetRuntime {
|
||||
handle_ref: OnceLock::new(),
|
||||
external_executor_count: 0.into(),
|
||||
sigint_shutdown: false.into(),
|
||||
common: XetCommon::new(&config),
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
system_monitor: config
|
||||
.system_monitor
|
||||
.enabled
|
||||
.then(|| {
|
||||
SystemMonitor::follow_process(
|
||||
config.system_monitor.sample_interval,
|
||||
config.system_monitor.log_path.clone(),
|
||||
)
|
||||
.ok()
|
||||
})
|
||||
.flatten(),
|
||||
config: Arc::new(config),
|
||||
system_monitor: system_monitor_for_config(config),
|
||||
});
|
||||
|
||||
// Each tokio worker thread stores a Weak reference so it can resolve its owning
|
||||
// XetRuntime via current()/current_if_exists(). Weak (not Arc) avoids a cycle:
|
||||
// XetRuntime owns the tokio runtime, so a strong TLS ref from workers would prevent
|
||||
// the runtime from being dropped when the last external Arc<XetRuntime> is released.
|
||||
let rt_weak = Arc::downgrade(&rt);
|
||||
let pid = std::process::id();
|
||||
let set_threadlocal_reference = move || {
|
||||
THREAD_RUNTIME_REF.set(Some((pid, rt_weak.clone())));
|
||||
THREAD_THREADPOOL_REF.set(Some((pid, rt_weak.clone())));
|
||||
};
|
||||
|
||||
// Set the name of a new thread for the threadpool. Names are prefixed with
|
||||
// `THREADPOOL_THREAD_ID_PREFIX` and suffixed with a counter:
|
||||
// e.g. hf-xet-0, hf-xet-1, hf-xet-2, ...
|
||||
let thread_id = AtomicUsize::new(0);
|
||||
let get_thread_name = move || {
|
||||
let id = thread_id.fetch_add(1, Ordering::Relaxed);
|
||||
@@ -329,7 +220,6 @@ impl XetRuntime {
|
||||
let mut tokio_rt_builder = {
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
{
|
||||
// A new multithreaded runtime with a capped number of threads
|
||||
TokioRuntimeBuilder::new_multi_thread()
|
||||
}
|
||||
|
||||
@@ -343,75 +233,44 @@ impl XetRuntime {
|
||||
tokio_rt_builder.worker_threads(get_num_tokio_worker_threads());
|
||||
}
|
||||
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
let tokio_rt_builder = tokio_rt_builder
|
||||
.on_thread_start(set_threadlocal_reference)
|
||||
.thread_keep_alive(std::time::Duration::from_millis(100));
|
||||
|
||||
#[cfg(target_family = "wasm")]
|
||||
let tokio_rt_builder = tokio_rt_builder.on_thread_start(set_threadlocal_reference);
|
||||
|
||||
let tokio_rt = tokio_rt_builder
|
||||
.thread_name_fn(get_thread_name) // thread names will be hf-xet-0, hf-xet-1, etc.
|
||||
.on_thread_start(set_threadlocal_reference) // Set the local runtime reference.
|
||||
.thread_stack_size(THREADPOOL_STACK_SIZE) // 8MB stack size, default is 2MB
|
||||
.thread_keep_alive(std::time::Duration::from_millis(100)) // Don't keep idle blocking threads for long
|
||||
.enable_all() // enable all features, including IO/Timer/Signal/Reactor
|
||||
.thread_name_fn(get_thread_name)
|
||||
.thread_stack_size(THREADPOOL_STACK_SIZE)
|
||||
.enable_all()
|
||||
.build()
|
||||
.map_err(RuntimeError::RuntimeInit)?;
|
||||
|
||||
// Now that the runtime is created, fill out the original struct.
|
||||
let handle = tokio_rt.handle().clone();
|
||||
let tokio_rt = Arc::new(tokio_rt);
|
||||
*runtime.write().unwrap() = Some(tokio_rt); // Only fails if other thread destroyed mutex; unwrap ok.
|
||||
rt.handle_ref.set(handle).unwrap(); // Only fails if set called twice; unwrap ok.
|
||||
*runtime.write().unwrap() = Some(tokio_rt);
|
||||
rt.handle_ref.set(handle).unwrap();
|
||||
|
||||
#[cfg(feature = "fd-track")]
|
||||
report_fd_count("XetRuntime::new_with_config complete");
|
||||
report_fd_count("XetRuntime::new complete");
|
||||
|
||||
Ok(rt)
|
||||
}
|
||||
|
||||
/// Wrap a caller-provided tokio handle after validating that it meets requirements.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// - [`RuntimeError::InvalidRuntime`] — the handle lacks multi-thread flavor, time driver, or IO driver.
|
||||
/// - [`RuntimeError::ExternalAlreadyAttached`] — a live `XetRuntime` is already registered for this handle (checked
|
||||
/// inside [`from_external_with_config`](Self::from_external_with_config)).
|
||||
///
|
||||
/// Not available on WASM targets.
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
pub fn from_validated_external(
|
||||
rt_handle: TokioRuntimeHandle,
|
||||
config: XetConfig,
|
||||
) -> Result<Arc<Self>, RuntimeError> {
|
||||
if !Self::handle_meets_requirements(&rt_handle) {
|
||||
return Err(RuntimeError::InvalidRuntime(
|
||||
"supplied tokio handle does not meet requirements \
|
||||
(missing drivers or wrong flavor)"
|
||||
.into(),
|
||||
));
|
||||
}
|
||||
Self::from_external_with_config(rt_handle, config)
|
||||
}
|
||||
|
||||
/// Wrap an existing tokio [`TokioRuntimeHandle`] with a [`XetRuntime`] using the provided
|
||||
/// [`XetConfig`]. No new thread pool is created; `spawn()` calls will schedule work on the
|
||||
/// runtime that owns `rt_handle`.
|
||||
///
|
||||
/// The resulting `XetRuntime` is registered in `EXTERNAL_RUNTIME_REGISTRY` so that
|
||||
/// [`XetRuntime::current()`] called from tasks running on `rt_handle`'s threads will return
|
||||
/// this instance (with the correct config and shared `XetCommon`) rather than a default
|
||||
/// throwaway. The entry is removed when the last `Arc<XetRuntime>` is dropped.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// - [`RuntimeError::ExternalAlreadyAttached`] — a live `XetRuntime` is already registered for `rt_handle`'s tokio
|
||||
/// runtime ID (i.e. the same handle was wrapped twice while the first is still alive). Drop the existing
|
||||
/// `XetRuntime` first, or use a different handle.
|
||||
/// Wraps an existing tokio handle with a new `XetRuntime`, using the config for
|
||||
/// system monitor setup.
|
||||
pub fn from_external_with_config(
|
||||
rt_handle: TokioRuntimeHandle,
|
||||
config: XetConfig,
|
||||
config: &XetConfig,
|
||||
) -> Result<Arc<Self>, RuntimeError> {
|
||||
#[cfg(feature = "fd-track")]
|
||||
let _fd_scope = track_fd_scope("XetRuntime::from_external_with_config");
|
||||
|
||||
let id = rt_handle.id();
|
||||
|
||||
let mut reg = EXTERNAL_RUNTIME_REGISTRY.write()?;
|
||||
let mut reg = EXTERNAL_THREADPOOL_REGISTRY.write()?;
|
||||
if let Some(existing) = reg.get(&id)
|
||||
&& existing.upgrade().is_some()
|
||||
{
|
||||
@@ -423,20 +282,8 @@ impl XetRuntime {
|
||||
handle_ref: rt_handle.into(),
|
||||
external_executor_count: 0.into(),
|
||||
sigint_shutdown: false.into(),
|
||||
common: XetCommon::new(&config),
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
system_monitor: config
|
||||
.system_monitor
|
||||
.enabled
|
||||
.then(|| {
|
||||
SystemMonitor::follow_process(
|
||||
config.system_monitor.sample_interval,
|
||||
config.system_monitor.log_path.clone(),
|
||||
)
|
||||
.ok()
|
||||
})
|
||||
.flatten(),
|
||||
config: Arc::new(config),
|
||||
system_monitor: system_monitor_for_config(config),
|
||||
});
|
||||
|
||||
reg.insert(id, Arc::downgrade(&rt));
|
||||
@@ -447,35 +294,30 @@ impl XetRuntime {
|
||||
Ok(rt)
|
||||
}
|
||||
|
||||
/// Wrap an existing tokio [`TokioRuntimeHandle`] with a [`XetRuntime`] using a default
|
||||
/// [`XetConfig`]. Prefer [`from_external_with_config`](Self::from_external_with_config) when
|
||||
/// you have a config available.
|
||||
///
|
||||
/// Unlike [`from_external_with_config`](Self::from_external_with_config), this function does
|
||||
/// **not** register the runtime in `EXTERNAL_RUNTIME_REGISTRY` and therefore performs no
|
||||
/// duplicate-handle check. It is intended for lightweight, short-lived wrapping where
|
||||
/// registry lookup via [`XetRuntime::current()`] is not required.
|
||||
/// Wraps an existing tokio handle without system monitoring.
|
||||
pub fn from_external(rt_handle: TokioRuntimeHandle) -> Arc<Self> {
|
||||
let config = XetConfig::new();
|
||||
Arc::new(Self {
|
||||
backend: RuntimeBackend::External { handle_id: None },
|
||||
handle_ref: rt_handle.into(),
|
||||
external_executor_count: 0.into(),
|
||||
sigint_shutdown: false.into(),
|
||||
common: XetCommon::new(&config),
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
system_monitor: config
|
||||
.system_monitor
|
||||
.enabled
|
||||
.then(|| {
|
||||
SystemMonitor::follow_process(
|
||||
config.system_monitor.sample_interval,
|
||||
config.system_monitor.log_path.clone(),
|
||||
)
|
||||
.ok()
|
||||
})
|
||||
.flatten(),
|
||||
config: Arc::new(config),
|
||||
system_monitor: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the current thread's active owned [`XetRuntime`], if any.
|
||||
///
|
||||
/// This is populated on owned runtime worker threads and on spawn_blocking
|
||||
/// threads created by an owned runtime. External runtimes do not set this.
|
||||
#[inline]
|
||||
pub fn current_if_exists() -> Option<Arc<Self>> {
|
||||
let pid = std::process::id();
|
||||
THREAD_THREADPOOL_REF.with_borrow(|entry| {
|
||||
entry
|
||||
.as_ref()
|
||||
.filter(|(entry_pid, _)| *entry_pid == pid)
|
||||
.and_then(|(_, weak_pool)| weak_pool.upgrade())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -484,36 +326,6 @@ impl XetRuntime {
|
||||
self.handle_ref.get().expect("Not initialized with handle set.").clone()
|
||||
}
|
||||
|
||||
/// Returns a reference to the shared `XetCommon` state.
|
||||
#[inline]
|
||||
pub fn common(&self) -> &XetCommon {
|
||||
&self.common
|
||||
}
|
||||
|
||||
/// Gets or creates a reqwest client, using a tag to identify the client type.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `tag` - A string identifier for the client (e.g., "tcp" for regular, socket path for UDS)
|
||||
/// * `f` - A function that creates the client if needed
|
||||
///
|
||||
/// # Returns
|
||||
/// Returns a clone of the cached client if the tag matches and we're in a runtime,
|
||||
/// or creates a new client otherwise. This allows creating high-level clients outside
|
||||
/// a runtime, like in tests.
|
||||
pub fn get_or_create_reqwest_client<F>(tag: String, f: F) -> crate::error::Result<Client>
|
||||
where
|
||||
F: FnOnce() -> std::result::Result<Client, reqwest::Error>,
|
||||
{
|
||||
// Cache the reqwest Client if we are running inside a runtime, otherwise
|
||||
// create a new one. This allows creating high-level clients outside a
|
||||
// runtime, like in tests.
|
||||
if let Some(rt) = Self::current_if_exists() {
|
||||
rt.common().get_or_create_reqwest_client(tag, f)
|
||||
} else {
|
||||
Ok(f()?)
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn num_worker_threads(&self) -> usize {
|
||||
self.handle().metrics().num_workers()
|
||||
@@ -552,7 +364,13 @@ impl XetRuntime {
|
||||
|
||||
// When a task is shut down, it will stop running at whichever .await it has yielded at. All local
|
||||
// variables are destroyed by running their destructor.
|
||||
let maybe_runtime = runtime_cell.write().expect("cancel_all called recursively.").take();
|
||||
let maybe_runtime = match runtime_cell.write() {
|
||||
Ok(mut guard) => guard.take(),
|
||||
Err(poisoned) => {
|
||||
eprintln!("WARNING: perform_sigint_shutdown encountered a poisoned runtime lock; continuing shutdown.");
|
||||
poisoned.into_inner().take()
|
||||
},
|
||||
};
|
||||
|
||||
let Some(runtime) = maybe_runtime else {
|
||||
eprintln!("WARNING: perform_sigint_shutdown called on runtime that has already been shut down.");
|
||||
@@ -641,7 +459,7 @@ impl XetRuntime {
|
||||
F::Output: Send + 'static,
|
||||
{
|
||||
// If the runtime has been shut down, this will immediately abort.
|
||||
debug!("threadpool: spawn called, {}", self);
|
||||
debug!("xet-runtime: spawn called, {}", self);
|
||||
self.handle().spawn(future)
|
||||
}
|
||||
|
||||
@@ -652,7 +470,7 @@ impl XetRuntime {
|
||||
/// channel (compatible with any executor).
|
||||
///
|
||||
/// This is the primary async entry point. Session-level async methods should call
|
||||
/// `self.runtime.bridge_async(...)`.
|
||||
/// `ctx.runtime.bridge_async(...)`.
|
||||
pub async fn bridge_async<T, F>(&self, task_name: &'static str, fut: F) -> Result<T, RuntimeError>
|
||||
where
|
||||
F: Future<Output = T> + Send + 'static,
|
||||
@@ -674,7 +492,7 @@ impl XetRuntime {
|
||||
/// `spawn_blocking` threads, OS threads, or the main thread is fine).
|
||||
///
|
||||
/// This is the primary sync entry point. Session-level `_blocking` methods
|
||||
/// should simply call `self.runtime.bridge_sync(...)`.
|
||||
/// should simply call `ctx.runtime.bridge_sync(...)`.
|
||||
pub fn bridge_sync<F>(&self, future: F) -> Result<F::Output, RuntimeError>
|
||||
where
|
||||
F: Future + Send + 'static,
|
||||
@@ -744,30 +562,21 @@ impl XetRuntime {
|
||||
}
|
||||
}
|
||||
|
||||
/// Spawn a blocking task on the runtime's blocking thread pool. The task runs with this
|
||||
/// runtime stored in thread-local storage so [`XetRuntime::current()`] works inside `f`.
|
||||
///
|
||||
/// The receiver must be an `Arc<XetRuntime>` so the runtime can be installed in the
|
||||
/// blocking thread (e.g. `rt.spawn_blocking(|| { ... })` where `rt: Arc<XetRuntime>`).
|
||||
/// Spawn a blocking task on the runtime's blocking thread pool. Installs a weak thread-local
|
||||
/// reference to this pool for the duration of `f`.
|
||||
pub fn spawn_blocking<F, R>(self: &Arc<Self>, f: F) -> JoinHandle<R>
|
||||
where
|
||||
F: FnOnce() -> R + Send + 'static,
|
||||
R: Send + 'static,
|
||||
{
|
||||
let rt_weak = Arc::downgrade(self);
|
||||
let pool_weak = Arc::downgrade(self);
|
||||
self.handle().spawn_blocking(move || {
|
||||
let pid = std::process::id();
|
||||
THREAD_RUNTIME_REF.set(Some((pid, rt_weak)));
|
||||
THREAD_THREADPOOL_REF.set(Some((pid, pool_weak)));
|
||||
f()
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns a reference to the primary configuration struct.
|
||||
#[inline]
|
||||
pub fn config(&self) -> &Arc<XetConfig> {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Returns the runtime mode (Owned or External).
|
||||
#[inline]
|
||||
pub fn mode(&self) -> RuntimeMode {
|
||||
@@ -777,6 +586,24 @@ impl XetRuntime {
|
||||
}
|
||||
}
|
||||
|
||||
/// Wraps a caller-provided tokio handle after validating that it meets requirements.
|
||||
///
|
||||
/// Not available on WASM targets.
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
pub fn from_validated_external(
|
||||
rt_handle: TokioRuntimeHandle,
|
||||
config: &XetConfig,
|
||||
) -> Result<Arc<Self>, RuntimeError> {
|
||||
if !Self::handle_meets_requirements(&rt_handle) {
|
||||
return Err(RuntimeError::InvalidRuntime(
|
||||
"supplied tokio handle does not meet requirements \
|
||||
(missing drivers or wrong flavor)"
|
||||
.into(),
|
||||
));
|
||||
}
|
||||
Self::from_external_with_config(rt_handle, config)
|
||||
}
|
||||
|
||||
/// Probe whether a tokio runtime handle meets the requirements for use as an
|
||||
/// External-mode runtime.
|
||||
///
|
||||
@@ -835,28 +662,26 @@ impl Drop for XetRuntime {
|
||||
|
||||
self.handle_ref.take();
|
||||
|
||||
if let RuntimeBackend::External { handle_id: Some(id) } = &self.backend {
|
||||
if let Ok(mut reg) = EXTERNAL_RUNTIME_REGISTRY.write() {
|
||||
reg.remove(id);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// When dropping from within an async context, the default TokioRuntime Drop
|
||||
// would panic ("Cannot drop a runtime in a context where blocking is not allowed").
|
||||
// Avoid this by taking ownership of the runtime and using shutdown_background(),
|
||||
// which spawns a thread for the blocking shutdown work instead.
|
||||
let in_async_context = TokioRuntimeHandle::try_current().is_ok();
|
||||
if let RuntimeBackend::OwnedThreadPool { runtime } = &self.backend
|
||||
&& let Ok(mut guard) = runtime.write()
|
||||
&& let Some(rt_arc) = guard.take()
|
||||
&& let Ok(rt) = Arc::try_unwrap(rt_arc)
|
||||
{
|
||||
if in_async_context {
|
||||
rt.shutdown_background();
|
||||
} else {
|
||||
rt.shutdown_timeout(std::time::Duration::from_secs(5));
|
||||
}
|
||||
match &self.backend {
|
||||
RuntimeBackend::External { handle_id: Some(id) } => {
|
||||
if let Ok(mut reg) = EXTERNAL_THREADPOOL_REGISTRY.write() {
|
||||
reg.remove(id);
|
||||
}
|
||||
},
|
||||
RuntimeBackend::External { handle_id: None } => {},
|
||||
RuntimeBackend::OwnedThreadPool { runtime } => {
|
||||
let in_async_context = TokioRuntimeHandle::try_current().is_ok();
|
||||
if let Ok(mut guard) = runtime.write()
|
||||
&& let Some(rt_arc) = guard.take()
|
||||
&& let Ok(rt) = Arc::try_unwrap(rt_arc)
|
||||
{
|
||||
if in_async_context {
|
||||
rt.shutdown_background();
|
||||
} else {
|
||||
rt.shutdown_timeout(std::time::Duration::from_secs(5));
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -892,95 +717,73 @@ impl Display for XetRuntime {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_get_or_create_reqwest_client_returns_client() {
|
||||
let result =
|
||||
XetRuntime::get_or_create_reqwest_client("test".to_string(), || reqwest::Client::builder().build());
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_spawn_blocking_sets_current_runtime() {
|
||||
let rt = XetRuntime::new().expect("Failed to create runtime");
|
||||
let rt_clone = rt.clone();
|
||||
let jh = rt.spawn_blocking(move || {
|
||||
let current = XetRuntime::current();
|
||||
Arc::ptr_eq(¤t, &rt_clone)
|
||||
});
|
||||
let same = rt.bridge_sync(async { jh.await.unwrap() }).unwrap();
|
||||
assert!(same);
|
||||
}
|
||||
|
||||
/// current_if_exists() must return the session-owned XetRuntime (with the correct config)
|
||||
/// when called from tasks on an External-mode runtime, not a default-config throwaway.
|
||||
#[test]
|
||||
fn test_current_if_exists_sees_external_runtime_config() {
|
||||
let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
|
||||
let mut config = XetConfig::new();
|
||||
config.data.default_cas_endpoint = "https://test-endpoint.example.com".into();
|
||||
let xet_rt = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), config).unwrap();
|
||||
|
||||
// current_if_exists() from within the runtime must find the registered entry.
|
||||
tokio_rt.block_on(async {
|
||||
let found = XetRuntime::current_if_exists().expect("should find a runtime");
|
||||
assert!(Arc::ptr_eq(&found, &xet_rt), "must be the same XetRuntime instance");
|
||||
assert_eq!(found.config().data.default_cas_endpoint, "https://test-endpoint.example.com");
|
||||
});
|
||||
|
||||
// After drop the entry is removed; current_if_exists() falls back to a default wrapper.
|
||||
drop(xet_rt);
|
||||
tokio_rt.block_on(async {
|
||||
let found = XetRuntime::current_if_exists().expect("should still find a runtime");
|
||||
assert_ne!(found.config().data.default_cas_endpoint, "https://test-endpoint.example.com");
|
||||
});
|
||||
}
|
||||
use crate::core::XetContext;
|
||||
|
||||
#[test]
|
||||
fn test_bridge_async_owned_mode_runs_on_pool() {
|
||||
let rt = XetRuntime::new().unwrap();
|
||||
assert_eq!(rt.mode(), RuntimeMode::Owned);
|
||||
let result = rt.bridge_sync(async {
|
||||
let inner_rt = XetRuntime::new().unwrap();
|
||||
inner_rt.bridge_async("test", async { 42 }).await.unwrap()
|
||||
});
|
||||
let ctx = XetContext::default().unwrap();
|
||||
assert_eq!(ctx.runtime.mode(), RuntimeMode::Owned);
|
||||
let rt = ctx.runtime.clone();
|
||||
let result = ctx
|
||||
.runtime
|
||||
.bridge_sync(async move { rt.bridge_async("test", async { 42 }).await.unwrap() });
|
||||
assert_eq!(result.unwrap(), 42);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bridge_async_external_mode_runs_directly() {
|
||||
let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
|
||||
let xet_rt = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), XetConfig::new()).unwrap();
|
||||
assert_eq!(xet_rt.mode(), RuntimeMode::External);
|
||||
let config = XetConfig::new();
|
||||
let rt = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), &config).unwrap();
|
||||
let ctx = XetContext::new(config, rt);
|
||||
assert_eq!(ctx.runtime.mode(), RuntimeMode::External);
|
||||
|
||||
let result = tokio_rt.block_on(async { xet_rt.bridge_async("test", async { 99 }).await.unwrap() });
|
||||
let result = tokio_rt.block_on(async { ctx.runtime.bridge_async("test", async { 99 }).await.unwrap() });
|
||||
assert_eq!(result, 99);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bridge_sync_owned_mode() {
|
||||
let rt = XetRuntime::new().unwrap();
|
||||
assert_eq!(rt.mode(), RuntimeMode::Owned);
|
||||
let result = rt.bridge_sync(async { 123 }).unwrap();
|
||||
let ctx = XetContext::default().unwrap();
|
||||
assert_eq!(ctx.runtime.mode(), RuntimeMode::Owned);
|
||||
let result = ctx.runtime.bridge_sync(async { 123 }).unwrap();
|
||||
assert_eq!(result, 123);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_reuses_owned_xet_runtime_from_tls() {
|
||||
let parent = XetContext::default().unwrap();
|
||||
let parent_runtime = parent.runtime.clone();
|
||||
let parent_config = parent.config.clone();
|
||||
|
||||
let child = parent
|
||||
.runtime
|
||||
.bridge_sync(async move { XetContext::default().unwrap() })
|
||||
.unwrap();
|
||||
|
||||
assert!(Arc::ptr_eq(&child.runtime, &parent_runtime));
|
||||
assert!(!Arc::ptr_eq(&child.config, &parent_config));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bridge_sync_from_spawn_blocking_owned_mode() {
|
||||
let rt = XetRuntime::new().unwrap();
|
||||
let rt_clone = rt.clone();
|
||||
let jh = rt.spawn_blocking(move || rt_clone.bridge_sync(async { 456 }).unwrap());
|
||||
let result = rt.bridge_sync(async { jh.await.unwrap() }).unwrap();
|
||||
let ctx = XetContext::default().unwrap();
|
||||
let rt = ctx.runtime.clone();
|
||||
let rt2 = ctx.runtime.clone();
|
||||
let jh = rt.spawn_blocking(move || rt2.bridge_sync(async { 456 }).unwrap());
|
||||
let result = ctx.runtime.bridge_sync(async { jh.await.unwrap() }).unwrap();
|
||||
assert_eq!(result, 456);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bridge_sync_external_mode_returns_error() {
|
||||
let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
|
||||
let xet_rt = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), XetConfig::new()).unwrap();
|
||||
assert_eq!(xet_rt.mode(), RuntimeMode::External);
|
||||
let config = XetConfig::new();
|
||||
let rt = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), &config).unwrap();
|
||||
let ctx = XetContext::new(config, rt);
|
||||
assert_eq!(ctx.runtime.mode(), RuntimeMode::External);
|
||||
|
||||
let result = xet_rt.bridge_sync(async { 789 });
|
||||
let result = ctx.runtime.bridge_sync(async { 789 });
|
||||
assert!(matches!(result, Err(RuntimeError::InvalidRuntime(_))));
|
||||
}
|
||||
|
||||
@@ -1009,15 +812,17 @@ mod tests {
|
||||
#[test]
|
||||
fn test_from_validated_external_accepts_valid_handle() {
|
||||
let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
|
||||
let xet_rt = XetRuntime::from_validated_external(tokio_rt.handle().clone(), XetConfig::new()).unwrap();
|
||||
assert_eq!(xet_rt.mode(), RuntimeMode::External);
|
||||
let config = XetConfig::new();
|
||||
let rt = XetRuntime::from_validated_external(tokio_rt.handle().clone(), &config).unwrap();
|
||||
assert_eq!(rt.mode(), RuntimeMode::External);
|
||||
}
|
||||
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
#[test]
|
||||
fn test_from_validated_external_rejects_current_thread_runtime() {
|
||||
let tokio_rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap();
|
||||
let result = XetRuntime::from_validated_external(tokio_rt.handle().clone(), XetConfig::new());
|
||||
let config = XetConfig::new();
|
||||
let result = XetRuntime::from_validated_external(tokio_rt.handle().clone(), &config);
|
||||
assert!(matches!(result, Err(RuntimeError::InvalidRuntime(_))));
|
||||
}
|
||||
|
||||
@@ -1025,16 +830,17 @@ mod tests {
|
||||
#[test]
|
||||
fn test_from_validated_external_rejects_runtime_without_drivers() {
|
||||
let tokio_rt = tokio::runtime::Builder::new_multi_thread().build().unwrap();
|
||||
let result = XetRuntime::from_validated_external(tokio_rt.handle().clone(), XetConfig::new());
|
||||
let config = XetConfig::new();
|
||||
let result = XetRuntime::from_validated_external(tokio_rt.handle().clone(), &config);
|
||||
assert!(matches!(result, Err(RuntimeError::InvalidRuntime(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bridge_async_owned_mode_catches_panic() {
|
||||
let rt = XetRuntime::new().unwrap();
|
||||
let rt2 = rt.clone();
|
||||
let result = rt.bridge_sync(async move {
|
||||
rt2.bridge_async("panic_test", async {
|
||||
let ctx = XetContext::default().unwrap();
|
||||
let rt = ctx.runtime.clone();
|
||||
let result = ctx.runtime.bridge_sync(async move {
|
||||
rt.bridge_async("panic_test", async {
|
||||
panic!("intentional test panic");
|
||||
})
|
||||
.await
|
||||
@@ -1044,37 +850,95 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
// Wrapping the same tokio handle a second time (while the first XetRuntime is alive)
|
||||
// must return ExternalAlreadyAttached.
|
||||
fn test_from_external_with_config_duplicate_handle_fails() {
|
||||
fn test_context_config_preserved_through_external() {
|
||||
let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
|
||||
let _first = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), XetConfig::new()).unwrap();
|
||||
let second = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), XetConfig::new());
|
||||
assert!(
|
||||
matches!(second, Err(RuntimeError::ExternalAlreadyAttached(_))),
|
||||
"expected ExternalAlreadyAttached for duplicate handle, got: {second:?}"
|
||||
);
|
||||
let mut config = XetConfig::new();
|
||||
config.data.default_cas_endpoint = "https://test-endpoint.example.com".into();
|
||||
let rt = XetRuntime::from_external(tokio_rt.handle().clone());
|
||||
let ctx = XetContext::new(config, rt);
|
||||
assert_eq!(ctx.config.data.default_cas_endpoint, "https://test-endpoint.example.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
// After the first XetRuntime is dropped (deregistered), wrapping the same handle again
|
||||
// must succeed.
|
||||
fn test_from_external_with_config_reuse_handle_after_drop() {
|
||||
fn test_check_sigint_shutdown_not_triggered() {
|
||||
let ctx = XetContext::default().unwrap();
|
||||
assert!(ctx.check_sigint_shutdown().is_ok());
|
||||
}
|
||||
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
#[test]
|
||||
fn test_from_external_with_config_rejects_second_attach() {
|
||||
let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
|
||||
let first = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), XetConfig::new()).unwrap();
|
||||
let config = XetConfig::new();
|
||||
|
||||
let first = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), &config).unwrap();
|
||||
let second = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), &config);
|
||||
|
||||
assert!(matches!(second, Err(RuntimeError::ExternalAlreadyAttached(_))));
|
||||
drop(first);
|
||||
let second = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), XetConfig::new());
|
||||
assert!(second.is_ok(), "expected Ok after previous XetRuntime was dropped, got: {second:?}");
|
||||
}
|
||||
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
#[test]
|
||||
// Two distinct tokio runtimes must each accept their own XetRuntime without conflict.
|
||||
fn test_from_external_with_config_distinct_handles_both_succeed() {
|
||||
let rt_a = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
|
||||
let rt_b = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
|
||||
let xet_a = XetRuntime::from_external_with_config(rt_a.handle().clone(), XetConfig::new());
|
||||
let xet_b = XetRuntime::from_external_with_config(rt_b.handle().clone(), XetConfig::new());
|
||||
assert!(xet_a.is_ok());
|
||||
assert!(xet_b.is_ok());
|
||||
fn test_perform_sigint_shutdown_tolerates_poisoned_runtime_lock() {
|
||||
let ctx = XetContext::default().unwrap();
|
||||
let runtime = ctx.runtime.clone();
|
||||
let runtime_cell = runtime.runtime_cell_if_owned().unwrap().clone();
|
||||
|
||||
let _ = std::thread::spawn(move || {
|
||||
let _guard = runtime_cell.write().unwrap();
|
||||
panic!("intentional poison for test");
|
||||
})
|
||||
.join();
|
||||
|
||||
let shutdown_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
|
||||
runtime.perform_sigint_shutdown();
|
||||
}));
|
||||
assert!(shutdown_result.is_ok());
|
||||
}
|
||||
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
#[test]
|
||||
fn test_sigint_shutdown_causes_keyboard_interrupt_on_bridges() {
|
||||
let ctx = XetContext::default().unwrap();
|
||||
ctx.runtime.perform_sigint_shutdown();
|
||||
|
||||
let sync_result = ctx.runtime.bridge_sync(async { 1 });
|
||||
assert!(matches!(sync_result, Err(RuntimeError::KeyboardInterrupt)));
|
||||
|
||||
let tokio_rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap();
|
||||
let tp = ctx.runtime.clone();
|
||||
let async_result = tokio_rt.block_on(async move { tp.bridge_async("sigint_test", async { 1 }).await });
|
||||
assert!(matches!(async_result, Err(RuntimeError::KeyboardInterrupt)));
|
||||
}
|
||||
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
#[test]
|
||||
fn test_concurrent_bridge_sync_stress() {
|
||||
use std::sync::Barrier;
|
||||
|
||||
let ctx = XetContext::default().unwrap();
|
||||
let n = 200;
|
||||
let barrier = Arc::new(Barrier::new(n));
|
||||
let sum = Arc::new(AtomicUsize::new(0));
|
||||
|
||||
let handles: Vec<_> = (0..n)
|
||||
.map(|i| {
|
||||
let tp = ctx.runtime.clone();
|
||||
let barrier = barrier.clone();
|
||||
let sum = sum.clone();
|
||||
std::thread::spawn(move || {
|
||||
barrier.wait();
|
||||
let result = tp.bridge_sync(async move { i }).unwrap();
|
||||
sum.fetch_add(result, Ordering::Relaxed);
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
for h in handles {
|
||||
h.join().unwrap();
|
||||
}
|
||||
|
||||
assert_eq!(sum.load(Ordering::Relaxed), (0..n).sum::<usize>());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,6 +16,5 @@ pub mod utils;
|
||||
pub use utils::configuration_utils;
|
||||
pub mod config;
|
||||
pub mod core;
|
||||
#[cfg(debug_assertions)]
|
||||
pub mod fd_diagnostics;
|
||||
pub mod logging;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::core::xet_config;
|
||||
use crate::config::XetConfig;
|
||||
use crate::utils::TemplatedPathBuf;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
@@ -11,9 +11,7 @@ pub enum LoggingMode {
|
||||
Console,
|
||||
}
|
||||
|
||||
/// The log directory cleanup configuration. By default, the values
|
||||
/// are loaded from environment variables.
|
||||
|
||||
/// The log directory cleanup configuration.
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
pub struct LogDirConfig {
|
||||
pub min_deletion_age: Duration,
|
||||
@@ -22,14 +20,13 @@ pub struct LogDirConfig {
|
||||
pub filename_prefix: String,
|
||||
}
|
||||
|
||||
impl Default for LogDirConfig {
|
||||
fn default() -> Self {
|
||||
// Load the defaults from the environmental config.
|
||||
impl LogDirConfig {
|
||||
pub fn from_config(config: &XetConfig) -> Self {
|
||||
Self {
|
||||
min_deletion_age: xet_config().log.dir_min_deletion_age,
|
||||
max_retention_age: xet_config().log.dir_max_retention_age,
|
||||
size_limit: xet_config().log.dir_max_size.as_u64(),
|
||||
filename_prefix: xet_config().log.prefix.to_string(),
|
||||
min_deletion_age: config.log.dir_min_deletion_age,
|
||||
max_retention_age: config.log.dir_max_retention_age,
|
||||
size_limit: config.log.dir_max_size.as_u64(),
|
||||
filename_prefix: config.log.prefix.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -44,11 +41,10 @@ pub struct LoggingConfig {
|
||||
}
|
||||
|
||||
impl LoggingConfig {
|
||||
/// Set up logging to a directory. Note that this can be overwritten by environmental
|
||||
pub fn default_to_directory(version: String, log_directory: impl AsRef<Path>) -> LoggingConfig {
|
||||
// Choose the logging mode.
|
||||
/// Set up logging to a directory using the given config.
|
||||
pub fn from_directory(config: &XetConfig, version: String, log_directory: impl AsRef<Path>) -> LoggingConfig {
|
||||
let logging_mode = {
|
||||
if let Some(log_dest) = &xet_config().log.dest {
|
||||
if let Some(log_dest) = &config.log.dest {
|
||||
if log_dest.as_str().is_empty() {
|
||||
LoggingMode::Console
|
||||
} else {
|
||||
@@ -69,7 +65,7 @@ impl LoggingConfig {
|
||||
};
|
||||
|
||||
let use_json = {
|
||||
if let Some(format) = &xet_config().log.format {
|
||||
if let Some(format) = &config.log.format {
|
||||
format.as_str().to_ascii_lowercase().trim() == "json"
|
||||
} else {
|
||||
logging_mode != LoggingMode::Console
|
||||
@@ -77,14 +73,14 @@ impl LoggingConfig {
|
||||
};
|
||||
|
||||
let enable_log_dir_cleanup =
|
||||
matches!(logging_mode, LoggingMode::Directory(_)) && !xet_config().log.dir_disable_cleanup;
|
||||
matches!(logging_mode, LoggingMode::Directory(_)) && !config.log.dir_disable_cleanup;
|
||||
|
||||
Self {
|
||||
logging_mode,
|
||||
use_json,
|
||||
enable_log_dir_cleanup,
|
||||
version,
|
||||
log_dir_config: LogDirConfig::default(),
|
||||
log_dir_config: LogDirConfig::from_config(config),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -398,7 +398,7 @@ mod tests {
|
||||
#[test]
|
||||
fn round_trip_make_and_parse() {
|
||||
let dir = Path::new("/tmp");
|
||||
let cfg = LogDirConfig::default();
|
||||
let cfg = LogDirConfig::from_config(&crate::config::XetConfig::new());
|
||||
let path = log_file_in_dir(&cfg, dir);
|
||||
let (base, ts, pid) = parse_log_file_name(&path).expect("parse");
|
||||
assert_eq!(base, cfg.filename_prefix);
|
||||
|
||||
@@ -264,7 +264,7 @@ macro_rules! test_set_constants {
|
||||
|
||||
#[cfg(not(doctest))]
|
||||
/// A macro for **tests** that sets config group environment variables **before**
|
||||
/// XetRuntime is initialized. The environment variables follow the pattern
|
||||
/// XetContext is initialized. The environment variables follow the pattern
|
||||
/// `HF_XET_{GROUP_NAME}_{FIELD_NAME}`.
|
||||
///
|
||||
/// This macro uses `ctor` to run on module load, ensuring environment variables
|
||||
|
||||
@@ -407,7 +407,7 @@ pub(crate) mod tests {
|
||||
|
||||
use super::super::errors::SingleflightError;
|
||||
use super::{Call, Group, OwnerTask};
|
||||
use crate::core::XetRuntime;
|
||||
use crate::core::XetContext;
|
||||
|
||||
/// A period of time for waiters to wait for a notification from the owner
|
||||
/// task. This is expected to be sufficient time for the test futures to
|
||||
@@ -428,10 +428,11 @@ pub(crate) mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_with_threadpool() {
|
||||
let threadpool = Arc::new(XetRuntime::new().unwrap());
|
||||
fn test_simple_with_xet_runtime() {
|
||||
let ctx = XetContext::default().unwrap();
|
||||
let g = Group::new();
|
||||
let res = threadpool
|
||||
let res = ctx
|
||||
.runtime
|
||||
.bridge_sync(async move { g.work("key", return_res()).await })
|
||||
.unwrap()
|
||||
.0;
|
||||
@@ -449,17 +450,17 @@ pub(crate) mod tests {
|
||||
|
||||
#[test]
|
||||
#[cfg_attr(feature = "smoke-test", ignore)]
|
||||
fn test_multiple_threads_with_threadpool() {
|
||||
fn test_multiple_threads_with_xet_runtime() {
|
||||
let times_called = Arc::new(AtomicU32::new(0));
|
||||
let threadpool = Arc::new(XetRuntime::new().unwrap());
|
||||
let ctx = XetContext::default().unwrap();
|
||||
let g: Arc<Group<usize, ()>> = Arc::new(Group::new());
|
||||
let mut handlers: Vec<JoinHandle<(usize, bool)>> = Vec::new();
|
||||
let threadpool_ = threadpool.clone();
|
||||
let runtime = ctx.runtime.clone();
|
||||
let tasks = async move {
|
||||
for _ in 0..10 {
|
||||
let g = g.clone();
|
||||
let counter = times_called.clone();
|
||||
handlers.push(threadpool_.spawn(async move {
|
||||
handlers.push(runtime.spawn(async move {
|
||||
let tup = g.work("key", expensive_fn(counter, RES)).await;
|
||||
let res = tup.0;
|
||||
let fn_response = res.unwrap();
|
||||
@@ -479,7 +480,7 @@ pub(crate) mod tests {
|
||||
assert_eq!(1, num_callers);
|
||||
assert_eq!(1, times_called.load(Ordering::SeqCst));
|
||||
};
|
||||
threadpool.bridge_sync(tasks).unwrap();
|
||||
ctx.runtime.bridge_sync(tasks).unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -20,7 +20,7 @@ fn main() {
|
||||
use_json: true,
|
||||
enable_log_dir_cleanup: true,
|
||||
version: "test".to_string(),
|
||||
log_dir_config: LogDirConfig::default(),
|
||||
log_dir_config: LogDirConfig::from_config(&xet_runtime::config::XetConfig::new()),
|
||||
};
|
||||
|
||||
init(config);
|
||||
|
||||
Reference in New Issue
Block a user