mirror of
https://github.com/huggingface/xet-core.git
synced 2026-06-04 13:30:29 +08:00
Error unification and cleanup (#737)
This PR performs some housecleaning and removes some technical debt
around using different error types, unifying them with the python
interface.
- Our client code tended to do a lot with anyhow errors as an artifact
of first using them before switching to thiserror. This PR cleans these
up in favor of using ClientError or other named error types directly.
- It also removes all the aliases to the old error type names present in
the packages before the refactoring, now settling into ClientError,
FormatError, DataError, and RuntimeError, with XetError being the error
type exposed publicly.
- Also, currently, xet_session exposes SessionError as an alias of
XetError, which adds an extra public type name without adding behavior.
This PR removes that alias and standardizes the public API/docs onto
XetError directly.
-It also tightens Python-facing error behavior and moves the python
handling to the XetError class directly, hidden behind a python feature
flag. Using these types, hf_xet now registers XetObjectNotFoundError and
XetAuthenticationError exception classes for authentication and the
not-found cases. These inherit from the current exception classes, so
all behavior is preserved.
- In addition, the From for PyErr mapping routes
timeout/network/auth/not-found categories to more appropriate Python
exception types than simply RuntimeError.
This is primarily an API-surface cleanup plus error-classification
alignment.
<!-- CURSOR_SUMMARY -->
---
> [!NOTE]
> **Medium Risk**
> API-breaking error-surface changes (removal of legacy alias modules
and signature changes like `CredentialHelper::fill_credential`) may
require downstream code updates, especially where errors are
matched/converted. Runtime behavior should be mostly unchanged, but
error mapping/propagation paths (including Python exceptions) are widely
touched across crates.
>
> **Overview**
> This PR **unifies error types across the workspace** by removing
legacy re-export/alias modules (e.g. `CasClientError`, `CasTypesError`,
`DataProcessingError`, `SessionError`) and updating call sites to use
canonical errors like `xet_client::ClientError`,
`xet_core_structures::CoreError`, and `xet_data::DataError` directly.
>
> It updates CAS client code to **standardize on
`crate::error::Result`/`ClientError`**, including deleting
`cas_client/error.rs`, adjusting error conversions in retry/http
middleware paths, and updating simulation/local-server code to map
`ClientError` to HTTP responses.
>
> Python bindings (`hf_xet`) now **convert failures via `XetError`**
(with `xet_pkg` built with `python` support), register custom exceptions
on module init, and refine argument-validation errors to `PyValueError`
while routing network/timeout/auth/not-found to more appropriate Python
exception classes.
>
> Misc cleanup: `git_xet` now depends on `xet-data`, simulation binaries
switch to `anyhow::Result`/`bail!`, and lockfiles are updated for
new/updated dependencies (notably `pyo3`/`inventory`).
>
> <sup>Written by [Cursor
Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit
f3d056a909. This will update automatically
on new commits. Configure
[here](https://cursor.com/dashboard?tab=bugbot).</sup>
<!-- /CURSOR_SUMMARY -->
This commit is contained in:
4
Cargo.lock
generated
4
Cargo.lock
generated
@@ -1655,6 +1655,7 @@ dependencies = [
|
||||
"thiserror 2.0.18",
|
||||
"tokio",
|
||||
"xet-client",
|
||||
"xet-data",
|
||||
"xet-runtime",
|
||||
]
|
||||
|
||||
@@ -1879,6 +1880,7 @@ dependencies = [
|
||||
"futures",
|
||||
"http",
|
||||
"more-asserts",
|
||||
"pyo3",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serial_test",
|
||||
@@ -4541,6 +4543,7 @@ checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa"
|
||||
name = "simulation"
|
||||
version = "1.4.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bytes",
|
||||
"chrono",
|
||||
"clap",
|
||||
@@ -6220,7 +6223,6 @@ dependencies = [
|
||||
name = "xet-core-structures"
|
||||
version = "1.4.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"base64 0.22.1",
|
||||
"bincode",
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
# API Update: Anyhow Removal and Error Alias Cleanup (2026-03-18)
|
||||
|
||||
## Overview
|
||||
|
||||
This change removes direct `anyhow` usage from core crates and finalizes migration to canonical
|
||||
package-level error types.
|
||||
|
||||
Main impact:
|
||||
- public legacy error alias modules were removed;
|
||||
- several public error enum variants changed payload types;
|
||||
- `SessionError` alias usage was removed in favor of `xet::XetError`;
|
||||
- Python exception mapping was tightened for `hf_xet`;
|
||||
- downstream imports should now use canonical error paths directly.
|
||||
|
||||
This is an API-breaking cleanup for callers still importing old alias paths.
|
||||
|
||||
## Canonical Error Types (now required)
|
||||
|
||||
Use these types directly:
|
||||
- `xet_client::ClientError`
|
||||
- `xet_core_structures::CoreError`
|
||||
- `xet_data::DataError`
|
||||
- `xet::XetError`
|
||||
- `xet_runtime::RuntimeError`
|
||||
|
||||
## Removed Legacy Alias Modules
|
||||
|
||||
The following compatibility modules were removed:
|
||||
- `xet_client::cas_client::error` (previous `CasClientError` alias)
|
||||
- `xet_client::cas_types::error` (previous `CasTypesError` alias)
|
||||
- `xet_client::hub_client::errors` (previous `HubClientError` alias)
|
||||
- `xet_core_structures::metadata_shard::error` (previous `MDBShardError` alias)
|
||||
- `xet_core_structures::xorb_object::error` (previous `XorbObjectError` alias)
|
||||
- `xet_data::processing::errors` (previous `DataProcessingError` alias)
|
||||
- `xet::xet_session::errors` (previous `SessionError` alias)
|
||||
- `xet_runtime::core::errors` (previous `MultithreadedRuntimeError` alias)
|
||||
|
||||
## Breaking Type/Variant Changes
|
||||
|
||||
### `xet_client::ClientError`
|
||||
- `InternalError(anyhow::Error)` -> `InternalError(String)`
|
||||
- `CredentialHelper(anyhow::Error)` -> `CredentialHelper(String)`
|
||||
|
||||
### `xet_core_structures::CoreError` (renamed from `FormatError`)
|
||||
- `Internal(anyhow::Error)` -> `InternalError(String)`
|
||||
- `Format(anyhow::Error)` -> `MalformedData(String)` (or a more specific `CoreError` variant)
|
||||
|
||||
### `xet::xet_session` / `xet::XetError`
|
||||
- `xet::xet_session::SessionError` alias was removed.
|
||||
- Public session APIs now return `Result<_, xet::XetError>`.
|
||||
- `ClientError::PresignedUrlExpirationError` now maps to `XetError::Authentication`.
|
||||
- `XetError::Timeout(String)` is used for timeout-class network failures.
|
||||
|
||||
Code matching old variant names or payload types must be updated.
|
||||
|
||||
## Trait Signature Change
|
||||
|
||||
`xet_client::hub_client::CredentialHelper` now uses:
|
||||
|
||||
```rust
|
||||
async fn fill_credential(&self, req: RequestBuilder) -> Result<RequestBuilder, xet_client::ClientError>;
|
||||
```
|
||||
|
||||
## Migration Guide
|
||||
|
||||
Replace old imports and aliases directly:
|
||||
|
||||
```rust
|
||||
// old
|
||||
use xet_client::cas_client::CasClientError;
|
||||
use xet_data::processing::errors::DataProcessingError;
|
||||
use xet::xet_session::SessionError;
|
||||
use xet_runtime::core::errors::MultithreadedRuntimeError;
|
||||
|
||||
// new
|
||||
use xet_client::ClientError;
|
||||
use xet_data::DataError;
|
||||
use xet::XetError;
|
||||
use xet_runtime::RuntimeError;
|
||||
```
|
||||
|
||||
## Python (`hf_xet`) behavior
|
||||
|
||||
`From<XetError> for PyErr` now maps:
|
||||
- `Authentication` -> `hf_xet.XetAuthenticationError` (inherits `PermissionError`)
|
||||
- `NotFound` -> `hf_xet.XetObjectNotFoundError` (inherits `FileNotFoundError`)
|
||||
- `Network` -> `ConnectionError`
|
||||
- `Timeout` -> `TimeoutError`
|
||||
- `Cancelled` -> `RuntimeError`
|
||||
|
||||
For constructors that previously accepted `anyhow::Error`, construct string-backed variants
|
||||
instead (`InternalError`, `CredentialHelper`, `MalformedData`, and related `CoreError` variants).
|
||||
|
||||
## Behavior Notes
|
||||
|
||||
No intended runtime behavior change in retry, session, or reconstruction logic; this update is
|
||||
primarily an error-surface and API cleanup.
|
||||
@@ -9,6 +9,7 @@ name = "git-xet"
|
||||
path = "src/bin/main.rs"
|
||||
|
||||
[dependencies]
|
||||
xet-data = { path = "../xet_data" }
|
||||
xet-runtime = { path = "../xet_runtime" }
|
||||
xet-client = { path = "../xet_client" }
|
||||
xet-pkg = { package = "hf-xet", path = "../xet_pkg" }
|
||||
|
||||
@@ -4,6 +4,7 @@ use async_trait::async_trait;
|
||||
use reqwest::header;
|
||||
use reqwest_middleware::RequestBuilder;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use xet_client::error::ClientError;
|
||||
use xet_client::hub_client::{CredentialHelper, Operation};
|
||||
|
||||
use crate::errors::{GitXetError, Result};
|
||||
@@ -73,8 +74,8 @@ impl SSHCredentialHelper {
|
||||
|
||||
#[async_trait]
|
||||
impl CredentialHelper for SSHCredentialHelper {
|
||||
async fn fill_credential(&self, req: RequestBuilder) -> anyhow::Result<RequestBuilder> {
|
||||
let authenticated = self.authenticate().await?;
|
||||
async fn fill_credential(&self, req: RequestBuilder) -> std::result::Result<RequestBuilder, ClientError> {
|
||||
let authenticated = self.authenticate().await.map_err(ClientError::credential_helper_error)?;
|
||||
Ok(req.header(header::AUTHORIZATION, authenticated.header.authorization))
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::path::PathBuf;
|
||||
|
||||
use thiserror::Error;
|
||||
use xet_client::ClientError;
|
||||
use xet_pkg::legacy::DataProcessingError;
|
||||
use xet_data::DataError;
|
||||
|
||||
use crate::lfs_agent_protocol::GitLFSProtocolError;
|
||||
|
||||
@@ -40,7 +40,7 @@ pub enum GitXetError {
|
||||
Internal(String),
|
||||
|
||||
#[error("Transfer agent error: {0}")]
|
||||
TransferAgent(#[from] DataProcessingError),
|
||||
TransferAgent(#[from] DataError),
|
||||
|
||||
#[error("Client error: {0}")]
|
||||
Client(#[from] ClientError),
|
||||
|
||||
@@ -220,15 +220,13 @@ pub fn to_line_delimited_json_string(value: impl Serialize) -> Result<String> {
|
||||
mod tests {
|
||||
use std::path::Path;
|
||||
|
||||
use anyhow::Result;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_protocol_serde_unknown_event() -> Result<()> {
|
||||
let message = r#"
|
||||
{ "event": "other", "operation": "upload", "remote": "origin", "concurrent": false }"#;
|
||||
let parsed: Result<LFSProtocolRequestEvent, GitLFSProtocolError> = message.parse();
|
||||
let parsed: std::result::Result<LFSProtocolRequestEvent, GitLFSProtocolError> = message.parse();
|
||||
|
||||
assert!(matches!(parsed, Err(GitLFSProtocolError::Syntax(_))));
|
||||
|
||||
@@ -269,21 +267,21 @@ mod tests {
|
||||
// init event with invalid operation
|
||||
let message1 = r#"
|
||||
{ "event": "init", "operation": "other", "remote": "origin", "concurrent": false }"#;
|
||||
let parsed1: Result<LFSProtocolRequestEvent, GitLFSProtocolError> = message1.parse();
|
||||
let parsed1: std::result::Result<LFSProtocolRequestEvent, GitLFSProtocolError> = message1.parse();
|
||||
|
||||
assert!(matches!(parsed1, Err(GitLFSProtocolError::Syntax(_))));
|
||||
|
||||
// init event missing required field
|
||||
let message2 = r#"
|
||||
{ "event": "init", "operation": "upload", "remote": "origin" }"#;
|
||||
let parsed2: Result<LFSProtocolRequestEvent, GitLFSProtocolError> = message2.parse();
|
||||
let parsed2: std::result::Result<LFSProtocolRequestEvent, GitLFSProtocolError> = message2.parse();
|
||||
|
||||
assert!(matches!(parsed2, Err(GitLFSProtocolError::Syntax(_))));
|
||||
|
||||
// init event with invalid remote
|
||||
let message2 = r#"
|
||||
{ "event": "init", "operation": "upload", "remote": "", "concurrent": false }"#;
|
||||
let parsed2: Result<LFSProtocolRequestEvent, GitLFSProtocolError> = message2.parse();
|
||||
let parsed2: std::result::Result<LFSProtocolRequestEvent, GitLFSProtocolError> = message2.parse();
|
||||
|
||||
assert!(matches!(parsed2, Err(GitLFSProtocolError::Argument(_))));
|
||||
|
||||
@@ -354,7 +352,7 @@ mod tests {
|
||||
let message1 = r#"
|
||||
{ "event": "upload", "oid": "bf3e3e2af9366a3b704ae0c31de5afa64193ebabffde2091936ad2e7510bc03a", "size": 346232,
|
||||
"path": "/path/to/file.png" }"#;
|
||||
let parsed1: Result<LFSProtocolRequestEvent, GitLFSProtocolError> = message1.parse();
|
||||
let parsed1: std::result::Result<LFSProtocolRequestEvent, GitLFSProtocolError> = message1.parse();
|
||||
|
||||
assert!(matches!(parsed1, Err(GitLFSProtocolError::Syntax(_))));
|
||||
|
||||
@@ -362,7 +360,7 @@ mod tests {
|
||||
let message2 = r#"
|
||||
{ "event": "upload", "oid": "bf3e3e2af9366abc03a", "size": 346232,
|
||||
"path": "/path/to/file.png", "action": { "href": "nfs://server/path", "header": { "key": "value" } } }"#;
|
||||
let parsed2: Result<LFSProtocolRequestEvent, GitLFSProtocolError> = message2.parse();
|
||||
let parsed2: std::result::Result<LFSProtocolRequestEvent, GitLFSProtocolError> = message2.parse();
|
||||
|
||||
assert!(matches!(parsed2, Err(GitLFSProtocolError::Argument(_))));
|
||||
|
||||
@@ -370,7 +368,7 @@ mod tests {
|
||||
let message3 = r#"
|
||||
{ "event": "upload", "oid": "bf3e3e2af9366a3b704ae0c31de5afa64193ebabffde2091936ad2e7510bc03a", "size": 0,
|
||||
"path": "/path/to/file.png", "action": { "href": "nfs://server/path", "header": { "key": "value" } } }"#;
|
||||
let parsed3: Result<LFSProtocolRequestEvent, GitLFSProtocolError> = message3.parse();
|
||||
let parsed3: std::result::Result<LFSProtocolRequestEvent, GitLFSProtocolError> = message3.parse();
|
||||
|
||||
assert!(matches!(parsed3, Err(GitLFSProtocolError::Argument(_))));
|
||||
|
||||
@@ -378,7 +376,7 @@ mod tests {
|
||||
let message4 = r#"
|
||||
{ "event": "upload", "oid": "bf3e3e2af9366a3b704ae0c31de5afa64193ebabffde2091936ad2e7510bc03a", "size": 346232,
|
||||
"action": { "href": "nfs://server/path", "header": { "key": "value" } } }"#;
|
||||
let parsed4: Result<LFSProtocolRequestEvent, GitLFSProtocolError> = message4.parse();
|
||||
let parsed4: std::result::Result<LFSProtocolRequestEvent, GitLFSProtocolError> = message4.parse();
|
||||
|
||||
assert!(matches!(parsed4, Err(GitLFSProtocolError::Syntax(_))));
|
||||
|
||||
@@ -386,7 +384,7 @@ mod tests {
|
||||
let message5 = r#"
|
||||
{ "event": "download", "oid": "bf3e3e2af9366a3b704ae0c31de5afa64193ebabffde2091936ad2e7510bc03a", "size": 12514,
|
||||
"path": "/path/to/file.png", "action": { "href": "https://server/path", "header": { "k1": "v1", "k2": "v2" } } }"#;
|
||||
let parsed5: Result<LFSProtocolRequestEvent, GitLFSProtocolError> = message5.parse();
|
||||
let parsed5: std::result::Result<LFSProtocolRequestEvent, GitLFSProtocolError> = message5.parse();
|
||||
|
||||
assert!(matches!(parsed5, Err(GitLFSProtocolError::Syntax(_))));
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use std::io;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::anyhow;
|
||||
use anyhow::{Result, bail};
|
||||
use rand_core::OsRng;
|
||||
use russh::keys::{Certificate, *};
|
||||
use russh::server::{Msg, Server as _, Session};
|
||||
@@ -154,15 +154,15 @@ impl server::Handler for ServerImpl {
|
||||
}
|
||||
|
||||
impl ServerImpl {
|
||||
fn git_lfs_authenticate(&self, request: Vec<&str>) -> anyhow::Result<String> {
|
||||
fn git_lfs_authenticate(&self, request: Vec<&str>) -> Result<String> {
|
||||
let Some(repo_id) = request.get(1) else {
|
||||
return Err(anyhow!("invalid request, missing repo id"));
|
||||
bail!("invalid request, missing repo id");
|
||||
};
|
||||
let Some(operation) = request.get(2) else {
|
||||
return Err(anyhow!("invalid request, missing operation"));
|
||||
bail!("invalid request, missing operation");
|
||||
};
|
||||
if !matches!(*operation, "upload" | "download") {
|
||||
return Err(anyhow!("invalid request, unrecognized operation"));
|
||||
bail!("invalid request, unrecognized operation");
|
||||
}
|
||||
let response = GitLFSAuthenticateResponse {
|
||||
header: GitLFSAuthentationResponseHeader {
|
||||
|
||||
@@ -60,7 +60,7 @@ impl TokenRefresher for DirectRefreshRouteTokenRefresher {
|
||||
let req = cred_helper
|
||||
.fill_credential(req)
|
||||
.await
|
||||
.map_err(reqwest_middleware::Error::Middleware)?;
|
||||
.map_err(reqwest_middleware::Error::middleware)?;
|
||||
req.send().await
|
||||
}
|
||||
})
|
||||
|
||||
@@ -187,10 +187,11 @@ impl CapturedCommand {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::io::Write;
|
||||
use std::process::Command;
|
||||
|
||||
use anyhow::Result;
|
||||
|
||||
use super::*;
|
||||
use super::{CapturedCommand, run_program_captured, run_program_captured_with_input_and_output};
|
||||
|
||||
#[test]
|
||||
fn test_run_program_captured() -> Result<()> {
|
||||
|
||||
@@ -173,8 +173,9 @@ fn format_for_shell_execution(command: &str, args: &[String]) -> Result<(String,
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use anyhow::{Ok, Result};
|
||||
use serial_test::serial;
|
||||
|
||||
type Result<T> = std::result::Result<T, Box<dyn std::error::Error + Send + Sync>>;
|
||||
use xet_runtime::utils::EnvVarGuard;
|
||||
|
||||
use super::*;
|
||||
|
||||
87
hf_xet/Cargo.lock
generated
87
hf_xet/Cargo.lock
generated
@@ -59,9 +59,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "anstream"
|
||||
version = "0.6.21"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a"
|
||||
checksum = "824a212faf96e9acacdbd09febd34438f8f711fb84e09a8916013cd7815ca28d"
|
||||
dependencies = [
|
||||
"anstyle",
|
||||
"anstyle-parse",
|
||||
@@ -74,15 +74,15 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "anstyle"
|
||||
version = "1.0.13"
|
||||
version = "1.0.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78"
|
||||
checksum = "940b3a0ca603d1eade50a4846a2afffd5ef57a9feac2c0e2ec2e14f9ead76000"
|
||||
|
||||
[[package]]
|
||||
name = "anstyle-parse"
|
||||
version = "0.2.7"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2"
|
||||
checksum = "52ce7f38b242319f7cabaa6813055467063ecdc9d355bbb4ce0c68908cd8130e"
|
||||
dependencies = [
|
||||
"utf8parse",
|
||||
]
|
||||
@@ -159,9 +159,9 @@ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
|
||||
|
||||
[[package]]
|
||||
name = "aws-lc-rs"
|
||||
version = "1.16.1"
|
||||
version = "1.16.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "94bffc006df10ac2a68c83692d734a465f8ee6c5b384d8545a636f81d858f4bf"
|
||||
checksum = "a054912289d18629dc78375ba2c3726a3afe3ff71b4edba9dedfca0e3446d1fc"
|
||||
dependencies = [
|
||||
"aws-lc-sys",
|
||||
"zeroize",
|
||||
@@ -169,9 +169,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "aws-lc-sys"
|
||||
version = "0.38.0"
|
||||
version = "0.39.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4321e568ed89bb5a7d291a7f37997c2c0df89809d7b6d12062c81ddb54aa782e"
|
||||
checksum = "1fa7e52a4c5c547c741610a2c6f123f3881e409b714cd27e6798ef020c514f0a"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"cmake",
|
||||
@@ -352,9 +352,9 @@ checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33"
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.2.56"
|
||||
version = "1.2.57"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "aebf35691d1bfb0ac386a69bac2fde4dd276fb618cf8bf4f5318fe285e821bb2"
|
||||
checksum = "7a0dd1ca384932ff3641c8718a02769f1698e7563dc6974ffd03346116310423"
|
||||
dependencies = [
|
||||
"find-msvc-tools",
|
||||
"jobserver",
|
||||
@@ -401,9 +401,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "4.5.60"
|
||||
version = "4.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2797f34da339ce31042b27d23607e051786132987f595b02ba4f6a6dffb7030a"
|
||||
checksum = "b193af5b67834b676abd72466a96c1024e6a6ad978a1f484bd90b85c94041351"
|
||||
dependencies = [
|
||||
"clap_builder",
|
||||
"clap_derive",
|
||||
@@ -411,9 +411,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "clap_builder"
|
||||
version = "4.5.60"
|
||||
version = "4.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "24a241312cea5059b13574bb9b3861cabf758b879c15190b37b6d6fd63ab6876"
|
||||
checksum = "714a53001bf66416adb0e2ef5ac857140e7dc3a0c48fb28b2f10762fc4b5069f"
|
||||
dependencies = [
|
||||
"anstream",
|
||||
"anstyle",
|
||||
@@ -423,9 +423,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "clap_derive"
|
||||
version = "4.5.55"
|
||||
version = "4.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a92793da1a46a5f2a02a6f4c46c6496b28c43638adea8306fcb0caa1634f24e5"
|
||||
checksum = "1110bd8a634a1ab8cb04345d8d878267d57c3cf1b38d91b71af6686408bbca6a"
|
||||
dependencies = [
|
||||
"heck",
|
||||
"proc-macro2",
|
||||
@@ -435,9 +435,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "clap_lex"
|
||||
version = "1.0.0"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3a822ea5bc7590f9d40f1ba12c0dc3c2760f3482c6984db1573ad11031420831"
|
||||
checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9"
|
||||
|
||||
[[package]]
|
||||
name = "cmake"
|
||||
@@ -450,9 +450,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "colorchoice"
|
||||
version = "1.0.4"
|
||||
version = "1.0.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75"
|
||||
checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570"
|
||||
|
||||
[[package]]
|
||||
name = "colored"
|
||||
@@ -1235,9 +1235,11 @@ checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c"
|
||||
name = "hf-xet"
|
||||
version = "1.4.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"http",
|
||||
"more-asserts",
|
||||
"pyo3",
|
||||
"serde",
|
||||
"thiserror 2.0.18",
|
||||
"tokio",
|
||||
@@ -1593,6 +1595,15 @@ dependencies = [
|
||||
"str_stack",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "inventory"
|
||||
version = "0.3.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "009ae045c87e7082cb72dab0ccd01ae075dd00141ddc108f43a0ea150a9e7227"
|
||||
dependencies = [
|
||||
"rustversion",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ipnet"
|
||||
version = "2.12.0"
|
||||
@@ -2096,9 +2107,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.21.3"
|
||||
version = "1.21.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
|
||||
checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50"
|
||||
|
||||
[[package]]
|
||||
name = "once_cell_polyfill"
|
||||
@@ -2114,9 +2125,9 @@ checksum = "269bca4c2591a28585d6bf10d9ed0332b7d76900a1b02bec41bdc3a2cdcda107"
|
||||
|
||||
[[package]]
|
||||
name = "openssl"
|
||||
version = "0.10.75"
|
||||
version = "0.10.76"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328"
|
||||
checksum = "951c002c75e16ea2c65b8c7e4d3d51d5530d8dfa7d060b4776828c88cfb18ecf"
|
||||
dependencies = [
|
||||
"bitflags 2.11.0",
|
||||
"cfg-if 1.0.4",
|
||||
@@ -2155,9 +2166,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "openssl-sys"
|
||||
version = "0.9.111"
|
||||
version = "0.9.112"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321"
|
||||
checksum = "57d55af3b3e226502be1526dfdba67ab0e9c96fc293004e79576b2b9edb0dbdb"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"libc",
|
||||
@@ -2528,6 +2539,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7ba0117f4212101ee6544044dae45abe1083d30ce7b29c4b5cbdfa2354e07383"
|
||||
dependencies = [
|
||||
"indoc",
|
||||
"inventory",
|
||||
"libc",
|
||||
"memoffset",
|
||||
"once_cell",
|
||||
@@ -3577,9 +3589,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tinyvec"
|
||||
version = "1.10.0"
|
||||
version = "1.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa"
|
||||
checksum = "3e61e67053d25a4e82c844e8424039d9745781b3fc4f32b8d55ed50f5f667ef3"
|
||||
dependencies = [
|
||||
"tinyvec_macros",
|
||||
]
|
||||
@@ -3829,9 +3841,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tracing-subscriber"
|
||||
version = "0.3.22"
|
||||
version = "0.3.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e"
|
||||
checksum = "cb7f578e5945fb242538965c2d0b04418d38ec25c79d160cd279bf0731c8d319"
|
||||
dependencies = [
|
||||
"matchers",
|
||||
"nu-ansi-term",
|
||||
@@ -4760,7 +4772,6 @@ dependencies = [
|
||||
name = "xet-core-structures"
|
||||
version = "1.4.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"base64 0.22.1",
|
||||
"bincode",
|
||||
@@ -4832,6 +4843,7 @@ dependencies = [
|
||||
name = "xet-runtime"
|
||||
version = "1.4.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"bytes",
|
||||
"chrono",
|
||||
@@ -4850,6 +4862,7 @@ dependencies = [
|
||||
"more-asserts",
|
||||
"oneshot",
|
||||
"pin-project",
|
||||
"pyo3",
|
||||
"rand 0.9.2",
|
||||
"reqwest",
|
||||
"serde",
|
||||
@@ -4891,18 +4904,18 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "zerocopy"
|
||||
version = "0.8.42"
|
||||
version = "0.8.47"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f2578b716f8a7a858b7f02d5bd870c14bf4ddbbcf3a4c05414ba6503640505e3"
|
||||
checksum = "efbb2a062be311f2ba113ce66f697a4dc589f85e78a4aea276200804cea0ed87"
|
||||
dependencies = [
|
||||
"zerocopy-derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zerocopy-derive"
|
||||
version = "0.8.42"
|
||||
version = "0.8.47"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7e6cc098ea4d3bd6246687de65af3f920c430e236bee1e3bf2e441463f08a02f"
|
||||
checksum = "0e8bc7269b54418e7aeeef514aa68f8690b8c0489a06b0136e5f57c4c5ccab89"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
|
||||
@@ -12,7 +12,7 @@ crate-type = ["cdylib", "lib"]
|
||||
[dependencies]
|
||||
xet-runtime = { path = "../xet_runtime" }
|
||||
xet-client = { path = "../xet_client" }
|
||||
xet-pkg = { package = "hf-xet", path = "../xet_pkg" }
|
||||
xet-pkg = { package = "hf-xet", path = "../xet_pkg", features = ["python"] }
|
||||
|
||||
async-trait = "0.1"
|
||||
chrono = "0.4"
|
||||
@@ -68,5 +68,3 @@ split-debuginfo = "none"
|
||||
inherits = "dev"
|
||||
debug = true
|
||||
opt-level = 3
|
||||
|
||||
# cargo-machete has detected the below unused dependency incorrectly
|
||||
|
||||
@@ -10,15 +10,16 @@ use std::sync::Arc;
|
||||
|
||||
use http::header::{self, HeaderMap, HeaderName, HeaderValue};
|
||||
use itertools::Itertools;
|
||||
use pyo3::exceptions::{PyKeyboardInterrupt, PyRuntimeError};
|
||||
use pyo3::exceptions::{PyKeyboardInterrupt, PyValueError};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::pyfunction;
|
||||
use rand::Rng;
|
||||
use runtime::async_run;
|
||||
use token_refresh::WrappedTokenRefresher;
|
||||
use tracing::debug;
|
||||
use xet_pkg::XetError;
|
||||
use xet_pkg::legacy::progress_tracking::TrackingProgressUpdater;
|
||||
use xet_pkg::legacy::{DataProcessingError, Sha256Policy, XetFileInfo, data_client};
|
||||
use xet_pkg::legacy::{Sha256Policy, XetFileInfo, data_client};
|
||||
use xet_runtime::core::file_handle_limits;
|
||||
|
||||
use crate::logging::init_logging;
|
||||
@@ -40,9 +41,9 @@ fn build_headers_with_user_agent(request_headers: Option<HashMap<String, String>
|
||||
let mut map = HeaderMap::new();
|
||||
for (key, value) in headers {
|
||||
let name = HeaderName::from_bytes(key.as_bytes())
|
||||
.map_err(|e| PyRuntimeError::new_err(format!("Invalid header name '{}': {}", key, e)))?;
|
||||
.map_err(|e| PyValueError::new_err(format!("Invalid header name '{}': {}", key, e)))?;
|
||||
let value = HeaderValue::from_str(&value)
|
||||
.map_err(|e| PyRuntimeError::new_err(format!("Invalid header value for '{}': {}", key, e)))?;
|
||||
.map_err(|e| PyValueError::new_err(format!("Invalid header value for '{}': {}", key, e)))?;
|
||||
map.insert(name, value);
|
||||
}
|
||||
Ok::<_, PyErr>(map)
|
||||
@@ -71,12 +72,8 @@ fn build_headers_with_user_agent(request_headers: Option<HashMap<String, String>
|
||||
Ok(Some(Arc::new(map)))
|
||||
}
|
||||
|
||||
fn convert_data_processing_error(e: DataProcessingError) -> PyErr {
|
||||
if cfg!(debug_assertions) {
|
||||
PyRuntimeError::new_err(format!("Data processing error: {e:?}"))
|
||||
} else {
|
||||
PyRuntimeError::new_err(format!("Data processing error: {e}"))
|
||||
}
|
||||
fn convert_xet_error(e: impl Into<XetError>) -> PyErr {
|
||||
PyErr::from(e.into())
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
@@ -95,13 +92,13 @@ pub fn upload_bytes(
|
||||
skip_sha256: bool,
|
||||
) -> PyResult<Vec<PyXetUploadInfo>> {
|
||||
if skip_sha256 && sha256s.is_some() {
|
||||
return Err(PyRuntimeError::new_err("skip_sha256=True and sha256s are mutually exclusive"));
|
||||
return Err(PyValueError::new_err("skip_sha256=True and sha256s are mutually exclusive"));
|
||||
}
|
||||
|
||||
if let Some(ref s) = sha256s
|
||||
&& s.len() != file_contents.len()
|
||||
{
|
||||
return Err(PyRuntimeError::new_err(format!(
|
||||
return Err(PyValueError::new_err(format!(
|
||||
"sha256s length ({}) must match file_contents length ({})",
|
||||
s.len(),
|
||||
file_contents.len()
|
||||
@@ -138,7 +135,7 @@ pub fn upload_bytes(
|
||||
header_map,
|
||||
)
|
||||
.await
|
||||
.map_err(convert_data_processing_error)?
|
||||
.map_err(convert_xet_error)?
|
||||
.into_iter()
|
||||
.map(PyXetUploadInfo::from)
|
||||
.collect();
|
||||
@@ -165,13 +162,13 @@ pub fn upload_files(
|
||||
skip_sha256: bool,
|
||||
) -> PyResult<Vec<PyXetUploadInfo>> {
|
||||
if skip_sha256 && sha256s.is_some() {
|
||||
return Err(PyRuntimeError::new_err("skip_sha256=True and sha256s are mutually exclusive"));
|
||||
return Err(PyValueError::new_err("skip_sha256=True and sha256s are mutually exclusive"));
|
||||
}
|
||||
|
||||
if let Some(ref s) = sha256s
|
||||
&& s.len() != file_paths.len()
|
||||
{
|
||||
return Err(PyRuntimeError::new_err(format!(
|
||||
return Err(PyValueError::new_err(format!(
|
||||
"sha256s length ({}) must match file_paths length ({})",
|
||||
s.len(),
|
||||
file_paths.len()
|
||||
@@ -212,7 +209,7 @@ pub fn upload_files(
|
||||
header_map,
|
||||
)
|
||||
.await
|
||||
.map_err(convert_data_processing_error)?
|
||||
.map_err(convert_xet_error)?
|
||||
.into_iter()
|
||||
.map(PyXetUploadInfo::from)
|
||||
.collect();
|
||||
@@ -254,7 +251,7 @@ pub fn hash_files(py: Python, file_paths: Vec<String>) -> PyResult<Vec<PyXetUplo
|
||||
async_run(py, async move {
|
||||
let out: Vec<PyXetUploadInfo> = data_client::hash_files_async(file_paths)
|
||||
.await
|
||||
.map_err(convert_data_processing_error)?
|
||||
.map_err(convert_xet_error)?
|
||||
.into_iter()
|
||||
.map(PyXetUploadInfo::from)
|
||||
.collect();
|
||||
@@ -302,7 +299,7 @@ pub fn download_files(
|
||||
header_map,
|
||||
)
|
||||
.await
|
||||
.map_err(convert_data_processing_error)?;
|
||||
.map_err(convert_xet_error)?;
|
||||
|
||||
debug!("Download call {x:x}: Completed.");
|
||||
|
||||
@@ -461,7 +458,6 @@ pub fn hf_xet(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(force_sigint_shutdown, m)?)?;
|
||||
m.add_class::<PyXetUploadInfo>()?;
|
||||
m.add_class::<PyXetDownloadInfo>()?;
|
||||
m.add_class::<PyXetUploadInfo>()?;
|
||||
m.add_class::<progress_update::PyItemProgressUpdate>()?;
|
||||
m.add_class::<progress_update::PyTotalProgressUpdate>()?;
|
||||
|
||||
@@ -470,6 +466,8 @@ pub fn hf_xet(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
// huggingface_hub.
|
||||
m.add_class::<PyPointerFile>()?;
|
||||
|
||||
xet_pkg::register_exceptions(m)?;
|
||||
|
||||
// Make sure the logger is set up.
|
||||
init_logging(py);
|
||||
|
||||
|
||||
@@ -3,9 +3,10 @@ use std::sync::{Arc, Mutex, RwLock};
|
||||
use std::time::Duration;
|
||||
|
||||
use lazy_static::lazy_static;
|
||||
use pyo3::exceptions::{PyKeyboardInterrupt, PyRuntimeError};
|
||||
use pyo3::exceptions::PyKeyboardInterrupt;
|
||||
use pyo3::prelude::*;
|
||||
use tracing::info;
|
||||
use xet_pkg::XetError;
|
||||
use xet_runtime::RuntimeError;
|
||||
use xet_runtime::core::XetRuntime;
|
||||
use xet_runtime::core::sync_primatives::spawn_os_thread;
|
||||
@@ -207,8 +208,8 @@ fn get_threadpool() -> Result<Arc<XetRuntime>, RuntimeError> {
|
||||
init_threadpool()
|
||||
}
|
||||
|
||||
pub fn convert_multithreading_error(e: impl Into<RuntimeError> + std::fmt::Display) -> PyErr {
|
||||
PyRuntimeError::new_err(format!("Xet Runtime Error: {e}"))
|
||||
pub fn convert_multithreading_error(e: impl Into<RuntimeError>) -> PyErr {
|
||||
PyErr::from(XetError::from(e.into()))
|
||||
}
|
||||
|
||||
pub fn async_run<Out, F>(py: Python, execution_call: F) -> PyResult<Out>
|
||||
|
||||
@@ -19,6 +19,7 @@ path = "src/bin/run_upload_simulations.rs"
|
||||
xet-runtime = { path = "../xet_runtime" }
|
||||
xet-client = { path = "../xet_client" }
|
||||
|
||||
anyhow = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
duration-str = { workspace = true }
|
||||
clap = { workspace = true }
|
||||
|
||||
@@ -11,6 +11,7 @@ use std::process::Command;
|
||||
use std::sync::{Arc, mpsc};
|
||||
use std::thread;
|
||||
|
||||
use anyhow::{Result, anyhow, bail};
|
||||
use clap::Parser;
|
||||
use simulation::scenario::VALID_SCENARIOS;
|
||||
use simulation::upload_concurrency::generate_summary_csv;
|
||||
@@ -121,7 +122,7 @@ fn scenario_binary() -> PathBuf {
|
||||
dir.join(name)
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
fn main() -> Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into()))
|
||||
.with_ansi(false)
|
||||
@@ -144,7 +145,7 @@ fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
};
|
||||
for s in &scenarios {
|
||||
if !VALID_SCENARIOS.contains(&s.as_str()) {
|
||||
return Err(format!("Unknown scenario: {}. Valid: {:?}", s, VALID_SCENARIOS).into());
|
||||
bail!("Unknown scenario: {}. Valid: {:?}", s, VALID_SCENARIOS);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,11 +180,10 @@ fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
|
||||
let bin = args.scenario_bin.unwrap_or_else(scenario_binary);
|
||||
if !bin.exists() {
|
||||
return Err(format!(
|
||||
bail!(
|
||||
"run_upload_scenario binary not found at {}; build with cargo build --release -p simulation",
|
||||
bin.display()
|
||||
)
|
||||
.into());
|
||||
);
|
||||
}
|
||||
|
||||
let total_runs =
|
||||
@@ -322,7 +322,7 @@ fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
);
|
||||
|
||||
for h in handles {
|
||||
h.join().map_err(|_| "scenario thread panicked")?;
|
||||
h.join().map_err(|_| anyhow!("scenario thread panicked"))?;
|
||||
}
|
||||
|
||||
generate_summary_csv(&results_base)?;
|
||||
|
||||
@@ -35,7 +35,7 @@ pub enum ScenarioError {
|
||||
#[error("Scenario error: {0}")]
|
||||
Scenario(String),
|
||||
#[error(transparent)]
|
||||
CasClient(#[from] xet_client::cas_client::CasClientError),
|
||||
CasClient(#[from] xet_client::ClientError),
|
||||
#[error(transparent)]
|
||||
Io(#[from] std::io::Error),
|
||||
#[error(transparent)]
|
||||
|
||||
@@ -4,6 +4,7 @@ use std::collections::{BTreeMap, HashMap};
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
|
||||
use anyhow::{Result, anyhow, bail};
|
||||
use serde::Deserialize;
|
||||
|
||||
use super::upload_simulation_client::ClientMetrics;
|
||||
@@ -18,7 +19,7 @@ struct ClientTimelineData {
|
||||
stats_by_timestamp: BTreeMap<u64, ClientMetrics>,
|
||||
}
|
||||
|
||||
fn load_json_lines<T>(file_path: &Path) -> Result<Vec<T>, Box<dyn std::error::Error + Send + Sync>>
|
||||
fn load_json_lines<T>(file_path: &Path) -> Result<Vec<T>>
|
||||
where
|
||||
T: for<'de> Deserialize<'de>,
|
||||
{
|
||||
@@ -34,7 +35,7 @@ where
|
||||
}
|
||||
|
||||
/// Generates timeline.csv in the given scenario directory from client_stats_*.json files.
|
||||
pub fn generate_timeline_csv(scenario_dir: &Path) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
pub fn generate_timeline_csv(scenario_dir: &Path) -> Result<()> {
|
||||
let mut client_timelines: HashMap<u64, ClientTimelineData> = HashMap::new();
|
||||
|
||||
for entry in fs::read_dir(scenario_dir)? {
|
||||
@@ -64,7 +65,7 @@ pub fn generate_timeline_csv(scenario_dir: &Path) -> Result<(), Box<dyn std::err
|
||||
}
|
||||
|
||||
if client_timelines.is_empty() {
|
||||
return Err("No client_stats_*.json found".into());
|
||||
bail!("No client_stats_*.json found");
|
||||
}
|
||||
|
||||
let mut sorted_clients: Vec<_> = client_timelines.values().collect();
|
||||
@@ -73,7 +74,7 @@ pub fn generate_timeline_csv(scenario_dir: &Path) -> Result<(), Box<dyn std::err
|
||||
let start_time = sorted_clients.iter().map(|c| c.first_timestamp).min().unwrap_or(0);
|
||||
let end_time = sorted_clients.iter().map(|c| c.last_timestamp).max().unwrap_or(0);
|
||||
if start_time == 0 || end_time == 0 || end_time <= start_time {
|
||||
return Err("Invalid time window".into());
|
||||
bail!("Invalid time window");
|
||||
}
|
||||
|
||||
let mut csv = String::new();
|
||||
@@ -168,15 +169,11 @@ struct ScenarioSummaryRow {
|
||||
average_round_trip_time_ms: f64,
|
||||
}
|
||||
|
||||
fn process_timeline_csv(
|
||||
timeline_path: &Path,
|
||||
scenario_dir: &Path,
|
||||
scenario_name: &str,
|
||||
) -> Result<ScenarioSummaryRow, Box<dyn std::error::Error + Send + Sync>> {
|
||||
fn process_timeline_csv(timeline_path: &Path, scenario_dir: &Path, scenario_name: &str) -> Result<ScenarioSummaryRow> {
|
||||
let content = fs::read_to_string(timeline_path)?;
|
||||
let lines: Vec<&str> = content.lines().collect();
|
||||
if lines.is_empty() {
|
||||
return Err("Empty timeline.csv".into());
|
||||
bail!("Empty timeline.csv");
|
||||
}
|
||||
|
||||
let header = lines[0];
|
||||
@@ -184,20 +181,23 @@ fn process_timeline_csv(
|
||||
let total_bytes_idx = header_cols
|
||||
.iter()
|
||||
.position(|&s| s == "total_bytes")
|
||||
.ok_or("total_bytes column")?;
|
||||
let elapsed_ms_idx = header_cols.iter().position(|&s| s == "elapsed_ms").ok_or("elapsed_ms column")?;
|
||||
.ok_or_else(|| anyhow!("total_bytes column"))?;
|
||||
let elapsed_ms_idx = header_cols
|
||||
.iter()
|
||||
.position(|&s| s == "elapsed_ms")
|
||||
.ok_or_else(|| anyhow!("elapsed_ms column"))?;
|
||||
let total_retries_idx = header_cols
|
||||
.iter()
|
||||
.position(|&s| s == "total_retries")
|
||||
.ok_or("total_retries column")?;
|
||||
.ok_or_else(|| anyhow!("total_retries column"))?;
|
||||
let total_concurrency_idx = header_cols
|
||||
.iter()
|
||||
.position(|&s| s == "total_concurrency")
|
||||
.ok_or("total_concurrency column")?;
|
||||
.ok_or_else(|| anyhow!("total_concurrency column"))?;
|
||||
let average_concurrency_idx = header_cols
|
||||
.iter()
|
||||
.position(|&s| s == "average_concurrency")
|
||||
.ok_or("average_concurrency column")?;
|
||||
.ok_or_else(|| anyhow!("average_concurrency column"))?;
|
||||
|
||||
let client_concurrency_indices: Vec<usize> = header_cols
|
||||
.iter()
|
||||
@@ -263,7 +263,7 @@ fn process_timeline_csv(
|
||||
}
|
||||
|
||||
if row_count == 0 {
|
||||
return Err("No data rows in timeline.csv".into());
|
||||
bail!("No data rows in timeline.csv");
|
||||
}
|
||||
|
||||
let avg_total_concurrency = total_concurrency_sum / row_count as f64;
|
||||
@@ -296,11 +296,7 @@ fn process_timeline_csv(
|
||||
/// Computes network utilization as (bytes_sent / max_possible_bytes) * 100, capped at 100%.
|
||||
/// max_possible_bytes = sum over segments of (bandwidth_bytes_per_sec * segment_sec) from network_stats.json,
|
||||
/// using elapsed_sec for segment boundaries so we don't mix absolute timestamps with duration.
|
||||
fn calculate_network_utilization(
|
||||
scenario_dir: &Path,
|
||||
total_bytes: f64,
|
||||
duration_sec: f64,
|
||||
) -> Result<f64, Box<dyn std::error::Error + Send + Sync>> {
|
||||
fn calculate_network_utilization(scenario_dir: &Path, total_bytes: f64, duration_sec: f64) -> Result<f64> {
|
||||
let network_stats_path = scenario_dir.join("network_stats.json");
|
||||
if !network_stats_path.exists() || duration_sec <= 0.0 {
|
||||
return Ok(0.0);
|
||||
@@ -338,7 +334,7 @@ fn calculate_network_utilization(
|
||||
}
|
||||
}
|
||||
|
||||
fn calculate_average_rtt(scenario_dir: &Path) -> Result<f64, Box<dyn std::error::Error + Send + Sync>> {
|
||||
fn calculate_average_rtt(scenario_dir: &Path) -> Result<f64> {
|
||||
let mut all_rtts = Vec::new();
|
||||
for entry in fs::read_dir(scenario_dir)? {
|
||||
let entry = entry?;
|
||||
@@ -363,7 +359,7 @@ fn calculate_average_rtt(scenario_dir: &Path) -> Result<f64, Box<dyn std::error:
|
||||
}
|
||||
|
||||
/// Generates summary.csv in the given results directory from all scenario subdirectories that have timeline.csv.
|
||||
pub fn generate_summary_csv(results_dir: &Path) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
pub fn generate_summary_csv(results_dir: &Path) -> Result<()> {
|
||||
let mut scenario_dirs = Vec::new();
|
||||
for entry in fs::read_dir(results_dir)? {
|
||||
let entry = entry?;
|
||||
@@ -373,7 +369,7 @@ pub fn generate_summary_csv(results_dir: &Path) -> Result<(), Box<dyn std::error
|
||||
}
|
||||
}
|
||||
if scenario_dirs.is_empty() {
|
||||
return Err("No scenario directories with timeline.csv found".into());
|
||||
bail!("No scenario directories with timeline.csv found");
|
||||
}
|
||||
|
||||
let mut scenario_data = Vec::new();
|
||||
@@ -392,7 +388,7 @@ pub fn generate_summary_csv(results_dir: &Path) -> Result<(), Box<dyn std::error
|
||||
}
|
||||
}
|
||||
if scenario_data.is_empty() {
|
||||
return Err("No valid timeline.csv files".into());
|
||||
bail!("No valid timeline.csv files");
|
||||
}
|
||||
|
||||
scenario_data.sort_by(|a, b| a.scenario_name.cmp(&b.scenario_name));
|
||||
|
||||
@@ -6,6 +6,7 @@ use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use anyhow::Result;
|
||||
use bytes::Bytes;
|
||||
use http::HeaderValue;
|
||||
use http::header::CONTENT_LENGTH;
|
||||
@@ -80,7 +81,7 @@ pub async fn run_upload_clients_until_cancelled(
|
||||
min_data_kb: u64,
|
||||
max_data_kb: u64,
|
||||
cancel: CancellationToken,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
) -> Result<()> {
|
||||
run_upload_clients_impl(server_addr, output_dir, min_data_kb, max_data_kb, None, Some(cancel)).await
|
||||
}
|
||||
|
||||
@@ -92,7 +93,7 @@ pub async fn run_upload_clients(
|
||||
min_data_kb: u64,
|
||||
max_data_kb: u64,
|
||||
repeat_duration_seconds: u64,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
) -> Result<()> {
|
||||
run_upload_clients_impl(server_addr, output_dir, min_data_kb, max_data_kb, Some(repeat_duration_seconds), None)
|
||||
.await
|
||||
}
|
||||
@@ -304,12 +305,12 @@ async fn run_upload_clients_impl(
|
||||
max_data_kb: u64,
|
||||
repeat_duration_seconds: Option<u64>,
|
||||
cancel: Option<CancellationToken>,
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
) -> Result<()> {
|
||||
let min_data_size = min_data_kb * 1024;
|
||||
let max_data_size = max_data_kb * 1024;
|
||||
let client_id = rand::rng().random_range(0..1000000000_u64);
|
||||
|
||||
let http_client = build_http_client("test_session", None, None).map_err(|e| e.to_string())?;
|
||||
let http_client = build_http_client("test_session", None, None)?;
|
||||
|
||||
let duration_sec = repeat_duration_seconds.unwrap_or(u64::MAX);
|
||||
let client_params = serde_json::json!({
|
||||
|
||||
2
wasm/hf_xet_thin_wasm/Cargo.lock
generated
2
wasm/hf_xet_thin_wasm/Cargo.lock
generated
@@ -3969,7 +3969,6 @@ dependencies = [
|
||||
name = "xet-core-structures"
|
||||
version = "1.4.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"base64",
|
||||
"bincode",
|
||||
@@ -4041,6 +4040,7 @@ dependencies = [
|
||||
name = "xet-runtime"
|
||||
version = "1.4.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"bytes",
|
||||
"chrono",
|
||||
|
||||
2
wasm/hf_xet_wasm/Cargo.lock
generated
2
wasm/hf_xet_wasm/Cargo.lock
generated
@@ -4141,7 +4141,6 @@ dependencies = [
|
||||
name = "xet-core-structures"
|
||||
version = "1.4.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"base64",
|
||||
"bincode",
|
||||
@@ -4213,6 +4212,7 @@ dependencies = [
|
||||
name = "xet-runtime"
|
||||
version = "1.4.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"async-trait",
|
||||
"bytes",
|
||||
"chrono",
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::fmt::Debug;
|
||||
|
||||
use thiserror::Error;
|
||||
use xet_client::ClientError;
|
||||
use xet_core_structures::FormatError;
|
||||
use xet_core_structures::CoreError;
|
||||
use xet_core_structures::merklehash::DataHashHexParseError;
|
||||
|
||||
#[non_exhaustive]
|
||||
@@ -17,8 +17,8 @@ pub enum DataProcessingError {
|
||||
#[error("Client error: {0}")]
|
||||
ClientError(#[from] ClientError),
|
||||
|
||||
#[error("Format error: {0}")]
|
||||
FormatError(#[from] FormatError),
|
||||
#[error("Core structures error: {0}")]
|
||||
CoreError(#[from] CoreError),
|
||||
}
|
||||
|
||||
impl DataProcessingError {
|
||||
|
||||
@@ -5,7 +5,7 @@ use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tokio_with_wasm::alias as wasmtokio;
|
||||
use xet_client::cas_client::CasClientError;
|
||||
use xet_client::ClientError;
|
||||
use xet_core_structures::merklehash::{HMACKey, MerkleHash};
|
||||
use xet_core_structures::metadata_shard::MDBShardInfo;
|
||||
use xet_core_structures::metadata_shard::file_structs::FileDataSequenceEntry;
|
||||
@@ -19,7 +19,7 @@ use super::wasm_file_upload_session::FileUploadSession;
|
||||
pub struct UploadSessionDataManager {
|
||||
session: Arc<FileUploadSession>,
|
||||
shard: HashMap<HMACKey, MDBInMemoryShard>,
|
||||
query_tasks: wasmtokio::task::JoinSet<stdResult<Option<bytes::Bytes>, CasClientError>>,
|
||||
query_tasks: wasmtokio::task::JoinSet<stdResult<Option<bytes::Bytes>, ClientError>>,
|
||||
}
|
||||
|
||||
impl UploadSessionDataManager {
|
||||
|
||||
@@ -3,7 +3,8 @@ use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use tokio_with_wasm::alias as wasmtokio;
|
||||
use xet_client::cas_client::{CasClientError, Client};
|
||||
use xet_client::cas_client::Client;
|
||||
use xet_client::ClientError;
|
||||
use xet_core_structures::xorb_object::SerializedXorbObject;
|
||||
|
||||
use crate::errors::*;
|
||||
@@ -49,7 +50,7 @@ impl XorbUploader for XorbUploaderLocalSequential {
|
||||
pub struct XorbUploaderSpawnParallel {
|
||||
client: Arc<dyn Client + Send + Sync>,
|
||||
cas_prefix: String,
|
||||
tasks: wasmtokio::task::JoinSet<stdResult<u64, CasClientError>>,
|
||||
tasks: wasmtokio::task::JoinSet<stdResult<u64, ClientError>>,
|
||||
}
|
||||
|
||||
impl XorbUploaderSpawnParallel {
|
||||
|
||||
@@ -14,9 +14,9 @@ use xet_core_structures::ExpWeightedMovingAvg;
|
||||
use xet_runtime::core::xet_config;
|
||||
use xet_runtime::utils::adjustable_semaphore::{AdjustableSemaphore, AdjustableSemaphorePermit};
|
||||
|
||||
use super::super::error::CasClientError;
|
||||
use super::super::progress_tracked_streams::ProgressCallback;
|
||||
use super::rtt_prediction::RTTPredictor;
|
||||
use crate::error::Result;
|
||||
|
||||
const MIN_PARTIAL_REPORT_INTERVAL_MS: u64 = 200;
|
||||
const PARTIAL_REPORT_WEIGHT_RATIO: f64 = 0.2;
|
||||
@@ -361,7 +361,7 @@ impl AdaptiveConcurrencyController {
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn acquire_connection_permit(self: &Arc<Self>) -> Result<ConnectionPermit, CasClientError> {
|
||||
pub async fn acquire_connection_permit(self: &Arc<Self>) -> Result<ConnectionPermit> {
|
||||
let _permit = self.concurrency_semaphore.acquire().await?;
|
||||
|
||||
let info = Arc::new(ConnectionPermitInfo {
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
pub use crate::error::{ClientError as CasClientError, Result};
|
||||
@@ -1,6 +1,5 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::anyhow;
|
||||
use http::{Extensions, HeaderMap, StatusCode};
|
||||
use reqwest::header::{AUTHORIZATION, COOKIE, HeaderValue, SET_COOKIE};
|
||||
use reqwest::{Request, Response};
|
||||
@@ -11,8 +10,8 @@ use xet_runtime::core::{XetRuntime, xet_config};
|
||||
use xet_runtime::error_printer::{ErrorPrinter, OptionPrinter};
|
||||
|
||||
use super::auth::{AuthConfig, TokenProvider};
|
||||
use super::{CasClientError, error};
|
||||
use crate::cas_types::{REQUEST_ID_HEADER, SESSION_ID_HEADER};
|
||||
use crate::error::{ClientError, Result};
|
||||
|
||||
/// Middleware that rewrites https:// URLs to http:// when using Unix socket.
|
||||
/// This allows the proxy to parse plain HTTP and upgrade to HTTPS when forwarding.
|
||||
@@ -27,7 +26,7 @@ impl Middleware for HttpsToHttpMiddleware {
|
||||
mut req: Request,
|
||||
extensions: &mut Extensions,
|
||||
next: Next<'_>,
|
||||
) -> Result<Response, reqwest_middleware::Error> {
|
||||
) -> std::result::Result<Response, reqwest_middleware::Error> {
|
||||
let url = req.url_mut();
|
||||
if url.scheme() == "https" {
|
||||
let original_scheme = url.scheme().to_string();
|
||||
@@ -53,10 +52,7 @@ fn redact_headers(headers: &HeaderMap) -> HeaderMap {
|
||||
|
||||
#[allow(unused_variables)]
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
fn reqwest_client(
|
||||
unix_socket_path: Option<&str>,
|
||||
custom_headers: Option<Arc<HeaderMap>>,
|
||||
) -> Result<reqwest::Client, CasClientError> {
|
||||
fn reqwest_client(unix_socket_path: Option<&str>, custom_headers: Option<Arc<HeaderMap>>) -> Result<reqwest::Client> {
|
||||
// Check config if explicit socket path is not provided
|
||||
let socket_path = unix_socket_path
|
||||
.map(|s| s.to_string())
|
||||
@@ -116,7 +112,7 @@ fn reqwest_client(
|
||||
fn reqwest_client_no_read_timeout(
|
||||
unix_socket_path: Option<&str>,
|
||||
custom_headers: Option<Arc<HeaderMap>>,
|
||||
) -> Result<reqwest::Client, CasClientError> {
|
||||
) -> Result<reqwest::Client> {
|
||||
let socket_path = unix_socket_path
|
||||
.map(|s| s.to_string())
|
||||
.or_else(|| xet_config().client.unix_socket_path.clone());
|
||||
@@ -149,10 +145,7 @@ fn reqwest_client_no_read_timeout(
|
||||
}
|
||||
|
||||
#[cfg(target_family = "wasm")]
|
||||
fn reqwest_client(
|
||||
_unix_socket_path: Option<&str>,
|
||||
custom_headers: Option<Arc<HeaderMap>>,
|
||||
) -> Result<reqwest::Client, CasClientError> {
|
||||
fn reqwest_client(_unix_socket_path: Option<&str>, custom_headers: Option<Arc<HeaderMap>>) -> Result<reqwest::Client> {
|
||||
// For WASM, create a new client with the specified headers, including the user-agent.
|
||||
// Note: we could cache this, but user_agent can vary, so we create per-call
|
||||
// Unix socket path is ignored on WASM
|
||||
@@ -170,7 +163,7 @@ pub fn build_auth_http_client(
|
||||
session_id: &str,
|
||||
unix_socket_path: Option<&str>,
|
||||
custom_headers: Option<Arc<HeaderMap>>,
|
||||
) -> Result<ClientWithMiddleware, CasClientError> {
|
||||
) -> Result<ClientWithMiddleware> {
|
||||
let auth_middleware = auth_config.as_ref().map(AuthMiddleware::from).info_none("CAS auth disabled");
|
||||
let logging_middleware = Some(LoggingMiddleware);
|
||||
let session_middleware = (!session_id.is_empty()).then(|| SessionMiddleware(session_id.to_owned()));
|
||||
@@ -201,7 +194,7 @@ pub fn build_auth_http_client_no_read_timeout(
|
||||
session_id: &str,
|
||||
unix_socket_path: Option<&str>,
|
||||
custom_headers: Option<Arc<HeaderMap>>,
|
||||
) -> Result<ClientWithMiddleware, CasClientError> {
|
||||
) -> Result<ClientWithMiddleware> {
|
||||
let auth_middleware = auth_config.as_ref().map(AuthMiddleware::from).info_none("CAS auth disabled");
|
||||
let logging_middleware = Some(LoggingMiddleware);
|
||||
let session_middleware = (!session_id.is_empty()).then(|| SessionMiddleware(session_id.to_owned()));
|
||||
@@ -226,7 +219,7 @@ pub fn build_http_client(
|
||||
session_id: &str,
|
||||
unix_socket_path: Option<&str>,
|
||||
custom_headers: Option<Arc<HeaderMap>>,
|
||||
) -> Result<ClientWithMiddleware, CasClientError> {
|
||||
) -> Result<ClientWithMiddleware> {
|
||||
build_auth_http_client(&None, session_id, unix_socket_path, custom_headers)
|
||||
}
|
||||
|
||||
@@ -290,14 +283,14 @@ impl AuthMiddleware {
|
||||
/// (e.g. to a remote service). During this time, no other CAS requests can proceed
|
||||
/// from this client until the token has been fetched. This is expected/ok since we
|
||||
/// don't have a valid token and thus any calls would fail.
|
||||
async fn get_token(&self) -> Result<String, anyhow::Error> {
|
||||
async fn get_token(&self) -> Result<String> {
|
||||
let mut provider = self.token_provider.lock().await;
|
||||
provider
|
||||
.get_valid_token()
|
||||
.await
|
||||
.map_err(|err| {
|
||||
warn!(?err, "Token refresh failed");
|
||||
anyhow!("couldn't get token: {err:?}")
|
||||
ClientError::AuthError(err)
|
||||
})
|
||||
.inspect(|_token| {
|
||||
info!("Token refresh successful for CAS authentication");
|
||||
@@ -322,7 +315,7 @@ impl Middleware for AuthMiddleware {
|
||||
extensions: &mut http::Extensions,
|
||||
next: Next<'_>,
|
||||
) -> reqwest_middleware::Result<Response> {
|
||||
let token = self.get_token().await.map_err(reqwest_middleware::Error::Middleware)?;
|
||||
let token = self.get_token().await.map_err(reqwest_middleware::Error::middleware)?;
|
||||
|
||||
let headers = req.headers_mut();
|
||||
headers.insert(AUTHORIZATION, HeaderValue::from_str(&format!("Bearer {token}")).unwrap());
|
||||
@@ -365,16 +358,14 @@ pub trait ResponseErrorLogger<T> {
|
||||
/// This logs an error if one occurred before receiving a response or
|
||||
/// if the status code indicates a failure.
|
||||
/// As a result of these checks, the response is also transformed into a
|
||||
/// cas_client::error::Result instead of the raw reqwest_middleware::Result.
|
||||
impl ResponseErrorLogger<error::Result<Response>> for reqwest_middleware::Result<Response> {
|
||||
fn process_error(self, api: &str) -> error::Result<Response> {
|
||||
let res = self
|
||||
.map_err(CasClientError::from)
|
||||
.log_error(format!("error invoking {api} api"))?;
|
||||
/// crate::error::Result instead of the raw reqwest_middleware::Result.
|
||||
impl ResponseErrorLogger<Result<Response>> for reqwest_middleware::Result<Response> {
|
||||
fn process_error(self, api: &str) -> Result<Response> {
|
||||
let res = self.map_err(ClientError::from).log_error(format!("error invoking {api} api"))?;
|
||||
let request_id = request_id_from_response(&res);
|
||||
let error_message = format!("{api} api failed: request id: {request_id}");
|
||||
let status = res.status();
|
||||
let res = res.error_for_status().map_err(CasClientError::from);
|
||||
let res = res.error_for_status().map_err(ClientError::from);
|
||||
match (api, status) {
|
||||
("get_reconstruction", StatusCode::RANGE_NOT_SATISFIABLE) => res.debug_error(&error_message),
|
||||
// not all status codes mean fatal error
|
||||
|
||||
@@ -4,9 +4,9 @@ use xet_core_structures::metadata_shard::file_structs::MDBFileInfo;
|
||||
use xet_core_structures::xorb_object::SerializedXorbObject;
|
||||
|
||||
use super::adaptive_concurrency::ConnectionPermit;
|
||||
use super::error::Result;
|
||||
use super::progress_tracked_streams::ProgressCallback;
|
||||
use crate::cas_types::{BatchQueryReconstructionResponse, FileRange, HttpRange, QueryReconstructionResponseV2};
|
||||
use crate::error::Result;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait URLProvider: Send + Sync {
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
pub use error::CasClientError;
|
||||
pub use http_client::{Api, ResponseErrorLogger, build_auth_http_client, build_http_client};
|
||||
pub use interface::{Client, URLProvider};
|
||||
pub use remote_client::RemoteClient;
|
||||
@@ -12,7 +11,6 @@ use tracing::Level;
|
||||
|
||||
pub mod adaptive_concurrency;
|
||||
pub mod auth;
|
||||
mod error;
|
||||
pub mod exports;
|
||||
pub mod http_client;
|
||||
mod interface;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use bytes::Bytes;
|
||||
|
||||
use crate::cas_client::error::{CasClientError, Result};
|
||||
use crate::cas_types::HttpRange;
|
||||
use crate::error::{ClientError, Result};
|
||||
|
||||
/// A single part from a multipart/byteranges HTTP response.
|
||||
pub struct MultipartPart {
|
||||
@@ -23,7 +23,7 @@ pub fn parse_multipart_byteranges(content_type: &str, body: Bytes) -> Result<Vec
|
||||
|
||||
let first_delim = format!("--{boundary}");
|
||||
let Some(start) = find_subsequence(body_slice, first_delim.as_bytes()) else {
|
||||
return Err(CasClientError::Other("No boundary found in multipart body".to_string()));
|
||||
return Err(ClientError::Other("No boundary found in multipart body".to_string()));
|
||||
};
|
||||
|
||||
let mut remaining = &body_slice[start + first_delim.len()..];
|
||||
@@ -42,7 +42,7 @@ pub fn parse_multipart_byteranges(content_type: &str, body: Bytes) -> Result<Vec
|
||||
};
|
||||
|
||||
let Some(header_end) = find_subsequence(part_data, b"\r\n\r\n") else {
|
||||
return Err(CasClientError::Other("Malformed multipart part: missing header/data separator".to_string()));
|
||||
return Err(ClientError::Other("Malformed multipart part: missing header/data separator".to_string()));
|
||||
};
|
||||
|
||||
let headers = &part_data[..header_end];
|
||||
@@ -80,12 +80,12 @@ fn extract_boundary(content_type: &str) -> Result<String> {
|
||||
return Ok(boundary.to_string());
|
||||
}
|
||||
}
|
||||
Err(CasClientError::Other(format!("No boundary found in Content-Type: {content_type}")))
|
||||
Err(ClientError::Other(format!("No boundary found in Content-Type: {content_type}")))
|
||||
}
|
||||
|
||||
fn parse_content_range(headers: &[u8]) -> Result<HttpRange> {
|
||||
let headers_str = std::str::from_utf8(headers)
|
||||
.map_err(|e| CasClientError::Other(format!("Invalid UTF-8 in part headers: {e}")))?;
|
||||
let headers_str =
|
||||
std::str::from_utf8(headers).map_err(|e| ClientError::Other(format!("Invalid UTF-8 in part headers: {e}")))?;
|
||||
|
||||
for line in headers_str.split("\r\n") {
|
||||
let line_lower = line.to_ascii_lowercase();
|
||||
@@ -96,24 +96,24 @@ fn parse_content_range(headers: &[u8]) -> Result<HttpRange> {
|
||||
let original_value = range_spec.trim();
|
||||
let slash_pos = original_value
|
||||
.find('/')
|
||||
.ok_or_else(|| CasClientError::Other(format!("Invalid Content-Range: {line}")))?;
|
||||
.ok_or_else(|| ClientError::Other(format!("Invalid Content-Range: {line}")))?;
|
||||
let range_part = &original_value[..slash_pos];
|
||||
let dash_pos = range_part
|
||||
.find('-')
|
||||
.ok_or_else(|| CasClientError::Other(format!("Invalid Content-Range: {line}")))?;
|
||||
.ok_or_else(|| ClientError::Other(format!("Invalid Content-Range: {line}")))?;
|
||||
let start: u64 = range_part[..dash_pos]
|
||||
.parse()
|
||||
.map_err(|e| CasClientError::Other(format!("Invalid Content-Range start: {e}")))?;
|
||||
.map_err(|e| ClientError::Other(format!("Invalid Content-Range start: {e}")))?;
|
||||
let end: u64 = range_part[dash_pos + 1..]
|
||||
.parse()
|
||||
.map_err(|e| CasClientError::Other(format!("Invalid Content-Range end: {e}")))?;
|
||||
.map_err(|e| ClientError::Other(format!("Invalid Content-Range end: {e}")))?;
|
||||
// RFC 7233 Content-Range uses an inclusive end, which matches HttpRange.
|
||||
return Ok(HttpRange::new(start, end));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(CasClientError::Other("No Content-Range header found in multipart part".to_string()))
|
||||
Err(ClientError::Other("No Content-Range header found in multipart part".to_string()))
|
||||
}
|
||||
|
||||
fn find_subsequence(haystack: &[u8], needle: &[u8]) -> Option<usize> {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
|
||||
|
||||
use anyhow::anyhow;
|
||||
use bytes::Bytes;
|
||||
use futures::TryStreamExt;
|
||||
use http::HeaderValue;
|
||||
@@ -15,7 +16,6 @@ use xet_runtime::core::xet_config;
|
||||
|
||||
use super::adaptive_concurrency::{AdaptiveConcurrencyController, ConnectionPermit};
|
||||
use super::auth::AuthConfig;
|
||||
use super::error::{CasClientError, Result};
|
||||
use super::http_client::{self, Api};
|
||||
use super::interface::URLProvider;
|
||||
use super::progress_tracked_streams::{
|
||||
@@ -27,6 +27,7 @@ use crate::cas_types::{
|
||||
BatchQueryReconstructionResponse, FileRange, HttpRange, Key, QueryReconstructionResponse,
|
||||
QueryReconstructionResponseV2, UploadShardResponse, UploadShardResponseType, UploadXorbResponse,
|
||||
};
|
||||
use crate::error::{ClientError, Result};
|
||||
|
||||
pub const CAS_ENDPOINT: &str = "http://localhost:8080";
|
||||
pub const PREFIX_DEFAULT: &str = "default";
|
||||
@@ -187,7 +188,9 @@ impl RemoteClient {
|
||||
"v1" => "cas::get_reconstruction_v1",
|
||||
"v2" => "cas::get_reconstruction_v2",
|
||||
_ => {
|
||||
return Err(CasClientError::internal(format!("unsupported reconstruction API version: {api_version}")));
|
||||
return Err(ClientError::InternalError(anyhow!(
|
||||
"unsupported reconstruction API version: {api_version}"
|
||||
)));
|
||||
},
|
||||
};
|
||||
|
||||
@@ -224,7 +227,7 @@ impl RemoteClient {
|
||||
);
|
||||
Ok(Some(response))
|
||||
},
|
||||
Err(CasClientError::ReqwestError(ref e, _)) if e.status() == Some(StatusCode::RANGE_NOT_SATISFIABLE) => {
|
||||
Err(ClientError::ReqwestError(ref e, _)) if e.status() == Some(StatusCode::RANGE_NOT_SATISFIABLE) => {
|
||||
Ok(None)
|
||||
},
|
||||
Err(e) => Err(e),
|
||||
@@ -286,7 +289,7 @@ impl RemoteClient {
|
||||
Err(e) => Err(e),
|
||||
},
|
||||
1 => Ok(self.get_reconstruction_v1(file_id, bytes_range).await?.map(Into::into)),
|
||||
other => Err(CasClientError::internal(format!("unsupported reconstruction API version: {other}"))),
|
||||
other => Err(ClientError::InternalError(anyhow!("unsupported reconstruction API version: {other}"))),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -424,7 +427,7 @@ impl Client for RemoteClient {
|
||||
let body = resp
|
||||
.bytes()
|
||||
.await
|
||||
.map_err(|e| RetryableReqwestError::RetryableError(CasClientError::from(e)))?;
|
||||
.map_err(|e| RetryableReqwestError::RetryableError(ClientError::from(e)))?;
|
||||
|
||||
let multipart_parts = crate::cas_client::multipart::parse_multipart_byteranges(&content_type, body)
|
||||
.map_err(RetryableReqwestError::FatalError)?;
|
||||
@@ -439,7 +442,7 @@ impl Client for RemoteClient {
|
||||
let (data, chunk_indices) =
|
||||
xet_core_structures::xorb_object::deserialize_chunks(&mut std::io::Cursor::new(part.data.as_ref()))
|
||||
.map_err(|e| {
|
||||
RetryableReqwestError::RetryableError(CasClientError::FormatError(e))
|
||||
RetryableReqwestError::RetryableError(ClientError::FormatError(e))
|
||||
})?;
|
||||
|
||||
xet_core_structures::xorb_object::append_chunk_segment(
|
||||
@@ -455,7 +458,7 @@ impl Client for RemoteClient {
|
||||
if let Some(expected) = uncompressed_size_if_known
|
||||
&& expected != all_decompressed.len()
|
||||
{
|
||||
return Err(RetryableReqwestError::RetryableError(CasClientError::Other(format!(
|
||||
return Err(RetryableReqwestError::RetryableError(ClientError::Other(format!(
|
||||
"get_file_term_data: expected {expected} uncompressed bytes, got {}",
|
||||
all_decompressed.len()
|
||||
))));
|
||||
@@ -482,14 +485,14 @@ impl Client for RemoteClient {
|
||||
if let Some(expected) = uncompressed_size_if_known
|
||||
&& expected != buffer.len()
|
||||
{
|
||||
return Err(RetryableReqwestError::RetryableError(CasClientError::Other(format!(
|
||||
return Err(RetryableReqwestError::RetryableError(ClientError::Other(format!(
|
||||
"get_file_term_data: expected {expected} uncompressed bytes, got {}",
|
||||
buffer.len()
|
||||
))));
|
||||
}
|
||||
Ok((Bytes::from(buffer), chunk_byte_indices))
|
||||
},
|
||||
Err(e) => Err(RetryableReqwestError::RetryableError(CasClientError::FormatError(e))),
|
||||
Err(e) => Err(RetryableReqwestError::RetryableError(ClientError::FormatError(e))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,13 +12,13 @@ use tracing::{error, info};
|
||||
use xet_runtime::core::xet_config;
|
||||
|
||||
use super::adaptive_concurrency::ConnectionPermit;
|
||||
use super::error::CasClientError;
|
||||
use super::http_client::request_id_from_response;
|
||||
use crate::error::{ClientError, Result};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum RetryableReqwestError {
|
||||
FatalError(CasClientError),
|
||||
RetryableError(CasClientError),
|
||||
FatalError(ClientError),
|
||||
RetryableError(ClientError),
|
||||
}
|
||||
|
||||
struct ConnectionPermitInfo {
|
||||
@@ -102,7 +102,7 @@ impl RetryWrapper {
|
||||
error!("{msg}");
|
||||
}
|
||||
|
||||
CasClientError::from(err)
|
||||
ClientError::from(err)
|
||||
};
|
||||
|
||||
// Here's the retry logic.
|
||||
@@ -122,7 +122,11 @@ impl RetryWrapper {
|
||||
}
|
||||
}
|
||||
|
||||
fn process_ok_response(&self, try_idx: usize, resp: Response) -> Result<Response, RetryableReqwestError> {
|
||||
fn process_ok_response(
|
||||
&self,
|
||||
try_idx: usize,
|
||||
resp: Response,
|
||||
) -> std::result::Result<Response, RetryableReqwestError> {
|
||||
let request_id = request_id_from_response(&resp).to_owned();
|
||||
|
||||
let retry_str = if try_idx == 0 {
|
||||
@@ -140,7 +144,7 @@ impl RetryWrapper {
|
||||
} else {
|
||||
error!("{context}: {api:?} api call failed (request id {request_id}{retry_str}): {err}");
|
||||
}
|
||||
CasClientError::from(err)
|
||||
ClientError::from(err)
|
||||
};
|
||||
|
||||
let retriability = default_on_request_success(&resp);
|
||||
@@ -206,12 +210,12 @@ impl RetryWrapper {
|
||||
self,
|
||||
make_request: ReqFn,
|
||||
process_fn: ProcFn,
|
||||
) -> Result<T, CasClientError>
|
||||
) -> Result<T>
|
||||
where
|
||||
ReqFn: Fn() -> ReqFut + Send + Sync + 'static,
|
||||
ReqFut: std::future::Future<Output = Result<Response, reqwest_middleware::Error>> + 'static,
|
||||
ReqFut: std::future::Future<Output = std::result::Result<Response, reqwest_middleware::Error>> + 'static,
|
||||
ProcFn: Fn(Response) -> ProcFut + Send + 'static,
|
||||
ProcFut: Future<Output = Result<T, RetryableReqwestError>> + 'static,
|
||||
ProcFut: Future<Output = std::result::Result<T, RetryableReqwestError>> + 'static,
|
||||
{
|
||||
let strategy = ExponentialBackoff::from_millis(self.base_delay.as_millis().min(u64::MAX as u128) as u64)
|
||||
.map(jitter)
|
||||
@@ -338,19 +342,16 @@ impl RetryWrapper {
|
||||
///
|
||||
/// This functions acts just like the json() function on a client response, but retries the entire connection on
|
||||
/// transient errors.
|
||||
pub async fn run_and_extract_json<JsonDest, ReqFn, ReqFut>(
|
||||
self,
|
||||
make_request: ReqFn,
|
||||
) -> Result<JsonDest, CasClientError>
|
||||
pub async fn run_and_extract_json<JsonDest, ReqFn, ReqFut>(self, make_request: ReqFn) -> Result<JsonDest>
|
||||
where
|
||||
JsonDest: for<'de> serde::Deserialize<'de>,
|
||||
ReqFn: Fn() -> ReqFut + Send + Sync + 'static,
|
||||
ReqFut: std::future::Future<Output = Result<Response, reqwest_middleware::Error>> + 'static,
|
||||
ReqFut: std::future::Future<Output = std::result::Result<Response, reqwest_middleware::Error>> + 'static,
|
||||
{
|
||||
self.run_and_process(make_request, |resp: Response| {
|
||||
async move {
|
||||
// Extract the json from the final result.
|
||||
let r: Result<JsonDest, reqwest::Error> = resp.json().await;
|
||||
let r: std::result::Result<JsonDest, reqwest::Error> = resp.json().await;
|
||||
|
||||
match r {
|
||||
Ok(v) => Ok(v),
|
||||
@@ -383,15 +384,15 @@ impl RetryWrapper {
|
||||
///
|
||||
/// This functions acts just like the bytes() function on a client response, but retries the entire connection on
|
||||
/// transient errors.
|
||||
pub async fn run_and_extract_bytes<ReqFut, ReqFn>(self, make_request: ReqFn) -> Result<Bytes, CasClientError>
|
||||
pub async fn run_and_extract_bytes<ReqFut, ReqFn>(self, make_request: ReqFn) -> Result<Bytes>
|
||||
where
|
||||
ReqFn: Fn() -> ReqFut + Send + Sync + 'static,
|
||||
ReqFut: std::future::Future<Output = Result<Response, reqwest_middleware::Error>> + 'static,
|
||||
ReqFut: std::future::Future<Output = std::result::Result<Response, reqwest_middleware::Error>> + 'static,
|
||||
{
|
||||
self.run_and_process(make_request, |resp: Response| {
|
||||
async move {
|
||||
// Extract the bytes from the final result.
|
||||
let r: Result<Bytes, reqwest::Error> = resp.bytes().await;
|
||||
let r: std::result::Result<Bytes, reqwest::Error> = resp.bytes().await;
|
||||
|
||||
match r {
|
||||
Ok(v) => Ok(v),
|
||||
@@ -432,12 +433,12 @@ impl RetryWrapper {
|
||||
self,
|
||||
make_request: ReqFn,
|
||||
parse: Parse,
|
||||
) -> Result<Dest, CasClientError>
|
||||
) -> Result<Dest>
|
||||
where
|
||||
ReqFn: Fn() -> ReqFut + Send + Sync + 'static,
|
||||
ReqFut: std::future::Future<Output = Result<Response, reqwest_middleware::Error>> + 'static,
|
||||
ReqFut: std::future::Future<Output = std::result::Result<Response, reqwest_middleware::Error>> + 'static,
|
||||
Parse: Fn(Response) -> ParseFut + Send + Sync + 'static,
|
||||
ParseFut: std::future::Future<Output = Result<Dest, RetryableReqwestError>> + 'static,
|
||||
ParseFut: std::future::Future<Output = std::result::Result<Dest, RetryableReqwestError>> + 'static,
|
||||
{
|
||||
self.run_and_process(make_request, parse).await
|
||||
}
|
||||
@@ -447,10 +448,10 @@ impl RetryWrapper {
|
||||
/// The `make_request` function returns a future that resolves to a Result<Response> object as is returned by the
|
||||
/// client middleware. For example, `|| client.clone().get(url).send()` returns a future (as `send()` is async)
|
||||
/// that will then be evaluated to get the response.
|
||||
pub async fn run<ReqFut, ReqFn>(self, make_request: ReqFn) -> Result<Response, CasClientError>
|
||||
pub async fn run<ReqFut, ReqFn>(self, make_request: ReqFn) -> Result<Response>
|
||||
where
|
||||
ReqFn: Fn() -> ReqFut + Send + Sync + 'static,
|
||||
ReqFut: std::future::Future<Output = Result<Response, reqwest_middleware::Error>> + 'static,
|
||||
ReqFut: std::future::Future<Output = std::result::Result<Response, reqwest_middleware::Error>> + 'static,
|
||||
{
|
||||
// Just have the process_fn pass through the response.
|
||||
self.run_and_process(make_request, |resp| async move { Ok(resp) }).await
|
||||
|
||||
@@ -8,8 +8,8 @@ use xet_core_structures::metadata_shard::file_structs::{FileDataSequenceEntry, F
|
||||
use xet_core_structures::metadata_shard::shard_in_memory::MDBInMemoryShard;
|
||||
use xet_core_structures::xorb_object::{Chunk, RawXorbData, SerializedXorbObject};
|
||||
|
||||
use super::super::error::Result;
|
||||
use super::super::interface::Client;
|
||||
use crate::error::Result;
|
||||
|
||||
/// Information about a term (segment) in the file, referencing an XORB and chunk range.
|
||||
#[derive(Clone, Debug)]
|
||||
|
||||
@@ -10,9 +10,9 @@ use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
|
||||
use super::super::error::CasClientError;
|
||||
use super::{ClientTestingUtils, DirectAccessClient};
|
||||
use crate::cas_types::FileRange;
|
||||
use crate::error::ClientError;
|
||||
|
||||
/// Runs all common Client trait tests using a factory that creates fresh clients.
|
||||
pub async fn test_client_functionality<Fut>(factory: impl Fn() -> Fut)
|
||||
@@ -535,15 +535,15 @@ pub async fn test_missing_xorb(client: Arc<dyn DirectAccessClient>) {
|
||||
|
||||
// get_full_xorb should return XORBNotFound
|
||||
let result = client.get_full_xorb(&fake_hash).await;
|
||||
assert!(matches!(result, Err(CasClientError::XORBNotFound(_))));
|
||||
assert!(matches!(result, Err(ClientError::XORBNotFound(_))));
|
||||
|
||||
// xorb_length should return XORBNotFound
|
||||
let result = client.xorb_length(&fake_hash).await;
|
||||
assert!(matches!(result, Err(CasClientError::XORBNotFound(_))));
|
||||
assert!(matches!(result, Err(ClientError::XORBNotFound(_))));
|
||||
|
||||
// get_xorb_ranges should return XORBNotFound
|
||||
let result = client.get_xorb_ranges(&fake_hash, vec![(0, 1)]).await;
|
||||
assert!(matches!(result, Err(CasClientError::XORBNotFound(_))));
|
||||
assert!(matches!(result, Err(ClientError::XORBNotFound(_))));
|
||||
}
|
||||
|
||||
/// Tests list_xorbs and delete_xorb operations.
|
||||
@@ -620,13 +620,13 @@ pub async fn test_get_file_data_with_ranges(client: Arc<dyn DirectAccessClient>)
|
||||
let result = client
|
||||
.get_file_data(&file.file_hash, Some(FileRange::new(file_size + 100, file_size + 1000)))
|
||||
.await;
|
||||
assert!(matches!(result.unwrap_err(), CasClientError::InvalidRange));
|
||||
assert!(matches!(result.unwrap_err(), ClientError::InvalidRange));
|
||||
|
||||
// Start equals file size returns error
|
||||
let result = client
|
||||
.get_file_data(&file.file_hash, Some(FileRange::new(file_size, file_size + 100)))
|
||||
.await;
|
||||
assert!(matches!(result.unwrap_err(), CasClientError::InvalidRange));
|
||||
assert!(matches!(result.unwrap_err(), ClientError::InvalidRange));
|
||||
}
|
||||
|
||||
/// Tests get_file_size returns correct size.
|
||||
@@ -756,7 +756,7 @@ async fn test_url_expiration_after_window(client: Arc<dyn DirectAccessClient>) {
|
||||
// The fetch should fail with expiration error
|
||||
let result = client.fetch_term_data(xorb_hash, fetch_info).await;
|
||||
assert!(result.is_err(), "URL should be expired after expiration window");
|
||||
assert!(matches!(result.unwrap_err(), CasClientError::PresignedUrlExpirationError));
|
||||
assert!(matches!(result.unwrap_err(), ClientError::PresignedUrlExpirationError));
|
||||
}
|
||||
|
||||
/// Tests that default URL expiration is effectively infinite.
|
||||
@@ -813,7 +813,7 @@ async fn test_url_expiration_exact_boundary(client: Arc<dyn DirectAccessClient>)
|
||||
|
||||
let result = client.fetch_term_data(xorb_hash, fetch_info).await;
|
||||
assert!(result.is_err(), "URL should be expired past boundary");
|
||||
assert!(matches!(result.unwrap_err(), CasClientError::PresignedUrlExpirationError));
|
||||
assert!(matches!(result.unwrap_err(), ClientError::PresignedUrlExpirationError));
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
|
||||
@@ -2,7 +2,7 @@ use async_trait::async_trait;
|
||||
use bytes::Bytes;
|
||||
use xet_core_structures::merklehash::MerkleHash;
|
||||
|
||||
use super::super::error::Result;
|
||||
use crate::error::Result;
|
||||
|
||||
/// Trait for clients that support deletion and integrity operations on shards and file entries.
|
||||
///
|
||||
|
||||
@@ -12,11 +12,11 @@ use bytes::Bytes;
|
||||
use xet_core_structures::merklehash::MerkleHash;
|
||||
use xet_core_structures::xorb_object::XorbObject;
|
||||
|
||||
use super::super::error::Result;
|
||||
use super::super::interface::Client;
|
||||
use crate::cas_types::{
|
||||
FileRange, QueryReconstructionResponse, QueryReconstructionResponseV2, XorbReconstructionFetchInfo,
|
||||
};
|
||||
use crate::error::Result;
|
||||
|
||||
/// A Client with direct access to XORB and file storage.
|
||||
///
|
||||
|
||||
@@ -31,12 +31,12 @@ use super::direct_access_client::DirectAccessClient;
|
||||
use super::xorb_utils::{self, REFERENCE_INSTANT};
|
||||
use crate::cas_client::Client;
|
||||
use crate::cas_client::adaptive_concurrency::AdaptiveConcurrencyController;
|
||||
use crate::cas_client::error::{CasClientError, Result};
|
||||
use crate::cas_client::progress_tracked_streams::ProgressCallback;
|
||||
use crate::cas_types::{
|
||||
BatchQueryReconstructionResponse, FileRange, HexMerkleHash, HttpRange, QueryReconstructionResponse,
|
||||
QueryReconstructionResponseV2, XorbMultiRangeFetch, XorbRangeDescriptor, XorbReconstructionFetchInfo,
|
||||
};
|
||||
use crate::error::{ClientError, Result};
|
||||
|
||||
pub struct LocalClient {
|
||||
// Note: Field order matters for Drop! heed::Env must be dropped before _tmp_dir
|
||||
@@ -121,7 +121,7 @@ impl LocalClient {
|
||||
}
|
||||
}
|
||||
result.ok_or_else(|| {
|
||||
CasClientError::Other(format!(
|
||||
ClientError::Other(format!(
|
||||
"Error opening db at {global_dedup_dir:?} after 5 attempts: {}",
|
||||
last_err.unwrap()
|
||||
))
|
||||
@@ -130,13 +130,13 @@ impl LocalClient {
|
||||
|
||||
let mut write_txn = global_dedup_db_env
|
||||
.write_txn()
|
||||
.map_err(|e| CasClientError::Other(format!("Error opening heed write transaction: {e}")))?;
|
||||
.map_err(|e| ClientError::Other(format!("Error opening heed write transaction: {e}")))?;
|
||||
let global_dedup_table = global_dedup_db_env
|
||||
.create_database(&mut write_txn, None)
|
||||
.map_err(|e| CasClientError::Other(format!("Error opening heed table: {e}")))?;
|
||||
.map_err(|e| ClientError::Other(format!("Error opening heed table: {e}")))?;
|
||||
write_txn
|
||||
.commit()
|
||||
.map_err(|e| CasClientError::Other(format!("Error committing heed database: {e}")))?;
|
||||
.map_err(|e| ClientError::Other(format!("Error committing heed database: {e}")))?;
|
||||
|
||||
// Open / set up the shard lookup
|
||||
let shard_manager = ShardFileManager::new_in_session_directory(shard_dir.clone(), true).await?;
|
||||
@@ -164,8 +164,8 @@ impl LocalClient {
|
||||
/// Returns all shard files in the shard directory as (shard_hash, path) pairs.
|
||||
fn shard_file_paths(&self) -> Result<Vec<(MerkleHash, PathBuf)>> {
|
||||
let mut result = Vec::new();
|
||||
for entry in std::fs::read_dir(&self.shard_dir).map_err(CasClientError::internal)? {
|
||||
let entry = entry.map_err(CasClientError::internal)?;
|
||||
for entry in std::fs::read_dir(&self.shard_dir).map_err(ClientError::internal)? {
|
||||
let entry = entry.map_err(ClientError::internal)?;
|
||||
let path = entry.path();
|
||||
if let Some(hash) = parse_shard_filename(&path)
|
||||
&& path.is_file()
|
||||
@@ -182,7 +182,7 @@ impl LocalClient {
|
||||
if path.exists() {
|
||||
Ok(path)
|
||||
} else {
|
||||
Err(CasClientError::Other(format!("Shard file not found for hash {}", hash.hex())))
|
||||
Err(ClientError::Other(format!("Shard file not found for hash {}", hash.hex())))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -281,7 +281,7 @@ impl super::DeletionControlableClient for LocalClient {
|
||||
for xorb_hash in &xorb_hashes {
|
||||
let xorb_path = self.get_path_for_entry(xorb_hash);
|
||||
if !xorb_path.exists() {
|
||||
return Err(CasClientError::Other(format!(
|
||||
return Err(ClientError::Other(format!(
|
||||
"Integrity error: shard {} references non-existent XORB {}",
|
||||
shard_hash.hex(),
|
||||
xorb_hash.hex()
|
||||
@@ -297,7 +297,7 @@ impl super::DeletionControlableClient for LocalClient {
|
||||
let segment = file_view.entry(seg_idx);
|
||||
|
||||
let chunk_count = xorb_chunk_counts.get(&segment.xorb_hash).ok_or_else(|| {
|
||||
CasClientError::Other(format!(
|
||||
ClientError::Other(format!(
|
||||
"Integrity error: file {} references XORB block {} not present in shard {}",
|
||||
fh.hex(),
|
||||
segment.xorb_hash.hex(),
|
||||
@@ -306,7 +306,7 @@ impl super::DeletionControlableClient for LocalClient {
|
||||
})?;
|
||||
|
||||
if segment.chunk_index_end as usize > *chunk_count {
|
||||
return Err(CasClientError::Other(format!(
|
||||
return Err(ClientError::Other(format!(
|
||||
"Integrity error: file {} references chunk range {}..{} but XORB block {} only has {} chunks",
|
||||
fh.hex(),
|
||||
segment.chunk_index_start,
|
||||
@@ -408,7 +408,7 @@ impl DirectAccessClient for LocalClient {
|
||||
let mut ret = Vec::new();
|
||||
self.xorb_dir
|
||||
.read_dir()
|
||||
.map_err(CasClientError::internal)?
|
||||
.map_err(ClientError::internal)?
|
||||
.filter_map(|x| x.ok())
|
||||
.filter_map(|x| x.file_name().into_string().ok())
|
||||
.for_each(|x| {
|
||||
@@ -443,7 +443,7 @@ impl DirectAccessClient for LocalClient {
|
||||
let file_path = self.get_path_for_entry(hash);
|
||||
let file = File::open(&file_path).map_err(|_| {
|
||||
error!("Unable to find file in local CAS {:?}", file_path);
|
||||
CasClientError::XORBNotFound(*hash)
|
||||
ClientError::XORBNotFound(*hash)
|
||||
})?;
|
||||
|
||||
let mut reader = BufReader::new(file);
|
||||
@@ -460,7 +460,7 @@ impl DirectAccessClient for LocalClient {
|
||||
let file_path = self.get_path_for_entry(hash);
|
||||
let file = File::open(&file_path).map_err(|_| {
|
||||
error!("Unable to find file in local CAS {:?}", file_path);
|
||||
CasClientError::XORBNotFound(*hash)
|
||||
ClientError::XORBNotFound(*hash)
|
||||
})?;
|
||||
|
||||
let mut reader = BufReader::new(file);
|
||||
@@ -488,7 +488,7 @@ impl DirectAccessClient for LocalClient {
|
||||
let length = xorb_obj.get_all_bytes(&mut reader)?.len();
|
||||
Ok(length as u32)
|
||||
},
|
||||
Err(_) => Err(CasClientError::XORBNotFound(*hash)),
|
||||
Err(_) => Err(ClientError::XORBNotFound(*hash)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -500,13 +500,13 @@ impl DirectAccessClient for LocalClient {
|
||||
};
|
||||
|
||||
if !md.is_file() {
|
||||
return Err(CasClientError::internal(format!(
|
||||
return Err(ClientError::InternalError(anyhow!(
|
||||
"Attempting to write to {file_path:?}, but it is not a file"
|
||||
)));
|
||||
}
|
||||
|
||||
let Ok(file) = File::open(&file_path) else {
|
||||
return Err(CasClientError::XORBNotFound(*hash));
|
||||
return Err(ClientError::XORBNotFound(*hash));
|
||||
};
|
||||
|
||||
let mut reader = BufReader::new(file);
|
||||
@@ -518,7 +518,7 @@ impl DirectAccessClient for LocalClient {
|
||||
let file_path = self.get_path_for_entry(hash);
|
||||
let mut file = File::open(&file_path).map_err(|_| {
|
||||
error!("Unable to find xorb in local CAS {:?}", file_path);
|
||||
CasClientError::XORBNotFound(*hash)
|
||||
ClientError::XORBNotFound(*hash)
|
||||
})?;
|
||||
|
||||
file.seek(SeekFrom::End(-(size_of::<u32>() as i64)))?;
|
||||
@@ -541,9 +541,9 @@ impl DirectAccessClient for LocalClient {
|
||||
.shard_manager
|
||||
.get_file_reconstruction_info(hash)
|
||||
.await
|
||||
.map_err(|e| anyhow!("{e}"))?
|
||||
.map_err(ClientError::internal)?
|
||||
else {
|
||||
return Err(CasClientError::FileNotFound(*hash));
|
||||
return Err(ClientError::FileNotFound(*hash));
|
||||
};
|
||||
|
||||
let mut file_vec = Vec::new();
|
||||
@@ -561,7 +561,7 @@ impl DirectAccessClient for LocalClient {
|
||||
let start = byte_range.as_ref().map(|range| range.start as usize).unwrap_or(0);
|
||||
|
||||
if byte_range.is_some() && start >= file_size {
|
||||
return Err(CasClientError::InvalidRange);
|
||||
return Err(ClientError::InvalidRange);
|
||||
}
|
||||
|
||||
let end = byte_range
|
||||
@@ -575,7 +575,7 @@ impl DirectAccessClient for LocalClient {
|
||||
|
||||
async fn get_xorb_raw_bytes(&self, hash: &MerkleHash, byte_range: Option<FileRange>) -> Result<Bytes> {
|
||||
let file_path = self.get_path_for_entry(hash);
|
||||
let data = std::fs::read(&file_path).map_err(|_| CasClientError::XORBNotFound(*hash))?;
|
||||
let data = std::fs::read(&file_path).map_err(|_| ClientError::XORBNotFound(*hash))?;
|
||||
|
||||
let start = byte_range.as_ref().map(|r| r.start as usize).unwrap_or(0);
|
||||
let end = byte_range
|
||||
@@ -585,7 +585,7 @@ impl DirectAccessClient for LocalClient {
|
||||
.min(data.len());
|
||||
|
||||
if start >= data.len() {
|
||||
return Err(CasClientError::InvalidRange);
|
||||
return Err(ClientError::InvalidRange);
|
||||
}
|
||||
|
||||
Ok(Bytes::from(data[start..end].to_vec()))
|
||||
@@ -593,7 +593,7 @@ impl DirectAccessClient for LocalClient {
|
||||
|
||||
async fn xorb_raw_length(&self, hash: &MerkleHash) -> Result<u64> {
|
||||
let file_path = self.get_path_for_entry(hash);
|
||||
let metadata = std::fs::metadata(&file_path).map_err(|_| CasClientError::XORBNotFound(*hash))?;
|
||||
let metadata = std::fs::metadata(&file_path).map_err(|_| ClientError::XORBNotFound(*hash))?;
|
||||
Ok(metadata.len())
|
||||
}
|
||||
|
||||
@@ -609,7 +609,7 @@ impl DirectAccessClient for LocalClient {
|
||||
let expiration_ms = self.url_expiration_ms.load(Ordering::Relaxed);
|
||||
let elapsed_ms = Instant::now().saturating_duration_since(url_timestamp).as_millis() as u64;
|
||||
if elapsed_ms > expiration_ms {
|
||||
return Err(CasClientError::PresignedUrlExpirationError);
|
||||
return Err(ClientError::PresignedUrlExpirationError);
|
||||
}
|
||||
|
||||
// Validate byte range matches url_range
|
||||
@@ -617,11 +617,11 @@ impl DirectAccessClient for LocalClient {
|
||||
// We convert url_range to FileRange for comparison
|
||||
let fetch_byte_range = FileRange::from(fetch_term.url_range);
|
||||
if url_byte_range.start != fetch_byte_range.start || url_byte_range.end != fetch_byte_range.end {
|
||||
return Err(CasClientError::InvalidArguments);
|
||||
return Err(ClientError::InvalidArguments);
|
||||
}
|
||||
let file = File::open(&file_path).map_err(|_| {
|
||||
error!("Unable to find xorb in local CAS {:?}", file_path);
|
||||
CasClientError::XORBNotFound(hash)
|
||||
ClientError::XORBNotFound(hash)
|
||||
})?;
|
||||
|
||||
let mut reader = BufReader::new(file);
|
||||
@@ -638,7 +638,7 @@ impl DirectAccessClient for LocalClient {
|
||||
for chunk_idx in fetch_term.range.start..fetch_term.range.end {
|
||||
let chunk_len = xorb_obj
|
||||
.uncompressed_chunk_length(chunk_idx)
|
||||
.map_err(|e| CasClientError::Other(format!("Failed to get chunk length: {e}")))?;
|
||||
.map_err(|e| ClientError::Other(format!("Failed to get chunk length: {e}")))?;
|
||||
cumulative += chunk_len;
|
||||
indices.push(cumulative);
|
||||
}
|
||||
@@ -666,7 +666,7 @@ impl LocalClient {
|
||||
let file_path = self.get_path_for_entry(hash);
|
||||
let mut file = File::open(&file_path).map_err(|_| {
|
||||
error!("Unable to find file in local CAS {:?}", file_path);
|
||||
CasClientError::XORBNotFound(*hash)
|
||||
ClientError::XORBNotFound(*hash)
|
||||
})?;
|
||||
XorbObject::deserialize(&mut file).map_err(Into::into)
|
||||
}
|
||||
@@ -784,7 +784,7 @@ impl Client for LocalClient {
|
||||
let env = self
|
||||
.global_dedup_db_env
|
||||
.as_ref()
|
||||
.ok_or_else(|| CasClientError::Other("LocalClient has been closed".to_string()))?;
|
||||
.ok_or_else(|| ClientError::Other("LocalClient has been closed".to_string()))?;
|
||||
let read_txn = env.read_txn().map_err(map_heed_db_error)?;
|
||||
|
||||
if let Some(shard) = self.global_dedup_table.get(&read_txn, chunk_hash).map_err(map_heed_db_error)? {
|
||||
@@ -838,7 +838,7 @@ impl Client for LocalClient {
|
||||
let env = self
|
||||
.global_dedup_db_env
|
||||
.as_ref()
|
||||
.ok_or_else(|| CasClientError::Other("LocalClient has been closed".to_string()))?;
|
||||
.ok_or_else(|| ClientError::Other("LocalClient has been closed".to_string()))?;
|
||||
let mut write_txn = env.write_txn().map_err(map_heed_db_error)?;
|
||||
|
||||
for chunk in chunk_hashes {
|
||||
@@ -878,7 +878,7 @@ impl Client for LocalClient {
|
||||
&serialized_data,
|
||||
)?;
|
||||
if computed_hash != hash {
|
||||
return Err(CasClientError::Other(format!(
|
||||
return Err(ClientError::Other(format!(
|
||||
"XORB hash mismatch: expected {}, got {}",
|
||||
hash.hex(),
|
||||
computed_hash.hex(),
|
||||
@@ -985,11 +985,11 @@ impl Client for LocalClient {
|
||||
url_info.refresh_url().await?;
|
||||
continue;
|
||||
}
|
||||
return Err(CasClientError::PresignedUrlExpirationError);
|
||||
return Err(ClientError::PresignedUrlExpirationError);
|
||||
}
|
||||
|
||||
// Read each byte range from the serialized file and deserialize the chunks.
|
||||
let mut file = File::open(&file_path).map_err(|_| CasClientError::XORBNotFound(MerkleHash::default()))?;
|
||||
let mut file = File::open(&file_path).map_err(|_| ClientError::XORBNotFound(MerkleHash::default()))?;
|
||||
|
||||
let mut all_decompressed = Vec::new();
|
||||
let mut all_chunk_indices = Vec::<u32>::new();
|
||||
@@ -1031,14 +1031,14 @@ impl Client for LocalClient {
|
||||
}
|
||||
|
||||
// Should not reach here, but return error if we do.
|
||||
Err(CasClientError::PresignedUrlExpirationError)
|
||||
Err(ClientError::PresignedUrlExpirationError)
|
||||
}
|
||||
}
|
||||
|
||||
fn map_heed_db_error(e: heed::Error) -> CasClientError {
|
||||
fn map_heed_db_error(e: heed::Error) -> ClientError {
|
||||
let msg = format!("Global shard dedup database error: {e:?}");
|
||||
warn!("{msg}");
|
||||
CasClientError::Other(msg)
|
||||
ClientError::Other(msg)
|
||||
}
|
||||
|
||||
fn generate_fetch_url(file_path: &Path, byte_range: &FileRange, timestamp: Instant) -> String {
|
||||
@@ -1051,13 +1051,13 @@ fn parse_fetch_url(url: &str) -> Result<(PathBuf, FileRange, Instant)> {
|
||||
parts.reverse();
|
||||
|
||||
if parts.len() != 4 {
|
||||
return Err(CasClientError::InvalidArguments);
|
||||
return Err(ClientError::InvalidArguments);
|
||||
}
|
||||
|
||||
let file_path_str = parts[0];
|
||||
let start_pos: u64 = parts[1].parse().map_err(|_| CasClientError::InvalidArguments)?;
|
||||
let end_pos: u64 = parts[2].parse().map_err(|_| CasClientError::InvalidArguments)?;
|
||||
let timestamp_ms: u64 = parts[3].parse().map_err(|_| CasClientError::InvalidArguments)?;
|
||||
let start_pos: u64 = parts[1].parse().map_err(|_| ClientError::InvalidArguments)?;
|
||||
let end_pos: u64 = parts[2].parse().map_err(|_| ClientError::InvalidArguments)?;
|
||||
let timestamp_ms: u64 = parts[3].parse().map_err(|_| ClientError::InvalidArguments)?;
|
||||
|
||||
let file_path: PathBuf = file_path_str.trim_matches('"').into();
|
||||
let byte_range = FileRange::new(start_pos, end_pos);
|
||||
@@ -1133,7 +1133,7 @@ mod tests {
|
||||
};
|
||||
let result = client.fetch_term_data(hash, invalid_fetch_term).await;
|
||||
assert!(result.is_err(), "URL with too few parts should fail");
|
||||
assert!(matches!(result.unwrap_err(), CasClientError::InvalidArguments));
|
||||
assert!(matches!(result.unwrap_err(), ClientError::InvalidArguments));
|
||||
|
||||
// Test 3: Invalid start_pos - doesn't match url_range.start
|
||||
let wrong_byte_range = FileRange::new(fetch_byte_start as u64 + 1, fetch_byte_end as u64);
|
||||
@@ -1145,7 +1145,7 @@ mod tests {
|
||||
};
|
||||
let result = client.fetch_term_data(hash, invalid_fetch_term).await;
|
||||
assert!(result.is_err(), "Wrong start_pos should fail");
|
||||
assert!(matches!(result.unwrap_err(), CasClientError::InvalidArguments));
|
||||
assert!(matches!(result.unwrap_err(), ClientError::InvalidArguments));
|
||||
|
||||
// Test 4: Invalid end_pos - doesn't match url_range.end
|
||||
let wrong_byte_range = FileRange::new(fetch_byte_start as u64, fetch_byte_end as u64 + 1);
|
||||
@@ -1157,7 +1157,7 @@ mod tests {
|
||||
};
|
||||
let result = client.fetch_term_data(hash, invalid_fetch_term).await;
|
||||
assert!(result.is_err(), "Wrong end_pos should fail");
|
||||
assert!(matches!(result.unwrap_err(), CasClientError::InvalidArguments));
|
||||
assert!(matches!(result.unwrap_err(), ClientError::InvalidArguments));
|
||||
|
||||
// Test 5: Invalid start_pos - non-numeric
|
||||
let timestamp_ms = timestamp.saturating_duration_since(*REFERENCE_INSTANT).as_millis() as u64;
|
||||
@@ -1169,7 +1169,7 @@ mod tests {
|
||||
};
|
||||
let result = client.fetch_term_data(hash, invalid_fetch_term).await;
|
||||
assert!(result.is_err(), "Non-numeric start_pos should fail");
|
||||
assert!(matches!(result.unwrap_err(), CasClientError::InvalidArguments));
|
||||
assert!(matches!(result.unwrap_err(), ClientError::InvalidArguments));
|
||||
|
||||
// Test 6: Invalid end_pos - non-numeric
|
||||
let non_numeric_end = format!("{:?}:{}:not_a_number:{}", file_path, fetch_byte_start, timestamp_ms);
|
||||
@@ -1180,7 +1180,7 @@ mod tests {
|
||||
};
|
||||
let result = client.fetch_term_data(hash, invalid_fetch_term).await;
|
||||
assert!(result.is_err(), "Non-numeric end_pos should fail");
|
||||
assert!(matches!(result.unwrap_err(), CasClientError::InvalidArguments));
|
||||
assert!(matches!(result.unwrap_err(), ClientError::InvalidArguments));
|
||||
|
||||
// Test 7: Empty URL
|
||||
let invalid_fetch_term = XorbReconstructionFetchInfo {
|
||||
@@ -1190,7 +1190,7 @@ mod tests {
|
||||
};
|
||||
let result = client.fetch_term_data(hash, invalid_fetch_term).await;
|
||||
assert!(result.is_err(), "Empty URL should fail");
|
||||
assert!(matches!(result.unwrap_err(), CasClientError::InvalidArguments));
|
||||
assert!(matches!(result.unwrap_err(), ClientError::InvalidArguments));
|
||||
|
||||
// Test 8: Invalid timestamp - non-numeric
|
||||
let non_numeric_timestamp = format!("{:?}:{}:{}:not_a_number", file_path, fetch_byte_start, fetch_byte_end);
|
||||
@@ -1201,7 +1201,7 @@ mod tests {
|
||||
};
|
||||
let result = client.fetch_term_data(hash, invalid_fetch_term).await;
|
||||
assert!(result.is_err(), "Non-numeric timestamp should fail");
|
||||
assert!(matches!(result.unwrap_err(), CasClientError::InvalidArguments));
|
||||
assert!(matches!(result.unwrap_err(), ClientError::InvalidArguments));
|
||||
|
||||
// Test 9: Non-existent file path
|
||||
let non_existent_path = PathBuf::from("/nonexistent/path/file.xorb");
|
||||
|
||||
@@ -28,13 +28,13 @@ use futures_util::StreamExt;
|
||||
use http::header::RANGE;
|
||||
use xet_core_structures::merklehash::MerkleHash;
|
||||
|
||||
use super::super::super::error::CasClientError;
|
||||
use super::super::super::{DeletionControlableClient, DirectAccessClient};
|
||||
use super::latency_simulation::{LatencySimulation, ServerLatencyProfile};
|
||||
use crate::cas_types::{
|
||||
FileRange, HexKey, HexMerkleHash, QueryReconstructionResponseV2, UploadShardResponse, UploadShardResponseType,
|
||||
UploadXorbResponse, XorbRangeDescriptor, XorbReconstructionFetchInfo,
|
||||
};
|
||||
use crate::error::ClientError;
|
||||
|
||||
/// Server state passed to all handlers.
|
||||
#[derive(Clone)]
|
||||
@@ -117,12 +117,12 @@ pub(super) fn parse_range_header(
|
||||
}
|
||||
}
|
||||
|
||||
/// Maps CasClientError to appropriate HTTP status codes.
|
||||
pub(super) fn error_to_response(e: CasClientError) -> Response {
|
||||
/// Maps ClientError to appropriate HTTP status codes.
|
||||
pub(super) fn error_to_response(e: ClientError) -> Response {
|
||||
let status = match &e {
|
||||
CasClientError::XORBNotFound(_) | CasClientError::FileNotFound(_) => StatusCode::NOT_FOUND,
|
||||
CasClientError::InvalidRange => StatusCode::RANGE_NOT_SATISFIABLE,
|
||||
CasClientError::InvalidArguments => StatusCode::BAD_REQUEST,
|
||||
ClientError::XORBNotFound(_) | ClientError::FileNotFound(_) => StatusCode::NOT_FOUND,
|
||||
ClientError::InvalidRange => StatusCode::RANGE_NOT_SATISFIABLE,
|
||||
ClientError::InvalidArguments => StatusCode::BAD_REQUEST,
|
||||
_ => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
};
|
||||
(status, e.to_string()).into_response()
|
||||
|
||||
@@ -43,6 +43,7 @@
|
||||
|
||||
use std::path::PathBuf;
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
use tracing_subscriber::EnvFilter;
|
||||
use xet_client::cas_client::{LocalServer, LocalServerConfig};
|
||||
@@ -93,7 +94,7 @@ struct Args {
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
async fn main() -> Result<()> {
|
||||
// Initialize tracing with environment filter (respects RUST_LOG)
|
||||
tracing_subscriber::fmt().with_env_filter(EnvFilter::from_default_env()).init();
|
||||
|
||||
|
||||
@@ -12,10 +12,11 @@
|
||||
//! # Example
|
||||
//!
|
||||
//! ```no_run
|
||||
//! use anyhow::Result;
|
||||
//! use xet_client::cas_client::{LocalServer, LocalServerConfig};
|
||||
//!
|
||||
//! #[tokio::main]
|
||||
//! async fn main() -> anyhow::Result<()> {
|
||||
//! async fn main() -> Result<()> {
|
||||
//! let config = LocalServerConfig {
|
||||
//! data_directory: "./data".into(),
|
||||
//! host: "127.0.0.1".to_string(),
|
||||
@@ -49,7 +50,6 @@ use tower_http::cors::CorsLayer;
|
||||
|
||||
#[cfg(test)]
|
||||
use super::super::super::RemoteClient;
|
||||
use super::super::super::error::{CasClientError, Result};
|
||||
#[cfg(test)]
|
||||
use super::super::super::interface::Client;
|
||||
#[cfg(test)]
|
||||
@@ -58,6 +58,7 @@ use super::super::socket_proxy::UnixSocketProxy;
|
||||
use super::super::{DeletionControlableClient, DirectAccessClient, LocalClient, MemoryClient};
|
||||
use super::handlers;
|
||||
use super::latency_simulation::LatencySimulation;
|
||||
use crate::error::{ClientError, Result};
|
||||
|
||||
/// Configuration for the local CAS server.
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -204,11 +205,11 @@ impl LocalServer {
|
||||
let addr: SocketAddr = self
|
||||
.addr()
|
||||
.parse()
|
||||
.map_err(|e| CasClientError::Other(format!("Failed to parse address: {e}")))?;
|
||||
.map_err(|e| ClientError::Other(format!("Failed to parse address: {e}")))?;
|
||||
|
||||
let listener = TcpListener::bind(addr)
|
||||
.await
|
||||
.map_err(|e| CasClientError::Other(format!("Failed to bind to {addr}: {e}")))?;
|
||||
.map_err(|e| ClientError::Other(format!("Failed to bind to {addr}: {e}")))?;
|
||||
|
||||
tracing::info!("Local CAS server listening on {}", addr);
|
||||
|
||||
@@ -217,7 +218,7 @@ impl LocalServer {
|
||||
axum::serve(listener, router.into_make_service())
|
||||
.with_graceful_shutdown(shutdown_signal())
|
||||
.await
|
||||
.map_err(|e| CasClientError::Other(format!("Server error: {e}")))
|
||||
.map_err(|e| ClientError::Other(format!("Server error: {e}")))
|
||||
}
|
||||
|
||||
/// Runs the server until a shutdown signal is received on the provided channel.
|
||||
@@ -227,11 +228,11 @@ impl LocalServer {
|
||||
let addr: SocketAddr = self
|
||||
.addr()
|
||||
.parse()
|
||||
.map_err(|e| CasClientError::Other(format!("Failed to parse address: {e}")))?;
|
||||
.map_err(|e| ClientError::Other(format!("Failed to parse address: {e}")))?;
|
||||
|
||||
let listener = TcpListener::bind(addr)
|
||||
.await
|
||||
.map_err(|e| CasClientError::Other(format!("Failed to bind to {addr}: {e}")))?;
|
||||
.map_err(|e| ClientError::Other(format!("Failed to bind to {addr}: {e}")))?;
|
||||
|
||||
tracing::info!("Local CAS server listening on {}", addr);
|
||||
|
||||
@@ -242,7 +243,7 @@ impl LocalServer {
|
||||
let _ = shutdown_rx.await;
|
||||
})
|
||||
.await
|
||||
.map_err(|e| CasClientError::Other(format!("Server error: {e}")))
|
||||
.map_err(|e| ClientError::Other(format!("Server error: {e}")))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,10 +14,10 @@ use super::simulation_types::{
|
||||
XorbRawLengthResponse,
|
||||
};
|
||||
use crate::cas_client::RemoteClient;
|
||||
use crate::cas_client::error::{CasClientError, Result};
|
||||
use crate::cas_client::interface::Client;
|
||||
use crate::cas_client::simulation::{DeletionControlableClient, DirectAccessClient};
|
||||
use crate::cas_types::{FileRange, HexMerkleHash, QueryReconstructionResponseV2, XorbReconstructionFetchInfo};
|
||||
use crate::error::{ClientError, Result};
|
||||
|
||||
/// A client that connects to a `LocalTestServer` via HTTP and provides access
|
||||
/// to both `DirectAccessClient` and `DeletionControlableClient` operations
|
||||
@@ -50,26 +50,26 @@ impl SimulationControlClient {
|
||||
format!("{}/simulation{}", self.endpoint, path)
|
||||
}
|
||||
|
||||
/// Checks an HTTP response status, mapping errors to `CasClientError`.
|
||||
/// Checks an HTTP response status, mapping errors to `ClientError`.
|
||||
async fn check_status(response: reqwest::Response) -> Result<reqwest::Response> {
|
||||
let status = response.status();
|
||||
if status.is_success() {
|
||||
Ok(response)
|
||||
} else if status == reqwest::StatusCode::NOT_IMPLEMENTED {
|
||||
Err(CasClientError::Other("Deletion controls not available for this server backend".to_string()))
|
||||
Err(ClientError::Other("Deletion controls not available for this server backend".to_string()))
|
||||
} else if status == reqwest::StatusCode::RANGE_NOT_SATISFIABLE {
|
||||
Err(CasClientError::InvalidRange)
|
||||
Err(ClientError::InvalidRange)
|
||||
} else {
|
||||
let body = response.text().await.unwrap_or_default();
|
||||
Err(CasClientError::Other(format!("HTTP {status}: {body}")))
|
||||
Err(ClientError::Other(format!("HTTP {status}: {body}")))
|
||||
}
|
||||
}
|
||||
|
||||
/// Like `check_status`, but maps 404 to `CasClientError::XORBNotFound` for XORB endpoints.
|
||||
/// Like `check_status`, but maps 404 to `ClientError::XORBNotFound` for XORB endpoints.
|
||||
async fn check_xorb_status(response: reqwest::Response, hash: &MerkleHash) -> Result<reqwest::Response> {
|
||||
let status = response.status();
|
||||
if status == reqwest::StatusCode::NOT_FOUND {
|
||||
Err(CasClientError::XORBNotFound(*hash))
|
||||
Err(ClientError::XORBNotFound(*hash))
|
||||
} else {
|
||||
Self::check_status(response).await
|
||||
}
|
||||
@@ -222,9 +222,9 @@ impl DirectAccessClient for SimulationControlClient {
|
||||
.get(self.sim_url("/xorbs"))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| CasClientError::Other(e.to_string()))?;
|
||||
.map_err(|e| ClientError::Other(e.to_string()))?;
|
||||
let resp = Self::check_status(resp).await?;
|
||||
resp.json().await.map_err(|e| CasClientError::Other(e.to_string()))
|
||||
resp.json().await.map_err(|e| ClientError::Other(e.to_string()))
|
||||
}
|
||||
|
||||
/// Deletes a XORB by hash via the `/simulation/xorbs/{hash}` endpoint.
|
||||
@@ -243,9 +243,9 @@ impl DirectAccessClient for SimulationControlClient {
|
||||
.get(&url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| CasClientError::Other(e.to_string()))?;
|
||||
.map_err(|e| ClientError::Other(e.to_string()))?;
|
||||
let resp = Self::check_xorb_status(resp, hash).await?;
|
||||
resp.bytes().await.map_err(|e| CasClientError::Other(e.to_string()))
|
||||
resp.bytes().await.map_err(|e| ClientError::Other(e.to_string()))
|
||||
}
|
||||
|
||||
/// Retrieves specific chunk ranges from a XORB via the `/simulation/xorbs/{hash}/ranges` endpoint.
|
||||
@@ -259,9 +259,9 @@ impl DirectAccessClient for SimulationControlClient {
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| CasClientError::Other(e.to_string()))?;
|
||||
.map_err(|e| ClientError::Other(e.to_string()))?;
|
||||
let resp = Self::check_xorb_status(resp, hash).await?;
|
||||
let result: XorbRangesResponse = resp.json().await.map_err(|e| CasClientError::Other(e.to_string()))?;
|
||||
let result: XorbRangesResponse = resp.json().await.map_err(|e| ClientError::Other(e.to_string()))?;
|
||||
Ok(result.data.into_iter().map(Bytes::from).collect())
|
||||
}
|
||||
|
||||
@@ -274,9 +274,9 @@ impl DirectAccessClient for SimulationControlClient {
|
||||
.get(&url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| CasClientError::Other(e.to_string()))?;
|
||||
.map_err(|e| ClientError::Other(e.to_string()))?;
|
||||
let resp = Self::check_xorb_status(resp, hash).await?;
|
||||
let result: XorbLengthResponse = resp.json().await.map_err(|e| CasClientError::Other(e.to_string()))?;
|
||||
let result: XorbLengthResponse = resp.json().await.map_err(|e| ClientError::Other(e.to_string()))?;
|
||||
Ok(result.length)
|
||||
}
|
||||
|
||||
@@ -289,9 +289,9 @@ impl DirectAccessClient for SimulationControlClient {
|
||||
.get(&url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| CasClientError::Other(e.to_string()))?;
|
||||
.map_err(|e| ClientError::Other(e.to_string()))?;
|
||||
let resp = Self::check_status(resp).await?;
|
||||
let result: XorbExistsResponse = resp.json().await.map_err(|e| CasClientError::Other(e.to_string()))?;
|
||||
let result: XorbExistsResponse = resp.json().await.map_err(|e| ClientError::Other(e.to_string()))?;
|
||||
Ok(result.exists)
|
||||
}
|
||||
|
||||
@@ -299,7 +299,7 @@ impl DirectAccessClient for SimulationControlClient {
|
||||
async fn xorb_footer(&self, hash: &MerkleHash) -> Result<XorbObject> {
|
||||
let raw_bytes = self.get_xorb_raw_bytes(hash, None).await?;
|
||||
XorbObject::deserialize(&mut std::io::Cursor::new(raw_bytes))
|
||||
.map_err(|e| CasClientError::Other(format!("Failed to deserialize XorbObject footer: {e}")))
|
||||
.map_err(|e| ClientError::Other(format!("Failed to deserialize XorbObject footer: {e}")))
|
||||
}
|
||||
|
||||
/// Returns the file size via the `/simulation/files/{hash}/size` endpoint.
|
||||
@@ -311,9 +311,9 @@ impl DirectAccessClient for SimulationControlClient {
|
||||
.get(&url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| CasClientError::Other(e.to_string()))?;
|
||||
.map_err(|e| ClientError::Other(e.to_string()))?;
|
||||
let resp = Self::check_status(resp).await?;
|
||||
let result: FileSizeResponse = resp.json().await.map_err(|e| CasClientError::Other(e.to_string()))?;
|
||||
let result: FileSizeResponse = resp.json().await.map_err(|e| ClientError::Other(e.to_string()))?;
|
||||
Ok(result.size)
|
||||
}
|
||||
|
||||
@@ -325,9 +325,9 @@ impl DirectAccessClient for SimulationControlClient {
|
||||
if let Some(range) = byte_range {
|
||||
req = req.header(reqwest::header::RANGE, format!("bytes={}-{}", range.start, range.end.saturating_sub(1)));
|
||||
}
|
||||
let resp = req.send().await.map_err(|e| CasClientError::Other(e.to_string()))?;
|
||||
let resp = req.send().await.map_err(|e| ClientError::Other(e.to_string()))?;
|
||||
let resp = Self::check_status(resp).await?;
|
||||
resp.bytes().await.map_err(|e| CasClientError::Other(e.to_string()))
|
||||
resp.bytes().await.map_err(|e| ClientError::Other(e.to_string()))
|
||||
}
|
||||
|
||||
/// Retrieves raw XORB bytes, optionally with a byte range, via the `/simulation/xorbs/{hash}/raw` endpoint.
|
||||
@@ -338,9 +338,9 @@ impl DirectAccessClient for SimulationControlClient {
|
||||
if let Some(range) = byte_range {
|
||||
req = req.header(reqwest::header::RANGE, format!("bytes={}-{}", range.start, range.end.saturating_sub(1)));
|
||||
}
|
||||
let resp = req.send().await.map_err(|e| CasClientError::Other(e.to_string()))?;
|
||||
let resp = req.send().await.map_err(|e| ClientError::Other(e.to_string()))?;
|
||||
let resp = Self::check_xorb_status(resp, hash).await?;
|
||||
resp.bytes().await.map_err(|e| CasClientError::Other(e.to_string()))
|
||||
resp.bytes().await.map_err(|e| ClientError::Other(e.to_string()))
|
||||
}
|
||||
|
||||
/// Returns the raw byte length of a XORB via the `/simulation/xorbs/{hash}/raw_length` endpoint.
|
||||
@@ -352,9 +352,9 @@ impl DirectAccessClient for SimulationControlClient {
|
||||
.get(&url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| CasClientError::Other(e.to_string()))?;
|
||||
.map_err(|e| ClientError::Other(e.to_string()))?;
|
||||
let resp = Self::check_xorb_status(resp, hash).await?;
|
||||
let result: XorbRawLengthResponse = resp.json().await.map_err(|e| CasClientError::Other(e.to_string()))?;
|
||||
let result: XorbRawLengthResponse = resp.json().await.map_err(|e| ClientError::Other(e.to_string()))?;
|
||||
Ok(result.length)
|
||||
}
|
||||
|
||||
@@ -372,9 +372,9 @@ impl DirectAccessClient for SimulationControlClient {
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| CasClientError::Other(e.to_string()))?;
|
||||
.map_err(|e| ClientError::Other(e.to_string()))?;
|
||||
let resp = Self::check_status(resp).await?;
|
||||
let result: FetchTermDataResponse = resp.json().await.map_err(|e| CasClientError::Other(e.to_string()))?;
|
||||
let result: FetchTermDataResponse = resp.json().await.map_err(|e| ClientError::Other(e.to_string()))?;
|
||||
Ok((Bytes::from(result.data), result.chunk_byte_indices))
|
||||
}
|
||||
}
|
||||
@@ -388,9 +388,9 @@ impl DeletionControlableClient for SimulationControlClient {
|
||||
.get(self.sim_url("/shards"))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| CasClientError::Other(e.to_string()))?;
|
||||
.map_err(|e| ClientError::Other(e.to_string()))?;
|
||||
let resp = Self::check_status(resp).await?;
|
||||
resp.json().await.map_err(|e| CasClientError::Other(e.to_string()))
|
||||
resp.json().await.map_err(|e| ClientError::Other(e.to_string()))
|
||||
}
|
||||
|
||||
/// Retrieves raw shard bytes by hash via the `/simulation/shards/{hash}` endpoint.
|
||||
@@ -402,9 +402,9 @@ impl DeletionControlableClient for SimulationControlClient {
|
||||
.get(&url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| CasClientError::Other(e.to_string()))?;
|
||||
.map_err(|e| ClientError::Other(e.to_string()))?;
|
||||
let resp = Self::check_status(resp).await?;
|
||||
resp.bytes().await.map_err(|e| CasClientError::Other(e.to_string()))
|
||||
resp.bytes().await.map_err(|e| ClientError::Other(e.to_string()))
|
||||
}
|
||||
|
||||
/// Deletes a shard entry by hash via the `/simulation/shards/{hash}` endpoint.
|
||||
@@ -416,7 +416,7 @@ impl DeletionControlableClient for SimulationControlClient {
|
||||
.delete(&url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| CasClientError::Other(e.to_string()))?;
|
||||
.map_err(|e| ClientError::Other(e.to_string()))?;
|
||||
Self::check_status(resp).await?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -428,9 +428,9 @@ impl DeletionControlableClient for SimulationControlClient {
|
||||
.get(self.sim_url("/file_entries"))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| CasClientError::Other(e.to_string()))?;
|
||||
.map_err(|e| ClientError::Other(e.to_string()))?;
|
||||
let resp = Self::check_status(resp).await?;
|
||||
let entries: Vec<FileShardsEntry> = resp.json().await.map_err(|e| CasClientError::Other(e.to_string()))?;
|
||||
let entries: Vec<FileShardsEntry> = resp.json().await.map_err(|e| ClientError::Other(e.to_string()))?;
|
||||
Ok(entries.into_iter().map(|e| (e.file_hash, e.shard_hash)).collect())
|
||||
}
|
||||
|
||||
@@ -443,7 +443,7 @@ impl DeletionControlableClient for SimulationControlClient {
|
||||
.delete(&url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| CasClientError::Other(e.to_string()))?;
|
||||
.map_err(|e| ClientError::Other(e.to_string()))?;
|
||||
Self::check_status(resp).await?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -455,7 +455,7 @@ impl DeletionControlableClient for SimulationControlClient {
|
||||
.post(self.sim_url("/verify_integrity"))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| CasClientError::Other(e.to_string()))?;
|
||||
.map_err(|e| ClientError::Other(e.to_string()))?;
|
||||
Self::check_status(resp).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -20,7 +20,6 @@ use xet_core_structures::xorb_object::{SerializedXorbObject, XorbObject};
|
||||
|
||||
use super::super::Client;
|
||||
use super::super::adaptive_concurrency::AdaptiveConcurrencyController;
|
||||
use super::super::error::{CasClientError, Result};
|
||||
use super::super::progress_tracked_streams::ProgressCallback;
|
||||
use super::client_testing_utils::{FileTermReference, RandomFileContents};
|
||||
use super::direct_access_client::DirectAccessClient;
|
||||
@@ -30,6 +29,7 @@ use crate::cas_types::{
|
||||
BatchQueryReconstructionResponse, FileRange, HexMerkleHash, HttpRange, QueryReconstructionResponse,
|
||||
QueryReconstructionResponseV2, XorbMultiRangeFetch, XorbRangeDescriptor, XorbReconstructionFetchInfo,
|
||||
};
|
||||
use crate::error::{ClientError, Result};
|
||||
|
||||
/// Stored XORB data: the serialized data and the deserialized XorbObject (header/footer).
|
||||
struct MaterializedXorb {
|
||||
@@ -308,7 +308,7 @@ impl DirectAccessClient for MemoryClient {
|
||||
let xorbs = self.xorbs.read().await;
|
||||
let storage = xorbs.get(hash).ok_or_else(|| {
|
||||
error!("Unable to find xorb in memory CAS {:?}", hash);
|
||||
CasClientError::XORBNotFound(*hash)
|
||||
ClientError::XORBNotFound(*hash)
|
||||
})?;
|
||||
|
||||
match storage {
|
||||
@@ -320,7 +320,7 @@ impl DirectAccessClient for MemoryClient {
|
||||
},
|
||||
XorbStorage::Random(xorb) => xorb
|
||||
.get_chunk_range_data(0, xorb.num_chunks())
|
||||
.ok_or(CasClientError::XORBNotFound(*hash)),
|
||||
.ok_or(ClientError::XORBNotFound(*hash)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -332,7 +332,7 @@ impl DirectAccessClient for MemoryClient {
|
||||
let xorbs = self.xorbs.read().await;
|
||||
let storage = xorbs.get(hash).ok_or_else(|| {
|
||||
error!("Unable to find xorb in memory CAS {:?}", hash);
|
||||
CasClientError::XORBNotFound(*hash)
|
||||
ClientError::XORBNotFound(*hash)
|
||||
})?;
|
||||
|
||||
match storage {
|
||||
@@ -358,7 +358,7 @@ impl DirectAccessClient for MemoryClient {
|
||||
ret.push(Bytes::new());
|
||||
continue;
|
||||
}
|
||||
let data = xorb.get_chunk_range_data(r.0, r.1).ok_or(CasClientError::XORBNotFound(*hash))?;
|
||||
let data = xorb.get_chunk_range_data(r.0, r.1).ok_or(ClientError::XORBNotFound(*hash))?;
|
||||
ret.push(data);
|
||||
}
|
||||
Ok(ret)
|
||||
@@ -379,7 +379,7 @@ impl DirectAccessClient for MemoryClient {
|
||||
let xorbs = self.xorbs.read().await;
|
||||
let storage = xorbs.get(hash).ok_or_else(|| {
|
||||
error!("Unable to find xorb in memory CAS {:?}", hash);
|
||||
CasClientError::XORBNotFound(*hash)
|
||||
ClientError::XORBNotFound(*hash)
|
||||
})?;
|
||||
|
||||
match storage {
|
||||
@@ -392,7 +392,7 @@ impl DirectAccessClient for MemoryClient {
|
||||
let shard = self.shard.read().await;
|
||||
let file_info = shard
|
||||
.get_file_reconstruction_info(hash)
|
||||
.ok_or(CasClientError::FileNotFound(*hash))?;
|
||||
.ok_or(ClientError::FileNotFound(*hash))?;
|
||||
Ok(file_info.file_size())
|
||||
}
|
||||
|
||||
@@ -401,7 +401,7 @@ impl DirectAccessClient for MemoryClient {
|
||||
let shard = self.shard.read().await;
|
||||
shard
|
||||
.get_file_reconstruction_info(hash)
|
||||
.ok_or(CasClientError::FileNotFound(*hash))?
|
||||
.ok_or(ClientError::FileNotFound(*hash))?
|
||||
};
|
||||
|
||||
let mut file_vec = Vec::new();
|
||||
@@ -419,7 +419,7 @@ impl DirectAccessClient for MemoryClient {
|
||||
let start = byte_range.as_ref().map(|range| range.start as usize).unwrap_or(0);
|
||||
|
||||
if byte_range.is_some() && start >= file_size {
|
||||
return Err(CasClientError::InvalidRange);
|
||||
return Err(ClientError::InvalidRange);
|
||||
}
|
||||
|
||||
let end = byte_range
|
||||
@@ -433,7 +433,7 @@ impl DirectAccessClient for MemoryClient {
|
||||
|
||||
async fn get_xorb_raw_bytes(&self, hash: &MerkleHash, byte_range: Option<FileRange>) -> Result<Bytes> {
|
||||
let xorbs = self.xorbs.read().await;
|
||||
let storage = xorbs.get(hash).ok_or(CasClientError::XORBNotFound(*hash))?;
|
||||
let storage = xorbs.get(hash).ok_or(ClientError::XORBNotFound(*hash))?;
|
||||
|
||||
match storage {
|
||||
XorbStorage::Materialized(entry) => {
|
||||
@@ -447,7 +447,7 @@ impl DirectAccessClient for MemoryClient {
|
||||
.min(data.len());
|
||||
|
||||
if start >= data.len() {
|
||||
return Err(CasClientError::InvalidRange);
|
||||
return Err(ClientError::InvalidRange);
|
||||
}
|
||||
|
||||
Ok(data.slice(start..end))
|
||||
@@ -458,7 +458,7 @@ impl DirectAccessClient for MemoryClient {
|
||||
let end = byte_range.as_ref().map(|r| r.end).unwrap_or(total_len).min(total_len);
|
||||
|
||||
if start >= total_len {
|
||||
return Err(CasClientError::InvalidRange);
|
||||
return Err(ClientError::InvalidRange);
|
||||
}
|
||||
|
||||
Ok(xorb.get_serialized_range(start, end))
|
||||
@@ -468,7 +468,7 @@ impl DirectAccessClient for MemoryClient {
|
||||
|
||||
async fn xorb_raw_length(&self, hash: &MerkleHash) -> Result<u64> {
|
||||
let xorbs = self.xorbs.read().await;
|
||||
let storage = xorbs.get(hash).ok_or(CasClientError::XORBNotFound(*hash))?;
|
||||
let storage = xorbs.get(hash).ok_or(ClientError::XORBNotFound(*hash))?;
|
||||
|
||||
match storage {
|
||||
XorbStorage::Materialized(entry) => Ok(entry.serialized_data.len() as u64),
|
||||
@@ -488,7 +488,7 @@ impl DirectAccessClient for MemoryClient {
|
||||
let expiration_ms = self.url_expiration_ms.load(Ordering::Relaxed);
|
||||
let elapsed_ms = Instant::now().saturating_duration_since(url_timestamp).as_millis() as u64;
|
||||
if elapsed_ms > expiration_ms {
|
||||
return Err(CasClientError::PresignedUrlExpirationError);
|
||||
return Err(ClientError::PresignedUrlExpirationError);
|
||||
}
|
||||
|
||||
// Validate byte range matches url_range
|
||||
@@ -496,13 +496,13 @@ impl DirectAccessClient for MemoryClient {
|
||||
// We convert url_range to FileRange for comparison
|
||||
let fetch_byte_range = FileRange::from(fetch_term.url_range);
|
||||
if url_byte_range.start != fetch_byte_range.start || url_byte_range.end != fetch_byte_range.end {
|
||||
return Err(CasClientError::InvalidArguments);
|
||||
return Err(ClientError::InvalidArguments);
|
||||
}
|
||||
|
||||
let xorbs = self.xorbs.read().await;
|
||||
let storage = xorbs.get(&xorb_hash).ok_or_else(|| {
|
||||
error!("Unable to find xorb in memory CAS {:?}", hash);
|
||||
CasClientError::XORBNotFound(hash)
|
||||
ClientError::XORBNotFound(hash)
|
||||
})?;
|
||||
|
||||
let (data, xorb_obj) = match storage {
|
||||
@@ -516,7 +516,7 @@ impl DirectAccessClient for MemoryClient {
|
||||
XorbStorage::Random(xorb) => {
|
||||
let data = xorb
|
||||
.get_chunk_range_data(fetch_term.range.start, fetch_term.range.end)
|
||||
.ok_or(CasClientError::XORBNotFound(hash))?;
|
||||
.ok_or(ClientError::XORBNotFound(hash))?;
|
||||
let xorb_obj = xorb.get_xorb_object();
|
||||
(data, xorb_obj)
|
||||
},
|
||||
@@ -529,7 +529,7 @@ impl DirectAccessClient for MemoryClient {
|
||||
for chunk_idx in fetch_term.range.start..fetch_term.range.end {
|
||||
let chunk_len = xorb_obj
|
||||
.uncompressed_chunk_length(chunk_idx)
|
||||
.map_err(|e| CasClientError::Other(format!("Failed to get chunk length: {e}")))?;
|
||||
.map_err(|e| ClientError::Other(format!("Failed to get chunk length: {e}")))?;
|
||||
cumulative += chunk_len;
|
||||
indices.push(cumulative);
|
||||
}
|
||||
@@ -558,7 +558,7 @@ impl MemoryClient {
|
||||
xorb_utils::compute_reconstruction_ranges(&file_info, bytes_range, &mut |hash| {
|
||||
let storage = xorbs.get(hash).ok_or_else(|| {
|
||||
error!("Unable to find xorb in memory CAS {:?}", hash);
|
||||
CasClientError::XORBNotFound(*hash)
|
||||
ClientError::XORBNotFound(*hash)
|
||||
})?;
|
||||
Ok(match storage {
|
||||
XorbStorage::Materialized(entry) => entry.xorb_object.clone(),
|
||||
@@ -764,7 +764,7 @@ impl Client for MemoryClient {
|
||||
&serialized_data,
|
||||
)?;
|
||||
if computed_hash != hash {
|
||||
return Err(CasClientError::Other(format!(
|
||||
return Err(ClientError::Other(format!(
|
||||
"XORB hash mismatch: expected {}, got {}",
|
||||
hash.hex(),
|
||||
computed_hash.hex(),
|
||||
@@ -847,11 +847,11 @@ impl Client for MemoryClient {
|
||||
let expiration_ms = self.url_expiration_ms.load(Ordering::Relaxed);
|
||||
let elapsed_ms = Instant::now().saturating_duration_since(url_timestamp).as_millis() as u64;
|
||||
if elapsed_ms > expiration_ms {
|
||||
return Err(CasClientError::PresignedUrlExpirationError);
|
||||
return Err(ClientError::PresignedUrlExpirationError);
|
||||
}
|
||||
|
||||
let xorbs = self.xorbs.read().await;
|
||||
let storage = xorbs.get(&xorb_hash).ok_or(CasClientError::XORBNotFound(xorb_hash))?;
|
||||
let storage = xorbs.get(&xorb_hash).ok_or(ClientError::XORBNotFound(xorb_hash))?;
|
||||
|
||||
// Extract each byte range from the serialized data and deserialize
|
||||
let mut all_decompressed = Vec::new();
|
||||
@@ -909,13 +909,13 @@ fn parse_fetch_url(url: &str) -> Result<(MerkleHash, FileRange, Instant)> {
|
||||
parts.reverse();
|
||||
|
||||
if parts.len() != 4 {
|
||||
return Err(CasClientError::InvalidArguments);
|
||||
return Err(ClientError::InvalidArguments);
|
||||
}
|
||||
|
||||
let hash = MerkleHash::from_hex(parts[0]).map_err(|_| CasClientError::InvalidArguments)?;
|
||||
let start_pos: u64 = parts[1].parse().map_err(|_| CasClientError::InvalidArguments)?;
|
||||
let end_pos: u64 = parts[2].parse().map_err(|_| CasClientError::InvalidArguments)?;
|
||||
let timestamp_ms: u64 = parts[3].parse().map_err(|_| CasClientError::InvalidArguments)?;
|
||||
let hash = MerkleHash::from_hex(parts[0]).map_err(|_| ClientError::InvalidArguments)?;
|
||||
let start_pos: u64 = parts[1].parse().map_err(|_| ClientError::InvalidArguments)?;
|
||||
let end_pos: u64 = parts[2].parse().map_err(|_| ClientError::InvalidArguments)?;
|
||||
let timestamp_ms: u64 = parts[3].parse().map_err(|_| ClientError::InvalidArguments)?;
|
||||
|
||||
let byte_range = FileRange::new(start_pos, end_pos);
|
||||
let timestamp = *REFERENCE_INSTANT + Duration::from_millis(timestamp_ms);
|
||||
|
||||
@@ -17,8 +17,8 @@ use tokio::sync::{Mutex, Semaphore};
|
||||
use tokio::time::{Instant, interval, sleep, sleep_until};
|
||||
use xet_runtime::utils::ClosureGuard;
|
||||
|
||||
use super::super::super::error::{CasClientError, Result};
|
||||
use super::network_profile::{NetworkConfig, NetworkProfile};
|
||||
use crate::error::{ClientError, Result};
|
||||
|
||||
const BUF_SIZE: usize = 65536;
|
||||
const REFILL_INTERVAL_MS: u64 = 50;
|
||||
@@ -144,7 +144,7 @@ impl NetworkSimulationProxy {
|
||||
let mut guard = self.listener.lock().await;
|
||||
guard
|
||||
.take()
|
||||
.ok_or_else(|| CasClientError::Other("accept loop already started or listener taken".into()))?
|
||||
.ok_or_else(|| ClientError::Other("accept loop already started or listener taken".into()))?
|
||||
};
|
||||
loop {
|
||||
if self.shutdown_flag.load(Ordering::Relaxed) {
|
||||
@@ -331,8 +331,8 @@ where
|
||||
Ok(total)
|
||||
}
|
||||
|
||||
fn map_proxy_err(e: impl std::fmt::Display) -> CasClientError {
|
||||
CasClientError::Other(format!("Proxy error: {}", e))
|
||||
fn map_proxy_err(e: impl std::fmt::Display) -> ClientError {
|
||||
ClientError::Other(format!("Proxy error: {}", e))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
||||
@@ -9,7 +9,7 @@ use std::time::Duration;
|
||||
|
||||
use human_bandwidth::parse_bandwidth as parse_bandwidth_str;
|
||||
|
||||
use super::super::super::error::{CasClientError, Result};
|
||||
use crate::error::{ClientError, Result};
|
||||
|
||||
/// What the proxy applies: bandwidth (bytes/sec), latency, jitter, drop probability.
|
||||
/// Only `bandwidth_bytes_per_sec` is optional (None = no limit); the rest default to zero.
|
||||
@@ -82,11 +82,11 @@ impl NetworkProfileOptions {
|
||||
/// Stored as bytes per second internally.
|
||||
pub fn with_bandwidth_str(mut self, s: &str) -> Result<Self> {
|
||||
let bw = parse_bandwidth_str(s.trim())
|
||||
.map_err(|e| CasClientError::Other(format!("invalid bandwidth {:?}: {}", s, e)))?;
|
||||
.map_err(|e| ClientError::Other(format!("invalid bandwidth {:?}: {}", s, e)))?;
|
||||
let bps = bw.as_bps();
|
||||
let bps = u64::try_from(bps).unwrap_or(u64::MAX);
|
||||
if bps == 0 {
|
||||
return Err(CasClientError::Other(format!("invalid bandwidth: {:?}", s)));
|
||||
return Err(ClientError::Other(format!("invalid bandwidth: {:?}", s)));
|
||||
}
|
||||
self.max_bandwidth_bytes_per_sec = Some(bps / 8);
|
||||
Ok(self)
|
||||
@@ -129,7 +129,7 @@ impl NetworkProfileOptions {
|
||||
bandwidth_scale: 0.7,
|
||||
}),
|
||||
_ => {
|
||||
return Err(CasClientError::Other(format!(
|
||||
return Err(ClientError::Other(format!(
|
||||
"unknown congestion profile {:?}; valid: {:?}",
|
||||
name, VALID_CONGESTION_PROFILE_NAMES
|
||||
)));
|
||||
|
||||
@@ -14,13 +14,13 @@ use reqwest::{Body, Url};
|
||||
use serde_json;
|
||||
|
||||
use super::super::adaptive_concurrency::ConnectionPermit;
|
||||
use super::super::error::{CasClientError, Result};
|
||||
use super::super::http_client::Api;
|
||||
use super::super::interface::Client;
|
||||
use super::super::progress_tracked_streams::{ProgressCallback, UploadProgressStream};
|
||||
use super::super::remote_client::RemoteClient;
|
||||
use super::super::retry_wrapper::RetryWrapper;
|
||||
use super::local_server::ServerLatencyProfile;
|
||||
use crate::error::{ClientError, Result};
|
||||
|
||||
/// A wrapper around `RemoteClient` that provides simulation-specific methods for controlling
|
||||
/// latency profiles and uploading dummy data for benchmarking and simulation purposes.
|
||||
@@ -43,7 +43,7 @@ impl RemoteSimulationClient {
|
||||
let client = self.inner.http_client();
|
||||
|
||||
let json_body = serde_json::to_vec(&profile)
|
||||
.map_err(|e| CasClientError::Other(format!("Failed to serialize ServerLatencyProfile: {e}")))?;
|
||||
.map_err(|e| ClientError::Other(format!("Failed to serialize ServerLatencyProfile: {e}")))?;
|
||||
|
||||
let response = client
|
||||
.post(url)
|
||||
@@ -52,12 +52,12 @@ impl RemoteSimulationClient {
|
||||
.body(json_body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| CasClientError::Other(format!("Failed to send set_config request: {e}")))?;
|
||||
.map_err(|e| ClientError::Other(format!("Failed to send set_config request: {e}")))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(CasClientError::Other(format!(
|
||||
return Err(ClientError::Other(format!(
|
||||
"set_config request failed with status {}: {}",
|
||||
status, error_text
|
||||
)));
|
||||
@@ -98,7 +98,7 @@ impl RemoteSimulationClient {
|
||||
.send()
|
||||
})
|
||||
.await
|
||||
.map_err(|e| CasClientError::Other(format!("Failed to upload dummy data: {e}")))?;
|
||||
.map_err(|e| ClientError::Other(format!("Failed to upload dummy data: {e}")))?;
|
||||
|
||||
Ok(n_upload_bytes)
|
||||
}
|
||||
|
||||
@@ -16,7 +16,6 @@ use tempfile::TempDir;
|
||||
use tokio::sync::oneshot;
|
||||
|
||||
use super::super::RemoteClient;
|
||||
use super::super::error::Result;
|
||||
use super::super::interface::Client;
|
||||
use super::super::progress_tracked_streams::ProgressCallback;
|
||||
use super::local_server::{LocalServer, ServerLatencyProfile};
|
||||
@@ -24,6 +23,7 @@ use super::network_simulation::{NetworkProfile, NetworkSimulationProxy};
|
||||
#[cfg(unix)]
|
||||
use super::socket_proxy::UnixSocketProxy;
|
||||
use super::{DirectAccessClient, LocalClient, MemoryClient, RemoteSimulationClient};
|
||||
use crate::error::Result;
|
||||
|
||||
/// Builder for creating a `LocalTestServer` with various configuration options.
|
||||
///
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
use std::path::PathBuf;
|
||||
|
||||
use anyhow::Result;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::{TcpStream, UnixListener, UnixStream};
|
||||
use tokio::sync::oneshot;
|
||||
@@ -21,9 +22,10 @@ use tokio::sync::oneshot;
|
||||
/// ```no_run
|
||||
/// use std::path::PathBuf;
|
||||
///
|
||||
/// use anyhow::Result;
|
||||
/// use xet_client::cas_client::simulation::socket_proxy::UnixSocketProxy;
|
||||
///
|
||||
/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
|
||||
/// # async fn example() -> Result<()> {
|
||||
/// let socket_path = PathBuf::from("/tmp/test_socket.sock");
|
||||
/// let proxy = UnixSocketProxy::new(socket_path, "127.0.0.1:8080".to_string()).await?;
|
||||
/// // Proxy is now listening on the Unix socket and forwarding to TCP
|
||||
@@ -45,7 +47,7 @@ impl UnixSocketProxy {
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the socket file cannot be created or bound.
|
||||
pub async fn new(socket_path: PathBuf, tcp_endpoint: String) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
pub async fn new(socket_path: PathBuf, tcp_endpoint: String) -> Result<Self> {
|
||||
// Remove socket file if it exists
|
||||
if socket_path.exists() {
|
||||
std::fs::remove_file(&socket_path)?;
|
||||
@@ -97,7 +99,7 @@ impl UnixSocketProxy {
|
||||
}
|
||||
|
||||
/// Handles a single connection by proxying data between Unix socket and TCP.
|
||||
async fn handle_connection(unix_stream: UnixStream, tcp_endpoint: &str) -> Result<(), Box<dyn std::error::Error>> {
|
||||
async fn handle_connection(unix_stream: UnixStream, tcp_endpoint: &str) -> Result<()> {
|
||||
let tcp_stream = TcpStream::connect(tcp_endpoint).await?;
|
||||
|
||||
// Use tokio::io::split to get owned halves that can be moved into tasks
|
||||
|
||||
@@ -13,8 +13,8 @@ use xet_core_structures::merklehash::MerkleHash;
|
||||
use xet_core_structures::metadata_shard::file_structs::MDBFileInfo;
|
||||
use xet_core_structures::xorb_object::XorbObject;
|
||||
|
||||
use crate::cas_client::error::{CasClientError, Result};
|
||||
use crate::cas_types::{ChunkRange, FileRange, HttpRange, XorbRangeDescriptor, XorbReconstructionTerm};
|
||||
use crate::error::{ClientError, Result};
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
/// Reference instant for URL timestamps. Initialized far in the past to allow
|
||||
@@ -76,7 +76,7 @@ pub(crate) fn compute_reconstruction_ranges(
|
||||
|
||||
loop {
|
||||
if s_idx >= file_info.segments.len() {
|
||||
return Err(CasClientError::InvalidRange);
|
||||
return Err(ClientError::InvalidRange);
|
||||
}
|
||||
|
||||
let n = file_info.segments[s_idx].unpacked_segment_bytes as u64;
|
||||
@@ -202,16 +202,16 @@ pub(crate) fn generate_v2_fetch_url(hash: &MerkleHash, ranges: &[XorbRangeDescri
|
||||
|
||||
/// Parses a V2 fetch URL back into (hash, timestamp, byte ranges).
|
||||
pub(crate) fn parse_v2_fetch_url(url: &str) -> Result<(MerkleHash, Instant, Vec<HttpRange>)> {
|
||||
let bytes = URL_SAFE_NO_PAD.decode(url).map_err(|_| CasClientError::InvalidArguments)?;
|
||||
let payload = String::from_utf8(bytes).map_err(|_| CasClientError::InvalidArguments)?;
|
||||
let bytes = URL_SAFE_NO_PAD.decode(url).map_err(|_| ClientError::InvalidArguments)?;
|
||||
let payload = String::from_utf8(bytes).map_err(|_| ClientError::InvalidArguments)?;
|
||||
|
||||
let mut parts = payload.splitn(3, ':');
|
||||
let hash_hex = parts.next().ok_or(CasClientError::InvalidArguments)?;
|
||||
let ts_str = parts.next().ok_or(CasClientError::InvalidArguments)?;
|
||||
let ranges_str = parts.next().ok_or(CasClientError::InvalidArguments)?;
|
||||
let hash_hex = parts.next().ok_or(ClientError::InvalidArguments)?;
|
||||
let ts_str = parts.next().ok_or(ClientError::InvalidArguments)?;
|
||||
let ranges_str = parts.next().ok_or(ClientError::InvalidArguments)?;
|
||||
|
||||
let hash = MerkleHash::from_hex(hash_hex).map_err(|_| CasClientError::InvalidArguments)?;
|
||||
let timestamp_ms: u64 = ts_str.parse().map_err(|_| CasClientError::InvalidArguments)?;
|
||||
let hash = MerkleHash::from_hex(hash_hex).map_err(|_| ClientError::InvalidArguments)?;
|
||||
let timestamp_ms: u64 = ts_str.parse().map_err(|_| ClientError::InvalidArguments)?;
|
||||
let timestamp = *REFERENCE_INSTANT + Duration::from_millis(timestamp_ms);
|
||||
|
||||
let mut ranges = Vec::new();
|
||||
@@ -219,14 +219,14 @@ pub(crate) fn parse_v2_fetch_url(url: &str) -> Result<(MerkleHash, Instant, Vec<
|
||||
let mut parts = r.splitn(2, '-');
|
||||
let start: u64 = parts
|
||||
.next()
|
||||
.ok_or(CasClientError::InvalidArguments)?
|
||||
.ok_or(ClientError::InvalidArguments)?
|
||||
.parse()
|
||||
.map_err(|_| CasClientError::InvalidArguments)?;
|
||||
.map_err(|_| ClientError::InvalidArguments)?;
|
||||
let end: u64 = parts
|
||||
.next()
|
||||
.ok_or(CasClientError::InvalidArguments)?
|
||||
.ok_or(ClientError::InvalidArguments)?
|
||||
.parse()
|
||||
.map_err(|_| CasClientError::InvalidArguments)?;
|
||||
.map_err(|_| ClientError::InvalidArguments)?;
|
||||
ranges.push(HttpRange::new(start, end));
|
||||
}
|
||||
|
||||
@@ -451,7 +451,7 @@ mod tests {
|
||||
} else if *hash == hash_b {
|
||||
Ok(obj_b.clone())
|
||||
} else {
|
||||
Err(CasClientError::XORBNotFound(*hash))
|
||||
Err(ClientError::XORBNotFound(*hash))
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
pub use crate::error::ClientError as CasTypesError;
|
||||
@@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize};
|
||||
use xet_core_structures::merklehash::MerkleHash;
|
||||
use xet_core_structures::merklehash::data_hash::hex;
|
||||
|
||||
use super::error::CasTypesError;
|
||||
use crate::error::ClientError;
|
||||
|
||||
/// A Key indicates a prefixed merkle hash for some data stored in the CAS DB.
|
||||
#[derive(Debug, PartialEq, Default, Serialize, Deserialize, Ord, PartialOrd, Eq, Hash, Clone)]
|
||||
@@ -21,15 +21,15 @@ impl Display for Key {
|
||||
}
|
||||
|
||||
impl FromStr for Key {
|
||||
type Err = CasTypesError;
|
||||
type Err = ClientError;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
let parts = s.rsplit_once('/');
|
||||
let Some((prefix, hash)) = parts else {
|
||||
return Err(CasTypesError::InvalidKey(s.to_owned()));
|
||||
return Err(ClientError::InvalidKey(s.to_owned()));
|
||||
};
|
||||
|
||||
let hash = MerkleHash::from_hex(hash).map_err(|_| CasTypesError::InvalidKey(s.to_owned()))?;
|
||||
let hash = MerkleHash::from_hex(hash).map_err(|_| ClientError::InvalidKey(s.to_owned()))?;
|
||||
|
||||
Ok(Key {
|
||||
prefix: prefix.to_owned(),
|
||||
|
||||
@@ -9,7 +9,6 @@ use serde_repr::{Deserialize_repr, Serialize_repr};
|
||||
use thiserror::Error;
|
||||
use xet_core_structures::merklehash::MerkleHash;
|
||||
|
||||
mod error;
|
||||
mod key;
|
||||
pub use key::*;
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
use std::fmt::Debug;
|
||||
use std::num::TryFromIntError;
|
||||
|
||||
use anyhow::anyhow;
|
||||
use anyhow::Error as AnyhowError;
|
||||
use http::StatusCode;
|
||||
use thiserror::Error;
|
||||
use tokio::sync::AcquireError;
|
||||
@@ -15,7 +14,7 @@ use crate::cas_client::auth::AuthError;
|
||||
#[derive(Error, Debug)]
|
||||
pub enum ClientError {
|
||||
#[error("Format error: {0}")]
|
||||
FormatError(#[from] xet_core_structures::FormatError),
|
||||
FormatError(#[from] xet_core_structures::CoreError),
|
||||
|
||||
#[error("Configuration error: {0}")]
|
||||
ConfigurationError(String),
|
||||
@@ -36,7 +35,7 @@ pub enum ClientError {
|
||||
InvalidShardKey(String),
|
||||
|
||||
#[error("Internal error: {0}")]
|
||||
InternalError(#[from] anyhow::Error),
|
||||
InternalError(AnyhowError),
|
||||
|
||||
#[error("{0}")]
|
||||
Other(String),
|
||||
@@ -66,7 +65,7 @@ pub enum ClientError {
|
||||
AuthError(#[from] AuthError),
|
||||
|
||||
#[error("Credential helper error: {0}")]
|
||||
CredentialHelper(anyhow::Error),
|
||||
CredentialHelper(AnyhowError),
|
||||
|
||||
#[error("Invalid repo type: {0}")]
|
||||
InvalidRepoType(String),
|
||||
@@ -101,8 +100,12 @@ impl From<reqwest::Error> for ClientError {
|
||||
}
|
||||
|
||||
impl ClientError {
|
||||
pub fn internal<T: Debug>(value: T) -> Self {
|
||||
ClientError::InternalError(anyhow!("{value:?}"))
|
||||
pub fn internal(value: impl std::error::Error + Send + Sync + 'static) -> Self {
|
||||
ClientError::InternalError(AnyhowError::new(value))
|
||||
}
|
||||
|
||||
pub fn credential_helper_error(e: impl std::error::Error + Send + Sync + 'static) -> Self {
|
||||
ClientError::CredentialHelper(AnyhowError::new(e))
|
||||
}
|
||||
|
||||
pub fn status(&self) -> Option<StatusCode> {
|
||||
@@ -135,7 +138,7 @@ impl From<xet_runtime::utils::singleflight::SingleflightError<ClientError>> for
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<std::sync::PoisonError<T>> for ClientError {
|
||||
impl<T: Send + Sync + 'static> From<std::sync::PoisonError<T>> for ClientError {
|
||||
fn from(value: std::sync::PoisonError<T>) -> Self {
|
||||
Self::internal(value)
|
||||
}
|
||||
@@ -147,7 +150,7 @@ impl From<AcquireError> for ClientError {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<SendError<T>> for ClientError {
|
||||
impl<T: Send + Sync + 'static> From<SendError<T>> for ClientError {
|
||||
fn from(value: SendError<T>) -> Self {
|
||||
Self::internal(value)
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use reqwest_middleware::RequestBuilder;
|
||||
|
||||
use super::CredentialHelper;
|
||||
use crate::error::ClientError;
|
||||
|
||||
pub struct NoopCredentialHelper {}
|
||||
|
||||
@@ -16,7 +16,7 @@ impl NoopCredentialHelper {
|
||||
|
||||
#[async_trait]
|
||||
impl CredentialHelper for NoopCredentialHelper {
|
||||
async fn fill_credential(&self, req: RequestBuilder) -> Result<RequestBuilder> {
|
||||
async fn fill_credential(&self, req: RequestBuilder) -> Result<RequestBuilder, ClientError> {
|
||||
Ok(req)
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ impl BearerCredentialHelper {
|
||||
|
||||
#[async_trait]
|
||||
impl CredentialHelper for BearerCredentialHelper {
|
||||
async fn fill_credential(&self, req: RequestBuilder) -> Result<RequestBuilder> {
|
||||
async fn fill_credential(&self, req: RequestBuilder) -> Result<RequestBuilder, ClientError> {
|
||||
Ok(req.bearer_auth(&self.hf_token))
|
||||
}
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use reqwest_middleware::RequestBuilder;
|
||||
|
||||
use crate::error::ClientError;
|
||||
|
||||
#[async_trait]
|
||||
pub trait CredentialHelper: Send + Sync {
|
||||
async fn fill_credential(&self, req: RequestBuilder) -> Result<RequestBuilder>;
|
||||
async fn fill_credential(&self, req: RequestBuilder) -> Result<RequestBuilder, ClientError>;
|
||||
|
||||
// Used in tests to identify the source of the credential.
|
||||
fn whoami(&self) -> &str;
|
||||
}
|
||||
|
||||
@@ -4,11 +4,11 @@ use http::header::HeaderMap;
|
||||
use urlencoding::encode;
|
||||
|
||||
use super::auth::CredentialHelper;
|
||||
use super::errors::*;
|
||||
use super::types::{CasJWTInfo, RepoInfo};
|
||||
use crate::cas_client::exports::ClientWithMiddleware;
|
||||
use crate::cas_client::retry_wrapper::RetryWrapper;
|
||||
use crate::cas_client::{Api, build_http_client};
|
||||
use crate::error::Result;
|
||||
|
||||
/// The type of operation to perform, either to upload files or to download files.
|
||||
/// Different operations lead to CAS access token with different authorization levels.
|
||||
@@ -98,7 +98,7 @@ impl HubClient {
|
||||
let req = cred_helper
|
||||
.fill_credential(req)
|
||||
.await
|
||||
.map_err(reqwest_middleware::Error::Middleware)?;
|
||||
.map_err(reqwest_middleware::Error::middleware)?;
|
||||
req.send().await
|
||||
}
|
||||
})
|
||||
@@ -114,9 +114,9 @@ mod tests {
|
||||
|
||||
use http::header::{self, HeaderMap, HeaderValue};
|
||||
|
||||
use super::super::errors::Result;
|
||||
use super::super::{BearerCredentialHelper, HFRepoType, Operation, RepoInfo};
|
||||
use super::HubClient;
|
||||
use crate::error::Result;
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore = "need valid write token"]
|
||||
|
||||
@@ -1,8 +0,0 @@
|
||||
pub use crate::error::ClientError as HubClientError;
|
||||
pub type Result<T> = std::result::Result<T, HubClientError>;
|
||||
|
||||
impl HubClientError {
|
||||
pub fn credential_helper_error(e: impl std::error::Error + Send + Sync + 'static) -> HubClientError {
|
||||
HubClientError::CredentialHelper(e.into())
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,7 @@
|
||||
mod auth;
|
||||
mod client;
|
||||
mod errors;
|
||||
mod types;
|
||||
|
||||
pub use auth::{BearerCredentialHelper, CredentialHelper, NoopCredentialHelper};
|
||||
pub use client::{HubClient, Operation};
|
||||
pub use errors::{HubClientError, Result};
|
||||
pub use types::{CasJWTInfo, HFRepoType, RepoInfo};
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::str::FromStr;
|
||||
|
||||
use serde::Deserialize;
|
||||
|
||||
use super::errors::{HubClientError, Result};
|
||||
use crate::error::{ClientError, Result};
|
||||
|
||||
/// This defines the response format from the Huggingface Hub Xet CAS access token API.
|
||||
#[derive(Deserialize, Debug)]
|
||||
@@ -23,7 +23,7 @@ pub enum HFRepoType {
|
||||
}
|
||||
|
||||
impl FromStr for HFRepoType {
|
||||
type Err = HubClientError;
|
||||
type Err = ClientError;
|
||||
|
||||
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
@@ -31,7 +31,7 @@ impl FromStr for HFRepoType {
|
||||
"model" | "models" => Ok(HFRepoType::Model),
|
||||
"dataset" | "datasets" => Ok(HFRepoType::Dataset),
|
||||
"space" | "spaces" => Ok(HFRepoType::Space),
|
||||
t => Err(HubClientError::InvalidRepoType(t.to_owned())),
|
||||
t => Err(ClientError::InvalidRepoType(t.to_owned())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,7 +23,6 @@ bench = true
|
||||
[dependencies]
|
||||
xet-runtime = { version = "1.4.0", path = "../xet_runtime" }
|
||||
|
||||
anyhow = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
base64 = { workspace = true }
|
||||
blake3 = { workspace = true }
|
||||
|
||||
@@ -7,13 +7,13 @@ use crate::merklehash::MerkleHash;
|
||||
|
||||
#[non_exhaustive]
|
||||
#[derive(Error, Debug)]
|
||||
pub enum FormatError {
|
||||
pub enum CoreError {
|
||||
// -- Common ----------------------------------------------------------
|
||||
#[error("I/O error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
#[error("Internal error: {0}")]
|
||||
Internal(anyhow::Error),
|
||||
InternalError(String),
|
||||
|
||||
#[error("{0}")]
|
||||
Other(String),
|
||||
@@ -50,14 +50,14 @@ pub enum FormatError {
|
||||
#[error("Invalid arguments")]
|
||||
InvalidArguments,
|
||||
|
||||
#[error("Format error: {0}")]
|
||||
Format(anyhow::Error),
|
||||
#[error("Malformed data: {0}")]
|
||||
MalformedData(String),
|
||||
|
||||
#[error("Hash mismatch")]
|
||||
HashMismatch,
|
||||
|
||||
#[error("Compression error: {0}")]
|
||||
Compression(#[from] lz4_flex::frame::Error),
|
||||
CompressionError(#[from] lz4_flex::frame::Error),
|
||||
|
||||
#[error("Hash parsing error: {0}")]
|
||||
HashParsing(#[from] Infallible),
|
||||
@@ -67,7 +67,7 @@ pub enum FormatError {
|
||||
|
||||
// -- Runtime ---------------------------------------------------------
|
||||
#[error("Runtime error: {0}")]
|
||||
Runtime(#[from] xet_runtime::RuntimeError),
|
||||
RuntimeError(#[from] xet_runtime::RuntimeError),
|
||||
|
||||
#[error("Task lock error: {0}")]
|
||||
TaskRuntime(#[from] xet_runtime::utils::RwTaskLockError),
|
||||
@@ -76,15 +76,15 @@ pub enum FormatError {
|
||||
TaskJoin(#[from] tokio::task::JoinError),
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, FormatError>;
|
||||
pub type Result<T> = std::result::Result<T, CoreError>;
|
||||
|
||||
impl PartialEq for FormatError {
|
||||
fn eq(&self, other: &FormatError) -> bool {
|
||||
impl PartialEq for CoreError {
|
||||
fn eq(&self, other: &CoreError) -> bool {
|
||||
std::mem::discriminant(self) == std::mem::discriminant(other)
|
||||
}
|
||||
}
|
||||
|
||||
impl FormatError {
|
||||
impl CoreError {
|
||||
pub fn other(inner: impl ToString) -> Self {
|
||||
Self::Other(inner.to_string())
|
||||
}
|
||||
@@ -103,7 +103,7 @@ impl<T> Validate<T> for Result<T> {
|
||||
fn ok_for_format_error(self) -> Result<Option<T>> {
|
||||
match self {
|
||||
Ok(v) => Ok(Some(v)),
|
||||
Err(FormatError::Format(e)) => {
|
||||
Err(CoreError::MalformedData(e)) => {
|
||||
warn!("XORB Validation: {e}");
|
||||
Ok(None)
|
||||
},
|
||||
@@ -112,14 +112,14 @@ impl<T> Validate<T> for Result<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<crate::merklehash::DataHashHexParseError> for FormatError {
|
||||
impl From<crate::merklehash::DataHashHexParseError> for CoreError {
|
||||
fn from(_: crate::merklehash::DataHashHexParseError) -> Self {
|
||||
FormatError::Other("Invalid hex input for DataHash".to_string())
|
||||
CoreError::Other("Invalid hex input for DataHash".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<crate::merklehash::DataHashBytesParseError> for FormatError {
|
||||
impl From<crate::merklehash::DataHashBytesParseError> for CoreError {
|
||||
fn from(_: crate::merklehash::DataHashBytesParseError) -> Self {
|
||||
FormatError::Other("Invalid bytes input for DataHash".to_string())
|
||||
CoreError::Other("Invalid bytes input for DataHash".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#![cfg_attr(feature = "strict", deny(warnings))]
|
||||
|
||||
pub mod error;
|
||||
pub use error::FormatError;
|
||||
pub use error::CoreError;
|
||||
|
||||
pub mod data_structures;
|
||||
pub mod merklehash;
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
pub use crate::error::FormatError as MDBShardError;
|
||||
pub type Result<T> = std::result::Result<T, MDBShardError>;
|
||||
@@ -5,7 +5,6 @@ use std::mem::size_of;
|
||||
use bytes::Bytes;
|
||||
use serde::Serialize;
|
||||
|
||||
use super::error::MDBShardError;
|
||||
use super::shard_file::MDB_FILE_INFO_ENTRY_SIZE;
|
||||
use super::xorb_structs::{XorbChunkSequenceEntry, XorbChunkSequenceHeader};
|
||||
use crate::merklehash::data_hash::hex;
|
||||
@@ -435,7 +434,7 @@ impl MDBFileInfo {
|
||||
/// Merges the content of other into the content of self if needed.
|
||||
/// After this call, self will have the verification info and metadata
|
||||
/// extension if they exist in the other object but not this one.
|
||||
pub fn merge_from(&mut self, other: &Self) -> Result<(), MDBShardError> {
|
||||
pub fn merge_from(&mut self, other: &Self) -> crate::error::Result<()> {
|
||||
FileDataSequenceHeader::verify_same_file(&self.metadata, &other.metadata);
|
||||
if self.contains_verification() != other.contains_verification() && other.contains_verification() {
|
||||
// self doesn't have verification. Copy from other
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
pub mod chunk_verification;
|
||||
pub mod constants;
|
||||
pub mod error;
|
||||
pub mod file_structs;
|
||||
pub mod interpolation_search;
|
||||
pub mod session_directory;
|
||||
|
||||
@@ -8,10 +8,10 @@ use tokio::task::JoinHandle;
|
||||
use tracing::{error, info};
|
||||
use xet_runtime::core::{XetRuntime, check_sigint_shutdown};
|
||||
|
||||
use super::error::Result;
|
||||
use super::set_operations::shard_set_union;
|
||||
use super::shard_file_handle::MDBShardFile;
|
||||
use super::{MDBShardFileFooter, MDBShardFileHeader, MDBShardInfo};
|
||||
use crate::error::Result;
|
||||
|
||||
/// Merge a collection of shards, deleting the old ones.
|
||||
/// After calling this, the passed in shards may be invalid -- i.e. may refer to a shard that doesn't exist.
|
||||
|
||||
@@ -6,7 +6,6 @@ use std::path::Path;
|
||||
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::error::Result;
|
||||
use super::file_structs::{
|
||||
FileDataSequenceEntry, FileDataSequenceHeader, FileMetadataExt, FileVerificationEntry, SupersetResult,
|
||||
};
|
||||
@@ -14,6 +13,7 @@ use super::shard_file::MDB_FILE_INFO_ENTRY_SIZE;
|
||||
use super::shard_format::{MDBShardFileFooter, MDBShardFileHeader, MDBShardInfo};
|
||||
use super::utils::truncate_hash;
|
||||
use super::xorb_structs::{XorbChunkSequenceEntry, XorbChunkSequenceHeader};
|
||||
use crate::error::Result;
|
||||
use crate::merklehash::{HashedWrite, MerkleHash};
|
||||
use crate::serialization_utils::*;
|
||||
|
||||
@@ -462,10 +462,10 @@ mod tests {
|
||||
use itertools::iproduct;
|
||||
use tempfile::TempDir;
|
||||
|
||||
use super::super::error::Result;
|
||||
use super::super::shard_format::test_routines::*;
|
||||
use super::super::shard_in_memory::MDBInMemoryShard;
|
||||
use super::*;
|
||||
use crate::error::Result;
|
||||
use crate::merklehash::compute_data_hash;
|
||||
|
||||
fn test_operations(mem_shard_1: &MDBInMemoryShard, mem_shard_2: &MDBInMemoryShard) -> Result<()> {
|
||||
|
||||
@@ -5,7 +5,6 @@ use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use anyhow::{Ok, Result, anyhow};
|
||||
use clap::Parser;
|
||||
use rand::rngs::StdRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
@@ -54,7 +53,7 @@ async fn run_shard_benchmark(
|
||||
contiguity: usize,
|
||||
block_hit_proportion: f64,
|
||||
dir: &Path,
|
||||
) -> Result<()> {
|
||||
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
let mut seed = 0u64;
|
||||
|
||||
eprintln!("Creating shards.");
|
||||
@@ -145,14 +144,14 @@ async fn run_shard_benchmark(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_arg(arg: &str) -> Result<(u64, u64)> {
|
||||
fn parse_arg(arg: &str) -> Result<(u64, u64), String> {
|
||||
let parts: Vec<&str> = arg.split(':').collect();
|
||||
if parts.len() != 2 {
|
||||
return Err(anyhow!("Failed to parse argument: {arg}"));
|
||||
return Err(format!("Failed to parse argument: {arg}"));
|
||||
}
|
||||
|
||||
let size1 = u64::from_str(parts[0]).map_err(|e| anyhow!("Failed to parse size1: {e:?}"))?;
|
||||
let size2 = u64::from_str(parts[1]).map_err(|e| anyhow!("Failed to parse size2: {e:?}"))?;
|
||||
let size1 = u64::from_str(parts[0]).map_err(|e| format!("Failed to parse size1: {e:?}"))?;
|
||||
let size2 = u64::from_str(parts[1]).map_err(|e| format!("Failed to parse size2: {e:?}"))?;
|
||||
|
||||
Ok((size1, size2))
|
||||
}
|
||||
|
||||
@@ -11,12 +11,12 @@ use tracing::{debug, error, info, warn};
|
||||
|
||||
use super::MDBShardFileFooter;
|
||||
use super::constants::MDB_SHARD_EXPIRATION_BUFFER;
|
||||
use super::error::{MDBShardError, Result};
|
||||
use super::file_structs::{FileDataSequenceEntry, MDBFileInfo};
|
||||
use super::shard_file::current_timestamp;
|
||||
use super::shard_format::MDBShardInfo;
|
||||
use super::utils::{parse_shard_filename, shard_file_name, temp_shard_file_name, truncate_hash};
|
||||
use super::xorb_structs::XorbChunkSequenceHeader;
|
||||
use crate::error::{CoreError, Result};
|
||||
use crate::merklehash::{HMACKey, HashedWrite, MerkleHash, compute_data_hash};
|
||||
|
||||
/// When a specific implementation of the
|
||||
@@ -179,7 +179,7 @@ impl MDBShardFile {
|
||||
if let Some(shard_hash) = parse_shard_filename(path.to_str().unwrap()) {
|
||||
Self::load_from_hash_and_path(shard_hash, path)
|
||||
} else {
|
||||
Err(MDBShardError::BadFilename(format!("{path:?} not a valid MerkleDB filename.")))
|
||||
Err(CoreError::BadFilename(format!("{path:?} not a valid MerkleDB filename.")))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -315,7 +315,7 @@ impl MDBShardFile {
|
||||
} else if let Some(h) = path.file_name().and_then(parse_shard_filename) {
|
||||
load_file(h, path)?;
|
||||
} else {
|
||||
return Err(MDBShardError::BadFilename(format!("Filename {path:?} not valid shard file name.")));
|
||||
return Err(CoreError::BadFilename(format!("Filename {path:?} not valid shard file name.")));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -365,11 +365,11 @@ impl MDBShardFile {
|
||||
pub fn get_reader_if_present(&self) -> Result<Option<BufReader<std::fs::File>>> {
|
||||
match self.get_reader() {
|
||||
Ok(v) => Ok(Some(v)),
|
||||
Err(MDBShardError::Io(e)) => {
|
||||
Err(CoreError::Io(e)) => {
|
||||
if e.kind() == ErrorKind::NotFound {
|
||||
Ok(None)
|
||||
} else {
|
||||
Err(MDBShardError::Io(e))
|
||||
Err(CoreError::Io(e))
|
||||
}
|
||||
},
|
||||
Err(other_err) => Err(other_err),
|
||||
|
||||
@@ -10,13 +10,13 @@ use xet_runtime::core::{XetRuntime, xet_config};
|
||||
use xet_runtime::utils::RwTaskLock;
|
||||
|
||||
use super::constants::MDB_SHARD_EXPIRATION_BUFFER;
|
||||
use super::error::{MDBShardError, Result};
|
||||
use super::file_structs::*;
|
||||
use super::shard_file_handle::MDBShardFile;
|
||||
use super::shard_file_reconstructor::FileReconstructor;
|
||||
use super::shard_in_memory::MDBInMemoryShard;
|
||||
use super::utils::truncate_hash;
|
||||
use super::xorb_structs::*;
|
||||
use crate::error::{CoreError, Result};
|
||||
use crate::merklehash::{HMACKey, MerkleHash};
|
||||
use crate::{MerkleHashMap, TruncatedMerkleHashMap};
|
||||
|
||||
@@ -74,7 +74,7 @@ impl ShardBookkeeper {
|
||||
}
|
||||
|
||||
pub struct ShardFileManager {
|
||||
shard_bookkeeper: RwTaskLock<ShardBookkeeper, MDBShardError>,
|
||||
shard_bookkeeper: RwTaskLock<ShardBookkeeper, CoreError>,
|
||||
current_state: RwLock<MDBInMemoryShard>,
|
||||
shard_directory: PathBuf,
|
||||
target_shard_max_size: u64,
|
||||
@@ -367,7 +367,7 @@ impl ShardFileManager {
|
||||
|
||||
#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)]
|
||||
#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))]
|
||||
impl FileReconstructor<MDBShardError> for ShardFileManager {
|
||||
impl FileReconstructor<CoreError> for ShardFileManager {
|
||||
// Given a file pointer, returns the information needed to reconstruct the file.
|
||||
// The information is stored in the destination vector dest_results. The function
|
||||
// returns true if the file hash was found, and false otherwise.
|
||||
@@ -572,13 +572,13 @@ mod tests {
|
||||
use rand::prelude::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
use super::super::error::Result;
|
||||
use super::super::file_structs::FileDataSequenceHeader;
|
||||
use super::super::session_directory::{consolidate_shards_in_directory, merge_shards};
|
||||
use super::super::shard_format::test_routines::{gen_random_file_info, rng_hash, simple_hash};
|
||||
use super::super::utils::parse_shard_filename;
|
||||
use super::super::xorb_structs::{XorbChunkSequenceEntry, XorbChunkSequenceHeader};
|
||||
use super::*;
|
||||
use crate::error::Result;
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub async fn fill_with_specific_shard(
|
||||
|
||||
@@ -5,18 +5,17 @@ use std::ops::Add;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, UNIX_EPOCH};
|
||||
|
||||
use anyhow::anyhow;
|
||||
use futures::AsyncReadExt;
|
||||
use static_assertions::const_assert;
|
||||
use tracing::debug;
|
||||
|
||||
use super::error::{MDBShardError, Result};
|
||||
use super::file_structs::*;
|
||||
use super::interpolation_search::search_on_sorted_u64s;
|
||||
use super::shard_in_memory::MDBInMemoryShard;
|
||||
use super::streaming_shard::MDBMinimalShard;
|
||||
use super::utils::{shard_expiry_time, truncate_hash};
|
||||
use super::xorb_structs::*;
|
||||
use crate::error::{CoreError, Result};
|
||||
use crate::merklehash::{HMACKey, MerkleHash};
|
||||
use crate::serialization_utils::*;
|
||||
|
||||
@@ -87,7 +86,7 @@ impl MDBShardFileHeader {
|
||||
reader.read_exact(&mut tag)?;
|
||||
|
||||
if tag != MDB_SHARD_HEADER_TAG {
|
||||
return Err(MDBShardError::ShardVersion(
|
||||
return Err(CoreError::ShardVersion(
|
||||
"File does not appear to be a valid Merkle DB Shard file (Wrong Magic Number).".to_owned(),
|
||||
));
|
||||
}
|
||||
@@ -187,7 +186,7 @@ impl MDBShardFileFooter {
|
||||
|
||||
// Do a version check as a simple guard against using this in an old repository
|
||||
if version != MDB_SHARD_FOOTER_VERSION {
|
||||
return Err(MDBShardError::ShardVersion(format!(
|
||||
return Err(CoreError::ShardVersion(format!(
|
||||
"Error: Expected footer version {MDB_SHARD_FOOTER_VERSION}, got {version}"
|
||||
)));
|
||||
}
|
||||
@@ -494,7 +493,7 @@ impl MDBShardInfo {
|
||||
if num_indices < dest_indices.len() {
|
||||
Ok(num_indices)
|
||||
} else {
|
||||
Err(MDBShardError::TruncatedHashCollision(truncate_hash(file_hash)))
|
||||
Err(CoreError::TruncatedHashCollision(truncate_hash(file_hash)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -517,7 +516,7 @@ impl MDBShardInfo {
|
||||
if num_indices < dest_indices.len() {
|
||||
Ok(num_indices)
|
||||
} else {
|
||||
Err(MDBShardError::TruncatedHashCollision(truncate_hash(xorb_hash)))
|
||||
Err(CoreError::TruncatedHashCollision(truncate_hash(xorb_hash)))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -556,7 +555,7 @@ impl MDBShardInfo {
|
||||
))?;
|
||||
|
||||
let Some(mdb_file) = MDBFileInfo::deserialize(reader)? else {
|
||||
return Err(MDBShardError::Internal(anyhow!("invalid file entry index")));
|
||||
return Err(CoreError::InternalError("invalid file entry index".to_string()));
|
||||
};
|
||||
|
||||
Ok(mdb_file)
|
||||
@@ -1213,13 +1212,13 @@ pub mod test_routines {
|
||||
use rand::rngs::{SmallRng, StdRng};
|
||||
use rand::{Rng, SeedableRng};
|
||||
|
||||
use super::super::error::Result;
|
||||
use super::super::file_structs::{FileDataSequenceEntry, FileDataSequenceHeader, FileMetadataExt, MDBFileInfo};
|
||||
use super::super::shard_format::MDBShardInfo;
|
||||
use super::super::shard_in_memory::MDBInMemoryShard;
|
||||
use super::super::streaming_shard::MDBMinimalShard;
|
||||
use super::super::xorb_structs::{MDBXorbInfo, XorbChunkSequenceEntry, XorbChunkSequenceHeader};
|
||||
use super::FileVerificationEntry;
|
||||
use crate::error::Result;
|
||||
use crate::merklehash::MerkleHash;
|
||||
|
||||
pub fn simple_hash(n: u64) -> MerkleHash {
|
||||
@@ -1694,8 +1693,8 @@ pub mod test_routines {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::super::error::Result;
|
||||
use super::test_routines::*;
|
||||
use crate::error::Result;
|
||||
|
||||
#[test]
|
||||
fn test_simple() -> Result<()> {
|
||||
|
||||
@@ -9,12 +9,12 @@ use std::time::Duration;
|
||||
|
||||
use tracing::debug;
|
||||
|
||||
use super::error::Result;
|
||||
use super::file_structs::*;
|
||||
use super::shard_format::MDBShardInfo;
|
||||
use super::utils::{shard_file_name, temp_shard_file_name};
|
||||
use super::xorb_structs::*;
|
||||
use crate::MerkleHashMap;
|
||||
use crate::error::Result;
|
||||
use crate::merklehash::{HashedWrite, MerkleHash};
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
|
||||
@@ -8,12 +8,12 @@ use futures_util::io::AsyncReadExt;
|
||||
use itertools::Itertools;
|
||||
use more_asserts::debug_assert_lt;
|
||||
|
||||
use super::error::{MDBShardError, Result};
|
||||
use super::file_structs::{FileDataSequenceHeader, MDBFileInfoView};
|
||||
use super::shard_file::{MDB_FILE_INFO_ENTRY_SIZE, current_timestamp};
|
||||
use super::xorb_structs::{MDBXorbInfoView, XorbChunkSequenceEntry, XorbChunkSequenceHeader};
|
||||
use super::{MDBShardFileFooter, MDBShardFileHeader};
|
||||
use crate::MerkleHashMap;
|
||||
use crate::error::{CoreError, Result};
|
||||
use crate::merklehash::MerkleHash;
|
||||
|
||||
/// Runs through a shard file info section, calling the specified callback function for each entry.
|
||||
@@ -243,7 +243,7 @@ impl MDBMinimalShard {
|
||||
// if only some files have verification, then we consider this shard invalid
|
||||
// either all files have verification or no files have verification
|
||||
if !file_info_views.is_empty() && !file_info_views.iter().map(|fiv| fiv.contains_verification()).all_equal() {
|
||||
return Err(MDBShardError::invalid_shard("only some files contain verification"));
|
||||
return Err(CoreError::invalid_shard("only some files contain verification"));
|
||||
}
|
||||
|
||||
// XORB stuff
|
||||
@@ -489,7 +489,6 @@ mod tests {
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::io::Cursor;
|
||||
|
||||
use anyhow::Result;
|
||||
use rand::rngs::SmallRng;
|
||||
use rand::{Rng, SeedableRng};
|
||||
|
||||
@@ -501,6 +500,7 @@ mod tests {
|
||||
use super::super::shard_in_memory::MDBInMemoryShard;
|
||||
use super::super::xorb_structs::MDBXorbInfo;
|
||||
use super::MDBMinimalShard;
|
||||
use crate::error::Result;
|
||||
use crate::merklehash::MerkleHash;
|
||||
|
||||
fn verify_serialization(min_shard: &MDBMinimalShard, mem_shard: &MDBInMemoryShard) -> Result<()> {
|
||||
|
||||
@@ -9,7 +9,6 @@ use std::io::{Read, Seek, SeekFrom};
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
use csv::Writer;
|
||||
use rand::rngs::StdRng;
|
||||
@@ -161,7 +160,7 @@ async fn main() {
|
||||
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||
// Parse command-line arguments
|
||||
let args = Args::parse();
|
||||
|
||||
|
||||
@@ -4,12 +4,11 @@ use std::io::{Cursor, Read, Write, copy};
|
||||
use std::str::FromStr;
|
||||
use std::time::Instant;
|
||||
|
||||
use anyhow::anyhow;
|
||||
use lz4_flex::frame::{FrameDecoder, FrameEncoder};
|
||||
|
||||
use super::byte_grouping::BG4Predictor;
|
||||
use super::byte_grouping::bg4::{bg4_regroup, bg4_split};
|
||||
use super::error::{Result, XorbObjectError};
|
||||
use crate::error::{CoreError, Result};
|
||||
|
||||
pub static mut BG4_SPLIT_RUNTIME: f64 = 0.;
|
||||
pub static mut BG4_REGROUP_RUNTIME: f64 = 0.;
|
||||
@@ -50,7 +49,7 @@ impl From<CompressionScheme> for &'static str {
|
||||
}
|
||||
|
||||
impl TryFrom<u8> for CompressionScheme {
|
||||
type Error = XorbObjectError;
|
||||
type Error = CoreError;
|
||||
|
||||
fn try_from(value: u8) -> Result<Self> {
|
||||
match value {
|
||||
@@ -58,13 +57,13 @@ impl TryFrom<u8> for CompressionScheme {
|
||||
1 => Ok(CompressionScheme::LZ4),
|
||||
2 => Ok(CompressionScheme::ByteGrouping4LZ4),
|
||||
99 => Ok(CompressionScheme::Auto),
|
||||
_ => Err(XorbObjectError::Format(anyhow!("cannot convert value {value} to CompressionScheme"))),
|
||||
_ => Err(CoreError::MalformedData(format!("cannot convert value {value} to CompressionScheme"))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for CompressionScheme {
|
||||
type Err = XorbObjectError;
|
||||
type Err = CoreError;
|
||||
|
||||
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||
match s.trim().to_lowercase().as_str() {
|
||||
@@ -72,7 +71,7 @@ impl FromStr for CompressionScheme {
|
||||
"none" => Ok(CompressionScheme::None),
|
||||
"lz4" => Ok(CompressionScheme::LZ4),
|
||||
"bg4-lz4" => Ok(CompressionScheme::ByteGrouping4LZ4),
|
||||
_ => Err(XorbObjectError::Format(anyhow!(
|
||||
_ => Err(CoreError::MalformedData(format!(
|
||||
"Invalid compression scheme '{s}'. Valid values are: auto, none, lz4, bg4-lz4."
|
||||
))),
|
||||
}
|
||||
@@ -101,7 +100,7 @@ impl CompressionScheme {
|
||||
pub fn decompress_from_slice<'a>(&self, data: &'a [u8]) -> Result<Cow<'a, [u8]>> {
|
||||
Ok(match self {
|
||||
CompressionScheme::Auto => {
|
||||
return Err(XorbObjectError::Format(anyhow!("Cannot decompress with Auto scheme")));
|
||||
return Err(CoreError::MalformedData("Cannot decompress with Auto scheme".to_string()));
|
||||
},
|
||||
CompressionScheme::None => data.into(),
|
||||
CompressionScheme::LZ4 => lz4_decompress_from_slice(data).map(Cow::from)?,
|
||||
@@ -112,7 +111,7 @@ impl CompressionScheme {
|
||||
pub fn decompress_from_reader<R: Read, W: Write>(&self, reader: &mut R, writer: &mut W) -> Result<u64> {
|
||||
Ok(match self {
|
||||
CompressionScheme::Auto => {
|
||||
return Err(XorbObjectError::Format(anyhow!("Cannot decompress with Auto scheme")));
|
||||
return Err(CoreError::MalformedData("Cannot decompress with Auto scheme".to_string()));
|
||||
},
|
||||
CompressionScheme::None => copy(reader, writer)?,
|
||||
CompressionScheme::LZ4 => lz4_decompress_from_reader(reader, writer)?,
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
pub use crate::error::{FormatError as XorbObjectError, Result, Validate};
|
||||
@@ -2,7 +2,6 @@ pub mod byte_grouping;
|
||||
mod chunk;
|
||||
mod compression_scheme;
|
||||
pub mod constants;
|
||||
pub mod error;
|
||||
mod raw_xorb_data;
|
||||
mod xorb_chunk_format;
|
||||
mod xorb_object_format;
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
use std::io::{Read, Write};
|
||||
use std::mem::size_of;
|
||||
|
||||
use anyhow::anyhow;
|
||||
|
||||
use super::CompressionScheme;
|
||||
use super::constants::MAX_CHUNK_SIZE;
|
||||
use super::error::XorbObjectError;
|
||||
use super::xorb_object_format::XORB_OBJECT_FORMAT_IDENT;
|
||||
use crate::error::CoreError;
|
||||
|
||||
pub mod deserialize_async;
|
||||
|
||||
@@ -54,7 +52,7 @@ impl XorbChunkHeader {
|
||||
convert_three_byte_num(&self.uncompressed_length)
|
||||
}
|
||||
|
||||
pub fn get_compression_scheme(&self) -> Result<CompressionScheme, XorbObjectError> {
|
||||
pub fn get_compression_scheme(&self) -> Result<CompressionScheme, CoreError> {
|
||||
CompressionScheme::try_from(self.compression_scheme)
|
||||
}
|
||||
|
||||
@@ -63,17 +61,16 @@ impl XorbChunkHeader {
|
||||
self.compression_scheme = compression_scheme as u8;
|
||||
}
|
||||
|
||||
fn validate(&self) -> Result<(), XorbObjectError> {
|
||||
fn validate(&self) -> Result<(), CoreError> {
|
||||
let _ = self.get_compression_scheme()?;
|
||||
if self.version > CURRENT_VERSION {
|
||||
return Err(XorbObjectError::Format(anyhow!(
|
||||
return Err(CoreError::MalformedData(format!(
|
||||
"chunk header version too high at {}, current version is {}",
|
||||
self.version,
|
||||
CURRENT_VERSION
|
||||
self.version, CURRENT_VERSION
|
||||
)));
|
||||
}
|
||||
if self.get_compressed_length() as usize > *MAX_CHUNK_SIZE * 2 {
|
||||
return Err(XorbObjectError::Format(anyhow!(
|
||||
return Err(CoreError::MalformedData(format!(
|
||||
"chunk header compressed length too large at {}, maximum: {}",
|
||||
self.get_compressed_length(),
|
||||
*MAX_CHUNK_SIZE
|
||||
@@ -81,7 +78,7 @@ impl XorbChunkHeader {
|
||||
}
|
||||
// the max chunk size is strictly enforced
|
||||
if self.get_uncompressed_length() as usize > *MAX_CHUNK_SIZE {
|
||||
return Err(XorbObjectError::Format(anyhow!(
|
||||
return Err(CoreError::MalformedData(format!(
|
||||
"chunk header uncompressed length too large at {}, maximum: {}",
|
||||
self.get_uncompressed_length(),
|
||||
*MAX_CHUNK_SIZE
|
||||
@@ -116,11 +113,13 @@ pub fn serialize_chunk<W: Write>(
|
||||
chunk: &[u8],
|
||||
w: &mut W,
|
||||
compression_scheme: CompressionScheme,
|
||||
) -> Result<usize, XorbObjectError> {
|
||||
) -> Result<usize, CoreError> {
|
||||
let compression_scheme = compression_scheme.resolve_for_data(chunk);
|
||||
debug_assert_ne!(compression_scheme, CompressionScheme::Auto);
|
||||
if compression_scheme == CompressionScheme::Auto {
|
||||
return Err(XorbObjectError::Format(anyhow!("CompressionScheme::Auto cannot be serialized into xorb chunks")));
|
||||
return Err(CoreError::MalformedData(
|
||||
"CompressionScheme::Auto cannot be serialized into xorb chunks".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let compressed = compression_scheme.compress_from_slice(chunk)?;
|
||||
@@ -139,24 +138,22 @@ pub fn serialize_chunk<W: Write>(
|
||||
Ok(size_of::<XorbChunkHeader>() + compressed.len())
|
||||
}
|
||||
|
||||
pub fn parse_chunk_header(
|
||||
chunk_header_bytes: [u8; XORB_CHUNK_HEADER_LENGTH],
|
||||
) -> Result<XorbChunkHeader, XorbObjectError> {
|
||||
pub fn parse_chunk_header(chunk_header_bytes: [u8; XORB_CHUNK_HEADER_LENGTH]) -> Result<XorbChunkHeader, CoreError> {
|
||||
if chunk_header_bytes[..XORB_OBJECT_FORMAT_IDENT.len()] == XORB_OBJECT_FORMAT_IDENT {
|
||||
return Err(XorbObjectError::ChunkHeaderParse);
|
||||
return Err(CoreError::ChunkHeaderParse);
|
||||
}
|
||||
let result: XorbChunkHeader = unsafe { std::mem::transmute_copy(&chunk_header_bytes) };
|
||||
result.validate()?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
pub fn deserialize_chunk_header<R: Read>(reader: &mut R) -> Result<XorbChunkHeader, XorbObjectError> {
|
||||
pub fn deserialize_chunk_header<R: Read>(reader: &mut R) -> Result<XorbChunkHeader, CoreError> {
|
||||
let mut buf = [0u8; size_of::<XorbChunkHeader>()];
|
||||
reader.read_exact(&mut buf)?;
|
||||
parse_chunk_header(buf)
|
||||
}
|
||||
|
||||
pub fn deserialize_chunk<R: Read>(reader: &mut R) -> Result<(Vec<u8>, usize, u32), XorbObjectError> {
|
||||
pub fn deserialize_chunk<R: Read>(reader: &mut R) -> Result<(Vec<u8>, usize, u32), CoreError> {
|
||||
let mut buf = Vec::new();
|
||||
let (compressed_chunk_size, uncompressed_chunk_size) = deserialize_chunk_to_writer(reader, &mut buf)?;
|
||||
Ok((buf, compressed_chunk_size, uncompressed_chunk_size))
|
||||
@@ -166,7 +163,7 @@ pub fn deserialize_chunk<R: Read>(reader: &mut R) -> Result<(Vec<u8>, usize, u32
|
||||
pub fn deserialize_chunk_to_writer<R: Read, W: Write>(
|
||||
reader: &mut R,
|
||||
writer: &mut W,
|
||||
) -> Result<(usize, u32), XorbObjectError> {
|
||||
) -> Result<(usize, u32), CoreError> {
|
||||
let header = deserialize_chunk_header(reader)?;
|
||||
deserialize_chunk_with_header_to_writer(reader, writer, header)
|
||||
}
|
||||
@@ -175,7 +172,7 @@ fn deserialize_chunk_with_header_to_writer<R: Read, W: Write>(
|
||||
reader: &mut R,
|
||||
writer: &mut W,
|
||||
header: XorbChunkHeader,
|
||||
) -> Result<(usize, u32), XorbObjectError> {
|
||||
) -> Result<(usize, u32), CoreError> {
|
||||
let mut compressed_data_reader = reader.take(header.get_compressed_length().into());
|
||||
|
||||
let uncompressed_len = header
|
||||
@@ -183,15 +180,15 @@ fn deserialize_chunk_with_header_to_writer<R: Read, W: Write>(
|
||||
.decompress_from_reader(&mut compressed_data_reader, writer)?;
|
||||
|
||||
if uncompressed_len != header.get_uncompressed_length() as u64 {
|
||||
return Err(XorbObjectError::Format(anyhow!(
|
||||
"chunk is corrupted, uncompressed bytes len doesn't agree with chunk header"
|
||||
)));
|
||||
return Err(CoreError::MalformedData(
|
||||
"chunk is corrupted, uncompressed bytes len doesn't agree with chunk header".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok((header.get_compressed_length() as usize + XORB_CHUNK_HEADER_LENGTH, uncompressed_len as u32))
|
||||
}
|
||||
|
||||
pub fn deserialize_chunks<R: Read>(reader: &mut R) -> Result<(Vec<u8>, Vec<u32>), XorbObjectError> {
|
||||
pub fn deserialize_chunks<R: Read>(reader: &mut R) -> Result<(Vec<u8>, Vec<u32>), CoreError> {
|
||||
let mut buf = Vec::new();
|
||||
let (_, chunk_byte_indices) = deserialize_chunks_to_writer(reader, &mut buf)?;
|
||||
Ok((buf, chunk_byte_indices))
|
||||
@@ -223,12 +220,12 @@ pub fn append_chunk_segment(
|
||||
/// Uses a single `read()` call to detect EOF (returns 0), then completes
|
||||
/// any partial header with `read_exact`. An `UnexpectedEof` from `read_exact`
|
||||
/// means the stream was truncated mid-header.
|
||||
fn try_read_chunk_header<R: Read>(reader: &mut R) -> Result<Option<XorbChunkHeader>, XorbObjectError> {
|
||||
fn try_read_chunk_header<R: Read>(reader: &mut R) -> Result<Option<XorbChunkHeader>, CoreError> {
|
||||
let mut header_buf = [0u8; XORB_CHUNK_HEADER_LENGTH];
|
||||
let n = match reader.read(&mut header_buf) {
|
||||
Ok(0) => return Ok(None),
|
||||
Ok(n) => n,
|
||||
Err(e) => return Err(XorbObjectError::Io(e)),
|
||||
Err(e) => return Err(CoreError::Io(e)),
|
||||
};
|
||||
if n < XORB_CHUNK_HEADER_LENGTH {
|
||||
reader.read_exact(&mut header_buf[n..])?;
|
||||
@@ -239,7 +236,7 @@ fn try_read_chunk_header<R: Read>(reader: &mut R) -> Result<Option<XorbChunkHead
|
||||
pub fn deserialize_chunks_to_writer<R: Read, W: Write>(
|
||||
reader: &mut R,
|
||||
writer: &mut W,
|
||||
) -> Result<(usize, Vec<u32>), XorbObjectError> {
|
||||
) -> Result<(usize, Vec<u32>), CoreError> {
|
||||
let mut num_compressed_written = 0;
|
||||
let mut num_uncompressed_written = 0;
|
||||
|
||||
|
||||
@@ -1,16 +1,13 @@
|
||||
use std::io::Write;
|
||||
use std::mem::size_of;
|
||||
|
||||
use anyhow::anyhow;
|
||||
use futures::io::{AsyncRead, AsyncReadExt};
|
||||
use futures::{Stream, TryStreamExt};
|
||||
|
||||
use super::super::error::XorbObjectError;
|
||||
use super::{XORB_CHUNK_HEADER_LENGTH, XorbChunkHeader, parse_chunk_header};
|
||||
use crate::error::CoreError;
|
||||
|
||||
pub async fn deserialize_chunk_header<R: AsyncRead + Unpin>(
|
||||
reader: &mut R,
|
||||
) -> Result<XorbChunkHeader, XorbObjectError> {
|
||||
pub async fn deserialize_chunk_header<R: AsyncRead + Unpin>(reader: &mut R) -> Result<XorbChunkHeader, CoreError> {
|
||||
let mut buf = [0u8; size_of::<XorbChunkHeader>()];
|
||||
reader.read_exact(&mut buf).await?;
|
||||
parse_chunk_header(buf)
|
||||
@@ -20,7 +17,7 @@ pub async fn deserialize_chunk_header<R: AsyncRead + Unpin>(
|
||||
pub async fn deserialize_chunk_to_writer<R: AsyncRead + Unpin, W: Write>(
|
||||
reader: &mut R,
|
||||
writer: &mut W,
|
||||
) -> Result<(usize, u32), XorbObjectError> {
|
||||
) -> Result<(usize, u32), CoreError> {
|
||||
let header = deserialize_chunk_header(reader).await?;
|
||||
deserialize_chunk_with_header_to_writer(reader, writer, header).await
|
||||
}
|
||||
@@ -29,7 +26,7 @@ async fn deserialize_chunk_with_header_to_writer<R: AsyncRead + Unpin, W: Write>
|
||||
reader: &mut R,
|
||||
writer: &mut W,
|
||||
header: XorbChunkHeader,
|
||||
) -> Result<(usize, u32), XorbObjectError> {
|
||||
) -> Result<(usize, u32), CoreError> {
|
||||
let mut compressed_data = vec![0u8; header.get_compressed_length() as usize];
|
||||
reader.read_exact(&mut compressed_data).await?;
|
||||
|
||||
@@ -37,9 +34,9 @@ async fn deserialize_chunk_with_header_to_writer<R: AsyncRead + Unpin, W: Write>
|
||||
let uncompressed_len = uncompressed_data.len();
|
||||
|
||||
if uncompressed_len != header.get_uncompressed_length() as usize {
|
||||
return Err(XorbObjectError::Format(anyhow!(
|
||||
"chunk is corrupted, uncompressed bytes len doesn't agree with chunk header"
|
||||
)));
|
||||
return Err(CoreError::MalformedData(
|
||||
"chunk is corrupted, uncompressed bytes len doesn't agree with chunk header".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
writer.write_all(&uncompressed_data)?;
|
||||
@@ -48,7 +45,7 @@ async fn deserialize_chunk_with_header_to_writer<R: AsyncRead + Unpin, W: Write>
|
||||
}
|
||||
|
||||
/// deserialize 1 chunk returning a Vec<u8>, the compressed length and the uncompressed length of the chunk
|
||||
pub async fn deserialize_chunk<R: AsyncRead + Unpin>(reader: &mut R) -> Result<(Vec<u8>, usize, u32), XorbObjectError> {
|
||||
pub async fn deserialize_chunk<R: AsyncRead + Unpin>(reader: &mut R) -> Result<(Vec<u8>, usize, u32), CoreError> {
|
||||
let mut buf = Vec::new();
|
||||
let (compressed_len, uncompressed_len) = deserialize_chunk_to_writer(reader, &mut buf).await?;
|
||||
Ok((buf, compressed_len, uncompressed_len))
|
||||
@@ -61,12 +58,12 @@ pub async fn deserialize_chunk<R: AsyncRead + Unpin>(reader: &mut R) -> Result<(
|
||||
/// means the stream was truncated mid-header.
|
||||
async fn try_read_chunk_header_async<R: AsyncRead + Unpin>(
|
||||
reader: &mut R,
|
||||
) -> Result<Option<XorbChunkHeader>, XorbObjectError> {
|
||||
) -> Result<Option<XorbChunkHeader>, CoreError> {
|
||||
let mut header_buf = [0u8; XORB_CHUNK_HEADER_LENGTH];
|
||||
let n = match AsyncReadExt::read(reader, &mut header_buf).await {
|
||||
Ok(0) => return Ok(None),
|
||||
Ok(n) => n,
|
||||
Err(e) => return Err(XorbObjectError::Io(e)),
|
||||
Err(e) => return Err(CoreError::Io(e)),
|
||||
};
|
||||
if n < XORB_CHUNK_HEADER_LENGTH {
|
||||
reader.read_exact(&mut header_buf[n..]).await?;
|
||||
@@ -77,7 +74,7 @@ async fn try_read_chunk_header_async<R: AsyncRead + Unpin>(
|
||||
pub async fn deserialize_chunks_to_writer_from_async_read<R: AsyncRead + Unpin, W: Write>(
|
||||
reader: &mut R,
|
||||
writer: &mut W,
|
||||
) -> Result<(usize, Vec<u32>), XorbObjectError> {
|
||||
) -> Result<(usize, Vec<u32>), CoreError> {
|
||||
let mut num_compressed_written = 0;
|
||||
let mut num_uncompressed_written = 0;
|
||||
|
||||
@@ -99,7 +96,7 @@ pub async fn deserialize_chunks_to_writer_from_async_read<R: AsyncRead + Unpin,
|
||||
|
||||
pub async fn deserialize_chunks_from_async_read<R: AsyncRead + Unpin>(
|
||||
reader: &mut R,
|
||||
) -> Result<(Vec<u8>, Vec<u32>), XorbObjectError> {
|
||||
) -> Result<(Vec<u8>, Vec<u32>), CoreError> {
|
||||
let mut buf = Vec::new();
|
||||
let (_, chunk_byte_indices) = deserialize_chunks_to_writer_from_async_read(reader, &mut buf).await?;
|
||||
Ok((buf, chunk_byte_indices))
|
||||
@@ -108,7 +105,7 @@ pub async fn deserialize_chunks_from_async_read<R: AsyncRead + Unpin>(
|
||||
pub async fn deserialize_chunks_to_writer_from_stream<B, E, S, W>(
|
||||
stream: S,
|
||||
writer: &mut W,
|
||||
) -> Result<(usize, Vec<u32>), XorbObjectError>
|
||||
) -> Result<(usize, Vec<u32>), CoreError>
|
||||
where
|
||||
B: AsRef<[u8]>,
|
||||
E: Into<std::io::Error>,
|
||||
@@ -119,7 +116,7 @@ where
|
||||
deserialize_chunks_to_writer_from_async_read(&mut stream_reader, writer).await
|
||||
}
|
||||
|
||||
pub async fn deserialize_chunks_from_stream<B, E, S>(stream: S) -> Result<(Vec<u8>, Vec<u32>), XorbObjectError>
|
||||
pub async fn deserialize_chunks_from_stream<B, E, S>(stream: S) -> Result<(Vec<u8>, Vec<u32>), CoreError>
|
||||
where
|
||||
B: AsRef<[u8]>,
|
||||
E: Into<std::io::Error>,
|
||||
|
||||
@@ -2,7 +2,6 @@ use std::cmp::min;
|
||||
use std::io::{Cursor, Read, Seek, SeekFrom, Write};
|
||||
use std::mem::{size_of, size_of_val};
|
||||
|
||||
use anyhow::anyhow;
|
||||
use bytes::Buf;
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
use futures::AsyncReadExt;
|
||||
@@ -12,9 +11,9 @@ use tracing::warn;
|
||||
use xet_runtime::core::xet_config;
|
||||
|
||||
use super::constants::{TARGET_CHUNK_SIZE, XORB_BLOCK_SIZE};
|
||||
use super::error::{Validate, XorbObjectError};
|
||||
use super::xorb_chunk_format::{deserialize_chunk, deserialize_chunk_header, serialize_chunk, write_chunk_header};
|
||||
use super::{CompressionScheme, RawXorbData, XorbChunkHeader};
|
||||
use crate::error::{CoreError, Validate};
|
||||
use crate::merklehash::{DataHash, MerkleHash};
|
||||
use crate::metadata_shard::chunk_verification::range_hash_from_chunks;
|
||||
use crate::serialization_utils::*;
|
||||
@@ -100,11 +99,11 @@ impl XorbObjectInfoV0 {
|
||||
///
|
||||
/// Assumes caller has set position of Writer to appropriate location for serialization.
|
||||
#[deprecated]
|
||||
pub fn serialize<W: Write>(&self, writer: &mut W) -> Result<usize, XorbObjectError> {
|
||||
pub fn serialize<W: Write>(&self, writer: &mut W) -> Result<usize, CoreError> {
|
||||
let mut total_bytes_written = 0;
|
||||
|
||||
// Helper function to write data and update the byte count
|
||||
let mut write_bytes = |data: &[u8]| -> Result<(), XorbObjectError> {
|
||||
let mut write_bytes = |data: &[u8]| -> Result<(), CoreError> {
|
||||
writer.write_all(data)?;
|
||||
total_bytes_written += data.len();
|
||||
Ok(())
|
||||
@@ -134,11 +133,11 @@ impl XorbObjectInfoV0 {
|
||||
///
|
||||
/// Expects metadata struct is found at end of Reader, written out in struct order.
|
||||
#[deprecated]
|
||||
pub fn deserialize<R: Read>(reader: &mut R) -> Result<(Self, u32), XorbObjectError> {
|
||||
pub fn deserialize<R: Read>(reader: &mut R) -> Result<(Self, u32), CoreError> {
|
||||
let mut total_bytes_read: u32 = 0;
|
||||
|
||||
// Helper function to read data and update the byte count
|
||||
let mut read_bytes = |data: &mut [u8]| -> Result<(), XorbObjectError> {
|
||||
let mut read_bytes = |data: &mut [u8]| -> Result<(), CoreError> {
|
||||
reader.read_exact(data)?;
|
||||
total_bytes_read += data.len() as u32;
|
||||
Ok(())
|
||||
@@ -148,14 +147,14 @@ impl XorbObjectInfoV0 {
|
||||
read_bytes(&mut ident)?;
|
||||
|
||||
if ident != XORB_OBJECT_FORMAT_IDENT {
|
||||
return Err(XorbObjectError::Format(anyhow!("Xorb Invalid Ident")));
|
||||
return Err(CoreError::MalformedData("Xorb Invalid Ident".to_string()));
|
||||
}
|
||||
|
||||
let mut version = [0u8; 1];
|
||||
read_bytes(&mut version)?;
|
||||
|
||||
if version[0] != XORB_OBJECT_FORMAT_VERSION_V0 {
|
||||
return Err(XorbObjectError::Format(anyhow!("Xorb Invalid Format Version")));
|
||||
return Err(CoreError::MalformedData("Xorb Invalid Format Version".to_string()));
|
||||
}
|
||||
|
||||
let (s, bytes_read_v0) = Self::deserialize_v0(reader)?;
|
||||
@@ -163,11 +162,11 @@ impl XorbObjectInfoV0 {
|
||||
Ok((s, total_bytes_read + bytes_read_v0))
|
||||
}
|
||||
|
||||
pub fn deserialize_v0<R: Read>(reader: &mut R) -> Result<(Self, u32), XorbObjectError> {
|
||||
pub fn deserialize_v0<R: Read>(reader: &mut R) -> Result<(Self, u32), CoreError> {
|
||||
let mut total_bytes_read: u32 = 0;
|
||||
|
||||
// Helper function to read data and update the byte count
|
||||
let mut read_bytes = |data: &mut [u8]| -> Result<(), XorbObjectError> {
|
||||
let mut read_bytes = |data: &mut [u8]| -> Result<(), CoreError> {
|
||||
reader.read_exact(data)?;
|
||||
total_bytes_read += data.len() as u32;
|
||||
Ok(())
|
||||
@@ -219,7 +218,7 @@ impl XorbObjectInfoV0 {
|
||||
pub async fn deserialize_async<R: futures::io::AsyncRead + Unpin>(
|
||||
reader: &mut R,
|
||||
version: u8,
|
||||
) -> Result<(Self, u32), XorbObjectError> {
|
||||
) -> Result<(Self, u32), CoreError> {
|
||||
// already read 8 bytes (ident + version)
|
||||
let mut total_bytes_read: u32 = (size_of::<XorbObjectIdent>() + size_of::<u8>()) as u32;
|
||||
|
||||
@@ -228,7 +227,7 @@ impl XorbObjectInfoV0 {
|
||||
reader: &mut R,
|
||||
total_bytes_read: &mut u32,
|
||||
buf: &mut [u8],
|
||||
) -> Result<(), XorbObjectError> {
|
||||
) -> Result<(), CoreError> {
|
||||
reader.read_exact(buf).await?;
|
||||
*total_bytes_read += buf.len() as u32;
|
||||
Ok(())
|
||||
@@ -399,7 +398,7 @@ impl XorbObjectInfoV1 {
|
||||
/// Serialize XorbObjectInfoV1 to provided Writer.
|
||||
///
|
||||
/// Assumes caller has set position of Writer to appropriate location for serialization.
|
||||
pub fn serialize<W: Write>(&self, writer: &mut W) -> Result<usize, XorbObjectError> {
|
||||
pub fn serialize<W: Write>(&self, writer: &mut W) -> Result<usize, CoreError> {
|
||||
let mut counting_writer = countio::Counter::new(writer);
|
||||
let w = &mut counting_writer;
|
||||
|
||||
@@ -422,7 +421,7 @@ impl XorbObjectInfoV1 {
|
||||
|
||||
if self.num_chunks as usize != self.chunk_hashes.len() {
|
||||
debug_assert_eq!(self.num_chunks as usize, self.chunk_hashes.len());
|
||||
return Err(XorbObjectError::Format(anyhow!(
|
||||
return Err(CoreError::MalformedData(format!(
|
||||
"Chunk hash vector not correct length on serialization. ({}, expected {})",
|
||||
self.chunk_hashes.len(),
|
||||
self.num_chunks
|
||||
@@ -443,7 +442,7 @@ impl XorbObjectInfoV1 {
|
||||
// write variable field: chunk boundaries
|
||||
if self.num_chunks as usize != self.chunk_boundary_offsets.len() {
|
||||
debug_assert_eq!(self.num_chunks as usize, self.chunk_boundary_offsets.len());
|
||||
return Err(XorbObjectError::Format(anyhow!(
|
||||
return Err(CoreError::MalformedData(format!(
|
||||
"Chunk boundary offset vector not correct length on serialization. ({}, expected {})",
|
||||
self.chunk_boundary_offsets.len(),
|
||||
self.num_chunks
|
||||
@@ -454,7 +453,7 @@ impl XorbObjectInfoV1 {
|
||||
// write variable field: unpacked chunk data offsets
|
||||
if self.num_chunks as usize != self.unpacked_chunk_offsets.len() {
|
||||
debug_assert_eq!(self.num_chunks as usize, self.unpacked_chunk_offsets.len());
|
||||
return Err(XorbObjectError::Format(anyhow!(
|
||||
return Err(CoreError::MalformedData(format!(
|
||||
"Unpacked chunk offset vector not correct length on serialization. ({}, expected {})",
|
||||
self.unpacked_chunk_offsets.len(),
|
||||
self.num_chunks
|
||||
@@ -481,7 +480,7 @@ impl XorbObjectInfoV1 {
|
||||
/// Construct XorbObjectInfo object from Reader + Seek.
|
||||
///
|
||||
/// Expects metadata struct is found at end of Reader, written out in struct order.
|
||||
pub fn deserialize<R: Read>(reader: &mut R) -> Result<(Self, u32), XorbObjectError> {
|
||||
pub fn deserialize<R: Read>(reader: &mut R) -> Result<(Self, u32), CoreError> {
|
||||
let mut counting_reader = countio::Counter::new(reader);
|
||||
let r = &mut counting_reader;
|
||||
|
||||
@@ -493,7 +492,7 @@ impl XorbObjectInfoV1 {
|
||||
read_bytes(r, &mut s.ident)?;
|
||||
|
||||
if s.ident != XORB_OBJECT_FORMAT_IDENT {
|
||||
return Err(XorbObjectError::Format(anyhow!("Xorb Invalid Ident")));
|
||||
return Err(CoreError::MalformedData("Xorb Invalid Ident".to_string()));
|
||||
}
|
||||
|
||||
s.version = read_u8(r)?;
|
||||
@@ -503,7 +502,7 @@ impl XorbObjectInfoV1 {
|
||||
// we don't have the missing info (unpacked_chunk_offsets), it's OK
|
||||
return Ok((Self::from_v0(sv0), r.reader_bytes() as u32));
|
||||
} else if s.version != XORB_OBJECT_FORMAT_VERSION {
|
||||
return Err(XorbObjectError::Format(anyhow!("Xorb Invalid Format Version")));
|
||||
return Err(CoreError::MalformedData("Xorb Invalid Format Version".to_string()));
|
||||
}
|
||||
|
||||
s.xorb_hash = read_hash(r)?;
|
||||
@@ -516,13 +515,13 @@ impl XorbObjectInfoV1 {
|
||||
read_bytes(r, &mut s.ident_hash_section)?;
|
||||
|
||||
if s.ident_hash_section != XORB_OBJECT_FORMAT_IDENT_HASHES {
|
||||
return Err(XorbObjectError::Format(anyhow!("Xorb Invalid Ident for Hash Metadata Section")));
|
||||
return Err(CoreError::MalformedData("Xorb Invalid Ident for Hash Metadata Section".to_string()));
|
||||
}
|
||||
|
||||
s.hashes_version = read_u8(r)?;
|
||||
|
||||
if s.hashes_version != XORB_OBJECT_FORMAT_HASHES_VERSION {
|
||||
return Err(XorbObjectError::Format(anyhow!("Xorb Invalid Format Version for Hash Metadata Section")));
|
||||
return Err(CoreError::MalformedData("Xorb Invalid Format Version for Hash Metadata Section".to_string()));
|
||||
}
|
||||
|
||||
let num_chunks_2 = read_u32(r)?;
|
||||
@@ -541,23 +540,23 @@ impl XorbObjectInfoV1 {
|
||||
read_bytes(r, &mut s.ident_boundary_section)?;
|
||||
|
||||
if s.ident_boundary_section != XORB_OBJECT_FORMAT_IDENT_BOUNDARIES {
|
||||
return Err(XorbObjectError::Format(anyhow!("Xorb Invalid Ident for Boundary Metadata Section")));
|
||||
return Err(CoreError::MalformedData("Xorb Invalid Ident for Boundary Metadata Section".to_string()));
|
||||
}
|
||||
|
||||
s.boundaries_version = read_u8(r)?;
|
||||
|
||||
if s.boundaries_version != XORB_OBJECT_FORMAT_BOUNDARIES_VERSION {
|
||||
return Err(XorbObjectError::Format(anyhow!(
|
||||
"Xorb Invalid Format Version for Boundaries Metadata Section"
|
||||
)));
|
||||
return Err(CoreError::MalformedData(
|
||||
"Xorb Invalid Format Version for Boundaries Metadata Section".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let num_chunks_3 = read_u32(r)?;
|
||||
|
||||
if num_chunks_2 != num_chunks_3 {
|
||||
return Err(XorbObjectError::Format(anyhow!(
|
||||
"Xorb Invalid: inconsistent num_chunks between hashes and boundaries section."
|
||||
)));
|
||||
return Err(CoreError::MalformedData(
|
||||
"Xorb Invalid: inconsistent num_chunks between hashes and boundaries section.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
s.chunk_boundary_offsets.reserve(prealloc_num_chunks(num_chunks_3 as usize));
|
||||
@@ -574,9 +573,9 @@ impl XorbObjectInfoV1 {
|
||||
s.num_chunks = read_u32(r)?;
|
||||
|
||||
if s.num_chunks != num_chunks_2 {
|
||||
return Err(XorbObjectError::Format(anyhow!(
|
||||
"Xorb Invalid: inconsistent num_chunks between metadata and hashes section."
|
||||
)));
|
||||
return Err(CoreError::MalformedData(
|
||||
"Xorb Invalid: inconsistent num_chunks between metadata and hashes section.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
s.hashes_section_offset_from_end = read_u32(r)?;
|
||||
@@ -587,11 +586,15 @@ impl XorbObjectInfoV1 {
|
||||
let end_byte_offset = r.reader_bytes();
|
||||
|
||||
if end_byte_offset - hash_section_begin_byte_offset != s.hashes_section_offset_from_end as usize {
|
||||
return Err(XorbObjectError::Format(anyhow!("Xorb Invalid: incorrect hashes_section_offset_from_end.")));
|
||||
return Err(CoreError::MalformedData(
|
||||
"Xorb Invalid: incorrect hashes_section_offset_from_end.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if end_byte_offset - boundary_section_begin_byte_offset != s.boundary_section_offset_from_end as usize {
|
||||
return Err(XorbObjectError::Format(anyhow!("Xorb Invalid: incorrect boundary_section_offset_from_end.")));
|
||||
return Err(CoreError::MalformedData(
|
||||
"Xorb Invalid: incorrect boundary_section_offset_from_end.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok((s, r.reader_bytes() as u32))
|
||||
@@ -600,7 +603,7 @@ impl XorbObjectInfoV1 {
|
||||
/// Construct XorbObjectInfo object from Reader + Seek.
|
||||
///
|
||||
/// Expects metadata struct is found at end of Reader, written out in struct order.
|
||||
pub fn deserialize_only_boundaries_section<R: Read + Seek>(reader: &mut R) -> Result<(Self, u32), XorbObjectError> {
|
||||
pub fn deserialize_only_boundaries_section<R: Read + Seek>(reader: &mut R) -> Result<(Self, u32), CoreError> {
|
||||
let mut s = Self::default();
|
||||
|
||||
// info_length + size of _buffer + size of u32 for offset field
|
||||
@@ -622,15 +625,15 @@ impl XorbObjectInfoV1 {
|
||||
read_bytes(r, &mut s.ident_boundary_section)?;
|
||||
|
||||
if s.ident_boundary_section != XORB_OBJECT_FORMAT_IDENT_BOUNDARIES {
|
||||
return Err(XorbObjectError::Format(anyhow!("Xorb Invalid Ident for Boundary Metadata Section")));
|
||||
return Err(CoreError::MalformedData("Xorb Invalid Ident for Boundary Metadata Section".to_string()));
|
||||
}
|
||||
|
||||
s.boundaries_version = read_u8(r)?;
|
||||
|
||||
if s.boundaries_version != XORB_OBJECT_FORMAT_BOUNDARIES_VERSION {
|
||||
return Err(XorbObjectError::Format(anyhow!(
|
||||
"Xorb Invalid Format Version for Boundaries Metadata Section"
|
||||
)));
|
||||
return Err(CoreError::MalformedData(
|
||||
"Xorb Invalid Format Version for Boundaries Metadata Section".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let num_chunks_boundaries_section = read_u32(r)?;
|
||||
@@ -645,9 +648,9 @@ impl XorbObjectInfoV1 {
|
||||
s.num_chunks = read_u32(r)?;
|
||||
|
||||
if s.num_chunks != num_chunks_boundaries_section {
|
||||
return Err(XorbObjectError::Format(anyhow!(
|
||||
"Xorb Invalid: inconsistent num_chunks between metadata and hashes section."
|
||||
)));
|
||||
return Err(CoreError::MalformedData(
|
||||
"Xorb Invalid: inconsistent num_chunks between metadata and hashes section.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
s.hashes_section_offset_from_end = read_u32(r)?;
|
||||
@@ -658,7 +661,9 @@ impl XorbObjectInfoV1 {
|
||||
let end_byte_offset = r.reader_bytes();
|
||||
|
||||
if end_byte_offset != s.boundary_section_offset_from_end as usize {
|
||||
return Err(XorbObjectError::Format(anyhow!("Xorb Invalid: incorrect boundary_section_offset_from_end.")));
|
||||
return Err(CoreError::MalformedData(
|
||||
"Xorb Invalid: incorrect boundary_section_offset_from_end.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
debug_assert!(s.chunk_hashes.is_empty());
|
||||
@@ -669,7 +674,7 @@ impl XorbObjectInfoV1 {
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
pub async fn deserialize_async_v1<R: futures::io::AsyncRead + Unpin>(
|
||||
reader: &mut R,
|
||||
) -> Result<(Self, u32), XorbObjectError> {
|
||||
) -> Result<(Self, u32), CoreError> {
|
||||
// already read 8 bytes (ident + version)
|
||||
let total_bytes_read: u32 = (size_of::<XorbObjectIdent>() + size_of::<u8>()) as u32;
|
||||
|
||||
@@ -692,13 +697,13 @@ impl XorbObjectInfoV1 {
|
||||
read_bytes_async(r, &mut s.ident_hash_section).await?;
|
||||
|
||||
if s.ident_hash_section != XORB_OBJECT_FORMAT_IDENT_HASHES {
|
||||
return Err(XorbObjectError::Format(anyhow!("Xorb Invalid Ident for Hash Metadata Section")));
|
||||
return Err(CoreError::MalformedData("Xorb Invalid Ident for Hash Metadata Section".to_string()));
|
||||
}
|
||||
|
||||
s.hashes_version = read_u8_async(r).await?;
|
||||
|
||||
if s.hashes_version != XORB_OBJECT_FORMAT_HASHES_VERSION {
|
||||
return Err(XorbObjectError::Format(anyhow!("Xorb Invalid Format Version for Hash Metadata Section")));
|
||||
return Err(CoreError::MalformedData("Xorb Invalid Format Version for Hash Metadata Section".to_string()));
|
||||
}
|
||||
|
||||
let num_chunks_2 = read_u32_async(r).await?;
|
||||
@@ -717,23 +722,23 @@ impl XorbObjectInfoV1 {
|
||||
read_bytes_async(r, &mut s.ident_boundary_section).await?;
|
||||
|
||||
if s.ident_boundary_section != XORB_OBJECT_FORMAT_IDENT_BOUNDARIES {
|
||||
return Err(XorbObjectError::Format(anyhow!("Xorb Invalid Ident for Boundary Metadata Section")));
|
||||
return Err(CoreError::MalformedData("Xorb Invalid Ident for Boundary Metadata Section".to_string()));
|
||||
}
|
||||
|
||||
s.boundaries_version = read_u8_async(r).await?;
|
||||
|
||||
if s.boundaries_version != XORB_OBJECT_FORMAT_BOUNDARIES_VERSION {
|
||||
return Err(XorbObjectError::Format(anyhow!(
|
||||
"Xorb Invalid Format Version for Boundaries Metadata Section"
|
||||
)));
|
||||
return Err(CoreError::MalformedData(
|
||||
"Xorb Invalid Format Version for Boundaries Metadata Section".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let num_chunks_3 = read_u32_async(r).await?;
|
||||
|
||||
if num_chunks_2 != num_chunks_3 {
|
||||
return Err(XorbObjectError::Format(anyhow!(
|
||||
"Xorb Invalid: inconsistent num_chunks between hashes and boundaries section."
|
||||
)));
|
||||
return Err(CoreError::MalformedData(
|
||||
"Xorb Invalid: inconsistent num_chunks between hashes and boundaries section.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
s.chunk_boundary_offsets.reserve(prealloc_num_chunks(num_chunks_3 as usize));
|
||||
@@ -749,9 +754,9 @@ impl XorbObjectInfoV1 {
|
||||
s.num_chunks = read_u32_async(r).await?;
|
||||
|
||||
if s.num_chunks != num_chunks_2 {
|
||||
return Err(XorbObjectError::Format(anyhow!(
|
||||
"Xorb Invalid: inconsistent num_chunks between metadata and hashes section."
|
||||
)));
|
||||
return Err(CoreError::MalformedData(
|
||||
"Xorb Invalid: inconsistent num_chunks between metadata and hashes section.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
s.hashes_section_offset_from_end = read_u32_async(r).await?;
|
||||
@@ -762,11 +767,15 @@ impl XorbObjectInfoV1 {
|
||||
let end_byte_offset = r.reader_bytes();
|
||||
|
||||
if end_byte_offset - hash_section_begin_byte_offset != s.hashes_section_offset_from_end as usize {
|
||||
return Err(XorbObjectError::Format(anyhow!("Xorb Invalid: incorrect hashes_section_offset_from_end.")));
|
||||
return Err(CoreError::MalformedData(
|
||||
"Xorb Invalid: incorrect hashes_section_offset_from_end.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if end_byte_offset - boundary_section_begin_byte_offset != s.boundary_section_offset_from_end as usize {
|
||||
return Err(XorbObjectError::Format(anyhow!("Xorb Invalid: incorrect boundary_section_offset_from_end.")));
|
||||
return Err(CoreError::MalformedData(
|
||||
"Xorb Invalid: incorrect boundary_section_offset_from_end.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok((s, r.reader_bytes() as u32 + total_bytes_read))
|
||||
@@ -779,7 +788,7 @@ impl XorbObjectInfoV1 {
|
||||
pub async fn deserialize_async<R: futures::io::AsyncRead + Unpin>(
|
||||
reader: &mut R,
|
||||
version: u8,
|
||||
) -> Result<(Self, u32), XorbObjectError> {
|
||||
) -> Result<(Self, u32), CoreError> {
|
||||
if version == 0 {
|
||||
let (s, n) = XorbObjectInfoV0::deserialize_async(reader, 0).await?;
|
||||
// we don't have the missing info (unpacked_chunk_offsets), it's OK
|
||||
@@ -787,7 +796,7 @@ impl XorbObjectInfoV1 {
|
||||
} else if version == 1 {
|
||||
Self::deserialize_async_v1(reader).await
|
||||
} else {
|
||||
Err(XorbObjectError::Format(anyhow!(
|
||||
Err(CoreError::MalformedData(format!(
|
||||
"Xorb Format Error: Version {version} not supported by this code version."
|
||||
)))
|
||||
}
|
||||
@@ -908,7 +917,7 @@ impl XorbObject {
|
||||
/// make up the info portion of the xorb.
|
||||
///
|
||||
/// Assumes reader has at least size_of::<u32>() bytes, otherwise returns an error.
|
||||
pub fn get_info_length<R: Read + Seek>(reader: &mut R) -> Result<u32, XorbObjectError> {
|
||||
pub fn get_info_length<R: Read + Seek>(reader: &mut R) -> Result<u32, CoreError> {
|
||||
// Go to end of Reader and get length, then jump back to it, and read sequentially
|
||||
// read last 4 bytes to get length
|
||||
reader.seek(SeekFrom::End(-(size_of::<u32>() as i64)))?;
|
||||
@@ -922,7 +931,7 @@ impl XorbObject {
|
||||
/// Deserialize the XorbObjectInfo struct, the metadata for this Xorb.
|
||||
///
|
||||
/// This allows the XorbObject to be partially constructed, allowing for range reads inside the XorbObject.
|
||||
pub fn deserialize<R: Read + Seek>(reader: &mut R) -> Result<Self, XorbObjectError> {
|
||||
pub fn deserialize<R: Read + Seek>(reader: &mut R) -> Result<Self, CoreError> {
|
||||
let info_length = Self::get_info_length(reader)?;
|
||||
|
||||
// now seek back that many bytes + size of length (u32) and read sequentially.
|
||||
@@ -932,7 +941,7 @@ impl XorbObject {
|
||||
|
||||
// validate that info_length matches what we read off of header
|
||||
if total_bytes_read != info_length {
|
||||
return Err(XorbObjectError::Format(anyhow!("Xorb Info Format Error")));
|
||||
return Err(CoreError::MalformedData("Xorb Info Format Error".to_string()));
|
||||
}
|
||||
|
||||
Ok(Self { info, info_length })
|
||||
@@ -944,7 +953,7 @@ impl XorbObject {
|
||||
pub async fn deserialize_async<R: futures::io::AsyncRead + Unpin>(
|
||||
reader: &mut R,
|
||||
version: u8,
|
||||
) -> Result<Self, XorbObjectError> {
|
||||
) -> Result<Self, CoreError> {
|
||||
let (info, total_bytes_read) = XorbObjectInfoV1::deserialize_async(reader, version).await?;
|
||||
|
||||
let mut info_length_buf = [0u8; size_of::<u32>()];
|
||||
@@ -954,18 +963,20 @@ impl XorbObject {
|
||||
let info_length = u32::from_le_bytes(info_length_buf);
|
||||
|
||||
if info_length != total_bytes_read {
|
||||
return Err(XorbObjectError::Format(anyhow!("Xorb Info Format Error")));
|
||||
return Err(CoreError::MalformedData("Xorb Info Format Error".to_string()));
|
||||
}
|
||||
|
||||
// verify we've read to the end
|
||||
if reader.read(&mut [0u8; 8]).await? != 0 {
|
||||
return Err(XorbObjectError::Format(anyhow!("Xorb Reader has content past the end of serialized xorb")));
|
||||
return Err(CoreError::MalformedData(
|
||||
"Xorb Reader has content past the end of serialized xorb".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Self { info, info_length })
|
||||
}
|
||||
|
||||
pub fn serialize_given_info<W: Write>(w: &mut W, info: XorbObjectInfoV1) -> Result<(Self, usize), XorbObjectError> {
|
||||
pub fn serialize_given_info<W: Write>(w: &mut W, info: XorbObjectInfoV1) -> Result<(Self, usize), CoreError> {
|
||||
let mut total_written_bytes: usize = 0;
|
||||
let info_length = info.serialize(w)? as u32;
|
||||
total_written_bytes += info_length as usize;
|
||||
@@ -993,7 +1004,7 @@ impl XorbObject {
|
||||
pub fn validate_xorb_object<R: Read + Seek>(
|
||||
reader: &mut R,
|
||||
hash: &MerkleHash,
|
||||
) -> Result<Option<XorbObject>, XorbObjectError> {
|
||||
) -> Result<Option<XorbObject>, CoreError> {
|
||||
// 1. deserialize to get Info
|
||||
// Errors can occur if either
|
||||
// - the object doesn't have at least 4 bytes for the "info_length";
|
||||
@@ -1092,11 +1103,11 @@ impl XorbObject {
|
||||
&self,
|
||||
chunk_start_index: u32,
|
||||
chunk_end_index: u32,
|
||||
) -> Result<DataHash, XorbObjectError> {
|
||||
) -> Result<DataHash, CoreError> {
|
||||
self.validate_xorb_object_info()?;
|
||||
|
||||
if chunk_end_index <= chunk_start_index || chunk_end_index > self.info.num_chunks {
|
||||
return Err(XorbObjectError::InvalidArguments);
|
||||
return Err(CoreError::InvalidArguments);
|
||||
}
|
||||
|
||||
// Collect relevant hashes
|
||||
@@ -1106,11 +1117,11 @@ impl XorbObject {
|
||||
}
|
||||
|
||||
/// Return end offset of all physical chunk contents (byte index at the beginning of footer)
|
||||
pub fn get_contents_length(&self) -> Result<u32, XorbObjectError> {
|
||||
pub fn get_contents_length(&self) -> Result<u32, CoreError> {
|
||||
self.validate_xorb_object_info()?;
|
||||
match self.info.chunk_boundary_offsets.last() {
|
||||
Some(c) => Ok(*c),
|
||||
None => Err(XorbObjectError::Format(anyhow!("Cannot retrieve content length"))),
|
||||
None => Err(CoreError::MalformedData("Cannot retrieve content length".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1119,14 +1130,9 @@ impl XorbObject {
|
||||
/// start and end are byte indices into the physical layout of a xorb.
|
||||
///
|
||||
/// The start and end parameters are required to align with chunk boundaries.
|
||||
fn get_range<R: Read + Seek>(
|
||||
&self,
|
||||
reader: &mut R,
|
||||
byte_start: u32,
|
||||
byte_end: u32,
|
||||
) -> Result<Vec<u8>, XorbObjectError> {
|
||||
fn get_range<R: Read + Seek>(&self, reader: &mut R, byte_start: u32, byte_end: u32) -> Result<Vec<u8>, CoreError> {
|
||||
if byte_end < byte_start {
|
||||
return Err(XorbObjectError::InvalidRange);
|
||||
return Err(CoreError::InvalidRange);
|
||||
}
|
||||
|
||||
self.validate_xorb_object_info()?;
|
||||
@@ -1145,7 +1151,7 @@ impl XorbObject {
|
||||
}
|
||||
|
||||
/// Get all the content bytes from a Xorb
|
||||
pub fn get_all_bytes<R: Read + Seek>(&self, reader: &mut R) -> Result<Vec<u8>, XorbObjectError> {
|
||||
pub fn get_all_bytes<R: Read + Seek>(&self, reader: &mut R) -> Result<Vec<u8>, CoreError> {
|
||||
self.validate_xorb_object_info()?;
|
||||
self.get_range(reader, 0, self.get_contents_length()?)
|
||||
}
|
||||
@@ -1156,14 +1162,14 @@ impl XorbObject {
|
||||
reader: &mut R,
|
||||
chunk_index_start: u32,
|
||||
chunk_index_end: u32,
|
||||
) -> Result<Vec<u8>, XorbObjectError> {
|
||||
) -> Result<Vec<u8>, CoreError> {
|
||||
let (byte_start, byte_end) = self.get_byte_offset(chunk_index_start, chunk_index_end)?;
|
||||
|
||||
self.get_range(reader, byte_start, byte_end)
|
||||
}
|
||||
|
||||
/// Assumes chunk_data is 1+ complete chunks. Processes them sequentially and returns them as Vec<u8>.
|
||||
fn get_chunk_contents(&self, chunk_data: &[u8]) -> Result<Vec<u8>, XorbObjectError> {
|
||||
fn get_chunk_contents(&self, chunk_data: &[u8]) -> Result<Vec<u8>, CoreError> {
|
||||
// walk chunk_data, deserialize into Chunks, and then get_bytes() from each of them.
|
||||
let mut reader = Cursor::new(chunk_data);
|
||||
let mut res = Vec::<u8>::new();
|
||||
@@ -1176,10 +1182,10 @@ impl XorbObject {
|
||||
}
|
||||
|
||||
/// Helper function to translate a range of chunk indices to physical byte offset range.
|
||||
pub fn get_byte_offset(&self, chunk_index_start: u32, chunk_index_end: u32) -> Result<(u32, u32), XorbObjectError> {
|
||||
pub fn get_byte_offset(&self, chunk_index_start: u32, chunk_index_end: u32) -> Result<(u32, u32), CoreError> {
|
||||
self.validate_xorb_object_info()?;
|
||||
if chunk_index_end <= chunk_index_start || chunk_index_end > self.info.num_chunks {
|
||||
return Err(XorbObjectError::InvalidArguments);
|
||||
return Err(CoreError::InvalidArguments);
|
||||
}
|
||||
|
||||
let byte_offset_start = match chunk_index_start {
|
||||
@@ -1193,11 +1199,11 @@ impl XorbObject {
|
||||
|
||||
/// given a valid chunk_index, returns the uncompressed chunk length for the chunk
|
||||
/// at the given index, chunk_index must be less than the number of chunks in the xorb
|
||||
pub fn uncompressed_chunk_length(&self, chunk_index: u32) -> Result<u32, XorbObjectError> {
|
||||
pub fn uncompressed_chunk_length(&self, chunk_index: u32) -> Result<u32, CoreError> {
|
||||
self.validate_xorb_object_info()?;
|
||||
let chunk_index = chunk_index as usize;
|
||||
if chunk_index >= self.info.unpacked_chunk_offsets.len() {
|
||||
return Err(XorbObjectError::InvalidArguments);
|
||||
return Err(CoreError::InvalidArguments);
|
||||
}
|
||||
let cumulative_sum = self.info.unpacked_chunk_offsets[chunk_index];
|
||||
let before = match chunk_index {
|
||||
@@ -1213,17 +1219,13 @@ impl XorbObject {
|
||||
/// chunk_index_start <= chunk_index_end &&
|
||||
/// chunk_index_end <= num_chunks &&
|
||||
/// chunk_index_start < num_chunks
|
||||
pub fn uncompressed_range_length(
|
||||
&self,
|
||||
chunk_index_start: u32,
|
||||
chunk_index_end: u32,
|
||||
) -> Result<u32, XorbObjectError> {
|
||||
pub fn uncompressed_range_length(&self, chunk_index_start: u32, chunk_index_end: u32) -> Result<u32, CoreError> {
|
||||
self.validate_xorb_object_info()?;
|
||||
if chunk_index_start > chunk_index_end
|
||||
|| chunk_index_end > self.info.num_chunks
|
||||
|| chunk_index_start >= self.info.num_chunks
|
||||
{
|
||||
return Err(XorbObjectError::InvalidArguments);
|
||||
return Err(CoreError::InvalidArguments);
|
||||
}
|
||||
|
||||
// this check is important if chunk_index_end is 0
|
||||
@@ -1240,9 +1242,9 @@ impl XorbObject {
|
||||
}
|
||||
|
||||
/// Helper method to verify that info object is complete
|
||||
fn validate_xorb_object_info(&self) -> Result<(), XorbObjectError> {
|
||||
fn validate_xorb_object_info(&self) -> Result<(), CoreError> {
|
||||
if self.info.num_chunks == 0 {
|
||||
return Err(XorbObjectError::Format(anyhow!("Invalid XorbObjectInfo, no chunks in XorbObject.")));
|
||||
return Err(CoreError::MalformedData("Invalid XorbObjectInfo, no chunks in XorbObject.".to_string()));
|
||||
}
|
||||
|
||||
if self.info.num_chunks != self.info.chunk_boundary_offsets.len() as u32
|
||||
@@ -1250,13 +1252,13 @@ impl XorbObject {
|
||||
|| (self.info.boundaries_version == XORB_OBJECT_FORMAT_BOUNDARIES_VERSION
|
||||
&& self.info.num_chunks != self.info.unpacked_chunk_offsets.len() as u32)
|
||||
{
|
||||
return Err(XorbObjectError::Format(anyhow!(
|
||||
"Invalid XorbObjectInfo, num chunks not matching boundaries or hashes."
|
||||
)));
|
||||
return Err(CoreError::MalformedData(
|
||||
"Invalid XorbObjectInfo, num chunks not matching boundaries or hashes.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if self.info.xorb_hash == MerkleHash::default() {
|
||||
return Err(XorbObjectError::Format(anyhow!("Invalid XorbObjectInfo, Missing xorb_hash.")));
|
||||
return Err(CoreError::MalformedData("Invalid XorbObjectInfo, Missing xorb_hash.".to_string()));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -1285,7 +1287,7 @@ impl SerializedXorbObject {
|
||||
///
|
||||
/// The compression scheme is determined by `HF_XET_XORB_COMPRESSION_POLICY`:
|
||||
/// auto-detect (default) or an explicit scheme (none, lz4, bg4-lz4).
|
||||
pub fn from_xorb(xorb: RawXorbData, serialize_footer: bool) -> Result<Self, XorbObjectError> {
|
||||
pub fn from_xorb(xorb: RawXorbData, serialize_footer: bool) -> Result<Self, CoreError> {
|
||||
let compression_scheme: CompressionScheme = xet_config().xorb.compression_policy.parse()?;
|
||||
Self::from_xorb_with_compression(xorb, compression_scheme, serialize_footer)
|
||||
}
|
||||
@@ -1295,7 +1297,7 @@ impl SerializedXorbObject {
|
||||
xorb: RawXorbData,
|
||||
compression_scheme: CompressionScheme,
|
||||
serialize_footer: bool,
|
||||
) -> Result<Self, XorbObjectError> {
|
||||
) -> Result<Self, CoreError> {
|
||||
let mut xorb_object_info = XorbObjectInfoV1::default();
|
||||
|
||||
let hash = xorb.hash();
|
||||
@@ -1396,7 +1398,7 @@ pub mod test_utils {
|
||||
data: &[u8],
|
||||
chunk_and_boundaries: &[(MerkleHash, u32)],
|
||||
compression_scheme: CompressionScheme,
|
||||
) -> Result<(XorbObject, usize, u64), XorbObjectError> {
|
||||
) -> Result<(XorbObject, usize, u64), CoreError> {
|
||||
let mut xorb = XorbObject::default();
|
||||
xorb.info.xorb_hash = *hash;
|
||||
xorb.info.num_chunks = chunk_and_boundaries.len() as u32;
|
||||
@@ -1446,7 +1448,7 @@ pub mod test_utils {
|
||||
data: Vec<u8>,
|
||||
chunk_and_boundaries: Vec<(MerkleHash, u32)>,
|
||||
compression: CompressionScheme,
|
||||
) -> Result<SerializedXorbObject, XorbObjectError> {
|
||||
) -> Result<SerializedXorbObject, CoreError> {
|
||||
let mut writer = Cursor::new(Vec::new());
|
||||
|
||||
let (_, _, footer_start) =
|
||||
@@ -1647,7 +1649,7 @@ pub mod test_utils {
|
||||
pub fn reconstruct_xorb_with_footer(
|
||||
writer: &mut impl Write,
|
||||
raw_data: &[u8],
|
||||
) -> Result<(XorbObject, MerkleHash), XorbObjectError> {
|
||||
) -> Result<(XorbObject, MerkleHash), CoreError> {
|
||||
let mut reader = Cursor::new(raw_data);
|
||||
let mut chunk_hash_and_size: Vec<(MerkleHash, u64)> = Vec::new();
|
||||
let mut info = XorbObjectInfoV1::default();
|
||||
@@ -1655,7 +1657,7 @@ pub fn reconstruct_xorb_with_footer(
|
||||
while (reader.position() as usize) < raw_data.len() {
|
||||
let chunk_header = match deserialize_chunk_header(&mut reader) {
|
||||
Ok(header) => header,
|
||||
Err(XorbObjectError::ChunkHeaderParse) => {
|
||||
Err(CoreError::ChunkHeaderParse) => {
|
||||
// Hit footer identifier, stop processing chunks
|
||||
break;
|
||||
},
|
||||
@@ -1666,12 +1668,12 @@ pub fn reconstruct_xorb_with_footer(
|
||||
let mut compressed_buf = vec![0u8; compressed_len];
|
||||
reader
|
||||
.read_exact(&mut compressed_buf)
|
||||
.map_err(|e| XorbObjectError::Format(anyhow!("Failed to read chunk data: {e}")))?;
|
||||
.map_err(|e| CoreError::MalformedData(format!("Failed to read chunk data: {e}")))?;
|
||||
|
||||
let uncompressed_data = chunk_header
|
||||
.get_compression_scheme()?
|
||||
.decompress_from_slice(&compressed_buf)
|
||||
.map_err(|e| XorbObjectError::Format(anyhow!("Failed to decompress chunk: {e}")))?;
|
||||
.map_err(|e| CoreError::MalformedData(format!("Failed to decompress chunk: {e}")))?;
|
||||
|
||||
let chunk_hash = crate::merklehash::compute_data_hash(&uncompressed_data);
|
||||
chunk_hash_and_size.push((chunk_hash, uncompressed_data.len() as u64));
|
||||
@@ -1790,9 +1792,9 @@ mod tests {
|
||||
build_xorb_object(5, ChunkSize::Fixed(100), CompressionScheme::None);
|
||||
|
||||
// Act & Assert
|
||||
assert_eq!(c.generate_chunk_range_hash(1, 6), Err(XorbObjectError::InvalidArguments));
|
||||
assert_eq!(c.generate_chunk_range_hash(100, 10), Err(XorbObjectError::InvalidArguments));
|
||||
assert_eq!(c.generate_chunk_range_hash(0, 0), Err(XorbObjectError::InvalidArguments));
|
||||
assert_eq!(c.generate_chunk_range_hash(1, 6), Err(CoreError::InvalidArguments));
|
||||
assert_eq!(c.generate_chunk_range_hash(100, 10), Err(CoreError::InvalidArguments));
|
||||
assert_eq!(c.generate_chunk_range_hash(0, 0), Err(CoreError::InvalidArguments));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -1806,7 +1808,10 @@ mod tests {
|
||||
// no chunks
|
||||
let c = XorbObject::default();
|
||||
let result = c.validate_xorb_object_info();
|
||||
assert_eq!(result, Err(XorbObjectError::Format(anyhow!("Invalid XorbObjectInfo, no chunks in XorbObject."))));
|
||||
assert_eq!(
|
||||
result,
|
||||
Err(CoreError::MalformedData("Invalid XorbObjectInfo, no chunks in XorbObject.".to_string()))
|
||||
);
|
||||
|
||||
// num_chunks doesn't match chunk_boundaries.len()
|
||||
let mut c = XorbObject::default();
|
||||
@@ -1814,9 +1819,9 @@ mod tests {
|
||||
let result = c.validate_xorb_object_info();
|
||||
assert_eq!(
|
||||
result,
|
||||
Err(XorbObjectError::Format(anyhow!(
|
||||
"Invalid XorbObjectInfo, num chunks not matching boundaries or hashes."
|
||||
)))
|
||||
Err(CoreError::MalformedData(
|
||||
"Invalid XorbObjectInfo, num chunks not matching boundaries or hashes.".to_string(),
|
||||
))
|
||||
);
|
||||
|
||||
// no hash
|
||||
@@ -1824,7 +1829,7 @@ mod tests {
|
||||
build_xorb_object(1, ChunkSize::Fixed(100), CompressionScheme::None);
|
||||
c.info.xorb_hash = MerkleHash::default();
|
||||
let result = c.validate_xorb_object_info();
|
||||
assert_eq!(result, Err(XorbObjectError::Format(anyhow!("Invalid XorbObjectInfo, Missing xorb_hash."))));
|
||||
assert_eq!(result, Err(CoreError::MalformedData("Invalid XorbObjectInfo, Missing xorb_hash.".to_string())));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -6,7 +6,7 @@ use tokio::sync::AcquireError;
|
||||
use tracing::error;
|
||||
use xet_client::ClientError;
|
||||
use xet_client::cas_client::auth::AuthError;
|
||||
use xet_core_structures::FormatError;
|
||||
use xet_core_structures::CoreError;
|
||||
use xet_core_structures::merklehash::DataHashHexParseError;
|
||||
use xet_runtime::RuntimeError;
|
||||
use xet_runtime::core::par_utils::ParutilsError;
|
||||
@@ -44,8 +44,8 @@ pub enum DataError {
|
||||
#[error("Channel error: {0}")]
|
||||
ChannelRecvError(#[from] RecvError),
|
||||
|
||||
#[error("Format error: {0}")]
|
||||
FormatError(#[from] FormatError),
|
||||
#[error("Core structures error: {0}")]
|
||||
FormatError(#[from] CoreError),
|
||||
|
||||
#[error("Client error: {0}")]
|
||||
ClientError(#[from] ClientError),
|
||||
|
||||
@@ -7,7 +7,7 @@ use thiserror::Error;
|
||||
#[derive(Error, Debug, Clone)]
|
||||
pub enum FileReconstructionError {
|
||||
#[error("CAS Client Error: {0}")]
|
||||
CasClientError(Arc<xet_client::cas_client::CasClientError>),
|
||||
ClientError(Arc<xet_client::ClientError>),
|
||||
|
||||
#[error("IO Error: {0}")]
|
||||
IoError(Arc<std::io::Error>),
|
||||
@@ -31,7 +31,7 @@ pub enum FileReconstructionError {
|
||||
TaskJoinError(Arc<tokio::task::JoinError>),
|
||||
|
||||
#[error("Runtime Error: {0}")]
|
||||
RuntimeError(Arc<xet_runtime::core::errors::MultithreadedRuntimeError>),
|
||||
RuntimeError(Arc<xet_runtime::RuntimeError>),
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, FileReconstructionError>;
|
||||
@@ -42,9 +42,9 @@ impl From<std::io::Error> for FileReconstructionError {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<xet_client::cas_client::CasClientError> for FileReconstructionError {
|
||||
fn from(err: xet_client::cas_client::CasClientError) -> Self {
|
||||
FileReconstructionError::CasClientError(Arc::new(err))
|
||||
impl From<xet_client::ClientError> for FileReconstructionError {
|
||||
fn from(err: xet_client::ClientError) -> Self {
|
||||
FileReconstructionError::ClientError(Arc::new(err))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -60,8 +60,8 @@ impl From<tokio::task::JoinError> for FileReconstructionError {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<xet_runtime::core::errors::MultithreadedRuntimeError> for FileReconstructionError {
|
||||
fn from(err: xet_runtime::core::errors::MultithreadedRuntimeError) -> Self {
|
||||
impl From<xet_runtime::RuntimeError> for FileReconstructionError {
|
||||
fn from(err: xet_runtime::RuntimeError) -> Self {
|
||||
FileReconstructionError::RuntimeError(Arc::new(err))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -139,20 +139,18 @@ pub struct XorbURLProvider {
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl URLProvider for XorbURLProvider {
|
||||
async fn retrieve_url(
|
||||
&self,
|
||||
) -> std::result::Result<(String, Vec<HttpRange>), xet_client::cas_client::CasClientError> {
|
||||
async fn retrieve_url(&self) -> std::result::Result<(String, Vec<HttpRange>), xet_client::ClientError> {
|
||||
let (unique_id, url, http_ranges) = self.url_info.get_retrieval_url(self.xorb_block_index).await;
|
||||
*self.last_acquisition_id.lock().await = unique_id;
|
||||
|
||||
Ok((url, http_ranges))
|
||||
}
|
||||
|
||||
async fn refresh_url(&self) -> std::result::Result<(), xet_client::cas_client::CasClientError> {
|
||||
async fn refresh_url(&self) -> std::result::Result<(), xet_client::ClientError> {
|
||||
self.url_info
|
||||
.refresh_retrieval_urls(self.client.clone(), *self.last_acquisition_id.lock().await)
|
||||
.await
|
||||
.map_err(|e| xet_client::cas_client::CasClientError::Other(e.to_string()))
|
||||
.map_err(|e| xet_client::ClientError::Other(e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#![cfg_attr(feature = "strict", deny(warnings))]
|
||||
|
||||
pub mod error;
|
||||
pub use error::DataError;
|
||||
pub use error::{DataError, Result};
|
||||
|
||||
pub mod deduplication;
|
||||
#[cfg(not(target_family = "wasm"))]
|
||||
|
||||
@@ -128,8 +128,9 @@ async fn smudge(_name: Arc<str>, mut reader: impl Read, output_path: PathBuf) ->
|
||||
let mut input = String::new();
|
||||
reader.read_to_string(&mut input)?;
|
||||
|
||||
let xet_file: XetFileInfo = serde_json::from_str(&input)
|
||||
.map_err(|_| anyhow::anyhow!("Failed to parse xet file info. Please check the format."))?;
|
||||
let xet_file: XetFileInfo = serde_json::from_str(&input).map_err(|_| {
|
||||
std::io::Error::new(std::io::ErrorKind::InvalidData, "Failed to parse xet file info. Please check the format.")
|
||||
})?;
|
||||
|
||||
// Use local config pointing to current directory
|
||||
let cas_path = std::env::current_dir()?;
|
||||
|
||||
@@ -220,7 +220,7 @@ async fn query_reconstruction(
|
||||
remote_client
|
||||
.get_reconstruction_v1(&file_hash, bytes_range)
|
||||
.await
|
||||
.map_err(anyhow::Error::from)
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@@ -231,13 +231,16 @@ fn main() -> Result<()> {
|
||||
&& let Some(c) = arg.compression
|
||||
{
|
||||
let scheme = CompressionScheme::try_from(c).map_err(|_| {
|
||||
anyhow::anyhow!("Invalid compression value {c}; expected one of: 0 (none), 1 (lz4), 2 (bg4-lz4), 99 (auto)")
|
||||
std::io::Error::new(
|
||||
std::io::ErrorKind::InvalidInput,
|
||||
format!("Invalid compression value {c}; expected one of: 0 (none), 1 (lz4), 2 (bg4-lz4), 99 (auto)"),
|
||||
)
|
||||
})?;
|
||||
config
|
||||
.xorb
|
||||
.compression_policy
|
||||
.try_set(<&str>::from(scheme))
|
||||
.map_err(|e| anyhow::anyhow!(e))?;
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e.to_string()))?;
|
||||
}
|
||||
|
||||
let threadpool = XetRuntime::new_with_config(config)?;
|
||||
|
||||
@@ -6,7 +6,7 @@ use tracing::info;
|
||||
use xet_client::cas_client::auth::AuthConfig;
|
||||
use xet_runtime::core::{xet_cache_root, xet_config};
|
||||
|
||||
use super::errors::Result;
|
||||
use crate::error::Result;
|
||||
|
||||
/// Session-specific configuration that varies per upload/download session.
|
||||
/// These are runtime values that cannot be configured via environment variables.
|
||||
|
||||
@@ -14,15 +14,16 @@ use xet_runtime::core::{XetRuntime, check_sigint_shutdown, xet_config};
|
||||
|
||||
use super::configurations::{SessionContext, TranslatorConfig};
|
||||
use super::file_cleaner::Sha256Policy;
|
||||
use super::{FileUploadSession, XetFileInfo, errors};
|
||||
use super::{FileUploadSession, XetFileInfo};
|
||||
use crate::deduplication::{Chunker, DeduplicationMetrics};
|
||||
use crate::error::Result;
|
||||
|
||||
pub fn default_config(
|
||||
endpoint: String,
|
||||
token_info: Option<(String, u64)>,
|
||||
token_refresher: Option<Arc<dyn TokenRefresher>>,
|
||||
custom_headers: Option<Arc<HeaderMap>>,
|
||||
) -> errors::Result<TranslatorConfig> {
|
||||
) -> Result<TranslatorConfig> {
|
||||
let (token, token_expiration) = token_info.unzip();
|
||||
let auth_cfg = AuthConfig::maybe_new(token, token_expiration, token_refresher);
|
||||
|
||||
@@ -42,7 +43,7 @@ pub async fn clean_bytes(
|
||||
processor: Arc<FileUploadSession>,
|
||||
bytes: Vec<u8>,
|
||||
sha256_policy: Sha256Policy,
|
||||
) -> errors::Result<(XetFileInfo, DeduplicationMetrics)> {
|
||||
) -> Result<(XetFileInfo, DeduplicationMetrics)> {
|
||||
let (_id, mut handle) = processor.start_clean(None, bytes.len() as u64, sha256_policy)?;
|
||||
handle.add_data(&bytes).await?;
|
||||
handle.finish().await
|
||||
@@ -53,7 +54,7 @@ pub async fn clean_file(
|
||||
processor: Arc<FileUploadSession>,
|
||||
filename: impl AsRef<Path>,
|
||||
sha256_policy: Sha256Policy,
|
||||
) -> errors::Result<(XetFileInfo, DeduplicationMetrics)> {
|
||||
) -> Result<(XetFileInfo, DeduplicationMetrics)> {
|
||||
let mut reader = File::open(&filename)?;
|
||||
|
||||
let filesize = reader.metadata()?.len();
|
||||
@@ -98,7 +99,7 @@ pub async fn clean_file(
|
||||
/// - Verify that downloaded files are correctly reassembled
|
||||
/// - Check if a file needs to be uploaded (by comparing hashes)
|
||||
/// - Generate cache keys for local file operations
|
||||
fn hash_single_file(filename: String, buffer_size: usize) -> errors::Result<XetFileInfo> {
|
||||
fn hash_single_file(filename: String, buffer_size: usize) -> Result<XetFileInfo> {
|
||||
let mut reader = File::open(&filename)?;
|
||||
let filesize = reader.metadata()?.len();
|
||||
|
||||
@@ -157,7 +158,7 @@ fn hash_single_file(filename: String, buffer_size: usize) -> errors::Result<XetF
|
||||
/// - No authentication or server connection required
|
||||
/// - Pure local computation
|
||||
#[instrument(skip_all, name = "data_client::hash_files", fields(num_files=file_paths.len()))]
|
||||
pub async fn hash_files_async(file_paths: Vec<String>) -> errors::Result<Vec<XetFileInfo>> {
|
||||
pub async fn hash_files_async(file_paths: Vec<String>) -> Result<Vec<XetFileInfo>> {
|
||||
let rt = XetRuntime::current();
|
||||
let semaphore = rt.common().file_ingestion_semaphore.clone();
|
||||
let buffer_size = *xet_config().data.ingestion_block_size as usize;
|
||||
|
||||
@@ -6,9 +6,9 @@ use tracing::Instrument;
|
||||
use xet_core_structures::merklehash::MerkleHash;
|
||||
use xet_core_structures::metadata_shard::file_structs::FileDataSequenceEntry;
|
||||
|
||||
use super::errors::Result;
|
||||
use super::file_upload_session::FileUploadSession;
|
||||
use crate::deduplication::{DeduplicationDataInterface, RawXorbData};
|
||||
use crate::error::{DataError, Result};
|
||||
use crate::progress_tracking::upload_tracking::FileXorbDependency;
|
||||
|
||||
pub struct UploadSessionDataManager {
|
||||
@@ -31,7 +31,7 @@ impl UploadSessionDataManager {
|
||||
|
||||
#[async_trait]
|
||||
impl DeduplicationDataInterface for UploadSessionDataManager {
|
||||
type ErrorType = super::errors::DataProcessingError;
|
||||
type ErrorType = DataError;
|
||||
|
||||
/// Query for possible shards that may dedup some chunks.
|
||||
async fn chunk_hash_dedup_query(
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
pub use crate::error::DataError as DataProcessingError;
|
||||
pub type Result<T> = std::result::Result<T, DataProcessingError>;
|
||||
@@ -11,10 +11,10 @@ use xet_runtime::core::{XetRuntime, xet_config};
|
||||
|
||||
use super::XetFileInfo;
|
||||
use super::deduplication_interface::UploadSessionDataManager;
|
||||
use super::errors::Result;
|
||||
use super::file_upload_session::FileUploadSession;
|
||||
use super::sha256::Sha256Generator;
|
||||
use crate::deduplication::{Chunk, Chunker, DeduplicationMetrics, FileDeduper};
|
||||
use crate::error::Result;
|
||||
use crate::progress_tracking::upload_tracking::CompletionTrackerFileId;
|
||||
|
||||
/// Controls how SHA-256 is handled during file cleaning.
|
||||
|
||||
@@ -12,9 +12,9 @@ use xet_client::cas_types::FileRange;
|
||||
use xet_runtime::core::{XetRuntime, xet_config};
|
||||
|
||||
use super::configurations::TranslatorConfig;
|
||||
use super::errors::*;
|
||||
use super::remote_client_interface::create_remote_client;
|
||||
use super::{XetFileInfo, prometheus_metrics};
|
||||
use crate::error::{DataError, Result};
|
||||
use crate::file_reconstruction::{DownloadStream, FileReconstructor};
|
||||
use crate::progress_tracking::{GroupProgress, ItemProgressUpdater, UniqueID};
|
||||
|
||||
@@ -159,7 +159,7 @@ impl FileDownloadSession {
|
||||
|
||||
fn check_not_finalized(&self) -> Result<()> {
|
||||
if self.finalized.load(Ordering::Acquire) {
|
||||
return Err(DataProcessingError::InvalidOperation("FileDownloadSession already finalized".to_string()));
|
||||
return Err(DataError::InvalidOperation("FileDownloadSession already finalized".to_string()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -167,7 +167,7 @@ impl FileDownloadSession {
|
||||
/// Finalizes the session; in debug builds, asserts all items are complete.
|
||||
pub async fn finalize(&self) -> Result<()> {
|
||||
if self.finalized.swap(true, Ordering::AcqRel) {
|
||||
return Err(DataProcessingError::InvalidOperation("FileDownloadSession already finalized".to_string()));
|
||||
return Err(DataError::InvalidOperation("FileDownloadSession already finalized".to_string()));
|
||||
}
|
||||
#[cfg(debug_assertions)]
|
||||
self.progress.assert_complete();
|
||||
|
||||
@@ -18,13 +18,13 @@ use xet_core_structures::xorb_object::SerializedXorbObject;
|
||||
use xet_runtime::core::{XetRuntime, xet_config};
|
||||
|
||||
use super::configurations::TranslatorConfig;
|
||||
use super::errors::*;
|
||||
use super::file_cleaner::{Sha256Policy, SingleFileCleaner};
|
||||
use super::remote_client_interface::create_remote_client;
|
||||
use super::shard_interface::SessionShardInterface;
|
||||
use super::{XetFileInfo, prometheus_metrics};
|
||||
use crate::deduplication::constants::{MAX_XORB_BYTES, MAX_XORB_CHUNKS};
|
||||
use crate::deduplication::{DataAggregator, DeduplicationMetrics, RawXorbData};
|
||||
use crate::error::{DataError, Result};
|
||||
use crate::progress_tracking::upload_tracking::{CompletionTracker, FileXorbDependency};
|
||||
use crate::progress_tracking::{GroupProgress, GroupProgressReport, ItemProgressReport, UniqueID};
|
||||
|
||||
@@ -472,7 +472,7 @@ impl FileUploadSession {
|
||||
return_files: bool,
|
||||
) -> Result<(DeduplicationMetrics, Vec<MDBFileInfo>, GroupProgressReport)> {
|
||||
if self.finalized.swap(true, Ordering::AcqRel) {
|
||||
return Err(DataProcessingError::InvalidOperation("FileUploadSession already finalized".to_string()));
|
||||
return Err(DataError::InvalidOperation("FileUploadSession already finalized".to_string()));
|
||||
}
|
||||
|
||||
// Register the remaining xorbs for upload.
|
||||
@@ -537,7 +537,7 @@ impl FileUploadSession {
|
||||
|
||||
fn check_not_finalized(&self) -> Result<()> {
|
||||
if self.finalized.load(Ordering::Acquire) {
|
||||
return Err(DataProcessingError::InvalidOperation("FileUploadSession already finalized".to_string()));
|
||||
return Err(DataError::InvalidOperation("FileUploadSession already finalized".to_string()));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use http::header;
|
||||
use tracing::{Instrument, Span, info_span, instrument};
|
||||
use xet_client::cas_client::auth::TokenRefresher;
|
||||
@@ -10,9 +9,9 @@ use xet_runtime::core::XetRuntime;
|
||||
use xet_runtime::core::par_utils::run_constrained;
|
||||
|
||||
use super::super::data_client::{clean_file, default_config};
|
||||
use super::super::errors::DataProcessingError;
|
||||
use super::super::{FileUploadSession, Sha256Policy, XetFileInfo};
|
||||
use super::hub_client_token_refresher::HubClientTokenRefresher;
|
||||
use crate::error::{DataError, Result};
|
||||
|
||||
const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
|
||||
|
||||
@@ -98,7 +97,9 @@ pub async fn migrate_files_impl(
|
||||
let sha256_policies: Vec<Sha256Policy> = match sha256s {
|
||||
Some(v) => {
|
||||
if v.len() != file_paths.len() {
|
||||
return Err(anyhow!("mismatched length of the file list and the sha256 list"));
|
||||
return Err(DataError::ParameterError(
|
||||
"mismatched length of the file list and the sha256 list".to_string(),
|
||||
));
|
||||
}
|
||||
v.iter().map(|s| Sha256Policy::from_hex(s)).collect()
|
||||
},
|
||||
@@ -109,7 +110,7 @@ pub async fn migrate_files_impl(
|
||||
let proc = processor.clone();
|
||||
async move {
|
||||
let (pf, metrics) = clean_file(proc, file_path, policy).await?;
|
||||
Ok::<(XetFileInfo, u64), DataProcessingError>((pf, metrics.new_bytes))
|
||||
Ok::<(XetFileInfo, u64), DataError>((pf, metrics.new_bytes))
|
||||
}
|
||||
.instrument(info_span!("clean_file"))
|
||||
});
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
pub mod configurations;
|
||||
pub mod data_client;
|
||||
mod deduplication_interface;
|
||||
pub mod errors;
|
||||
mod file_cleaner;
|
||||
mod file_download_session;
|
||||
mod file_upload_session;
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::sync::Arc;
|
||||
use xet_client::cas_client::{Client, RemoteClient};
|
||||
|
||||
use super::configurations::TranslatorConfig;
|
||||
use super::errors::Result;
|
||||
use crate::error::Result;
|
||||
|
||||
pub(crate) async fn create_remote_client(
|
||||
config: &TranslatorConfig,
|
||||
|
||||
@@ -25,7 +25,7 @@ use xet_runtime::core::xet_config;
|
||||
use xet_runtime::error_printer::ErrorPrinter;
|
||||
|
||||
use super::configurations::TranslatorConfig;
|
||||
use super::errors::Result;
|
||||
use crate::error::Result;
|
||||
|
||||
pub struct SessionShardInterface {
|
||||
session_shard_manager: Arc<ShardFileManager>,
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::io::Write;
|
||||
use std::path::Path;
|
||||
use std::process::Command;
|
||||
|
||||
use anyhow::anyhow;
|
||||
use anyhow::Result;
|
||||
use tempfile::TempDir;
|
||||
use tracing::info;
|
||||
|
||||
@@ -34,7 +34,7 @@ impl IntegrationTest {
|
||||
self.assets.push((name.to_owned(), arg));
|
||||
}
|
||||
|
||||
fn run(&self) -> anyhow::Result<()> {
|
||||
fn run(&self) -> Result<()> {
|
||||
// Create a temporary directory
|
||||
let tmp_repo_dest = TempDir::new().unwrap();
|
||||
let tmp_path_path = tmp_repo_dest.path().to_path_buf();
|
||||
@@ -94,9 +94,13 @@ impl IntegrationTest {
|
||||
let captures = error_re.captures(stderr_out);
|
||||
|
||||
if let Some(captured_text) = captures {
|
||||
Err(anyhow!("Test failed: {}", captured_text.get(1).unwrap().as_str()))
|
||||
Err(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
format!("Test failed: {}", captured_text.get(1).unwrap().as_str()),
|
||||
)
|
||||
.into())
|
||||
} else {
|
||||
Err(anyhow!("Test failed: Unknown Error."))
|
||||
Err(std::io::Error::new(std::io::ErrorKind::Other, "Test failed: Unknown Error.").into())
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -104,9 +108,11 @@ impl IntegrationTest {
|
||||
|
||||
#[cfg(all(test, unix))]
|
||||
mod git_integration_tests {
|
||||
use anyhow::Result;
|
||||
|
||||
use super::*;
|
||||
#[test]
|
||||
fn test_basic_read() -> anyhow::Result<()> {
|
||||
fn test_basic_read() -> Result<()> {
|
||||
IntegrationTest::new(include_str!("integration_tests/test_basic_clean_smudge.sh")).run()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ xet-core-structures = { version = "1.4.0", path = "../xet_core_structures" }
|
||||
xet-client = { version = "1.4.0", path = "../xet_client" }
|
||||
xet-data = { version = "1.4.0", path = "../xet_data" }
|
||||
|
||||
anyhow = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
http = { workspace = true }
|
||||
more-asserts = { workspace = true }
|
||||
@@ -26,9 +27,10 @@ thiserror = { workspace = true }
|
||||
tokio = { workspace = true, features = ["net", "time"] }
|
||||
tracing = { workspace = true }
|
||||
ulid = { workspace = true }
|
||||
pyo3 = { workspace = true, optional = true }
|
||||
|
||||
[features]
|
||||
python = ["xet-runtime/python"]
|
||||
python = ["xet-runtime/python", "dep:pyo3"]
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use thiserror::Error;
|
||||
use xet_client::ClientError;
|
||||
use xet_core_structures::FormatError;
|
||||
use xet_core_structures::CoreError;
|
||||
use xet_data::DataError;
|
||||
use xet_data::file_reconstruction::FileReconstructionError;
|
||||
use xet_data::progress_tracking::UniqueID;
|
||||
use xet_runtime::RuntimeError;
|
||||
|
||||
@@ -35,10 +36,14 @@ pub enum XetError {
|
||||
#[error("Authentication error: {0}")]
|
||||
Authentication(String),
|
||||
|
||||
/// Network-level failures: DNS, timeouts, HTTP 5xx, etc.
|
||||
/// Network-level failures: DNS, HTTP 5xx, connection reset, etc.
|
||||
#[error("Network error: {0}")]
|
||||
Network(String),
|
||||
|
||||
/// A network request timed out.
|
||||
#[error("Timeout: {0}")]
|
||||
Timeout(String),
|
||||
|
||||
/// A requested resource (file, XORB, shard) does not exist.
|
||||
#[error("Not found: {0}")]
|
||||
NotFound(String),
|
||||
@@ -84,31 +89,34 @@ impl XetError {
|
||||
}
|
||||
}
|
||||
|
||||
fn from_format_error_ref(fe: &FormatError) -> Self {
|
||||
fn from_core_error_ref(fe: &CoreError) -> Self {
|
||||
match fe {
|
||||
FormatError::Io(_) => XetError::Io(fe.to_string()),
|
||||
FormatError::ShardNotFound(_) | FormatError::FileNotFound(_) => XetError::NotFound(fe.to_string()),
|
||||
FormatError::HashMismatch
|
||||
| FormatError::TruncatedHashCollision(_)
|
||||
| FormatError::InvalidShard(_)
|
||||
| FormatError::ShardVersion(_)
|
||||
| FormatError::ChunkHeaderParse
|
||||
| FormatError::Format(_)
|
||||
| FormatError::Compression(_) => XetError::DataIntegrity(fe.to_string()),
|
||||
FormatError::InvalidRange | FormatError::InvalidArguments | FormatError::BadFilename(_) => {
|
||||
CoreError::Io(_) => XetError::Io(fe.to_string()),
|
||||
CoreError::ShardNotFound(_) | CoreError::FileNotFound(_) => XetError::NotFound(fe.to_string()),
|
||||
CoreError::HashMismatch
|
||||
| CoreError::TruncatedHashCollision(_)
|
||||
| CoreError::InvalidShard(_)
|
||||
| CoreError::ShardVersion(_)
|
||||
| CoreError::ChunkHeaderParse
|
||||
| CoreError::MalformedData(_)
|
||||
| CoreError::CompressionError(_) => XetError::DataIntegrity(fe.to_string()),
|
||||
CoreError::InvalidRange | CoreError::InvalidArguments | CoreError::BadFilename(_) => {
|
||||
XetError::Configuration(fe.to_string())
|
||||
},
|
||||
FormatError::Runtime(re) => XetError::from_runtime_error_ref(re),
|
||||
CoreError::RuntimeError(re) => XetError::from_runtime_error_ref(re),
|
||||
_ => XetError::Internal(fe.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn from_client_error_ref(ce: &ClientError) -> Self {
|
||||
match ce {
|
||||
ClientError::AuthError(_) => XetError::Authentication(ce.to_string()),
|
||||
ClientError::ReqwestError(_, _)
|
||||
| ClientError::ReqwestMiddlewareError(_)
|
||||
| ClientError::PresignedUrlExpirationError => XetError::Network(ce.to_string()),
|
||||
ClientError::AuthError(_) | ClientError::PresignedUrlExpirationError | ClientError::CredentialHelper(_) => {
|
||||
XetError::Authentication(ce.to_string())
|
||||
},
|
||||
ClientError::ReqwestError(e, _) if e.is_timeout() => XetError::Timeout(ce.to_string()),
|
||||
ClientError::ReqwestError(_, _) | ClientError::ReqwestMiddlewareError(_) => {
|
||||
XetError::Network(ce.to_string())
|
||||
},
|
||||
ClientError::FileNotFound(_) | ClientError::XORBNotFound(_) => XetError::NotFound(ce.to_string()),
|
||||
ClientError::ConfigurationError(_)
|
||||
| ClientError::InvalidArguments
|
||||
@@ -117,16 +125,31 @@ impl XetError {
|
||||
| ClientError::InvalidKey(_)
|
||||
| ClientError::InvalidRepoType(_) => XetError::Configuration(ce.to_string()),
|
||||
ClientError::IOError(_) => XetError::Io(ce.to_string()),
|
||||
ClientError::FormatError(fe) => XetError::from_format_error_ref(fe),
|
||||
ClientError::FormatError(fe) => XetError::from_core_error_ref(fe),
|
||||
_ => XetError::Internal(ce.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn from_file_reconstruction_error_ref(fre: &FileReconstructionError) -> Self {
|
||||
match fre {
|
||||
FileReconstructionError::ClientError(ce) => XetError::from_client_error_ref(ce),
|
||||
FileReconstructionError::IoError(_) => XetError::Io(fre.to_string()),
|
||||
FileReconstructionError::RuntimeError(re) => XetError::from_runtime_error_ref(re),
|
||||
FileReconstructionError::TaskJoinError(je) if je.is_cancelled() => {
|
||||
XetError::Cancelled(format!("Task cancelled: {je}"))
|
||||
},
|
||||
FileReconstructionError::TaskJoinError(je) => XetError::Internal(format!("Task join error: {je}")),
|
||||
FileReconstructionError::ConfigurationError(_) => XetError::Configuration(fre.to_string()),
|
||||
FileReconstructionError::CorruptedReconstruction(_) => XetError::DataIntegrity(fre.to_string()),
|
||||
_ => XetError::Internal(fre.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn from_data_error_ref(de: &DataError) -> Self {
|
||||
match de {
|
||||
DataError::AuthError(_) => XetError::Authentication(de.to_string()),
|
||||
DataError::ClientError(ce) => XetError::from_client_error_ref(ce),
|
||||
DataError::FormatError(fe) => XetError::from_format_error_ref(fe),
|
||||
DataError::FormatError(fe) => XetError::from_core_error_ref(fe),
|
||||
DataError::IOError(_) => XetError::Io(de.to_string()),
|
||||
DataError::RuntimeError(re) => XetError::from_runtime_error_ref(re),
|
||||
DataError::FileQueryPolicyError(_)
|
||||
@@ -138,6 +161,7 @@ impl XetError {
|
||||
DataError::HashNotFound => XetError::NotFound(de.to_string()),
|
||||
DataError::HashStringParsingFailure(_) => XetError::DataIntegrity(de.to_string()),
|
||||
DataError::InvalidOperation(_) => XetError::Configuration(de.to_string()),
|
||||
DataError::FileReconstructionError(fre) => XetError::from_file_reconstruction_error_ref(fre),
|
||||
_ => XetError::Internal(de.to_string()),
|
||||
}
|
||||
}
|
||||
@@ -151,9 +175,9 @@ impl From<RuntimeError> for XetError {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<FormatError> for XetError {
|
||||
fn from(e: FormatError) -> Self {
|
||||
XetError::from_format_error_ref(&e)
|
||||
impl From<CoreError> for XetError {
|
||||
fn from(e: CoreError) -> Self {
|
||||
XetError::from_core_error_ref(&e)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -169,6 +193,12 @@ impl From<DataError> for XetError {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<FileReconstructionError> for XetError {
|
||||
fn from(e: FileReconstructionError) -> Self {
|
||||
XetError::from_file_reconstruction_error_ref(&e)
|
||||
}
|
||||
}
|
||||
|
||||
// -- Convenience From impls for common error types -----------------------
|
||||
|
||||
impl From<std::io::Error> for XetError {
|
||||
@@ -211,6 +241,58 @@ impl<T> From<std::sync::PoisonError<std::sync::RwLockReadGuard<'_, T>>> for XetE
|
||||
}
|
||||
}
|
||||
|
||||
// -- Python exception classes & conversion --------------------------------
|
||||
|
||||
#[cfg(feature = "python")]
|
||||
mod py_exceptions {
|
||||
// Inherits from Python's PermissionError so `except PermissionError` still catches it.
|
||||
pyo3::create_exception!(hf_xet, XetAuthenticationError, pyo3::exceptions::PyPermissionError);
|
||||
|
||||
// Inherits from Python's FileNotFoundError so `except FileNotFoundError` still catches it.
|
||||
pyo3::create_exception!(hf_xet, XetObjectNotFoundError, pyo3::exceptions::PyFileNotFoundError);
|
||||
|
||||
/// Register the custom exception classes on a Python module.
|
||||
///
|
||||
/// Call this from the `#[pymodule]` init function so that the exceptions
|
||||
/// are importable as `hf_xet.XetAuthenticationError`, etc.
|
||||
pub fn register_exceptions(m: &pyo3::Bound<'_, pyo3::types::PyModule>) -> pyo3::PyResult<()> {
|
||||
use pyo3::types::PyModuleMethods;
|
||||
|
||||
m.add("XetAuthenticationError", m.py().get_type::<XetAuthenticationError>())?;
|
||||
m.add("XetObjectNotFoundError", m.py().get_type::<XetObjectNotFoundError>())?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python")]
|
||||
pub use py_exceptions::{XetAuthenticationError, XetObjectNotFoundError, register_exceptions};
|
||||
|
||||
#[cfg(feature = "python")]
|
||||
impl From<XetError> for pyo3::PyErr {
|
||||
fn from(err: XetError) -> pyo3::PyErr {
|
||||
use pyo3::exceptions::{PyConnectionError, PyOSError, PyRuntimeError, PyTimeoutError, PyValueError};
|
||||
|
||||
let msg = err.to_string();
|
||||
#[allow(unreachable_patterns)] // XetError is #[non_exhaustive]
|
||||
match err {
|
||||
XetError::Authentication(_) => XetAuthenticationError::new_err(msg),
|
||||
XetError::NotFound(_) => XetObjectNotFoundError::new_err(msg),
|
||||
XetError::Network(_) => PyConnectionError::new_err(msg),
|
||||
XetError::Timeout(_) => PyTimeoutError::new_err(msg),
|
||||
XetError::Io(_) => PyOSError::new_err(msg),
|
||||
XetError::Configuration(_) | XetError::InvalidTaskID(_) => PyValueError::new_err(msg),
|
||||
XetError::DataIntegrity(_)
|
||||
| XetError::Internal(_)
|
||||
| XetError::WrongRuntimeMode(_)
|
||||
| XetError::AlreadyCommitted
|
||||
| XetError::AlreadyFinished
|
||||
| XetError::Aborted
|
||||
| XetError::Cancelled(_) => PyRuntimeError::new_err(msg),
|
||||
_ => PyRuntimeError::new_err(msg),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use xet_client::cas_client::auth::AuthError;
|
||||
@@ -226,13 +308,13 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn format_not_found_maps_to_not_found() {
|
||||
let err = XetError::from(FormatError::ShardNotFound(MerkleHash::default()));
|
||||
let err = XetError::from(CoreError::ShardNotFound(MerkleHash::default()));
|
||||
assert!(matches!(err, XetError::NotFound(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_invalid_args_maps_to_configuration() {
|
||||
let err = XetError::from(FormatError::InvalidArguments);
|
||||
let err = XetError::from(CoreError::InvalidArguments);
|
||||
assert!(matches!(err, XetError::Configuration(_)));
|
||||
}
|
||||
|
||||
@@ -244,7 +326,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn client_nested_format_maps_using_format_rules() {
|
||||
let err = XetError::from(ClientError::FormatError(FormatError::InvalidRange));
|
||||
let err = XetError::from(ClientError::FormatError(CoreError::InvalidRange));
|
||||
assert!(matches!(err, XetError::Configuration(_)));
|
||||
}
|
||||
|
||||
@@ -259,4 +341,37 @@ mod tests {
|
||||
let err = XetError::from(DataError::RuntimeError(RuntimeError::TaskCanceled("cancelled".to_string())));
|
||||
assert!(matches!(err, XetError::Cancelled(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn presigned_url_expiration_maps_to_authentication() {
|
||||
let err = XetError::from(ClientError::PresignedUrlExpirationError);
|
||||
assert!(matches!(err, XetError::Authentication(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn credential_helper_maps_to_authentication() {
|
||||
let err = XetError::from(ClientError::credential_helper_error(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
"cred fail",
|
||||
)));
|
||||
assert!(matches!(err, XetError::Authentication(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn client_not_found_maps_to_not_found() {
|
||||
let err = XetError::from(ClientError::FileNotFound(MerkleHash::default()));
|
||||
assert!(matches!(err, XetError::NotFound(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn client_xorb_not_found_maps_to_not_found() {
|
||||
let err = XetError::from(ClientError::XORBNotFound(MerkleHash::default()));
|
||||
assert!(matches!(err, XetError::NotFound(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn client_io_maps_to_io() {
|
||||
let err = XetError::from(ClientError::IOError(std::io::Error::new(std::io::ErrorKind::NotFound, "gone")));
|
||||
assert!(matches!(err, XetError::Io(_)));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,17 +6,12 @@ use tracing::{Instrument, Span, info_span, instrument};
|
||||
use xet_client::cas_client::auth::TokenRefresher;
|
||||
pub use xet_data::processing::data_client::hash_files_async;
|
||||
use xet_data::processing::data_client::{clean_bytes, default_config};
|
||||
use xet_data::processing::errors::DataProcessingError;
|
||||
use xet_data::processing::{FileDownloadSession, FileUploadSession, Sha256Policy, XetFileInfo};
|
||||
use xet_data::{DataError, Result};
|
||||
use xet_runtime::core::par_utils::run_constrained_with_semaphore;
|
||||
use xet_runtime::core::{XetRuntime, xet_config};
|
||||
|
||||
use super::progress_tracking::{GroupProgressCallbackUpdater, ItemProgressCallbackUpdater, TrackingProgressUpdater};
|
||||
use crate::legacy::data_client::errors::Result;
|
||||
|
||||
mod errors {
|
||||
pub use xet_data::processing::errors::Result;
|
||||
}
|
||||
|
||||
#[instrument(skip_all, name = "data_client::upload_bytes", fields(session_id = tracing::field::Empty, num_files=file_contents.len()))]
|
||||
pub async fn upload_bytes_async(
|
||||
@@ -29,7 +24,7 @@ pub async fn upload_bytes_async(
|
||||
custom_headers: Option<Arc<HeaderMap>>,
|
||||
) -> Result<Vec<XetFileInfo>> {
|
||||
if sha256_policies.len() != file_contents.len() {
|
||||
return Err(DataProcessingError::ParameterError(format!(
|
||||
return Err(DataError::ParameterError(format!(
|
||||
"sha256_policies length ({}) must match file_contents length ({})",
|
||||
sha256_policies.len(),
|
||||
file_contents.len()
|
||||
@@ -86,7 +81,7 @@ pub async fn upload_async(
|
||||
custom_headers: Option<Arc<HeaderMap>>,
|
||||
) -> Result<Vec<XetFileInfo>> {
|
||||
if sha256_policies.len() != file_paths.len() {
|
||||
return Err(DataProcessingError::ParameterError(format!(
|
||||
return Err(DataError::ParameterError(format!(
|
||||
"sha256_policies length ({}) must match file_paths length ({})",
|
||||
sha256_policies.len(),
|
||||
file_paths.len()
|
||||
@@ -140,7 +135,7 @@ pub async fn download_async(
|
||||
if let Some(updaters) = &progress_updaters
|
||||
&& updaters.len() != file_infos.len()
|
||||
{
|
||||
return Err(DataProcessingError::ParameterError("updaters are not same length as pointer_files".to_string()));
|
||||
return Err(DataError::ParameterError("updaters are not same length as pointer_files".to_string()));
|
||||
}
|
||||
let config: Arc<_> = default_config(
|
||||
endpoint.unwrap_or_else(|| xet_config().data.default_cas_endpoint.clone()),
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user