diff --git a/xet_client/src/cas_client/adaptive_concurrency/controller.rs b/xet_client/src/cas_client/adaptive_concurrency/controller.rs index d9670a50..39f57635 100644 --- a/xet_client/src/cas_client/adaptive_concurrency/controller.rs +++ b/xet_client/src/cas_client/adaptive_concurrency/controller.rs @@ -21,6 +21,9 @@ use super::rtt_prediction::RTTPredictor; const MIN_PARTIAL_REPORT_INTERVAL_MS: u64 = 200; const PARTIAL_REPORT_WEIGHT_RATIO: f64 = 0.2; +const REFERENCE_SIZE_QUANTILE_Z: f64 = 1.645; // z-score for 95th percentile +const MIN_SIZE_OBSERVATIONS_FOR_REFERENCE: u64 = 3; + /// The network model state extracted from the concurrency controller. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct CCSuccessModelState { @@ -59,6 +62,14 @@ struct ConcurrencyControllerState { // The number of completed transmissions observed so far. completed_transmissions_count: u64, + + // Exponentially-weighted trackers for estimating the transmission size distribution. + // Tracks log(size_bytes) for the mean and log(size_bytes)^2 for variance computation, + // enabling a log-normal 95th percentile estimate used as a dynamic reference size + // for the concurrency increase decision. + size_log_tracker: ExpWeightedMovingAvg, + size_log_sq_tracker: ExpWeightedMovingAvg, + size_observation_count: u64, } impl ConcurrencyControllerState { @@ -74,6 +85,9 @@ impl ConcurrencyControllerState { last_logging_time: Instant::now(), bytes_sent_so_far: 0, completed_transmissions_count: 0, + size_log_tracker: ExpWeightedMovingAvg::new_count_decay(rtt_half_life_count), + size_log_sq_tracker: ExpWeightedMovingAvg::new_count_decay(rtt_half_life_count), + size_observation_count: 0, } } @@ -119,7 +133,7 @@ impl ConcurrencyControllerState { let config = xet_config(); let (predicted_max_rtt, prediction_max_rtt_standard_error) = self .rtt_predictor - .predict(config.client.ac_target_rtt_transmission_size, current_concurrency); + .predict(*config.client.ac_max_reference_transmission_size, current_concurrency); let predicted_bandwidth = self.rtt_predictor.predicted_bandwidth(); @@ -129,6 +143,38 @@ impl ConcurrencyControllerState { predicted_bandwidth: predicted_bandwidth.unwrap_or(0.), } } + + /// Estimates a workload-appropriate reference transmission size using the 95th percentile + /// of observed transfer sizes (log-normal model). Returns None if insufficient data. + /// The result is clamped to [ac_min_reference_transmission_size, ac_max_reference_transmission_size]. + fn estimated_reference_transmission_size(&self) -> Option { + if self.size_observation_count < MIN_SIZE_OBSERVATIONS_FOR_REFERENCE { + return None; + } + + let mu = self.size_log_tracker.value(); + let mu_sq = self.size_log_sq_tracker.value(); + let variance = (mu_sq - mu * mu).max(0.0); + let sigma = variance.sqrt(); + + let quantile_95 = (mu + REFERENCE_SIZE_QUANTILE_Z * sigma).exp(); + + let config = xet_config(); + let min_size = *config.client.ac_min_reference_transmission_size; + let max_size = *config.client.ac_max_reference_transmission_size; + + Some((quantile_95 as u64).clamp(min_size, max_size)) + } + + fn update_size_tracking(&mut self, n_bytes: u64) { + if n_bytes == 0 { + return; + } + let log_size = (n_bytes as f64).ln(); + self.size_log_tracker.update(log_size); + self.size_log_sq_tracker.update(log_size * log_size); + self.size_observation_count += 1; + } } /// Controls dynamic adjustment of concurrency for upload and download operations. @@ -448,8 +494,19 @@ impl AdaptiveConcurrencyController { state_lg.rtt_predictor.update(n_bytes, elapsed_time, avg_concurrency, weight); } - // Calculate common values once - let reference_size = config.client.ac_target_rtt_transmission_size; + // Track transmission sizes on final completion for dynamic reference size estimation. + // This intentionally includes failed transfers (option A): transfer-size mix is treated + // as workload context independent of the transfer outcome. + if !partial_update && let Some(n_bytes) = n_bytes_if_known { + state_lg.update_size_tracking(n_bytes); + } + + // Use the dynamically estimated reference size when available, falling back to the + // configured value. This adapts the concurrency increase check to the actual workload: + // when most transfers are small, the reference size drops and concurrency can grow faster. + let reference_size = state_lg + .estimated_reference_transmission_size() + .unwrap_or(*config.client.ac_max_reference_transmission_size); let target_rtt_secs = config.client.ac_target_rtt.as_secs_f64(); // If the success ratio is healthy and the predicted RTT is below the target RTT, @@ -524,13 +581,16 @@ impl AdaptiveConcurrencyController { } if state_lg.last_logging_time.elapsed() > Duration::from_millis(config.client.ac_logging_interval_ms) { + state_lg.last_logging_time = Instant::now(); let latency_state = state_lg.latency_model_state(self.concurrency_semaphore.active_permits() as f64); + let ref_size_mb = reference_size as f64 / (1024.0 * 1024.0); info!( - "Concurrency control for {}: Current concurrency = {}; predicted bandwidth = {}; success_ratio = {:.3}; observed bytes sent so far = {}; completed transmissions = {}", + "Concurrency control for {}: Current concurrency = {}; predicted bandwidth = {:.0}; success_ratio = {:.3}; reference_size = {:.1}MB; observed bytes sent so far = {}; completed transmissions = {}", self.logging_tag, self.concurrency_semaphore.total_permits(), latency_state.predicted_bandwidth, model_state.success_ratio, + ref_size_mb, state_lg.bytes_sent_so_far, state_lg.completed_transmissions_count ); @@ -688,6 +748,9 @@ impl ConcurrencyControllerState { last_logging_time: Instant::now(), bytes_sent_so_far: 0, completed_transmissions_count: 0, + size_log_tracker: ExpWeightedMovingAvg::new_count_decay(TR_HALF_LIFE_COUNT), + size_log_sq_tracker: ExpWeightedMovingAvg::new_count_decay(TR_HALF_LIFE_COUNT), + size_observation_count: 0, } } } @@ -912,4 +975,160 @@ mod tests { let latency_state = controller.latency_model_state().await; assert!(latency_state.predicted_bandwidth >= 0.0); } + + #[test] + fn test_reference_size_returns_none_with_insufficient_data() { + let state = ConcurrencyControllerState::new_testing(); + assert!(state.estimated_reference_transmission_size().is_none()); + } + + #[test] + fn test_reference_size_with_uniform_sizes() { + let mut state = ConcurrencyControllerState::new_testing(); + + let size: u64 = 10 * 1024 * 1024; // 10 MB + for _ in 0..10 { + state.update_size_tracking(size); + } + + let ref_size = state.estimated_reference_transmission_size().unwrap(); + let config = xet_config(); + + // With zero variance, the 95th percentile should equal the mean (~10MB). + debug_assert!(ref_size >= *config.client.ac_min_reference_transmission_size); + debug_assert_le!(ref_size, *config.client.ac_max_reference_transmission_size); + assert!((5 * 1024 * 1024..=12 * 1024 * 1024).contains(&ref_size)); + } + + #[test] + fn test_reference_size_bounded_by_minimum() { + let mut state = ConcurrencyControllerState::new_testing(); + + let size: u64 = 1024; // 1 KB + for _ in 0..10 { + state.update_size_tracking(size); + } + + let config = xet_config(); + let ref_size = state.estimated_reference_transmission_size().unwrap(); + assert_eq!(ref_size, *config.client.ac_min_reference_transmission_size); + } + + #[test] + fn test_reference_size_bounded_by_config_maximum() { + let mut state = ConcurrencyControllerState::new_testing(); + + let size: u64 = 200 * 1024 * 1024; // 200 MB (above the 64MB config default) + for _ in 0..10 { + state.update_size_tracking(size); + } + + let ref_size = state.estimated_reference_transmission_size().unwrap(); + let config = xet_config(); + assert!(ref_size <= *config.client.ac_max_reference_transmission_size); + } + + #[test] + fn test_reference_size_skips_zero_byte_transfers() { + let mut state = ConcurrencyControllerState::new_testing(); + + for _ in 0..10 { + state.update_size_tracking(0); + } + + assert!(state.estimated_reference_transmission_size().is_none()); + assert_eq!(state.size_observation_count, 0); + } + + #[test] + fn test_reference_size_with_mixed_sizes() { + let config = xet_config(); + + let mut small_only_state = ConcurrencyControllerState::new_testing(); + for _ in 0..10 { + small_only_state.update_size_tracking(512 * 1024); // 512 KB + } + let small_only_ref_size = small_only_state.estimated_reference_transmission_size().unwrap(); + + let mut state = ConcurrencyControllerState::new_testing(); + + // Mix of small and large transfers + for _ in 0..5 { + state.update_size_tracking(512 * 1024); // 512 KB + } + for _ in 0..5 { + state.update_size_tracking(32 * 1024 * 1024); // 32 MB + } + + let ref_size = state.estimated_reference_transmission_size().unwrap(); + debug_assert!(ref_size >= *config.client.ac_min_reference_transmission_size); + debug_assert_le!(ref_size, *config.client.ac_max_reference_transmission_size); + + // Mixed workloads should produce a larger reference than the small-only baseline. + assert!(ref_size > small_only_ref_size); + } + + #[tokio::test] + async fn test_failed_transfers_still_update_size_tracking() { + time::pause(); + let controller = AdaptiveConcurrencyController::new_testing(1, (1, 4)); + + for _ in 0..MIN_SIZE_OBSERVATIONS_FOR_REFERENCE { + let permit = controller.acquire_connection_permit().await.unwrap(); + advance(Duration::from_millis(10)).await; + permit.report_completion(8 * 1024 * 1024, false).await; + advance(Duration::from_millis(DECR_SPACING_MS + 1)).await; + } + + let state = controller.state.lock().await; + assert_eq!(state.size_observation_count, MIN_SIZE_OBSERVATIONS_FOR_REFERENCE); + assert!(state.estimated_reference_transmission_size().is_some()); + } + + /// Helper: run a sequence of transfers with varying sizes on a controller. + /// sizes_bytes are cycled through. bandwidth_bps is simulated bandwidth in bytes/sec. + /// Returns the final total_permits. + async fn train_controller( + controller: &Arc, + sizes_bytes: &[u64], + bandwidth_bps: f64, + num_iterations: usize, + ) -> usize { + for i in 0..num_iterations { + let size = sizes_bytes[i % sizes_bytes.len()]; + let permit = controller.acquire_connection_permit().await.unwrap(); + let duration_ms = ((size as f64 / bandwidth_bps) * 1000.0) as u64 + 10; + advance(Duration::from_millis(duration_ms)).await; + permit.report_completion(size, true).await; + advance(Duration::from_millis(INCR_SPACING_MS + 1)).await; + } + controller.total_permits() + } + + #[tokio::test] + async fn test_small_transfers_allow_higher_concurrency_than_large() { + time::pause(); + advance(Duration::from_millis(INCR_SPACING_MS + 1)).await; + + // Use varying sizes within each class so the OLR has diverse x_eff values. + let small_sizes: Vec = vec![256 * 1024, 512 * 1024, 1024 * 1024, 2 * 1024 * 1024]; + let large_sizes: Vec = vec![10 * 1024 * 1024, 20 * 1024 * 1024, 40 * 1024 * 1024, 64 * 1024 * 1024]; + let bandwidth = 5.0 * 1024.0 * 1024.0; // 5 MB/s + + // Run small transfers first (separate time context for clean measurements). + let controller_small = AdaptiveConcurrencyController::new_testing(1, (1, 50)); + let small_concurrency = train_controller(&controller_small, &small_sizes, bandwidth, 40).await; + + // Run large transfers next. + let controller_large = AdaptiveConcurrencyController::new_testing(1, (1, 50)); + let large_concurrency = train_controller(&controller_large, &large_sizes, bandwidth, 40).await; + + // The small-transfer controller should reach higher concurrency because its + // dynamic reference size is much smaller (~2MB vs ~64MB), making the predicted + // RTT at the reference size much lower and the increase check more permissive. + assert!( + small_concurrency >= large_concurrency, + "Small-transfer concurrency ({small_concurrency}) should be >= large-transfer concurrency ({large_concurrency})" + ); + } } diff --git a/xet_runtime/src/config/groups/client.rs b/xet_runtime/src/config/groups/client.rs index 45ed002c..6bf59e7d 100644 --- a/xet_runtime/src/config/groups/client.rs +++ b/xet_runtime/src/config/groups/client.rs @@ -171,12 +171,22 @@ crate::config_group!({ /// Use the environment variable `HF_XET_CLIENT_AC_UNHEALTHY_SUCCESS_RATIO_THRESHOLD` to set this value. ref ac_unhealthy_success_ratio_threshold: f64 = 0.5; - /// The reference size (64MB) used for bandwidth target checks. + /// The maximum reference transmission size used for bandwidth target checks. + /// The dynamic reference size (estimated from observed transfer sizes) is capped at this value. /// /// The default value is 64MB. /// - /// Use the environment variable `HF_XET_CLIENT_AC_TARGET_RTT_TRANSMISSION_SIZE` to set this value. - ref ac_target_rtt_transmission_size: u64 = 64 * 1024 * 1024; + /// Use the environment variable `HF_XET_CLIENT_AC_MAX_REFERENCE_TRANSMISSION_SIZE` to set this value. + ref ac_max_reference_transmission_size: ByteSize = ByteSize::from("64mb"); + + /// The minimum reference transmission size used for bandwidth target checks. + /// The dynamic reference size (estimated from observed transfer sizes) is floored at this value + /// to prevent excessively aggressive concurrency increases with very small transfers. + /// + /// The default value is 1MB. + /// + /// Use the environment variable `HF_XET_CLIENT_AC_MIN_REFERENCE_TRANSMISSION_SIZE` to set this value. + ref ac_min_reference_transmission_size: ByteSize = ByteSize::from("1mb"); /// Log the concurrency on this interval. ///