From 343ac39a1ad546e1f7ca1772f8adecdefd958fc7 Mon Sep 17 00:00:00 2001 From: Will Hopkins Date: Sun, 29 Jun 2025 10:44:27 -0700 Subject: [PATCH] feat: add `drain` method --- src/lib.rs | 4 +- src/map.rs | 75 ++++++++++++++++++ src/raw/mod.rs | 181 +++++++++++++++++++++++++++++++++++++++++++ src/raw/utils/mod.rs | 6 ++ tests/basic.rs | 72 +++++++++++++++++ 5 files changed, 336 insertions(+), 2 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 51998bd..9bd0848 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -236,8 +236,8 @@ mod serde_impls; pub use equivalent::Equivalent; pub use map::{ - Compute, HashMap, HashMapBuilder, HashMapRef, Iter, Keys, OccupiedError, Operation, ResizeMode, - Values, + Compute, Drain, HashMap, HashMapBuilder, HashMapRef, Iter, Keys, OccupiedError, Operation, + ResizeMode, Values, }; pub use seize::{Guard, LocalGuard, OwnedGuard}; pub use set::{HashSet, HashSetBuilder, HashSetRef}; diff --git a/src/map.rs b/src/map.rs index f149a25..f172c83 100644 --- a/src/map.rs +++ b/src/map.rs @@ -946,6 +946,38 @@ where self.raw.clear(self.raw.verify(guard)) } + /// Drains the map, removing all key-value pairs and returning an iterator + /// over the removed values. + /// + /// The iterator yields values in arbitrary order. The key-value pairs are + /// removed even if the iterator is not consumed entirely. + /// + /// Note that this method will block until any in-progress resizes are + /// completed before proceeding. See the [consistency](crate#consistency) + /// section for details. + /// + /// # Examples + /// + /// ``` + /// use papaya::HashMap; + /// + /// let map = HashMap::new(); + /// map.pin().insert(1, "a"); + /// map.pin().insert(2, "b"); + /// + /// let values: Vec<_> = map.pin().drain().collect(); + /// assert_eq!(values.len(), 2); + /// assert!(values.contains(&"a")); + /// assert!(values.contains(&"b")); + /// assert!(map.is_empty()); + /// ``` + #[inline] + pub fn drain<'g, G: Guard>(&'g self, guard: &'g G) -> Drain<'g, K, V, S, G> { + Drain { + raw: self.raw.drain(self.raw.verify(guard)), + } + } + /// Retains only the elements specified by the predicate. /// /// In other words, remove all pairs `(k, v)` for which `f(&k, &v)` returns `false`. @@ -1496,6 +1528,17 @@ where self.map.raw.clear(&self.guard) } + /// Drains the map, removing all key-value pairs and returning an iterator + /// over the removed values. + /// + /// See [`HashMap::drain`] for details. + #[inline] + pub fn drain(&self) -> Drain> { + Drain { + raw: self.map.raw.drain(self.map.raw.verify(&self.guard)), + } + } + /// Retains only the elements specified by the predicate. /// /// See [`HashMap::retain`] for details. @@ -1667,3 +1710,35 @@ where f.debug_tuple("Values").field(&self.iter).finish() } } + +/// An iterator over the values of a drained map. +/// +/// This struct is created by the [`drain`](HashMap::drain) method on [`HashMap`]. See its documentation for details. +pub struct Drain<'a, K, V, S, G> { + raw: raw::Drain<'a, K, V, S, MapGuard>, +} + +impl<'a, K, V, S, G> Iterator for Drain<'a, K, V, S, G> +where + K: Hash + Eq, + S: BuildHasher, + G: Guard, +{ + type Item = V; + + #[inline] + fn next(&mut self) -> Option { + self.raw.next() + } +} + +impl fmt::Debug for Drain<'_, K, V, S, G> +where + K: fmt::Debug, + V: fmt::Debug, + G: Guard, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Drain").field(&"...").finish() + } +} diff --git a/src/raw/mod.rs b/src/raw/mod.rs index 9aa8f75..c8f8b21 100644 --- a/src/raw/mod.rs +++ b/src/raw/mod.rs @@ -1098,6 +1098,35 @@ where } } + /// Drains the map, removing all key-value pairs and returning an iterator + /// over the removed values. + #[inline] + pub fn drain<'g, G: VerifiedGuard>(&'g self, guard: &'g G) -> Drain<'g, K, V, S, G> { + let root = self.root(guard); + + // The table has not been initialized yet, return a dummy iterator. + if root.raw.is_null() { + return Drain { + map: self, + guard, + current_table: root, + i: 0, + finished: true, + }; + } + + // Get a clean copy of the table to iterate over. + let table = self.linearize(root, guard); + + Drain { + map: self, + guard, + current_table: table, + i: 0, + finished: false, + } + } + /// Retains only the elements specified by the predicate. #[inline] pub fn retain(&self, mut f: F, guard: &impl VerifiedGuard) @@ -2838,3 +2867,155 @@ mod meta { (top7 & 0x7f) as u8 } } + +/// An iterator that drains entries from a table. +pub struct Drain<'a, K, V, S, G> { + map: &'a HashMap, + guard: &'a G, + current_table: Table>, + i: usize, + finished: bool, +} + +// Safety: An iterator holds a shared reference to the HashMap +// and Guard, and outputs shared references to keys and values. +// Thus everything must be `Sync` for the iterator to be `Send` +// or `Sync`. +// +// It is not possible to obtain an owned key, value, or guard +// from an iterator, so `Send` is not a required bound. +unsafe impl Send for Drain<'_, K, V, S, G> +where + K: Sync, + V: Sync, + S: Sync, + G: Sync, +{ +} + +unsafe impl Sync for Drain<'_, K, V, S, G> +where + K: Sync, + V: Sync, + S: Sync, + G: Sync, +{ +} + +impl<'a, K, V, S, G> Iterator for Drain<'a, K, V, S, G> +where + K: Hash + Eq, + S: BuildHasher, + G: VerifiedGuard, +{ + type Item = V; + + #[inline] + fn next(&mut self) -> Option { + if self.finished { + return None; + } + + let mut needs_resize_completion = false; + + loop { + let table = self.current_table; + + // Check if we've gone through all entries + if self.i >= table.len() { + // Check if there are any entries being copied that we need to handle + if needs_resize_completion { + // Complete the resize and switch to the new table + let new_table = self.map.help_copy(true, &table, self.guard); + self.current_table = new_table; + needs_resize_completion = false; + self.i = 0; + continue; + } else { + self.finished = true; + return None; + } + } + + // Load the entry metadata first to ensure consistency + // + // Safety: We verified that `self.i` is in-bounds above. + let meta = unsafe { table.meta(self.i) }.load(Ordering::Acquire); + + // The entry is empty or deleted. + if matches!(meta, meta::EMPTY | meta::TOMBSTONE) { + self.i += 1; + continue; + } + + // Load the entry. + // + // Safety: We verified that `self.i` is in-bounds above. + let entry = self + .guard + .protect(unsafe { table.entry(self.i) }, Ordering::Acquire) + .unpack(); + + // The entry was deleted. + if entry.ptr.is_null() { + self.i += 1; + continue; + } + + // Found a non-empty entry being copied. + if entry.tag() & Entry::<(), ()>::COPYING != 0 { + // Skip this entry for now, we'll handle it after the resize + self.i += 1; + needs_resize_completion = true; + continue; + } + + // Try to delete the entry atomically. + // + // Safety: `self.i` is in bounds for the table length. + let result = unsafe { + table.entry(self.i).compare_exchange( + entry.raw, + Entry::TOMBSTONE, + Ordering::Release, + Ordering::Acquire, + ) + }; + + match result { + // Successfully deleted the entry. + Ok(_) => { + // Update the metadata table. + // + // Safety: `self.i` is in bounds for the table length. + unsafe { table.meta(self.i).store(meta::TOMBSTONE, Ordering::Release) }; + + // Decrement the table length. + self.map + .count + .get(self.guard) + .fetch_sub(1, Ordering::Relaxed); + + // Safety: We performed a protected load of the pointer using a verified guard with + // `Acquire` and ensured that it is non-null, meaning it is valid for reads as long + // as we hold the guard. + let entry_ref = unsafe { &(*entry.ptr) }; + + // Extract the value before retiring the entry. + let value = unsafe { ptr::read(&entry_ref.value) }; + + // Safety: The entry is now unreachable from this table due to the CAS above. + unsafe { self.map.defer_retire(entry, &table, self.guard) }; + + self.i += 1; + return Some(value); + } + + // Lost to a concurrent update, retry without advancing i. + Err(_) => { + continue; + } + } + } + } +} diff --git a/src/raw/utils/mod.rs b/src/raw/utils/mod.rs index ce36dc8..aaa3fca 100644 --- a/src/raw/utils/mod.rs +++ b/src/raw/utils/mod.rs @@ -67,6 +67,12 @@ where } } +impl core::fmt::Debug for MapGuard { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("MapGuard").finish() + } +} + /// Pads and aligns a value to the length of a cache line. /// // Source: https://github.com/crossbeam-rs/crossbeam/blob/0f81a6957588ddca9973e32e92e7e94abdad801e/crossbeam-utils/src/cache_padded.rs#L63. diff --git a/tests/basic.rs b/tests/basic.rs index 5e48414..0d12457 100644 --- a/tests/basic.rs +++ b/tests/basic.rs @@ -30,6 +30,78 @@ fn clear() { }); } +#[test] +fn drain() { + with_map::(|map| { + let map = map(); + let guard = map.guard(); + + // Insert some values + map.insert(0, 10, &guard); + map.insert(1, 11, &guard); + map.insert(2, 12, &guard); + map.insert(3, 13, &guard); + map.insert(4, 14, &guard); + + assert_eq!(map.len(), 5); + + // Drain all values + let drained: Vec<_> = map.drain(&guard).collect(); + + // Map should be empty after drain + assert!(map.is_empty()); + assert_eq!(map.len(), 0); + + // Check that all values were drained + assert_eq!(drained.len(), 5); + assert!(drained.contains(&10)); + assert!(drained.contains(&11)); + assert!(drained.contains(&12)); + assert!(drained.contains(&13)); + assert!(drained.contains(&14)); + }); +} + +#[test] +fn drain_empty() { + with_map::(|map| { + let map = map(); + + // Drain empty map + let drained: Vec<_> = map.drain(&map.guard()).collect(); + + assert!(map.is_empty()); + assert_eq!(drained.len(), 0); + }); +} + +#[test] +fn drain_pinned() { + with_map::(|map| { + let map = map(); + + // Insert some values using pinned API + map.pin().insert(0, 100); + map.pin().insert(1, 200); + map.pin().insert(2, 300); + + assert_eq!(map.len(), 3); + + // Drain all values using pinned API + let drained: Vec<_> = map.pin().drain().collect(); + + // Map should be empty after drain + assert!(map.is_empty()); + assert_eq!(map.len(), 0); + + // Check that all values were drained + assert_eq!(drained.len(), 3); + assert!(drained.contains(&100)); + assert!(drained.contains(&200)); + assert!(drained.contains(&300)); + }); +} + #[test] fn insert() { with_map::(|map| {