diff --git a/datasketches/src/common/binomial_bounds.rs b/datasketches/src/common/binomial_bounds.rs index c2b7d74..2950d1d 100644 --- a/datasketches/src/common/binomial_bounds.rs +++ b/datasketches/src/common/binomial_bounds.rs @@ -275,7 +275,7 @@ static UB_EQUIV_TABLE: [f64; 363] = [ /// # Arguments /// /// * `num_samples`: The number of samples in the sample set. -/// * `theta`: The sampling probability. Must be in the range (0.0, 1.0]. +/// * `theta`: The sampling probability. Must be in the range `(0.0, 1.0]`. /// * `num_std_dev`: The number of standard deviations for confidence bounds. /// /// # Returns @@ -284,13 +284,17 @@ static UB_EQUIV_TABLE: [f64; 363] = [ /// /// # Errors /// -/// Returns an error if `theta` is not in the range (0.0, 1.0]. +/// Returns an error if `theta` is not in the range `(0.0, 1.0]`. pub(crate) fn lower_bound( num_samples: u64, theta: f64, num_std_dev: NumStdDev, ) -> Result { - check_theta(theta)?; + if theta <= 0.0 || theta > 1.0 { + return Err(Error::invalid_argument(format!( + "theta must be in the range (0.0, 1.0], got {theta}" + ))); + } let estimate = num_samples as f64 / theta; let lb = compute_approx_binomial_lower_bound(num_samples, theta, num_std_dev); @@ -325,7 +329,12 @@ pub(crate) fn upper_bound( if no_data_seen { return Ok(0.0); } - check_theta(theta)?; + + if theta <= 0.0 || theta > 1.0 { + return Err(Error::invalid_argument(format!( + "theta must be in the range (0.0, 1.0], got {theta}" + ))); + } let estimate = num_samples as f64 / theta; let ub = compute_approx_binomial_upper_bound(num_samples, theta, num_std_dev); @@ -360,6 +369,7 @@ fn cont_classic_ub(num_samples: u64, theta: f64, num_std_devs: f64) -> f64 { /// # Limitations /// /// Outside of the valid input range, two different bad things will happen: +/// /// 1. Because we are not using logarithms, the values of intermediate quantities will exceed the /// dynamic range of doubles. /// 2. Even if that problem were fixed, the running time of this procedure is essentially linear in @@ -548,21 +558,6 @@ fn compute_approx_binomial_upper_bound( special_n_prime_f(num_samples, theta, delta).unwrap_or(num_samples + 1) as f64 // no need to round } -/// Validates that theta is in the valid range [0.0, 1.0]. -/// -/// # Errors -/// -/// Returns an error if theta < 0.0 or theta > 1.0. -fn check_theta(theta: f64) -> Result<(), Error> { - if (theta <= 0.0) || (theta > 1.0) { - return Err(Error::invalid_argument(format!( - "theta must be in the range [0.0, 1.0]: {}", - theta - ))); - } - Ok(()) -} - #[cfg(test)] mod tests { use super::*; @@ -679,19 +674,18 @@ mod tests { fn check_bounds() { let mut i = 0; + fn assert_approx_equal(ci: NumStdDev, j: usize, expected: f64, actual: f64) { + let ratio = actual / expected; + assert!( + (ratio - 1.0).abs() < TOL, + "ci={ci:?}, j={j}: expected {expected}, got {actual}, ratio={ratio}", + ); + } + for ci in [NumStdDev::One, NumStdDev::Two, NumStdDev::Three] { let arr = run_test_aux(20, ci, 1e-3); for j in 0..5 { - let ratio = arr[j] / STD[i][j]; - assert!( - (ratio - 1.0).abs() < TOL, - "ci={:?}, j={}: expected {}, got {}, ratio={}", - ci, - j, - STD[i][j], - arr[j], - ratio - ); + assert_approx_equal(ci, j, STD[i][j], arr[j]); } i += 1; } @@ -699,16 +693,7 @@ mod tests { for ci in [NumStdDev::One, NumStdDev::Two, NumStdDev::Three] { let arr = run_test_aux(200, ci, 1e-5); for j in 0..5 { - let ratio = arr[j] / STD[i][j]; - assert!( - (ratio - 1.0) < TOL, - "ci={:?}, j={}: expected {}, got {}, ratio={}", - ci, - j, - STD[i][j], - arr[j], - ratio - ); + assert_approx_equal(ci, j, STD[i][j], arr[j]); } i += 1; } @@ -716,16 +701,7 @@ mod tests { for ci in [NumStdDev::One, NumStdDev::Two, NumStdDev::Three] { let arr = run_test_aux(2000, ci, 1e-7); for j in 0..5 { - let ratio = arr[j] / STD[i][j]; - assert!( - (ratio - 1.0).abs() < TOL, - "ci={:?}, j={}: expected {}, got {}, ratio={}", - ci, - j, - STD[i][j], - arr[j], - ratio - ); + assert_approx_equal(ci, j, STD[i][j], arr[j]); } i += 1; }