diff --git a/datasketches/src/bloom/mod.rs b/datasketches/src/bloom/mod.rs index 8e58139..3638ee2 100644 --- a/datasketches/src/bloom/mod.rs +++ b/datasketches/src/bloom/mod.rs @@ -30,7 +30,7 @@ //! //! # Usage //! -//! ```rust +//! ``` //! use datasketches::bloom::BloomFilter; //! use datasketches::bloom::BloomFilterBuilder; //! @@ -60,7 +60,7 @@ //! //! Automatically calculates optimal size and hash functions: //! -//! ```rust +//! ``` //! # use datasketches::bloom::BloomFilterBuilder; //! let filter = BloomFilterBuilder::with_accuracy( //! 10_000, // Expected max items @@ -74,7 +74,7 @@ //! //! Specify requested bit count and hash functions (rounded up to a multiple of 64 bits): //! -//! ```rust +//! ``` //! # use datasketches::bloom::BloomFilterBuilder; //! let filter = BloomFilterBuilder::with_size( //! 95_851, // Number of bits @@ -87,7 +87,7 @@ //! //! Bloom filters support efficient set operations: //! -//! ```rust +//! ``` //! # use datasketches::bloom::BloomFilterBuilder; //! let mut filter1 = BloomFilterBuilder::with_accuracy(100, 0.01).build(); //! let mut filter2 = BloomFilterBuilder::with_accuracy(100, 0.01).build(); diff --git a/datasketches/src/bloom/sketch.rs b/datasketches/src/bloom/sketch.rs index d7332e0..19142a3 100644 --- a/datasketches/src/bloom/sketch.rs +++ b/datasketches/src/bloom/sketch.rs @@ -20,9 +20,10 @@ use std::hash::Hasher; use crate::codec::SketchBytes; use crate::codec::SketchSlice; +use crate::codec::assert::ensure_preamble_longs_in_range; +use crate::codec::assert::ensure_serial_version_is; +use crate::codec::assert::insufficient_data; use crate::codec::family::Family; -use crate::codec::utility::ensure_preamble_longs_in_range; -use crate::codec::utility::ensure_serial_version_is; use crate::error::Error; use crate::hash::XxHash64; @@ -399,18 +400,14 @@ impl BloomFilter { // Read preamble let preamble_longs = cursor .read_u8() - .map_err(|_| Error::insufficient_data("preamble_longs"))?; + .map_err(insufficient_data("preamble_longs"))?; let serial_version = cursor .read_u8() - .map_err(|_| Error::insufficient_data("serial_version"))?; - let family_id = cursor - .read_u8() - .map_err(|_| Error::insufficient_data("family_id"))?; + .map_err(insufficient_data("serial_version"))?; + let family_id = cursor.read_u8().map_err(insufficient_data("family_id"))?; // Byte 3: flags byte (directly after family_id) - let flags = cursor - .read_u8() - .map_err(|_| Error::insufficient_data("flags"))?; + let flags = cursor.read_u8().map_err(insufficient_data("flags"))?; // Validate Family::BLOOMFILTER.validate_id(family_id)?; @@ -425,7 +422,7 @@ impl BloomFilter { // Bytes 4-5: num_hashes (u16) let num_hashes = cursor .read_u16_le() - .map_err(|_| Error::insufficient_data("num_hashes"))?; + .map_err(insufficient_data("num_hashes"))?; if num_hashes == 0 || num_hashes > i16::MAX as u16 { return Err(Error::deserial(format!( "invalid num_hashes: expected [1, {}], got {}", @@ -436,18 +433,14 @@ impl BloomFilter { // Bytes 6-7: unused (u16) let _unused = cursor .read_u16_le() - .map_err(|_| Error::insufficient_data("unused_header"))?; - let seed = cursor - .read_u64_le() - .map_err(|_| Error::insufficient_data("seed"))?; + .map_err(insufficient_data("unused_header"))?; + let seed = cursor.read_u64_le().map_err(insufficient_data("seed"))?; // Bit array capacity is stored as number of 64-bit words (int32) + unused padding (uint32). let num_longs = cursor .read_i32_le() - .map_err(|_| Error::insufficient_data("num_longs"))?; - let _unused = cursor - .read_u32_le() - .map_err(|_| Error::insufficient_data("unused"))?; + .map_err(insufficient_data("num_longs"))?; + let _unused = cursor.read_u32_le().map_err(insufficient_data("unused"))?; if num_longs <= 0 { return Err(Error::deserial(format!( @@ -465,12 +458,12 @@ impl BloomFilter { } else { let raw_num_bits_set = cursor .read_u64_le() - .map_err(|_| Error::insufficient_data("num_bits_set"))?; + .map_err(insufficient_data("num_bits_set"))?; for word in &mut bit_array { *word = cursor .read_u64_le() - .map_err(|_| Error::insufficient_data("bit_array"))?; + .map_err(insufficient_data("bit_array"))?; } // Handle "dirty" state: 0xFFFFFFFFFFFFFFFF indicates bits need recounting diff --git a/datasketches/src/codec/utility.rs b/datasketches/src/codec/assert.rs similarity index 94% rename from datasketches/src/codec/utility.rs rename to datasketches/src/codec/assert.rs index e098fd6..2ef5ee5 100644 --- a/datasketches/src/codec/utility.rs +++ b/datasketches/src/codec/assert.rs @@ -20,6 +20,10 @@ use std::ops::RangeBounds; use crate::error::Error; +pub(crate) fn insufficient_data(tag: &'static str) -> impl FnOnce(std::io::Error) -> Error { + move |_| Error::insufficient_data(tag) +} + pub(crate) fn ensure_serial_version_is(expected: u8, actual: u8) -> Result<(), Error> { if expected == actual { Ok(()) diff --git a/datasketches/src/codec/mod.rs b/datasketches/src/codec/mod.rs index de4648a..28008ff 100644 --- a/datasketches/src/codec/mod.rs +++ b/datasketches/src/codec/mod.rs @@ -24,5 +24,5 @@ pub use self::decode::SketchSlice; pub use self::encode::SketchBytes; // private to datasketches crate +pub(crate) mod assert; pub(crate) mod family; -pub(crate) mod utility; diff --git a/datasketches/src/countmin/mod.rs b/datasketches/src/countmin/mod.rs index 254907e..cea20ae 100644 --- a/datasketches/src/countmin/mod.rs +++ b/datasketches/src/countmin/mod.rs @@ -22,7 +22,7 @@ //! //! # Usage //! -//! ```rust +//! ``` //! # use datasketches::countmin::CountMinSketch; //! let mut sketch = CountMinSketch::::new(5, 256); //! sketch.update("apple"); @@ -32,7 +32,7 @@ //! //! # Configuration Helpers //! -//! ```rust +//! ``` //! # use datasketches::countmin::CountMinSketch; //! let buckets = CountMinSketch::::suggest_num_buckets(0.01); //! let hashes = CountMinSketch::::suggest_num_hashes(0.99); diff --git a/datasketches/src/countmin/sketch.rs b/datasketches/src/countmin/sketch.rs index 2116b75..48b7a3f 100644 --- a/datasketches/src/countmin/sketch.rs +++ b/datasketches/src/countmin/sketch.rs @@ -20,9 +20,10 @@ use std::hash::Hasher; use crate::codec::SketchBytes; use crate::codec::SketchSlice; +use crate::codec::assert::ensure_preamble_longs_in; +use crate::codec::assert::ensure_serial_version_is; +use crate::codec::assert::insufficient_data; use crate::codec::family::Family; -use crate::codec::utility::ensure_preamble_longs_in; -use crate::codec::utility::ensure_serial_version_is; use crate::countmin::CountMinValue; use crate::countmin::UnsignedCountMinValue; use crate::countmin::serialization::FLAGS_IS_EMPTY; @@ -61,7 +62,7 @@ impl CountMinSketch { /// /// # Examples /// - /// ```rust + /// ``` /// # use datasketches::countmin::CountMinSketch; /// let sketch = CountMinSketch::::new(4, 128); /// assert_eq!(sketch.num_buckets(), 128); @@ -82,7 +83,7 @@ impl CountMinSketch { /// /// # Examples /// - /// ```rust + /// ``` /// # use datasketches::countmin::CountMinSketch; /// let sketch = CountMinSketch::::with_seed(4, 64, 42); /// assert_eq!(sketch.seed(), 42); @@ -153,7 +154,7 @@ impl CountMinSketch { /// /// # Examples /// - /// ```rust + /// ``` /// # use datasketches::countmin::CountMinSketch; /// let mut sketch = CountMinSketch::::new(4, 128); /// sketch.update("apple"); @@ -167,7 +168,7 @@ impl CountMinSketch { /// /// # Examples /// - /// ```rust + /// ``` /// # use datasketches::countmin::CountMinSketch; /// let mut sketch = CountMinSketch::::new(4, 128); /// sketch.update_with_weight("banana", 3); @@ -191,7 +192,7 @@ impl CountMinSketch { /// /// # Examples /// - /// ```rust + /// ``` /// # use datasketches::countmin::CountMinSketch; /// let mut sketch = CountMinSketch::::new(4, 128); /// sketch.update_with_weight("pear", 2); @@ -231,7 +232,7 @@ impl CountMinSketch { /// /// # Examples /// - /// ```rust + /// ``` /// # use datasketches::countmin::CountMinSketch; /// let mut left = CountMinSketch::::new(4, 128); /// let mut right = CountMinSketch::::new(4, 128); @@ -261,7 +262,7 @@ impl CountMinSketch { /// /// # Examples /// - /// ```rust + /// ``` /// # use datasketches::countmin::CountMinSketch; /// # let mut sketch = CountMinSketch::::new(4, 128); /// # sketch.update("apple"); @@ -306,7 +307,7 @@ impl CountMinSketch { /// /// # Examples /// - /// ```rust + /// ``` /// # use datasketches::countmin::CountMinSketch; /// # let mut sketch = CountMinSketch::::new(4, 64); /// # sketch.update("apple"); @@ -322,7 +323,7 @@ impl CountMinSketch { /// /// # Examples /// - /// ```rust + /// ``` /// # use datasketches::countmin::CountMinSketch; /// # let mut sketch = CountMinSketch::::with_seed(4, 64, 7); /// # sketch.update("apple"); @@ -331,34 +332,40 @@ impl CountMinSketch { /// assert!(decoded.estimate("apple") >= 1); /// ``` pub fn deserialize_with_seed(bytes: &[u8], seed: u64) -> Result { - fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> Error { - move |_| Error::insufficient_data(tag) - } - fn read_value( cursor: &mut SketchSlice<'_>, tag: &'static str, ) -> Result { let mut bs = [0u8; 8]; - cursor.read_exact(&mut bs).map_err(make_error(tag))?; + cursor.read_exact(&mut bs).map_err(insufficient_data(tag))?; T::try_from_bytes(bs) } let mut cursor = SketchSlice::new(bytes); - let preamble_longs = cursor.read_u8().map_err(make_error("preamble_longs"))?; - let serial_version = cursor.read_u8().map_err(make_error("serial_version"))?; - let family_id = cursor.read_u8().map_err(make_error("family_id"))?; - let flags = cursor.read_u8().map_err(make_error("flags"))?; - cursor.read_u32_le().map_err(make_error(""))?; + let preamble_longs = cursor + .read_u8() + .map_err(insufficient_data("preamble_longs"))?; + let serial_version = cursor + .read_u8() + .map_err(insufficient_data("serial_version"))?; + let family_id = cursor.read_u8().map_err(insufficient_data("family_id"))?; + let flags = cursor.read_u8().map_err(insufficient_data("flags"))?; + cursor + .read_u32_le() + .map_err(insufficient_data(""))?; Family::COUNTMIN.validate_id(family_id)?; ensure_serial_version_is(SERIAL_VERSION, serial_version)?; ensure_preamble_longs_in(&[PREAMBLE_LONGS_SHORT], preamble_longs)?; - let num_buckets = cursor.read_u32_le().map_err(make_error("num_buckets"))?; - let num_hashes = cursor.read_u8().map_err(make_error("num_hashes"))?; - let seed_hash = cursor.read_u16_le().map_err(make_error("seed_hash"))?; - cursor.read_u8().map_err(make_error("unused8"))?; + let num_buckets = cursor + .read_u32_le() + .map_err(insufficient_data("num_buckets"))?; + let num_hashes = cursor.read_u8().map_err(insufficient_data("num_hashes"))?; + let seed_hash = cursor + .read_u16_le() + .map_err(insufficient_data("seed_hash"))?; + cursor.read_u8().map_err(insufficient_data("unused8"))?; let expected_seed_hash = compute_seed_hash(seed); if seed_hash != expected_seed_hash { @@ -410,7 +417,7 @@ impl CountMinSketch { /// /// # Examples /// - /// ```rust + /// ``` /// # use datasketches::countmin::CountMinSketch; /// let mut sketch = CountMinSketch::::new(4, 128); /// sketch.update_with_weight("apple", 3); @@ -431,7 +438,7 @@ impl CountMinSketch { /// /// # Examples /// - /// ```rust + /// ``` /// # use datasketches::countmin::CountMinSketch; /// let mut sketch = CountMinSketch::::new(4, 128); /// sketch.update_with_weight("apple", 3); diff --git a/datasketches/src/cpc/sketch.rs b/datasketches/src/cpc/sketch.rs index 15c818f..d11d59b 100644 --- a/datasketches/src/cpc/sketch.rs +++ b/datasketches/src/cpc/sketch.rs @@ -19,9 +19,10 @@ use std::hash::Hash; use crate::codec::SketchBytes; use crate::codec::SketchSlice; +use crate::codec::assert::ensure_preamble_longs_in; +use crate::codec::assert::ensure_serial_version_is; +use crate::codec::assert::insufficient_data; use crate::codec::family::Family; -use crate::codec::utility::ensure_preamble_longs_in; -use crate::codec::utility::ensure_serial_version_is; use crate::common::NumStdDev; use crate::common::canonical_double; use crate::common::inv_pow2_table::INVERSE_POWERS_OF_2; @@ -514,24 +515,26 @@ impl CpcSketch { /// Deserializes a CpcSketch from bytes with the provided seed. pub fn deserialize_with_seed(bytes: &[u8], seed: u64) -> Result { - fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> Error { - move |_| Error::insufficient_data(tag) - } - let mut cursor = SketchSlice::new(bytes); - let preamble_ints = cursor.read_u8().map_err(make_error("preamble_ints"))?; - let serial_version = cursor.read_u8().map_err(make_error("serial_version"))?; - let family_id = cursor.read_u8().map_err(make_error("family_id"))?; + let preamble_ints = cursor + .read_u8() + .map_err(insufficient_data("preamble_ints"))?; + let serial_version = cursor + .read_u8() + .map_err(insufficient_data("serial_version"))?; + let family_id = cursor.read_u8().map_err(insufficient_data("family_id"))?; Family::CPC.validate_id(family_id)?; ensure_serial_version_is(SERIAL_VERSION, serial_version)?; - let lg_k = cursor.read_u8().map_err(make_error("lg_k"))?; + let lg_k = cursor.read_u8().map_err(insufficient_data("lg_k"))?; let first_interesting_column = cursor .read_u8() - .map_err(make_error("first_interesting_column"))?; + .map_err(insufficient_data("first_interesting_column"))?; - let flags = cursor.read_u8().map_err(make_error("flags"))?; - let seed_hash = cursor.read_u16_le().map_err(make_error("seed_hash"))?; + let flags = cursor.read_u8().map_err(insufficient_data("flags"))?; + let seed_hash = cursor + .read_u16_le() + .map_err(insufficient_data("seed_hash"))?; let is_compressed = flags & (1 << FLAG_COMPRESSED) != 0; if !is_compressed { return Err(Error::new( @@ -549,41 +552,51 @@ impl CpcSketch { let mut hip_est_accum = 0.0; if has_table || has_window { - num_coupons = cursor.read_u32_le().map_err(make_error("num_coupons"))?; + num_coupons = cursor + .read_u32_le() + .map_err(insufficient_data("num_coupons"))?; if has_table && has_window { compressed.table_num_entries = cursor .read_u32_le() - .map_err(make_error("table_num_entries"))?; + .map_err(insufficient_data("table_num_entries"))?; if has_hip { - kxp = cursor.read_f64_le().map_err(make_error("kxp"))?; - hip_est_accum = cursor.read_f64_le().map_err(make_error("hip_est_accum"))?; + kxp = cursor.read_f64_le().map_err(insufficient_data("kxp"))?; + hip_est_accum = cursor + .read_f64_le() + .map_err(insufficient_data("hip_est_accum"))?; } } if has_table { compressed.table_data_words = cursor .read_u32_le() - .map_err(make_error("table_data_words"))? + .map_err(insufficient_data("table_data_words"))? as usize; } if has_window { compressed.window_data_words = cursor .read_u32_le() - .map_err(make_error("window_data_words"))? + .map_err(insufficient_data("window_data_words"))? as usize; } if has_hip && !(has_table && has_window) { - kxp = cursor.read_f64_le().map_err(make_error("kxp"))?; - hip_est_accum = cursor.read_f64_le().map_err(make_error("hip_est_accum"))?; + kxp = cursor.read_f64_le().map_err(insufficient_data("kxp"))?; + hip_est_accum = cursor + .read_f64_le() + .map_err(insufficient_data("hip_est_accum"))?; } if has_window { for _ in 0..compressed.window_data_words { - let word = cursor.read_u32_le().map_err(make_error("window_data"))?; + let word = cursor + .read_u32_le() + .map_err(insufficient_data("window_data"))?; compressed.window_data.push(word); } } if has_table { for _ in 0..compressed.table_data_words { - let word = cursor.read_u32_le().map_err(make_error("table_data"))?; + let word = cursor + .read_u32_le() + .map_err(insufficient_data("table_data"))?; compressed.table_data.push(word); } } diff --git a/datasketches/src/cpc/wrapper.rs b/datasketches/src/cpc/wrapper.rs index 2b1000e..622d5ed 100644 --- a/datasketches/src/cpc/wrapper.rs +++ b/datasketches/src/cpc/wrapper.rs @@ -16,9 +16,10 @@ // under the License. use crate::codec::SketchSlice; +use crate::codec::assert::ensure_preamble_longs_in; +use crate::codec::assert::ensure_serial_version_is; +use crate::codec::assert::insufficient_data; use crate::codec::family::Family; -use crate::codec::utility::ensure_preamble_longs_in; -use crate::codec::utility::ensure_serial_version_is; use crate::common::NumStdDev; use crate::cpc::MAX_LG_K; use crate::cpc::MIN_LG_K; @@ -46,21 +47,21 @@ pub struct CpcWrapper { impl CpcWrapper { /// Creates a new `CpcWrapper` from the given byte slice without copying bytes. pub fn new(bytes: &[u8]) -> Result { - fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> Error { - move |_| Error::insufficient_data(tag) - } - let mut cursor = SketchSlice::new(bytes); - let preamble_ints = cursor.read_u8().map_err(make_error("preamble_ints"))?; - let serial_version = cursor.read_u8().map_err(make_error("serial_version"))?; - let family_id = cursor.read_u8().map_err(make_error("family_id"))?; + let preamble_ints = cursor + .read_u8() + .map_err(insufficient_data("preamble_ints"))?; + let serial_version = cursor + .read_u8() + .map_err(insufficient_data("serial_version"))?; + let family_id = cursor.read_u8().map_err(insufficient_data("family_id"))?; Family::CPC.validate_id(family_id)?; ensure_serial_version_is(SERIAL_VERSION, serial_version)?; - let lg_k = cursor.read_u8().map_err(make_error("lg_k"))?; + let lg_k = cursor.read_u8().map_err(insufficient_data("lg_k"))?; let first_interesting_column = cursor .read_u8() - .map_err(make_error("first_interesting_column"))?; + .map_err(insufficient_data("first_interesting_column"))?; if !(MIN_LG_K..=MAX_LG_K).contains(&lg_k) { return Err(Error::invalid_argument(format!( "lg_k out of range; got {}", @@ -74,7 +75,7 @@ impl CpcWrapper { ))); } - let flags = cursor.read_u8().map_err(make_error("flags"))?; + let flags = cursor.read_u8().map_err(insufficient_data("flags"))?; let is_compressed = flags & (1 << FLAG_COMPRESSED) != 0; if !is_compressed { return Err(Error::new( @@ -86,35 +87,43 @@ impl CpcWrapper { let has_table = flags & (1 << FLAG_HAS_TABLE) != 0; let has_window = flags & (1 << FLAG_HAS_WINDOW) != 0; - cursor.read_u16_le().map_err(make_error("seed_hash"))?; + cursor + .read_u16_le() + .map_err(insufficient_data("seed_hash"))?; let mut num_coupons = 0; let mut hip_est_accum = 0.0; if has_table || has_window { - num_coupons = cursor.read_u32_le().map_err(make_error("num_coupons"))?; + num_coupons = cursor + .read_u32_le() + .map_err(insufficient_data("num_coupons"))?; if has_table && has_window { cursor .read_u32_le() - .map_err(make_error("table_num_entries"))?; + .map_err(insufficient_data("table_num_entries"))?; if has_hip { - cursor.read_f64_le().map_err(make_error("kxp"))?; - hip_est_accum = cursor.read_f64_le().map_err(make_error("hip_est_accum"))?; + cursor.read_f64_le().map_err(insufficient_data("kxp"))?; + hip_est_accum = cursor + .read_f64_le() + .map_err(insufficient_data("hip_est_accum"))?; } } if has_table { cursor .read_u32_le() - .map_err(make_error("table_data_words"))?; + .map_err(insufficient_data("table_data_words"))?; } if has_window { cursor .read_u32_le() - .map_err(make_error("window_data_words"))?; + .map_err(insufficient_data("window_data_words"))?; } if has_hip && !(has_table && has_window) { - cursor.read_f64_le().map_err(make_error("kxp"))?; - hip_est_accum = cursor.read_f64_le().map_err(make_error("hip_est_accum"))?; + cursor.read_f64_le().map_err(insufficient_data("kxp"))?; + hip_est_accum = cursor + .read_f64_le() + .map_err(insufficient_data("hip_est_accum"))?; } } diff --git a/datasketches/src/frequencies/sketch.rs b/datasketches/src/frequencies/sketch.rs index 83de1cf..fc2b297 100644 --- a/datasketches/src/frequencies/sketch.rs +++ b/datasketches/src/frequencies/sketch.rs @@ -21,12 +21,17 @@ use std::hash::Hash; use crate::codec::SketchBytes; use crate::codec::SketchSlice; +use crate::codec::assert::ensure_preamble_longs_in; +use crate::codec::assert::ensure_serial_version_is; +use crate::codec::assert::insufficient_data; use crate::codec::family::Family; -use crate::codec::utility::ensure_preamble_longs_in; -use crate::codec::utility::ensure_serial_version_is; use crate::error::Error; +use crate::frequencies::FrequentItemValue; use crate::frequencies::reverse_purge_item_hash_map::ReversePurgeItemHashMap; -use crate::frequencies::serialization::*; +use crate::frequencies::serialization::EMPTY_FLAG_MASK; +use crate::frequencies::serialization::PREAMBLE_LONGS_EMPTY; +use crate::frequencies::serialization::PREAMBLE_LONGS_NONEMPTY; +use crate::frequencies::serialization::SERIAL_VERSION; type CountSerializeSize = fn(&[T]) -> usize; type SerializeItems = fn(&mut SketchBytes, &[T]); @@ -451,19 +456,23 @@ impl FrequentItemsSketch { bytes: &[u8], deserialize_items: DeserializeItems, ) -> Result { - fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> Error { - move |_| Error::insufficient_data(tag) - } - let mut cursor = SketchSlice::new(bytes); - let pre_longs = cursor.read_u8().map_err(make_error("pre_longs"))?; + let pre_longs = cursor.read_u8().map_err(insufficient_data("pre_longs"))?; let pre_longs = pre_longs & 0x3F; - let serial_version = cursor.read_u8().map_err(make_error("serial_version"))?; - let family = cursor.read_u8().map_err(make_error("family"))?; - let lg_max = cursor.read_u8().map_err(make_error("lg_max_map_size"))?; - let lg_cur = cursor.read_u8().map_err(make_error("lg_cur_map_size"))?; - let flags = cursor.read_u8().map_err(make_error("flags"))?; - cursor.read_u16_le().map_err(make_error(""))?; + let serial_version = cursor + .read_u8() + .map_err(insufficient_data("serial_version"))?; + let family = cursor.read_u8().map_err(insufficient_data("family"))?; + let lg_max = cursor + .read_u8() + .map_err(insufficient_data("lg_max_map_size"))?; + let lg_cur = cursor + .read_u8() + .map_err(insufficient_data("lg_cur_map_size"))?; + let flags = cursor.read_u8().map_err(insufficient_data("flags"))?; + cursor + .read_u16_le() + .map_err(insufficient_data(""))?; Family::FREQUENCY.validate_id(family)?; ensure_serial_version_is(SERIAL_VERSION, serial_version)?; @@ -478,11 +487,17 @@ impl FrequentItemsSketch { } ensure_preamble_longs_in(&[PREAMBLE_LONGS_NONEMPTY], pre_longs)?; - let active_items = cursor.read_u32_le().map_err(make_error("active_items"))?; + let active_items = cursor + .read_u32_le() + .map_err(insufficient_data("active_items"))?; let active_items = active_items as usize; - cursor.read_u32_le().map_err(make_error(""))?; - let stream_weight = cursor.read_u64_le().map_err(make_error("stream_weight"))?; - let offset_val = cursor.read_u64_le().map_err(make_error("offset"))?; + cursor + .read_u32_le() + .map_err(insufficient_data(""))?; + let stream_weight = cursor + .read_u64_le() + .map_err(insufficient_data("stream_weight"))?; + let offset_val = cursor.read_u64_le().map_err(insufficient_data("offset"))?; let mut values = Vec::with_capacity(active_items); for i in 0..active_items { diff --git a/datasketches/src/hll/array4.rs b/datasketches/src/hll/array4.rs index 073c335..9fc400d 100644 --- a/datasketches/src/hll/array4.rs +++ b/datasketches/src/hll/array4.rs @@ -23,12 +23,22 @@ use super::aux_map::AuxMap; use crate::codec::SketchBytes; use crate::codec::SketchSlice; +use crate::codec::assert::insufficient_data; use crate::codec::family::Family; use crate::common::NumStdDev; use crate::error::Error; use crate::hll::estimator::HipEstimator; use crate::hll::get_slot; use crate::hll::get_value; +use crate::hll::pack_coupon; +use crate::hll::serialization::COUPON_SIZE_BYTES; +use crate::hll::serialization::CUR_MODE_HLL; +use crate::hll::serialization::HLL_PREAMBLE_SIZE; +use crate::hll::serialization::HLL_PREINTS; +use crate::hll::serialization::OUT_OF_ORDER_FLAG_MASK; +use crate::hll::serialization::SERIAL_VERSION; +use crate::hll::serialization::TGT_HLL4; +use crate::hll::serialization::encode_mode_byte; const AUX_TOKEN: u8 = 15; @@ -294,28 +304,29 @@ impl Array4 { compact: bool, ooo: bool, ) -> Result { - use crate::hll::get_slot; - use crate::hll::get_value; - - fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> Error { - move |_| Error::insufficient_data(tag) - } - let num_bytes = 1 << (lg_config_k - 1); // k/2 bytes for 4-bit packing // Read HIP estimator values from preamble - let hip_accum = cursor.read_f64_le().map_err(make_error("hip_accum"))?; - let kxq0 = cursor.read_f64_le().map_err(make_error("kxq0"))?; - let kxq1 = cursor.read_f64_le().map_err(make_error("kxq1"))?; + let hip_accum = cursor + .read_f64_le() + .map_err(insufficient_data("hip_accum"))?; + let kxq0 = cursor.read_f64_le().map_err(insufficient_data("kxq0"))?; + let kxq1 = cursor.read_f64_le().map_err(insufficient_data("kxq1"))?; // Read num_at_cur_min and aux_count - let num_at_cur_min = cursor.read_u32_le().map_err(make_error("num_at_cur_min"))?; - let aux_count = cursor.read_u32_le().map_err(make_error("aux_count"))?; + let num_at_cur_min = cursor + .read_u32_le() + .map_err(insufficient_data("num_at_cur_min"))?; + let aux_count = cursor + .read_u32_le() + .map_err(insufficient_data("aux_count"))?; // Read packed 4-bit byte array let mut data = vec![0u8; num_bytes]; if !compact { - cursor.read_exact(&mut data).map_err(make_error("data"))?; + cursor + .read_exact(&mut data) + .map_err(insufficient_data("data"))?; } else { cursor.advance(num_bytes as u64); } @@ -358,9 +369,6 @@ impl Array4 { /// /// Produces full HLL preamble (40 bytes) followed by packed 4-bit data and optional aux map. pub fn serialize(&self, lg_config_k: u8) -> Vec { - use crate::hll::pack_coupon; - use crate::hll::serialization::*; - let num_bytes = 1 << (lg_config_k - 1); // k/2 bytes for 4-bit packing // Collect aux map entries if present diff --git a/datasketches/src/hll/array6.rs b/datasketches/src/hll/array6.rs index 0bdb6eb..abaf620 100644 --- a/datasketches/src/hll/array6.rs +++ b/datasketches/src/hll/array6.rs @@ -23,12 +23,20 @@ use crate::codec::SketchBytes; use crate::codec::SketchSlice; +use crate::codec::assert::insufficient_data; use crate::codec::family::Family; use crate::common::NumStdDev; use crate::error::Error; use crate::hll::estimator::HipEstimator; use crate::hll::get_slot; use crate::hll::get_value; +use crate::hll::serialization::CUR_MODE_HLL; +use crate::hll::serialization::HLL_PREAMBLE_SIZE; +use crate::hll::serialization::HLL_PREINTS; +use crate::hll::serialization::OUT_OF_ORDER_FLAG_MASK; +use crate::hll::serialization::SERIAL_VERSION; +use crate::hll::serialization::TGT_HLL6; +use crate::hll::serialization::encode_mode_byte; const VAL_MASK_6: u16 = 0x3F; // 6 bits: 0b0011_1111 @@ -177,26 +185,30 @@ impl Array6 { compact: bool, ooo: bool, ) -> Result { - fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> Error { - move |_| Error::insufficient_data(tag) - } - let k = 1 << lg_config_k; let num_bytes = num_bytes_for_k(k); // Read HIP estimator values from preamble - let hip_accum = cursor.read_f64_le().map_err(make_error("hip_accum"))?; - let kxq0 = cursor.read_f64_le().map_err(make_error("kxq0"))?; - let kxq1 = cursor.read_f64_le().map_err(make_error("kxq1"))?; + let hip_accum = cursor + .read_f64_le() + .map_err(insufficient_data("hip_accum"))?; + let kxq0 = cursor.read_f64_le().map_err(insufficient_data("kxq0"))?; + let kxq1 = cursor.read_f64_le().map_err(insufficient_data("kxq1"))?; // Read num_at_cur_min (for Array6, this is num_zeros since cur_min=0) - let num_zeros = cursor.read_u32_le().map_err(make_error("num_zeros"))?; - let _aux_count = cursor.read_u32_le().map_err(make_error("aux_count"))?; // always 0 + let num_zeros = cursor + .read_u32_le() + .map_err(insufficient_data("num_zeros"))?; + let _aux_count = cursor + .read_u32_le() + .map_err(insufficient_data("aux_count"))?; // always 0 // Read packed byte array from offset HLL_BYTE_ARR_START let mut data = vec![0u8; num_bytes]; if !compact { - cursor.read_exact(&mut data).map_err(make_error("data"))?; + cursor + .read_exact(&mut data) + .map_err(insufficient_data("data"))?; } else { cursor.advance(num_bytes as u64); } @@ -220,8 +232,6 @@ impl Array6 { /// /// Produces full HLL preamble (40 bytes) followed by packed 6-bit data. pub fn serialize(&self, lg_config_k: u8) -> Vec { - use crate::hll::serialization::*; - let k = 1 << lg_config_k; let num_bytes = num_bytes_for_k(k); let total_size = HLL_PREAMBLE_SIZE + num_bytes; diff --git a/datasketches/src/hll/array8.rs b/datasketches/src/hll/array8.rs index 2bd1509..1a0f1f0 100644 --- a/datasketches/src/hll/array8.rs +++ b/datasketches/src/hll/array8.rs @@ -22,12 +22,20 @@ use crate::codec::SketchBytes; use crate::codec::SketchSlice; +use crate::codec::assert::insufficient_data; use crate::codec::family::Family; use crate::common::NumStdDev; use crate::error::Error; use crate::hll::estimator::HipEstimator; use crate::hll::get_slot; use crate::hll::get_value; +use crate::hll::serialization::CUR_MODE_HLL; +use crate::hll::serialization::HLL_PREAMBLE_SIZE; +use crate::hll::serialization::HLL_PREINTS; +use crate::hll::serialization::OUT_OF_ORDER_FLAG_MASK; +use crate::hll::serialization::SERIAL_VERSION; +use crate::hll::serialization::TGT_HLL8; +use crate::hll::serialization::encode_mode_byte; /// Core Array8 data structure - one byte per slot, no packing #[derive(Debug, Clone, PartialEq)] @@ -251,25 +259,29 @@ impl Array8 { compact: bool, ooo: bool, ) -> Result { - fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> Error { - move |_| Error::insufficient_data(tag) - } - let k = 1usize << lg_config_k; // Read HIP estimator values from preamble - let hip_accum = cursor.read_f64_le().map_err(make_error("hip_accum"))?; - let kxq0 = cursor.read_f64_le().map_err(make_error("kxq0"))?; - let kxq1 = cursor.read_f64_le().map_err(make_error("kxq1"))?; + let hip_accum = cursor + .read_f64_le() + .map_err(insufficient_data("hip_accum"))?; + let kxq0 = cursor.read_f64_le().map_err(insufficient_data("kxq0"))?; + let kxq1 = cursor.read_f64_le().map_err(insufficient_data("kxq1"))?; // Read num_at_cur_min (for Array8, this is num_zeros since cur_min=0) - let num_zeros = cursor.read_u32_le().map_err(make_error("num_zeros"))?; - let _aux_count = cursor.read_u32_le().map_err(make_error("aux_count"))?; // always 0 + let num_zeros = cursor + .read_u32_le() + .map_err(insufficient_data("num_zeros"))?; + let _aux_count = cursor + .read_u32_le() + .map_err(insufficient_data("aux_count"))?; // always 0 // Read byte array from offset HLL_BYTE_ARR_START let mut data = vec![0u8; k]; if !compact { - cursor.read_exact(&mut data).map_err(make_error("data"))?; + cursor + .read_exact(&mut data) + .map_err(insufficient_data("data"))?; } else { cursor.advance(k as u64); } @@ -293,8 +305,6 @@ impl Array8 { /// /// Produces full HLL preamble (40 bytes) followed by k bytes of data. pub fn serialize(&self, lg_config_k: u8) -> Vec { - use crate::hll::serialization::*; - let k = 1 << lg_config_k; let total_size = HLL_PREAMBLE_SIZE + k as usize; let mut bytes = SketchBytes::with_capacity(total_size); diff --git a/datasketches/src/hll/hash_set.rs b/datasketches/src/hll/hash_set.rs index cbe99ff..c04f89b 100644 --- a/datasketches/src/hll/hash_set.rs +++ b/datasketches/src/hll/hash_set.rs @@ -22,13 +22,19 @@ use crate::codec::SketchBytes; use crate::codec::SketchSlice; +use crate::codec::assert::insufficient_data; use crate::codec::family::Family; use crate::error::Error; use crate::hll::HllType; use crate::hll::KEY_MASK_26; use crate::hll::container::COUPON_EMPTY; use crate::hll::container::Container; -use crate::hll::serialization::*; +use crate::hll::serialization::COMPACT_FLAG_MASK; +use crate::hll::serialization::CUR_MODE_SET; +use crate::hll::serialization::HASH_SET_PREINTS; +use crate::hll::serialization::SERIAL_VERSION; +use crate::hll::serialization::SET_PREAMBLE_SIZE; +use crate::hll::serialization::encode_mode_byte; /// Hash set for efficient coupon storage with collision handling #[derive(Debug, Clone, PartialEq)] @@ -95,7 +101,7 @@ impl HashSet { // Read coupon count from bytes 8-11 let coupon_count = cursor .read_u32_le() - .map_err(|_| Error::insufficient_data("coupon_count"))?; + .map_err(insufficient_data("coupon_count"))?; let coupon_count = coupon_count as usize; if compact { diff --git a/datasketches/src/hll/list.rs b/datasketches/src/hll/list.rs index 6cd92f8..5895882 100644 --- a/datasketches/src/hll/list.rs +++ b/datasketches/src/hll/list.rs @@ -27,7 +27,13 @@ use crate::error::Error; use crate::hll::HllType; use crate::hll::container::COUPON_EMPTY; use crate::hll::container::Container; -use crate::hll::serialization::*; +use crate::hll::serialization::COMPACT_FLAG_MASK; +use crate::hll::serialization::CUR_MODE_LIST; +use crate::hll::serialization::EMPTY_FLAG_MASK; +use crate::hll::serialization::LIST_PREAMBLE_SIZE; +use crate::hll::serialization::LIST_PREINTS; +use crate::hll::serialization::SERIAL_VERSION; +use crate::hll::serialization::encode_mode_byte; /// List for sequential coupon storage with duplicate detection #[derive(Debug, Clone, PartialEq)] diff --git a/datasketches/src/hll/mod.rs b/datasketches/src/hll/mod.rs index 6f99a49..b3e1b36 100644 --- a/datasketches/src/hll/mod.rs +++ b/datasketches/src/hll/mod.rs @@ -74,7 +74,7 @@ //! //! # Usage //! -//! ```rust +//! ``` //! # use datasketches::hll::HllSketch; //! # use datasketches::hll::HllType; //! # use datasketches::common::NumStdDev; @@ -86,7 +86,7 @@ //! //! # Union //! -//! ```rust +//! ``` //! # use datasketches::hll::HllSketch; //! # use datasketches::hll::HllType; //! # use datasketches::hll::HllUnion; diff --git a/datasketches/src/hll/sketch.rs b/datasketches/src/hll/sketch.rs index ecf3ff1..9e9c5a1 100644 --- a/datasketches/src/hll/sketch.rs +++ b/datasketches/src/hll/sketch.rs @@ -23,8 +23,9 @@ use std::hash::Hash; use crate::codec::SketchSlice; +use crate::codec::assert::ensure_serial_version_is; +use crate::codec::assert::insufficient_data; use crate::codec::family::Family; -use crate::codec::utility::ensure_serial_version_is; use crate::common::NumStdDev; use crate::error::Error; use crate::hll::HllType; @@ -38,7 +39,21 @@ use crate::hll::coupon; use crate::hll::hash_set::HashSet; use crate::hll::list::List; use crate::hll::mode::Mode; -use crate::hll::serialization::*; +use crate::hll::serialization::COMPACT_FLAG_MASK; +use crate::hll::serialization::CUR_MODE_HLL; +use crate::hll::serialization::CUR_MODE_LIST; +use crate::hll::serialization::CUR_MODE_SET; +use crate::hll::serialization::EMPTY_FLAG_MASK; +use crate::hll::serialization::HASH_SET_PREINTS; +use crate::hll::serialization::HLL_PREINTS; +use crate::hll::serialization::LIST_PREINTS; +use crate::hll::serialization::OUT_OF_ORDER_FLAG_MASK; +use crate::hll::serialization::SERIAL_VERSION; +use crate::hll::serialization::TGT_HLL4; +use crate::hll::serialization::TGT_HLL6; +use crate::hll::serialization::TGT_HLL8; +use crate::hll::serialization::extract_cur_mode; +use crate::hll::serialization::extract_tgt_hll_type; /// A HyperLogLog sketch. /// @@ -257,26 +272,26 @@ impl HllSketch { /// assert!(decoded.estimate() >= 1.0); /// ``` pub fn deserialize(bytes: &[u8]) -> Result { - fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> Error { - move |_| Error::insufficient_data(tag) - } - let mut cursor = SketchSlice::new(bytes); // Read and validate preamble - let preamble_ints = cursor.read_u8().map_err(make_error("preamble_ints"))?; - let serial_version = cursor.read_u8().map_err(make_error("serial_version"))?; - let family_id = cursor.read_u8().map_err(make_error("family_id"))?; - let lg_config_k = cursor.read_u8().map_err(make_error("lg_config_k"))?; + let preamble_ints = cursor + .read_u8() + .map_err(insufficient_data("preamble_ints"))?; + let serial_version = cursor + .read_u8() + .map_err(insufficient_data("serial_version"))?; + let family_id = cursor.read_u8().map_err(insufficient_data("family_id"))?; + let lg_config_k = cursor.read_u8().map_err(insufficient_data("lg_config_k"))?; // lg_arr used in List/Set modes - let lg_arr = cursor.read_u8().map_err(make_error("lg_arr"))?; - let flags = cursor.read_u8().map_err(make_error("flags"))?; + let lg_arr = cursor.read_u8().map_err(insufficient_data("lg_arr"))?; + let flags = cursor.read_u8().map_err(insufficient_data("flags"))?; // The contextual state byte: // * coupon count in LIST mode // * cur_min in HLL mode // * unused in SET mode - let state = cursor.read_u8().map_err(make_error("state"))?; - let mode_byte = cursor.read_u8().map_err(make_error("mode"))?; + let state = cursor.read_u8().map_err(insufficient_data("state"))?; + let mode_byte = cursor.read_u8().map_err(insufficient_data("mode"))?; // Verify family ID Family::HLL.validate_id(family_id)?; diff --git a/datasketches/src/tdigest/mod.rs b/datasketches/src/tdigest/mod.rs index d1a80c5..b3bcf00 100644 --- a/datasketches/src/tdigest/mod.rs +++ b/datasketches/src/tdigest/mod.rs @@ -50,7 +50,7 @@ //! //! # Usage //! -//! ```rust +//! ``` //! # use datasketches::tdigest::TDigestMut; //! let mut sketch = TDigestMut::new(100); //! sketch.update(1.0); diff --git a/datasketches/src/tdigest/sketch.rs b/datasketches/src/tdigest/sketch.rs index d831f72..51e2b62 100644 --- a/datasketches/src/tdigest/sketch.rs +++ b/datasketches/src/tdigest/sketch.rs @@ -21,11 +21,19 @@ use std::num::NonZeroU64; use crate::codec::SketchBytes; use crate::codec::SketchSlice; +use crate::codec::assert::ensure_preamble_longs_in; +use crate::codec::assert::ensure_serial_version_is; +use crate::codec::assert::insufficient_data; use crate::codec::family::Family; -use crate::codec::utility::ensure_preamble_longs_in; -use crate::codec::utility::ensure_serial_version_is; use crate::error::Error; -use crate::tdigest::serialization::*; +use crate::tdigest::serialization::COMPAT_DOUBLE; +use crate::tdigest::serialization::COMPAT_FLOAT; +use crate::tdigest::serialization::FLAGS_IS_EMPTY; +use crate::tdigest::serialization::FLAGS_IS_SINGLE_VALUE; +use crate::tdigest::serialization::FLAGS_REVERSE_MERGE; +use crate::tdigest::serialization::PREAMBLE_LONGS_EMPTY_OR_SINGLE; +use crate::tdigest::serialization::PREAMBLE_LONGS_MULTIPLE; +use crate::tdigest::serialization::SERIAL_VERSION; /// The default value of K if one is not specified. const DEFAULT_K: u16 = 200; @@ -487,15 +495,15 @@ impl TDigestMut { /// assert_eq!(decoded.max_value(), Some(2.0)); /// ``` pub fn deserialize(bytes: &[u8], is_f32: bool) -> Result { - fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> Error { - move |_| Error::insufficient_data(tag) - } - let mut cursor = SketchSlice::new(bytes); - let preamble_longs = cursor.read_u8().map_err(make_error("preamble_longs"))?; - let serial_version = cursor.read_u8().map_err(make_error("serial_version"))?; - let family_id = cursor.read_u8().map_err(make_error("family_id"))?; + let preamble_longs = cursor + .read_u8() + .map_err(insufficient_data("preamble_longs"))?; + let serial_version = cursor + .read_u8() + .map_err(insufficient_data("serial_version"))?; + let family_id = cursor.read_u8().map_err(insufficient_data("family_id"))?; if let Err(err) = Family::TDIGEST.validate_id(family_id) { return if preamble_longs == 0 && serial_version == 0 && family_id == 0 { Self::deserialize_compat(bytes) @@ -504,11 +512,11 @@ impl TDigestMut { }; } ensure_serial_version_is(SERIAL_VERSION, serial_version)?; - let k = cursor.read_u16_le().map_err(make_error("k"))?; + let k = cursor.read_u16_le().map_err(insufficient_data("k"))?; if k < 10 { return Err(Error::deserial(format!("k must be at least 10, got {k}"))); } - let flags = cursor.read_u8().map_err(make_error("flags"))?; + let flags = cursor.read_u8().map_err(insufficient_data("flags"))?; let is_empty = (flags & FLAGS_IS_EMPTY) != 0; let is_single_value = (flags & FLAGS_IS_SINGLE_VALUE) != 0; let expected_preamble_longs = if is_empty || is_single_value { @@ -517,7 +525,9 @@ impl TDigestMut { PREAMBLE_LONGS_MULTIPLE }; ensure_preamble_longs_in(&[expected_preamble_longs], preamble_longs)?; - cursor.read_u16_le().map_err(make_error(""))?; // unused + cursor + .read_u16_le() + .map_err(insufficient_data(""))?; // unused if is_empty { return Ok(TDigestMut::new(k)); } @@ -525,9 +535,13 @@ impl TDigestMut { let reverse_merge = (flags & FLAGS_REVERSE_MERGE) != 0; if is_single_value { let value = if is_f32 { - cursor.read_f32_le().map_err(make_error("single_value"))? as f64 + cursor + .read_f32_le() + .map_err(insufficient_data("single_value"))? as f64 } else { - cursor.read_f64_le().map_err(make_error("single_value"))? + cursor + .read_f64_le() + .map_err(insufficient_data("single_value"))? }; check_non_nan(value, "single_value")?; check_finite(value, "single_value")?; @@ -544,17 +558,21 @@ impl TDigestMut { vec![], )); } - let num_centroids = cursor.read_u32_le().map_err(make_error("num_centroids"))? as usize; - let num_buffered = cursor.read_u32_le().map_err(make_error("num_buffered"))? as usize; + let num_centroids = cursor + .read_u32_le() + .map_err(insufficient_data("num_centroids"))? as usize; + let num_buffered = cursor + .read_u32_le() + .map_err(insufficient_data("num_buffered"))? as usize; let (min, max) = if is_f32 { ( - cursor.read_f32_le().map_err(make_error("min"))? as f64, - cursor.read_f32_le().map_err(make_error("max"))? as f64, + cursor.read_f32_le().map_err(insufficient_data("min"))? as f64, + cursor.read_f32_le().map_err(insufficient_data("max"))? as f64, ) } else { ( - cursor.read_f64_le().map_err(make_error("min"))?, - cursor.read_f64_le().map_err(make_error("max"))?, + cursor.read_f64_le().map_err(insufficient_data("min"))?, + cursor.read_f64_le().map_err(insufficient_data("max"))?, ) }; check_non_nan(min, "min")?; @@ -564,13 +582,13 @@ impl TDigestMut { for _ in 0..num_centroids { let (mean, weight) = if is_f32 { ( - cursor.read_f32_le().map_err(make_error("mean"))? as f64, - cursor.read_u32_le().map_err(make_error("weight"))? as u64, + cursor.read_f32_le().map_err(insufficient_data("mean"))? as f64, + cursor.read_u32_le().map_err(insufficient_data("weight"))? as u64, ) } else { ( - cursor.read_f64_le().map_err(make_error("mean"))?, - cursor.read_u64_le().map_err(make_error("weight"))?, + cursor.read_f64_le().map_err(insufficient_data("mean"))?, + cursor.read_u64_le().map_err(insufficient_data("weight"))?, ) }; check_non_nan(mean, "centroid mean")?; @@ -582,9 +600,13 @@ impl TDigestMut { let mut buffer = Vec::with_capacity(num_buffered); for _ in 0..num_buffered { let value = if is_f32 { - cursor.read_f32_le().map_err(make_error("buffered_value"))? as f64 + cursor + .read_f32_le() + .map_err(insufficient_data("buffered_value"))? as f64 } else { - cursor.read_f64_le().map_err(make_error("buffered_value"))? + cursor + .read_f64_le() + .map_err(insufficient_data("buffered_value"))? }; check_non_nan(value, "buffered_value mean")?; check_finite(value, "buffered_value mean")?; diff --git a/datasketches/src/theta/mod.rs b/datasketches/src/theta/mod.rs index fdde037..4c96a3f 100644 --- a/datasketches/src/theta/mod.rs +++ b/datasketches/src/theta/mod.rs @@ -32,7 +32,7 @@ //! //! # Usage //! -//! ```rust +//! ``` //! # use datasketches::theta::ThetaSketch; //! let mut sketch = ThetaSketch::builder().build(); //! sketch.update("apple"); diff --git a/datasketches/src/theta/sketch.rs b/datasketches/src/theta/sketch.rs index 32f6e9a..1e8cbb4 100644 --- a/datasketches/src/theta/sketch.rs +++ b/datasketches/src/theta/sketch.rs @@ -24,8 +24,9 @@ use std::hash::Hash; use crate::codec::SketchBytes; use crate::codec::SketchSlice; +use crate::codec::assert::ensure_preamble_longs_in_range; +use crate::codec::assert::insufficient_data; use crate::codec::family::Family; -use crate::codec::utility::ensure_preamble_longs_in_range; use crate::common::NumStdDev; use crate::common::ResizeFactor; use crate::common::binomial_bounds; @@ -558,14 +559,14 @@ impl CompactThetaSketch { /// Deserializes a compact theta sketch from bytes using the provided expected seed. pub fn deserialize_with_seed(bytes: &[u8], seed: u64) -> Result { - fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> Error { - move |_| Error::insufficient_data(tag) - } - let mut cursor = SketchSlice::new(bytes); - let pre_longs = cursor.read_u8().map_err(make_error("preamble_longs"))?; - let ser_ver = cursor.read_u8().map_err(make_error("serial_version"))?; - let family_id = cursor.read_u8().map_err(make_error("family_id"))?; + let pre_longs = cursor + .read_u8() + .map_err(insufficient_data("preamble_longs"))?; + let ser_ver = cursor + .read_u8() + .map_err(insufficient_data("serial_version"))?; + let family_id = cursor.read_u8().map_err(insufficient_data("family_id"))?; Family::THETA.validate_id(family_id)?; @@ -593,9 +594,7 @@ impl CompactThetaSketch { ) -> Result, Error> { let mut entries = Vec::with_capacity(num_entries); for _ in 0..num_entries { - let hash = cursor - .read_u64_le() - .map_err(|_| Error::insufficient_data("entries"))?; + let hash = cursor.read_u64_le().map_err(insufficient_data("entries"))?; if hash == 0 || hash >= theta { return Err(Error::deserial("corrupted: invalid retained hash value")); } @@ -605,16 +604,20 @@ impl CompactThetaSketch { } fn deserialize_v1(mut cursor: SketchSlice<'_>, expected_seed: u64) -> Result { - fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> Error { - move |_| Error::insufficient_data(tag) - } - let seed_hash = compute_seed_hash(expected_seed); - cursor.read_u8().map_err(make_error(""))?; - cursor.read_u32_le().map_err(make_error(""))?; - let num_entries = cursor.read_u32_le().map_err(make_error("num_entries"))? as usize; - cursor.read_u32_le().map_err(make_error(""))?; - let theta = cursor.read_u64_le().map_err(make_error("theta_long"))?; + cursor.read_u8().map_err(insufficient_data(""))?; + cursor + .read_u32_le() + .map_err(insufficient_data(""))?; + let num_entries = cursor + .read_u32_le() + .map_err(insufficient_data("num_entries"))? as usize; + cursor + .read_u32_le() + .map_err(insufficient_data(""))?; + let theta = cursor + .read_u64_le() + .map_err(insufficient_data("theta_long"))?; let empty = num_entries == 0 && theta == MAX_THETA; if empty { @@ -643,13 +646,13 @@ impl CompactThetaSketch { mut cursor: SketchSlice<'_>, expected_seed: u64, ) -> Result { - fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> Error { - move |_| Error::insufficient_data(tag) - } - - cursor.read_u8().map_err(make_error(""))?; - cursor.read_u16_le().map_err(make_error(""))?; - let seed_hash = cursor.read_u16_le().map_err(make_error("seed_hash"))?; + cursor.read_u8().map_err(insufficient_data(""))?; + cursor + .read_u16_le() + .map_err(insufficient_data(""))?; + let seed_hash = cursor + .read_u16_le() + .map_err(insufficient_data("seed_hash"))?; let expected_seed_hash = compute_seed_hash(expected_seed); if seed_hash != expected_seed_hash { return Err(Error::deserial(format!( @@ -666,8 +669,13 @@ impl CompactThetaSketch { empty: true, }), V2_PREAMBLE_PRECISE => { - let num_entries = cursor.read_u32_le().map_err(make_error("num_entries"))? as usize; - cursor.read_u32_le().map_err(make_error(""))?; + let num_entries = cursor + .read_u32_le() + .map_err(insufficient_data("num_entries"))? + as usize; + cursor + .read_u32_le() + .map_err(insufficient_data(""))?; let entries = Self::read_entries(&mut cursor, num_entries, MAX_THETA)?; Ok(Self { entries, @@ -678,9 +686,16 @@ impl CompactThetaSketch { }) } V2_PREAMBLE_ESTIMATE => { - let num_entries = cursor.read_u32_le().map_err(make_error("num_entries"))? as usize; - cursor.read_u32_le().map_err(make_error(""))?; - let theta = cursor.read_u64_le().map_err(make_error("theta_long"))?; + let num_entries = cursor + .read_u32_le() + .map_err(insufficient_data("num_entries"))? + as usize; + cursor + .read_u32_le() + .map_err(insufficient_data(""))?; + let theta = cursor + .read_u64_le() + .map_err(insufficient_data("theta_long"))?; let empty = (num_entries == 0) && (theta == MAX_THETA); let entries = Self::read_entries(&mut cursor, num_entries, theta)?; Ok(Self { @@ -700,12 +715,13 @@ impl CompactThetaSketch { mut cursor: SketchSlice<'_>, expected_seed: u64, ) -> Result { - fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> Error { - move |_| Error::insufficient_data(tag) - } - cursor.read_u16_le().map_err(make_error(""))?; - let flags = cursor.read_u8().map_err(make_error("flags"))?; - let seed_hash = cursor.read_u16_le().map_err(make_error("seed_hash"))?; + cursor + .read_u16_le() + .map_err(insufficient_data(""))?; + let flags = cursor.read_u8().map_err(insufficient_data("flags"))?; + let seed_hash = cursor + .read_u16_le() + .map_err(insufficient_data("seed_hash"))?; let empty = (flags & serialization::FLAGS_IS_EMPTY) != 0; let mut theta = MAX_THETA; @@ -721,10 +737,16 @@ impl CompactThetaSketch { if pre_longs == 1 { num_entries = 1; } else { - num_entries = cursor.read_u32_le().map_err(make_error("num_entries"))?; - cursor.read_u32_le().map_err(make_error(""))?; + num_entries = cursor + .read_u32_le() + .map_err(insufficient_data("num_entries"))?; + cursor + .read_u32_le() + .map_err(insufficient_data(""))?; if pre_longs > 2 { - theta = cursor.read_u64_le().map_err(make_error("theta_long"))?; + theta = cursor + .read_u64_le() + .map_err(insufficient_data("theta_long"))?; } } entries = Self::read_entries(&mut cursor, num_entries as usize, theta)?; @@ -744,13 +766,12 @@ impl CompactThetaSketch { mut cursor: SketchSlice<'_>, expected_seed: u64, ) -> Result { - fn make_error(tag: &'static str) -> impl FnOnce(std::io::Error) -> Error { - move |_| Error::insufficient_data(tag) - } - let entry_bits = cursor.read_u8().map_err(make_error("entry_bits"))?; - let num_entries_bytes = cursor.read_u8().map_err(make_error("num_entries"))?; - let flags = cursor.read_u8().map_err(make_error("flags"))?; - let seed_hash = cursor.read_u16_le().map_err(make_error("seed_hash"))?; + let entry_bits = cursor.read_u8().map_err(insufficient_data("entry_bits"))?; + let num_entries_bytes = cursor.read_u8().map_err(insufficient_data("num_entries"))?; + let flags = cursor.read_u8().map_err(insufficient_data("flags"))?; + let seed_hash = cursor + .read_u16_le() + .map_err(insufficient_data("seed_hash"))?; let empty = (flags & serialization::FLAGS_IS_EMPTY) != 0; if !empty { let expected_seed_hash = compute_seed_hash(expected_seed); @@ -761,7 +782,9 @@ impl CompactThetaSketch { } } let theta = if pre_longs > 1 { - cursor.read_u64_le().map_err(make_error("theta_long"))? + cursor + .read_u64_le() + .map_err(insufficient_data("theta_long"))? } else { MAX_THETA }; @@ -769,7 +792,9 @@ impl CompactThetaSketch { // unpack num_entries let mut num_entries = 0usize; for i in 0..num_entries_bytes { - let entry_count_byte = cursor.read_u8().map_err(make_error("num_entries_byte"))?; + let entry_count_byte = cursor + .read_u8() + .map_err(insufficient_data("num_entries_byte"))?; num_entries |= (entry_count_byte as usize) << ((i as usize) << 3); } @@ -780,7 +805,7 @@ impl CompactThetaSketch { let mut block = vec![0u8; entry_bits as usize]; cursor .read_exact(&mut block) - .map_err(make_error("delta_block"))?; + .map_err(insufficient_data("delta_block"))?; unpack_bits_block(&mut entries[i..i + BLOCK_WIDTH], &block, entry_bits); i += BLOCK_WIDTH; } @@ -793,7 +818,7 @@ impl CompactThetaSketch { let mut tail = vec![0u8; bytes_needed]; cursor .read_exact(&mut tail) - .map_err(make_error("delta_tail"))?; + .map_err(insufficient_data("delta_tail"))?; let mut unpacker = BitUnpacker::new(&tail); for slot in entries.iter_mut().take(num_entries).skip(i) {