diff --git a/crates/teamtalk/src/utils/backoff.rs b/crates/teamtalk/src/utils/backoff.rs index a407d65..67d0a99 100644 --- a/crates/teamtalk/src/utils/backoff.rs +++ b/crates/teamtalk/src/utils/backoff.rs @@ -1,13 +1,38 @@ //! Exponential backoff helper. +//! +//! Implements an AWS-style exponential-backoff schedule with a +//! parameterised jitter factor (see "Exponential Backoff And Jitter", +//! AWS Architecture Blog, 2015). Given an attempt count `n`, a base +//! delay `b`, a growth factor `f`, a cap `m`, and a jitter knob +//! `j in [0, 1]`, the next delay is: +//! +//! ```text +//! cap_n = min(b * f^n, m) +//! delay = cap_n * (1 - j) + rand(0, cap_n * j) +//! ``` +//! +//! The extremes are: +//! +//! * `j = 1.0` — full jitter: `delay = rand(0, cap_n)`. +//! * `j = 0.0` — no jitter: `delay = cap_n`. +//! * `j = 0.5` — equal jitter: `delay in [0.5*cap_n, cap_n]`. +//! +//! Full jitter is the default and is usually the safest choice for +//! thundering-herd avoidance on reconnect storms. + use rand::{RngExt, rng}; use std::time::Duration; +/// Default jitter factor used by [`Default`]. +const DEFAULT_JITTER: f32 = 1.0; + /// Exponential backoff with jitter and a maximum cap. #[derive(Debug, Clone)] pub struct ExponentialBackoff { initial_delay: Duration, max_delay: Duration, factor: f32, + jitter: f32, attempts: u32, current_val: Duration, } @@ -18,6 +43,7 @@ impl Default for ExponentialBackoff { initial_delay: Duration::from_secs(1), max_delay: Duration::from_secs(120), factor: 1.6, + jitter: DEFAULT_JITTER, attempts: 0, current_val: Duration::ZERO, } @@ -26,17 +52,34 @@ impl Default for ExponentialBackoff { impl ExponentialBackoff { /// Creates a new backoff schedule. + /// + /// * `initial` — base delay for attempt 0. + /// * `max` — upper cap on any single delay (pre-jitter). + /// * `factor` — growth factor per attempt; the cap at attempt `n` + /// is `initial * factor^n`, then clamped to `max`. + /// * `jitter` — jitter factor, clamped to `[0.0, 1.0]`. A value + /// of `1.0` picks uniformly in `[0, cap]`; `0.0` returns the + /// deterministic cap; `0.5` picks uniformly in + /// `[cap/2, cap]`. Values outside `[0, 1]` are silently + /// clamped. #[must_use] - pub fn new(initial: Duration, max: Duration, factor: f32, _jitter: f32) -> Self { + pub fn new(initial: Duration, max: Duration, factor: f32, jitter: f32) -> Self { Self { initial_delay: initial, max_delay: max, factor, + jitter: jitter.clamp(0.0, 1.0), attempts: 0, current_val: Duration::ZERO, } } + /// Returns the configured jitter factor, always in `[0.0, 1.0]`. + #[must_use] + pub fn jitter(&self) -> f32 { + self.jitter + } + /// Returns the next delay in the schedule. #[must_use] pub fn next_delay(&mut self) -> Duration { @@ -52,19 +95,48 @@ impl ExponentialBackoff { self.initial_delay }; - let exponent = self.attempts as f32; - let cap_secs = base.as_secs_f32() * self.factor.powf(exponent); - let cap = Duration::from_secs_f32(cap_secs).min(self.max_delay); + // Compute the cap in integer milliseconds to avoid panics + // from `Duration::from_secs_f32` on overflow / NaN and to + // keep exact integer growth for typical millisecond-scale + // backoffs (no floating-point drift for bases like 10 ms). + let max_millis_u128 = self.max_delay.as_millis(); + let mut cap_millis_u128 = base.as_millis().min(max_millis_u128); + for _ in 0..self.attempts { + if cap_millis_u128 >= max_millis_u128 { + cap_millis_u128 = max_millis_u128; + break; + } + let factor = f64::from(self.factor); + if !factor.is_finite() || factor <= 1.0 { + break; + } + let next = ((cap_millis_u128 as f64) * factor) as u128; + if next <= cap_millis_u128 { + break; + } + cap_millis_u128 = next.min(max_millis_u128); + } + let cap_millis = cap_millis_u128.min(u128::from(u64::MAX)) as u64; self.attempts += 1; - let max_millis = cap.as_millis() as u64; - if max_millis == 0 { + if cap_millis == 0 { self.current_val = Duration::ZERO; return Duration::ZERO; } - let jittered = rng().random_range(0..=max_millis); + // delay = cap * (1 - j) + rand(0, cap * j) + // + // Compute the random span in integer milliseconds so that we + // stay deterministic in ordering with the pre-existing tests + // and avoid floating-point drift on the cap boundary. + let random_span_millis = ((cap_millis as f64) * f64::from(self.jitter)) as u64; + let fixed_millis = cap_millis.saturating_sub(random_span_millis); + let jittered = if random_span_millis == 0 { + fixed_millis + } else { + fixed_millis + rng().random_range(0..=random_span_millis) + }; self.current_val = Duration::from_millis(jittered); self.current_val } diff --git a/crates/teamtalk/tests/backoff_jitter_tests.rs b/crates/teamtalk/tests/backoff_jitter_tests.rs new file mode 100644 index 0000000..7fd2d0b --- /dev/null +++ b/crates/teamtalk/tests/backoff_jitter_tests.rs @@ -0,0 +1,188 @@ +//! Integration tests for [`ExponentialBackoff`] jitter semantics. +//! +//! The core contract (repeated here so tests read top-to-bottom): +//! +//! ```text +//! cap_n = min(initial * factor^n, max) +//! delay = cap_n * (1 - jitter) + rand(0, cap_n * jitter) +//! ``` +//! +//! * `jitter = 1.0` — uniform `[0, cap_n]` (full jitter, AWS default). +//! * `jitter = 0.0` — deterministic `cap_n`. +//! * `jitter = 0.5` — uniform `[cap_n/2, cap_n]`. +//! * Values outside `[0, 1]` are clamped. + +use std::time::Duration; +use teamtalk::utils::backoff::ExponentialBackoff; + +#[test] +fn jitter_zero_is_deterministic_cap() { + let mut backoff = ExponentialBackoff::new( + Duration::from_millis(50), + Duration::from_millis(50), + 2.0, + 0.0, + ); + // With jitter = 0 and a flat cap we must always return the cap. + for _ in 0..10 { + let delay = backoff.next_delay(); + assert_eq!( + delay, + Duration::from_millis(50), + "jitter = 0 must produce the deterministic cap each time", + ); + } + assert_eq!(backoff.jitter(), 0.0); +} + +#[test] +fn jitter_one_spans_full_range() { + let mut backoff = ExponentialBackoff::new( + Duration::from_millis(100), + Duration::from_millis(100), + 2.0, + 1.0, + ); + let mut min_seen = Duration::from_millis(u64::MAX); + let mut max_seen = Duration::ZERO; + for _ in 0..1_000 { + let delay = backoff.next_delay(); + assert!( + delay <= Duration::from_millis(100), + "jitter = 1 must stay at or below cap", + ); + if delay < min_seen { + min_seen = delay; + } + if delay > max_seen { + max_seen = delay; + } + } + // Over 1 000 samples the spread should be wide; we assert a + // generous lower bound so the test is not flaky but still + // detects a regression to "always cap". + assert!( + min_seen <= Duration::from_millis(20), + "jitter = 1 should produce low values over 1 000 samples, got min {min_seen:?}", + ); + assert!( + max_seen >= Duration::from_millis(80), + "jitter = 1 should produce high values over 1 000 samples, got max {max_seen:?}", + ); +} + +#[test] +fn jitter_half_stays_in_upper_half_of_cap() { + let mut backoff = ExponentialBackoff::new( + Duration::from_millis(100), + Duration::from_millis(100), + 2.0, + 0.5, + ); + for _ in 0..500 { + let delay = backoff.next_delay(); + assert!( + delay >= Duration::from_millis(50), + "jitter = 0.5 must stay at or above cap/2, got {delay:?}", + ); + assert!( + delay <= Duration::from_millis(100), + "jitter = 0.5 must stay at or below cap, got {delay:?}", + ); + } +} + +#[test] +fn jitter_above_one_is_clamped() { + let backoff = ExponentialBackoff::new( + Duration::from_millis(10), + Duration::from_millis(10), + 2.0, + 5.0, + ); + assert_eq!( + backoff.jitter(), + 1.0, + "jitter > 1 must be clamped to the full-jitter upper bound", + ); +} + +#[test] +fn jitter_below_zero_is_clamped() { + let backoff = ExponentialBackoff::new( + Duration::from_millis(10), + Duration::from_millis(10), + 2.0, + -0.5, + ); + assert_eq!( + backoff.jitter(), + 0.0, + "jitter < 0 must be clamped to the no-jitter lower bound", + ); +} + +#[test] +fn jitter_nan_is_clamped_to_zero() { + // f32::clamp(NaN, 0, 1) returns NaN in the spec; use the + // observable behaviour instead: a NaN jitter must never produce + // a delay outside the cap range. + let mut backoff = ExponentialBackoff::new( + Duration::from_millis(100), + Duration::from_millis(100), + 2.0, + f32::NAN, + ); + let delay = backoff.next_delay(); + assert!( + delay <= Duration::from_millis(100), + "NaN jitter must not produce delays above the cap, got {delay:?}", + ); +} + +#[test] +fn next_delay_advances_cap_until_max() { + // With jitter = 0 we can observe the raw cap sequence. + let mut backoff = ExponentialBackoff::new( + Duration::from_millis(10), + Duration::from_millis(1_000), + 2.0, + 0.0, + ); + let d0 = backoff.next_delay(); + let d1 = backoff.next_delay(); + let d2 = backoff.next_delay(); + assert_eq!(d0, Duration::from_millis(10)); + assert_eq!(d1, Duration::from_millis(20)); + assert_eq!(d2, Duration::from_millis(40)); + for _ in 0..20 { + let _ = backoff.next_delay(); + } + // Eventually the cap must hit the configured max. + assert_eq!(backoff.next_delay(), Duration::from_millis(1_000)); +} + +#[test] +fn reset_clears_current_delay() { + let mut backoff = ExponentialBackoff::new( + Duration::from_millis(10), + Duration::from_millis(100), + 2.0, + 0.5, + ); + let _ = backoff.next_delay(); + assert!(backoff.attempts() >= 1); + backoff.reset(); + assert_eq!(backoff.attempts(), 0); + assert_eq!(backoff.current_delay(), Duration::ZERO); +} + +#[test] +fn default_is_full_jitter() { + let backoff = ExponentialBackoff::default(); + assert_eq!( + backoff.jitter(), + 1.0, + "Default should be full jitter for safe thundering-herd avoidance", + ); +}