Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 25 additions & 49 deletions datasketches/src/common/binomial_bounds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ static UB_EQUIV_TABLE: [f64; 363] = [
/// # Arguments
///
/// * `num_samples`: The number of samples in the sample set.
/// * `theta`: The sampling probability. Must be in the range (0.0, 1.0].
/// * `theta`: The sampling probability. Must be in the range `(0.0, 1.0]`.
/// * `num_std_dev`: The number of standard deviations for confidence bounds.
///
/// # Returns
Expand All @@ -284,13 +284,17 @@ static UB_EQUIV_TABLE: [f64; 363] = [
///
/// # Errors
///
/// Returns an error if `theta` is not in the range (0.0, 1.0].
/// Returns an error if `theta` is not in the range `(0.0, 1.0]`.
pub(crate) fn lower_bound(
num_samples: u64,
theta: f64,
num_std_dev: NumStdDev,
) -> Result<f64, Error> {
check_theta(theta)?;
if theta <= 0.0 || theta > 1.0 {
return Err(Error::invalid_argument(format!(
"theta must be in the range (0.0, 1.0], got {theta}"
)));
}

let estimate = num_samples as f64 / theta;
let lb = compute_approx_binomial_lower_bound(num_samples, theta, num_std_dev);
Expand Down Expand Up @@ -325,7 +329,12 @@ pub(crate) fn upper_bound(
if no_data_seen {
return Ok(0.0);
}
check_theta(theta)?;

if theta <= 0.0 || theta > 1.0 {
return Err(Error::invalid_argument(format!(
"theta must be in the range (0.0, 1.0], got {theta}"
)));
}

let estimate = num_samples as f64 / theta;
let ub = compute_approx_binomial_upper_bound(num_samples, theta, num_std_dev);
Expand Down Expand Up @@ -360,6 +369,7 @@ fn cont_classic_ub(num_samples: u64, theta: f64, num_std_devs: f64) -> f64 {
/// # Limitations
///
/// Outside of the valid input range, two different bad things will happen:
///
/// 1. Because we are not using logarithms, the values of intermediate quantities will exceed the
/// dynamic range of doubles.
/// 2. Even if that problem were fixed, the running time of this procedure is essentially linear in
Expand Down Expand Up @@ -548,21 +558,6 @@ fn compute_approx_binomial_upper_bound(
special_n_prime_f(num_samples, theta, delta).unwrap_or(num_samples + 1) as f64 // no need to round
}

/// Validates that theta is in the valid range [0.0, 1.0].
///
/// # Errors
///
/// Returns an error if theta < 0.0 or theta > 1.0.
fn check_theta(theta: f64) -> Result<(), Error> {
if (theta <= 0.0) || (theta > 1.0) {
return Err(Error::invalid_argument(format!(
"theta must be in the range [0.0, 1.0]: {}",
theta
)));
}
Ok(())
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -679,53 +674,34 @@ mod tests {
fn check_bounds() {
let mut i = 0;

fn assert_approx_equal(ci: NumStdDev, j: usize, expected: f64, actual: f64) {
let ratio = actual / expected;
assert!(
(ratio - 1.0).abs() < TOL,
"ci={ci:?}, j={j}: expected {expected}, got {actual}, ratio={ratio}",
);
}

for ci in [NumStdDev::One, NumStdDev::Two, NumStdDev::Three] {
let arr = run_test_aux(20, ci, 1e-3);
for j in 0..5 {
let ratio = arr[j] / STD[i][j];
assert!(
(ratio - 1.0).abs() < TOL,
"ci={:?}, j={}: expected {}, got {}, ratio={}",
ci,
j,
STD[i][j],
arr[j],
ratio
);
assert_approx_equal(ci, j, STD[i][j], arr[j]);
}
i += 1;
}

for ci in [NumStdDev::One, NumStdDev::Two, NumStdDev::Three] {
let arr = run_test_aux(200, ci, 1e-5);
for j in 0..5 {
let ratio = arr[j] / STD[i][j];
assert!(
(ratio - 1.0) < TOL,
"ci={:?}, j={}: expected {}, got {}, ratio={}",
ci,
j,
STD[i][j],
arr[j],
ratio
);
assert_approx_equal(ci, j, STD[i][j], arr[j]);
}
i += 1;
}

for ci in [NumStdDev::One, NumStdDev::Two, NumStdDev::Three] {
let arr = run_test_aux(2000, ci, 1e-7);
for j in 0..5 {
let ratio = arr[j] / STD[i][j];
assert!(
(ratio - 1.0).abs() < TOL,
"ci={:?}, j={}: expected {}, got {}, ratio={}",
ci,
j,
STD[i][j],
arr[j],
ratio
);
assert_approx_equal(ci, j, STD[i][j], arr[j]);
}
i += 1;
}
Expand Down