Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 79 additions & 7 deletions crates/teamtalk/src/utils/backoff.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,38 @@
//! Exponential backoff helper.
//!
//! Implements an AWS-style exponential-backoff schedule with a
//! parameterised jitter factor (see "Exponential Backoff And Jitter",
//! AWS Architecture Blog, 2015). Given an attempt count `n`, a base
//! delay `b`, a growth factor `f`, a cap `m`, and a jitter knob
//! `j in [0, 1]`, the next delay is:
//!
//! ```text
//! cap_n = min(b * f^n, m)
//! delay = cap_n * (1 - j) + rand(0, cap_n * j)
//! ```
//!
//! The extremes are:
//!
//! * `j = 1.0` — full jitter: `delay = rand(0, cap_n)`.
//! * `j = 0.0` — no jitter: `delay = cap_n`.
//! * `j = 0.5` — equal jitter: `delay in [0.5*cap_n, cap_n]`.
//!
//! Full jitter is the default and is usually the safest choice for
//! thundering-herd avoidance on reconnect storms.

use rand::{RngExt, rng};
use std::time::Duration;

/// Default jitter factor used by [`Default`].
const DEFAULT_JITTER: f32 = 1.0;

/// Exponential backoff with jitter and a maximum cap.
#[derive(Debug, Clone)]
pub struct ExponentialBackoff {
initial_delay: Duration,
max_delay: Duration,
factor: f32,
jitter: f32,
attempts: u32,
current_val: Duration,
}
Expand All @@ -18,6 +43,7 @@ impl Default for ExponentialBackoff {
initial_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(120),
factor: 1.6,
jitter: DEFAULT_JITTER,
attempts: 0,
current_val: Duration::ZERO,
}
Expand All @@ -26,17 +52,34 @@ impl Default for ExponentialBackoff {

impl ExponentialBackoff {
/// Creates a new backoff schedule.
///
/// * `initial` — base delay for attempt 0.
/// * `max` — upper cap on any single delay (pre-jitter).
/// * `factor` — growth factor per attempt; the cap at attempt `n`
/// is `initial * factor^n`, then clamped to `max`.
/// * `jitter` — jitter factor, clamped to `[0.0, 1.0]`. A value
/// of `1.0` picks uniformly in `[0, cap]`; `0.0` returns the
/// deterministic cap; `0.5` picks uniformly in
/// `[cap/2, cap]`. Values outside `[0, 1]` are silently
/// clamped.
#[must_use]
pub fn new(initial: Duration, max: Duration, factor: f32, _jitter: f32) -> Self {
pub fn new(initial: Duration, max: Duration, factor: f32, jitter: f32) -> Self {
Self {
initial_delay: initial,
max_delay: max,
factor,
jitter: jitter.clamp(0.0, 1.0),
attempts: 0,
current_val: Duration::ZERO,
}
}

/// Returns the configured jitter factor, always in `[0.0, 1.0]`.
#[must_use]
pub fn jitter(&self) -> f32 {
self.jitter
}

/// Returns the next delay in the schedule.
#[must_use]
pub fn next_delay(&mut self) -> Duration {
Expand All @@ -52,19 +95,48 @@ impl ExponentialBackoff {
self.initial_delay
};

let exponent = self.attempts as f32;
let cap_secs = base.as_secs_f32() * self.factor.powf(exponent);
let cap = Duration::from_secs_f32(cap_secs).min(self.max_delay);
// Compute the cap in integer milliseconds to avoid panics
// from `Duration::from_secs_f32` on overflow / NaN and to
// keep exact integer growth for typical millisecond-scale
// backoffs (no floating-point drift for bases like 10 ms).
let max_millis_u128 = self.max_delay.as_millis();
let mut cap_millis_u128 = base.as_millis().min(max_millis_u128);
for _ in 0..self.attempts {
if cap_millis_u128 >= max_millis_u128 {
cap_millis_u128 = max_millis_u128;
break;
}
let factor = f64::from(self.factor);
if !factor.is_finite() || factor <= 1.0 {
break;
}
let next = ((cap_millis_u128 as f64) * factor) as u128;
if next <= cap_millis_u128 {
break;
}
cap_millis_u128 = next.min(max_millis_u128);
}
let cap_millis = cap_millis_u128.min(u128::from(u64::MAX)) as u64;

self.attempts += 1;

let max_millis = cap.as_millis() as u64;
if max_millis == 0 {
if cap_millis == 0 {
self.current_val = Duration::ZERO;
return Duration::ZERO;
}

let jittered = rng().random_range(0..=max_millis);
// delay = cap * (1 - j) + rand(0, cap * j)
//
// Compute the random span in integer milliseconds so that we
// stay deterministic in ordering with the pre-existing tests
// and avoid floating-point drift on the cap boundary.
let random_span_millis = ((cap_millis as f64) * f64::from(self.jitter)) as u64;
let fixed_millis = cap_millis.saturating_sub(random_span_millis);
let jittered = if random_span_millis == 0 {
fixed_millis
} else {
fixed_millis + rng().random_range(0..=random_span_millis)
};
self.current_val = Duration::from_millis(jittered);
self.current_val
}
Expand Down
188 changes: 188 additions & 0 deletions crates/teamtalk/tests/backoff_jitter_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
//! Integration tests for [`ExponentialBackoff`] jitter semantics.
//!
//! The core contract (repeated here so tests read top-to-bottom):
//!
//! ```text
//! cap_n = min(initial * factor^n, max)
//! delay = cap_n * (1 - jitter) + rand(0, cap_n * jitter)
//! ```
//!
//! * `jitter = 1.0` — uniform `[0, cap_n]` (full jitter, AWS default).
//! * `jitter = 0.0` — deterministic `cap_n`.
//! * `jitter = 0.5` — uniform `[cap_n/2, cap_n]`.
//! * Values outside `[0, 1]` are clamped.

use std::time::Duration;
use teamtalk::utils::backoff::ExponentialBackoff;

#[test]
fn jitter_zero_is_deterministic_cap() {
let mut backoff = ExponentialBackoff::new(
Duration::from_millis(50),
Duration::from_millis(50),
2.0,
0.0,
);
// With jitter = 0 and a flat cap we must always return the cap.
for _ in 0..10 {
let delay = backoff.next_delay();
assert_eq!(
delay,
Duration::from_millis(50),
"jitter = 0 must produce the deterministic cap each time",
);
}
assert_eq!(backoff.jitter(), 0.0);
}

#[test]
fn jitter_one_spans_full_range() {
let mut backoff = ExponentialBackoff::new(
Duration::from_millis(100),
Duration::from_millis(100),
2.0,
1.0,
);
let mut min_seen = Duration::from_millis(u64::MAX);
let mut max_seen = Duration::ZERO;
for _ in 0..1_000 {
let delay = backoff.next_delay();
assert!(
delay <= Duration::from_millis(100),
"jitter = 1 must stay at or below cap",
);
if delay < min_seen {
min_seen = delay;
}
if delay > max_seen {
max_seen = delay;
}
}
// Over 1 000 samples the spread should be wide; we assert a
// generous lower bound so the test is not flaky but still
// detects a regression to "always cap".
assert!(
min_seen <= Duration::from_millis(20),
"jitter = 1 should produce low values over 1 000 samples, got min {min_seen:?}",
);
assert!(
max_seen >= Duration::from_millis(80),
"jitter = 1 should produce high values over 1 000 samples, got max {max_seen:?}",
);
}

#[test]
fn jitter_half_stays_in_upper_half_of_cap() {
let mut backoff = ExponentialBackoff::new(
Duration::from_millis(100),
Duration::from_millis(100),
2.0,
0.5,
);
for _ in 0..500 {
let delay = backoff.next_delay();
assert!(
delay >= Duration::from_millis(50),
"jitter = 0.5 must stay at or above cap/2, got {delay:?}",
);
assert!(
delay <= Duration::from_millis(100),
"jitter = 0.5 must stay at or below cap, got {delay:?}",
);
}
}

#[test]
fn jitter_above_one_is_clamped() {
let backoff = ExponentialBackoff::new(
Duration::from_millis(10),
Duration::from_millis(10),
2.0,
5.0,
);
assert_eq!(
backoff.jitter(),
1.0,
"jitter > 1 must be clamped to the full-jitter upper bound",
);
}

#[test]
fn jitter_below_zero_is_clamped() {
let backoff = ExponentialBackoff::new(
Duration::from_millis(10),
Duration::from_millis(10),
2.0,
-0.5,
);
assert_eq!(
backoff.jitter(),
0.0,
"jitter < 0 must be clamped to the no-jitter lower bound",
);
}

#[test]
fn jitter_nan_is_clamped_to_zero() {
// f32::clamp(NaN, 0, 1) returns NaN in the spec; use the
// observable behaviour instead: a NaN jitter must never produce
// a delay outside the cap range.
let mut backoff = ExponentialBackoff::new(
Duration::from_millis(100),
Duration::from_millis(100),
2.0,
f32::NAN,
);
let delay = backoff.next_delay();
assert!(
delay <= Duration::from_millis(100),
"NaN jitter must not produce delays above the cap, got {delay:?}",
);
}

#[test]
fn next_delay_advances_cap_until_max() {
// With jitter = 0 we can observe the raw cap sequence.
let mut backoff = ExponentialBackoff::new(
Duration::from_millis(10),
Duration::from_millis(1_000),
2.0,
0.0,
);
let d0 = backoff.next_delay();
let d1 = backoff.next_delay();
let d2 = backoff.next_delay();
assert_eq!(d0, Duration::from_millis(10));
assert_eq!(d1, Duration::from_millis(20));
assert_eq!(d2, Duration::from_millis(40));
for _ in 0..20 {
let _ = backoff.next_delay();
}
// Eventually the cap must hit the configured max.
assert_eq!(backoff.next_delay(), Duration::from_millis(1_000));
}

#[test]
fn reset_clears_current_delay() {
let mut backoff = ExponentialBackoff::new(
Duration::from_millis(10),
Duration::from_millis(100),
2.0,
0.5,
);
let _ = backoff.next_delay();
assert!(backoff.attempts() >= 1);
backoff.reset();
assert_eq!(backoff.attempts(), 0);
assert_eq!(backoff.current_delay(), Duration::ZERO);
}

#[test]
fn default_is_full_jitter() {
let backoff = ExponentialBackoff::default();
assert_eq!(
backoff.jitter(),
1.0,
"Default should be full jitter for safe thundering-herd avoidance",
);
}
Loading