diff --git a/Cargo.lock b/Cargo.lock index 5bfd4f6b..c414abed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1655,6 +1655,7 @@ dependencies = [ "thiserror 2.0.18", "tokio", "xet-client", + "xet-data", "xet-runtime", ] @@ -1879,6 +1880,7 @@ dependencies = [ "futures", "http", "more-asserts", + "pyo3", "serde", "serde_json", "serial_test", @@ -4541,6 +4543,7 @@ checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" name = "simulation" version = "1.4.0" dependencies = [ + "anyhow", "bytes", "chrono", "clap", @@ -6220,7 +6223,6 @@ dependencies = [ name = "xet-core-structures" version = "1.4.0" dependencies = [ - "anyhow", "async-trait", "base64 0.22.1", "bincode", diff --git a/api_changes/update_260318_anyhow_removal_error_alias_cleanup.md b/api_changes/update_260318_anyhow_removal_error_alias_cleanup.md new file mode 100644 index 00000000..b6caa628 --- /dev/null +++ b/api_changes/update_260318_anyhow_removal_error_alias_cleanup.md @@ -0,0 +1,97 @@ +# API Update: Anyhow Removal and Error Alias Cleanup (2026-03-18) + +## Overview + +This change removes direct `anyhow` usage from core crates and finalizes migration to canonical +package-level error types. + +Main impact: +- public legacy error alias modules were removed; +- several public error enum variants changed payload types; +- `SessionError` alias usage was removed in favor of `xet::XetError`; +- Python exception mapping was tightened for `hf_xet`; +- downstream imports should now use canonical error paths directly. + +This is an API-breaking cleanup for callers still importing old alias paths. + +## Canonical Error Types (now required) + +Use these types directly: +- `xet_client::ClientError` +- `xet_core_structures::CoreError` +- `xet_data::DataError` +- `xet::XetError` +- `xet_runtime::RuntimeError` + +## Removed Legacy Alias Modules + +The following compatibility modules were removed: +- `xet_client::cas_client::error` (previous `CasClientError` alias) +- `xet_client::cas_types::error` (previous `CasTypesError` alias) +- `xet_client::hub_client::errors` (previous `HubClientError` alias) +- `xet_core_structures::metadata_shard::error` (previous `MDBShardError` alias) +- `xet_core_structures::xorb_object::error` (previous `XorbObjectError` alias) +- `xet_data::processing::errors` (previous `DataProcessingError` alias) +- `xet::xet_session::errors` (previous `SessionError` alias) +- `xet_runtime::core::errors` (previous `MultithreadedRuntimeError` alias) + +## Breaking Type/Variant Changes + +### `xet_client::ClientError` +- `InternalError(anyhow::Error)` -> `InternalError(String)` +- `CredentialHelper(anyhow::Error)` -> `CredentialHelper(String)` + +### `xet_core_structures::CoreError` (renamed from `FormatError`) +- `Internal(anyhow::Error)` -> `InternalError(String)` +- `Format(anyhow::Error)` -> `MalformedData(String)` (or a more specific `CoreError` variant) + +### `xet::xet_session` / `xet::XetError` +- `xet::xet_session::SessionError` alias was removed. +- Public session APIs now return `Result<_, xet::XetError>`. +- `ClientError::PresignedUrlExpirationError` now maps to `XetError::Authentication`. +- `XetError::Timeout(String)` is used for timeout-class network failures. + +Code matching old variant names or payload types must be updated. + +## Trait Signature Change + +`xet_client::hub_client::CredentialHelper` now uses: + +```rust +async fn fill_credential(&self, req: RequestBuilder) -> Result; +``` + +## Migration Guide + +Replace old imports and aliases directly: + +```rust +// old +use xet_client::cas_client::CasClientError; +use xet_data::processing::errors::DataProcessingError; +use xet::xet_session::SessionError; +use xet_runtime::core::errors::MultithreadedRuntimeError; + +// new +use xet_client::ClientError; +use xet_data::DataError; +use xet::XetError; +use xet_runtime::RuntimeError; +``` + +## Python (`hf_xet`) behavior + +`From for PyErr` now maps: +- `Authentication` -> `hf_xet.XetAuthenticationError` (inherits `PermissionError`) +- `NotFound` -> `hf_xet.XetObjectNotFoundError` (inherits `FileNotFoundError`) +- `Network` -> `ConnectionError` +- `Timeout` -> `TimeoutError` +- `Cancelled` -> `RuntimeError` + +For constructors that previously accepted `anyhow::Error`, construct string-backed variants +instead (`InternalError`, `CredentialHelper`, `MalformedData`, and related `CoreError` variants). + +## Behavior Notes + +No intended runtime behavior change in retry, session, or reconstruction logic; this update is +primarily an error-surface and API cleanup. diff --git a/git_xet/Cargo.toml b/git_xet/Cargo.toml index 2973d85d..77117cf5 100644 --- a/git_xet/Cargo.toml +++ b/git_xet/Cargo.toml @@ -9,6 +9,7 @@ name = "git-xet" path = "src/bin/main.rs" [dependencies] +xet-data = { path = "../xet_data" } xet-runtime = { path = "../xet_runtime" } xet-client = { path = "../xet_client" } xet-pkg = { package = "hf-xet", path = "../xet_pkg" } diff --git a/git_xet/src/auth/ssh.rs b/git_xet/src/auth/ssh.rs index ff1b009e..bbfc2736 100644 --- a/git_xet/src/auth/ssh.rs +++ b/git_xet/src/auth/ssh.rs @@ -4,6 +4,7 @@ use async_trait::async_trait; use reqwest::header; use reqwest_middleware::RequestBuilder; use serde::{Deserialize, Serialize}; +use xet_client::error::ClientError; use xet_client::hub_client::{CredentialHelper, Operation}; use crate::errors::{GitXetError, Result}; @@ -73,8 +74,8 @@ impl SSHCredentialHelper { #[async_trait] impl CredentialHelper for SSHCredentialHelper { - async fn fill_credential(&self, req: RequestBuilder) -> anyhow::Result { - let authenticated = self.authenticate().await?; + async fn fill_credential(&self, req: RequestBuilder) -> std::result::Result { + let authenticated = self.authenticate().await.map_err(ClientError::credential_helper_error)?; Ok(req.header(header::AUTHORIZATION, authenticated.header.authorization)) } diff --git a/git_xet/src/errors.rs b/git_xet/src/errors.rs index bca68aa7..c86a154a 100644 --- a/git_xet/src/errors.rs +++ b/git_xet/src/errors.rs @@ -3,7 +3,7 @@ use std::path::PathBuf; use thiserror::Error; use xet_client::ClientError; -use xet_pkg::legacy::DataProcessingError; +use xet_data::DataError; use crate::lfs_agent_protocol::GitLFSProtocolError; @@ -40,7 +40,7 @@ pub enum GitXetError { Internal(String), #[error("Transfer agent error: {0}")] - TransferAgent(#[from] DataProcessingError), + TransferAgent(#[from] DataError), #[error("Client error: {0}")] Client(#[from] ClientError), diff --git a/git_xet/src/lfs_agent_protocol/protocol_spec.rs b/git_xet/src/lfs_agent_protocol/protocol_spec.rs index 52edc89b..a6f6b1a0 100644 --- a/git_xet/src/lfs_agent_protocol/protocol_spec.rs +++ b/git_xet/src/lfs_agent_protocol/protocol_spec.rs @@ -220,15 +220,13 @@ pub fn to_line_delimited_json_string(value: impl Serialize) -> Result { mod tests { use std::path::Path; - use anyhow::Result; - use super::*; #[test] fn test_protocol_serde_unknown_event() -> Result<()> { let message = r#" { "event": "other", "operation": "upload", "remote": "origin", "concurrent": false }"#; - let parsed: Result = message.parse(); + let parsed: std::result::Result = message.parse(); assert!(matches!(parsed, Err(GitLFSProtocolError::Syntax(_)))); @@ -269,21 +267,21 @@ mod tests { // init event with invalid operation let message1 = r#" { "event": "init", "operation": "other", "remote": "origin", "concurrent": false }"#; - let parsed1: Result = message1.parse(); + let parsed1: std::result::Result = message1.parse(); assert!(matches!(parsed1, Err(GitLFSProtocolError::Syntax(_)))); // init event missing required field let message2 = r#" { "event": "init", "operation": "upload", "remote": "origin" }"#; - let parsed2: Result = message2.parse(); + let parsed2: std::result::Result = message2.parse(); assert!(matches!(parsed2, Err(GitLFSProtocolError::Syntax(_)))); // init event with invalid remote let message2 = r#" { "event": "init", "operation": "upload", "remote": "", "concurrent": false }"#; - let parsed2: Result = message2.parse(); + let parsed2: std::result::Result = message2.parse(); assert!(matches!(parsed2, Err(GitLFSProtocolError::Argument(_)))); @@ -354,7 +352,7 @@ mod tests { let message1 = r#" { "event": "upload", "oid": "bf3e3e2af9366a3b704ae0c31de5afa64193ebabffde2091936ad2e7510bc03a", "size": 346232, "path": "/path/to/file.png" }"#; - let parsed1: Result = message1.parse(); + let parsed1: std::result::Result = message1.parse(); assert!(matches!(parsed1, Err(GitLFSProtocolError::Syntax(_)))); @@ -362,7 +360,7 @@ mod tests { let message2 = r#" { "event": "upload", "oid": "bf3e3e2af9366abc03a", "size": 346232, "path": "/path/to/file.png", "action": { "href": "nfs://server/path", "header": { "key": "value" } } }"#; - let parsed2: Result = message2.parse(); + let parsed2: std::result::Result = message2.parse(); assert!(matches!(parsed2, Err(GitLFSProtocolError::Argument(_)))); @@ -370,7 +368,7 @@ mod tests { let message3 = r#" { "event": "upload", "oid": "bf3e3e2af9366a3b704ae0c31de5afa64193ebabffde2091936ad2e7510bc03a", "size": 0, "path": "/path/to/file.png", "action": { "href": "nfs://server/path", "header": { "key": "value" } } }"#; - let parsed3: Result = message3.parse(); + let parsed3: std::result::Result = message3.parse(); assert!(matches!(parsed3, Err(GitLFSProtocolError::Argument(_)))); @@ -378,7 +376,7 @@ mod tests { let message4 = r#" { "event": "upload", "oid": "bf3e3e2af9366a3b704ae0c31de5afa64193ebabffde2091936ad2e7510bc03a", "size": 346232, "action": { "href": "nfs://server/path", "header": { "key": "value" } } }"#; - let parsed4: Result = message4.parse(); + let parsed4: std::result::Result = message4.parse(); assert!(matches!(parsed4, Err(GitLFSProtocolError::Syntax(_)))); @@ -386,7 +384,7 @@ mod tests { let message5 = r#" { "event": "download", "oid": "bf3e3e2af9366a3b704ae0c31de5afa64193ebabffde2091936ad2e7510bc03a", "size": 12514, "path": "/path/to/file.png", "action": { "href": "https://server/path", "header": { "k1": "v1", "k2": "v2" } } }"#; - let parsed5: Result = message5.parse(); + let parsed5: std::result::Result = message5.parse(); assert!(matches!(parsed5, Err(GitLFSProtocolError::Syntax(_)))); diff --git a/git_xet/src/test_utils/ssh_server.rs b/git_xet/src/test_utils/ssh_server.rs index 7a99484d..f4dd8dcd 100644 --- a/git_xet/src/test_utils/ssh_server.rs +++ b/git_xet/src/test_utils/ssh_server.rs @@ -1,7 +1,7 @@ use std::io; use std::sync::Arc; -use anyhow::anyhow; +use anyhow::{Result, bail}; use rand_core::OsRng; use russh::keys::{Certificate, *}; use russh::server::{Msg, Server as _, Session}; @@ -154,15 +154,15 @@ impl server::Handler for ServerImpl { } impl ServerImpl { - fn git_lfs_authenticate(&self, request: Vec<&str>) -> anyhow::Result { + fn git_lfs_authenticate(&self, request: Vec<&str>) -> Result { let Some(repo_id) = request.get(1) else { - return Err(anyhow!("invalid request, missing repo id")); + bail!("invalid request, missing repo id"); }; let Some(operation) = request.get(2) else { - return Err(anyhow!("invalid request, missing operation")); + bail!("invalid request, missing operation"); }; if !matches!(*operation, "upload" | "download") { - return Err(anyhow!("invalid request, unrecognized operation")); + bail!("invalid request, unrecognized operation"); } let response = GitLFSAuthenticateResponse { header: GitLFSAuthentationResponseHeader { diff --git a/git_xet/src/token_refresher.rs b/git_xet/src/token_refresher.rs index 3afaf531..53464d14 100644 --- a/git_xet/src/token_refresher.rs +++ b/git_xet/src/token_refresher.rs @@ -60,7 +60,7 @@ impl TokenRefresher for DirectRefreshRouteTokenRefresher { let req = cred_helper .fill_credential(req) .await - .map_err(reqwest_middleware::Error::Middleware)?; + .map_err(reqwest_middleware::Error::middleware)?; req.send().await } }) diff --git a/git_xet/src/utils/process_wrapping.rs b/git_xet/src/utils/process_wrapping.rs index 3986622d..3346d3dc 100644 --- a/git_xet/src/utils/process_wrapping.rs +++ b/git_xet/src/utils/process_wrapping.rs @@ -187,10 +187,11 @@ impl CapturedCommand { #[cfg(test)] mod tests { use std::io::Write; + use std::process::Command; use anyhow::Result; - use super::*; + use super::{CapturedCommand, run_program_captured, run_program_captured_with_input_and_output}; #[test] fn test_run_program_captured() -> Result<()> { diff --git a/git_xet/src/utils/ssh_connect.rs b/git_xet/src/utils/ssh_connect.rs index 1d2e4959..3278b5f5 100644 --- a/git_xet/src/utils/ssh_connect.rs +++ b/git_xet/src/utils/ssh_connect.rs @@ -173,8 +173,9 @@ fn format_for_shell_execution(command: &str, args: &[String]) -> Result<(String, #[cfg(test)] mod tests { - use anyhow::{Ok, Result}; use serial_test::serial; + + type Result = std::result::Result>; use xet_runtime::utils::EnvVarGuard; use super::*; diff --git a/hf_xet/Cargo.lock b/hf_xet/Cargo.lock index b0b396dd..c9a70a66 100644 --- a/hf_xet/Cargo.lock +++ b/hf_xet/Cargo.lock @@ -59,9 +59,9 @@ dependencies = [ [[package]] name = "anstream" -version = "0.6.21" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a" +checksum = "824a212faf96e9acacdbd09febd34438f8f711fb84e09a8916013cd7815ca28d" dependencies = [ "anstyle", "anstyle-parse", @@ -74,15 +74,15 @@ dependencies = [ [[package]] name = "anstyle" -version = "1.0.13" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" +checksum = "940b3a0ca603d1eade50a4846a2afffd5ef57a9feac2c0e2ec2e14f9ead76000" [[package]] name = "anstyle-parse" -version = "0.2.7" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" +checksum = "52ce7f38b242319f7cabaa6813055467063ecdc9d355bbb4ce0c68908cd8130e" dependencies = [ "utf8parse", ] @@ -159,9 +159,9 @@ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" [[package]] name = "aws-lc-rs" -version = "1.16.1" +version = "1.16.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94bffc006df10ac2a68c83692d734a465f8ee6c5b384d8545a636f81d858f4bf" +checksum = "a054912289d18629dc78375ba2c3726a3afe3ff71b4edba9dedfca0e3446d1fc" dependencies = [ "aws-lc-sys", "zeroize", @@ -169,9 +169,9 @@ dependencies = [ [[package]] name = "aws-lc-sys" -version = "0.38.0" +version = "0.39.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4321e568ed89bb5a7d291a7f37997c2c0df89809d7b6d12062c81ddb54aa782e" +checksum = "1fa7e52a4c5c547c741610a2c6f123f3881e409b714cd27e6798ef020c514f0a" dependencies = [ "cc", "cmake", @@ -352,9 +352,9 @@ checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" [[package]] name = "cc" -version = "1.2.56" +version = "1.2.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aebf35691d1bfb0ac386a69bac2fde4dd276fb618cf8bf4f5318fe285e821bb2" +checksum = "7a0dd1ca384932ff3641c8718a02769f1698e7563dc6974ffd03346116310423" dependencies = [ "find-msvc-tools", "jobserver", @@ -401,9 +401,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.60" +version = "4.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2797f34da339ce31042b27d23607e051786132987f595b02ba4f6a6dffb7030a" +checksum = "b193af5b67834b676abd72466a96c1024e6a6ad978a1f484bd90b85c94041351" dependencies = [ "clap_builder", "clap_derive", @@ -411,9 +411,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.60" +version = "4.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24a241312cea5059b13574bb9b3861cabf758b879c15190b37b6d6fd63ab6876" +checksum = "714a53001bf66416adb0e2ef5ac857140e7dc3a0c48fb28b2f10762fc4b5069f" dependencies = [ "anstream", "anstyle", @@ -423,9 +423,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.55" +version = "4.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a92793da1a46a5f2a02a6f4c46c6496b28c43638adea8306fcb0caa1634f24e5" +checksum = "1110bd8a634a1ab8cb04345d8d878267d57c3cf1b38d91b71af6686408bbca6a" dependencies = [ "heck", "proc-macro2", @@ -435,9 +435,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "1.0.0" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a822ea5bc7590f9d40f1ba12c0dc3c2760f3482c6984db1573ad11031420831" +checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9" [[package]] name = "cmake" @@ -450,9 +450,9 @@ dependencies = [ [[package]] name = "colorchoice" -version = "1.0.4" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" +checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570" [[package]] name = "colored" @@ -1235,9 +1235,11 @@ checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" name = "hf-xet" version = "1.4.0" dependencies = [ + "anyhow", "async-trait", "http", "more-asserts", + "pyo3", "serde", "thiserror 2.0.18", "tokio", @@ -1593,6 +1595,15 @@ dependencies = [ "str_stack", ] +[[package]] +name = "inventory" +version = "0.3.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "009ae045c87e7082cb72dab0ccd01ae075dd00141ddc108f43a0ea150a9e7227" +dependencies = [ + "rustversion", +] + [[package]] name = "ipnet" version = "2.12.0" @@ -2096,9 +2107,9 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.21.3" +version = "1.21.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" [[package]] name = "once_cell_polyfill" @@ -2114,9 +2125,9 @@ checksum = "269bca4c2591a28585d6bf10d9ed0332b7d76900a1b02bec41bdc3a2cdcda107" [[package]] name = "openssl" -version = "0.10.75" +version = "0.10.76" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328" +checksum = "951c002c75e16ea2c65b8c7e4d3d51d5530d8dfa7d060b4776828c88cfb18ecf" dependencies = [ "bitflags 2.11.0", "cfg-if 1.0.4", @@ -2155,9 +2166,9 @@ dependencies = [ [[package]] name = "openssl-sys" -version = "0.9.111" +version = "0.9.112" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321" +checksum = "57d55af3b3e226502be1526dfdba67ab0e9c96fc293004e79576b2b9edb0dbdb" dependencies = [ "cc", "libc", @@ -2528,6 +2539,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ba0117f4212101ee6544044dae45abe1083d30ce7b29c4b5cbdfa2354e07383" dependencies = [ "indoc", + "inventory", "libc", "memoffset", "once_cell", @@ -3577,9 +3589,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa" +checksum = "3e61e67053d25a4e82c844e8424039d9745781b3fc4f32b8d55ed50f5f667ef3" dependencies = [ "tinyvec_macros", ] @@ -3829,9 +3841,9 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.22" +version = "0.3.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e" +checksum = "cb7f578e5945fb242538965c2d0b04418d38ec25c79d160cd279bf0731c8d319" dependencies = [ "matchers", "nu-ansi-term", @@ -4760,7 +4772,6 @@ dependencies = [ name = "xet-core-structures" version = "1.4.0" dependencies = [ - "anyhow", "async-trait", "base64 0.22.1", "bincode", @@ -4832,6 +4843,7 @@ dependencies = [ name = "xet-runtime" version = "1.4.0" dependencies = [ + "anyhow", "async-trait", "bytes", "chrono", @@ -4850,6 +4862,7 @@ dependencies = [ "more-asserts", "oneshot", "pin-project", + "pyo3", "rand 0.9.2", "reqwest", "serde", @@ -4891,18 +4904,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.42" +version = "0.8.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2578b716f8a7a858b7f02d5bd870c14bf4ddbbcf3a4c05414ba6503640505e3" +checksum = "efbb2a062be311f2ba113ce66f697a4dc589f85e78a4aea276200804cea0ed87" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.42" +version = "0.8.47" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e6cc098ea4d3bd6246687de65af3f920c430e236bee1e3bf2e441463f08a02f" +checksum = "0e8bc7269b54418e7aeeef514aa68f8690b8c0489a06b0136e5f57c4c5ccab89" dependencies = [ "proc-macro2", "quote", diff --git a/hf_xet/Cargo.toml b/hf_xet/Cargo.toml index 5aeab0d5..a12bc25d 100644 --- a/hf_xet/Cargo.toml +++ b/hf_xet/Cargo.toml @@ -12,7 +12,7 @@ crate-type = ["cdylib", "lib"] [dependencies] xet-runtime = { path = "../xet_runtime" } xet-client = { path = "../xet_client" } -xet-pkg = { package = "hf-xet", path = "../xet_pkg" } +xet-pkg = { package = "hf-xet", path = "../xet_pkg", features = ["python"] } async-trait = "0.1" chrono = "0.4" @@ -68,5 +68,3 @@ split-debuginfo = "none" inherits = "dev" debug = true opt-level = 3 - -# cargo-machete has detected the below unused dependency incorrectly diff --git a/hf_xet/src/lib.rs b/hf_xet/src/lib.rs index b8574bfa..d005c43a 100644 --- a/hf_xet/src/lib.rs +++ b/hf_xet/src/lib.rs @@ -10,15 +10,16 @@ use std::sync::Arc; use http::header::{self, HeaderMap, HeaderName, HeaderValue}; use itertools::Itertools; -use pyo3::exceptions::{PyKeyboardInterrupt, PyRuntimeError}; +use pyo3::exceptions::{PyKeyboardInterrupt, PyValueError}; use pyo3::prelude::*; use pyo3::pyfunction; use rand::Rng; use runtime::async_run; use token_refresh::WrappedTokenRefresher; use tracing::debug; +use xet_pkg::XetError; use xet_pkg::legacy::progress_tracking::TrackingProgressUpdater; -use xet_pkg::legacy::{DataProcessingError, Sha256Policy, XetFileInfo, data_client}; +use xet_pkg::legacy::{Sha256Policy, XetFileInfo, data_client}; use xet_runtime::core::file_handle_limits; use crate::logging::init_logging; @@ -40,9 +41,9 @@ fn build_headers_with_user_agent(request_headers: Option let mut map = HeaderMap::new(); for (key, value) in headers { let name = HeaderName::from_bytes(key.as_bytes()) - .map_err(|e| PyRuntimeError::new_err(format!("Invalid header name '{}': {}", key, e)))?; + .map_err(|e| PyValueError::new_err(format!("Invalid header name '{}': {}", key, e)))?; let value = HeaderValue::from_str(&value) - .map_err(|e| PyRuntimeError::new_err(format!("Invalid header value for '{}': {}", key, e)))?; + .map_err(|e| PyValueError::new_err(format!("Invalid header value for '{}': {}", key, e)))?; map.insert(name, value); } Ok::<_, PyErr>(map) @@ -71,12 +72,8 @@ fn build_headers_with_user_agent(request_headers: Option Ok(Some(Arc::new(map))) } -fn convert_data_processing_error(e: DataProcessingError) -> PyErr { - if cfg!(debug_assertions) { - PyRuntimeError::new_err(format!("Data processing error: {e:?}")) - } else { - PyRuntimeError::new_err(format!("Data processing error: {e}")) - } +fn convert_xet_error(e: impl Into) -> PyErr { + PyErr::from(e.into()) } #[pyfunction] @@ -95,13 +92,13 @@ pub fn upload_bytes( skip_sha256: bool, ) -> PyResult> { if skip_sha256 && sha256s.is_some() { - return Err(PyRuntimeError::new_err("skip_sha256=True and sha256s are mutually exclusive")); + return Err(PyValueError::new_err("skip_sha256=True and sha256s are mutually exclusive")); } if let Some(ref s) = sha256s && s.len() != file_contents.len() { - return Err(PyRuntimeError::new_err(format!( + return Err(PyValueError::new_err(format!( "sha256s length ({}) must match file_contents length ({})", s.len(), file_contents.len() @@ -138,7 +135,7 @@ pub fn upload_bytes( header_map, ) .await - .map_err(convert_data_processing_error)? + .map_err(convert_xet_error)? .into_iter() .map(PyXetUploadInfo::from) .collect(); @@ -165,13 +162,13 @@ pub fn upload_files( skip_sha256: bool, ) -> PyResult> { if skip_sha256 && sha256s.is_some() { - return Err(PyRuntimeError::new_err("skip_sha256=True and sha256s are mutually exclusive")); + return Err(PyValueError::new_err("skip_sha256=True and sha256s are mutually exclusive")); } if let Some(ref s) = sha256s && s.len() != file_paths.len() { - return Err(PyRuntimeError::new_err(format!( + return Err(PyValueError::new_err(format!( "sha256s length ({}) must match file_paths length ({})", s.len(), file_paths.len() @@ -212,7 +209,7 @@ pub fn upload_files( header_map, ) .await - .map_err(convert_data_processing_error)? + .map_err(convert_xet_error)? .into_iter() .map(PyXetUploadInfo::from) .collect(); @@ -254,7 +251,7 @@ pub fn hash_files(py: Python, file_paths: Vec) -> PyResult = data_client::hash_files_async(file_paths) .await - .map_err(convert_data_processing_error)? + .map_err(convert_xet_error)? .into_iter() .map(PyXetUploadInfo::from) .collect(); @@ -302,7 +299,7 @@ pub fn download_files( header_map, ) .await - .map_err(convert_data_processing_error)?; + .map_err(convert_xet_error)?; debug!("Download call {x:x}: Completed."); @@ -461,7 +458,6 @@ pub fn hf_xet(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(force_sigint_shutdown, m)?)?; m.add_class::()?; m.add_class::()?; - m.add_class::()?; m.add_class::()?; m.add_class::()?; @@ -470,6 +466,8 @@ pub fn hf_xet(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { // huggingface_hub. m.add_class::()?; + xet_pkg::register_exceptions(m)?; + // Make sure the logger is set up. init_logging(py); diff --git a/hf_xet/src/runtime.rs b/hf_xet/src/runtime.rs index 1cd40e65..c03252c6 100644 --- a/hf_xet/src/runtime.rs +++ b/hf_xet/src/runtime.rs @@ -3,9 +3,10 @@ use std::sync::{Arc, Mutex, RwLock}; use std::time::Duration; use lazy_static::lazy_static; -use pyo3::exceptions::{PyKeyboardInterrupt, PyRuntimeError}; +use pyo3::exceptions::PyKeyboardInterrupt; use pyo3::prelude::*; use tracing::info; +use xet_pkg::XetError; use xet_runtime::RuntimeError; use xet_runtime::core::XetRuntime; use xet_runtime::core::sync_primatives::spawn_os_thread; @@ -207,8 +208,8 @@ fn get_threadpool() -> Result, RuntimeError> { init_threadpool() } -pub fn convert_multithreading_error(e: impl Into + std::fmt::Display) -> PyErr { - PyRuntimeError::new_err(format!("Xet Runtime Error: {e}")) +pub fn convert_multithreading_error(e: impl Into) -> PyErr { + PyErr::from(XetError::from(e.into())) } pub fn async_run(py: Python, execution_call: F) -> PyResult diff --git a/simulation/Cargo.toml b/simulation/Cargo.toml index 6399b37b..cb5caea8 100644 --- a/simulation/Cargo.toml +++ b/simulation/Cargo.toml @@ -19,6 +19,7 @@ path = "src/bin/run_upload_simulations.rs" xet-runtime = { path = "../xet_runtime" } xet-client = { path = "../xet_client" } +anyhow = { workspace = true } chrono = { workspace = true } duration-str = { workspace = true } clap = { workspace = true } diff --git a/simulation/src/bin/run_upload_simulations.rs b/simulation/src/bin/run_upload_simulations.rs index 17d9254e..a9281a9b 100644 --- a/simulation/src/bin/run_upload_simulations.rs +++ b/simulation/src/bin/run_upload_simulations.rs @@ -11,6 +11,7 @@ use std::process::Command; use std::sync::{Arc, mpsc}; use std::thread; +use anyhow::{Result, anyhow, bail}; use clap::Parser; use simulation::scenario::VALID_SCENARIOS; use simulation::upload_concurrency::generate_summary_csv; @@ -121,7 +122,7 @@ fn scenario_binary() -> PathBuf { dir.join(name) } -fn main() -> Result<(), Box> { +fn main() -> Result<()> { tracing_subscriber::fmt() .with_env_filter(tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| "info".into())) .with_ansi(false) @@ -144,7 +145,7 @@ fn main() -> Result<(), Box> { }; for s in &scenarios { if !VALID_SCENARIOS.contains(&s.as_str()) { - return Err(format!("Unknown scenario: {}. Valid: {:?}", s, VALID_SCENARIOS).into()); + bail!("Unknown scenario: {}. Valid: {:?}", s, VALID_SCENARIOS); } } @@ -179,11 +180,10 @@ fn main() -> Result<(), Box> { let bin = args.scenario_bin.unwrap_or_else(scenario_binary); if !bin.exists() { - return Err(format!( + bail!( "run_upload_scenario binary not found at {}; build with cargo build --release -p simulation", bin.display() - ) - .into()); + ); } let total_runs = @@ -322,7 +322,7 @@ fn main() -> Result<(), Box> { ); for h in handles { - h.join().map_err(|_| "scenario thread panicked")?; + h.join().map_err(|_| anyhow!("scenario thread panicked"))?; } generate_summary_csv(&results_base)?; diff --git a/simulation/src/scenario.rs b/simulation/src/scenario.rs index 15a12da3..eaca09e8 100644 --- a/simulation/src/scenario.rs +++ b/simulation/src/scenario.rs @@ -35,7 +35,7 @@ pub enum ScenarioError { #[error("Scenario error: {0}")] Scenario(String), #[error(transparent)] - CasClient(#[from] xet_client::cas_client::CasClientError), + CasClient(#[from] xet_client::ClientError), #[error(transparent)] Io(#[from] std::io::Error), #[error(transparent)] diff --git a/simulation/src/upload_concurrency/reporting.rs b/simulation/src/upload_concurrency/reporting.rs index f4733f36..0b8c899d 100644 --- a/simulation/src/upload_concurrency/reporting.rs +++ b/simulation/src/upload_concurrency/reporting.rs @@ -4,6 +4,7 @@ use std::collections::{BTreeMap, HashMap}; use std::fs; use std::path::Path; +use anyhow::{Result, anyhow, bail}; use serde::Deserialize; use super::upload_simulation_client::ClientMetrics; @@ -18,7 +19,7 @@ struct ClientTimelineData { stats_by_timestamp: BTreeMap, } -fn load_json_lines(file_path: &Path) -> Result, Box> +fn load_json_lines(file_path: &Path) -> Result> where T: for<'de> Deserialize<'de>, { @@ -34,7 +35,7 @@ where } /// Generates timeline.csv in the given scenario directory from client_stats_*.json files. -pub fn generate_timeline_csv(scenario_dir: &Path) -> Result<(), Box> { +pub fn generate_timeline_csv(scenario_dir: &Path) -> Result<()> { let mut client_timelines: HashMap = HashMap::new(); for entry in fs::read_dir(scenario_dir)? { @@ -64,7 +65,7 @@ pub fn generate_timeline_csv(scenario_dir: &Path) -> Result<(), Box = client_timelines.values().collect(); @@ -73,7 +74,7 @@ pub fn generate_timeline_csv(scenario_dir: &Path) -> Result<(), Box Result> { +fn process_timeline_csv(timeline_path: &Path, scenario_dir: &Path, scenario_name: &str) -> Result { let content = fs::read_to_string(timeline_path)?; let lines: Vec<&str> = content.lines().collect(); if lines.is_empty() { - return Err("Empty timeline.csv".into()); + bail!("Empty timeline.csv"); } let header = lines[0]; @@ -184,20 +181,23 @@ fn process_timeline_csv( let total_bytes_idx = header_cols .iter() .position(|&s| s == "total_bytes") - .ok_or("total_bytes column")?; - let elapsed_ms_idx = header_cols.iter().position(|&s| s == "elapsed_ms").ok_or("elapsed_ms column")?; + .ok_or_else(|| anyhow!("total_bytes column"))?; + let elapsed_ms_idx = header_cols + .iter() + .position(|&s| s == "elapsed_ms") + .ok_or_else(|| anyhow!("elapsed_ms column"))?; let total_retries_idx = header_cols .iter() .position(|&s| s == "total_retries") - .ok_or("total_retries column")?; + .ok_or_else(|| anyhow!("total_retries column"))?; let total_concurrency_idx = header_cols .iter() .position(|&s| s == "total_concurrency") - .ok_or("total_concurrency column")?; + .ok_or_else(|| anyhow!("total_concurrency column"))?; let average_concurrency_idx = header_cols .iter() .position(|&s| s == "average_concurrency") - .ok_or("average_concurrency column")?; + .ok_or_else(|| anyhow!("average_concurrency column"))?; let client_concurrency_indices: Vec = header_cols .iter() @@ -263,7 +263,7 @@ fn process_timeline_csv( } if row_count == 0 { - return Err("No data rows in timeline.csv".into()); + bail!("No data rows in timeline.csv"); } let avg_total_concurrency = total_concurrency_sum / row_count as f64; @@ -296,11 +296,7 @@ fn process_timeline_csv( /// Computes network utilization as (bytes_sent / max_possible_bytes) * 100, capped at 100%. /// max_possible_bytes = sum over segments of (bandwidth_bytes_per_sec * segment_sec) from network_stats.json, /// using elapsed_sec for segment boundaries so we don't mix absolute timestamps with duration. -fn calculate_network_utilization( - scenario_dir: &Path, - total_bytes: f64, - duration_sec: f64, -) -> Result> { +fn calculate_network_utilization(scenario_dir: &Path, total_bytes: f64, duration_sec: f64) -> Result { let network_stats_path = scenario_dir.join("network_stats.json"); if !network_stats_path.exists() || duration_sec <= 0.0 { return Ok(0.0); @@ -338,7 +334,7 @@ fn calculate_network_utilization( } } -fn calculate_average_rtt(scenario_dir: &Path) -> Result> { +fn calculate_average_rtt(scenario_dir: &Path) -> Result { let mut all_rtts = Vec::new(); for entry in fs::read_dir(scenario_dir)? { let entry = entry?; @@ -363,7 +359,7 @@ fn calculate_average_rtt(scenario_dir: &Path) -> Result Result<(), Box> { +pub fn generate_summary_csv(results_dir: &Path) -> Result<()> { let mut scenario_dirs = Vec::new(); for entry in fs::read_dir(results_dir)? { let entry = entry?; @@ -373,7 +369,7 @@ pub fn generate_summary_csv(results_dir: &Path) -> Result<(), Box Result<(), Box Result<(), Box> { +) -> Result<()> { run_upload_clients_impl(server_addr, output_dir, min_data_kb, max_data_kb, None, Some(cancel)).await } @@ -92,7 +93,7 @@ pub async fn run_upload_clients( min_data_kb: u64, max_data_kb: u64, repeat_duration_seconds: u64, -) -> Result<(), Box> { +) -> Result<()> { run_upload_clients_impl(server_addr, output_dir, min_data_kb, max_data_kb, Some(repeat_duration_seconds), None) .await } @@ -304,12 +305,12 @@ async fn run_upload_clients_impl( max_data_kb: u64, repeat_duration_seconds: Option, cancel: Option, -) -> Result<(), Box> { +) -> Result<()> { let min_data_size = min_data_kb * 1024; let max_data_size = max_data_kb * 1024; let client_id = rand::rng().random_range(0..1000000000_u64); - let http_client = build_http_client("test_session", None, None).map_err(|e| e.to_string())?; + let http_client = build_http_client("test_session", None, None)?; let duration_sec = repeat_duration_seconds.unwrap_or(u64::MAX); let client_params = serde_json::json!({ diff --git a/wasm/hf_xet_thin_wasm/Cargo.lock b/wasm/hf_xet_thin_wasm/Cargo.lock index b2ccc23b..5e6eb580 100644 --- a/wasm/hf_xet_thin_wasm/Cargo.lock +++ b/wasm/hf_xet_thin_wasm/Cargo.lock @@ -3969,7 +3969,6 @@ dependencies = [ name = "xet-core-structures" version = "1.4.0" dependencies = [ - "anyhow", "async-trait", "base64", "bincode", @@ -4041,6 +4040,7 @@ dependencies = [ name = "xet-runtime" version = "1.4.0" dependencies = [ + "anyhow", "async-trait", "bytes", "chrono", diff --git a/wasm/hf_xet_wasm/Cargo.lock b/wasm/hf_xet_wasm/Cargo.lock index cd136912..3d949c32 100644 --- a/wasm/hf_xet_wasm/Cargo.lock +++ b/wasm/hf_xet_wasm/Cargo.lock @@ -4141,7 +4141,6 @@ dependencies = [ name = "xet-core-structures" version = "1.4.0" dependencies = [ - "anyhow", "async-trait", "base64", "bincode", @@ -4213,6 +4212,7 @@ dependencies = [ name = "xet-runtime" version = "1.4.0" dependencies = [ + "anyhow", "async-trait", "bytes", "chrono", diff --git a/wasm/hf_xet_wasm/src/errors.rs b/wasm/hf_xet_wasm/src/errors.rs index 252cf20a..8b20886a 100644 --- a/wasm/hf_xet_wasm/src/errors.rs +++ b/wasm/hf_xet_wasm/src/errors.rs @@ -2,7 +2,7 @@ use std::fmt::Debug; use thiserror::Error; use xet_client::ClientError; -use xet_core_structures::FormatError; +use xet_core_structures::CoreError; use xet_core_structures::merklehash::DataHashHexParseError; #[non_exhaustive] @@ -17,8 +17,8 @@ pub enum DataProcessingError { #[error("Client error: {0}")] ClientError(#[from] ClientError), - #[error("Format error: {0}")] - FormatError(#[from] FormatError), + #[error("Core structures error: {0}")] + CoreError(#[from] CoreError), } impl DataProcessingError { diff --git a/wasm/hf_xet_wasm/src/wasm_deduplication_interface.rs b/wasm/hf_xet_wasm/src/wasm_deduplication_interface.rs index bd59546a..3b9f0c1e 100644 --- a/wasm/hf_xet_wasm/src/wasm_deduplication_interface.rs +++ b/wasm/hf_xet_wasm/src/wasm_deduplication_interface.rs @@ -5,7 +5,7 @@ use std::sync::Arc; use async_trait::async_trait; use tokio_with_wasm::alias as wasmtokio; -use xet_client::cas_client::CasClientError; +use xet_client::ClientError; use xet_core_structures::merklehash::{HMACKey, MerkleHash}; use xet_core_structures::metadata_shard::MDBShardInfo; use xet_core_structures::metadata_shard::file_structs::FileDataSequenceEntry; @@ -19,7 +19,7 @@ use super::wasm_file_upload_session::FileUploadSession; pub struct UploadSessionDataManager { session: Arc, shard: HashMap, - query_tasks: wasmtokio::task::JoinSet, CasClientError>>, + query_tasks: wasmtokio::task::JoinSet, ClientError>>, } impl UploadSessionDataManager { diff --git a/wasm/hf_xet_wasm/src/xorb_uploader.rs b/wasm/hf_xet_wasm/src/xorb_uploader.rs index 9aa636f3..e5f234c6 100644 --- a/wasm/hf_xet_wasm/src/xorb_uploader.rs +++ b/wasm/hf_xet_wasm/src/xorb_uploader.rs @@ -3,7 +3,8 @@ use std::sync::Arc; use async_trait::async_trait; use tokio_with_wasm::alias as wasmtokio; -use xet_client::cas_client::{CasClientError, Client}; +use xet_client::cas_client::Client; +use xet_client::ClientError; use xet_core_structures::xorb_object::SerializedXorbObject; use crate::errors::*; @@ -49,7 +50,7 @@ impl XorbUploader for XorbUploaderLocalSequential { pub struct XorbUploaderSpawnParallel { client: Arc, cas_prefix: String, - tasks: wasmtokio::task::JoinSet>, + tasks: wasmtokio::task::JoinSet>, } impl XorbUploaderSpawnParallel { diff --git a/xet_client/src/cas_client/adaptive_concurrency/controller.rs b/xet_client/src/cas_client/adaptive_concurrency/controller.rs index 39f57635..8af46b01 100644 --- a/xet_client/src/cas_client/adaptive_concurrency/controller.rs +++ b/xet_client/src/cas_client/adaptive_concurrency/controller.rs @@ -14,9 +14,9 @@ use xet_core_structures::ExpWeightedMovingAvg; use xet_runtime::core::xet_config; use xet_runtime::utils::adjustable_semaphore::{AdjustableSemaphore, AdjustableSemaphorePermit}; -use super::super::error::CasClientError; use super::super::progress_tracked_streams::ProgressCallback; use super::rtt_prediction::RTTPredictor; +use crate::error::Result; const MIN_PARTIAL_REPORT_INTERVAL_MS: u64 = 200; const PARTIAL_REPORT_WEIGHT_RATIO: f64 = 0.2; @@ -361,7 +361,7 @@ impl AdaptiveConcurrencyController { ) } - pub async fn acquire_connection_permit(self: &Arc) -> Result { + pub async fn acquire_connection_permit(self: &Arc) -> Result { let _permit = self.concurrency_semaphore.acquire().await?; let info = Arc::new(ConnectionPermitInfo { diff --git a/xet_client/src/cas_client/error.rs b/xet_client/src/cas_client/error.rs deleted file mode 100644 index c74cd6a4..00000000 --- a/xet_client/src/cas_client/error.rs +++ /dev/null @@ -1 +0,0 @@ -pub use crate::error::{ClientError as CasClientError, Result}; diff --git a/xet_client/src/cas_client/http_client.rs b/xet_client/src/cas_client/http_client.rs index f58122dc..43dce0ed 100644 --- a/xet_client/src/cas_client/http_client.rs +++ b/xet_client/src/cas_client/http_client.rs @@ -1,6 +1,5 @@ use std::sync::Arc; -use anyhow::anyhow; use http::{Extensions, HeaderMap, StatusCode}; use reqwest::header::{AUTHORIZATION, COOKIE, HeaderValue, SET_COOKIE}; use reqwest::{Request, Response}; @@ -11,8 +10,8 @@ use xet_runtime::core::{XetRuntime, xet_config}; use xet_runtime::error_printer::{ErrorPrinter, OptionPrinter}; use super::auth::{AuthConfig, TokenProvider}; -use super::{CasClientError, error}; use crate::cas_types::{REQUEST_ID_HEADER, SESSION_ID_HEADER}; +use crate::error::{ClientError, Result}; /// Middleware that rewrites https:// URLs to http:// when using Unix socket. /// This allows the proxy to parse plain HTTP and upgrade to HTTPS when forwarding. @@ -27,7 +26,7 @@ impl Middleware for HttpsToHttpMiddleware { mut req: Request, extensions: &mut Extensions, next: Next<'_>, - ) -> Result { + ) -> std::result::Result { let url = req.url_mut(); if url.scheme() == "https" { let original_scheme = url.scheme().to_string(); @@ -53,10 +52,7 @@ fn redact_headers(headers: &HeaderMap) -> HeaderMap { #[allow(unused_variables)] #[cfg(not(target_family = "wasm"))] -fn reqwest_client( - unix_socket_path: Option<&str>, - custom_headers: Option>, -) -> Result { +fn reqwest_client(unix_socket_path: Option<&str>, custom_headers: Option>) -> Result { // Check config if explicit socket path is not provided let socket_path = unix_socket_path .map(|s| s.to_string()) @@ -116,7 +112,7 @@ fn reqwest_client( fn reqwest_client_no_read_timeout( unix_socket_path: Option<&str>, custom_headers: Option>, -) -> Result { +) -> Result { let socket_path = unix_socket_path .map(|s| s.to_string()) .or_else(|| xet_config().client.unix_socket_path.clone()); @@ -149,10 +145,7 @@ fn reqwest_client_no_read_timeout( } #[cfg(target_family = "wasm")] -fn reqwest_client( - _unix_socket_path: Option<&str>, - custom_headers: Option>, -) -> Result { +fn reqwest_client(_unix_socket_path: Option<&str>, custom_headers: Option>) -> Result { // For WASM, create a new client with the specified headers, including the user-agent. // Note: we could cache this, but user_agent can vary, so we create per-call // Unix socket path is ignored on WASM @@ -170,7 +163,7 @@ pub fn build_auth_http_client( session_id: &str, unix_socket_path: Option<&str>, custom_headers: Option>, -) -> Result { +) -> Result { let auth_middleware = auth_config.as_ref().map(AuthMiddleware::from).info_none("CAS auth disabled"); let logging_middleware = Some(LoggingMiddleware); let session_middleware = (!session_id.is_empty()).then(|| SessionMiddleware(session_id.to_owned())); @@ -201,7 +194,7 @@ pub fn build_auth_http_client_no_read_timeout( session_id: &str, unix_socket_path: Option<&str>, custom_headers: Option>, -) -> Result { +) -> Result { let auth_middleware = auth_config.as_ref().map(AuthMiddleware::from).info_none("CAS auth disabled"); let logging_middleware = Some(LoggingMiddleware); let session_middleware = (!session_id.is_empty()).then(|| SessionMiddleware(session_id.to_owned())); @@ -226,7 +219,7 @@ pub fn build_http_client( session_id: &str, unix_socket_path: Option<&str>, custom_headers: Option>, -) -> Result { +) -> Result { build_auth_http_client(&None, session_id, unix_socket_path, custom_headers) } @@ -290,14 +283,14 @@ impl AuthMiddleware { /// (e.g. to a remote service). During this time, no other CAS requests can proceed /// from this client until the token has been fetched. This is expected/ok since we /// don't have a valid token and thus any calls would fail. - async fn get_token(&self) -> Result { + async fn get_token(&self) -> Result { let mut provider = self.token_provider.lock().await; provider .get_valid_token() .await .map_err(|err| { warn!(?err, "Token refresh failed"); - anyhow!("couldn't get token: {err:?}") + ClientError::AuthError(err) }) .inspect(|_token| { info!("Token refresh successful for CAS authentication"); @@ -322,7 +315,7 @@ impl Middleware for AuthMiddleware { extensions: &mut http::Extensions, next: Next<'_>, ) -> reqwest_middleware::Result { - let token = self.get_token().await.map_err(reqwest_middleware::Error::Middleware)?; + let token = self.get_token().await.map_err(reqwest_middleware::Error::middleware)?; let headers = req.headers_mut(); headers.insert(AUTHORIZATION, HeaderValue::from_str(&format!("Bearer {token}")).unwrap()); @@ -365,16 +358,14 @@ pub trait ResponseErrorLogger { /// This logs an error if one occurred before receiving a response or /// if the status code indicates a failure. /// As a result of these checks, the response is also transformed into a -/// cas_client::error::Result instead of the raw reqwest_middleware::Result. -impl ResponseErrorLogger> for reqwest_middleware::Result { - fn process_error(self, api: &str) -> error::Result { - let res = self - .map_err(CasClientError::from) - .log_error(format!("error invoking {api} api"))?; +/// crate::error::Result instead of the raw reqwest_middleware::Result. +impl ResponseErrorLogger> for reqwest_middleware::Result { + fn process_error(self, api: &str) -> Result { + let res = self.map_err(ClientError::from).log_error(format!("error invoking {api} api"))?; let request_id = request_id_from_response(&res); let error_message = format!("{api} api failed: request id: {request_id}"); let status = res.status(); - let res = res.error_for_status().map_err(CasClientError::from); + let res = res.error_for_status().map_err(ClientError::from); match (api, status) { ("get_reconstruction", StatusCode::RANGE_NOT_SATISFIABLE) => res.debug_error(&error_message), // not all status codes mean fatal error diff --git a/xet_client/src/cas_client/interface.rs b/xet_client/src/cas_client/interface.rs index a355d097..8664336f 100644 --- a/xet_client/src/cas_client/interface.rs +++ b/xet_client/src/cas_client/interface.rs @@ -4,9 +4,9 @@ use xet_core_structures::metadata_shard::file_structs::MDBFileInfo; use xet_core_structures::xorb_object::SerializedXorbObject; use super::adaptive_concurrency::ConnectionPermit; -use super::error::Result; use super::progress_tracked_streams::ProgressCallback; use crate::cas_types::{BatchQueryReconstructionResponse, FileRange, HttpRange, QueryReconstructionResponseV2}; +use crate::error::Result; #[async_trait::async_trait] pub trait URLProvider: Send + Sync { diff --git a/xet_client/src/cas_client/mod.rs b/xet_client/src/cas_client/mod.rs index e81486c0..8a63c11e 100644 --- a/xet_client/src/cas_client/mod.rs +++ b/xet_client/src/cas_client/mod.rs @@ -1,4 +1,3 @@ -pub use error::CasClientError; pub use http_client::{Api, ResponseErrorLogger, build_auth_http_client, build_http_client}; pub use interface::{Client, URLProvider}; pub use remote_client::RemoteClient; @@ -12,7 +11,6 @@ use tracing::Level; pub mod adaptive_concurrency; pub mod auth; -mod error; pub mod exports; pub mod http_client; mod interface; diff --git a/xet_client/src/cas_client/multipart.rs b/xet_client/src/cas_client/multipart.rs index 8bc9a226..b2403764 100644 --- a/xet_client/src/cas_client/multipart.rs +++ b/xet_client/src/cas_client/multipart.rs @@ -1,7 +1,7 @@ use bytes::Bytes; -use crate::cas_client::error::{CasClientError, Result}; use crate::cas_types::HttpRange; +use crate::error::{ClientError, Result}; /// A single part from a multipart/byteranges HTTP response. pub struct MultipartPart { @@ -23,7 +23,7 @@ pub fn parse_multipart_byteranges(content_type: &str, body: Bytes) -> Result Result Result { return Ok(boundary.to_string()); } } - Err(CasClientError::Other(format!("No boundary found in Content-Type: {content_type}"))) + Err(ClientError::Other(format!("No boundary found in Content-Type: {content_type}"))) } fn parse_content_range(headers: &[u8]) -> Result { - let headers_str = std::str::from_utf8(headers) - .map_err(|e| CasClientError::Other(format!("Invalid UTF-8 in part headers: {e}")))?; + let headers_str = + std::str::from_utf8(headers).map_err(|e| ClientError::Other(format!("Invalid UTF-8 in part headers: {e}")))?; for line in headers_str.split("\r\n") { let line_lower = line.to_ascii_lowercase(); @@ -96,24 +96,24 @@ fn parse_content_range(headers: &[u8]) -> Result { let original_value = range_spec.trim(); let slash_pos = original_value .find('/') - .ok_or_else(|| CasClientError::Other(format!("Invalid Content-Range: {line}")))?; + .ok_or_else(|| ClientError::Other(format!("Invalid Content-Range: {line}")))?; let range_part = &original_value[..slash_pos]; let dash_pos = range_part .find('-') - .ok_or_else(|| CasClientError::Other(format!("Invalid Content-Range: {line}")))?; + .ok_or_else(|| ClientError::Other(format!("Invalid Content-Range: {line}")))?; let start: u64 = range_part[..dash_pos] .parse() - .map_err(|e| CasClientError::Other(format!("Invalid Content-Range start: {e}")))?; + .map_err(|e| ClientError::Other(format!("Invalid Content-Range start: {e}")))?; let end: u64 = range_part[dash_pos + 1..] .parse() - .map_err(|e| CasClientError::Other(format!("Invalid Content-Range end: {e}")))?; + .map_err(|e| ClientError::Other(format!("Invalid Content-Range end: {e}")))?; // RFC 7233 Content-Range uses an inclusive end, which matches HttpRange. return Ok(HttpRange::new(start, end)); } } } - Err(CasClientError::Other("No Content-Range header found in multipart part".to_string())) + Err(ClientError::Other("No Content-Range header found in multipart part".to_string())) } fn find_subsequence(haystack: &[u8], needle: &[u8]) -> Option { diff --git a/xet_client/src/cas_client/remote_client.rs b/xet_client/src/cas_client/remote_client.rs index 88c7a7e0..c34dc826 100644 --- a/xet_client/src/cas_client/remote_client.rs +++ b/xet_client/src/cas_client/remote_client.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use std::sync::atomic::{AtomicU32, AtomicU64, Ordering}; +use anyhow::anyhow; use bytes::Bytes; use futures::TryStreamExt; use http::HeaderValue; @@ -15,7 +16,6 @@ use xet_runtime::core::xet_config; use super::adaptive_concurrency::{AdaptiveConcurrencyController, ConnectionPermit}; use super::auth::AuthConfig; -use super::error::{CasClientError, Result}; use super::http_client::{self, Api}; use super::interface::URLProvider; use super::progress_tracked_streams::{ @@ -27,6 +27,7 @@ use crate::cas_types::{ BatchQueryReconstructionResponse, FileRange, HttpRange, Key, QueryReconstructionResponse, QueryReconstructionResponseV2, UploadShardResponse, UploadShardResponseType, UploadXorbResponse, }; +use crate::error::{ClientError, Result}; pub const CAS_ENDPOINT: &str = "http://localhost:8080"; pub const PREFIX_DEFAULT: &str = "default"; @@ -187,7 +188,9 @@ impl RemoteClient { "v1" => "cas::get_reconstruction_v1", "v2" => "cas::get_reconstruction_v2", _ => { - return Err(CasClientError::internal(format!("unsupported reconstruction API version: {api_version}"))); + return Err(ClientError::InternalError(anyhow!( + "unsupported reconstruction API version: {api_version}" + ))); }, }; @@ -224,7 +227,7 @@ impl RemoteClient { ); Ok(Some(response)) }, - Err(CasClientError::ReqwestError(ref e, _)) if e.status() == Some(StatusCode::RANGE_NOT_SATISFIABLE) => { + Err(ClientError::ReqwestError(ref e, _)) if e.status() == Some(StatusCode::RANGE_NOT_SATISFIABLE) => { Ok(None) }, Err(e) => Err(e), @@ -286,7 +289,7 @@ impl RemoteClient { Err(e) => Err(e), }, 1 => Ok(self.get_reconstruction_v1(file_id, bytes_range).await?.map(Into::into)), - other => Err(CasClientError::internal(format!("unsupported reconstruction API version: {other}"))), + other => Err(ClientError::InternalError(anyhow!("unsupported reconstruction API version: {other}"))), } } } @@ -424,7 +427,7 @@ impl Client for RemoteClient { let body = resp .bytes() .await - .map_err(|e| RetryableReqwestError::RetryableError(CasClientError::from(e)))?; + .map_err(|e| RetryableReqwestError::RetryableError(ClientError::from(e)))?; let multipart_parts = crate::cas_client::multipart::parse_multipart_byteranges(&content_type, body) .map_err(RetryableReqwestError::FatalError)?; @@ -439,7 +442,7 @@ impl Client for RemoteClient { let (data, chunk_indices) = xet_core_structures::xorb_object::deserialize_chunks(&mut std::io::Cursor::new(part.data.as_ref())) .map_err(|e| { - RetryableReqwestError::RetryableError(CasClientError::FormatError(e)) + RetryableReqwestError::RetryableError(ClientError::FormatError(e)) })?; xet_core_structures::xorb_object::append_chunk_segment( @@ -455,7 +458,7 @@ impl Client for RemoteClient { if let Some(expected) = uncompressed_size_if_known && expected != all_decompressed.len() { - return Err(RetryableReqwestError::RetryableError(CasClientError::Other(format!( + return Err(RetryableReqwestError::RetryableError(ClientError::Other(format!( "get_file_term_data: expected {expected} uncompressed bytes, got {}", all_decompressed.len() )))); @@ -482,14 +485,14 @@ impl Client for RemoteClient { if let Some(expected) = uncompressed_size_if_known && expected != buffer.len() { - return Err(RetryableReqwestError::RetryableError(CasClientError::Other(format!( + return Err(RetryableReqwestError::RetryableError(ClientError::Other(format!( "get_file_term_data: expected {expected} uncompressed bytes, got {}", buffer.len() )))); } Ok((Bytes::from(buffer), chunk_byte_indices)) }, - Err(e) => Err(RetryableReqwestError::RetryableError(CasClientError::FormatError(e))), + Err(e) => Err(RetryableReqwestError::RetryableError(ClientError::FormatError(e))), } } } diff --git a/xet_client/src/cas_client/retry_wrapper.rs b/xet_client/src/cas_client/retry_wrapper.rs index 99682baa..878d9acb 100644 --- a/xet_client/src/cas_client/retry_wrapper.rs +++ b/xet_client/src/cas_client/retry_wrapper.rs @@ -12,13 +12,13 @@ use tracing::{error, info}; use xet_runtime::core::xet_config; use super::adaptive_concurrency::ConnectionPermit; -use super::error::CasClientError; use super::http_client::request_id_from_response; +use crate::error::{ClientError, Result}; #[derive(Debug)] pub enum RetryableReqwestError { - FatalError(CasClientError), - RetryableError(CasClientError), + FatalError(ClientError), + RetryableError(ClientError), } struct ConnectionPermitInfo { @@ -102,7 +102,7 @@ impl RetryWrapper { error!("{msg}"); } - CasClientError::from(err) + ClientError::from(err) }; // Here's the retry logic. @@ -122,7 +122,11 @@ impl RetryWrapper { } } - fn process_ok_response(&self, try_idx: usize, resp: Response) -> Result { + fn process_ok_response( + &self, + try_idx: usize, + resp: Response, + ) -> std::result::Result { let request_id = request_id_from_response(&resp).to_owned(); let retry_str = if try_idx == 0 { @@ -140,7 +144,7 @@ impl RetryWrapper { } else { error!("{context}: {api:?} api call failed (request id {request_id}{retry_str}): {err}"); } - CasClientError::from(err) + ClientError::from(err) }; let retriability = default_on_request_success(&resp); @@ -206,12 +210,12 @@ impl RetryWrapper { self, make_request: ReqFn, process_fn: ProcFn, - ) -> Result + ) -> Result where ReqFn: Fn() -> ReqFut + Send + Sync + 'static, - ReqFut: std::future::Future> + 'static, + ReqFut: std::future::Future> + 'static, ProcFn: Fn(Response) -> ProcFut + Send + 'static, - ProcFut: Future> + 'static, + ProcFut: Future> + 'static, { let strategy = ExponentialBackoff::from_millis(self.base_delay.as_millis().min(u64::MAX as u128) as u64) .map(jitter) @@ -338,19 +342,16 @@ impl RetryWrapper { /// /// This functions acts just like the json() function on a client response, but retries the entire connection on /// transient errors. - pub async fn run_and_extract_json( - self, - make_request: ReqFn, - ) -> Result + pub async fn run_and_extract_json(self, make_request: ReqFn) -> Result where JsonDest: for<'de> serde::Deserialize<'de>, ReqFn: Fn() -> ReqFut + Send + Sync + 'static, - ReqFut: std::future::Future> + 'static, + ReqFut: std::future::Future> + 'static, { self.run_and_process(make_request, |resp: Response| { async move { // Extract the json from the final result. - let r: Result = resp.json().await; + let r: std::result::Result = resp.json().await; match r { Ok(v) => Ok(v), @@ -383,15 +384,15 @@ impl RetryWrapper { /// /// This functions acts just like the bytes() function on a client response, but retries the entire connection on /// transient errors. - pub async fn run_and_extract_bytes(self, make_request: ReqFn) -> Result + pub async fn run_and_extract_bytes(self, make_request: ReqFn) -> Result where ReqFn: Fn() -> ReqFut + Send + Sync + 'static, - ReqFut: std::future::Future> + 'static, + ReqFut: std::future::Future> + 'static, { self.run_and_process(make_request, |resp: Response| { async move { // Extract the bytes from the final result. - let r: Result = resp.bytes().await; + let r: std::result::Result = resp.bytes().await; match r { Ok(v) => Ok(v), @@ -432,12 +433,12 @@ impl RetryWrapper { self, make_request: ReqFn, parse: Parse, - ) -> Result + ) -> Result where ReqFn: Fn() -> ReqFut + Send + Sync + 'static, - ReqFut: std::future::Future> + 'static, + ReqFut: std::future::Future> + 'static, Parse: Fn(Response) -> ParseFut + Send + Sync + 'static, - ParseFut: std::future::Future> + 'static, + ParseFut: std::future::Future> + 'static, { self.run_and_process(make_request, parse).await } @@ -447,10 +448,10 @@ impl RetryWrapper { /// The `make_request` function returns a future that resolves to a Result object as is returned by the /// client middleware. For example, `|| client.clone().get(url).send()` returns a future (as `send()` is async) /// that will then be evaluated to get the response. - pub async fn run(self, make_request: ReqFn) -> Result + pub async fn run(self, make_request: ReqFn) -> Result where ReqFn: Fn() -> ReqFut + Send + Sync + 'static, - ReqFut: std::future::Future> + 'static, + ReqFut: std::future::Future> + 'static, { // Just have the process_fn pass through the response. self.run_and_process(make_request, |resp| async move { Ok(resp) }).await diff --git a/xet_client/src/cas_client/simulation/client_testing_utils.rs b/xet_client/src/cas_client/simulation/client_testing_utils.rs index 016f95ea..c58b5a3d 100644 --- a/xet_client/src/cas_client/simulation/client_testing_utils.rs +++ b/xet_client/src/cas_client/simulation/client_testing_utils.rs @@ -8,8 +8,8 @@ use xet_core_structures::metadata_shard::file_structs::{FileDataSequenceEntry, F use xet_core_structures::metadata_shard::shard_in_memory::MDBInMemoryShard; use xet_core_structures::xorb_object::{Chunk, RawXorbData, SerializedXorbObject}; -use super::super::error::Result; use super::super::interface::Client; +use crate::error::Result; /// Information about a term (segment) in the file, referencing an XORB and chunk range. #[derive(Clone, Debug)] diff --git a/xet_client/src/cas_client/simulation/client_unit_testing.rs b/xet_client/src/cas_client/simulation/client_unit_testing.rs index 3424cbf5..e46d130d 100644 --- a/xet_client/src/cas_client/simulation/client_unit_testing.rs +++ b/xet_client/src/cas_client/simulation/client_unit_testing.rs @@ -10,9 +10,9 @@ use std::sync::Arc; use bytes::Bytes; -use super::super::error::CasClientError; use super::{ClientTestingUtils, DirectAccessClient}; use crate::cas_types::FileRange; +use crate::error::ClientError; /// Runs all common Client trait tests using a factory that creates fresh clients. pub async fn test_client_functionality(factory: impl Fn() -> Fut) @@ -535,15 +535,15 @@ pub async fn test_missing_xorb(client: Arc) { // get_full_xorb should return XORBNotFound let result = client.get_full_xorb(&fake_hash).await; - assert!(matches!(result, Err(CasClientError::XORBNotFound(_)))); + assert!(matches!(result, Err(ClientError::XORBNotFound(_)))); // xorb_length should return XORBNotFound let result = client.xorb_length(&fake_hash).await; - assert!(matches!(result, Err(CasClientError::XORBNotFound(_)))); + assert!(matches!(result, Err(ClientError::XORBNotFound(_)))); // get_xorb_ranges should return XORBNotFound let result = client.get_xorb_ranges(&fake_hash, vec![(0, 1)]).await; - assert!(matches!(result, Err(CasClientError::XORBNotFound(_)))); + assert!(matches!(result, Err(ClientError::XORBNotFound(_)))); } /// Tests list_xorbs and delete_xorb operations. @@ -620,13 +620,13 @@ pub async fn test_get_file_data_with_ranges(client: Arc) let result = client .get_file_data(&file.file_hash, Some(FileRange::new(file_size + 100, file_size + 1000))) .await; - assert!(matches!(result.unwrap_err(), CasClientError::InvalidRange)); + assert!(matches!(result.unwrap_err(), ClientError::InvalidRange)); // Start equals file size returns error let result = client .get_file_data(&file.file_hash, Some(FileRange::new(file_size, file_size + 100))) .await; - assert!(matches!(result.unwrap_err(), CasClientError::InvalidRange)); + assert!(matches!(result.unwrap_err(), ClientError::InvalidRange)); } /// Tests get_file_size returns correct size. @@ -756,7 +756,7 @@ async fn test_url_expiration_after_window(client: Arc) { // The fetch should fail with expiration error let result = client.fetch_term_data(xorb_hash, fetch_info).await; assert!(result.is_err(), "URL should be expired after expiration window"); - assert!(matches!(result.unwrap_err(), CasClientError::PresignedUrlExpirationError)); + assert!(matches!(result.unwrap_err(), ClientError::PresignedUrlExpirationError)); } /// Tests that default URL expiration is effectively infinite. @@ -813,7 +813,7 @@ async fn test_url_expiration_exact_boundary(client: Arc) let result = client.fetch_term_data(xorb_hash, fetch_info).await; assert!(result.is_err(), "URL should be expired past boundary"); - assert!(matches!(result.unwrap_err(), CasClientError::PresignedUrlExpirationError)); + assert!(matches!(result.unwrap_err(), ClientError::PresignedUrlExpirationError)); } // ============================================================================= diff --git a/xet_client/src/cas_client/simulation/deletion_controls.rs b/xet_client/src/cas_client/simulation/deletion_controls.rs index e5982315..f1700500 100644 --- a/xet_client/src/cas_client/simulation/deletion_controls.rs +++ b/xet_client/src/cas_client/simulation/deletion_controls.rs @@ -2,7 +2,7 @@ use async_trait::async_trait; use bytes::Bytes; use xet_core_structures::merklehash::MerkleHash; -use super::super::error::Result; +use crate::error::Result; /// Trait for clients that support deletion and integrity operations on shards and file entries. /// diff --git a/xet_client/src/cas_client/simulation/direct_access_client.rs b/xet_client/src/cas_client/simulation/direct_access_client.rs index 89fd4d54..8314a615 100644 --- a/xet_client/src/cas_client/simulation/direct_access_client.rs +++ b/xet_client/src/cas_client/simulation/direct_access_client.rs @@ -12,11 +12,11 @@ use bytes::Bytes; use xet_core_structures::merklehash::MerkleHash; use xet_core_structures::xorb_object::XorbObject; -use super::super::error::Result; use super::super::interface::Client; use crate::cas_types::{ FileRange, QueryReconstructionResponse, QueryReconstructionResponseV2, XorbReconstructionFetchInfo, }; +use crate::error::Result; /// A Client with direct access to XORB and file storage. /// diff --git a/xet_client/src/cas_client/simulation/local_client.rs b/xet_client/src/cas_client/simulation/local_client.rs index 1686c233..80985236 100644 --- a/xet_client/src/cas_client/simulation/local_client.rs +++ b/xet_client/src/cas_client/simulation/local_client.rs @@ -31,12 +31,12 @@ use super::direct_access_client::DirectAccessClient; use super::xorb_utils::{self, REFERENCE_INSTANT}; use crate::cas_client::Client; use crate::cas_client::adaptive_concurrency::AdaptiveConcurrencyController; -use crate::cas_client::error::{CasClientError, Result}; use crate::cas_client::progress_tracked_streams::ProgressCallback; use crate::cas_types::{ BatchQueryReconstructionResponse, FileRange, HexMerkleHash, HttpRange, QueryReconstructionResponse, QueryReconstructionResponseV2, XorbMultiRangeFetch, XorbRangeDescriptor, XorbReconstructionFetchInfo, }; +use crate::error::{ClientError, Result}; pub struct LocalClient { // Note: Field order matters for Drop! heed::Env must be dropped before _tmp_dir @@ -121,7 +121,7 @@ impl LocalClient { } } result.ok_or_else(|| { - CasClientError::Other(format!( + ClientError::Other(format!( "Error opening db at {global_dedup_dir:?} after 5 attempts: {}", last_err.unwrap() )) @@ -130,13 +130,13 @@ impl LocalClient { let mut write_txn = global_dedup_db_env .write_txn() - .map_err(|e| CasClientError::Other(format!("Error opening heed write transaction: {e}")))?; + .map_err(|e| ClientError::Other(format!("Error opening heed write transaction: {e}")))?; let global_dedup_table = global_dedup_db_env .create_database(&mut write_txn, None) - .map_err(|e| CasClientError::Other(format!("Error opening heed table: {e}")))?; + .map_err(|e| ClientError::Other(format!("Error opening heed table: {e}")))?; write_txn .commit() - .map_err(|e| CasClientError::Other(format!("Error committing heed database: {e}")))?; + .map_err(|e| ClientError::Other(format!("Error committing heed database: {e}")))?; // Open / set up the shard lookup let shard_manager = ShardFileManager::new_in_session_directory(shard_dir.clone(), true).await?; @@ -164,8 +164,8 @@ impl LocalClient { /// Returns all shard files in the shard directory as (shard_hash, path) pairs. fn shard_file_paths(&self) -> Result> { let mut result = Vec::new(); - for entry in std::fs::read_dir(&self.shard_dir).map_err(CasClientError::internal)? { - let entry = entry.map_err(CasClientError::internal)?; + for entry in std::fs::read_dir(&self.shard_dir).map_err(ClientError::internal)? { + let entry = entry.map_err(ClientError::internal)?; let path = entry.path(); if let Some(hash) = parse_shard_filename(&path) && path.is_file() @@ -182,7 +182,7 @@ impl LocalClient { if path.exists() { Ok(path) } else { - Err(CasClientError::Other(format!("Shard file not found for hash {}", hash.hex()))) + Err(ClientError::Other(format!("Shard file not found for hash {}", hash.hex()))) } } @@ -281,7 +281,7 @@ impl super::DeletionControlableClient for LocalClient { for xorb_hash in &xorb_hashes { let xorb_path = self.get_path_for_entry(xorb_hash); if !xorb_path.exists() { - return Err(CasClientError::Other(format!( + return Err(ClientError::Other(format!( "Integrity error: shard {} references non-existent XORB {}", shard_hash.hex(), xorb_hash.hex() @@ -297,7 +297,7 @@ impl super::DeletionControlableClient for LocalClient { let segment = file_view.entry(seg_idx); let chunk_count = xorb_chunk_counts.get(&segment.xorb_hash).ok_or_else(|| { - CasClientError::Other(format!( + ClientError::Other(format!( "Integrity error: file {} references XORB block {} not present in shard {}", fh.hex(), segment.xorb_hash.hex(), @@ -306,7 +306,7 @@ impl super::DeletionControlableClient for LocalClient { })?; if segment.chunk_index_end as usize > *chunk_count { - return Err(CasClientError::Other(format!( + return Err(ClientError::Other(format!( "Integrity error: file {} references chunk range {}..{} but XORB block {} only has {} chunks", fh.hex(), segment.chunk_index_start, @@ -408,7 +408,7 @@ impl DirectAccessClient for LocalClient { let mut ret = Vec::new(); self.xorb_dir .read_dir() - .map_err(CasClientError::internal)? + .map_err(ClientError::internal)? .filter_map(|x| x.ok()) .filter_map(|x| x.file_name().into_string().ok()) .for_each(|x| { @@ -443,7 +443,7 @@ impl DirectAccessClient for LocalClient { let file_path = self.get_path_for_entry(hash); let file = File::open(&file_path).map_err(|_| { error!("Unable to find file in local CAS {:?}", file_path); - CasClientError::XORBNotFound(*hash) + ClientError::XORBNotFound(*hash) })?; let mut reader = BufReader::new(file); @@ -460,7 +460,7 @@ impl DirectAccessClient for LocalClient { let file_path = self.get_path_for_entry(hash); let file = File::open(&file_path).map_err(|_| { error!("Unable to find file in local CAS {:?}", file_path); - CasClientError::XORBNotFound(*hash) + ClientError::XORBNotFound(*hash) })?; let mut reader = BufReader::new(file); @@ -488,7 +488,7 @@ impl DirectAccessClient for LocalClient { let length = xorb_obj.get_all_bytes(&mut reader)?.len(); Ok(length as u32) }, - Err(_) => Err(CasClientError::XORBNotFound(*hash)), + Err(_) => Err(ClientError::XORBNotFound(*hash)), } } @@ -500,13 +500,13 @@ impl DirectAccessClient for LocalClient { }; if !md.is_file() { - return Err(CasClientError::internal(format!( + return Err(ClientError::InternalError(anyhow!( "Attempting to write to {file_path:?}, but it is not a file" ))); } let Ok(file) = File::open(&file_path) else { - return Err(CasClientError::XORBNotFound(*hash)); + return Err(ClientError::XORBNotFound(*hash)); }; let mut reader = BufReader::new(file); @@ -518,7 +518,7 @@ impl DirectAccessClient for LocalClient { let file_path = self.get_path_for_entry(hash); let mut file = File::open(&file_path).map_err(|_| { error!("Unable to find xorb in local CAS {:?}", file_path); - CasClientError::XORBNotFound(*hash) + ClientError::XORBNotFound(*hash) })?; file.seek(SeekFrom::End(-(size_of::() as i64)))?; @@ -541,9 +541,9 @@ impl DirectAccessClient for LocalClient { .shard_manager .get_file_reconstruction_info(hash) .await - .map_err(|e| anyhow!("{e}"))? + .map_err(ClientError::internal)? else { - return Err(CasClientError::FileNotFound(*hash)); + return Err(ClientError::FileNotFound(*hash)); }; let mut file_vec = Vec::new(); @@ -561,7 +561,7 @@ impl DirectAccessClient for LocalClient { let start = byte_range.as_ref().map(|range| range.start as usize).unwrap_or(0); if byte_range.is_some() && start >= file_size { - return Err(CasClientError::InvalidRange); + return Err(ClientError::InvalidRange); } let end = byte_range @@ -575,7 +575,7 @@ impl DirectAccessClient for LocalClient { async fn get_xorb_raw_bytes(&self, hash: &MerkleHash, byte_range: Option) -> Result { let file_path = self.get_path_for_entry(hash); - let data = std::fs::read(&file_path).map_err(|_| CasClientError::XORBNotFound(*hash))?; + let data = std::fs::read(&file_path).map_err(|_| ClientError::XORBNotFound(*hash))?; let start = byte_range.as_ref().map(|r| r.start as usize).unwrap_or(0); let end = byte_range @@ -585,7 +585,7 @@ impl DirectAccessClient for LocalClient { .min(data.len()); if start >= data.len() { - return Err(CasClientError::InvalidRange); + return Err(ClientError::InvalidRange); } Ok(Bytes::from(data[start..end].to_vec())) @@ -593,7 +593,7 @@ impl DirectAccessClient for LocalClient { async fn xorb_raw_length(&self, hash: &MerkleHash) -> Result { let file_path = self.get_path_for_entry(hash); - let metadata = std::fs::metadata(&file_path).map_err(|_| CasClientError::XORBNotFound(*hash))?; + let metadata = std::fs::metadata(&file_path).map_err(|_| ClientError::XORBNotFound(*hash))?; Ok(metadata.len()) } @@ -609,7 +609,7 @@ impl DirectAccessClient for LocalClient { let expiration_ms = self.url_expiration_ms.load(Ordering::Relaxed); let elapsed_ms = Instant::now().saturating_duration_since(url_timestamp).as_millis() as u64; if elapsed_ms > expiration_ms { - return Err(CasClientError::PresignedUrlExpirationError); + return Err(ClientError::PresignedUrlExpirationError); } // Validate byte range matches url_range @@ -617,11 +617,11 @@ impl DirectAccessClient for LocalClient { // We convert url_range to FileRange for comparison let fetch_byte_range = FileRange::from(fetch_term.url_range); if url_byte_range.start != fetch_byte_range.start || url_byte_range.end != fetch_byte_range.end { - return Err(CasClientError::InvalidArguments); + return Err(ClientError::InvalidArguments); } let file = File::open(&file_path).map_err(|_| { error!("Unable to find xorb in local CAS {:?}", file_path); - CasClientError::XORBNotFound(hash) + ClientError::XORBNotFound(hash) })?; let mut reader = BufReader::new(file); @@ -638,7 +638,7 @@ impl DirectAccessClient for LocalClient { for chunk_idx in fetch_term.range.start..fetch_term.range.end { let chunk_len = xorb_obj .uncompressed_chunk_length(chunk_idx) - .map_err(|e| CasClientError::Other(format!("Failed to get chunk length: {e}")))?; + .map_err(|e| ClientError::Other(format!("Failed to get chunk length: {e}")))?; cumulative += chunk_len; indices.push(cumulative); } @@ -666,7 +666,7 @@ impl LocalClient { let file_path = self.get_path_for_entry(hash); let mut file = File::open(&file_path).map_err(|_| { error!("Unable to find file in local CAS {:?}", file_path); - CasClientError::XORBNotFound(*hash) + ClientError::XORBNotFound(*hash) })?; XorbObject::deserialize(&mut file).map_err(Into::into) } @@ -784,7 +784,7 @@ impl Client for LocalClient { let env = self .global_dedup_db_env .as_ref() - .ok_or_else(|| CasClientError::Other("LocalClient has been closed".to_string()))?; + .ok_or_else(|| ClientError::Other("LocalClient has been closed".to_string()))?; let read_txn = env.read_txn().map_err(map_heed_db_error)?; if let Some(shard) = self.global_dedup_table.get(&read_txn, chunk_hash).map_err(map_heed_db_error)? { @@ -838,7 +838,7 @@ impl Client for LocalClient { let env = self .global_dedup_db_env .as_ref() - .ok_or_else(|| CasClientError::Other("LocalClient has been closed".to_string()))?; + .ok_or_else(|| ClientError::Other("LocalClient has been closed".to_string()))?; let mut write_txn = env.write_txn().map_err(map_heed_db_error)?; for chunk in chunk_hashes { @@ -878,7 +878,7 @@ impl Client for LocalClient { &serialized_data, )?; if computed_hash != hash { - return Err(CasClientError::Other(format!( + return Err(ClientError::Other(format!( "XORB hash mismatch: expected {}, got {}", hash.hex(), computed_hash.hex(), @@ -985,11 +985,11 @@ impl Client for LocalClient { url_info.refresh_url().await?; continue; } - return Err(CasClientError::PresignedUrlExpirationError); + return Err(ClientError::PresignedUrlExpirationError); } // Read each byte range from the serialized file and deserialize the chunks. - let mut file = File::open(&file_path).map_err(|_| CasClientError::XORBNotFound(MerkleHash::default()))?; + let mut file = File::open(&file_path).map_err(|_| ClientError::XORBNotFound(MerkleHash::default()))?; let mut all_decompressed = Vec::new(); let mut all_chunk_indices = Vec::::new(); @@ -1031,14 +1031,14 @@ impl Client for LocalClient { } // Should not reach here, but return error if we do. - Err(CasClientError::PresignedUrlExpirationError) + Err(ClientError::PresignedUrlExpirationError) } } -fn map_heed_db_error(e: heed::Error) -> CasClientError { +fn map_heed_db_error(e: heed::Error) -> ClientError { let msg = format!("Global shard dedup database error: {e:?}"); warn!("{msg}"); - CasClientError::Other(msg) + ClientError::Other(msg) } fn generate_fetch_url(file_path: &Path, byte_range: &FileRange, timestamp: Instant) -> String { @@ -1051,13 +1051,13 @@ fn parse_fetch_url(url: &str) -> Result<(PathBuf, FileRange, Instant)> { parts.reverse(); if parts.len() != 4 { - return Err(CasClientError::InvalidArguments); + return Err(ClientError::InvalidArguments); } let file_path_str = parts[0]; - let start_pos: u64 = parts[1].parse().map_err(|_| CasClientError::InvalidArguments)?; - let end_pos: u64 = parts[2].parse().map_err(|_| CasClientError::InvalidArguments)?; - let timestamp_ms: u64 = parts[3].parse().map_err(|_| CasClientError::InvalidArguments)?; + let start_pos: u64 = parts[1].parse().map_err(|_| ClientError::InvalidArguments)?; + let end_pos: u64 = parts[2].parse().map_err(|_| ClientError::InvalidArguments)?; + let timestamp_ms: u64 = parts[3].parse().map_err(|_| ClientError::InvalidArguments)?; let file_path: PathBuf = file_path_str.trim_matches('"').into(); let byte_range = FileRange::new(start_pos, end_pos); @@ -1133,7 +1133,7 @@ mod tests { }; let result = client.fetch_term_data(hash, invalid_fetch_term).await; assert!(result.is_err(), "URL with too few parts should fail"); - assert!(matches!(result.unwrap_err(), CasClientError::InvalidArguments)); + assert!(matches!(result.unwrap_err(), ClientError::InvalidArguments)); // Test 3: Invalid start_pos - doesn't match url_range.start let wrong_byte_range = FileRange::new(fetch_byte_start as u64 + 1, fetch_byte_end as u64); @@ -1145,7 +1145,7 @@ mod tests { }; let result = client.fetch_term_data(hash, invalid_fetch_term).await; assert!(result.is_err(), "Wrong start_pos should fail"); - assert!(matches!(result.unwrap_err(), CasClientError::InvalidArguments)); + assert!(matches!(result.unwrap_err(), ClientError::InvalidArguments)); // Test 4: Invalid end_pos - doesn't match url_range.end let wrong_byte_range = FileRange::new(fetch_byte_start as u64, fetch_byte_end as u64 + 1); @@ -1157,7 +1157,7 @@ mod tests { }; let result = client.fetch_term_data(hash, invalid_fetch_term).await; assert!(result.is_err(), "Wrong end_pos should fail"); - assert!(matches!(result.unwrap_err(), CasClientError::InvalidArguments)); + assert!(matches!(result.unwrap_err(), ClientError::InvalidArguments)); // Test 5: Invalid start_pos - non-numeric let timestamp_ms = timestamp.saturating_duration_since(*REFERENCE_INSTANT).as_millis() as u64; @@ -1169,7 +1169,7 @@ mod tests { }; let result = client.fetch_term_data(hash, invalid_fetch_term).await; assert!(result.is_err(), "Non-numeric start_pos should fail"); - assert!(matches!(result.unwrap_err(), CasClientError::InvalidArguments)); + assert!(matches!(result.unwrap_err(), ClientError::InvalidArguments)); // Test 6: Invalid end_pos - non-numeric let non_numeric_end = format!("{:?}:{}:not_a_number:{}", file_path, fetch_byte_start, timestamp_ms); @@ -1180,7 +1180,7 @@ mod tests { }; let result = client.fetch_term_data(hash, invalid_fetch_term).await; assert!(result.is_err(), "Non-numeric end_pos should fail"); - assert!(matches!(result.unwrap_err(), CasClientError::InvalidArguments)); + assert!(matches!(result.unwrap_err(), ClientError::InvalidArguments)); // Test 7: Empty URL let invalid_fetch_term = XorbReconstructionFetchInfo { @@ -1190,7 +1190,7 @@ mod tests { }; let result = client.fetch_term_data(hash, invalid_fetch_term).await; assert!(result.is_err(), "Empty URL should fail"); - assert!(matches!(result.unwrap_err(), CasClientError::InvalidArguments)); + assert!(matches!(result.unwrap_err(), ClientError::InvalidArguments)); // Test 8: Invalid timestamp - non-numeric let non_numeric_timestamp = format!("{:?}:{}:{}:not_a_number", file_path, fetch_byte_start, fetch_byte_end); @@ -1201,7 +1201,7 @@ mod tests { }; let result = client.fetch_term_data(hash, invalid_fetch_term).await; assert!(result.is_err(), "Non-numeric timestamp should fail"); - assert!(matches!(result.unwrap_err(), CasClientError::InvalidArguments)); + assert!(matches!(result.unwrap_err(), ClientError::InvalidArguments)); // Test 9: Non-existent file path let non_existent_path = PathBuf::from("/nonexistent/path/file.xorb"); diff --git a/xet_client/src/cas_client/simulation/local_server/handlers.rs b/xet_client/src/cas_client/simulation/local_server/handlers.rs index 8753f6d8..d44a0810 100644 --- a/xet_client/src/cas_client/simulation/local_server/handlers.rs +++ b/xet_client/src/cas_client/simulation/local_server/handlers.rs @@ -28,13 +28,13 @@ use futures_util::StreamExt; use http::header::RANGE; use xet_core_structures::merklehash::MerkleHash; -use super::super::super::error::CasClientError; use super::super::super::{DeletionControlableClient, DirectAccessClient}; use super::latency_simulation::{LatencySimulation, ServerLatencyProfile}; use crate::cas_types::{ FileRange, HexKey, HexMerkleHash, QueryReconstructionResponseV2, UploadShardResponse, UploadShardResponseType, UploadXorbResponse, XorbRangeDescriptor, XorbReconstructionFetchInfo, }; +use crate::error::ClientError; /// Server state passed to all handlers. #[derive(Clone)] @@ -117,12 +117,12 @@ pub(super) fn parse_range_header( } } -/// Maps CasClientError to appropriate HTTP status codes. -pub(super) fn error_to_response(e: CasClientError) -> Response { +/// Maps ClientError to appropriate HTTP status codes. +pub(super) fn error_to_response(e: ClientError) -> Response { let status = match &e { - CasClientError::XORBNotFound(_) | CasClientError::FileNotFound(_) => StatusCode::NOT_FOUND, - CasClientError::InvalidRange => StatusCode::RANGE_NOT_SATISFIABLE, - CasClientError::InvalidArguments => StatusCode::BAD_REQUEST, + ClientError::XORBNotFound(_) | ClientError::FileNotFound(_) => StatusCode::NOT_FOUND, + ClientError::InvalidRange => StatusCode::RANGE_NOT_SATISFIABLE, + ClientError::InvalidArguments => StatusCode::BAD_REQUEST, _ => StatusCode::INTERNAL_SERVER_ERROR, }; (status, e.to_string()).into_response() diff --git a/xet_client/src/cas_client/simulation/local_server/main.rs b/xet_client/src/cas_client/simulation/local_server/main.rs index cb3830de..b3fd3219 100644 --- a/xet_client/src/cas_client/simulation/local_server/main.rs +++ b/xet_client/src/cas_client/simulation/local_server/main.rs @@ -43,6 +43,7 @@ use std::path::PathBuf; +use anyhow::Result; use clap::Parser; use tracing_subscriber::EnvFilter; use xet_client::cas_client::{LocalServer, LocalServerConfig}; @@ -93,7 +94,7 @@ struct Args { } #[tokio::main] -async fn main() -> anyhow::Result<()> { +async fn main() -> Result<()> { // Initialize tracing with environment filter (respects RUST_LOG) tracing_subscriber::fmt().with_env_filter(EnvFilter::from_default_env()).init(); diff --git a/xet_client/src/cas_client/simulation/local_server/server.rs b/xet_client/src/cas_client/simulation/local_server/server.rs index 9cda702b..c001d6f4 100644 --- a/xet_client/src/cas_client/simulation/local_server/server.rs +++ b/xet_client/src/cas_client/simulation/local_server/server.rs @@ -12,10 +12,11 @@ //! # Example //! //! ```no_run +//! use anyhow::Result; //! use xet_client::cas_client::{LocalServer, LocalServerConfig}; //! //! #[tokio::main] -//! async fn main() -> anyhow::Result<()> { +//! async fn main() -> Result<()> { //! let config = LocalServerConfig { //! data_directory: "./data".into(), //! host: "127.0.0.1".to_string(), @@ -49,7 +50,6 @@ use tower_http::cors::CorsLayer; #[cfg(test)] use super::super::super::RemoteClient; -use super::super::super::error::{CasClientError, Result}; #[cfg(test)] use super::super::super::interface::Client; #[cfg(test)] @@ -58,6 +58,7 @@ use super::super::socket_proxy::UnixSocketProxy; use super::super::{DeletionControlableClient, DirectAccessClient, LocalClient, MemoryClient}; use super::handlers; use super::latency_simulation::LatencySimulation; +use crate::error::{ClientError, Result}; /// Configuration for the local CAS server. #[derive(Debug, Clone)] @@ -204,11 +205,11 @@ impl LocalServer { let addr: SocketAddr = self .addr() .parse() - .map_err(|e| CasClientError::Other(format!("Failed to parse address: {e}")))?; + .map_err(|e| ClientError::Other(format!("Failed to parse address: {e}")))?; let listener = TcpListener::bind(addr) .await - .map_err(|e| CasClientError::Other(format!("Failed to bind to {addr}: {e}")))?; + .map_err(|e| ClientError::Other(format!("Failed to bind to {addr}: {e}")))?; tracing::info!("Local CAS server listening on {}", addr); @@ -217,7 +218,7 @@ impl LocalServer { axum::serve(listener, router.into_make_service()) .with_graceful_shutdown(shutdown_signal()) .await - .map_err(|e| CasClientError::Other(format!("Server error: {e}"))) + .map_err(|e| ClientError::Other(format!("Server error: {e}"))) } /// Runs the server until a shutdown signal is received on the provided channel. @@ -227,11 +228,11 @@ impl LocalServer { let addr: SocketAddr = self .addr() .parse() - .map_err(|e| CasClientError::Other(format!("Failed to parse address: {e}")))?; + .map_err(|e| ClientError::Other(format!("Failed to parse address: {e}")))?; let listener = TcpListener::bind(addr) .await - .map_err(|e| CasClientError::Other(format!("Failed to bind to {addr}: {e}")))?; + .map_err(|e| ClientError::Other(format!("Failed to bind to {addr}: {e}")))?; tracing::info!("Local CAS server listening on {}", addr); @@ -242,7 +243,7 @@ impl LocalServer { let _ = shutdown_rx.await; }) .await - .map_err(|e| CasClientError::Other(format!("Server error: {e}"))) + .map_err(|e| ClientError::Other(format!("Server error: {e}"))) } } diff --git a/xet_client/src/cas_client/simulation/local_server/simulation_control_client.rs b/xet_client/src/cas_client/simulation/local_server/simulation_control_client.rs index 0368b12b..103cd26e 100644 --- a/xet_client/src/cas_client/simulation/local_server/simulation_control_client.rs +++ b/xet_client/src/cas_client/simulation/local_server/simulation_control_client.rs @@ -14,10 +14,10 @@ use super::simulation_types::{ XorbRawLengthResponse, }; use crate::cas_client::RemoteClient; -use crate::cas_client::error::{CasClientError, Result}; use crate::cas_client::interface::Client; use crate::cas_client::simulation::{DeletionControlableClient, DirectAccessClient}; use crate::cas_types::{FileRange, HexMerkleHash, QueryReconstructionResponseV2, XorbReconstructionFetchInfo}; +use crate::error::{ClientError, Result}; /// A client that connects to a `LocalTestServer` via HTTP and provides access /// to both `DirectAccessClient` and `DeletionControlableClient` operations @@ -50,26 +50,26 @@ impl SimulationControlClient { format!("{}/simulation{}", self.endpoint, path) } - /// Checks an HTTP response status, mapping errors to `CasClientError`. + /// Checks an HTTP response status, mapping errors to `ClientError`. async fn check_status(response: reqwest::Response) -> Result { let status = response.status(); if status.is_success() { Ok(response) } else if status == reqwest::StatusCode::NOT_IMPLEMENTED { - Err(CasClientError::Other("Deletion controls not available for this server backend".to_string())) + Err(ClientError::Other("Deletion controls not available for this server backend".to_string())) } else if status == reqwest::StatusCode::RANGE_NOT_SATISFIABLE { - Err(CasClientError::InvalidRange) + Err(ClientError::InvalidRange) } else { let body = response.text().await.unwrap_or_default(); - Err(CasClientError::Other(format!("HTTP {status}: {body}"))) + Err(ClientError::Other(format!("HTTP {status}: {body}"))) } } - /// Like `check_status`, but maps 404 to `CasClientError::XORBNotFound` for XORB endpoints. + /// Like `check_status`, but maps 404 to `ClientError::XORBNotFound` for XORB endpoints. async fn check_xorb_status(response: reqwest::Response, hash: &MerkleHash) -> Result { let status = response.status(); if status == reqwest::StatusCode::NOT_FOUND { - Err(CasClientError::XORBNotFound(*hash)) + Err(ClientError::XORBNotFound(*hash)) } else { Self::check_status(response).await } @@ -222,9 +222,9 @@ impl DirectAccessClient for SimulationControlClient { .get(self.sim_url("/xorbs")) .send() .await - .map_err(|e| CasClientError::Other(e.to_string()))?; + .map_err(|e| ClientError::Other(e.to_string()))?; let resp = Self::check_status(resp).await?; - resp.json().await.map_err(|e| CasClientError::Other(e.to_string())) + resp.json().await.map_err(|e| ClientError::Other(e.to_string())) } /// Deletes a XORB by hash via the `/simulation/xorbs/{hash}` endpoint. @@ -243,9 +243,9 @@ impl DirectAccessClient for SimulationControlClient { .get(&url) .send() .await - .map_err(|e| CasClientError::Other(e.to_string()))?; + .map_err(|e| ClientError::Other(e.to_string()))?; let resp = Self::check_xorb_status(resp, hash).await?; - resp.bytes().await.map_err(|e| CasClientError::Other(e.to_string())) + resp.bytes().await.map_err(|e| ClientError::Other(e.to_string())) } /// Retrieves specific chunk ranges from a XORB via the `/simulation/xorbs/{hash}/ranges` endpoint. @@ -259,9 +259,9 @@ impl DirectAccessClient for SimulationControlClient { .json(&body) .send() .await - .map_err(|e| CasClientError::Other(e.to_string()))?; + .map_err(|e| ClientError::Other(e.to_string()))?; let resp = Self::check_xorb_status(resp, hash).await?; - let result: XorbRangesResponse = resp.json().await.map_err(|e| CasClientError::Other(e.to_string()))?; + let result: XorbRangesResponse = resp.json().await.map_err(|e| ClientError::Other(e.to_string()))?; Ok(result.data.into_iter().map(Bytes::from).collect()) } @@ -274,9 +274,9 @@ impl DirectAccessClient for SimulationControlClient { .get(&url) .send() .await - .map_err(|e| CasClientError::Other(e.to_string()))?; + .map_err(|e| ClientError::Other(e.to_string()))?; let resp = Self::check_xorb_status(resp, hash).await?; - let result: XorbLengthResponse = resp.json().await.map_err(|e| CasClientError::Other(e.to_string()))?; + let result: XorbLengthResponse = resp.json().await.map_err(|e| ClientError::Other(e.to_string()))?; Ok(result.length) } @@ -289,9 +289,9 @@ impl DirectAccessClient for SimulationControlClient { .get(&url) .send() .await - .map_err(|e| CasClientError::Other(e.to_string()))?; + .map_err(|e| ClientError::Other(e.to_string()))?; let resp = Self::check_status(resp).await?; - let result: XorbExistsResponse = resp.json().await.map_err(|e| CasClientError::Other(e.to_string()))?; + let result: XorbExistsResponse = resp.json().await.map_err(|e| ClientError::Other(e.to_string()))?; Ok(result.exists) } @@ -299,7 +299,7 @@ impl DirectAccessClient for SimulationControlClient { async fn xorb_footer(&self, hash: &MerkleHash) -> Result { let raw_bytes = self.get_xorb_raw_bytes(hash, None).await?; XorbObject::deserialize(&mut std::io::Cursor::new(raw_bytes)) - .map_err(|e| CasClientError::Other(format!("Failed to deserialize XorbObject footer: {e}"))) + .map_err(|e| ClientError::Other(format!("Failed to deserialize XorbObject footer: {e}"))) } /// Returns the file size via the `/simulation/files/{hash}/size` endpoint. @@ -311,9 +311,9 @@ impl DirectAccessClient for SimulationControlClient { .get(&url) .send() .await - .map_err(|e| CasClientError::Other(e.to_string()))?; + .map_err(|e| ClientError::Other(e.to_string()))?; let resp = Self::check_status(resp).await?; - let result: FileSizeResponse = resp.json().await.map_err(|e| CasClientError::Other(e.to_string()))?; + let result: FileSizeResponse = resp.json().await.map_err(|e| ClientError::Other(e.to_string()))?; Ok(result.size) } @@ -325,9 +325,9 @@ impl DirectAccessClient for SimulationControlClient { if let Some(range) = byte_range { req = req.header(reqwest::header::RANGE, format!("bytes={}-{}", range.start, range.end.saturating_sub(1))); } - let resp = req.send().await.map_err(|e| CasClientError::Other(e.to_string()))?; + let resp = req.send().await.map_err(|e| ClientError::Other(e.to_string()))?; let resp = Self::check_status(resp).await?; - resp.bytes().await.map_err(|e| CasClientError::Other(e.to_string())) + resp.bytes().await.map_err(|e| ClientError::Other(e.to_string())) } /// Retrieves raw XORB bytes, optionally with a byte range, via the `/simulation/xorbs/{hash}/raw` endpoint. @@ -338,9 +338,9 @@ impl DirectAccessClient for SimulationControlClient { if let Some(range) = byte_range { req = req.header(reqwest::header::RANGE, format!("bytes={}-{}", range.start, range.end.saturating_sub(1))); } - let resp = req.send().await.map_err(|e| CasClientError::Other(e.to_string()))?; + let resp = req.send().await.map_err(|e| ClientError::Other(e.to_string()))?; let resp = Self::check_xorb_status(resp, hash).await?; - resp.bytes().await.map_err(|e| CasClientError::Other(e.to_string())) + resp.bytes().await.map_err(|e| ClientError::Other(e.to_string())) } /// Returns the raw byte length of a XORB via the `/simulation/xorbs/{hash}/raw_length` endpoint. @@ -352,9 +352,9 @@ impl DirectAccessClient for SimulationControlClient { .get(&url) .send() .await - .map_err(|e| CasClientError::Other(e.to_string()))?; + .map_err(|e| ClientError::Other(e.to_string()))?; let resp = Self::check_xorb_status(resp, hash).await?; - let result: XorbRawLengthResponse = resp.json().await.map_err(|e| CasClientError::Other(e.to_string()))?; + let result: XorbRawLengthResponse = resp.json().await.map_err(|e| ClientError::Other(e.to_string()))?; Ok(result.length) } @@ -372,9 +372,9 @@ impl DirectAccessClient for SimulationControlClient { .json(&body) .send() .await - .map_err(|e| CasClientError::Other(e.to_string()))?; + .map_err(|e| ClientError::Other(e.to_string()))?; let resp = Self::check_status(resp).await?; - let result: FetchTermDataResponse = resp.json().await.map_err(|e| CasClientError::Other(e.to_string()))?; + let result: FetchTermDataResponse = resp.json().await.map_err(|e| ClientError::Other(e.to_string()))?; Ok((Bytes::from(result.data), result.chunk_byte_indices)) } } @@ -388,9 +388,9 @@ impl DeletionControlableClient for SimulationControlClient { .get(self.sim_url("/shards")) .send() .await - .map_err(|e| CasClientError::Other(e.to_string()))?; + .map_err(|e| ClientError::Other(e.to_string()))?; let resp = Self::check_status(resp).await?; - resp.json().await.map_err(|e| CasClientError::Other(e.to_string())) + resp.json().await.map_err(|e| ClientError::Other(e.to_string())) } /// Retrieves raw shard bytes by hash via the `/simulation/shards/{hash}` endpoint. @@ -402,9 +402,9 @@ impl DeletionControlableClient for SimulationControlClient { .get(&url) .send() .await - .map_err(|e| CasClientError::Other(e.to_string()))?; + .map_err(|e| ClientError::Other(e.to_string()))?; let resp = Self::check_status(resp).await?; - resp.bytes().await.map_err(|e| CasClientError::Other(e.to_string())) + resp.bytes().await.map_err(|e| ClientError::Other(e.to_string())) } /// Deletes a shard entry by hash via the `/simulation/shards/{hash}` endpoint. @@ -416,7 +416,7 @@ impl DeletionControlableClient for SimulationControlClient { .delete(&url) .send() .await - .map_err(|e| CasClientError::Other(e.to_string()))?; + .map_err(|e| ClientError::Other(e.to_string()))?; Self::check_status(resp).await?; Ok(()) } @@ -428,9 +428,9 @@ impl DeletionControlableClient for SimulationControlClient { .get(self.sim_url("/file_entries")) .send() .await - .map_err(|e| CasClientError::Other(e.to_string()))?; + .map_err(|e| ClientError::Other(e.to_string()))?; let resp = Self::check_status(resp).await?; - let entries: Vec = resp.json().await.map_err(|e| CasClientError::Other(e.to_string()))?; + let entries: Vec = resp.json().await.map_err(|e| ClientError::Other(e.to_string()))?; Ok(entries.into_iter().map(|e| (e.file_hash, e.shard_hash)).collect()) } @@ -443,7 +443,7 @@ impl DeletionControlableClient for SimulationControlClient { .delete(&url) .send() .await - .map_err(|e| CasClientError::Other(e.to_string()))?; + .map_err(|e| ClientError::Other(e.to_string()))?; Self::check_status(resp).await?; Ok(()) } @@ -455,7 +455,7 @@ impl DeletionControlableClient for SimulationControlClient { .post(self.sim_url("/verify_integrity")) .send() .await - .map_err(|e| CasClientError::Other(e.to_string()))?; + .map_err(|e| ClientError::Other(e.to_string()))?; Self::check_status(resp).await?; Ok(()) } diff --git a/xet_client/src/cas_client/simulation/memory_client.rs b/xet_client/src/cas_client/simulation/memory_client.rs index 53ec028c..b50fbe87 100644 --- a/xet_client/src/cas_client/simulation/memory_client.rs +++ b/xet_client/src/cas_client/simulation/memory_client.rs @@ -20,7 +20,6 @@ use xet_core_structures::xorb_object::{SerializedXorbObject, XorbObject}; use super::super::Client; use super::super::adaptive_concurrency::AdaptiveConcurrencyController; -use super::super::error::{CasClientError, Result}; use super::super::progress_tracked_streams::ProgressCallback; use super::client_testing_utils::{FileTermReference, RandomFileContents}; use super::direct_access_client::DirectAccessClient; @@ -30,6 +29,7 @@ use crate::cas_types::{ BatchQueryReconstructionResponse, FileRange, HexMerkleHash, HttpRange, QueryReconstructionResponse, QueryReconstructionResponseV2, XorbMultiRangeFetch, XorbRangeDescriptor, XorbReconstructionFetchInfo, }; +use crate::error::{ClientError, Result}; /// Stored XORB data: the serialized data and the deserialized XorbObject (header/footer). struct MaterializedXorb { @@ -308,7 +308,7 @@ impl DirectAccessClient for MemoryClient { let xorbs = self.xorbs.read().await; let storage = xorbs.get(hash).ok_or_else(|| { error!("Unable to find xorb in memory CAS {:?}", hash); - CasClientError::XORBNotFound(*hash) + ClientError::XORBNotFound(*hash) })?; match storage { @@ -320,7 +320,7 @@ impl DirectAccessClient for MemoryClient { }, XorbStorage::Random(xorb) => xorb .get_chunk_range_data(0, xorb.num_chunks()) - .ok_or(CasClientError::XORBNotFound(*hash)), + .ok_or(ClientError::XORBNotFound(*hash)), } } @@ -332,7 +332,7 @@ impl DirectAccessClient for MemoryClient { let xorbs = self.xorbs.read().await; let storage = xorbs.get(hash).ok_or_else(|| { error!("Unable to find xorb in memory CAS {:?}", hash); - CasClientError::XORBNotFound(*hash) + ClientError::XORBNotFound(*hash) })?; match storage { @@ -358,7 +358,7 @@ impl DirectAccessClient for MemoryClient { ret.push(Bytes::new()); continue; } - let data = xorb.get_chunk_range_data(r.0, r.1).ok_or(CasClientError::XORBNotFound(*hash))?; + let data = xorb.get_chunk_range_data(r.0, r.1).ok_or(ClientError::XORBNotFound(*hash))?; ret.push(data); } Ok(ret) @@ -379,7 +379,7 @@ impl DirectAccessClient for MemoryClient { let xorbs = self.xorbs.read().await; let storage = xorbs.get(hash).ok_or_else(|| { error!("Unable to find xorb in memory CAS {:?}", hash); - CasClientError::XORBNotFound(*hash) + ClientError::XORBNotFound(*hash) })?; match storage { @@ -392,7 +392,7 @@ impl DirectAccessClient for MemoryClient { let shard = self.shard.read().await; let file_info = shard .get_file_reconstruction_info(hash) - .ok_or(CasClientError::FileNotFound(*hash))?; + .ok_or(ClientError::FileNotFound(*hash))?; Ok(file_info.file_size()) } @@ -401,7 +401,7 @@ impl DirectAccessClient for MemoryClient { let shard = self.shard.read().await; shard .get_file_reconstruction_info(hash) - .ok_or(CasClientError::FileNotFound(*hash))? + .ok_or(ClientError::FileNotFound(*hash))? }; let mut file_vec = Vec::new(); @@ -419,7 +419,7 @@ impl DirectAccessClient for MemoryClient { let start = byte_range.as_ref().map(|range| range.start as usize).unwrap_or(0); if byte_range.is_some() && start >= file_size { - return Err(CasClientError::InvalidRange); + return Err(ClientError::InvalidRange); } let end = byte_range @@ -433,7 +433,7 @@ impl DirectAccessClient for MemoryClient { async fn get_xorb_raw_bytes(&self, hash: &MerkleHash, byte_range: Option) -> Result { let xorbs = self.xorbs.read().await; - let storage = xorbs.get(hash).ok_or(CasClientError::XORBNotFound(*hash))?; + let storage = xorbs.get(hash).ok_or(ClientError::XORBNotFound(*hash))?; match storage { XorbStorage::Materialized(entry) => { @@ -447,7 +447,7 @@ impl DirectAccessClient for MemoryClient { .min(data.len()); if start >= data.len() { - return Err(CasClientError::InvalidRange); + return Err(ClientError::InvalidRange); } Ok(data.slice(start..end)) @@ -458,7 +458,7 @@ impl DirectAccessClient for MemoryClient { let end = byte_range.as_ref().map(|r| r.end).unwrap_or(total_len).min(total_len); if start >= total_len { - return Err(CasClientError::InvalidRange); + return Err(ClientError::InvalidRange); } Ok(xorb.get_serialized_range(start, end)) @@ -468,7 +468,7 @@ impl DirectAccessClient for MemoryClient { async fn xorb_raw_length(&self, hash: &MerkleHash) -> Result { let xorbs = self.xorbs.read().await; - let storage = xorbs.get(hash).ok_or(CasClientError::XORBNotFound(*hash))?; + let storage = xorbs.get(hash).ok_or(ClientError::XORBNotFound(*hash))?; match storage { XorbStorage::Materialized(entry) => Ok(entry.serialized_data.len() as u64), @@ -488,7 +488,7 @@ impl DirectAccessClient for MemoryClient { let expiration_ms = self.url_expiration_ms.load(Ordering::Relaxed); let elapsed_ms = Instant::now().saturating_duration_since(url_timestamp).as_millis() as u64; if elapsed_ms > expiration_ms { - return Err(CasClientError::PresignedUrlExpirationError); + return Err(ClientError::PresignedUrlExpirationError); } // Validate byte range matches url_range @@ -496,13 +496,13 @@ impl DirectAccessClient for MemoryClient { // We convert url_range to FileRange for comparison let fetch_byte_range = FileRange::from(fetch_term.url_range); if url_byte_range.start != fetch_byte_range.start || url_byte_range.end != fetch_byte_range.end { - return Err(CasClientError::InvalidArguments); + return Err(ClientError::InvalidArguments); } let xorbs = self.xorbs.read().await; let storage = xorbs.get(&xorb_hash).ok_or_else(|| { error!("Unable to find xorb in memory CAS {:?}", hash); - CasClientError::XORBNotFound(hash) + ClientError::XORBNotFound(hash) })?; let (data, xorb_obj) = match storage { @@ -516,7 +516,7 @@ impl DirectAccessClient for MemoryClient { XorbStorage::Random(xorb) => { let data = xorb .get_chunk_range_data(fetch_term.range.start, fetch_term.range.end) - .ok_or(CasClientError::XORBNotFound(hash))?; + .ok_or(ClientError::XORBNotFound(hash))?; let xorb_obj = xorb.get_xorb_object(); (data, xorb_obj) }, @@ -529,7 +529,7 @@ impl DirectAccessClient for MemoryClient { for chunk_idx in fetch_term.range.start..fetch_term.range.end { let chunk_len = xorb_obj .uncompressed_chunk_length(chunk_idx) - .map_err(|e| CasClientError::Other(format!("Failed to get chunk length: {e}")))?; + .map_err(|e| ClientError::Other(format!("Failed to get chunk length: {e}")))?; cumulative += chunk_len; indices.push(cumulative); } @@ -558,7 +558,7 @@ impl MemoryClient { xorb_utils::compute_reconstruction_ranges(&file_info, bytes_range, &mut |hash| { let storage = xorbs.get(hash).ok_or_else(|| { error!("Unable to find xorb in memory CAS {:?}", hash); - CasClientError::XORBNotFound(*hash) + ClientError::XORBNotFound(*hash) })?; Ok(match storage { XorbStorage::Materialized(entry) => entry.xorb_object.clone(), @@ -764,7 +764,7 @@ impl Client for MemoryClient { &serialized_data, )?; if computed_hash != hash { - return Err(CasClientError::Other(format!( + return Err(ClientError::Other(format!( "XORB hash mismatch: expected {}, got {}", hash.hex(), computed_hash.hex(), @@ -847,11 +847,11 @@ impl Client for MemoryClient { let expiration_ms = self.url_expiration_ms.load(Ordering::Relaxed); let elapsed_ms = Instant::now().saturating_duration_since(url_timestamp).as_millis() as u64; if elapsed_ms > expiration_ms { - return Err(CasClientError::PresignedUrlExpirationError); + return Err(ClientError::PresignedUrlExpirationError); } let xorbs = self.xorbs.read().await; - let storage = xorbs.get(&xorb_hash).ok_or(CasClientError::XORBNotFound(xorb_hash))?; + let storage = xorbs.get(&xorb_hash).ok_or(ClientError::XORBNotFound(xorb_hash))?; // Extract each byte range from the serialized data and deserialize let mut all_decompressed = Vec::new(); @@ -909,13 +909,13 @@ fn parse_fetch_url(url: &str) -> Result<(MerkleHash, FileRange, Instant)> { parts.reverse(); if parts.len() != 4 { - return Err(CasClientError::InvalidArguments); + return Err(ClientError::InvalidArguments); } - let hash = MerkleHash::from_hex(parts[0]).map_err(|_| CasClientError::InvalidArguments)?; - let start_pos: u64 = parts[1].parse().map_err(|_| CasClientError::InvalidArguments)?; - let end_pos: u64 = parts[2].parse().map_err(|_| CasClientError::InvalidArguments)?; - let timestamp_ms: u64 = parts[3].parse().map_err(|_| CasClientError::InvalidArguments)?; + let hash = MerkleHash::from_hex(parts[0]).map_err(|_| ClientError::InvalidArguments)?; + let start_pos: u64 = parts[1].parse().map_err(|_| ClientError::InvalidArguments)?; + let end_pos: u64 = parts[2].parse().map_err(|_| ClientError::InvalidArguments)?; + let timestamp_ms: u64 = parts[3].parse().map_err(|_| ClientError::InvalidArguments)?; let byte_range = FileRange::new(start_pos, end_pos); let timestamp = *REFERENCE_INSTANT + Duration::from_millis(timestamp_ms); diff --git a/xet_client/src/cas_client/simulation/network_simulation/bandwidth_limit_router.rs b/xet_client/src/cas_client/simulation/network_simulation/bandwidth_limit_router.rs index 6c6f23f0..f7b04b14 100644 --- a/xet_client/src/cas_client/simulation/network_simulation/bandwidth_limit_router.rs +++ b/xet_client/src/cas_client/simulation/network_simulation/bandwidth_limit_router.rs @@ -17,8 +17,8 @@ use tokio::sync::{Mutex, Semaphore}; use tokio::time::{Instant, interval, sleep, sleep_until}; use xet_runtime::utils::ClosureGuard; -use super::super::super::error::{CasClientError, Result}; use super::network_profile::{NetworkConfig, NetworkProfile}; +use crate::error::{ClientError, Result}; const BUF_SIZE: usize = 65536; const REFILL_INTERVAL_MS: u64 = 50; @@ -144,7 +144,7 @@ impl NetworkSimulationProxy { let mut guard = self.listener.lock().await; guard .take() - .ok_or_else(|| CasClientError::Other("accept loop already started or listener taken".into()))? + .ok_or_else(|| ClientError::Other("accept loop already started or listener taken".into()))? }; loop { if self.shutdown_flag.load(Ordering::Relaxed) { @@ -331,8 +331,8 @@ where Ok(total) } -fn map_proxy_err(e: impl std::fmt::Display) -> CasClientError { - CasClientError::Other(format!("Proxy error: {}", e)) +fn map_proxy_err(e: impl std::fmt::Display) -> ClientError { + ClientError::Other(format!("Proxy error: {}", e)) } #[cfg(test)] diff --git a/xet_client/src/cas_client/simulation/network_simulation/network_profile.rs b/xet_client/src/cas_client/simulation/network_simulation/network_profile.rs index d104b362..12799336 100644 --- a/xet_client/src/cas_client/simulation/network_simulation/network_profile.rs +++ b/xet_client/src/cas_client/simulation/network_simulation/network_profile.rs @@ -9,7 +9,7 @@ use std::time::Duration; use human_bandwidth::parse_bandwidth as parse_bandwidth_str; -use super::super::super::error::{CasClientError, Result}; +use crate::error::{ClientError, Result}; /// What the proxy applies: bandwidth (bytes/sec), latency, jitter, drop probability. /// Only `bandwidth_bytes_per_sec` is optional (None = no limit); the rest default to zero. @@ -82,11 +82,11 @@ impl NetworkProfileOptions { /// Stored as bytes per second internally. pub fn with_bandwidth_str(mut self, s: &str) -> Result { let bw = parse_bandwidth_str(s.trim()) - .map_err(|e| CasClientError::Other(format!("invalid bandwidth {:?}: {}", s, e)))?; + .map_err(|e| ClientError::Other(format!("invalid bandwidth {:?}: {}", s, e)))?; let bps = bw.as_bps(); let bps = u64::try_from(bps).unwrap_or(u64::MAX); if bps == 0 { - return Err(CasClientError::Other(format!("invalid bandwidth: {:?}", s))); + return Err(ClientError::Other(format!("invalid bandwidth: {:?}", s))); } self.max_bandwidth_bytes_per_sec = Some(bps / 8); Ok(self) @@ -129,7 +129,7 @@ impl NetworkProfileOptions { bandwidth_scale: 0.7, }), _ => { - return Err(CasClientError::Other(format!( + return Err(ClientError::Other(format!( "unknown congestion profile {:?}; valid: {:?}", name, VALID_CONGESTION_PROFILE_NAMES ))); diff --git a/xet_client/src/cas_client/simulation/simulation_client.rs b/xet_client/src/cas_client/simulation/simulation_client.rs index 8836c06a..0cbc7fb9 100644 --- a/xet_client/src/cas_client/simulation/simulation_client.rs +++ b/xet_client/src/cas_client/simulation/simulation_client.rs @@ -14,13 +14,13 @@ use reqwest::{Body, Url}; use serde_json; use super::super::adaptive_concurrency::ConnectionPermit; -use super::super::error::{CasClientError, Result}; use super::super::http_client::Api; use super::super::interface::Client; use super::super::progress_tracked_streams::{ProgressCallback, UploadProgressStream}; use super::super::remote_client::RemoteClient; use super::super::retry_wrapper::RetryWrapper; use super::local_server::ServerLatencyProfile; +use crate::error::{ClientError, Result}; /// A wrapper around `RemoteClient` that provides simulation-specific methods for controlling /// latency profiles and uploading dummy data for benchmarking and simulation purposes. @@ -43,7 +43,7 @@ impl RemoteSimulationClient { let client = self.inner.http_client(); let json_body = serde_json::to_vec(&profile) - .map_err(|e| CasClientError::Other(format!("Failed to serialize ServerLatencyProfile: {e}")))?; + .map_err(|e| ClientError::Other(format!("Failed to serialize ServerLatencyProfile: {e}")))?; let response = client .post(url) @@ -52,12 +52,12 @@ impl RemoteSimulationClient { .body(json_body) .send() .await - .map_err(|e| CasClientError::Other(format!("Failed to send set_config request: {e}")))?; + .map_err(|e| ClientError::Other(format!("Failed to send set_config request: {e}")))?; if !response.status().is_success() { let status = response.status(); let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string()); - return Err(CasClientError::Other(format!( + return Err(ClientError::Other(format!( "set_config request failed with status {}: {}", status, error_text ))); @@ -98,7 +98,7 @@ impl RemoteSimulationClient { .send() }) .await - .map_err(|e| CasClientError::Other(format!("Failed to upload dummy data: {e}")))?; + .map_err(|e| ClientError::Other(format!("Failed to upload dummy data: {e}")))?; Ok(n_upload_bytes) } diff --git a/xet_client/src/cas_client/simulation/simulation_server.rs b/xet_client/src/cas_client/simulation/simulation_server.rs index f1d58f9c..b424e0ed 100644 --- a/xet_client/src/cas_client/simulation/simulation_server.rs +++ b/xet_client/src/cas_client/simulation/simulation_server.rs @@ -16,7 +16,6 @@ use tempfile::TempDir; use tokio::sync::oneshot; use super::super::RemoteClient; -use super::super::error::Result; use super::super::interface::Client; use super::super::progress_tracked_streams::ProgressCallback; use super::local_server::{LocalServer, ServerLatencyProfile}; @@ -24,6 +23,7 @@ use super::network_simulation::{NetworkProfile, NetworkSimulationProxy}; #[cfg(unix)] use super::socket_proxy::UnixSocketProxy; use super::{DirectAccessClient, LocalClient, MemoryClient, RemoteSimulationClient}; +use crate::error::Result; /// Builder for creating a `LocalTestServer` with various configuration options. /// diff --git a/xet_client/src/cas_client/simulation/socket_proxy.rs b/xet_client/src/cas_client/simulation/socket_proxy.rs index 06ef5e56..4ead6517 100644 --- a/xet_client/src/cas_client/simulation/socket_proxy.rs +++ b/xet_client/src/cas_client/simulation/socket_proxy.rs @@ -6,6 +6,7 @@ use std::path::PathBuf; +use anyhow::Result; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpStream, UnixListener, UnixStream}; use tokio::sync::oneshot; @@ -21,9 +22,10 @@ use tokio::sync::oneshot; /// ```no_run /// use std::path::PathBuf; /// +/// use anyhow::Result; /// use xet_client::cas_client::simulation::socket_proxy::UnixSocketProxy; /// -/// # async fn example() -> Result<(), Box> { +/// # async fn example() -> Result<()> { /// let socket_path = PathBuf::from("/tmp/test_socket.sock"); /// let proxy = UnixSocketProxy::new(socket_path, "127.0.0.1:8080".to_string()).await?; /// // Proxy is now listening on the Unix socket and forwarding to TCP @@ -45,7 +47,7 @@ impl UnixSocketProxy { /// /// # Errors /// Returns an error if the socket file cannot be created or bound. - pub async fn new(socket_path: PathBuf, tcp_endpoint: String) -> Result> { + pub async fn new(socket_path: PathBuf, tcp_endpoint: String) -> Result { // Remove socket file if it exists if socket_path.exists() { std::fs::remove_file(&socket_path)?; @@ -97,7 +99,7 @@ impl UnixSocketProxy { } /// Handles a single connection by proxying data between Unix socket and TCP. - async fn handle_connection(unix_stream: UnixStream, tcp_endpoint: &str) -> Result<(), Box> { + async fn handle_connection(unix_stream: UnixStream, tcp_endpoint: &str) -> Result<()> { let tcp_stream = TcpStream::connect(tcp_endpoint).await?; // Use tokio::io::split to get owned halves that can be moved into tasks diff --git a/xet_client/src/cas_client/simulation/xorb_utils.rs b/xet_client/src/cas_client/simulation/xorb_utils.rs index 3828e5db..0a20724c 100644 --- a/xet_client/src/cas_client/simulation/xorb_utils.rs +++ b/xet_client/src/cas_client/simulation/xorb_utils.rs @@ -13,8 +13,8 @@ use xet_core_structures::merklehash::MerkleHash; use xet_core_structures::metadata_shard::file_structs::MDBFileInfo; use xet_core_structures::xorb_object::XorbObject; -use crate::cas_client::error::{CasClientError, Result}; use crate::cas_types::{ChunkRange, FileRange, HttpRange, XorbRangeDescriptor, XorbReconstructionTerm}; +use crate::error::{ClientError, Result}; lazy_static::lazy_static! { /// Reference instant for URL timestamps. Initialized far in the past to allow @@ -76,7 +76,7 @@ pub(crate) fn compute_reconstruction_ranges( loop { if s_idx >= file_info.segments.len() { - return Err(CasClientError::InvalidRange); + return Err(ClientError::InvalidRange); } let n = file_info.segments[s_idx].unpacked_segment_bytes as u64; @@ -202,16 +202,16 @@ pub(crate) fn generate_v2_fetch_url(hash: &MerkleHash, ranges: &[XorbRangeDescri /// Parses a V2 fetch URL back into (hash, timestamp, byte ranges). pub(crate) fn parse_v2_fetch_url(url: &str) -> Result<(MerkleHash, Instant, Vec)> { - let bytes = URL_SAFE_NO_PAD.decode(url).map_err(|_| CasClientError::InvalidArguments)?; - let payload = String::from_utf8(bytes).map_err(|_| CasClientError::InvalidArguments)?; + let bytes = URL_SAFE_NO_PAD.decode(url).map_err(|_| ClientError::InvalidArguments)?; + let payload = String::from_utf8(bytes).map_err(|_| ClientError::InvalidArguments)?; let mut parts = payload.splitn(3, ':'); - let hash_hex = parts.next().ok_or(CasClientError::InvalidArguments)?; - let ts_str = parts.next().ok_or(CasClientError::InvalidArguments)?; - let ranges_str = parts.next().ok_or(CasClientError::InvalidArguments)?; + let hash_hex = parts.next().ok_or(ClientError::InvalidArguments)?; + let ts_str = parts.next().ok_or(ClientError::InvalidArguments)?; + let ranges_str = parts.next().ok_or(ClientError::InvalidArguments)?; - let hash = MerkleHash::from_hex(hash_hex).map_err(|_| CasClientError::InvalidArguments)?; - let timestamp_ms: u64 = ts_str.parse().map_err(|_| CasClientError::InvalidArguments)?; + let hash = MerkleHash::from_hex(hash_hex).map_err(|_| ClientError::InvalidArguments)?; + let timestamp_ms: u64 = ts_str.parse().map_err(|_| ClientError::InvalidArguments)?; let timestamp = *REFERENCE_INSTANT + Duration::from_millis(timestamp_ms); let mut ranges = Vec::new(); @@ -219,14 +219,14 @@ pub(crate) fn parse_v2_fetch_url(url: &str) -> Result<(MerkleHash, Instant, Vec< let mut parts = r.splitn(2, '-'); let start: u64 = parts .next() - .ok_or(CasClientError::InvalidArguments)? + .ok_or(ClientError::InvalidArguments)? .parse() - .map_err(|_| CasClientError::InvalidArguments)?; + .map_err(|_| ClientError::InvalidArguments)?; let end: u64 = parts .next() - .ok_or(CasClientError::InvalidArguments)? + .ok_or(ClientError::InvalidArguments)? .parse() - .map_err(|_| CasClientError::InvalidArguments)?; + .map_err(|_| ClientError::InvalidArguments)?; ranges.push(HttpRange::new(start, end)); } @@ -451,7 +451,7 @@ mod tests { } else if *hash == hash_b { Ok(obj_b.clone()) } else { - Err(CasClientError::XORBNotFound(*hash)) + Err(ClientError::XORBNotFound(*hash)) } }) .unwrap(); diff --git a/xet_client/src/cas_types/error.rs b/xet_client/src/cas_types/error.rs deleted file mode 100644 index 44b810f1..00000000 --- a/xet_client/src/cas_types/error.rs +++ /dev/null @@ -1 +0,0 @@ -pub use crate::error::ClientError as CasTypesError; diff --git a/xet_client/src/cas_types/key.rs b/xet_client/src/cas_types/key.rs index 861d1484..67472a70 100644 --- a/xet_client/src/cas_types/key.rs +++ b/xet_client/src/cas_types/key.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; use xet_core_structures::merklehash::MerkleHash; use xet_core_structures::merklehash::data_hash::hex; -use super::error::CasTypesError; +use crate::error::ClientError; /// A Key indicates a prefixed merkle hash for some data stored in the CAS DB. #[derive(Debug, PartialEq, Default, Serialize, Deserialize, Ord, PartialOrd, Eq, Hash, Clone)] @@ -21,15 +21,15 @@ impl Display for Key { } impl FromStr for Key { - type Err = CasTypesError; + type Err = ClientError; fn from_str(s: &str) -> Result { let parts = s.rsplit_once('/'); let Some((prefix, hash)) = parts else { - return Err(CasTypesError::InvalidKey(s.to_owned())); + return Err(ClientError::InvalidKey(s.to_owned())); }; - let hash = MerkleHash::from_hex(hash).map_err(|_| CasTypesError::InvalidKey(s.to_owned()))?; + let hash = MerkleHash::from_hex(hash).map_err(|_| ClientError::InvalidKey(s.to_owned()))?; Ok(Key { prefix: prefix.to_owned(), diff --git a/xet_client/src/cas_types/mod.rs b/xet_client/src/cas_types/mod.rs index f28c0fc4..b6d5bb6e 100644 --- a/xet_client/src/cas_types/mod.rs +++ b/xet_client/src/cas_types/mod.rs @@ -9,7 +9,6 @@ use serde_repr::{Deserialize_repr, Serialize_repr}; use thiserror::Error; use xet_core_structures::merklehash::MerkleHash; -mod error; mod key; pub use key::*; diff --git a/xet_client/src/error.rs b/xet_client/src/error.rs index bd902b87..92537d3b 100644 --- a/xet_client/src/error.rs +++ b/xet_client/src/error.rs @@ -1,7 +1,6 @@ -use std::fmt::Debug; use std::num::TryFromIntError; -use anyhow::anyhow; +use anyhow::Error as AnyhowError; use http::StatusCode; use thiserror::Error; use tokio::sync::AcquireError; @@ -15,7 +14,7 @@ use crate::cas_client::auth::AuthError; #[derive(Error, Debug)] pub enum ClientError { #[error("Format error: {0}")] - FormatError(#[from] xet_core_structures::FormatError), + FormatError(#[from] xet_core_structures::CoreError), #[error("Configuration error: {0}")] ConfigurationError(String), @@ -36,7 +35,7 @@ pub enum ClientError { InvalidShardKey(String), #[error("Internal error: {0}")] - InternalError(#[from] anyhow::Error), + InternalError(AnyhowError), #[error("{0}")] Other(String), @@ -66,7 +65,7 @@ pub enum ClientError { AuthError(#[from] AuthError), #[error("Credential helper error: {0}")] - CredentialHelper(anyhow::Error), + CredentialHelper(AnyhowError), #[error("Invalid repo type: {0}")] InvalidRepoType(String), @@ -101,8 +100,12 @@ impl From for ClientError { } impl ClientError { - pub fn internal(value: T) -> Self { - ClientError::InternalError(anyhow!("{value:?}")) + pub fn internal(value: impl std::error::Error + Send + Sync + 'static) -> Self { + ClientError::InternalError(AnyhowError::new(value)) + } + + pub fn credential_helper_error(e: impl std::error::Error + Send + Sync + 'static) -> Self { + ClientError::CredentialHelper(AnyhowError::new(e)) } pub fn status(&self) -> Option { @@ -135,7 +138,7 @@ impl From> for } } -impl From> for ClientError { +impl From> for ClientError { fn from(value: std::sync::PoisonError) -> Self { Self::internal(value) } @@ -147,7 +150,7 @@ impl From for ClientError { } } -impl From> for ClientError { +impl From> for ClientError { fn from(value: SendError) -> Self { Self::internal(value) } diff --git a/xet_client/src/hub_client/auth/basics.rs b/xet_client/src/hub_client/auth/basics.rs index 4e33307c..6fa409d6 100644 --- a/xet_client/src/hub_client/auth/basics.rs +++ b/xet_client/src/hub_client/auth/basics.rs @@ -1,10 +1,10 @@ use std::sync::Arc; -use anyhow::Result; use async_trait::async_trait; use reqwest_middleware::RequestBuilder; use super::CredentialHelper; +use crate::error::ClientError; pub struct NoopCredentialHelper {} @@ -16,7 +16,7 @@ impl NoopCredentialHelper { #[async_trait] impl CredentialHelper for NoopCredentialHelper { - async fn fill_credential(&self, req: RequestBuilder) -> Result { + async fn fill_credential(&self, req: RequestBuilder) -> Result { Ok(req) } @@ -42,7 +42,7 @@ impl BearerCredentialHelper { #[async_trait] impl CredentialHelper for BearerCredentialHelper { - async fn fill_credential(&self, req: RequestBuilder) -> Result { + async fn fill_credential(&self, req: RequestBuilder) -> Result { Ok(req.bearer_auth(&self.hf_token)) } diff --git a/xet_client/src/hub_client/auth/interface.rs b/xet_client/src/hub_client/auth/interface.rs index 8d93e06e..0fe6fa96 100644 --- a/xet_client/src/hub_client/auth/interface.rs +++ b/xet_client/src/hub_client/auth/interface.rs @@ -1,11 +1,11 @@ -use anyhow::Result; use async_trait::async_trait; use reqwest_middleware::RequestBuilder; +use crate::error::ClientError; + #[async_trait] pub trait CredentialHelper: Send + Sync { - async fn fill_credential(&self, req: RequestBuilder) -> Result; + async fn fill_credential(&self, req: RequestBuilder) -> Result; - // Used in tests to identify the source of the credential. fn whoami(&self) -> &str; } diff --git a/xet_client/src/hub_client/client.rs b/xet_client/src/hub_client/client.rs index 098c88d0..474e771a 100644 --- a/xet_client/src/hub_client/client.rs +++ b/xet_client/src/hub_client/client.rs @@ -4,11 +4,11 @@ use http::header::HeaderMap; use urlencoding::encode; use super::auth::CredentialHelper; -use super::errors::*; use super::types::{CasJWTInfo, RepoInfo}; use crate::cas_client::exports::ClientWithMiddleware; use crate::cas_client::retry_wrapper::RetryWrapper; use crate::cas_client::{Api, build_http_client}; +use crate::error::Result; /// The type of operation to perform, either to upload files or to download files. /// Different operations lead to CAS access token with different authorization levels. @@ -98,7 +98,7 @@ impl HubClient { let req = cred_helper .fill_credential(req) .await - .map_err(reqwest_middleware::Error::Middleware)?; + .map_err(reqwest_middleware::Error::middleware)?; req.send().await } }) @@ -114,9 +114,9 @@ mod tests { use http::header::{self, HeaderMap, HeaderValue}; - use super::super::errors::Result; use super::super::{BearerCredentialHelper, HFRepoType, Operation, RepoInfo}; use super::HubClient; + use crate::error::Result; #[tokio::test] #[ignore = "need valid write token"] diff --git a/xet_client/src/hub_client/errors.rs b/xet_client/src/hub_client/errors.rs deleted file mode 100644 index e29aab8d..00000000 --- a/xet_client/src/hub_client/errors.rs +++ /dev/null @@ -1,8 +0,0 @@ -pub use crate::error::ClientError as HubClientError; -pub type Result = std::result::Result; - -impl HubClientError { - pub fn credential_helper_error(e: impl std::error::Error + Send + Sync + 'static) -> HubClientError { - HubClientError::CredentialHelper(e.into()) - } -} diff --git a/xet_client/src/hub_client/mod.rs b/xet_client/src/hub_client/mod.rs index f6b7015f..45395256 100644 --- a/xet_client/src/hub_client/mod.rs +++ b/xet_client/src/hub_client/mod.rs @@ -1,9 +1,7 @@ mod auth; mod client; -mod errors; mod types; pub use auth::{BearerCredentialHelper, CredentialHelper, NoopCredentialHelper}; pub use client::{HubClient, Operation}; -pub use errors::{HubClientError, Result}; pub use types::{CasJWTInfo, HFRepoType, RepoInfo}; diff --git a/xet_client/src/hub_client/types.rs b/xet_client/src/hub_client/types.rs index e138e45a..6fb30674 100644 --- a/xet_client/src/hub_client/types.rs +++ b/xet_client/src/hub_client/types.rs @@ -3,7 +3,7 @@ use std::str::FromStr; use serde::Deserialize; -use super::errors::{HubClientError, Result}; +use crate::error::{ClientError, Result}; /// This defines the response format from the Huggingface Hub Xet CAS access token API. #[derive(Deserialize, Debug)] @@ -23,7 +23,7 @@ pub enum HFRepoType { } impl FromStr for HFRepoType { - type Err = HubClientError; + type Err = ClientError; fn from_str(s: &str) -> std::result::Result { match s.to_lowercase().as_str() { @@ -31,7 +31,7 @@ impl FromStr for HFRepoType { "model" | "models" => Ok(HFRepoType::Model), "dataset" | "datasets" => Ok(HFRepoType::Dataset), "space" | "spaces" => Ok(HFRepoType::Space), - t => Err(HubClientError::InvalidRepoType(t.to_owned())), + t => Err(ClientError::InvalidRepoType(t.to_owned())), } } } diff --git a/xet_core_structures/Cargo.toml b/xet_core_structures/Cargo.toml index c453cb38..4e1f0de0 100644 --- a/xet_core_structures/Cargo.toml +++ b/xet_core_structures/Cargo.toml @@ -23,7 +23,6 @@ bench = true [dependencies] xet-runtime = { version = "1.4.0", path = "../xet_runtime" } -anyhow = { workspace = true } async-trait = { workspace = true } base64 = { workspace = true } blake3 = { workspace = true } diff --git a/xet_core_structures/src/error.rs b/xet_core_structures/src/error.rs index b69ed8e0..d710f8c1 100644 --- a/xet_core_structures/src/error.rs +++ b/xet_core_structures/src/error.rs @@ -7,13 +7,13 @@ use crate::merklehash::MerkleHash; #[non_exhaustive] #[derive(Error, Debug)] -pub enum FormatError { +pub enum CoreError { // -- Common ---------------------------------------------------------- #[error("I/O error: {0}")] Io(#[from] std::io::Error), #[error("Internal error: {0}")] - Internal(anyhow::Error), + InternalError(String), #[error("{0}")] Other(String), @@ -50,14 +50,14 @@ pub enum FormatError { #[error("Invalid arguments")] InvalidArguments, - #[error("Format error: {0}")] - Format(anyhow::Error), + #[error("Malformed data: {0}")] + MalformedData(String), #[error("Hash mismatch")] HashMismatch, #[error("Compression error: {0}")] - Compression(#[from] lz4_flex::frame::Error), + CompressionError(#[from] lz4_flex::frame::Error), #[error("Hash parsing error: {0}")] HashParsing(#[from] Infallible), @@ -67,7 +67,7 @@ pub enum FormatError { // -- Runtime --------------------------------------------------------- #[error("Runtime error: {0}")] - Runtime(#[from] xet_runtime::RuntimeError), + RuntimeError(#[from] xet_runtime::RuntimeError), #[error("Task lock error: {0}")] TaskRuntime(#[from] xet_runtime::utils::RwTaskLockError), @@ -76,15 +76,15 @@ pub enum FormatError { TaskJoin(#[from] tokio::task::JoinError), } -pub type Result = std::result::Result; +pub type Result = std::result::Result; -impl PartialEq for FormatError { - fn eq(&self, other: &FormatError) -> bool { +impl PartialEq for CoreError { + fn eq(&self, other: &CoreError) -> bool { std::mem::discriminant(self) == std::mem::discriminant(other) } } -impl FormatError { +impl CoreError { pub fn other(inner: impl ToString) -> Self { Self::Other(inner.to_string()) } @@ -103,7 +103,7 @@ impl Validate for Result { fn ok_for_format_error(self) -> Result> { match self { Ok(v) => Ok(Some(v)), - Err(FormatError::Format(e)) => { + Err(CoreError::MalformedData(e)) => { warn!("XORB Validation: {e}"); Ok(None) }, @@ -112,14 +112,14 @@ impl Validate for Result { } } -impl From for FormatError { +impl From for CoreError { fn from(_: crate::merklehash::DataHashHexParseError) -> Self { - FormatError::Other("Invalid hex input for DataHash".to_string()) + CoreError::Other("Invalid hex input for DataHash".to_string()) } } -impl From for FormatError { +impl From for CoreError { fn from(_: crate::merklehash::DataHashBytesParseError) -> Self { - FormatError::Other("Invalid bytes input for DataHash".to_string()) + CoreError::Other("Invalid bytes input for DataHash".to_string()) } } diff --git a/xet_core_structures/src/lib.rs b/xet_core_structures/src/lib.rs index d776c3aa..e4ae8c38 100644 --- a/xet_core_structures/src/lib.rs +++ b/xet_core_structures/src/lib.rs @@ -1,7 +1,7 @@ #![cfg_attr(feature = "strict", deny(warnings))] pub mod error; -pub use error::FormatError; +pub use error::CoreError; pub mod data_structures; pub mod merklehash; diff --git a/xet_core_structures/src/metadata_shard/error.rs b/xet_core_structures/src/metadata_shard/error.rs deleted file mode 100644 index 303bfceb..00000000 --- a/xet_core_structures/src/metadata_shard/error.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub use crate::error::FormatError as MDBShardError; -pub type Result = std::result::Result; diff --git a/xet_core_structures/src/metadata_shard/file_structs.rs b/xet_core_structures/src/metadata_shard/file_structs.rs index 61890526..a05e012a 100644 --- a/xet_core_structures/src/metadata_shard/file_structs.rs +++ b/xet_core_structures/src/metadata_shard/file_structs.rs @@ -5,7 +5,6 @@ use std::mem::size_of; use bytes::Bytes; use serde::Serialize; -use super::error::MDBShardError; use super::shard_file::MDB_FILE_INFO_ENTRY_SIZE; use super::xorb_structs::{XorbChunkSequenceEntry, XorbChunkSequenceHeader}; use crate::merklehash::data_hash::hex; @@ -435,7 +434,7 @@ impl MDBFileInfo { /// Merges the content of other into the content of self if needed. /// After this call, self will have the verification info and metadata /// extension if they exist in the other object but not this one. - pub fn merge_from(&mut self, other: &Self) -> Result<(), MDBShardError> { + pub fn merge_from(&mut self, other: &Self) -> crate::error::Result<()> { FileDataSequenceHeader::verify_same_file(&self.metadata, &other.metadata); if self.contains_verification() != other.contains_verification() && other.contains_verification() { // self doesn't have verification. Copy from other diff --git a/xet_core_structures/src/metadata_shard/mod.rs b/xet_core_structures/src/metadata_shard/mod.rs index 953dc1ac..2ce83e5e 100644 --- a/xet_core_structures/src/metadata_shard/mod.rs +++ b/xet_core_structures/src/metadata_shard/mod.rs @@ -1,6 +1,5 @@ pub mod chunk_verification; pub mod constants; -pub mod error; pub mod file_structs; pub mod interpolation_search; pub mod session_directory; diff --git a/xet_core_structures/src/metadata_shard/session_directory.rs b/xet_core_structures/src/metadata_shard/session_directory.rs index e9e8703c..acd26e60 100644 --- a/xet_core_structures/src/metadata_shard/session_directory.rs +++ b/xet_core_structures/src/metadata_shard/session_directory.rs @@ -8,10 +8,10 @@ use tokio::task::JoinHandle; use tracing::{error, info}; use xet_runtime::core::{XetRuntime, check_sigint_shutdown}; -use super::error::Result; use super::set_operations::shard_set_union; use super::shard_file_handle::MDBShardFile; use super::{MDBShardFileFooter, MDBShardFileHeader, MDBShardInfo}; +use crate::error::Result; /// Merge a collection of shards, deleting the old ones. /// After calling this, the passed in shards may be invalid -- i.e. may refer to a shard that doesn't exist. diff --git a/xet_core_structures/src/metadata_shard/set_operations.rs b/xet_core_structures/src/metadata_shard/set_operations.rs index 59f4adb3..669d39aa 100644 --- a/xet_core_structures/src/metadata_shard/set_operations.rs +++ b/xet_core_structures/src/metadata_shard/set_operations.rs @@ -6,7 +6,6 @@ use std::path::Path; use uuid::Uuid; -use super::error::Result; use super::file_structs::{ FileDataSequenceEntry, FileDataSequenceHeader, FileMetadataExt, FileVerificationEntry, SupersetResult, }; @@ -14,6 +13,7 @@ use super::shard_file::MDB_FILE_INFO_ENTRY_SIZE; use super::shard_format::{MDBShardFileFooter, MDBShardFileHeader, MDBShardInfo}; use super::utils::truncate_hash; use super::xorb_structs::{XorbChunkSequenceEntry, XorbChunkSequenceHeader}; +use crate::error::Result; use crate::merklehash::{HashedWrite, MerkleHash}; use crate::serialization_utils::*; @@ -462,10 +462,10 @@ mod tests { use itertools::iproduct; use tempfile::TempDir; - use super::super::error::Result; use super::super::shard_format::test_routines::*; use super::super::shard_in_memory::MDBInMemoryShard; use super::*; + use crate::error::Result; use crate::merklehash::compute_data_hash; fn test_operations(mem_shard_1: &MDBInMemoryShard, mem_shard_2: &MDBInMemoryShard) -> Result<()> { diff --git a/xet_core_structures/src/metadata_shard/shard_benchmark.rs b/xet_core_structures/src/metadata_shard/shard_benchmark.rs index d51df245..66b2fa2e 100644 --- a/xet_core_structures/src/metadata_shard/shard_benchmark.rs +++ b/xet_core_structures/src/metadata_shard/shard_benchmark.rs @@ -5,7 +5,6 @@ use std::sync::Arc; use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::{Duration, Instant}; -use anyhow::{Ok, Result, anyhow}; use clap::Parser; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; @@ -54,7 +53,7 @@ async fn run_shard_benchmark( contiguity: usize, block_hit_proportion: f64, dir: &Path, -) -> Result<()> { +) -> Result<(), Box> { let mut seed = 0u64; eprintln!("Creating shards."); @@ -145,14 +144,14 @@ async fn run_shard_benchmark( Ok(()) } -fn parse_arg(arg: &str) -> Result<(u64, u64)> { +fn parse_arg(arg: &str) -> Result<(u64, u64), String> { let parts: Vec<&str> = arg.split(':').collect(); if parts.len() != 2 { - return Err(anyhow!("Failed to parse argument: {arg}")); + return Err(format!("Failed to parse argument: {arg}")); } - let size1 = u64::from_str(parts[0]).map_err(|e| anyhow!("Failed to parse size1: {e:?}"))?; - let size2 = u64::from_str(parts[1]).map_err(|e| anyhow!("Failed to parse size2: {e:?}"))?; + let size1 = u64::from_str(parts[0]).map_err(|e| format!("Failed to parse size1: {e:?}"))?; + let size2 = u64::from_str(parts[1]).map_err(|e| format!("Failed to parse size2: {e:?}"))?; Ok((size1, size2)) } diff --git a/xet_core_structures/src/metadata_shard/shard_file_handle.rs b/xet_core_structures/src/metadata_shard/shard_file_handle.rs index 8010dccc..78c6c098 100644 --- a/xet_core_structures/src/metadata_shard/shard_file_handle.rs +++ b/xet_core_structures/src/metadata_shard/shard_file_handle.rs @@ -11,12 +11,12 @@ use tracing::{debug, error, info, warn}; use super::MDBShardFileFooter; use super::constants::MDB_SHARD_EXPIRATION_BUFFER; -use super::error::{MDBShardError, Result}; use super::file_structs::{FileDataSequenceEntry, MDBFileInfo}; use super::shard_file::current_timestamp; use super::shard_format::MDBShardInfo; use super::utils::{parse_shard_filename, shard_file_name, temp_shard_file_name, truncate_hash}; use super::xorb_structs::XorbChunkSequenceHeader; +use crate::error::{CoreError, Result}; use crate::merklehash::{HMACKey, HashedWrite, MerkleHash, compute_data_hash}; /// When a specific implementation of the @@ -179,7 +179,7 @@ impl MDBShardFile { if let Some(shard_hash) = parse_shard_filename(path.to_str().unwrap()) { Self::load_from_hash_and_path(shard_hash, path) } else { - Err(MDBShardError::BadFilename(format!("{path:?} not a valid MerkleDB filename."))) + Err(CoreError::BadFilename(format!("{path:?} not a valid MerkleDB filename."))) } } @@ -315,7 +315,7 @@ impl MDBShardFile { } else if let Some(h) = path.file_name().and_then(parse_shard_filename) { load_file(h, path)?; } else { - return Err(MDBShardError::BadFilename(format!("Filename {path:?} not valid shard file name."))); + return Err(CoreError::BadFilename(format!("Filename {path:?} not valid shard file name."))); } Ok(()) @@ -365,11 +365,11 @@ impl MDBShardFile { pub fn get_reader_if_present(&self) -> Result>> { match self.get_reader() { Ok(v) => Ok(Some(v)), - Err(MDBShardError::Io(e)) => { + Err(CoreError::Io(e)) => { if e.kind() == ErrorKind::NotFound { Ok(None) } else { - Err(MDBShardError::Io(e)) + Err(CoreError::Io(e)) } }, Err(other_err) => Err(other_err), diff --git a/xet_core_structures/src/metadata_shard/shard_file_manager.rs b/xet_core_structures/src/metadata_shard/shard_file_manager.rs index eb7d612d..8e4ddf44 100644 --- a/xet_core_structures/src/metadata_shard/shard_file_manager.rs +++ b/xet_core_structures/src/metadata_shard/shard_file_manager.rs @@ -10,13 +10,13 @@ use xet_runtime::core::{XetRuntime, xet_config}; use xet_runtime::utils::RwTaskLock; use super::constants::MDB_SHARD_EXPIRATION_BUFFER; -use super::error::{MDBShardError, Result}; use super::file_structs::*; use super::shard_file_handle::MDBShardFile; use super::shard_file_reconstructor::FileReconstructor; use super::shard_in_memory::MDBInMemoryShard; use super::utils::truncate_hash; use super::xorb_structs::*; +use crate::error::{CoreError, Result}; use crate::merklehash::{HMACKey, MerkleHash}; use crate::{MerkleHashMap, TruncatedMerkleHashMap}; @@ -74,7 +74,7 @@ impl ShardBookkeeper { } pub struct ShardFileManager { - shard_bookkeeper: RwTaskLock, + shard_bookkeeper: RwTaskLock, current_state: RwLock, shard_directory: PathBuf, target_shard_max_size: u64, @@ -367,7 +367,7 @@ impl ShardFileManager { #[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)] #[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))] -impl FileReconstructor for ShardFileManager { +impl FileReconstructor for ShardFileManager { // Given a file pointer, returns the information needed to reconstruct the file. // The information is stored in the destination vector dest_results. The function // returns true if the file hash was found, and false otherwise. @@ -572,13 +572,13 @@ mod tests { use rand::prelude::*; use tempfile::TempDir; - use super::super::error::Result; use super::super::file_structs::FileDataSequenceHeader; use super::super::session_directory::{consolidate_shards_in_directory, merge_shards}; use super::super::shard_format::test_routines::{gen_random_file_info, rng_hash, simple_hash}; use super::super::utils::parse_shard_filename; use super::super::xorb_structs::{XorbChunkSequenceEntry, XorbChunkSequenceHeader}; use super::*; + use crate::error::Result; #[allow(clippy::type_complexity)] pub async fn fill_with_specific_shard( diff --git a/xet_core_structures/src/metadata_shard/shard_format.rs b/xet_core_structures/src/metadata_shard/shard_format.rs index 31d06f81..fceddaa1 100644 --- a/xet_core_structures/src/metadata_shard/shard_format.rs +++ b/xet_core_structures/src/metadata_shard/shard_format.rs @@ -5,18 +5,17 @@ use std::ops::Add; use std::sync::Arc; use std::time::{Duration, UNIX_EPOCH}; -use anyhow::anyhow; use futures::AsyncReadExt; use static_assertions::const_assert; use tracing::debug; -use super::error::{MDBShardError, Result}; use super::file_structs::*; use super::interpolation_search::search_on_sorted_u64s; use super::shard_in_memory::MDBInMemoryShard; use super::streaming_shard::MDBMinimalShard; use super::utils::{shard_expiry_time, truncate_hash}; use super::xorb_structs::*; +use crate::error::{CoreError, Result}; use crate::merklehash::{HMACKey, MerkleHash}; use crate::serialization_utils::*; @@ -87,7 +86,7 @@ impl MDBShardFileHeader { reader.read_exact(&mut tag)?; if tag != MDB_SHARD_HEADER_TAG { - return Err(MDBShardError::ShardVersion( + return Err(CoreError::ShardVersion( "File does not appear to be a valid Merkle DB Shard file (Wrong Magic Number).".to_owned(), )); } @@ -187,7 +186,7 @@ impl MDBShardFileFooter { // Do a version check as a simple guard against using this in an old repository if version != MDB_SHARD_FOOTER_VERSION { - return Err(MDBShardError::ShardVersion(format!( + return Err(CoreError::ShardVersion(format!( "Error: Expected footer version {MDB_SHARD_FOOTER_VERSION}, got {version}" ))); } @@ -494,7 +493,7 @@ impl MDBShardInfo { if num_indices < dest_indices.len() { Ok(num_indices) } else { - Err(MDBShardError::TruncatedHashCollision(truncate_hash(file_hash))) + Err(CoreError::TruncatedHashCollision(truncate_hash(file_hash))) } } @@ -517,7 +516,7 @@ impl MDBShardInfo { if num_indices < dest_indices.len() { Ok(num_indices) } else { - Err(MDBShardError::TruncatedHashCollision(truncate_hash(xorb_hash))) + Err(CoreError::TruncatedHashCollision(truncate_hash(xorb_hash))) } } @@ -556,7 +555,7 @@ impl MDBShardInfo { ))?; let Some(mdb_file) = MDBFileInfo::deserialize(reader)? else { - return Err(MDBShardError::Internal(anyhow!("invalid file entry index"))); + return Err(CoreError::InternalError("invalid file entry index".to_string())); }; Ok(mdb_file) @@ -1213,13 +1212,13 @@ pub mod test_routines { use rand::rngs::{SmallRng, StdRng}; use rand::{Rng, SeedableRng}; - use super::super::error::Result; use super::super::file_structs::{FileDataSequenceEntry, FileDataSequenceHeader, FileMetadataExt, MDBFileInfo}; use super::super::shard_format::MDBShardInfo; use super::super::shard_in_memory::MDBInMemoryShard; use super::super::streaming_shard::MDBMinimalShard; use super::super::xorb_structs::{MDBXorbInfo, XorbChunkSequenceEntry, XorbChunkSequenceHeader}; use super::FileVerificationEntry; + use crate::error::Result; use crate::merklehash::MerkleHash; pub fn simple_hash(n: u64) -> MerkleHash { @@ -1694,8 +1693,8 @@ pub mod test_routines { #[cfg(test)] mod tests { - use super::super::error::Result; use super::test_routines::*; + use crate::error::Result; #[test] fn test_simple() -> Result<()> { diff --git a/xet_core_structures/src/metadata_shard/shard_in_memory.rs b/xet_core_structures/src/metadata_shard/shard_in_memory.rs index 944dbf98..0db94638 100644 --- a/xet_core_structures/src/metadata_shard/shard_in_memory.rs +++ b/xet_core_structures/src/metadata_shard/shard_in_memory.rs @@ -9,12 +9,12 @@ use std::time::Duration; use tracing::debug; -use super::error::Result; use super::file_structs::*; use super::shard_format::MDBShardInfo; use super::utils::{shard_file_name, temp_shard_file_name}; use super::xorb_structs::*; use crate::MerkleHashMap; +use crate::error::Result; use crate::merklehash::{HashedWrite, MerkleHash}; #[allow(clippy::type_complexity)] diff --git a/xet_core_structures/src/metadata_shard/streaming_shard.rs b/xet_core_structures/src/metadata_shard/streaming_shard.rs index 70bdfbd2..bad1bbd2 100644 --- a/xet_core_structures/src/metadata_shard/streaming_shard.rs +++ b/xet_core_structures/src/metadata_shard/streaming_shard.rs @@ -8,12 +8,12 @@ use futures_util::io::AsyncReadExt; use itertools::Itertools; use more_asserts::debug_assert_lt; -use super::error::{MDBShardError, Result}; use super::file_structs::{FileDataSequenceHeader, MDBFileInfoView}; use super::shard_file::{MDB_FILE_INFO_ENTRY_SIZE, current_timestamp}; use super::xorb_structs::{MDBXorbInfoView, XorbChunkSequenceEntry, XorbChunkSequenceHeader}; use super::{MDBShardFileFooter, MDBShardFileHeader}; use crate::MerkleHashMap; +use crate::error::{CoreError, Result}; use crate::merklehash::MerkleHash; /// Runs through a shard file info section, calling the specified callback function for each entry. @@ -243,7 +243,7 @@ impl MDBMinimalShard { // if only some files have verification, then we consider this shard invalid // either all files have verification or no files have verification if !file_info_views.is_empty() && !file_info_views.iter().map(|fiv| fiv.contains_verification()).all_equal() { - return Err(MDBShardError::invalid_shard("only some files contain verification")); + return Err(CoreError::invalid_shard("only some files contain verification")); } // XORB stuff @@ -489,7 +489,6 @@ mod tests { use std::collections::{HashMap, HashSet}; use std::io::Cursor; - use anyhow::Result; use rand::rngs::SmallRng; use rand::{Rng, SeedableRng}; @@ -501,6 +500,7 @@ mod tests { use super::super::shard_in_memory::MDBInMemoryShard; use super::super::xorb_structs::MDBXorbInfo; use super::MDBMinimalShard; + use crate::error::Result; use crate::merklehash::MerkleHash; fn verify_serialization(min_shard: &MDBMinimalShard, mem_shard: &MDBInMemoryShard) -> Result<()> { diff --git a/xet_core_structures/src/xorb_object/byte_grouping/compression_stats/collect_compression_stats.rs b/xet_core_structures/src/xorb_object/byte_grouping/compression_stats/collect_compression_stats.rs index 8a0fbc43..498b64d2 100644 --- a/xet_core_structures/src/xorb_object/byte_grouping/compression_stats/collect_compression_stats.rs +++ b/xet_core_structures/src/xorb_object/byte_grouping/compression_stats/collect_compression_stats.rs @@ -9,7 +9,6 @@ use std::io::{Read, Seek, SeekFrom}; use std::sync::Arc; use std::sync::atomic::AtomicUsize; -use anyhow::Result; use clap::Parser; use csv::Writer; use rand::rngs::StdRng; @@ -161,7 +160,7 @@ async fn main() { #[cfg(not(target_family = "wasm"))] #[tokio::main] -async fn main() -> Result<()> { +async fn main() -> Result<(), Box> { // Parse command-line arguments let args = Args::parse(); diff --git a/xet_core_structures/src/xorb_object/compression_scheme.rs b/xet_core_structures/src/xorb_object/compression_scheme.rs index a0eb220b..9c69fdd7 100644 --- a/xet_core_structures/src/xorb_object/compression_scheme.rs +++ b/xet_core_structures/src/xorb_object/compression_scheme.rs @@ -4,12 +4,11 @@ use std::io::{Cursor, Read, Write, copy}; use std::str::FromStr; use std::time::Instant; -use anyhow::anyhow; use lz4_flex::frame::{FrameDecoder, FrameEncoder}; use super::byte_grouping::BG4Predictor; use super::byte_grouping::bg4::{bg4_regroup, bg4_split}; -use super::error::{Result, XorbObjectError}; +use crate::error::{CoreError, Result}; pub static mut BG4_SPLIT_RUNTIME: f64 = 0.; pub static mut BG4_REGROUP_RUNTIME: f64 = 0.; @@ -50,7 +49,7 @@ impl From for &'static str { } impl TryFrom for CompressionScheme { - type Error = XorbObjectError; + type Error = CoreError; fn try_from(value: u8) -> Result { match value { @@ -58,13 +57,13 @@ impl TryFrom for CompressionScheme { 1 => Ok(CompressionScheme::LZ4), 2 => Ok(CompressionScheme::ByteGrouping4LZ4), 99 => Ok(CompressionScheme::Auto), - _ => Err(XorbObjectError::Format(anyhow!("cannot convert value {value} to CompressionScheme"))), + _ => Err(CoreError::MalformedData(format!("cannot convert value {value} to CompressionScheme"))), } } } impl FromStr for CompressionScheme { - type Err = XorbObjectError; + type Err = CoreError; fn from_str(s: &str) -> std::result::Result { match s.trim().to_lowercase().as_str() { @@ -72,7 +71,7 @@ impl FromStr for CompressionScheme { "none" => Ok(CompressionScheme::None), "lz4" => Ok(CompressionScheme::LZ4), "bg4-lz4" => Ok(CompressionScheme::ByteGrouping4LZ4), - _ => Err(XorbObjectError::Format(anyhow!( + _ => Err(CoreError::MalformedData(format!( "Invalid compression scheme '{s}'. Valid values are: auto, none, lz4, bg4-lz4." ))), } @@ -101,7 +100,7 @@ impl CompressionScheme { pub fn decompress_from_slice<'a>(&self, data: &'a [u8]) -> Result> { Ok(match self { CompressionScheme::Auto => { - return Err(XorbObjectError::Format(anyhow!("Cannot decompress with Auto scheme"))); + return Err(CoreError::MalformedData("Cannot decompress with Auto scheme".to_string())); }, CompressionScheme::None => data.into(), CompressionScheme::LZ4 => lz4_decompress_from_slice(data).map(Cow::from)?, @@ -112,7 +111,7 @@ impl CompressionScheme { pub fn decompress_from_reader(&self, reader: &mut R, writer: &mut W) -> Result { Ok(match self { CompressionScheme::Auto => { - return Err(XorbObjectError::Format(anyhow!("Cannot decompress with Auto scheme"))); + return Err(CoreError::MalformedData("Cannot decompress with Auto scheme".to_string())); }, CompressionScheme::None => copy(reader, writer)?, CompressionScheme::LZ4 => lz4_decompress_from_reader(reader, writer)?, diff --git a/xet_core_structures/src/xorb_object/error.rs b/xet_core_structures/src/xorb_object/error.rs deleted file mode 100644 index 636f6cad..00000000 --- a/xet_core_structures/src/xorb_object/error.rs +++ /dev/null @@ -1 +0,0 @@ -pub use crate::error::{FormatError as XorbObjectError, Result, Validate}; diff --git a/xet_core_structures/src/xorb_object/mod.rs b/xet_core_structures/src/xorb_object/mod.rs index eab8479a..3e75728b 100644 --- a/xet_core_structures/src/xorb_object/mod.rs +++ b/xet_core_structures/src/xorb_object/mod.rs @@ -2,7 +2,6 @@ pub mod byte_grouping; mod chunk; mod compression_scheme; pub mod constants; -pub mod error; mod raw_xorb_data; mod xorb_chunk_format; mod xorb_object_format; diff --git a/xet_core_structures/src/xorb_object/xorb_chunk_format.rs b/xet_core_structures/src/xorb_object/xorb_chunk_format.rs index 9a7ca012..d9286a24 100644 --- a/xet_core_structures/src/xorb_object/xorb_chunk_format.rs +++ b/xet_core_structures/src/xorb_object/xorb_chunk_format.rs @@ -1,12 +1,10 @@ use std::io::{Read, Write}; use std::mem::size_of; -use anyhow::anyhow; - use super::CompressionScheme; use super::constants::MAX_CHUNK_SIZE; -use super::error::XorbObjectError; use super::xorb_object_format::XORB_OBJECT_FORMAT_IDENT; +use crate::error::CoreError; pub mod deserialize_async; @@ -54,7 +52,7 @@ impl XorbChunkHeader { convert_three_byte_num(&self.uncompressed_length) } - pub fn get_compression_scheme(&self) -> Result { + pub fn get_compression_scheme(&self) -> Result { CompressionScheme::try_from(self.compression_scheme) } @@ -63,17 +61,16 @@ impl XorbChunkHeader { self.compression_scheme = compression_scheme as u8; } - fn validate(&self) -> Result<(), XorbObjectError> { + fn validate(&self) -> Result<(), CoreError> { let _ = self.get_compression_scheme()?; if self.version > CURRENT_VERSION { - return Err(XorbObjectError::Format(anyhow!( + return Err(CoreError::MalformedData(format!( "chunk header version too high at {}, current version is {}", - self.version, - CURRENT_VERSION + self.version, CURRENT_VERSION ))); } if self.get_compressed_length() as usize > *MAX_CHUNK_SIZE * 2 { - return Err(XorbObjectError::Format(anyhow!( + return Err(CoreError::MalformedData(format!( "chunk header compressed length too large at {}, maximum: {}", self.get_compressed_length(), *MAX_CHUNK_SIZE @@ -81,7 +78,7 @@ impl XorbChunkHeader { } // the max chunk size is strictly enforced if self.get_uncompressed_length() as usize > *MAX_CHUNK_SIZE { - return Err(XorbObjectError::Format(anyhow!( + return Err(CoreError::MalformedData(format!( "chunk header uncompressed length too large at {}, maximum: {}", self.get_uncompressed_length(), *MAX_CHUNK_SIZE @@ -116,11 +113,13 @@ pub fn serialize_chunk( chunk: &[u8], w: &mut W, compression_scheme: CompressionScheme, -) -> Result { +) -> Result { let compression_scheme = compression_scheme.resolve_for_data(chunk); debug_assert_ne!(compression_scheme, CompressionScheme::Auto); if compression_scheme == CompressionScheme::Auto { - return Err(XorbObjectError::Format(anyhow!("CompressionScheme::Auto cannot be serialized into xorb chunks"))); + return Err(CoreError::MalformedData( + "CompressionScheme::Auto cannot be serialized into xorb chunks".to_string(), + )); } let compressed = compression_scheme.compress_from_slice(chunk)?; @@ -139,24 +138,22 @@ pub fn serialize_chunk( Ok(size_of::() + compressed.len()) } -pub fn parse_chunk_header( - chunk_header_bytes: [u8; XORB_CHUNK_HEADER_LENGTH], -) -> Result { +pub fn parse_chunk_header(chunk_header_bytes: [u8; XORB_CHUNK_HEADER_LENGTH]) -> Result { if chunk_header_bytes[..XORB_OBJECT_FORMAT_IDENT.len()] == XORB_OBJECT_FORMAT_IDENT { - return Err(XorbObjectError::ChunkHeaderParse); + return Err(CoreError::ChunkHeaderParse); } let result: XorbChunkHeader = unsafe { std::mem::transmute_copy(&chunk_header_bytes) }; result.validate()?; Ok(result) } -pub fn deserialize_chunk_header(reader: &mut R) -> Result { +pub fn deserialize_chunk_header(reader: &mut R) -> Result { let mut buf = [0u8; size_of::()]; reader.read_exact(&mut buf)?; parse_chunk_header(buf) } -pub fn deserialize_chunk(reader: &mut R) -> Result<(Vec, usize, u32), XorbObjectError> { +pub fn deserialize_chunk(reader: &mut R) -> Result<(Vec, usize, u32), CoreError> { let mut buf = Vec::new(); let (compressed_chunk_size, uncompressed_chunk_size) = deserialize_chunk_to_writer(reader, &mut buf)?; Ok((buf, compressed_chunk_size, uncompressed_chunk_size)) @@ -166,7 +163,7 @@ pub fn deserialize_chunk(reader: &mut R) -> Result<(Vec, usize, u32 pub fn deserialize_chunk_to_writer( reader: &mut R, writer: &mut W, -) -> Result<(usize, u32), XorbObjectError> { +) -> Result<(usize, u32), CoreError> { let header = deserialize_chunk_header(reader)?; deserialize_chunk_with_header_to_writer(reader, writer, header) } @@ -175,7 +172,7 @@ fn deserialize_chunk_with_header_to_writer( reader: &mut R, writer: &mut W, header: XorbChunkHeader, -) -> Result<(usize, u32), XorbObjectError> { +) -> Result<(usize, u32), CoreError> { let mut compressed_data_reader = reader.take(header.get_compressed_length().into()); let uncompressed_len = header @@ -183,15 +180,15 @@ fn deserialize_chunk_with_header_to_writer( .decompress_from_reader(&mut compressed_data_reader, writer)?; if uncompressed_len != header.get_uncompressed_length() as u64 { - return Err(XorbObjectError::Format(anyhow!( - "chunk is corrupted, uncompressed bytes len doesn't agree with chunk header" - ))); + return Err(CoreError::MalformedData( + "chunk is corrupted, uncompressed bytes len doesn't agree with chunk header".to_string(), + )); } Ok((header.get_compressed_length() as usize + XORB_CHUNK_HEADER_LENGTH, uncompressed_len as u32)) } -pub fn deserialize_chunks(reader: &mut R) -> Result<(Vec, Vec), XorbObjectError> { +pub fn deserialize_chunks(reader: &mut R) -> Result<(Vec, Vec), CoreError> { let mut buf = Vec::new(); let (_, chunk_byte_indices) = deserialize_chunks_to_writer(reader, &mut buf)?; Ok((buf, chunk_byte_indices)) @@ -223,12 +220,12 @@ pub fn append_chunk_segment( /// Uses a single `read()` call to detect EOF (returns 0), then completes /// any partial header with `read_exact`. An `UnexpectedEof` from `read_exact` /// means the stream was truncated mid-header. -fn try_read_chunk_header(reader: &mut R) -> Result, XorbObjectError> { +fn try_read_chunk_header(reader: &mut R) -> Result, CoreError> { let mut header_buf = [0u8; XORB_CHUNK_HEADER_LENGTH]; let n = match reader.read(&mut header_buf) { Ok(0) => return Ok(None), Ok(n) => n, - Err(e) => return Err(XorbObjectError::Io(e)), + Err(e) => return Err(CoreError::Io(e)), }; if n < XORB_CHUNK_HEADER_LENGTH { reader.read_exact(&mut header_buf[n..])?; @@ -239,7 +236,7 @@ fn try_read_chunk_header(reader: &mut R) -> Result( reader: &mut R, writer: &mut W, -) -> Result<(usize, Vec), XorbObjectError> { +) -> Result<(usize, Vec), CoreError> { let mut num_compressed_written = 0; let mut num_uncompressed_written = 0; diff --git a/xet_core_structures/src/xorb_object/xorb_chunk_format/deserialize_async.rs b/xet_core_structures/src/xorb_object/xorb_chunk_format/deserialize_async.rs index a5cc5198..3a3e6394 100644 --- a/xet_core_structures/src/xorb_object/xorb_chunk_format/deserialize_async.rs +++ b/xet_core_structures/src/xorb_object/xorb_chunk_format/deserialize_async.rs @@ -1,16 +1,13 @@ use std::io::Write; use std::mem::size_of; -use anyhow::anyhow; use futures::io::{AsyncRead, AsyncReadExt}; use futures::{Stream, TryStreamExt}; -use super::super::error::XorbObjectError; use super::{XORB_CHUNK_HEADER_LENGTH, XorbChunkHeader, parse_chunk_header}; +use crate::error::CoreError; -pub async fn deserialize_chunk_header( - reader: &mut R, -) -> Result { +pub async fn deserialize_chunk_header(reader: &mut R) -> Result { let mut buf = [0u8; size_of::()]; reader.read_exact(&mut buf).await?; parse_chunk_header(buf) @@ -20,7 +17,7 @@ pub async fn deserialize_chunk_header( pub async fn deserialize_chunk_to_writer( reader: &mut R, writer: &mut W, -) -> Result<(usize, u32), XorbObjectError> { +) -> Result<(usize, u32), CoreError> { let header = deserialize_chunk_header(reader).await?; deserialize_chunk_with_header_to_writer(reader, writer, header).await } @@ -29,7 +26,7 @@ async fn deserialize_chunk_with_header_to_writer reader: &mut R, writer: &mut W, header: XorbChunkHeader, -) -> Result<(usize, u32), XorbObjectError> { +) -> Result<(usize, u32), CoreError> { let mut compressed_data = vec![0u8; header.get_compressed_length() as usize]; reader.read_exact(&mut compressed_data).await?; @@ -37,9 +34,9 @@ async fn deserialize_chunk_with_header_to_writer let uncompressed_len = uncompressed_data.len(); if uncompressed_len != header.get_uncompressed_length() as usize { - return Err(XorbObjectError::Format(anyhow!( - "chunk is corrupted, uncompressed bytes len doesn't agree with chunk header" - ))); + return Err(CoreError::MalformedData( + "chunk is corrupted, uncompressed bytes len doesn't agree with chunk header".to_string(), + )); } writer.write_all(&uncompressed_data)?; @@ -48,7 +45,7 @@ async fn deserialize_chunk_with_header_to_writer } /// deserialize 1 chunk returning a Vec, the compressed length and the uncompressed length of the chunk -pub async fn deserialize_chunk(reader: &mut R) -> Result<(Vec, usize, u32), XorbObjectError> { +pub async fn deserialize_chunk(reader: &mut R) -> Result<(Vec, usize, u32), CoreError> { let mut buf = Vec::new(); let (compressed_len, uncompressed_len) = deserialize_chunk_to_writer(reader, &mut buf).await?; Ok((buf, compressed_len, uncompressed_len)) @@ -61,12 +58,12 @@ pub async fn deserialize_chunk(reader: &mut R) -> Result<( /// means the stream was truncated mid-header. async fn try_read_chunk_header_async( reader: &mut R, -) -> Result, XorbObjectError> { +) -> Result, CoreError> { let mut header_buf = [0u8; XORB_CHUNK_HEADER_LENGTH]; let n = match AsyncReadExt::read(reader, &mut header_buf).await { Ok(0) => return Ok(None), Ok(n) => n, - Err(e) => return Err(XorbObjectError::Io(e)), + Err(e) => return Err(CoreError::Io(e)), }; if n < XORB_CHUNK_HEADER_LENGTH { reader.read_exact(&mut header_buf[n..]).await?; @@ -77,7 +74,7 @@ async fn try_read_chunk_header_async( pub async fn deserialize_chunks_to_writer_from_async_read( reader: &mut R, writer: &mut W, -) -> Result<(usize, Vec), XorbObjectError> { +) -> Result<(usize, Vec), CoreError> { let mut num_compressed_written = 0; let mut num_uncompressed_written = 0; @@ -99,7 +96,7 @@ pub async fn deserialize_chunks_to_writer_from_async_read( reader: &mut R, -) -> Result<(Vec, Vec), XorbObjectError> { +) -> Result<(Vec, Vec), CoreError> { let mut buf = Vec::new(); let (_, chunk_byte_indices) = deserialize_chunks_to_writer_from_async_read(reader, &mut buf).await?; Ok((buf, chunk_byte_indices)) @@ -108,7 +105,7 @@ pub async fn deserialize_chunks_from_async_read( pub async fn deserialize_chunks_to_writer_from_stream( stream: S, writer: &mut W, -) -> Result<(usize, Vec), XorbObjectError> +) -> Result<(usize, Vec), CoreError> where B: AsRef<[u8]>, E: Into, @@ -119,7 +116,7 @@ where deserialize_chunks_to_writer_from_async_read(&mut stream_reader, writer).await } -pub async fn deserialize_chunks_from_stream(stream: S) -> Result<(Vec, Vec), XorbObjectError> +pub async fn deserialize_chunks_from_stream(stream: S) -> Result<(Vec, Vec), CoreError> where B: AsRef<[u8]>, E: Into, diff --git a/xet_core_structures/src/xorb_object/xorb_object_format.rs b/xet_core_structures/src/xorb_object/xorb_object_format.rs index 7e91468d..b1084471 100644 --- a/xet_core_structures/src/xorb_object/xorb_object_format.rs +++ b/xet_core_structures/src/xorb_object/xorb_object_format.rs @@ -2,7 +2,6 @@ use std::cmp::min; use std::io::{Cursor, Read, Seek, SeekFrom, Write}; use std::mem::{size_of, size_of_val}; -use anyhow::anyhow; use bytes::Buf; #[cfg(not(target_family = "wasm"))] use futures::AsyncReadExt; @@ -12,9 +11,9 @@ use tracing::warn; use xet_runtime::core::xet_config; use super::constants::{TARGET_CHUNK_SIZE, XORB_BLOCK_SIZE}; -use super::error::{Validate, XorbObjectError}; use super::xorb_chunk_format::{deserialize_chunk, deserialize_chunk_header, serialize_chunk, write_chunk_header}; use super::{CompressionScheme, RawXorbData, XorbChunkHeader}; +use crate::error::{CoreError, Validate}; use crate::merklehash::{DataHash, MerkleHash}; use crate::metadata_shard::chunk_verification::range_hash_from_chunks; use crate::serialization_utils::*; @@ -100,11 +99,11 @@ impl XorbObjectInfoV0 { /// /// Assumes caller has set position of Writer to appropriate location for serialization. #[deprecated] - pub fn serialize(&self, writer: &mut W) -> Result { + pub fn serialize(&self, writer: &mut W) -> Result { let mut total_bytes_written = 0; // Helper function to write data and update the byte count - let mut write_bytes = |data: &[u8]| -> Result<(), XorbObjectError> { + let mut write_bytes = |data: &[u8]| -> Result<(), CoreError> { writer.write_all(data)?; total_bytes_written += data.len(); Ok(()) @@ -134,11 +133,11 @@ impl XorbObjectInfoV0 { /// /// Expects metadata struct is found at end of Reader, written out in struct order. #[deprecated] - pub fn deserialize(reader: &mut R) -> Result<(Self, u32), XorbObjectError> { + pub fn deserialize(reader: &mut R) -> Result<(Self, u32), CoreError> { let mut total_bytes_read: u32 = 0; // Helper function to read data and update the byte count - let mut read_bytes = |data: &mut [u8]| -> Result<(), XorbObjectError> { + let mut read_bytes = |data: &mut [u8]| -> Result<(), CoreError> { reader.read_exact(data)?; total_bytes_read += data.len() as u32; Ok(()) @@ -148,14 +147,14 @@ impl XorbObjectInfoV0 { read_bytes(&mut ident)?; if ident != XORB_OBJECT_FORMAT_IDENT { - return Err(XorbObjectError::Format(anyhow!("Xorb Invalid Ident"))); + return Err(CoreError::MalformedData("Xorb Invalid Ident".to_string())); } let mut version = [0u8; 1]; read_bytes(&mut version)?; if version[0] != XORB_OBJECT_FORMAT_VERSION_V0 { - return Err(XorbObjectError::Format(anyhow!("Xorb Invalid Format Version"))); + return Err(CoreError::MalformedData("Xorb Invalid Format Version".to_string())); } let (s, bytes_read_v0) = Self::deserialize_v0(reader)?; @@ -163,11 +162,11 @@ impl XorbObjectInfoV0 { Ok((s, total_bytes_read + bytes_read_v0)) } - pub fn deserialize_v0(reader: &mut R) -> Result<(Self, u32), XorbObjectError> { + pub fn deserialize_v0(reader: &mut R) -> Result<(Self, u32), CoreError> { let mut total_bytes_read: u32 = 0; // Helper function to read data and update the byte count - let mut read_bytes = |data: &mut [u8]| -> Result<(), XorbObjectError> { + let mut read_bytes = |data: &mut [u8]| -> Result<(), CoreError> { reader.read_exact(data)?; total_bytes_read += data.len() as u32; Ok(()) @@ -219,7 +218,7 @@ impl XorbObjectInfoV0 { pub async fn deserialize_async( reader: &mut R, version: u8, - ) -> Result<(Self, u32), XorbObjectError> { + ) -> Result<(Self, u32), CoreError> { // already read 8 bytes (ident + version) let mut total_bytes_read: u32 = (size_of::() + size_of::()) as u32; @@ -228,7 +227,7 @@ impl XorbObjectInfoV0 { reader: &mut R, total_bytes_read: &mut u32, buf: &mut [u8], - ) -> Result<(), XorbObjectError> { + ) -> Result<(), CoreError> { reader.read_exact(buf).await?; *total_bytes_read += buf.len() as u32; Ok(()) @@ -399,7 +398,7 @@ impl XorbObjectInfoV1 { /// Serialize XorbObjectInfoV1 to provided Writer. /// /// Assumes caller has set position of Writer to appropriate location for serialization. - pub fn serialize(&self, writer: &mut W) -> Result { + pub fn serialize(&self, writer: &mut W) -> Result { let mut counting_writer = countio::Counter::new(writer); let w = &mut counting_writer; @@ -422,7 +421,7 @@ impl XorbObjectInfoV1 { if self.num_chunks as usize != self.chunk_hashes.len() { debug_assert_eq!(self.num_chunks as usize, self.chunk_hashes.len()); - return Err(XorbObjectError::Format(anyhow!( + return Err(CoreError::MalformedData(format!( "Chunk hash vector not correct length on serialization. ({}, expected {})", self.chunk_hashes.len(), self.num_chunks @@ -443,7 +442,7 @@ impl XorbObjectInfoV1 { // write variable field: chunk boundaries if self.num_chunks as usize != self.chunk_boundary_offsets.len() { debug_assert_eq!(self.num_chunks as usize, self.chunk_boundary_offsets.len()); - return Err(XorbObjectError::Format(anyhow!( + return Err(CoreError::MalformedData(format!( "Chunk boundary offset vector not correct length on serialization. ({}, expected {})", self.chunk_boundary_offsets.len(), self.num_chunks @@ -454,7 +453,7 @@ impl XorbObjectInfoV1 { // write variable field: unpacked chunk data offsets if self.num_chunks as usize != self.unpacked_chunk_offsets.len() { debug_assert_eq!(self.num_chunks as usize, self.unpacked_chunk_offsets.len()); - return Err(XorbObjectError::Format(anyhow!( + return Err(CoreError::MalformedData(format!( "Unpacked chunk offset vector not correct length on serialization. ({}, expected {})", self.unpacked_chunk_offsets.len(), self.num_chunks @@ -481,7 +480,7 @@ impl XorbObjectInfoV1 { /// Construct XorbObjectInfo object from Reader + Seek. /// /// Expects metadata struct is found at end of Reader, written out in struct order. - pub fn deserialize(reader: &mut R) -> Result<(Self, u32), XorbObjectError> { + pub fn deserialize(reader: &mut R) -> Result<(Self, u32), CoreError> { let mut counting_reader = countio::Counter::new(reader); let r = &mut counting_reader; @@ -493,7 +492,7 @@ impl XorbObjectInfoV1 { read_bytes(r, &mut s.ident)?; if s.ident != XORB_OBJECT_FORMAT_IDENT { - return Err(XorbObjectError::Format(anyhow!("Xorb Invalid Ident"))); + return Err(CoreError::MalformedData("Xorb Invalid Ident".to_string())); } s.version = read_u8(r)?; @@ -503,7 +502,7 @@ impl XorbObjectInfoV1 { // we don't have the missing info (unpacked_chunk_offsets), it's OK return Ok((Self::from_v0(sv0), r.reader_bytes() as u32)); } else if s.version != XORB_OBJECT_FORMAT_VERSION { - return Err(XorbObjectError::Format(anyhow!("Xorb Invalid Format Version"))); + return Err(CoreError::MalformedData("Xorb Invalid Format Version".to_string())); } s.xorb_hash = read_hash(r)?; @@ -516,13 +515,13 @@ impl XorbObjectInfoV1 { read_bytes(r, &mut s.ident_hash_section)?; if s.ident_hash_section != XORB_OBJECT_FORMAT_IDENT_HASHES { - return Err(XorbObjectError::Format(anyhow!("Xorb Invalid Ident for Hash Metadata Section"))); + return Err(CoreError::MalformedData("Xorb Invalid Ident for Hash Metadata Section".to_string())); } s.hashes_version = read_u8(r)?; if s.hashes_version != XORB_OBJECT_FORMAT_HASHES_VERSION { - return Err(XorbObjectError::Format(anyhow!("Xorb Invalid Format Version for Hash Metadata Section"))); + return Err(CoreError::MalformedData("Xorb Invalid Format Version for Hash Metadata Section".to_string())); } let num_chunks_2 = read_u32(r)?; @@ -541,23 +540,23 @@ impl XorbObjectInfoV1 { read_bytes(r, &mut s.ident_boundary_section)?; if s.ident_boundary_section != XORB_OBJECT_FORMAT_IDENT_BOUNDARIES { - return Err(XorbObjectError::Format(anyhow!("Xorb Invalid Ident for Boundary Metadata Section"))); + return Err(CoreError::MalformedData("Xorb Invalid Ident for Boundary Metadata Section".to_string())); } s.boundaries_version = read_u8(r)?; if s.boundaries_version != XORB_OBJECT_FORMAT_BOUNDARIES_VERSION { - return Err(XorbObjectError::Format(anyhow!( - "Xorb Invalid Format Version for Boundaries Metadata Section" - ))); + return Err(CoreError::MalformedData( + "Xorb Invalid Format Version for Boundaries Metadata Section".to_string(), + )); } let num_chunks_3 = read_u32(r)?; if num_chunks_2 != num_chunks_3 { - return Err(XorbObjectError::Format(anyhow!( - "Xorb Invalid: inconsistent num_chunks between hashes and boundaries section." - ))); + return Err(CoreError::MalformedData( + "Xorb Invalid: inconsistent num_chunks between hashes and boundaries section.".to_string(), + )); } s.chunk_boundary_offsets.reserve(prealloc_num_chunks(num_chunks_3 as usize)); @@ -574,9 +573,9 @@ impl XorbObjectInfoV1 { s.num_chunks = read_u32(r)?; if s.num_chunks != num_chunks_2 { - return Err(XorbObjectError::Format(anyhow!( - "Xorb Invalid: inconsistent num_chunks between metadata and hashes section." - ))); + return Err(CoreError::MalformedData( + "Xorb Invalid: inconsistent num_chunks between metadata and hashes section.".to_string(), + )); } s.hashes_section_offset_from_end = read_u32(r)?; @@ -587,11 +586,15 @@ impl XorbObjectInfoV1 { let end_byte_offset = r.reader_bytes(); if end_byte_offset - hash_section_begin_byte_offset != s.hashes_section_offset_from_end as usize { - return Err(XorbObjectError::Format(anyhow!("Xorb Invalid: incorrect hashes_section_offset_from_end."))); + return Err(CoreError::MalformedData( + "Xorb Invalid: incorrect hashes_section_offset_from_end.".to_string(), + )); } if end_byte_offset - boundary_section_begin_byte_offset != s.boundary_section_offset_from_end as usize { - return Err(XorbObjectError::Format(anyhow!("Xorb Invalid: incorrect boundary_section_offset_from_end."))); + return Err(CoreError::MalformedData( + "Xorb Invalid: incorrect boundary_section_offset_from_end.".to_string(), + )); } Ok((s, r.reader_bytes() as u32)) @@ -600,7 +603,7 @@ impl XorbObjectInfoV1 { /// Construct XorbObjectInfo object from Reader + Seek. /// /// Expects metadata struct is found at end of Reader, written out in struct order. - pub fn deserialize_only_boundaries_section(reader: &mut R) -> Result<(Self, u32), XorbObjectError> { + pub fn deserialize_only_boundaries_section(reader: &mut R) -> Result<(Self, u32), CoreError> { let mut s = Self::default(); // info_length + size of _buffer + size of u32 for offset field @@ -622,15 +625,15 @@ impl XorbObjectInfoV1 { read_bytes(r, &mut s.ident_boundary_section)?; if s.ident_boundary_section != XORB_OBJECT_FORMAT_IDENT_BOUNDARIES { - return Err(XorbObjectError::Format(anyhow!("Xorb Invalid Ident for Boundary Metadata Section"))); + return Err(CoreError::MalformedData("Xorb Invalid Ident for Boundary Metadata Section".to_string())); } s.boundaries_version = read_u8(r)?; if s.boundaries_version != XORB_OBJECT_FORMAT_BOUNDARIES_VERSION { - return Err(XorbObjectError::Format(anyhow!( - "Xorb Invalid Format Version for Boundaries Metadata Section" - ))); + return Err(CoreError::MalformedData( + "Xorb Invalid Format Version for Boundaries Metadata Section".to_string(), + )); } let num_chunks_boundaries_section = read_u32(r)?; @@ -645,9 +648,9 @@ impl XorbObjectInfoV1 { s.num_chunks = read_u32(r)?; if s.num_chunks != num_chunks_boundaries_section { - return Err(XorbObjectError::Format(anyhow!( - "Xorb Invalid: inconsistent num_chunks between metadata and hashes section." - ))); + return Err(CoreError::MalformedData( + "Xorb Invalid: inconsistent num_chunks between metadata and hashes section.".to_string(), + )); } s.hashes_section_offset_from_end = read_u32(r)?; @@ -658,7 +661,9 @@ impl XorbObjectInfoV1 { let end_byte_offset = r.reader_bytes(); if end_byte_offset != s.boundary_section_offset_from_end as usize { - return Err(XorbObjectError::Format(anyhow!("Xorb Invalid: incorrect boundary_section_offset_from_end."))); + return Err(CoreError::MalformedData( + "Xorb Invalid: incorrect boundary_section_offset_from_end.".to_string(), + )); } debug_assert!(s.chunk_hashes.is_empty()); @@ -669,7 +674,7 @@ impl XorbObjectInfoV1 { #[cfg(not(target_family = "wasm"))] pub async fn deserialize_async_v1( reader: &mut R, - ) -> Result<(Self, u32), XorbObjectError> { + ) -> Result<(Self, u32), CoreError> { // already read 8 bytes (ident + version) let total_bytes_read: u32 = (size_of::() + size_of::()) as u32; @@ -692,13 +697,13 @@ impl XorbObjectInfoV1 { read_bytes_async(r, &mut s.ident_hash_section).await?; if s.ident_hash_section != XORB_OBJECT_FORMAT_IDENT_HASHES { - return Err(XorbObjectError::Format(anyhow!("Xorb Invalid Ident for Hash Metadata Section"))); + return Err(CoreError::MalformedData("Xorb Invalid Ident for Hash Metadata Section".to_string())); } s.hashes_version = read_u8_async(r).await?; if s.hashes_version != XORB_OBJECT_FORMAT_HASHES_VERSION { - return Err(XorbObjectError::Format(anyhow!("Xorb Invalid Format Version for Hash Metadata Section"))); + return Err(CoreError::MalformedData("Xorb Invalid Format Version for Hash Metadata Section".to_string())); } let num_chunks_2 = read_u32_async(r).await?; @@ -717,23 +722,23 @@ impl XorbObjectInfoV1 { read_bytes_async(r, &mut s.ident_boundary_section).await?; if s.ident_boundary_section != XORB_OBJECT_FORMAT_IDENT_BOUNDARIES { - return Err(XorbObjectError::Format(anyhow!("Xorb Invalid Ident for Boundary Metadata Section"))); + return Err(CoreError::MalformedData("Xorb Invalid Ident for Boundary Metadata Section".to_string())); } s.boundaries_version = read_u8_async(r).await?; if s.boundaries_version != XORB_OBJECT_FORMAT_BOUNDARIES_VERSION { - return Err(XorbObjectError::Format(anyhow!( - "Xorb Invalid Format Version for Boundaries Metadata Section" - ))); + return Err(CoreError::MalformedData( + "Xorb Invalid Format Version for Boundaries Metadata Section".to_string(), + )); } let num_chunks_3 = read_u32_async(r).await?; if num_chunks_2 != num_chunks_3 { - return Err(XorbObjectError::Format(anyhow!( - "Xorb Invalid: inconsistent num_chunks between hashes and boundaries section." - ))); + return Err(CoreError::MalformedData( + "Xorb Invalid: inconsistent num_chunks between hashes and boundaries section.".to_string(), + )); } s.chunk_boundary_offsets.reserve(prealloc_num_chunks(num_chunks_3 as usize)); @@ -749,9 +754,9 @@ impl XorbObjectInfoV1 { s.num_chunks = read_u32_async(r).await?; if s.num_chunks != num_chunks_2 { - return Err(XorbObjectError::Format(anyhow!( - "Xorb Invalid: inconsistent num_chunks between metadata and hashes section." - ))); + return Err(CoreError::MalformedData( + "Xorb Invalid: inconsistent num_chunks between metadata and hashes section.".to_string(), + )); } s.hashes_section_offset_from_end = read_u32_async(r).await?; @@ -762,11 +767,15 @@ impl XorbObjectInfoV1 { let end_byte_offset = r.reader_bytes(); if end_byte_offset - hash_section_begin_byte_offset != s.hashes_section_offset_from_end as usize { - return Err(XorbObjectError::Format(anyhow!("Xorb Invalid: incorrect hashes_section_offset_from_end."))); + return Err(CoreError::MalformedData( + "Xorb Invalid: incorrect hashes_section_offset_from_end.".to_string(), + )); } if end_byte_offset - boundary_section_begin_byte_offset != s.boundary_section_offset_from_end as usize { - return Err(XorbObjectError::Format(anyhow!("Xorb Invalid: incorrect boundary_section_offset_from_end."))); + return Err(CoreError::MalformedData( + "Xorb Invalid: incorrect boundary_section_offset_from_end.".to_string(), + )); } Ok((s, r.reader_bytes() as u32 + total_bytes_read)) @@ -779,7 +788,7 @@ impl XorbObjectInfoV1 { pub async fn deserialize_async( reader: &mut R, version: u8, - ) -> Result<(Self, u32), XorbObjectError> { + ) -> Result<(Self, u32), CoreError> { if version == 0 { let (s, n) = XorbObjectInfoV0::deserialize_async(reader, 0).await?; // we don't have the missing info (unpacked_chunk_offsets), it's OK @@ -787,7 +796,7 @@ impl XorbObjectInfoV1 { } else if version == 1 { Self::deserialize_async_v1(reader).await } else { - Err(XorbObjectError::Format(anyhow!( + Err(CoreError::MalformedData(format!( "Xorb Format Error: Version {version} not supported by this code version." ))) } @@ -908,7 +917,7 @@ impl XorbObject { /// make up the info portion of the xorb. /// /// Assumes reader has at least size_of::() bytes, otherwise returns an error. - pub fn get_info_length(reader: &mut R) -> Result { + pub fn get_info_length(reader: &mut R) -> Result { // Go to end of Reader and get length, then jump back to it, and read sequentially // read last 4 bytes to get length reader.seek(SeekFrom::End(-(size_of::() as i64)))?; @@ -922,7 +931,7 @@ impl XorbObject { /// Deserialize the XorbObjectInfo struct, the metadata for this Xorb. /// /// This allows the XorbObject to be partially constructed, allowing for range reads inside the XorbObject. - pub fn deserialize(reader: &mut R) -> Result { + pub fn deserialize(reader: &mut R) -> Result { let info_length = Self::get_info_length(reader)?; // now seek back that many bytes + size of length (u32) and read sequentially. @@ -932,7 +941,7 @@ impl XorbObject { // validate that info_length matches what we read off of header if total_bytes_read != info_length { - return Err(XorbObjectError::Format(anyhow!("Xorb Info Format Error"))); + return Err(CoreError::MalformedData("Xorb Info Format Error".to_string())); } Ok(Self { info, info_length }) @@ -944,7 +953,7 @@ impl XorbObject { pub async fn deserialize_async( reader: &mut R, version: u8, - ) -> Result { + ) -> Result { let (info, total_bytes_read) = XorbObjectInfoV1::deserialize_async(reader, version).await?; let mut info_length_buf = [0u8; size_of::()]; @@ -954,18 +963,20 @@ impl XorbObject { let info_length = u32::from_le_bytes(info_length_buf); if info_length != total_bytes_read { - return Err(XorbObjectError::Format(anyhow!("Xorb Info Format Error"))); + return Err(CoreError::MalformedData("Xorb Info Format Error".to_string())); } // verify we've read to the end if reader.read(&mut [0u8; 8]).await? != 0 { - return Err(XorbObjectError::Format(anyhow!("Xorb Reader has content past the end of serialized xorb"))); + return Err(CoreError::MalformedData( + "Xorb Reader has content past the end of serialized xorb".to_string(), + )); } Ok(Self { info, info_length }) } - pub fn serialize_given_info(w: &mut W, info: XorbObjectInfoV1) -> Result<(Self, usize), XorbObjectError> { + pub fn serialize_given_info(w: &mut W, info: XorbObjectInfoV1) -> Result<(Self, usize), CoreError> { let mut total_written_bytes: usize = 0; let info_length = info.serialize(w)? as u32; total_written_bytes += info_length as usize; @@ -993,7 +1004,7 @@ impl XorbObject { pub fn validate_xorb_object( reader: &mut R, hash: &MerkleHash, - ) -> Result, XorbObjectError> { + ) -> Result, CoreError> { // 1. deserialize to get Info // Errors can occur if either // - the object doesn't have at least 4 bytes for the "info_length"; @@ -1092,11 +1103,11 @@ impl XorbObject { &self, chunk_start_index: u32, chunk_end_index: u32, - ) -> Result { + ) -> Result { self.validate_xorb_object_info()?; if chunk_end_index <= chunk_start_index || chunk_end_index > self.info.num_chunks { - return Err(XorbObjectError::InvalidArguments); + return Err(CoreError::InvalidArguments); } // Collect relevant hashes @@ -1106,11 +1117,11 @@ impl XorbObject { } /// Return end offset of all physical chunk contents (byte index at the beginning of footer) - pub fn get_contents_length(&self) -> Result { + pub fn get_contents_length(&self) -> Result { self.validate_xorb_object_info()?; match self.info.chunk_boundary_offsets.last() { Some(c) => Ok(*c), - None => Err(XorbObjectError::Format(anyhow!("Cannot retrieve content length"))), + None => Err(CoreError::MalformedData("Cannot retrieve content length".to_string())), } } @@ -1119,14 +1130,9 @@ impl XorbObject { /// start and end are byte indices into the physical layout of a xorb. /// /// The start and end parameters are required to align with chunk boundaries. - fn get_range( - &self, - reader: &mut R, - byte_start: u32, - byte_end: u32, - ) -> Result, XorbObjectError> { + fn get_range(&self, reader: &mut R, byte_start: u32, byte_end: u32) -> Result, CoreError> { if byte_end < byte_start { - return Err(XorbObjectError::InvalidRange); + return Err(CoreError::InvalidRange); } self.validate_xorb_object_info()?; @@ -1145,7 +1151,7 @@ impl XorbObject { } /// Get all the content bytes from a Xorb - pub fn get_all_bytes(&self, reader: &mut R) -> Result, XorbObjectError> { + pub fn get_all_bytes(&self, reader: &mut R) -> Result, CoreError> { self.validate_xorb_object_info()?; self.get_range(reader, 0, self.get_contents_length()?) } @@ -1156,14 +1162,14 @@ impl XorbObject { reader: &mut R, chunk_index_start: u32, chunk_index_end: u32, - ) -> Result, XorbObjectError> { + ) -> Result, CoreError> { let (byte_start, byte_end) = self.get_byte_offset(chunk_index_start, chunk_index_end)?; self.get_range(reader, byte_start, byte_end) } /// Assumes chunk_data is 1+ complete chunks. Processes them sequentially and returns them as Vec. - fn get_chunk_contents(&self, chunk_data: &[u8]) -> Result, XorbObjectError> { + fn get_chunk_contents(&self, chunk_data: &[u8]) -> Result, CoreError> { // walk chunk_data, deserialize into Chunks, and then get_bytes() from each of them. let mut reader = Cursor::new(chunk_data); let mut res = Vec::::new(); @@ -1176,10 +1182,10 @@ impl XorbObject { } /// Helper function to translate a range of chunk indices to physical byte offset range. - pub fn get_byte_offset(&self, chunk_index_start: u32, chunk_index_end: u32) -> Result<(u32, u32), XorbObjectError> { + pub fn get_byte_offset(&self, chunk_index_start: u32, chunk_index_end: u32) -> Result<(u32, u32), CoreError> { self.validate_xorb_object_info()?; if chunk_index_end <= chunk_index_start || chunk_index_end > self.info.num_chunks { - return Err(XorbObjectError::InvalidArguments); + return Err(CoreError::InvalidArguments); } let byte_offset_start = match chunk_index_start { @@ -1193,11 +1199,11 @@ impl XorbObject { /// given a valid chunk_index, returns the uncompressed chunk length for the chunk /// at the given index, chunk_index must be less than the number of chunks in the xorb - pub fn uncompressed_chunk_length(&self, chunk_index: u32) -> Result { + pub fn uncompressed_chunk_length(&self, chunk_index: u32) -> Result { self.validate_xorb_object_info()?; let chunk_index = chunk_index as usize; if chunk_index >= self.info.unpacked_chunk_offsets.len() { - return Err(XorbObjectError::InvalidArguments); + return Err(CoreError::InvalidArguments); } let cumulative_sum = self.info.unpacked_chunk_offsets[chunk_index]; let before = match chunk_index { @@ -1213,17 +1219,13 @@ impl XorbObject { /// chunk_index_start <= chunk_index_end && /// chunk_index_end <= num_chunks && /// chunk_index_start < num_chunks - pub fn uncompressed_range_length( - &self, - chunk_index_start: u32, - chunk_index_end: u32, - ) -> Result { + pub fn uncompressed_range_length(&self, chunk_index_start: u32, chunk_index_end: u32) -> Result { self.validate_xorb_object_info()?; if chunk_index_start > chunk_index_end || chunk_index_end > self.info.num_chunks || chunk_index_start >= self.info.num_chunks { - return Err(XorbObjectError::InvalidArguments); + return Err(CoreError::InvalidArguments); } // this check is important if chunk_index_end is 0 @@ -1240,9 +1242,9 @@ impl XorbObject { } /// Helper method to verify that info object is complete - fn validate_xorb_object_info(&self) -> Result<(), XorbObjectError> { + fn validate_xorb_object_info(&self) -> Result<(), CoreError> { if self.info.num_chunks == 0 { - return Err(XorbObjectError::Format(anyhow!("Invalid XorbObjectInfo, no chunks in XorbObject."))); + return Err(CoreError::MalformedData("Invalid XorbObjectInfo, no chunks in XorbObject.".to_string())); } if self.info.num_chunks != self.info.chunk_boundary_offsets.len() as u32 @@ -1250,13 +1252,13 @@ impl XorbObject { || (self.info.boundaries_version == XORB_OBJECT_FORMAT_BOUNDARIES_VERSION && self.info.num_chunks != self.info.unpacked_chunk_offsets.len() as u32) { - return Err(XorbObjectError::Format(anyhow!( - "Invalid XorbObjectInfo, num chunks not matching boundaries or hashes." - ))); + return Err(CoreError::MalformedData( + "Invalid XorbObjectInfo, num chunks not matching boundaries or hashes.".to_string(), + )); } if self.info.xorb_hash == MerkleHash::default() { - return Err(XorbObjectError::Format(anyhow!("Invalid XorbObjectInfo, Missing xorb_hash."))); + return Err(CoreError::MalformedData("Invalid XorbObjectInfo, Missing xorb_hash.".to_string())); } Ok(()) @@ -1285,7 +1287,7 @@ impl SerializedXorbObject { /// /// The compression scheme is determined by `HF_XET_XORB_COMPRESSION_POLICY`: /// auto-detect (default) or an explicit scheme (none, lz4, bg4-lz4). - pub fn from_xorb(xorb: RawXorbData, serialize_footer: bool) -> Result { + pub fn from_xorb(xorb: RawXorbData, serialize_footer: bool) -> Result { let compression_scheme: CompressionScheme = xet_config().xorb.compression_policy.parse()?; Self::from_xorb_with_compression(xorb, compression_scheme, serialize_footer) } @@ -1295,7 +1297,7 @@ impl SerializedXorbObject { xorb: RawXorbData, compression_scheme: CompressionScheme, serialize_footer: bool, - ) -> Result { + ) -> Result { let mut xorb_object_info = XorbObjectInfoV1::default(); let hash = xorb.hash(); @@ -1396,7 +1398,7 @@ pub mod test_utils { data: &[u8], chunk_and_boundaries: &[(MerkleHash, u32)], compression_scheme: CompressionScheme, - ) -> Result<(XorbObject, usize, u64), XorbObjectError> { + ) -> Result<(XorbObject, usize, u64), CoreError> { let mut xorb = XorbObject::default(); xorb.info.xorb_hash = *hash; xorb.info.num_chunks = chunk_and_boundaries.len() as u32; @@ -1446,7 +1448,7 @@ pub mod test_utils { data: Vec, chunk_and_boundaries: Vec<(MerkleHash, u32)>, compression: CompressionScheme, - ) -> Result { + ) -> Result { let mut writer = Cursor::new(Vec::new()); let (_, _, footer_start) = @@ -1647,7 +1649,7 @@ pub mod test_utils { pub fn reconstruct_xorb_with_footer( writer: &mut impl Write, raw_data: &[u8], -) -> Result<(XorbObject, MerkleHash), XorbObjectError> { +) -> Result<(XorbObject, MerkleHash), CoreError> { let mut reader = Cursor::new(raw_data); let mut chunk_hash_and_size: Vec<(MerkleHash, u64)> = Vec::new(); let mut info = XorbObjectInfoV1::default(); @@ -1655,7 +1657,7 @@ pub fn reconstruct_xorb_with_footer( while (reader.position() as usize) < raw_data.len() { let chunk_header = match deserialize_chunk_header(&mut reader) { Ok(header) => header, - Err(XorbObjectError::ChunkHeaderParse) => { + Err(CoreError::ChunkHeaderParse) => { // Hit footer identifier, stop processing chunks break; }, @@ -1666,12 +1668,12 @@ pub fn reconstruct_xorb_with_footer( let mut compressed_buf = vec![0u8; compressed_len]; reader .read_exact(&mut compressed_buf) - .map_err(|e| XorbObjectError::Format(anyhow!("Failed to read chunk data: {e}")))?; + .map_err(|e| CoreError::MalformedData(format!("Failed to read chunk data: {e}")))?; let uncompressed_data = chunk_header .get_compression_scheme()? .decompress_from_slice(&compressed_buf) - .map_err(|e| XorbObjectError::Format(anyhow!("Failed to decompress chunk: {e}")))?; + .map_err(|e| CoreError::MalformedData(format!("Failed to decompress chunk: {e}")))?; let chunk_hash = crate::merklehash::compute_data_hash(&uncompressed_data); chunk_hash_and_size.push((chunk_hash, uncompressed_data.len() as u64)); @@ -1790,9 +1792,9 @@ mod tests { build_xorb_object(5, ChunkSize::Fixed(100), CompressionScheme::None); // Act & Assert - assert_eq!(c.generate_chunk_range_hash(1, 6), Err(XorbObjectError::InvalidArguments)); - assert_eq!(c.generate_chunk_range_hash(100, 10), Err(XorbObjectError::InvalidArguments)); - assert_eq!(c.generate_chunk_range_hash(0, 0), Err(XorbObjectError::InvalidArguments)); + assert_eq!(c.generate_chunk_range_hash(1, 6), Err(CoreError::InvalidArguments)); + assert_eq!(c.generate_chunk_range_hash(100, 10), Err(CoreError::InvalidArguments)); + assert_eq!(c.generate_chunk_range_hash(0, 0), Err(CoreError::InvalidArguments)); } #[test] @@ -1806,7 +1808,10 @@ mod tests { // no chunks let c = XorbObject::default(); let result = c.validate_xorb_object_info(); - assert_eq!(result, Err(XorbObjectError::Format(anyhow!("Invalid XorbObjectInfo, no chunks in XorbObject.")))); + assert_eq!( + result, + Err(CoreError::MalformedData("Invalid XorbObjectInfo, no chunks in XorbObject.".to_string())) + ); // num_chunks doesn't match chunk_boundaries.len() let mut c = XorbObject::default(); @@ -1814,9 +1819,9 @@ mod tests { let result = c.validate_xorb_object_info(); assert_eq!( result, - Err(XorbObjectError::Format(anyhow!( - "Invalid XorbObjectInfo, num chunks not matching boundaries or hashes." - ))) + Err(CoreError::MalformedData( + "Invalid XorbObjectInfo, num chunks not matching boundaries or hashes.".to_string(), + )) ); // no hash @@ -1824,7 +1829,7 @@ mod tests { build_xorb_object(1, ChunkSize::Fixed(100), CompressionScheme::None); c.info.xorb_hash = MerkleHash::default(); let result = c.validate_xorb_object_info(); - assert_eq!(result, Err(XorbObjectError::Format(anyhow!("Invalid XorbObjectInfo, Missing xorb_hash.")))); + assert_eq!(result, Err(CoreError::MalformedData("Invalid XorbObjectInfo, Missing xorb_hash.".to_string()))); } #[test] diff --git a/xet_data/src/error.rs b/xet_data/src/error.rs index a58acaab..74804628 100644 --- a/xet_data/src/error.rs +++ b/xet_data/src/error.rs @@ -6,7 +6,7 @@ use tokio::sync::AcquireError; use tracing::error; use xet_client::ClientError; use xet_client::cas_client::auth::AuthError; -use xet_core_structures::FormatError; +use xet_core_structures::CoreError; use xet_core_structures::merklehash::DataHashHexParseError; use xet_runtime::RuntimeError; use xet_runtime::core::par_utils::ParutilsError; @@ -44,8 +44,8 @@ pub enum DataError { #[error("Channel error: {0}")] ChannelRecvError(#[from] RecvError), - #[error("Format error: {0}")] - FormatError(#[from] FormatError), + #[error("Core structures error: {0}")] + FormatError(#[from] CoreError), #[error("Client error: {0}")] ClientError(#[from] ClientError), diff --git a/xet_data/src/file_reconstruction/error.rs b/xet_data/src/file_reconstruction/error.rs index a025c8b2..06fe4e3f 100644 --- a/xet_data/src/file_reconstruction/error.rs +++ b/xet_data/src/file_reconstruction/error.rs @@ -7,7 +7,7 @@ use thiserror::Error; #[derive(Error, Debug, Clone)] pub enum FileReconstructionError { #[error("CAS Client Error: {0}")] - CasClientError(Arc), + ClientError(Arc), #[error("IO Error: {0}")] IoError(Arc), @@ -31,7 +31,7 @@ pub enum FileReconstructionError { TaskJoinError(Arc), #[error("Runtime Error: {0}")] - RuntimeError(Arc), + RuntimeError(Arc), } pub type Result = std::result::Result; @@ -42,9 +42,9 @@ impl From for FileReconstructionError { } } -impl From for FileReconstructionError { - fn from(err: xet_client::cas_client::CasClientError) -> Self { - FileReconstructionError::CasClientError(Arc::new(err)) +impl From for FileReconstructionError { + fn from(err: xet_client::ClientError) -> Self { + FileReconstructionError::ClientError(Arc::new(err)) } } @@ -60,8 +60,8 @@ impl From for FileReconstructionError { } } -impl From for FileReconstructionError { - fn from(err: xet_runtime::core::errors::MultithreadedRuntimeError) -> Self { +impl From for FileReconstructionError { + fn from(err: xet_runtime::RuntimeError) -> Self { FileReconstructionError::RuntimeError(Arc::new(err)) } } diff --git a/xet_data/src/file_reconstruction/reconstruction_terms/retrieval_urls.rs b/xet_data/src/file_reconstruction/reconstruction_terms/retrieval_urls.rs index c3394dfd..c585371b 100644 --- a/xet_data/src/file_reconstruction/reconstruction_terms/retrieval_urls.rs +++ b/xet_data/src/file_reconstruction/reconstruction_terms/retrieval_urls.rs @@ -139,20 +139,18 @@ pub struct XorbURLProvider { #[async_trait::async_trait] impl URLProvider for XorbURLProvider { - async fn retrieve_url( - &self, - ) -> std::result::Result<(String, Vec), xet_client::cas_client::CasClientError> { + async fn retrieve_url(&self) -> std::result::Result<(String, Vec), xet_client::ClientError> { let (unique_id, url, http_ranges) = self.url_info.get_retrieval_url(self.xorb_block_index).await; *self.last_acquisition_id.lock().await = unique_id; Ok((url, http_ranges)) } - async fn refresh_url(&self) -> std::result::Result<(), xet_client::cas_client::CasClientError> { + async fn refresh_url(&self) -> std::result::Result<(), xet_client::ClientError> { self.url_info .refresh_retrieval_urls(self.client.clone(), *self.last_acquisition_id.lock().await) .await - .map_err(|e| xet_client::cas_client::CasClientError::Other(e.to_string())) + .map_err(|e| xet_client::ClientError::Other(e.to_string())) } } diff --git a/xet_data/src/lib.rs b/xet_data/src/lib.rs index 427e0657..3890697f 100644 --- a/xet_data/src/lib.rs +++ b/xet_data/src/lib.rs @@ -1,7 +1,7 @@ #![cfg_attr(feature = "strict", deny(warnings))] pub mod error; -pub use error::DataError; +pub use error::{DataError, Result}; pub mod deduplication; #[cfg(not(target_family = "wasm"))] diff --git a/xet_data/src/processing/bin/example.rs b/xet_data/src/processing/bin/example.rs index aa41a98a..1a6147b7 100644 --- a/xet_data/src/processing/bin/example.rs +++ b/xet_data/src/processing/bin/example.rs @@ -128,8 +128,9 @@ async fn smudge(_name: Arc, mut reader: impl Read, output_path: PathBuf) -> let mut input = String::new(); reader.read_to_string(&mut input)?; - let xet_file: XetFileInfo = serde_json::from_str(&input) - .map_err(|_| anyhow::anyhow!("Failed to parse xet file info. Please check the format."))?; + let xet_file: XetFileInfo = serde_json::from_str(&input).map_err(|_| { + std::io::Error::new(std::io::ErrorKind::InvalidData, "Failed to parse xet file info. Please check the format.") + })?; // Use local config pointing to current directory let cas_path = std::env::current_dir()?; diff --git a/xet_data/src/processing/bin/xtool.rs b/xet_data/src/processing/bin/xtool.rs index 7cbc8af0..8b7b4107 100644 --- a/xet_data/src/processing/bin/xtool.rs +++ b/xet_data/src/processing/bin/xtool.rs @@ -220,7 +220,7 @@ async fn query_reconstruction( remote_client .get_reconstruction_v1(&file_hash, bytes_range) .await - .map_err(anyhow::Error::from) + .map_err(Into::into) } fn main() -> Result<()> { @@ -231,13 +231,16 @@ fn main() -> Result<()> { && let Some(c) = arg.compression { let scheme = CompressionScheme::try_from(c).map_err(|_| { - anyhow::anyhow!("Invalid compression value {c}; expected one of: 0 (none), 1 (lz4), 2 (bg4-lz4), 99 (auto)") + std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!("Invalid compression value {c}; expected one of: 0 (none), 1 (lz4), 2 (bg4-lz4), 99 (auto)"), + ) })?; config .xorb .compression_policy .try_set(<&str>::from(scheme)) - .map_err(|e| anyhow::anyhow!(e))?; + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidInput, e.to_string()))?; } let threadpool = XetRuntime::new_with_config(config)?; diff --git a/xet_data/src/processing/configurations.rs b/xet_data/src/processing/configurations.rs index bfd2bdba..1a27f578 100644 --- a/xet_data/src/processing/configurations.rs +++ b/xet_data/src/processing/configurations.rs @@ -6,7 +6,7 @@ use tracing::info; use xet_client::cas_client::auth::AuthConfig; use xet_runtime::core::{xet_cache_root, xet_config}; -use super::errors::Result; +use crate::error::Result; /// Session-specific configuration that varies per upload/download session. /// These are runtime values that cannot be configured via environment variables. diff --git a/xet_data/src/processing/data_client.rs b/xet_data/src/processing/data_client.rs index 60aca129..ce064930 100644 --- a/xet_data/src/processing/data_client.rs +++ b/xet_data/src/processing/data_client.rs @@ -14,15 +14,16 @@ use xet_runtime::core::{XetRuntime, check_sigint_shutdown, xet_config}; use super::configurations::{SessionContext, TranslatorConfig}; use super::file_cleaner::Sha256Policy; -use super::{FileUploadSession, XetFileInfo, errors}; +use super::{FileUploadSession, XetFileInfo}; use crate::deduplication::{Chunker, DeduplicationMetrics}; +use crate::error::Result; pub fn default_config( endpoint: String, token_info: Option<(String, u64)>, token_refresher: Option>, custom_headers: Option>, -) -> errors::Result { +) -> Result { let (token, token_expiration) = token_info.unzip(); let auth_cfg = AuthConfig::maybe_new(token, token_expiration, token_refresher); @@ -42,7 +43,7 @@ pub async fn clean_bytes( processor: Arc, bytes: Vec, sha256_policy: Sha256Policy, -) -> errors::Result<(XetFileInfo, DeduplicationMetrics)> { +) -> Result<(XetFileInfo, DeduplicationMetrics)> { let (_id, mut handle) = processor.start_clean(None, bytes.len() as u64, sha256_policy)?; handle.add_data(&bytes).await?; handle.finish().await @@ -53,7 +54,7 @@ pub async fn clean_file( processor: Arc, filename: impl AsRef, sha256_policy: Sha256Policy, -) -> errors::Result<(XetFileInfo, DeduplicationMetrics)> { +) -> Result<(XetFileInfo, DeduplicationMetrics)> { let mut reader = File::open(&filename)?; let filesize = reader.metadata()?.len(); @@ -98,7 +99,7 @@ pub async fn clean_file( /// - Verify that downloaded files are correctly reassembled /// - Check if a file needs to be uploaded (by comparing hashes) /// - Generate cache keys for local file operations -fn hash_single_file(filename: String, buffer_size: usize) -> errors::Result { +fn hash_single_file(filename: String, buffer_size: usize) -> Result { let mut reader = File::open(&filename)?; let filesize = reader.metadata()?.len(); @@ -157,7 +158,7 @@ fn hash_single_file(filename: String, buffer_size: usize) -> errors::Result) -> errors::Result> { +pub async fn hash_files_async(file_paths: Vec) -> Result> { let rt = XetRuntime::current(); let semaphore = rt.common().file_ingestion_semaphore.clone(); let buffer_size = *xet_config().data.ingestion_block_size as usize; diff --git a/xet_data/src/processing/deduplication_interface.rs b/xet_data/src/processing/deduplication_interface.rs index 2d7cd02a..daefddbe 100644 --- a/xet_data/src/processing/deduplication_interface.rs +++ b/xet_data/src/processing/deduplication_interface.rs @@ -6,9 +6,9 @@ use tracing::Instrument; use xet_core_structures::merklehash::MerkleHash; use xet_core_structures::metadata_shard::file_structs::FileDataSequenceEntry; -use super::errors::Result; use super::file_upload_session::FileUploadSession; use crate::deduplication::{DeduplicationDataInterface, RawXorbData}; +use crate::error::{DataError, Result}; use crate::progress_tracking::upload_tracking::FileXorbDependency; pub struct UploadSessionDataManager { @@ -31,7 +31,7 @@ impl UploadSessionDataManager { #[async_trait] impl DeduplicationDataInterface for UploadSessionDataManager { - type ErrorType = super::errors::DataProcessingError; + type ErrorType = DataError; /// Query for possible shards that may dedup some chunks. async fn chunk_hash_dedup_query( diff --git a/xet_data/src/processing/errors.rs b/xet_data/src/processing/errors.rs deleted file mode 100644 index 91f2eadb..00000000 --- a/xet_data/src/processing/errors.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub use crate::error::DataError as DataProcessingError; -pub type Result = std::result::Result; diff --git a/xet_data/src/processing/file_cleaner.rs b/xet_data/src/processing/file_cleaner.rs index af5bacd2..dfc5b891 100644 --- a/xet_data/src/processing/file_cleaner.rs +++ b/xet_data/src/processing/file_cleaner.rs @@ -11,10 +11,10 @@ use xet_runtime::core::{XetRuntime, xet_config}; use super::XetFileInfo; use super::deduplication_interface::UploadSessionDataManager; -use super::errors::Result; use super::file_upload_session::FileUploadSession; use super::sha256::Sha256Generator; use crate::deduplication::{Chunk, Chunker, DeduplicationMetrics, FileDeduper}; +use crate::error::Result; use crate::progress_tracking::upload_tracking::CompletionTrackerFileId; /// Controls how SHA-256 is handled during file cleaning. diff --git a/xet_data/src/processing/file_download_session.rs b/xet_data/src/processing/file_download_session.rs index 42303a38..7780c15f 100644 --- a/xet_data/src/processing/file_download_session.rs +++ b/xet_data/src/processing/file_download_session.rs @@ -12,9 +12,9 @@ use xet_client::cas_types::FileRange; use xet_runtime::core::{XetRuntime, xet_config}; use super::configurations::TranslatorConfig; -use super::errors::*; use super::remote_client_interface::create_remote_client; use super::{XetFileInfo, prometheus_metrics}; +use crate::error::{DataError, Result}; use crate::file_reconstruction::{DownloadStream, FileReconstructor}; use crate::progress_tracking::{GroupProgress, ItemProgressUpdater, UniqueID}; @@ -159,7 +159,7 @@ impl FileDownloadSession { fn check_not_finalized(&self) -> Result<()> { if self.finalized.load(Ordering::Acquire) { - return Err(DataProcessingError::InvalidOperation("FileDownloadSession already finalized".to_string())); + return Err(DataError::InvalidOperation("FileDownloadSession already finalized".to_string())); } Ok(()) } @@ -167,7 +167,7 @@ impl FileDownloadSession { /// Finalizes the session; in debug builds, asserts all items are complete. pub async fn finalize(&self) -> Result<()> { if self.finalized.swap(true, Ordering::AcqRel) { - return Err(DataProcessingError::InvalidOperation("FileDownloadSession already finalized".to_string())); + return Err(DataError::InvalidOperation("FileDownloadSession already finalized".to_string())); } #[cfg(debug_assertions)] self.progress.assert_complete(); diff --git a/xet_data/src/processing/file_upload_session.rs b/xet_data/src/processing/file_upload_session.rs index bca33bdc..2af5aba0 100644 --- a/xet_data/src/processing/file_upload_session.rs +++ b/xet_data/src/processing/file_upload_session.rs @@ -18,13 +18,13 @@ use xet_core_structures::xorb_object::SerializedXorbObject; use xet_runtime::core::{XetRuntime, xet_config}; use super::configurations::TranslatorConfig; -use super::errors::*; use super::file_cleaner::{Sha256Policy, SingleFileCleaner}; use super::remote_client_interface::create_remote_client; use super::shard_interface::SessionShardInterface; use super::{XetFileInfo, prometheus_metrics}; use crate::deduplication::constants::{MAX_XORB_BYTES, MAX_XORB_CHUNKS}; use crate::deduplication::{DataAggregator, DeduplicationMetrics, RawXorbData}; +use crate::error::{DataError, Result}; use crate::progress_tracking::upload_tracking::{CompletionTracker, FileXorbDependency}; use crate::progress_tracking::{GroupProgress, GroupProgressReport, ItemProgressReport, UniqueID}; @@ -472,7 +472,7 @@ impl FileUploadSession { return_files: bool, ) -> Result<(DeduplicationMetrics, Vec, GroupProgressReport)> { if self.finalized.swap(true, Ordering::AcqRel) { - return Err(DataProcessingError::InvalidOperation("FileUploadSession already finalized".to_string())); + return Err(DataError::InvalidOperation("FileUploadSession already finalized".to_string())); } // Register the remaining xorbs for upload. @@ -537,7 +537,7 @@ impl FileUploadSession { fn check_not_finalized(&self) -> Result<()> { if self.finalized.load(Ordering::Acquire) { - return Err(DataProcessingError::InvalidOperation("FileUploadSession already finalized".to_string())); + return Err(DataError::InvalidOperation("FileUploadSession already finalized".to_string())); } Ok(()) } diff --git a/xet_data/src/processing/migration_tool/migrate.rs b/xet_data/src/processing/migration_tool/migrate.rs index a73c7c37..728b343e 100644 --- a/xet_data/src/processing/migration_tool/migrate.rs +++ b/xet_data/src/processing/migration_tool/migrate.rs @@ -1,6 +1,5 @@ use std::sync::Arc; -use anyhow::{Result, anyhow}; use http::header; use tracing::{Instrument, Span, info_span, instrument}; use xet_client::cas_client::auth::TokenRefresher; @@ -10,9 +9,9 @@ use xet_runtime::core::XetRuntime; use xet_runtime::core::par_utils::run_constrained; use super::super::data_client::{clean_file, default_config}; -use super::super::errors::DataProcessingError; use super::super::{FileUploadSession, Sha256Policy, XetFileInfo}; use super::hub_client_token_refresher::HubClientTokenRefresher; +use crate::error::{DataError, Result}; const USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")); @@ -98,7 +97,9 @@ pub async fn migrate_files_impl( let sha256_policies: Vec = match sha256s { Some(v) => { if v.len() != file_paths.len() { - return Err(anyhow!("mismatched length of the file list and the sha256 list")); + return Err(DataError::ParameterError( + "mismatched length of the file list and the sha256 list".to_string(), + )); } v.iter().map(|s| Sha256Policy::from_hex(s)).collect() }, @@ -109,7 +110,7 @@ pub async fn migrate_files_impl( let proc = processor.clone(); async move { let (pf, metrics) = clean_file(proc, file_path, policy).await?; - Ok::<(XetFileInfo, u64), DataProcessingError>((pf, metrics.new_bytes)) + Ok::<(XetFileInfo, u64), DataError>((pf, metrics.new_bytes)) } .instrument(info_span!("clean_file")) }); diff --git a/xet_data/src/processing/mod.rs b/xet_data/src/processing/mod.rs index daf84170..ee52fd53 100644 --- a/xet_data/src/processing/mod.rs +++ b/xet_data/src/processing/mod.rs @@ -1,7 +1,6 @@ pub mod configurations; pub mod data_client; mod deduplication_interface; -pub mod errors; mod file_cleaner; mod file_download_session; mod file_upload_session; diff --git a/xet_data/src/processing/remote_client_interface.rs b/xet_data/src/processing/remote_client_interface.rs index 4d44ee39..4de2aade 100644 --- a/xet_data/src/processing/remote_client_interface.rs +++ b/xet_data/src/processing/remote_client_interface.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use xet_client::cas_client::{Client, RemoteClient}; use super::configurations::TranslatorConfig; -use super::errors::Result; +use crate::error::Result; pub(crate) async fn create_remote_client( config: &TranslatorConfig, diff --git a/xet_data/src/processing/shard_interface.rs b/xet_data/src/processing/shard_interface.rs index 96af3396..60a95581 100644 --- a/xet_data/src/processing/shard_interface.rs +++ b/xet_data/src/processing/shard_interface.rs @@ -25,7 +25,7 @@ use xet_runtime::core::xet_config; use xet_runtime::error_printer::ErrorPrinter; use super::configurations::TranslatorConfig; -use super::errors::Result; +use crate::error::Result; pub struct SessionShardInterface { session_shard_manager: Arc, diff --git a/xet_data/tests/integration_tests.rs b/xet_data/tests/integration_tests.rs index be274e06..0590e062 100644 --- a/xet_data/tests/integration_tests.rs +++ b/xet_data/tests/integration_tests.rs @@ -2,7 +2,7 @@ use std::io::Write; use std::path::Path; use std::process::Command; -use anyhow::anyhow; +use anyhow::Result; use tempfile::TempDir; use tracing::info; @@ -34,7 +34,7 @@ impl IntegrationTest { self.assets.push((name.to_owned(), arg)); } - fn run(&self) -> anyhow::Result<()> { + fn run(&self) -> Result<()> { // Create a temporary directory let tmp_repo_dest = TempDir::new().unwrap(); let tmp_path_path = tmp_repo_dest.path().to_path_buf(); @@ -94,9 +94,13 @@ impl IntegrationTest { let captures = error_re.captures(stderr_out); if let Some(captured_text) = captures { - Err(anyhow!("Test failed: {}", captured_text.get(1).unwrap().as_str())) + Err(std::io::Error::new( + std::io::ErrorKind::Other, + format!("Test failed: {}", captured_text.get(1).unwrap().as_str()), + ) + .into()) } else { - Err(anyhow!("Test failed: Unknown Error.")) + Err(std::io::Error::new(std::io::ErrorKind::Other, "Test failed: Unknown Error.").into()) } } } @@ -104,9 +108,11 @@ impl IntegrationTest { #[cfg(all(test, unix))] mod git_integration_tests { + use anyhow::Result; + use super::*; #[test] - fn test_basic_read() -> anyhow::Result<()> { + fn test_basic_read() -> Result<()> { IntegrationTest::new(include_str!("integration_tests/test_basic_clean_smudge.sh")).run() } } diff --git a/xet_pkg/Cargo.toml b/xet_pkg/Cargo.toml index 14dcc0be..cc39063d 100644 --- a/xet_pkg/Cargo.toml +++ b/xet_pkg/Cargo.toml @@ -18,6 +18,7 @@ xet-core-structures = { version = "1.4.0", path = "../xet_core_structures" } xet-client = { version = "1.4.0", path = "../xet_client" } xet-data = { version = "1.4.0", path = "../xet_data" } +anyhow = { workspace = true } async-trait = { workspace = true } http = { workspace = true } more-asserts = { workspace = true } @@ -26,9 +27,10 @@ thiserror = { workspace = true } tokio = { workspace = true, features = ["net", "time"] } tracing = { workspace = true } ulid = { workspace = true } +pyo3 = { workspace = true, optional = true } [features] -python = ["xet-runtime/python"] +python = ["xet-runtime/python", "dep:pyo3"] [dev-dependencies] anyhow = { workspace = true } diff --git a/xet_pkg/src/error.rs b/xet_pkg/src/error.rs index e57dd3da..6b92cbfc 100644 --- a/xet_pkg/src/error.rs +++ b/xet_pkg/src/error.rs @@ -1,7 +1,8 @@ use thiserror::Error; use xet_client::ClientError; -use xet_core_structures::FormatError; +use xet_core_structures::CoreError; use xet_data::DataError; +use xet_data::file_reconstruction::FileReconstructionError; use xet_data::progress_tracking::UniqueID; use xet_runtime::RuntimeError; @@ -35,10 +36,14 @@ pub enum XetError { #[error("Authentication error: {0}")] Authentication(String), - /// Network-level failures: DNS, timeouts, HTTP 5xx, etc. + /// Network-level failures: DNS, HTTP 5xx, connection reset, etc. #[error("Network error: {0}")] Network(String), + /// A network request timed out. + #[error("Timeout: {0}")] + Timeout(String), + /// A requested resource (file, XORB, shard) does not exist. #[error("Not found: {0}")] NotFound(String), @@ -84,31 +89,34 @@ impl XetError { } } - fn from_format_error_ref(fe: &FormatError) -> Self { + fn from_core_error_ref(fe: &CoreError) -> Self { match fe { - FormatError::Io(_) => XetError::Io(fe.to_string()), - FormatError::ShardNotFound(_) | FormatError::FileNotFound(_) => XetError::NotFound(fe.to_string()), - FormatError::HashMismatch - | FormatError::TruncatedHashCollision(_) - | FormatError::InvalidShard(_) - | FormatError::ShardVersion(_) - | FormatError::ChunkHeaderParse - | FormatError::Format(_) - | FormatError::Compression(_) => XetError::DataIntegrity(fe.to_string()), - FormatError::InvalidRange | FormatError::InvalidArguments | FormatError::BadFilename(_) => { + CoreError::Io(_) => XetError::Io(fe.to_string()), + CoreError::ShardNotFound(_) | CoreError::FileNotFound(_) => XetError::NotFound(fe.to_string()), + CoreError::HashMismatch + | CoreError::TruncatedHashCollision(_) + | CoreError::InvalidShard(_) + | CoreError::ShardVersion(_) + | CoreError::ChunkHeaderParse + | CoreError::MalformedData(_) + | CoreError::CompressionError(_) => XetError::DataIntegrity(fe.to_string()), + CoreError::InvalidRange | CoreError::InvalidArguments | CoreError::BadFilename(_) => { XetError::Configuration(fe.to_string()) }, - FormatError::Runtime(re) => XetError::from_runtime_error_ref(re), + CoreError::RuntimeError(re) => XetError::from_runtime_error_ref(re), _ => XetError::Internal(fe.to_string()), } } fn from_client_error_ref(ce: &ClientError) -> Self { match ce { - ClientError::AuthError(_) => XetError::Authentication(ce.to_string()), - ClientError::ReqwestError(_, _) - | ClientError::ReqwestMiddlewareError(_) - | ClientError::PresignedUrlExpirationError => XetError::Network(ce.to_string()), + ClientError::AuthError(_) | ClientError::PresignedUrlExpirationError | ClientError::CredentialHelper(_) => { + XetError::Authentication(ce.to_string()) + }, + ClientError::ReqwestError(e, _) if e.is_timeout() => XetError::Timeout(ce.to_string()), + ClientError::ReqwestError(_, _) | ClientError::ReqwestMiddlewareError(_) => { + XetError::Network(ce.to_string()) + }, ClientError::FileNotFound(_) | ClientError::XORBNotFound(_) => XetError::NotFound(ce.to_string()), ClientError::ConfigurationError(_) | ClientError::InvalidArguments @@ -117,16 +125,31 @@ impl XetError { | ClientError::InvalidKey(_) | ClientError::InvalidRepoType(_) => XetError::Configuration(ce.to_string()), ClientError::IOError(_) => XetError::Io(ce.to_string()), - ClientError::FormatError(fe) => XetError::from_format_error_ref(fe), + ClientError::FormatError(fe) => XetError::from_core_error_ref(fe), _ => XetError::Internal(ce.to_string()), } } + fn from_file_reconstruction_error_ref(fre: &FileReconstructionError) -> Self { + match fre { + FileReconstructionError::ClientError(ce) => XetError::from_client_error_ref(ce), + FileReconstructionError::IoError(_) => XetError::Io(fre.to_string()), + FileReconstructionError::RuntimeError(re) => XetError::from_runtime_error_ref(re), + FileReconstructionError::TaskJoinError(je) if je.is_cancelled() => { + XetError::Cancelled(format!("Task cancelled: {je}")) + }, + FileReconstructionError::TaskJoinError(je) => XetError::Internal(format!("Task join error: {je}")), + FileReconstructionError::ConfigurationError(_) => XetError::Configuration(fre.to_string()), + FileReconstructionError::CorruptedReconstruction(_) => XetError::DataIntegrity(fre.to_string()), + _ => XetError::Internal(fre.to_string()), + } + } + fn from_data_error_ref(de: &DataError) -> Self { match de { DataError::AuthError(_) => XetError::Authentication(de.to_string()), DataError::ClientError(ce) => XetError::from_client_error_ref(ce), - DataError::FormatError(fe) => XetError::from_format_error_ref(fe), + DataError::FormatError(fe) => XetError::from_core_error_ref(fe), DataError::IOError(_) => XetError::Io(de.to_string()), DataError::RuntimeError(re) => XetError::from_runtime_error_ref(re), DataError::FileQueryPolicyError(_) @@ -138,6 +161,7 @@ impl XetError { DataError::HashNotFound => XetError::NotFound(de.to_string()), DataError::HashStringParsingFailure(_) => XetError::DataIntegrity(de.to_string()), DataError::InvalidOperation(_) => XetError::Configuration(de.to_string()), + DataError::FileReconstructionError(fre) => XetError::from_file_reconstruction_error_ref(fre), _ => XetError::Internal(de.to_string()), } } @@ -151,9 +175,9 @@ impl From for XetError { } } -impl From for XetError { - fn from(e: FormatError) -> Self { - XetError::from_format_error_ref(&e) +impl From for XetError { + fn from(e: CoreError) -> Self { + XetError::from_core_error_ref(&e) } } @@ -169,6 +193,12 @@ impl From for XetError { } } +impl From for XetError { + fn from(e: FileReconstructionError) -> Self { + XetError::from_file_reconstruction_error_ref(&e) + } +} + // -- Convenience From impls for common error types ----------------------- impl From for XetError { @@ -211,6 +241,58 @@ impl From>> for XetE } } +// -- Python exception classes & conversion -------------------------------- + +#[cfg(feature = "python")] +mod py_exceptions { + // Inherits from Python's PermissionError so `except PermissionError` still catches it. + pyo3::create_exception!(hf_xet, XetAuthenticationError, pyo3::exceptions::PyPermissionError); + + // Inherits from Python's FileNotFoundError so `except FileNotFoundError` still catches it. + pyo3::create_exception!(hf_xet, XetObjectNotFoundError, pyo3::exceptions::PyFileNotFoundError); + + /// Register the custom exception classes on a Python module. + /// + /// Call this from the `#[pymodule]` init function so that the exceptions + /// are importable as `hf_xet.XetAuthenticationError`, etc. + pub fn register_exceptions(m: &pyo3::Bound<'_, pyo3::types::PyModule>) -> pyo3::PyResult<()> { + use pyo3::types::PyModuleMethods; + + m.add("XetAuthenticationError", m.py().get_type::())?; + m.add("XetObjectNotFoundError", m.py().get_type::())?; + Ok(()) + } +} + +#[cfg(feature = "python")] +pub use py_exceptions::{XetAuthenticationError, XetObjectNotFoundError, register_exceptions}; + +#[cfg(feature = "python")] +impl From for pyo3::PyErr { + fn from(err: XetError) -> pyo3::PyErr { + use pyo3::exceptions::{PyConnectionError, PyOSError, PyRuntimeError, PyTimeoutError, PyValueError}; + + let msg = err.to_string(); + #[allow(unreachable_patterns)] // XetError is #[non_exhaustive] + match err { + XetError::Authentication(_) => XetAuthenticationError::new_err(msg), + XetError::NotFound(_) => XetObjectNotFoundError::new_err(msg), + XetError::Network(_) => PyConnectionError::new_err(msg), + XetError::Timeout(_) => PyTimeoutError::new_err(msg), + XetError::Io(_) => PyOSError::new_err(msg), + XetError::Configuration(_) | XetError::InvalidTaskID(_) => PyValueError::new_err(msg), + XetError::DataIntegrity(_) + | XetError::Internal(_) + | XetError::WrongRuntimeMode(_) + | XetError::AlreadyCommitted + | XetError::AlreadyFinished + | XetError::Aborted + | XetError::Cancelled(_) => PyRuntimeError::new_err(msg), + _ => PyRuntimeError::new_err(msg), + } + } +} + #[cfg(test)] mod tests { use xet_client::cas_client::auth::AuthError; @@ -226,13 +308,13 @@ mod tests { #[test] fn format_not_found_maps_to_not_found() { - let err = XetError::from(FormatError::ShardNotFound(MerkleHash::default())); + let err = XetError::from(CoreError::ShardNotFound(MerkleHash::default())); assert!(matches!(err, XetError::NotFound(_))); } #[test] fn format_invalid_args_maps_to_configuration() { - let err = XetError::from(FormatError::InvalidArguments); + let err = XetError::from(CoreError::InvalidArguments); assert!(matches!(err, XetError::Configuration(_))); } @@ -244,7 +326,7 @@ mod tests { #[test] fn client_nested_format_maps_using_format_rules() { - let err = XetError::from(ClientError::FormatError(FormatError::InvalidRange)); + let err = XetError::from(ClientError::FormatError(CoreError::InvalidRange)); assert!(matches!(err, XetError::Configuration(_))); } @@ -259,4 +341,37 @@ mod tests { let err = XetError::from(DataError::RuntimeError(RuntimeError::TaskCanceled("cancelled".to_string()))); assert!(matches!(err, XetError::Cancelled(_))); } + + #[test] + fn presigned_url_expiration_maps_to_authentication() { + let err = XetError::from(ClientError::PresignedUrlExpirationError); + assert!(matches!(err, XetError::Authentication(_))); + } + + #[test] + fn credential_helper_maps_to_authentication() { + let err = XetError::from(ClientError::credential_helper_error(std::io::Error::new( + std::io::ErrorKind::Other, + "cred fail", + ))); + assert!(matches!(err, XetError::Authentication(_))); + } + + #[test] + fn client_not_found_maps_to_not_found() { + let err = XetError::from(ClientError::FileNotFound(MerkleHash::default())); + assert!(matches!(err, XetError::NotFound(_))); + } + + #[test] + fn client_xorb_not_found_maps_to_not_found() { + let err = XetError::from(ClientError::XORBNotFound(MerkleHash::default())); + assert!(matches!(err, XetError::NotFound(_))); + } + + #[test] + fn client_io_maps_to_io() { + let err = XetError::from(ClientError::IOError(std::io::Error::new(std::io::ErrorKind::NotFound, "gone"))); + assert!(matches!(err, XetError::Io(_))); + } } diff --git a/xet_pkg/src/legacy/data_client.rs b/xet_pkg/src/legacy/data_client.rs index 4fceff03..5c111a2a 100644 --- a/xet_pkg/src/legacy/data_client.rs +++ b/xet_pkg/src/legacy/data_client.rs @@ -6,17 +6,12 @@ use tracing::{Instrument, Span, info_span, instrument}; use xet_client::cas_client::auth::TokenRefresher; pub use xet_data::processing::data_client::hash_files_async; use xet_data::processing::data_client::{clean_bytes, default_config}; -use xet_data::processing::errors::DataProcessingError; use xet_data::processing::{FileDownloadSession, FileUploadSession, Sha256Policy, XetFileInfo}; +use xet_data::{DataError, Result}; use xet_runtime::core::par_utils::run_constrained_with_semaphore; use xet_runtime::core::{XetRuntime, xet_config}; use super::progress_tracking::{GroupProgressCallbackUpdater, ItemProgressCallbackUpdater, TrackingProgressUpdater}; -use crate::legacy::data_client::errors::Result; - -mod errors { - pub use xet_data::processing::errors::Result; -} #[instrument(skip_all, name = "data_client::upload_bytes", fields(session_id = tracing::field::Empty, num_files=file_contents.len()))] pub async fn upload_bytes_async( @@ -29,7 +24,7 @@ pub async fn upload_bytes_async( custom_headers: Option>, ) -> Result> { if sha256_policies.len() != file_contents.len() { - return Err(DataProcessingError::ParameterError(format!( + return Err(DataError::ParameterError(format!( "sha256_policies length ({}) must match file_contents length ({})", sha256_policies.len(), file_contents.len() @@ -86,7 +81,7 @@ pub async fn upload_async( custom_headers: Option>, ) -> Result> { if sha256_policies.len() != file_paths.len() { - return Err(DataProcessingError::ParameterError(format!( + return Err(DataError::ParameterError(format!( "sha256_policies length ({}) must match file_paths length ({})", sha256_policies.len(), file_paths.len() @@ -140,7 +135,7 @@ pub async fn download_async( if let Some(updaters) = &progress_updaters && updaters.len() != file_infos.len() { - return Err(DataProcessingError::ParameterError("updaters are not same length as pointer_files".to_string())); + return Err(DataError::ParameterError("updaters are not same length as pointer_files".to_string())); } let config: Arc<_> = default_config( endpoint.unwrap_or_else(|| xet_config().data.default_cas_endpoint.clone()), diff --git a/xet_pkg/src/legacy/mod.rs b/xet_pkg/src/legacy/mod.rs index ecf4cc82..fd8f8ec7 100644 --- a/xet_pkg/src/legacy/mod.rs +++ b/xet_pkg/src/legacy/mod.rs @@ -5,5 +5,4 @@ pub mod progress_tracking; // a direct xet_data dependency. pub use xet_data::processing::configurations::{SessionContext, TranslatorConfig}; pub use xet_data::processing::data_client::{clean_bytes, clean_file, default_config, hash_files_async}; -pub use xet_data::processing::errors::DataProcessingError; pub use xet_data::processing::{FileDownloadSession, FileUploadSession, Sha256Policy, XetFileInfo}; diff --git a/xet_pkg/src/lib.rs b/xet_pkg/src/lib.rs index fe7bfa29..9ef561cc 100644 --- a/xet_pkg/src/lib.rs +++ b/xet_pkg/src/lib.rs @@ -1,5 +1,7 @@ pub mod error; pub use error::XetError; +#[cfg(feature = "python")] +pub use error::{XetAuthenticationError, XetObjectNotFoundError, register_exceptions}; pub mod legacy; pub mod xet_session; diff --git a/xet_pkg/src/xet_session/common.rs b/xet_pkg/src/xet_session/common.rs index 0828d7aa..b3dba1f5 100644 --- a/xet_pkg/src/xet_session/common.rs +++ b/xet_pkg/src/xet_session/common.rs @@ -1,9 +1,10 @@ use xet_data::processing::configurations::TranslatorConfig; -use super::{SessionError, XetSession}; +use super::XetSession; +use crate::error::XetError; // Helper function to create TranslatorConfig -pub(super) fn create_translator_config(session: &XetSession) -> Result { +pub(super) fn create_translator_config(session: &XetSession) -> Result { let endpoint = session .endpoint .clone() diff --git a/xet_pkg/src/xet_session/download_group.rs b/xet_pkg/src/xet_session/download_group.rs index 86af72b0..35d4730d 100644 --- a/xet_pkg/src/xet_session/download_group.rs +++ b/xet_pkg/src/xet_session/download_group.rs @@ -11,9 +11,9 @@ use xet_data::progress_tracking::{GroupProgressReport, UniqueID}; use xet_runtime::core::XetRuntime; use super::common::{GroupState, create_translator_config}; -use super::errors::SessionError; use super::session::{RuntimeMode, XetSession}; use super::tasks::{DownloadTaskHandle, TaskHandle, TaskStatus}; +use crate::error::XetError; /// API for grouping related file downloads into a single unit of work. /// @@ -34,8 +34,8 @@ use super::tasks::{DownloadTaskHandle, TaskHandle, TaskStatus}; /// /// # Errors /// -/// Methods return [`SessionError::Aborted`] if the parent session has been -/// aborted, and [`SessionError::AlreadyFinished`] if +/// Methods return [`XetError::Aborted`] if the parent session has been +/// aborted, and [`XetError::AlreadyFinished`] if /// [`finish`](Self::finish) has already been called. #[derive(Clone)] pub struct DownloadGroup { @@ -52,7 +52,7 @@ impl std::ops::Deref for DownloadGroup { impl DownloadGroup { /// Create a new download group from an **async** context. Initialisation logic shared by the sync and async /// constructors. - pub(super) async fn new(session: XetSession) -> Result { + pub(super) async fn new(session: XetSession) -> Result { let group_id = UniqueID::new(); let config = create_translator_config(&session)?; let download_session = FileDownloadSession::new(Arc::new(config)).await?; @@ -74,7 +74,7 @@ impl DownloadGroup { } /// Abort this download group. - pub(super) fn abort(&self) -> Result<(), SessionError> { + pub(super) fn abort(&self) -> Result<(), XetError> { self.inner.abort() } @@ -97,14 +97,14 @@ impl DownloadGroup { /// /// # Errors /// - /// Returns [`SessionError::Aborted`] if the session has been aborted, or - /// [`SessionError::AlreadyFinished`] if [`finish`](Self::finish) has already + /// Returns [`XetError::Aborted`] if the session has been aborted, or + /// [`XetError::AlreadyFinished`] if [`finish`](Self::finish) has already /// been called. pub async fn download_file_to_path( &self, file_info: XetFileInfo, dest_path: PathBuf, - ) -> Result { + ) -> Result { self.session.check_alive()?; // Use the absolute path in case the process current working directory changes @@ -114,7 +114,7 @@ impl DownloadGroup { } /// Return a snapshot of progress for every queued download. - pub fn get_progress(&self) -> Result { + pub fn get_progress(&self) -> Result { let Some(download_session) = self.download_session.lock()?.clone() else { return Ok(GroupProgressReport::default()); }; @@ -125,7 +125,7 @@ impl DownloadGroup { /// /// Returns a `HashMap` keyed by task ID where each value is /// [`DownloadResult`] (= `Arc>`). A single failed download + /// [`XetError`](crate::XetError)`>>`). A single failed download /// does not prevent the others from being collected. /// /// Per-task results can also be read directly from the @@ -133,8 +133,8 @@ impl DownloadGroup { /// [`result`](DownloadTaskHandle::result) after this method returns. /// /// Consumes `self` — subsequent calls on any clone will return - /// [`SessionError::AlreadyFinished`]. - pub async fn finish(self) -> Result, SessionError> { + /// [`XetError::AlreadyFinished`]. + pub async fn finish(self) -> Result, XetError> { let inner = self.inner.clone(); self.session .dispatch("finish", async move { inner.handle_finish().await }) @@ -154,7 +154,7 @@ impl DownloadGroup { /// /// # Errors /// - /// Returns [`SessionError::WrongRuntimeMode`] if the session was created with an external + /// Returns [`XetError::WrongRuntimeMode`] if the session was created with an external /// tokio runtime ([`XetSessionBuilder::with_tokio_handle`] / [`XetSessionBuilder::build_async`] /// inside a tokio context). Use [`download_file_to_path`](Self::download_file_to_path)`.await` /// instead. @@ -166,9 +166,9 @@ impl DownloadGroup { &self, file_info: XetFileInfo, dest_path: PathBuf, - ) -> Result { + ) -> Result { if matches!(self.session.runtime_mode, RuntimeMode::External) { - return Err(SessionError::wrong_mode( + return Err(XetError::wrong_mode( "download_file_to_path_blocking() cannot be called on a session using an \ external tokio runtime (with_tokio_handle() or tokio build_async()); \ use download_file_to_path().await instead", @@ -180,7 +180,7 @@ impl DownloadGroup { } /// Blocking version of [`get_progress`](Self::get_progress). - pub fn get_progress_blocking(&self) -> Result { + pub fn get_progress_blocking(&self) -> Result { self.get_progress() } @@ -189,7 +189,7 @@ impl DownloadGroup { /// # Panics /// /// Panics if called from within a tokio async runtime. - pub fn finish_blocking(self) -> Result, SessionError> { + pub fn finish_blocking(self) -> Result, XetError> { let group = self.clone(); self.runtime().external_run_async_task(group.finish())? } @@ -200,7 +200,7 @@ impl DownloadGroup { /// The `Arc` lets the same value be stored in both the `finish()` return map /// and the per-task [`DownloadTaskHandle`] without requiring the inner /// `Result` to be `Clone`. -pub type DownloadResult = Arc>; +pub type DownloadResult = Arc>; /// Handle for a single download task tracked internally by DownloadGroup. struct InnerDownloadTaskHandle { @@ -232,10 +232,10 @@ impl DownloadGroupInner { // ===== State helpers ===== /// Check whether the group is still accepting new tasks. - fn check_accepting_tasks(state: &MutexGuard) -> Result<(), SessionError> { + fn check_accepting_tasks(state: &MutexGuard) -> Result<(), XetError> { match **state { - GroupState::Finished => Err(SessionError::AlreadyFinished), - GroupState::Aborted => Err(SessionError::Aborted), + GroupState::Finished => Err(XetError::AlreadyFinished), + GroupState::Aborted => Err(XetError::Aborted), GroupState::Alive => Ok(()), } } @@ -244,13 +244,13 @@ impl DownloadGroupInner { self: &Arc, file_info: XetFileInfo, dest_path: PathBuf, - ) -> Result { + ) -> Result { let download_session = { let state = self.state.lock()?; Self::check_accepting_tasks(&state)?; let Some(download_session) = self.download_session.lock()?.clone() else { - return Err(SessionError::other("Download session not initialized")); + return Err(XetError::other("Download session not initialized")); }; download_session // state guard dropped here before the .await @@ -299,12 +299,12 @@ impl DownloadGroupInner { } /// Join all active download tasks and mark the group as finished. - pub(super) async fn handle_finish(&self) -> Result, SessionError> { + pub(super) async fn handle_finish(&self) -> Result, XetError> { // Mark as not accepting new tasks { let mut state_guard = self.state.lock()?; if *state_guard == GroupState::Finished { - return Err(SessionError::AlreadyFinished); + return Err(XetError::AlreadyFinished); } *state_guard = GroupState::Aborted; // stop new tasks while draining } @@ -343,7 +343,7 @@ impl DownloadGroupInner { TaskStatus::mark_terminal(&handle.status, TaskStatus::Failed); } if join_err.is_none() { - join_err = Some(SessionError::from(e)); + join_err = Some(XetError::from(e)); } }, } @@ -361,7 +361,7 @@ impl DownloadGroupInner { Ok(results) } - fn abort(&self) -> Result<(), SessionError> { + fn abort(&self) -> Result<(), XetError> { *self.state.lock()? = GroupState::Aborted; let active_tasks = std::mem::take(&mut *self.active_tasks.write()?); for (_tracking_id, inner_task_handle) in active_tasks { @@ -388,13 +388,14 @@ mod tests { use std::sync::mpsc; use std::time::Duration; + use anyhow::Result; use tempfile::{TempDir, tempdir}; use xet_data::processing::Sha256Policy; use super::*; use crate::xet_session::session::{RuntimeMode, XetSession, XetSessionBuilder}; - async fn local_session(temp: &TempDir) -> Result> { + async fn local_session(temp: &TempDir) -> Result { let cas_path = temp.path().join("cas"); Ok(XetSessionBuilder::new() .with_endpoint(format!("local://{}", cas_path.display())) @@ -402,11 +403,7 @@ mod tests { .await?) } - async fn upload_bytes( - session: &XetSession, - data: &[u8], - name: &str, - ) -> Result> { + async fn upload_bytes(session: &XetSession, data: &[u8], name: &str) -> Result { let commit = session.new_upload_commit().await?; let handle = commit .upload_bytes(data.to_vec(), Sha256Policy::Compute, Some(name.into())) @@ -430,7 +427,7 @@ mod tests { #[test] // finish() must block while download_file_to_path() holds the state lock. - fn test_finish_blocked_while_download_registration_holds_state_lock() -> Result<(), Box> { + fn test_finish_blocked_while_download_registration_holds_state_lock() -> Result<()> { let session = XetSessionBuilder::new().build()?; let runtime = session.runtime.clone(); // Create DownloadGroup directly so we can access its private state field @@ -513,7 +510,7 @@ mod tests { let g2 = g1.clone(); g1.finish().await.unwrap(); let err = g2.finish().await.unwrap_err(); - assert!(matches!(err, SessionError::AlreadyFinished | SessionError::Internal(_))); + assert!(matches!(err, XetError::AlreadyFinished | XetError::Internal(_))); } #[tokio::test(flavor = "multi_thread")] @@ -545,7 +542,7 @@ mod tests { ) .await .unwrap_err(); - assert!(matches!(err, SessionError::Aborted)); + assert!(matches!(err, XetError::Aborted)); } #[tokio::test(flavor = "multi_thread")] @@ -566,7 +563,7 @@ mod tests { ) .await .unwrap_err(); - assert!(matches!(err, SessionError::AlreadyFinished)); + assert!(matches!(err, XetError::AlreadyFinished)); } #[tokio::test(flavor = "multi_thread")] @@ -586,7 +583,7 @@ mod tests { ) .await .unwrap_err(); - assert!(matches!(err, SessionError::Aborted)); + assert!(matches!(err, XetError::Aborted)); } // ── Independence ───────────────────────────────────────────────────────── @@ -894,18 +891,14 @@ mod tests { // ── Blocking API tests ──────────────────────────────────────────────────── - fn local_session_sync(temp: &TempDir) -> Result> { + fn local_session_sync(temp: &TempDir) -> Result { let cas_path = temp.path().join("cas"); Ok(XetSessionBuilder::new() .with_endpoint(format!("local://{}", cas_path.display())) .build()?) } - fn upload_bytes_blocking( - session: &XetSession, - data: &[u8], - name: &str, - ) -> Result> { + fn upload_bytes_blocking(session: &XetSession, data: &[u8], name: &str) -> Result { let commit = session.new_upload_commit_blocking()?; let handle = commit.upload_bytes_blocking(data.to_vec(), Sha256Policy::Compute, Some(name.into()))?; let results = commit.commit_blocking()?; @@ -918,7 +911,7 @@ mod tests { } #[test] - fn test_blocking_download_file_round_trip() -> Result<(), Box> { + fn test_blocking_download_file_round_trip() -> Result<()> { let temp = tempdir()?; let session = local_session_sync(&temp)?; let original = b"Hello, download round-trip!"; @@ -934,7 +927,7 @@ mod tests { } #[test] - fn test_blocking_download_multiple_files() -> Result<(), Box> { + fn test_blocking_download_multiple_files() -> Result<()> { let temp = tempdir()?; let session = local_session_sync(&temp)?; @@ -968,7 +961,7 @@ mod tests { } #[test] - fn test_blocking_download_progress_reflects_bytes_after_finish() -> Result<(), Box> { + fn test_blocking_download_progress_reflects_bytes_after_finish() -> Result<()> { let temp = tempdir()?; let session = local_session_sync(&temp)?; let original = b"download progress tracking data"; @@ -997,7 +990,7 @@ mod tests { } #[test] - fn test_blocking_download_result_access_patterns() -> Result<(), Box> { + fn test_blocking_download_result_access_patterns() -> Result<()> { let temp = tempdir()?; let session = local_session_sync(&temp)?; let data = b"download result access patterns"; @@ -1073,7 +1066,7 @@ mod tests { .download_file_to_path_blocking(file_info, PathBuf::from("/nonexistent")) .err() .unwrap(); - assert!(matches!(err, SessionError::WrongRuntimeMode(_))); + assert!(matches!(err, XetError::WrongRuntimeMode(_))); } // ── Owned-mode _blocking panic guard ───────────────────────────────────── diff --git a/xet_pkg/src/xet_session/errors.rs b/xet_pkg/src/xet_session/errors.rs deleted file mode 100644 index 0314d4c8..00000000 --- a/xet_pkg/src/xet_session/errors.rs +++ /dev/null @@ -1 +0,0 @@ -pub use crate::error::XetError as SessionError; diff --git a/xet_pkg/src/xet_session/mod.rs b/xet_pkg/src/xet_session/mod.rs index fdab569d..b08082e6 100644 --- a/xet_pkg/src/xet_session/mod.rs +++ b/xet_pkg/src/xet_session/mod.rs @@ -27,7 +27,7 @@ //! transfers to finish and receive a `HashMap<`[`UniqueID`]`, `[`UploadResult`]`>` //! keyed by task ID. //! -//! `UploadResult` = `Arc>`. +//! `UploadResult` = `Arc>`. //! Per-task results can also be read from the returned [`UploadTaskHandle`] //! via [`result`](UploadTaskHandle::result) after `commit()` returns. //! @@ -42,7 +42,7 @@ //! transfers to complete and receive a `HashMap<`[`UniqueID`]`, `[`DownloadResult`]`>` //! keyed by task ID. //! -//! `DownloadResult` = `Arc>`. +//! `DownloadResult` = `Arc>`. //! Per-task results can also be read from the returned [`DownloadTaskHandle`] //! via [`result`](DownloadTaskHandle::result) after `finish()` returns. //! @@ -56,7 +56,7 @@ //! //! ## Error handling //! -//! All public methods return `Result<_, `[`SessionError`]`>`. +//! All public methods return `Result<_, `[`XetError`]`>`. //! [`commit`](UploadCommit::commit) returns `HashMap<`[`UniqueID`]`, `[`UploadResult`]`>` //! keyed by task ID, and [`finish`](DownloadGroup::finish) returns //! `HashMap<`[`UniqueID`]`, `[`DownloadResult`]`>` keyed by task ID, so a single failed @@ -77,7 +77,7 @@ //! // 2. Upload — use the _blocking factory and _blocking methods //! let commit = session.new_upload_commit_blocking()?; //! let handle = commit.upload_from_path_blocking("file.bin".into(), Sha256Policy::Compute)?; -//! // UploadResult = Arc> +//! // UploadResult = Arc> //! let results = commit.commit_blocking()?; //! let m = results.values().next().unwrap().as_ref().as_ref().unwrap(); //! @@ -90,10 +90,10 @@ //! }; //! let dl_handle = group.download_file_to_path_blocking(info, "out/file.bin".into())?; //! let finish_results = group.finish_blocking()?; -//! // DownloadResult = Arc> +//! // DownloadResult = Arc> //! let r = finish_results.get(&dl_handle.task_id).unwrap().as_ref().as_ref().unwrap(); //! -//! # Ok::<(), xet::xet_session::SessionError>(()) +//! # Ok::<(), xet::XetError>(()) //! ``` //! //! # Quick start — async API @@ -101,7 +101,7 @@ //! ```rust,no_run //! use xet::xet_session::{Sha256Policy, XetFileInfo, XetSessionBuilder}; //! -//! # async fn example() -> Result<(), xet::xet_session::SessionError> { +//! # async fn example() -> Result<(), xet::XetError> { //! // 1. Build a session. build_async() auto-detects the executor: //! // - tokio (multi-thread): wraps the caller's handle, no second thread pool. //! // - non-tokio (smol, async-std, etc.): creates an owned thread pool. @@ -114,7 +114,7 @@ //! // 2. Upload — use the async factory and async methods //! let commit = session.new_upload_commit().await?; //! let handle = commit.upload_from_path("file.bin".into(), Sha256Policy::Compute).await?; -//! // UploadResult = Arc> +//! // UploadResult = Arc> //! let results = commit.commit().await?; //! let m = results.values().next().unwrap().as_ref().as_ref().unwrap(); //! @@ -127,7 +127,7 @@ //! }; //! let dl_handle = group.download_file_to_path(info, "out/file.bin".into()).await?; //! let finish_results = group.finish().await?; -//! // DownloadResult = Arc> +//! // DownloadResult = Arc> //! let r = finish_results.get(&dl_handle.task_id).unwrap().as_ref().as_ref().unwrap(); //! # Ok(()) //! # } @@ -135,13 +135,11 @@ mod common; mod download_group; -mod errors; mod session; mod tasks; mod upload_commit; pub use download_group::{DownloadGroup, DownloadResult, DownloadedFile}; -pub use errors::SessionError; pub use session::{XetSession, XetSessionBuilder}; pub use tasks::{DownloadTaskHandle, TaskHandle, TaskStatus, UploadTaskHandle}; pub use upload_commit::{FileMetadata, UploadCommit, UploadResult}; diff --git a/xet_pkg/src/xet_session/session.rs b/xet_pkg/src/xet_session/session.rs index ecdd1103..059ff6d7 100644 --- a/xet_pkg/src/xet_session/session.rs +++ b/xet_pkg/src/xet_session/session.rs @@ -16,8 +16,8 @@ use xet_runtime::config::XetConfig; use xet_runtime::core::XetRuntime; use super::download_group::DownloadGroup; -use super::errors::SessionError; use super::upload_commit::UploadCommit; +use crate::error::XetError; /// Session state enum SessionState { @@ -34,7 +34,7 @@ enum SessionState { /// /// - **`External`**: session wraps a caller-provided tokio handle via [`XetSessionBuilder::with_tokio_handle`] or /// [`XetSessionBuilder::build_async`] (tokio context). Only async methods may be called; `_blocking` methods return -/// [`SessionError::WrongRuntimeMode`]. No second thread pool is created. +/// [`XetError::WrongRuntimeMode`]. No second thread pool is created. #[derive(Clone, Copy, PartialEq, Eq, Debug)] pub(super) enum RuntimeMode { Owned, @@ -128,12 +128,12 @@ fn handle_meets_session_requirements(handle: &tokio::runtime::Handle) -> bool { /// .with_endpoint("https://cas.example.com".into()) /// .with_token_info("my-token".into(), 1_700_000_000) /// .build()?; -/// # Ok::<(), xet::xet_session::SessionError>(()) +/// # Ok::<(), xet::XetError>(()) /// ``` /// /// ```rust,no_run /// # use xet::xet_session::XetSessionBuilder; -/// # async fn example() -> Result<(), xet::xet_session::SessionError> { +/// # async fn example() -> Result<(), xet::XetError> { /// // Async context — wraps the caller's tokio handle (External mode) if inside tokio, /// // or creates an owned runtime (Owned mode) if called from a non-tokio executor: /// let session = XetSessionBuilder::new() @@ -221,7 +221,7 @@ impl XetSessionBuilder { /// If the handle meets session requirements (multi-thread flavor, time driver, IO driver), /// the session will wrap it — no second thread pool is created (External mode). Only async /// methods (`new_upload_commit`, `new_download_group`) may be called; `_blocking` variants - /// will return [`SessionError::WrongRuntimeMode`]. + /// will return [`XetError::WrongRuntimeMode`]. /// /// If the handle does **not** meet requirements (e.g. `current_thread` flavor or missing /// drivers), it is silently ignored and [`build`](Self::build) will fall back to creating @@ -253,7 +253,7 @@ impl XetSessionBuilder { /// `with_tokio_handle`; falls back to an owned thread pool — Owned mode. /// - **Non-tokio context** (smol, async-std, etc.): creates an owned thread pool — Owned mode; async methods use an /// internal bridge compatible with any executor. - pub async fn build_async(self) -> Result { + pub async fn build_async(self) -> Result { match tokio::runtime::Handle::try_current() { Ok(handle) => self.with_tokio_handle(handle).build(), Err(_) => self.build(), @@ -268,7 +268,7 @@ impl XetSessionBuilder { /// executor, and `_blocking` methods are available. /// /// For async contexts, prefer [`build_async`](Self::build_async). - pub fn build(self) -> Result { + pub fn build(self) -> Result { let (runtime, mode) = match self.tokio_handle { Some(handle) => (XetRuntime::from_external_with_config(handle, self.config.clone()), RuntimeMode::External), None => (XetRuntime::new_with_config(self.config.clone())?, RuntimeMode::Owned), @@ -360,18 +360,18 @@ impl XetSession { /// Create a new [`UploadCommit`] that groups related file uploads. /// - /// Returns `Err(SessionError::Aborted)` if the session has been aborted. + /// Returns `Err(XetError::Aborted)` if the session has been aborted. /// /// # Note /// /// This is an `async fn` and must be `.await`ed. For sync Rust or Python (PyO3) callers, /// use [`new_upload_commit_blocking`](Self::new_upload_commit_blocking). - pub async fn new_upload_commit(&self) -> Result { + pub async fn new_upload_commit(&self) -> Result { // Check state before the async init; drop the guard so it is not held across .await. { let state = self.state.lock()?; if matches!(*state, SessionState::Aborted) { - return Err(SessionError::Aborted); + return Err(XetError::Aborted); } } @@ -391,8 +391,8 @@ impl XetSession { /// The returned [`UploadCommit`] supports both async methods (`upload_from_path`, /// `commit`) and blocking methods (`upload_from_path_blocking`, `commit_blocking`). /// - /// Returns `Err(SessionError::Aborted)` if the session has been aborted. - /// Returns `Err(SessionError::WrongRuntimeMode)` if the session uses an external + /// Returns `Err(XetError::Aborted)` if the session has been aborted. + /// Returns `Err(XetError::WrongRuntimeMode)` if the session uses an external /// tokio runtime (from [`XetSessionBuilder::with_tokio_handle`] or tokio-detected /// [`XetSessionBuilder::build_async`]). /// @@ -403,9 +403,9 @@ impl XetSession { /// async-std, `futures::executor`) do not set this context, so calling from those is /// safe — it blocks the executor thread until the task completes. Use /// [`new_upload_commit`](Self::new_upload_commit) from async contexts instead. - pub fn new_upload_commit_blocking(&self) -> Result { + pub fn new_upload_commit_blocking(&self) -> Result { if matches!(self.runtime_mode, RuntimeMode::External) { - return Err(SessionError::wrong_mode( + return Err(XetError::wrong_mode( "new_upload_commit_blocking() cannot be called on a session using an \ external tokio runtime (with_tokio_handle() or tokio build_async()); \ use new_upload_commit().await instead", @@ -414,7 +414,7 @@ impl XetSession { { let state = self.state.lock()?; if matches!(*state, SessionState::Aborted) { - return Err(SessionError::Aborted); + return Err(XetError::Aborted); } } @@ -425,18 +425,18 @@ impl XetSession { /// Create a new [`DownloadGroup`] that groups related file downloads. /// - /// Returns `Err(SessionError::Aborted)` if the session has been aborted. + /// Returns `Err(XetError::Aborted)` if the session has been aborted. /// /// # Note /// /// This is an `async fn` and must be `.await`ed. For sync Rust or Python (PyO3) callers, /// use [`new_download_group_blocking`](Self::new_download_group_blocking). - pub async fn new_download_group(&self) -> Result { + pub async fn new_download_group(&self) -> Result { // Check state before the async init; drop the guard so it is not held across .await. { let state = self.state.lock()?; if matches!(*state, SessionState::Aborted) { - return Err(SessionError::Aborted); + return Err(XetError::Aborted); } } @@ -456,8 +456,8 @@ impl XetSession { /// The returned [`DownloadGroup`] supports both the async [`finish`](DownloadGroup::finish) /// and blocking [`finish_blocking`](DownloadGroup::finish_blocking) methods. /// - /// Returns `Err(SessionError::Aborted)` if the session has been aborted. - /// Returns `Err(SessionError::WrongRuntimeMode)` if the session uses an external + /// Returns `Err(XetError::Aborted)` if the session has been aborted. + /// Returns `Err(XetError::WrongRuntimeMode)` if the session uses an external /// tokio runtime (from [`XetSessionBuilder::with_tokio_handle`] or tokio-detected /// [`XetSessionBuilder::build_async`]). /// @@ -468,9 +468,9 @@ impl XetSession { /// async-std, `futures::executor`) do not set this context, so calling from those is /// safe — it blocks the executor thread until the task completes. Use /// [`new_download_group`](Self::new_download_group) from async contexts instead. - pub fn new_download_group_blocking(&self) -> Result { + pub fn new_download_group_blocking(&self) -> Result { if matches!(self.runtime_mode, RuntimeMode::External) { - return Err(SessionError::wrong_mode( + return Err(XetError::wrong_mode( "new_download_group_blocking() cannot be called on a session using an \ external tokio runtime (with_tokio_handle() or tokio build_async()); \ use new_download_group().await instead", @@ -479,7 +479,7 @@ impl XetSession { { let state = self.state.lock()?; if matches!(*state, SessionState::Aborted) { - return Err(SessionError::Aborted); + return Err(XetError::Aborted); } } @@ -492,7 +492,7 @@ impl XetSession { /// /// This performs a SIGINT-style shutdown, aborting all active upload and download tasks. /// Use this when a Ctrl+C signal is detected or when you need to immediately stop all operations. - pub fn abort(&self) -> Result<(), SessionError> { + pub fn abort(&self) -> Result<(), XetError> { // Mark as not accepting new work, hold the lock so no new task can be created when aborting let mut state = self.state.lock()?; *state = SessionState::Aborted; @@ -513,19 +513,19 @@ impl XetSession { Ok(()) } - pub(super) fn check_alive(&self) -> Result<(), SessionError> { + pub(super) fn check_alive(&self) -> Result<(), XetError> { if matches!(*self.state.lock()?, SessionState::Aborted) { - return Err(SessionError::Aborted); + return Err(XetError::Aborted); } Ok(()) } - pub(super) fn finish_upload_commit(&self, commit_id: UniqueID) -> Result<(), SessionError> { + pub(super) fn finish_upload_commit(&self, commit_id: UniqueID) -> Result<(), XetError> { self.active_upload_commits.lock()?.remove(&commit_id); Ok(()) } - pub(super) fn finish_download_group(&self, group_id: UniqueID) -> Result<(), SessionError> { + pub(super) fn finish_download_group(&self, group_id: UniqueID) -> Result<(), XetError> { self.active_download_groups.lock()?.remove(&group_id); Ok(()) } @@ -568,7 +568,7 @@ mod tests { let session = XetSessionBuilder::new().build().unwrap(); session.abort().unwrap(); let err = session.check_alive().unwrap_err(); - assert!(matches!(err, SessionError::Aborted)); + assert!(matches!(err, XetError::Aborted)); } #[test] @@ -577,7 +577,7 @@ mod tests { let session = XetSessionBuilder::new().build().unwrap(); session.abort().unwrap(); let err = session.new_upload_commit_blocking().err().unwrap(); - assert!(matches!(err, SessionError::Aborted)); + assert!(matches!(err, XetError::Aborted)); } #[test] @@ -586,7 +586,7 @@ mod tests { let session = XetSessionBuilder::new().build().unwrap(); session.abort().unwrap(); let err = session.new_download_group_blocking().err().unwrap(); - assert!(matches!(err, SessionError::Aborted)); + assert!(matches!(err, XetError::Aborted)); } #[test] @@ -669,8 +669,8 @@ mod tests { session.abort().unwrap(); let commit_err = session.new_upload_commit().await.err().unwrap(); let group_err = session.new_download_group().await.err().unwrap(); - assert!(matches!(commit_err, SessionError::Aborted)); - assert!(matches!(group_err, SessionError::Aborted)); + assert!(matches!(commit_err, XetError::Aborted)); + assert!(matches!(group_err, XetError::Aborted)); } #[tokio::test(flavor = "multi_thread")] @@ -763,7 +763,7 @@ mod tests { let session = XetSessionBuilder::new().build_async().await.unwrap(); assert_eq!(session.runtime_mode, RuntimeMode::External); let err = session.new_upload_commit_blocking().err().unwrap(); - assert!(matches!(err, SessionError::WrongRuntimeMode(_))); + assert!(matches!(err, XetError::WrongRuntimeMode(_))); } #[tokio::test(flavor = "multi_thread")] @@ -772,7 +772,7 @@ mod tests { let session = XetSessionBuilder::new().build_async().await.unwrap(); assert_eq!(session.runtime_mode, RuntimeMode::External); let err = session.new_download_group_blocking().err().unwrap(); - assert!(matches!(err, SessionError::WrongRuntimeMode(_))); + assert!(matches!(err, XetError::WrongRuntimeMode(_))); } // ── Owned-mode _blocking panic guard ───────────────────────────────────── diff --git a/xet_pkg/src/xet_session/tasks.rs b/xet_pkg/src/xet_session/tasks.rs index c423f67f..b5d51342 100644 --- a/xet_pkg/src/xet_session/tasks.rs +++ b/xet_pkg/src/xet_session/tasks.rs @@ -5,9 +5,9 @@ use std::sync::{Arc, Mutex, OnceLock}; use xet_data::progress_tracking::UniqueID; -use super::SessionError; use super::download_group::DownloadResult; use super::upload_commit::UploadResult; +use crate::error::XetError; /// Lifecycle state of a single upload or download task. #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -84,11 +84,11 @@ impl Deref for DownloadTaskHandle { } impl TaskHandle { - pub fn status(&self) -> Result { + pub fn status(&self) -> Result { if let Some(status) = &self.status { Ok(*status.lock()?) } else { - Err(SessionError::other("status not available")) + Err(XetError::other("status not available")) } } } diff --git a/xet_pkg/src/xet_session/upload_commit.rs b/xet_pkg/src/xet_session/upload_commit.rs index 7d1b3401..f944ed3c 100644 --- a/xet_pkg/src/xet_session/upload_commit.rs +++ b/xet_pkg/src/xet_session/upload_commit.rs @@ -11,9 +11,9 @@ use xet_data::progress_tracking::{GroupProgressReport, UniqueID}; use xet_runtime::core::XetRuntime; use super::common::{GroupState, create_translator_config}; -use super::errors::SessionError; use super::session::{RuntimeMode, XetSession}; use super::tasks::{TaskHandle, TaskStatus, UploadTaskHandle}; +use crate::error::XetError; /// API for grouping related file uploads into a single atomic commit. /// @@ -37,8 +37,8 @@ use super::tasks::{TaskHandle, TaskStatus, UploadTaskHandle}; /// /// # Errors /// -/// Methods return [`SessionError::Aborted`] if the parent session has been -/// aborted, and [`SessionError::AlreadyCommitted`] if [`commit`](Self::commit) +/// Methods return [`XetError::Aborted`] if the parent session has been +/// aborted, and [`XetError::AlreadyCommitted`] if [`commit`](Self::commit) /// has already been called. #[derive(Clone)] pub struct UploadCommit { @@ -55,7 +55,7 @@ impl std::ops::Deref for UploadCommit { impl UploadCommit { /// Create a new upload commit from an **async** context. Initialisation logic shared by the sync and async /// constructors. - pub(super) async fn new(session: XetSession) -> Result { + pub(super) async fn new(session: XetSession) -> Result { let commit_id = UniqueID::new(); let config = create_translator_config(&session)?; let upload_session = FileUploadSession::new(Arc::new(config)).await?; @@ -78,7 +78,7 @@ impl UploadCommit { } /// Abort this upload commit. - pub(super) fn abort(&self) -> Result<(), SessionError> { + pub(super) fn abort(&self) -> Result<(), XetError> { self.inner.abort() } @@ -101,14 +101,14 @@ impl UploadCommit { /// /// # Errors /// - /// Returns [`SessionError::Aborted`] if the session has been aborted, or - /// [`SessionError::AlreadyCommitted`] if [`commit`](Self::commit) has + /// Returns [`XetError::Aborted`] if the session has been aborted, or + /// [`XetError::AlreadyCommitted`] if [`commit`](Self::commit) has /// already been called. pub async fn upload_from_path( &self, file_path: PathBuf, sha256: Sha256Policy, - ) -> Result { + ) -> Result { self.session.check_alive()?; // Use the absolute path in case the process current working directory changes @@ -129,8 +129,8 @@ impl UploadCommit { /// ```rust,no_run /// # use std::fs::File; /// # use std::io::Read; - /// # use xet::xet_session::SessionError; - /// # async fn example(commit: xet::xet_session::UploadCommit, filename: &str, filesize: u64) -> Result<(), Box> { + /// # use xet::XetError; + /// # async fn example(commit: xet::xet_session::UploadCommit, filename: &str, filesize: u64) -> anyhow::Result<()> { /// # use xet::xet_session::Sha256Policy; /// let (handle, mut cleaner) = commit.upload_file(Some(filename.into()), filesize, Sha256Policy::Compute).await?; /// let mut reader = File::open(&filename)?; @@ -166,7 +166,7 @@ impl UploadCommit { file_name: Option, file_size: u64, sha256: Sha256Policy, - ) -> Result<(TaskHandle, SingleFileCleaner), SessionError> { + ) -> Result<(TaskHandle, SingleFileCleaner), XetError> { self.session.check_alive()?; let inner = self.inner.clone(); @@ -190,14 +190,14 @@ impl UploadCommit { /// /// # Errors /// - /// Returns [`SessionError::Aborted`] if the session has been aborted, or - /// [`SessionError::AlreadyCommitted`] if [`commit`](Self::commit) has already been called. + /// Returns [`XetError::Aborted`] if the session has been aborted, or + /// [`XetError::AlreadyCommitted`] if [`commit`](Self::commit) has already been called. pub async fn upload_bytes( &self, bytes: Vec, sha256: Sha256Policy, tracking_name: Option, - ) -> Result { + ) -> Result { self.session.check_alive()?; let inner = self.inner.clone(); @@ -207,7 +207,7 @@ impl UploadCommit { } /// Return a snapshot of progress for every queued upload. - pub fn get_progress(&self) -> Result { + pub fn get_progress(&self) -> Result { let session_opt = self.upload_session.lock()?.clone(); if let Some(upload_session) = session_opt { return Ok(upload_session.report()); @@ -221,12 +221,12 @@ impl UploadCommit { /// Wait for all uploads to complete and push metadata to the CAS server. /// /// Returns a `HashMap` keyed by task ID where each value is - /// [`UploadResult`] (= `Arc>`). + /// [`UploadResult`] (= `Arc>`). /// A single failed upload does not prevent the others from being collected. /// /// Consumes `self` — subsequent calls on any clone will return - /// [`SessionError::AlreadyCommitted`]. - pub async fn commit(self) -> Result, SessionError> { + /// [`XetError::AlreadyCommitted`]. + pub async fn commit(self) -> Result, XetError> { let inner = self.inner.clone(); self.session .dispatch("commit", async move { inner.handle_commit().await }) @@ -249,7 +249,7 @@ impl UploadCommit { /// /// # Errors /// - /// Returns [`SessionError::WrongRuntimeMode`] if the session was created with an external + /// Returns [`XetError::WrongRuntimeMode`] if the session was created with an external /// tokio runtime ([`XetSessionBuilder::with_tokio_handle`] / [`XetSessionBuilder::build_async`] /// inside a tokio context). Use [`upload_from_path`](Self::upload_from_path)`.await` instead. /// @@ -260,9 +260,9 @@ impl UploadCommit { &self, file_path: PathBuf, sha256: Sha256Policy, - ) -> Result { + ) -> Result { if matches!(self.session.runtime_mode, RuntimeMode::External) { - return Err(SessionError::wrong_mode( + return Err(XetError::wrong_mode( "upload_from_path_blocking() cannot be called on a session using an \ external tokio runtime (with_tokio_handle() or tokio build_async()); \ use upload_from_path().await instead", @@ -281,7 +281,7 @@ impl UploadCommit { /// /// # Errors /// - /// Returns [`SessionError::WrongRuntimeMode`] if the session was created with an external + /// Returns [`XetError::WrongRuntimeMode`] if the session was created with an external /// tokio runtime ([`XetSessionBuilder::with_tokio_handle`] / [`XetSessionBuilder::build_async`] /// inside a tokio context). Use [`upload_bytes`](Self::upload_bytes)`.await` instead. /// @@ -293,9 +293,9 @@ impl UploadCommit { bytes: Vec, sha256: Sha256Policy, tracking_name: Option, - ) -> Result { + ) -> Result { if matches!(self.session.runtime_mode, RuntimeMode::External) { - return Err(SessionError::wrong_mode( + return Err(XetError::wrong_mode( "upload_bytes_blocking() cannot be called on a session using an \ external tokio runtime (with_tokio_handle() or tokio build_async()); \ use upload_bytes().await instead", @@ -313,7 +313,7 @@ impl UploadCommit { /// /// # Errors /// - /// Returns [`SessionError::WrongRuntimeMode`] if the session was created with an external + /// Returns [`XetError::WrongRuntimeMode`] if the session was created with an external /// tokio runtime ([`XetSessionBuilder::with_tokio_handle`] / [`XetSessionBuilder::build_async`] /// inside a tokio context). Use [`upload_file`](Self::upload_file)`.await` instead. /// @@ -325,9 +325,9 @@ impl UploadCommit { file_name: Option, file_size: u64, sha256: Sha256Policy, - ) -> Result<(TaskHandle, SingleFileCleaner), SessionError> { + ) -> Result<(TaskHandle, SingleFileCleaner), XetError> { if matches!(self.session.runtime_mode, RuntimeMode::External) { - return Err(SessionError::wrong_mode( + return Err(XetError::wrong_mode( "upload_file_blocking() cannot be called on a session using an \ external tokio runtime (with_tokio_handle() or tokio build_async()); \ use upload_file().await instead", @@ -342,7 +342,7 @@ impl UploadCommit { } /// Blocking version of [`get_progress`](Self::get_progress). - pub fn get_progress_blocking(&self) -> Result { + pub fn get_progress_blocking(&self) -> Result { self.get_progress() } @@ -351,7 +351,7 @@ impl UploadCommit { /// # Panics /// /// Panics if called from within a tokio async runtime. - pub fn commit_blocking(self) -> Result, SessionError> { + pub fn commit_blocking(self) -> Result, XetError> { let commit = self.clone(); self.runtime().external_run_async_task(commit.commit())? } @@ -362,7 +362,7 @@ impl UploadCommit { /// The `Arc` lets the same value be stored in both the `commit()` return map /// and the per-task [`UploadTaskHandle`] without requiring the inner /// `Result` to be `Clone`. -pub type UploadResult = Arc>; +pub type UploadResult = Arc>; /// Handle for a single upload task tracked internally by UploadCommit. struct InnerUploadTaskHandle { @@ -398,10 +398,10 @@ impl UploadCommitInner { // ===== State helpers ===== /// Check whether the commit is still accepting new tasks. - fn check_accepting_tasks(state: &GroupState) -> Result<(), SessionError> { + fn check_accepting_tasks(state: &GroupState) -> Result<(), XetError> { match *state { - GroupState::Finished => Err(SessionError::AlreadyCommitted), - GroupState::Aborted => Err(SessionError::Aborted), + GroupState::Finished => Err(XetError::AlreadyCommitted), + GroupState::Aborted => Err(XetError::Aborted), GroupState::Alive => Ok(()), } } @@ -410,13 +410,13 @@ impl UploadCommitInner { &self, file_path: PathBuf, sha256: Sha256Policy, - ) -> Result { + ) -> Result { let upload_session = { let state = self.state.lock().await; Self::check_accepting_tasks(&state)?; let Some(upload_session) = self.upload_session.lock()?.clone() else { - return Err(SessionError::other("Upload session not initialized")); + return Err(XetError::other("Upload session not initialized")); }; upload_session }; @@ -457,12 +457,12 @@ impl UploadCommitInner { tracking_name: Option, file_size: u64, sha256: Sha256Policy, - ) -> Result<(TaskHandle, SingleFileCleaner), SessionError> { + ) -> Result<(TaskHandle, SingleFileCleaner), XetError> { let state = self.state.lock().await; Self::check_accepting_tasks(&state)?; let Some(upload_session) = self.upload_session.lock()?.clone() else { - return Err(SessionError::other("Upload session not initialized")); + return Err(XetError::other("Upload session not initialized")); }; let tracking_name: Option> = tracking_name.as_deref().map(Arc::from); @@ -481,13 +481,13 @@ impl UploadCommitInner { bytes: Vec, sha256: Sha256Policy, tracking_name: Option, - ) -> Result { + ) -> Result { let upload_session = { let state = self.state.lock().await; Self::check_accepting_tasks(&state)?; let Some(upload_session) = self.upload_session.lock()?.clone() else { - return Err(SessionError::other("Upload session not initialized")); + return Err(XetError::other("Upload session not initialized")); }; upload_session }; @@ -519,14 +519,14 @@ impl UploadCommitInner { } /// Join all active upload tasks and finalise the upload session. - pub(super) async fn handle_commit(&self) -> Result, SessionError> { + pub(super) async fn handle_commit(&self) -> Result, XetError> { // Mark as not accepting new tasks. The tokio state lock serialises this // against all three registration methods, including start_upload_file // which holds it across the start_clean await. { let mut state_guard = self.state.lock().await; if *state_guard == GroupState::Finished { - return Err(SessionError::AlreadyCommitted); + return Err(XetError::AlreadyCommitted); } *state_guard = GroupState::Aborted; // stop new tasks while draining } @@ -538,7 +538,7 @@ impl UploadCommitInner { let mut results = HashMap::new(); let mut join_err = None; for (task_id, handle) in active_tasks { - match handle.join_handle.await.map_err(SessionError::from) { + match handle.join_handle.await.map_err(XetError::from) { Ok(Ok(file_info)) => { TaskStatus::mark_terminal(&handle.status, TaskStatus::Completed); let result = Arc::new(Ok(FileMetadata { @@ -552,12 +552,12 @@ impl UploadCommitInner { }, Ok(Err(data_err)) => { TaskStatus::mark_terminal(&handle.status, TaskStatus::Failed); - let result = Arc::new(Err(SessionError::from(data_err))); + let result = Arc::new(Err(XetError::from(data_err))); results.insert(task_id, result.clone()); let _ = handle.result.set(result); }, Err(e) => { - if matches!(e, SessionError::Cancelled(_)) { + if matches!(e, XetError::Cancelled(_)) { TaskStatus::mark_cancelled(&handle.status); } else { TaskStatus::mark_terminal(&handle.status, TaskStatus::Failed); @@ -606,7 +606,7 @@ impl UploadCommitInner { /// obtaining a session and prevents `handle_commit` from calling `finalize`. /// It does not invalidate any `SingleFileCleaner` already in the caller's hands, /// since the cleaner holds its own `Arc` to the session. - fn abort(&self) -> Result<(), SessionError> { + fn abort(&self) -> Result<(), XetError> { if let Ok(mut guard) = self.state.try_lock() { *guard = GroupState::Aborted; } @@ -641,12 +641,13 @@ mod tests { use std::sync::mpsc; use std::time::Duration; + use anyhow::Result; use tempfile::{TempDir, tempdir}; use super::*; use crate::xet_session::session::{RuntimeMode, XetSession, XetSessionBuilder}; - async fn local_session(temp: &TempDir) -> Result> { + async fn local_session(temp: &TempDir) -> Result { let cas_path = temp.path().join("cas"); Ok(XetSessionBuilder::new() .with_endpoint(format!("local://{}", cas_path.display())) @@ -667,7 +668,7 @@ mod tests { #[test] // commit() must block while any enqueue method holds the state lock. - fn test_commit_blocked_while_upload_registration_holds_state_lock() -> Result<(), Box> { + fn test_commit_blocked_while_upload_registration_holds_state_lock() -> Result<()> { let temp = tempdir()?; let cas_path = temp.path().join("cas"); let session = XetSessionBuilder::new() @@ -762,7 +763,7 @@ mod tests { let c2 = c1.clone(); c1.commit().await.unwrap(); let err = c2.commit().await.unwrap_err(); - assert!(matches!(err, SessionError::AlreadyCommitted | SessionError::Internal(_))); + assert!(matches!(err, XetError::AlreadyCommitted | XetError::Internal(_))); } #[tokio::test(flavor = "multi_thread")] @@ -787,7 +788,7 @@ mod tests { .upload_from_path(PathBuf::from("nonexistent.bin"), Sha256Policy::Compute) .await .unwrap_err(); - assert!(matches!(err, SessionError::Aborted)); + assert!(matches!(err, XetError::Aborted)); } #[tokio::test(flavor = "multi_thread")] @@ -800,7 +801,7 @@ mod tests { .upload_bytes(b"data".to_vec(), Sha256Policy::Compute, Some("bytes 1".into())) .await .unwrap_err(); - assert!(matches!(err, SessionError::Aborted)); + assert!(matches!(err, XetError::Aborted)); } // ── Post-commit guards (AlreadyCommitted) ──────────────────────────────── @@ -816,7 +817,7 @@ mod tests { .upload_from_path(PathBuf::from("any.bin"), Sha256Policy::Compute) .await .unwrap_err(); - assert!(matches!(err, SessionError::AlreadyCommitted)); + assert!(matches!(err, XetError::AlreadyCommitted)); } #[tokio::test(flavor = "multi_thread")] @@ -830,7 +831,7 @@ mod tests { .upload_bytes(b"hello".to_vec(), Sha256Policy::Compute, None) .await .unwrap_err(); - assert!(matches!(err, SessionError::AlreadyCommitted)); + assert!(matches!(err, XetError::AlreadyCommitted)); } // ── API coverage & abort ───────────────────────────────────────────────── @@ -1222,7 +1223,7 @@ mod tests { // ── Blocking API tests ──────────────────────────────────────────────────── - fn local_session_sync(temp: &TempDir) -> Result> { + fn local_session_sync(temp: &TempDir) -> Result { let cas_path = temp.path().join("cas"); Ok(XetSessionBuilder::new() .with_endpoint(format!("local://{}", cas_path.display())) @@ -1230,7 +1231,7 @@ mod tests { } #[test] - fn test_blocking_upload_bytes_round_trip() -> Result<(), Box> { + fn test_blocking_upload_bytes_round_trip() -> Result<()> { let temp = tempdir()?; let session = local_session_sync(&temp)?; let data = b"Hello, upload commit round-trip!"; @@ -1246,7 +1247,7 @@ mod tests { } #[test] - fn test_blocking_upload_from_path_round_trip() -> Result<(), Box> { + fn test_blocking_upload_from_path_round_trip() -> Result<()> { let temp = tempdir()?; let session = local_session_sync(&temp)?; let src = temp.path().join("data.bin"); @@ -1263,7 +1264,7 @@ mod tests { } #[test] - fn test_blocking_upload_result_access_patterns() -> Result<(), Box> { + fn test_blocking_upload_result_access_patterns() -> Result<()> { let temp = tempdir()?; let session = local_session_sync(&temp)?; let data = b"result access patterns"; @@ -1288,7 +1289,7 @@ mod tests { } #[test] - fn test_blocking_upload_streaming_round_trip() -> Result<(), Box> { + fn test_blocking_upload_streaming_round_trip() -> Result<()> { let temp = tempdir()?; let session = local_session_sync(&temp)?; let data = b"streamed upload bytes"; @@ -1309,7 +1310,7 @@ mod tests { } #[test] - fn test_blocking_upload_multiple_files_in_one_commit() -> Result<(), Box> { + fn test_blocking_upload_multiple_files_in_one_commit() -> Result<()> { let temp = tempdir()?; let session = local_session_sync(&temp)?; let commit = session.new_upload_commit_blocking()?; @@ -1322,7 +1323,7 @@ mod tests { } #[test] - fn test_blocking_upload_progress_reflects_bytes_after_commit() -> Result<(), Box> { + fn test_blocking_upload_progress_reflects_bytes_after_commit() -> Result<()> { let temp = tempdir()?; let session = local_session_sync(&temp)?; let data = b"progress tracking upload data"; @@ -1339,7 +1340,7 @@ mod tests { } #[test] - fn test_blocking_upload_file_returns_handle_without_status() -> Result<(), Box> { + fn test_blocking_upload_file_returns_handle_without_status() -> Result<()> { let temp = tempdir()?; let session = local_session_sync(&temp)?; let commit = session.new_upload_commit_blocking()?; @@ -1395,7 +1396,7 @@ mod tests { .upload_from_path_blocking(PathBuf::from("/nonexistent"), Sha256Policy::Compute) .err() .unwrap(); - assert!(matches!(err, SessionError::WrongRuntimeMode(_))); + assert!(matches!(err, XetError::WrongRuntimeMode(_))); } #[tokio::test(flavor = "multi_thread")] @@ -1405,7 +1406,7 @@ mod tests { assert_eq!(session.runtime_mode, RuntimeMode::External); let commit = session.new_upload_commit().await.unwrap(); let err = commit.upload_bytes_blocking(vec![], Sha256Policy::Compute, None).err().unwrap(); - assert!(matches!(err, SessionError::WrongRuntimeMode(_))); + assert!(matches!(err, XetError::WrongRuntimeMode(_))); } #[tokio::test(flavor = "multi_thread")] @@ -1415,7 +1416,7 @@ mod tests { assert_eq!(session.runtime_mode, RuntimeMode::External); let commit = session.new_upload_commit().await.unwrap(); let err = commit.upload_file_blocking(None, 0, Sha256Policy::Compute).err().unwrap(); - assert!(matches!(err, SessionError::WrongRuntimeMode(_))); + assert!(matches!(err, XetError::WrongRuntimeMode(_))); } // ── Owned-mode _blocking panic guard ───────────────────────────────────── diff --git a/xet_pkg/tests/test_xet_session.rs b/xet_pkg/tests/test_xet_session.rs index 13b58975..c9793288 100644 --- a/xet_pkg/tests/test_xet_session.rs +++ b/xet_pkg/tests/test_xet_session.rs @@ -18,9 +18,8 @@ use std::path::PathBuf; use std::pin::Pin; use tempfile::{TempDir, tempdir}; -use xet::xet_session::{ - FileMetadata, SessionError, Sha256Policy, TaskStatus, XetFileInfo, XetSession, XetSessionBuilder, -}; +use xet::XetError; +use xet::xet_session::{FileMetadata, Sha256Policy, TaskStatus, XetFileInfo, XetSession, XetSessionBuilder}; // ── Helpers ────────────────────────────────────────────────────────────── @@ -794,14 +793,14 @@ fn blocking_in_non_tokio_executor_upload_from_path() { async fn external_mode_blocking_upload_returns_wrong_mode() { let session = XetSessionBuilder::new().build_async().await.unwrap(); let err = session.new_upload_commit_blocking().err().unwrap(); - assert!(matches!(err, SessionError::WrongRuntimeMode(_))); + assert!(matches!(err, XetError::WrongRuntimeMode(_))); } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] async fn external_mode_blocking_download_returns_wrong_mode() { let session = XetSessionBuilder::new().build_async().await.unwrap(); let err = session.new_download_group_blocking().err().unwrap(); - assert!(matches!(err, SessionError::WrongRuntimeMode(_))); + assert!(matches!(err, XetError::WrongRuntimeMode(_))); } // ── 7. Abort behavior ─────────────────────────────────────────────────── @@ -811,7 +810,7 @@ async fn async_abort_prevents_new_commits() { let session = XetSessionBuilder::new().build_async().await.unwrap(); session.abort().unwrap(); let err = session.new_upload_commit().await.err().unwrap(); - assert!(matches!(err, SessionError::Aborted)); + assert!(matches!(err, XetError::Aborted)); } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] @@ -819,7 +818,7 @@ async fn async_abort_prevents_new_groups() { let session = XetSessionBuilder::new().build_async().await.unwrap(); session.abort().unwrap(); let err = session.new_download_group().await.err().unwrap(); - assert!(matches!(err, SessionError::Aborted)); + assert!(matches!(err, XetError::Aborted)); } #[test] @@ -827,7 +826,7 @@ fn blocking_abort_prevents_new_commits() { let session = XetSessionBuilder::new().build().unwrap(); session.abort().unwrap(); let err = session.new_upload_commit_blocking().err().unwrap(); - assert!(matches!(err, SessionError::Aborted)); + assert!(matches!(err, XetError::Aborted)); } #[test] @@ -835,7 +834,7 @@ fn blocking_abort_prevents_new_groups() { let session = XetSessionBuilder::new().build().unwrap(); session.abort().unwrap(); let err = session.new_download_group_blocking().err().unwrap(); - assert!(matches!(err, SessionError::Aborted)); + assert!(matches!(err, XetError::Aborted)); } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] @@ -848,7 +847,7 @@ async fn async_abort_rejects_upload_on_existing_commit() { .await .err() .unwrap(); - assert!(matches!(err, SessionError::Aborted)); + assert!(matches!(err, XetError::Aborted)); } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] @@ -868,7 +867,7 @@ async fn async_abort_rejects_download_on_existing_group() { .await .err() .unwrap(); - assert!(matches!(err, SessionError::Aborted)); + assert!(matches!(err, XetError::Aborted)); } // ── 8. Deduplication (same content uploaded twice) ─────────────────────── diff --git a/xet_runtime/Cargo.toml b/xet_runtime/Cargo.toml index c8d90227..45bd821c 100644 --- a/xet_runtime/Cargo.toml +++ b/xet_runtime/Cargo.toml @@ -11,6 +11,7 @@ name = "xet_runtime" path = "src/lib.rs" [dependencies] +anyhow = { workspace = true } chrono = { workspace = true } colored = { workspace = true } const-str = "1.1" @@ -70,7 +71,6 @@ name = "log_test_executable" path = "tests/bin/log_test_executable.rs" [dev-dependencies] -anyhow = { workspace = true } xet-core-structures = { version = "1.4.0", path = "../xet_core_structures" } rand = { workspace = true } diff --git a/xet_runtime/src/core/errors.rs b/xet_runtime/src/core/errors.rs deleted file mode 100644 index 670bcdfa..00000000 --- a/xet_runtime/src/core/errors.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub use crate::error::RuntimeError as MultithreadedRuntimeError; -pub type Result = std::result::Result; diff --git a/xet_runtime/src/core/mod.rs b/xet_runtime/src/core/mod.rs index 2d460625..7a89db47 100644 --- a/xet_runtime/src/core/mod.rs +++ b/xet_runtime/src/core/mod.rs @@ -1,5 +1,4 @@ pub mod common; -pub mod errors; pub mod exports; pub mod runtime; diff --git a/xet_runtime/src/file_utils/privilege_context.rs b/xet_runtime/src/file_utils/privilege_context.rs index 5c3cb119..b76e517e 100644 --- a/xet_runtime/src/file_utils/privilege_context.rs +++ b/xet_runtime/src/file_utils/privilege_context.rs @@ -243,11 +243,13 @@ mod test { use std::os::unix::fs::MetadataExt; use std::path::Path; + use anyhow::Result; + use super::{PrivilegedExecutionContext, WARNING_PRINTED}; #[test] #[ignore = "run manually"] - fn test_create_dir_all() -> anyhow::Result<()> { + fn test_create_dir_all() -> Result<()> { // Run this test manually, steps: /* For Unix @@ -273,7 +275,7 @@ mod test { #[test] #[ignore = "run manually"] - fn test_create_dir_all_sudo() -> anyhow::Result<()> { + fn test_create_dir_all_sudo() -> Result<()> { // Run this test manually, steps: /* For Unix @@ -309,7 +311,7 @@ mod test { #[test] #[ignore = "run manually"] - fn test_create_file() -> anyhow::Result<()> { + fn test_create_file() -> Result<()> { // Run this test manually, steps: /* For Unix @@ -343,7 +345,7 @@ mod test { #[test] #[ignore = "run manually"] - fn test_create_file_sudo() -> anyhow::Result<()> { + fn test_create_file_sudo() -> Result<()> { // Run this test manually, steps: /* For Unix