diff --git a/hf_xet/src/lib.rs b/hf_xet/src/lib.rs index 41c8d118..4ed06d38 100644 --- a/hf_xet/src/lib.rs +++ b/hf_xet/src/lib.rs @@ -1,7 +1,6 @@ mod logging; mod progress_update; mod runtime; -mod telemetry; mod token_refresh; use std::fmt::Debug; @@ -20,6 +19,7 @@ use runtime::async_run; use token_refresh::WrappedTokenRefresher; use tracing::debug; +use crate::logging::init_logging; use crate::progress_update::WrappedProgressUpdater; // For profiling @@ -311,6 +311,9 @@ pub fn hf_xet(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> { // huggingface_hub. m.add_class::()?; + // Make sure the logger is set up. + init_logging(py); + #[cfg(feature = "profiling")] { profiling::start_profiler(); diff --git a/hf_xet/src/logging.rs b/hf_xet/src/logging.rs index f9aebc5a..a1276fd9 100644 --- a/hf_xet/src/logging.rs +++ b/hf_xet/src/logging.rs @@ -1,18 +1,15 @@ use std::env; use std::path::Path; -use std::sync::{Mutex, OnceLock}; +use std::sync::OnceLock; use pyo3::types::PyAnyMethods; use pyo3::Python; -use tracing::{error, info}; -use tracing_subscriber::filter::FilterFn; +use tracing::info; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::util::SubscriberInitExt; -use tracing_subscriber::{EnvFilter, Layer}; +use tracing_subscriber::EnvFilter; use utils::normalized_path_from_user_string; -use crate::telemetry::{init_telemetry_logging, restart_telemetry_task_after_spawn}; - /// Default log level for the library to use. Override using `RUST_LOG` env variable. #[cfg(not(debug_assertions))] const DEFAULT_LOG_LEVEL: &str = "warn"; @@ -112,7 +109,7 @@ fn get_version_info_string(py: Python<'_>) -> String { version_info } -fn init_global_logging(py: Python) { +pub fn init_logging(py: Python) { let version_info = get_version_info_string(py); if let Ok(log_path_s) = env::var("HF_XET_LOG_FILE") { @@ -137,27 +134,6 @@ fn init_global_logging(py: Python) { .or_else(|_| EnvFilter::try_new(DEFAULT_LOG_LEVEL)) .unwrap_or_default(); - // Do we use telemetry? - if env::var("HF_HUB_ENABLE_TELEMETRY").is_ok() { - match init_telemetry_logging(version_info.clone()) { - Ok(tl) => { - let telemetry_filter_layer = tl.with_filter(FilterFn::new(|meta| meta.target() == "client_telemetry")); - - tracing_subscriber::registry() - .with(filter_layer) - .with(fmt_layer_base.json()) - .with(telemetry_filter_layer) - .init(); - - return; - }, - - Err(e) => { - eprintln!("Error initializing telemetry process : {e:?}. Reverting to logging to console."); - }, - } - } - // Now, just use basic console logging. let tr_sub = tracing_subscriber::registry().with(filter_layer); @@ -169,23 +145,3 @@ fn init_global_logging(py: Python) { info!("hf_xet version info: {version_info}"); } - -static INITIALIZED_LOGGING_ID: Mutex = Mutex::new(0); - -pub fn check_logging_state(py: Python<'_>) { - let Ok(mut logger_pid) = INITIALIZED_LOGGING_ID.lock() else { - return; - }; - - let pid = std::process::id(); - - if *logger_pid == 0 { - init_global_logging(py); - } else if *logger_pid != pid { - if let Err(e) = restart_telemetry_task_after_spawn() { - error!("Error restarting telemetry task in subprocess; telemetry may not work: {e:?}"); - } - } - - *logger_pid = pid; -} diff --git a/hf_xet/src/runtime.rs b/hf_xet/src/runtime.rs index 215a0507..b1be8364 100644 --- a/hf_xet/src/runtime.rs +++ b/hf_xet/src/runtime.rs @@ -10,8 +10,6 @@ use xet_threadpool::errors::MultithreadedRuntimeError; use xet_threadpool::sync_primatives::spawn_os_thread; use xet_threadpool::ThreadPool; -use crate::logging::check_logging_state; - lazy_static! { static ref SIGINT_DETECTED: Arc = Arc::new(AtomicBool::new(false)); static ref SIGINT_HANDLER_INSTALL_PID: (AtomicU32, Mutex<()>) = (AtomicU32::new(0), Mutex::new(())); @@ -199,9 +197,6 @@ where F::Output: Into> + Send + Sync, Out: Send + Sync + 'static, { - // Make sure the logger is set up. - check_logging_state(py); - let result: PyResult = py.allow_threads(move || { // Now, without the GIL, spawn the task on a new OS thread. This avoids having tokio cache stuff in // thread-local storage that is invalidated after a fork-exec. diff --git a/hf_xet/src/telemetry.rs b/hf_xet/src/telemetry.rs deleted file mode 100644 index eca31e1d..00000000 --- a/hf_xet/src/telemetry.rs +++ /dev/null @@ -1,372 +0,0 @@ -use std::collections::HashMap; -use std::env; -use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::{Arc, Mutex, OnceLock}; -use std::time::Duration; - -use bipbuffer::BipBuffer; -use cas_client::exports::reqwest; -use cas_client::exports::reqwest::header::{HeaderMap, HeaderName, HeaderValue}; -use serde::{Deserialize, Serialize}; -use tracing::{debug, Subscriber}; -use tracing_subscriber::Layer; -use xet_threadpool::errors::MultithreadedRuntimeError; -use xet_threadpool::exports::tokio; - -pub const TELEMETRY_PRE_ALLOC_BYTES: usize = 2 * 1024 * 1024; -pub const TELEMETRY_PERIOD_MS: u64 = 100; -pub const HF_DEFAULT_ENDPOINT: &str = "https://huggingface.co"; -pub const HF_DEFAULT_STAGING_ENDPOINT: &str = "https://hub-ci.huggingface.co"; -pub const TELEMETRY_SUFFIX: &str = "api/telemetry/xet/cli"; - -#[derive(Debug)] -pub struct LoggingStats { - pub records_written: AtomicU64, - pub records_refused: AtomicU64, - pub bytes_written: AtomicU64, - pub records_read: AtomicU64, - pub records_corrupted: AtomicU64, - pub bytes_read: AtomicU64, - pub records_transmitted: AtomicU64, - pub records_dropped: AtomicU64, - pub bytes_refused: AtomicU64, -} - -impl Default for LoggingStats { - fn default() -> Self { - Self { - records_written: AtomicU64::new(0), - records_refused: AtomicU64::new(0), - bytes_written: AtomicU64::new(0), - records_read: AtomicU64::new(0), - records_corrupted: AtomicU64::new(0), - bytes_read: AtomicU64::new(0), - records_transmitted: AtomicU64::new(0), - records_dropped: AtomicU64::new(0), - bytes_refused: AtomicU64::new(0), - } - } -} - -fn is_staging_mode() -> bool { - matches!(env::var("HUGGINGFACE_CO_STAGING").as_deref(), Ok("1")) -} - -pub fn get_telemetry_endpoint() -> String { - env::var("HF_ENDPOINT").unwrap_or_else(|_| { - if is_staging_mode() { - HF_DEFAULT_STAGING_ENDPOINT.to_string() - } else { - HF_DEFAULT_ENDPOINT.to_string() - } - }) -} - -#[derive(Serialize, Deserialize, Debug)] -struct SerializableHeaders { - headers: HashMap, -} - -impl From<&HeaderMap> for SerializableHeaders { - fn from(header_map: &HeaderMap) -> Self { - let headers = header_map - .iter() - .filter_map(|(name, value)| { - let name = name.to_string(); - let value = value.to_str().ok()?.to_string(); - Some((name, value)) - }) - .collect(); - - SerializableHeaders { headers } - } -} - -impl TryFrom for HeaderMap { - type Error = reqwest::header::InvalidHeaderValue; - - fn try_from(serializable: SerializableHeaders) -> Result { - let mut header_map = HeaderMap::new(); - for (key, value) in serializable.headers { - let name = HeaderName::from_bytes(key.as_bytes()).unwrap(); - let val = HeaderValue::from_str(&value)?; - header_map.insert(name, val); - } - Ok(header_map) - } -} - -pub struct TelemetryLogger { - log_buffer: Mutex>, - stats: LoggingStats, - version_info: String, -} - -#[derive(Clone)] -pub struct TelemetryLoggerPtr(Arc); - -impl TelemetryLogger { - pub(crate) fn init(version_info: String) -> Result { - let log_buffer = Mutex::new(BipBuffer::new(TELEMETRY_PRE_ALLOC_BYTES)); - let stats = LoggingStats::default(); - - // Start up the background process. - let s = Arc::new(Self { - log_buffer, - stats, - version_info, - }); - - s.spawn_telemetry_task()?; - - Ok(TelemetryLoggerPtr(s)) - } - - fn spawn_telemetry_task(self: &Arc) -> Result<(), MultithreadedRuntimeError> { - let client = reqwest::Client::new(); - let telemetry_url = format!("{}/{}", get_telemetry_endpoint(), TELEMETRY_SUFFIX); - - let s = self.clone(); - - // Set up the task. - let telemetry_send_task = async move { - let mut interval = tokio::time::interval(Duration::from_millis(TELEMETRY_PERIOD_MS)); - - loop { - // Use tokio tick to run this at regular intervals - interval.tick().await; - - let mut read_len: usize = 0; - let mut http_header_map: HeaderMap = HeaderMap::new(); - - { - let mut buffer = s.log_buffer.lock().unwrap(); - - if let Some(block) = buffer.read() { - read_len = block.len(); - s.stats.bytes_read.fetch_add(read_len as u64, Ordering::Relaxed); - - if let Ok(deserialized) = serde_json::from_slice::(block) { - if let Ok(http_header_map_deserialized) = deserialized.try_into() { - s.stats.records_read.fetch_add(1, Ordering::Relaxed); - http_header_map = http_header_map_deserialized; - } else { - s.stats.records_corrupted.fetch_add(1, Ordering::Relaxed); - } - } else { - s.stats.records_corrupted.fetch_add(1, Ordering::Relaxed); - } - } - } - - if read_len > 0 { - let mut buffer = s.log_buffer.lock().unwrap(); - buffer.decommit(read_len); - } - - if !http_header_map.is_empty() { - if let Ok(response) = client.head(telemetry_url.clone()).headers(http_header_map).send().await { - if response.status().is_success() { - s.stats.records_transmitted.fetch_add(1, Ordering::Relaxed); - } else { - debug!( - "Failed to transmit telemetry to {}: HTTP status {}", - telemetry_url, - response.status() - ); - s.stats.records_dropped.fetch_add(1, Ordering::Relaxed); - } - } else { - debug!("Failed to send HEAD request to {}: Error occurred during transmission", telemetry_url); - s.stats.records_dropped.fetch_add(1, Ordering::Relaxed); - } - } - debug!("Stats from telemetry {:?}", s.stats); - } - }; - - // Spawn the background telemetry task on it's own tokio runtime on the current thread; that way it will remain - // isolated and not exist in a limbo state through spawns. We can cleanly restart it in the child - // process. - - // Create a oneshot token to send back the result of starting the runtime. - let (rt_status_sender, rt_status) = tokio::sync::oneshot::channel(); - - std::thread::spawn(move || { - // Get the single threaded runtime to simply poll the log buffers and send them to python. - match tokio::runtime::Builder::new_current_thread().enable_all().build() { - Ok(rt) => { - // Okay, runtime started successfully, start the telemetry send task. - if rt_status_sender.send(Ok(())).is_err() { - eprintln!("Error in reporting ok logging status; pipe closed"); - } - - // Now have this runtime simply run the telemetry task, which should just run in a loop. This - rt.block_on(telemetry_send_task); - }, - Err(e) => { - if let Err(e) = rt_status_sender.send(Err(MultithreadedRuntimeError::Other(format!( - "Initialization Error: Failed to create single threaded runtime for telemetry task {e:?}" - )))) { - eprintln!("Error in reporting Err logging status; pipe closed ({e:?})"); - } - }, - }; - }); - - rt_status.blocking_recv().map_err(|e| { - MultithreadedRuntimeError::Other(format!( - "Initialization Error: Failed to connect with telemetry background thread: {e:?}" - )) - })? - } -} - -impl Layer for TelemetryLoggerPtr -where - S: Subscriber, -{ - fn on_event(&self, event: &tracing::Event<'_>, _ctx: tracing_subscriber::layer::Context<'_, S>) { - let tl = &self.0; - - let mut http_headers = HeaderMap::new(); - { - let mut user_agent = tl.version_info.clone(); - let mut visitor = |field: &tracing::field::Field, value: &dyn std::fmt::Debug| { - user_agent.push_str(&format!("{}/{:?}; ", field.name(), value)); - }; - event.record(&mut visitor); - user_agent = user_agent.replace("\"", ""); - if let Ok(header_value) = HeaderValue::from_str(&user_agent) { - http_headers.insert("User-Agent", header_value); - } else { - tl.stats.records_refused.fetch_add(1, Ordering::Relaxed); - return; - } - } - - let serializable: SerializableHeaders = (&http_headers).into(); - if let Ok(serialized_headers) = serde_json::to_string(&serializable) { - let mut buffer = tl.log_buffer.lock().unwrap(); - if let Ok(reserved) = buffer.reserve(serialized_headers.len()) { - if reserved.len() < serialized_headers.len() { - // log goes to /dev/null if not enough free space - tl.stats.records_refused.fetch_add(1, Ordering::Relaxed); - tl.stats - .bytes_refused - .fetch_add(serialized_headers.len() as u64, Ordering::Relaxed); - buffer.commit(0); - } else { - tl.stats.records_written.fetch_add(1, Ordering::Relaxed); - tl.stats - .bytes_written - .fetch_add(serialized_headers.len() as u64, Ordering::Relaxed); - reserved[..serialized_headers.len()].copy_from_slice(serialized_headers.as_bytes()); - buffer.commit(serialized_headers.len()); - } - } else { - tl.stats.records_refused.fetch_add(1, Ordering::Relaxed); - tl.stats - .bytes_refused - .fetch_add(serialized_headers.len() as u64, Ordering::Relaxed); - } - } else { - tl.stats.records_refused.fetch_add(1, Ordering::Relaxed); - } - } -} - -lazy_static::lazy_static! { - static ref global_telemetry_logger_info : OnceLock> = OnceLock::default(); -} - -/// Restarts the telemetry background task after a spawn has been detected. -pub fn restart_telemetry_task_after_spawn() -> Result<(), MultithreadedRuntimeError> { - if let Some(ref current_tl) = global_telemetry_logger_info.get_or_init(|| None) { - current_tl.0.spawn_telemetry_task()?; - } - - Ok(()) -} - -/// Initializes the telemetry logging; should be called only once. -pub fn init_telemetry_logging(version_info: String) -> Result { - let mut maybe_error = None; - - let tl = global_telemetry_logger_info.get_or_init(|| match TelemetryLogger::init(version_info) { - Err(e) => { - maybe_error = Some(e); - None - }, - Ok(tl) => Some(tl), - }); - - if let Some(e) = maybe_error { - Err(e) - } else { - Ok(tl.clone().expect("Only None if no error.")) - } -} - -#[cfg(test)] -mod tests { - use std::sync::atomic::Ordering; - use std::sync::Arc; - - use bipbuffer::BipBuffer; - use tracing_subscriber::layer::SubscriberExt; - - use super::*; - - #[test] - fn test_buffer_layer() { - let layer = TelemetryLoggerPtr(Arc::new(TelemetryLogger { - log_buffer: Mutex::new(BipBuffer::new(50 * 2)), - stats: LoggingStats::default(), - version_info: "Testing".to_owned(), - })); - - let subscriber = tracing_subscriber::registry().with(layer.clone()); - tracing::subscriber::with_default(subscriber, || { - let stats = &layer.0.stats; - - tracing::info!(target: "client_telemetry", "50 b event"); - assert_eq!(stats.records_written.load(Ordering::Relaxed), 1); - assert_eq!(stats.records_refused.load(Ordering::Relaxed), 0); - assert_eq!(stats.bytes_written.load(Ordering::Relaxed), 50); - assert_eq!(stats.bytes_refused.load(Ordering::Relaxed), 0); - - for _ in 0..9 { - tracing::info!(target: "client_telemetry", "test event"); - } - assert_eq!(stats.records_written.load(Ordering::Relaxed), 2); - assert_eq!(stats.records_refused.load(Ordering::Relaxed), 8); - assert_eq!(stats.bytes_written.load(Ordering::Relaxed), 50 * 2); - assert_eq!(stats.bytes_refused.load(Ordering::Relaxed), 50 * 8); - }); - } - - #[test] - fn test_serializable() { - let mut header_map = HeaderMap::new(); - header_map.insert("Content-Type", HeaderValue::from_static("application/json")); - header_map.insert("Authorization", HeaderValue::from_static("Bearer token")); - - let serializable: SerializableHeaders = (&header_map).into(); - - assert_eq!(serializable.headers.get("content-type"), Some(&"application/json".to_string())); - assert_eq!(serializable.headers.get("authorization"), Some(&"Bearer token".to_string())); - - let mut headers = HashMap::new(); - headers.insert("Content-Type".to_string(), "application/json".to_string()); - headers.insert("Authorization".to_string(), "Bearer token".to_string()); - - let serializable = SerializableHeaders { headers }; - let header_map: Result = HeaderMap::try_from(serializable); - - assert!(header_map.is_ok()); - let header_map = header_map.unwrap(); - assert_eq!(header_map.get("Content-Type").unwrap(), "application/json"); - assert_eq!(header_map.get("Authorization").unwrap(), "Bearer token"); - } -}