Adjust RTT prediction determining concurrency by transmission size. (#708)

Currently, the condition for increasing connection concurrency is gated
on the model predicting that a 64MB transmission will complete within 90
seconds. However, when the transmissions are primarily composed of small
packets, this can drastically overestimate the round trip, artificially
suppressing the connection concurrency.

This PR fixes this issue by also modeling the average predicted packet
size, using the 95% quantile of that (bounded by two config variables)
to predict the round trip time when considering a concurrency increase.
This commit is contained in:
Hoyt Koepke
2026-03-13 10:47:45 -07:00
committed by GitHub
parent bcce76be63
commit 3390bdc716
2 changed files with 236 additions and 7 deletions

View File

@@ -21,6 +21,9 @@ use super::rtt_prediction::RTTPredictor;
const MIN_PARTIAL_REPORT_INTERVAL_MS: u64 = 200;
const PARTIAL_REPORT_WEIGHT_RATIO: f64 = 0.2;
const REFERENCE_SIZE_QUANTILE_Z: f64 = 1.645; // z-score for 95th percentile
const MIN_SIZE_OBSERVATIONS_FOR_REFERENCE: u64 = 3;
/// The network model state extracted from the concurrency controller.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CCSuccessModelState {
@@ -59,6 +62,14 @@ struct ConcurrencyControllerState {
// The number of completed transmissions observed so far.
completed_transmissions_count: u64,
// Exponentially-weighted trackers for estimating the transmission size distribution.
// Tracks log(size_bytes) for the mean and log(size_bytes)^2 for variance computation,
// enabling a log-normal 95th percentile estimate used as a dynamic reference size
// for the concurrency increase decision.
size_log_tracker: ExpWeightedMovingAvg,
size_log_sq_tracker: ExpWeightedMovingAvg,
size_observation_count: u64,
}
impl ConcurrencyControllerState {
@@ -74,6 +85,9 @@ impl ConcurrencyControllerState {
last_logging_time: Instant::now(),
bytes_sent_so_far: 0,
completed_transmissions_count: 0,
size_log_tracker: ExpWeightedMovingAvg::new_count_decay(rtt_half_life_count),
size_log_sq_tracker: ExpWeightedMovingAvg::new_count_decay(rtt_half_life_count),
size_observation_count: 0,
}
}
@@ -119,7 +133,7 @@ impl ConcurrencyControllerState {
let config = xet_config();
let (predicted_max_rtt, prediction_max_rtt_standard_error) = self
.rtt_predictor
.predict(config.client.ac_target_rtt_transmission_size, current_concurrency);
.predict(*config.client.ac_max_reference_transmission_size, current_concurrency);
let predicted_bandwidth = self.rtt_predictor.predicted_bandwidth();
@@ -129,6 +143,38 @@ impl ConcurrencyControllerState {
predicted_bandwidth: predicted_bandwidth.unwrap_or(0.),
}
}
/// Estimates a workload-appropriate reference transmission size using the 95th percentile
/// of observed transfer sizes (log-normal model). Returns None if insufficient data.
/// The result is clamped to [ac_min_reference_transmission_size, ac_max_reference_transmission_size].
fn estimated_reference_transmission_size(&self) -> Option<u64> {
if self.size_observation_count < MIN_SIZE_OBSERVATIONS_FOR_REFERENCE {
return None;
}
let mu = self.size_log_tracker.value();
let mu_sq = self.size_log_sq_tracker.value();
let variance = (mu_sq - mu * mu).max(0.0);
let sigma = variance.sqrt();
let quantile_95 = (mu + REFERENCE_SIZE_QUANTILE_Z * sigma).exp();
let config = xet_config();
let min_size = *config.client.ac_min_reference_transmission_size;
let max_size = *config.client.ac_max_reference_transmission_size;
Some((quantile_95 as u64).clamp(min_size, max_size))
}
fn update_size_tracking(&mut self, n_bytes: u64) {
if n_bytes == 0 {
return;
}
let log_size = (n_bytes as f64).ln();
self.size_log_tracker.update(log_size);
self.size_log_sq_tracker.update(log_size * log_size);
self.size_observation_count += 1;
}
}
/// Controls dynamic adjustment of concurrency for upload and download operations.
@@ -448,8 +494,19 @@ impl AdaptiveConcurrencyController {
state_lg.rtt_predictor.update(n_bytes, elapsed_time, avg_concurrency, weight);
}
// Calculate common values once
let reference_size = config.client.ac_target_rtt_transmission_size;
// Track transmission sizes on final completion for dynamic reference size estimation.
// This intentionally includes failed transfers (option A): transfer-size mix is treated
// as workload context independent of the transfer outcome.
if !partial_update && let Some(n_bytes) = n_bytes_if_known {
state_lg.update_size_tracking(n_bytes);
}
// Use the dynamically estimated reference size when available, falling back to the
// configured value. This adapts the concurrency increase check to the actual workload:
// when most transfers are small, the reference size drops and concurrency can grow faster.
let reference_size = state_lg
.estimated_reference_transmission_size()
.unwrap_or(*config.client.ac_max_reference_transmission_size);
let target_rtt_secs = config.client.ac_target_rtt.as_secs_f64();
// If the success ratio is healthy and the predicted RTT is below the target RTT,
@@ -524,13 +581,16 @@ impl AdaptiveConcurrencyController {
}
if state_lg.last_logging_time.elapsed() > Duration::from_millis(config.client.ac_logging_interval_ms) {
state_lg.last_logging_time = Instant::now();
let latency_state = state_lg.latency_model_state(self.concurrency_semaphore.active_permits() as f64);
let ref_size_mb = reference_size as f64 / (1024.0 * 1024.0);
info!(
"Concurrency control for {}: Current concurrency = {}; predicted bandwidth = {}; success_ratio = {:.3}; observed bytes sent so far = {}; completed transmissions = {}",
"Concurrency control for {}: Current concurrency = {}; predicted bandwidth = {:.0}; success_ratio = {:.3}; reference_size = {:.1}MB; observed bytes sent so far = {}; completed transmissions = {}",
self.logging_tag,
self.concurrency_semaphore.total_permits(),
latency_state.predicted_bandwidth,
model_state.success_ratio,
ref_size_mb,
state_lg.bytes_sent_so_far,
state_lg.completed_transmissions_count
);
@@ -688,6 +748,9 @@ impl ConcurrencyControllerState {
last_logging_time: Instant::now(),
bytes_sent_so_far: 0,
completed_transmissions_count: 0,
size_log_tracker: ExpWeightedMovingAvg::new_count_decay(TR_HALF_LIFE_COUNT),
size_log_sq_tracker: ExpWeightedMovingAvg::new_count_decay(TR_HALF_LIFE_COUNT),
size_observation_count: 0,
}
}
}
@@ -912,4 +975,160 @@ mod tests {
let latency_state = controller.latency_model_state().await;
assert!(latency_state.predicted_bandwidth >= 0.0);
}
#[test]
fn test_reference_size_returns_none_with_insufficient_data() {
let state = ConcurrencyControllerState::new_testing();
assert!(state.estimated_reference_transmission_size().is_none());
}
#[test]
fn test_reference_size_with_uniform_sizes() {
let mut state = ConcurrencyControllerState::new_testing();
let size: u64 = 10 * 1024 * 1024; // 10 MB
for _ in 0..10 {
state.update_size_tracking(size);
}
let ref_size = state.estimated_reference_transmission_size().unwrap();
let config = xet_config();
// With zero variance, the 95th percentile should equal the mean (~10MB).
debug_assert!(ref_size >= *config.client.ac_min_reference_transmission_size);
debug_assert_le!(ref_size, *config.client.ac_max_reference_transmission_size);
assert!((5 * 1024 * 1024..=12 * 1024 * 1024).contains(&ref_size));
}
#[test]
fn test_reference_size_bounded_by_minimum() {
let mut state = ConcurrencyControllerState::new_testing();
let size: u64 = 1024; // 1 KB
for _ in 0..10 {
state.update_size_tracking(size);
}
let config = xet_config();
let ref_size = state.estimated_reference_transmission_size().unwrap();
assert_eq!(ref_size, *config.client.ac_min_reference_transmission_size);
}
#[test]
fn test_reference_size_bounded_by_config_maximum() {
let mut state = ConcurrencyControllerState::new_testing();
let size: u64 = 200 * 1024 * 1024; // 200 MB (above the 64MB config default)
for _ in 0..10 {
state.update_size_tracking(size);
}
let ref_size = state.estimated_reference_transmission_size().unwrap();
let config = xet_config();
assert!(ref_size <= *config.client.ac_max_reference_transmission_size);
}
#[test]
fn test_reference_size_skips_zero_byte_transfers() {
let mut state = ConcurrencyControllerState::new_testing();
for _ in 0..10 {
state.update_size_tracking(0);
}
assert!(state.estimated_reference_transmission_size().is_none());
assert_eq!(state.size_observation_count, 0);
}
#[test]
fn test_reference_size_with_mixed_sizes() {
let config = xet_config();
let mut small_only_state = ConcurrencyControllerState::new_testing();
for _ in 0..10 {
small_only_state.update_size_tracking(512 * 1024); // 512 KB
}
let small_only_ref_size = small_only_state.estimated_reference_transmission_size().unwrap();
let mut state = ConcurrencyControllerState::new_testing();
// Mix of small and large transfers
for _ in 0..5 {
state.update_size_tracking(512 * 1024); // 512 KB
}
for _ in 0..5 {
state.update_size_tracking(32 * 1024 * 1024); // 32 MB
}
let ref_size = state.estimated_reference_transmission_size().unwrap();
debug_assert!(ref_size >= *config.client.ac_min_reference_transmission_size);
debug_assert_le!(ref_size, *config.client.ac_max_reference_transmission_size);
// Mixed workloads should produce a larger reference than the small-only baseline.
assert!(ref_size > small_only_ref_size);
}
#[tokio::test]
async fn test_failed_transfers_still_update_size_tracking() {
time::pause();
let controller = AdaptiveConcurrencyController::new_testing(1, (1, 4));
for _ in 0..MIN_SIZE_OBSERVATIONS_FOR_REFERENCE {
let permit = controller.acquire_connection_permit().await.unwrap();
advance(Duration::from_millis(10)).await;
permit.report_completion(8 * 1024 * 1024, false).await;
advance(Duration::from_millis(DECR_SPACING_MS + 1)).await;
}
let state = controller.state.lock().await;
assert_eq!(state.size_observation_count, MIN_SIZE_OBSERVATIONS_FOR_REFERENCE);
assert!(state.estimated_reference_transmission_size().is_some());
}
/// Helper: run a sequence of transfers with varying sizes on a controller.
/// sizes_bytes are cycled through. bandwidth_bps is simulated bandwidth in bytes/sec.
/// Returns the final total_permits.
async fn train_controller(
controller: &Arc<AdaptiveConcurrencyController>,
sizes_bytes: &[u64],
bandwidth_bps: f64,
num_iterations: usize,
) -> usize {
for i in 0..num_iterations {
let size = sizes_bytes[i % sizes_bytes.len()];
let permit = controller.acquire_connection_permit().await.unwrap();
let duration_ms = ((size as f64 / bandwidth_bps) * 1000.0) as u64 + 10;
advance(Duration::from_millis(duration_ms)).await;
permit.report_completion(size, true).await;
advance(Duration::from_millis(INCR_SPACING_MS + 1)).await;
}
controller.total_permits()
}
#[tokio::test]
async fn test_small_transfers_allow_higher_concurrency_than_large() {
time::pause();
advance(Duration::from_millis(INCR_SPACING_MS + 1)).await;
// Use varying sizes within each class so the OLR has diverse x_eff values.
let small_sizes: Vec<u64> = vec![256 * 1024, 512 * 1024, 1024 * 1024, 2 * 1024 * 1024];
let large_sizes: Vec<u64> = vec![10 * 1024 * 1024, 20 * 1024 * 1024, 40 * 1024 * 1024, 64 * 1024 * 1024];
let bandwidth = 5.0 * 1024.0 * 1024.0; // 5 MB/s
// Run small transfers first (separate time context for clean measurements).
let controller_small = AdaptiveConcurrencyController::new_testing(1, (1, 50));
let small_concurrency = train_controller(&controller_small, &small_sizes, bandwidth, 40).await;
// Run large transfers next.
let controller_large = AdaptiveConcurrencyController::new_testing(1, (1, 50));
let large_concurrency = train_controller(&controller_large, &large_sizes, bandwidth, 40).await;
// The small-transfer controller should reach higher concurrency because its
// dynamic reference size is much smaller (~2MB vs ~64MB), making the predicted
// RTT at the reference size much lower and the increase check more permissive.
assert!(
small_concurrency >= large_concurrency,
"Small-transfer concurrency ({small_concurrency}) should be >= large-transfer concurrency ({large_concurrency})"
);
}
}

View File

@@ -171,12 +171,22 @@ crate::config_group!({
/// Use the environment variable `HF_XET_CLIENT_AC_UNHEALTHY_SUCCESS_RATIO_THRESHOLD` to set this value.
ref ac_unhealthy_success_ratio_threshold: f64 = 0.5;
/// The reference size (64MB) used for bandwidth target checks.
/// The maximum reference transmission size used for bandwidth target checks.
/// The dynamic reference size (estimated from observed transfer sizes) is capped at this value.
///
/// The default value is 64MB.
///
/// Use the environment variable `HF_XET_CLIENT_AC_TARGET_RTT_TRANSMISSION_SIZE` to set this value.
ref ac_target_rtt_transmission_size: u64 = 64 * 1024 * 1024;
/// Use the environment variable `HF_XET_CLIENT_AC_MAX_REFERENCE_TRANSMISSION_SIZE` to set this value.
ref ac_max_reference_transmission_size: ByteSize = ByteSize::from("64mb");
/// The minimum reference transmission size used for bandwidth target checks.
/// The dynamic reference size (estimated from observed transfer sizes) is floored at this value
/// to prevent excessively aggressive concurrency increases with very small transfers.
///
/// The default value is 1MB.
///
/// Use the environment variable `HF_XET_CLIENT_AC_MIN_REFERENCE_TRANSMISSION_SIZE` to set this value.
ref ac_min_reference_transmission_size: ByteSize = ByteSize::from("1mb");
/// Log the concurrency on this interval.
///