Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions datasketches/src/bloom/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
//!
//! # Usage
//!
//! ```rust
//! ```
//! use datasketches::bloom::BloomFilter;
//! use datasketches::bloom::BloomFilterBuilder;
//!
Expand Down Expand Up @@ -60,7 +60,7 @@
//!
//! Automatically calculates optimal size and hash functions:
//!
//! ```rust
//! ```
//! # use datasketches::bloom::BloomFilterBuilder;
//! let filter = BloomFilterBuilder::with_accuracy(
//! 10_000, // Expected max items
Expand All @@ -74,7 +74,7 @@
//!
//! Specify requested bit count and hash functions (rounded up to a multiple of 64 bits):
//!
//! ```rust
//! ```
//! # use datasketches::bloom::BloomFilterBuilder;
//! let filter = BloomFilterBuilder::with_size(
//! 95_851, // Number of bits
Expand All @@ -87,7 +87,7 @@
//!
//! Bloom filters support efficient set operations:
//!
//! ```rust
//! ```
//! # use datasketches::bloom::BloomFilterBuilder;
//! let mut filter1 = BloomFilterBuilder::with_accuracy(100, 0.01).build();
//! let mut filter2 = BloomFilterBuilder::with_accuracy(100, 0.01).build();
Expand Down
35 changes: 14 additions & 21 deletions datasketches/src/bloom/sketch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ use std::hash::Hasher;

use crate::codec::SketchBytes;
use crate::codec::SketchSlice;
use crate::codec::assert::ensure_preamble_longs_in_range;
use crate::codec::assert::ensure_serial_version_is;
use crate::codec::assert::insufficient_data;
use crate::codec::family::Family;
use crate::codec::utility::ensure_preamble_longs_in_range;
use crate::codec::utility::ensure_serial_version_is;
use crate::error::Error;
use crate::hash::XxHash64;

Expand Down Expand Up @@ -399,18 +400,14 @@ impl BloomFilter {
// Read preamble
let preamble_longs = cursor
.read_u8()
.map_err(|_| Error::insufficient_data("preamble_longs"))?;
.map_err(insufficient_data("preamble_longs"))?;
let serial_version = cursor
.read_u8()
.map_err(|_| Error::insufficient_data("serial_version"))?;
let family_id = cursor
.read_u8()
.map_err(|_| Error::insufficient_data("family_id"))?;
.map_err(insufficient_data("serial_version"))?;
let family_id = cursor.read_u8().map_err(insufficient_data("family_id"))?;

// Byte 3: flags byte (directly after family_id)
let flags = cursor
.read_u8()
.map_err(|_| Error::insufficient_data("flags"))?;
let flags = cursor.read_u8().map_err(insufficient_data("flags"))?;

// Validate
Family::BLOOMFILTER.validate_id(family_id)?;
Expand All @@ -425,7 +422,7 @@ impl BloomFilter {
// Bytes 4-5: num_hashes (u16)
let num_hashes = cursor
.read_u16_le()
.map_err(|_| Error::insufficient_data("num_hashes"))?;
.map_err(insufficient_data("num_hashes"))?;
if num_hashes == 0 || num_hashes > i16::MAX as u16 {
return Err(Error::deserial(format!(
"invalid num_hashes: expected [1, {}], got {}",
Expand All @@ -436,18 +433,14 @@ impl BloomFilter {
// Bytes 6-7: unused (u16)
let _unused = cursor
.read_u16_le()
.map_err(|_| Error::insufficient_data("unused_header"))?;
let seed = cursor
.read_u64_le()
.map_err(|_| Error::insufficient_data("seed"))?;
.map_err(insufficient_data("unused_header"))?;
let seed = cursor.read_u64_le().map_err(insufficient_data("seed"))?;

// Bit array capacity is stored as number of 64-bit words (int32) + unused padding (uint32).
let num_longs = cursor
.read_i32_le()
.map_err(|_| Error::insufficient_data("num_longs"))?;
let _unused = cursor
.read_u32_le()
.map_err(|_| Error::insufficient_data("unused"))?;
.map_err(insufficient_data("num_longs"))?;
let _unused = cursor.read_u32_le().map_err(insufficient_data("unused"))?;

if num_longs <= 0 {
return Err(Error::deserial(format!(
Expand All @@ -465,12 +458,12 @@ impl BloomFilter {
} else {
let raw_num_bits_set = cursor
.read_u64_le()
.map_err(|_| Error::insufficient_data("num_bits_set"))?;
.map_err(insufficient_data("num_bits_set"))?;

for word in &mut bit_array {
*word = cursor
.read_u64_le()
.map_err(|_| Error::insufficient_data("bit_array"))?;
.map_err(insufficient_data("bit_array"))?;
}

// Handle "dirty" state: 0xFFFFFFFFFFFFFFFF indicates bits need recounting
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ use std::ops::RangeBounds;

use crate::error::Error;

pub(crate) fn insufficient_data(tag: &'static str) -> impl FnOnce(std::io::Error) -> Error {
move |_| Error::insufficient_data(tag)
}

pub(crate) fn ensure_serial_version_is(expected: u8, actual: u8) -> Result<(), Error> {
if expected == actual {
Ok(())
Expand Down
2 changes: 1 addition & 1 deletion datasketches/src/codec/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@ pub use self::decode::SketchSlice;
pub use self::encode::SketchBytes;

// private to datasketches crate
pub(crate) mod assert;
pub(crate) mod family;
pub(crate) mod utility;
4 changes: 2 additions & 2 deletions datasketches/src/countmin/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
//!
//! # Usage
//!
//! ```rust
//! ```
//! # use datasketches::countmin::CountMinSketch;
//! let mut sketch = CountMinSketch::<i64>::new(5, 256);
//! sketch.update("apple");
Expand All @@ -32,7 +32,7 @@
//!
//! # Configuration Helpers
//!
//! ```rust
//! ```
//! # use datasketches::countmin::CountMinSketch;
//! let buckets = CountMinSketch::<i64>::suggest_num_buckets(0.01);
//! let hashes = CountMinSketch::<i64>::suggest_num_hashes(0.99);
Expand Down
61 changes: 34 additions & 27 deletions datasketches/src/countmin/sketch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@ use std::hash::Hasher;

use crate::codec::SketchBytes;
use crate::codec::SketchSlice;
use crate::codec::assert::ensure_preamble_longs_in;
use crate::codec::assert::ensure_serial_version_is;
use crate::codec::assert::insufficient_data;
use crate::codec::family::Family;
use crate::codec::utility::ensure_preamble_longs_in;
use crate::codec::utility::ensure_serial_version_is;
use crate::countmin::CountMinValue;
use crate::countmin::UnsignedCountMinValue;
use crate::countmin::serialization::FLAGS_IS_EMPTY;
Expand Down Expand Up @@ -61,7 +62,7 @@ impl<T: CountMinValue> CountMinSketch<T> {
///
/// # Examples
///
/// ```rust
/// ```
/// # use datasketches::countmin::CountMinSketch;
/// let sketch = CountMinSketch::<i64>::new(4, 128);
/// assert_eq!(sketch.num_buckets(), 128);
Expand All @@ -82,7 +83,7 @@ impl<T: CountMinValue> CountMinSketch<T> {
///
/// # Examples
///
/// ```rust
/// ```
/// # use datasketches::countmin::CountMinSketch;
/// let sketch = CountMinSketch::<i64>::with_seed(4, 64, 42);
/// assert_eq!(sketch.seed(), 42);
Expand Down Expand Up @@ -153,7 +154,7 @@ impl<T: CountMinValue> CountMinSketch<T> {
///
/// # Examples
///
/// ```rust
/// ```
/// # use datasketches::countmin::CountMinSketch;
/// let mut sketch = CountMinSketch::<i64>::new(4, 128);
/// sketch.update("apple");
Expand All @@ -167,7 +168,7 @@ impl<T: CountMinValue> CountMinSketch<T> {
///
/// # Examples
///
/// ```rust
/// ```
/// # use datasketches::countmin::CountMinSketch;
/// let mut sketch = CountMinSketch::<i64>::new(4, 128);
/// sketch.update_with_weight("banana", 3);
Expand All @@ -191,7 +192,7 @@ impl<T: CountMinValue> CountMinSketch<T> {
///
/// # Examples
///
/// ```rust
/// ```
/// # use datasketches::countmin::CountMinSketch;
/// let mut sketch = CountMinSketch::<i64>::new(4, 128);
/// sketch.update_with_weight("pear", 2);
Expand Down Expand Up @@ -231,7 +232,7 @@ impl<T: CountMinValue> CountMinSketch<T> {
///
/// # Examples
///
/// ```rust
/// ```
/// # use datasketches::countmin::CountMinSketch;
/// let mut left = CountMinSketch::<i64>::new(4, 128);
/// let mut right = CountMinSketch::<i64>::new(4, 128);
Expand Down Expand Up @@ -261,7 +262,7 @@ impl<T: CountMinValue> CountMinSketch<T> {
///
/// # Examples
///
/// ```rust
/// ```
/// # use datasketches::countmin::CountMinSketch;
/// # let mut sketch = CountMinSketch::<i64>::new(4, 128);
/// # sketch.update("apple");
Expand Down Expand Up @@ -306,7 +307,7 @@ impl<T: CountMinValue> CountMinSketch<T> {
///
/// # Examples
///
/// ```rust
/// ```
/// # use datasketches::countmin::CountMinSketch;
/// # let mut sketch = CountMinSketch::<i64>::new(4, 64);
/// # sketch.update("apple");
Expand All @@ -322,7 +323,7 @@ impl<T: CountMinValue> CountMinSketch<T> {
///
/// # Examples
///
/// ```rust
/// ```
/// # use datasketches::countmin::CountMinSketch;
/// # let mut sketch = CountMinSketch::<i64>::with_seed(4, 64, 7);
/// # sketch.update("apple");
Expand All @@ -331,34 +332,40 @@ impl<T: CountMinValue> CountMinSketch<T> {
/// assert!(decoded.estimate("apple") >= 1);
/// ```
pub fn deserialize_with_seed(bytes: &[u8], seed: u64) -> Result<Self, Error> {
fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> Error {
move |_| Error::insufficient_data(tag)
}

fn read_value<T: CountMinValue>(
cursor: &mut SketchSlice<'_>,
tag: &'static str,
) -> Result<T, Error> {
let mut bs = [0u8; 8];
cursor.read_exact(&mut bs).map_err(make_error(tag))?;
cursor.read_exact(&mut bs).map_err(insufficient_data(tag))?;
T::try_from_bytes(bs)
}

let mut cursor = SketchSlice::new(bytes);
let preamble_longs = cursor.read_u8().map_err(make_error("preamble_longs"))?;
let serial_version = cursor.read_u8().map_err(make_error("serial_version"))?;
let family_id = cursor.read_u8().map_err(make_error("family_id"))?;
let flags = cursor.read_u8().map_err(make_error("flags"))?;
cursor.read_u32_le().map_err(make_error("<unused>"))?;
let preamble_longs = cursor
.read_u8()
.map_err(insufficient_data("preamble_longs"))?;
let serial_version = cursor
.read_u8()
.map_err(insufficient_data("serial_version"))?;
let family_id = cursor.read_u8().map_err(insufficient_data("family_id"))?;
let flags = cursor.read_u8().map_err(insufficient_data("flags"))?;
cursor
.read_u32_le()
.map_err(insufficient_data("<unused>"))?;

Family::COUNTMIN.validate_id(family_id)?;
ensure_serial_version_is(SERIAL_VERSION, serial_version)?;
ensure_preamble_longs_in(&[PREAMBLE_LONGS_SHORT], preamble_longs)?;

let num_buckets = cursor.read_u32_le().map_err(make_error("num_buckets"))?;
let num_hashes = cursor.read_u8().map_err(make_error("num_hashes"))?;
let seed_hash = cursor.read_u16_le().map_err(make_error("seed_hash"))?;
cursor.read_u8().map_err(make_error("unused8"))?;
let num_buckets = cursor
.read_u32_le()
.map_err(insufficient_data("num_buckets"))?;
let num_hashes = cursor.read_u8().map_err(insufficient_data("num_hashes"))?;
let seed_hash = cursor
.read_u16_le()
.map_err(insufficient_data("seed_hash"))?;
cursor.read_u8().map_err(insufficient_data("unused8"))?;

let expected_seed_hash = compute_seed_hash(seed);
if seed_hash != expected_seed_hash {
Expand Down Expand Up @@ -410,7 +417,7 @@ impl<T: UnsignedCountMinValue> CountMinSketch<T> {
///
/// # Examples
///
/// ```rust
/// ```
/// # use datasketches::countmin::CountMinSketch;
/// let mut sketch = CountMinSketch::<u64>::new(4, 128);
/// sketch.update_with_weight("apple", 3);
Expand All @@ -431,7 +438,7 @@ impl<T: UnsignedCountMinValue> CountMinSketch<T> {
///
/// # Examples
///
/// ```rust
/// ```
/// # use datasketches::countmin::CountMinSketch;
/// let mut sketch = CountMinSketch::<u64>::new(4, 128);
/// sketch.update_with_weight("apple", 3);
Expand Down
Loading