diff --git a/library/alloc/src/collections/binary_heap/extract_if.rs b/library/alloc/src/collections/binary_heap/extract_if.rs new file mode 100644 index 0000000000000..3aa50fc67e097 --- /dev/null +++ b/library/alloc/src/collections/binary_heap/extract_if.rs @@ -0,0 +1,89 @@ +use core::fmt; +use core::iter::FusedIterator; +use core::mem::ManuallyDrop; + +use super::BinaryHeap; +use crate::alloc::{Allocator, Global}; + +/// An iterator which uses a closure to determine if an element should be removed. +/// +/// This struct is created by [`BinaryHeap::extract_if`]. +/// See its documentation for more. +/// +/// # Example +/// +/// ``` +/// #![feature(binary_heap_extract_if)] +/// use crate::alloc::collections::BinaryHeap; +/// +/// let mut heap: BinaryHeap = (0..128).collect(); +/// let iter: Vec = heap.extract_if(|x| *x % 2 == 0).collect(); +#[unstable(feature = "binary_heap_extract_if", issue = "154721")] +#[must_use = "iterators are lazy and do nothing unless consumed; \ + use `retain_mut` or `extract_if().for_each(drop)` to remove and discard elements"] +pub struct ExtractIf< + 'a, + T: Ord, + F, + #[unstable(feature = "allocator_api", issue = "32838")] A: Allocator = Global, +> { + heap_ptr: *mut BinaryHeap, + extract_if: ManuallyDrop>, +} + +impl ExtractIf<'_, T, F, A> +where + F: FnMut(&mut T) -> bool, +{ + pub(super) fn new<'a>(heap: &'a mut BinaryHeap, predicate: F) -> ExtractIf<'a, T, F, A> { + // We need to keep a reference around to the heap so that we can + let heap_ptr: *mut BinaryHeap = heap; + let extract_if = ManuallyDrop::new(heap.data.extract_if(.., predicate)); + + ExtractIf { heap_ptr, extract_if } + } +} + +#[unstable(feature = "binary_heap_extract_if", issue = "154721")] +impl Iterator for ExtractIf<'_, T, F, A> +where + F: FnMut(&mut T) -> bool, +{ + type Item = T; + + fn next(&mut self) -> Option { + self.extract_if.next() + } +} + +#[unstable(feature = "binary_heap_extract_if", issue = "154721")] +impl<'a, T: Ord, F, A: Allocator> Drop for ExtractIf<'a, T, F, A> { + fn drop(&mut self) { + // SAFETY: We need to drop this before we rebuild the heap so that its descructor resets the vec info + // We also are only calling this hear during the drop of ExtractIf and then never using it again + unsafe { + ManuallyDrop::drop(&mut self.extract_if); + } + + // SAFETY: We only generate this ptr from a reference so we know that it is never null + let heap = unsafe { self.heap_ptr.as_mut_unchecked() }; + + // Removing some items from the heap almost certainly has invalidated its invariants, we need to fix this up here + heap.rebuild(); + } +} + +#[unstable(feature = "binary_heap_extract_if", issue = "154721")] +impl FusedIterator for ExtractIf<'_, T, F, A> where F: FnMut(&mut T) -> bool +{} + +#[unstable(feature = "binary_heap_extract_if", issue = "154721")] +impl fmt::Debug for ExtractIf<'_, T, F, A> +where + T: fmt::Debug, + A: Allocator, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ExtractIf").finish_non_exhaustive() + } +} diff --git a/library/alloc/src/collections/binary_heap/mod.rs b/library/alloc/src/collections/binary_heap/mod.rs index 4ddfcde57280e..c7ad4430d71c3 100644 --- a/library/alloc/src/collections/binary_heap/mod.rs +++ b/library/alloc/src/collections/binary_heap/mod.rs @@ -150,6 +150,11 @@ use core::num::NonZero; use core::ops::{Deref, DerefMut}; use core::{fmt, ptr}; +#[unstable(feature = "binary_heap_extract_if", issue = "154721")] +pub use self::extract_if::ExtractIf; + +mod extract_if; + use crate::alloc::Global; use crate::collections::TryReserveError; use crate::slice; @@ -1039,6 +1044,19 @@ impl BinaryHeap { DrainSorted { inner: self } } + /// Creates an iterator which uses a closure to determine if an element should be removed. + /// The items are checked in sorted order + /// + /// If the closure returns `true`, the element is marked to be removed and yielded + #[unstable(feature = "binary_heap_extract_if", issue = "154721")] + #[must_use] + pub fn extract_if(&mut self, predicate: F) -> ExtractIf<'_, T, F, A> + where + F: FnMut(&mut T) -> bool, + { + ExtractIf::new(self, predicate) + } + /// Retains only the elements specified by the predicate. /// /// In other words, remove all elements `e` for which `f(&e)` returns diff --git a/library/alloctests/tests/collections/binary_heap.rs b/library/alloctests/tests/collections/binary_heap.rs index e1484c32a4f8a..c712918a98b35 100644 --- a/library/alloctests/tests/collections/binary_heap.rs +++ b/library/alloctests/tests/collections/binary_heap.rs @@ -590,3 +590,65 @@ fn panic_safe() { } } } + +#[test] +fn given_a_binary_heap_can_create_an_extract_if_iterator() { + let mut heap: BinaryHeap = BinaryHeap::new(); + let iter = heap.extract_if(|_| unreachable!("there's nothing to decide on")); + + iter.for_each(drop); + assert!(heap.is_empty()) +} + +#[test] +fn given_some_binary_heap_with_one_item_when_extracting_if_true_extracts_all_items() { + let mut heap: BinaryHeap = BinaryHeap::new(); + heap.push(10); + let v: Vec = heap.extract_if(|_| true).collect(); + + assert!(heap.is_empty()); + assert_eq!(v, vec![10]); +} + +#[test] +fn given_some_binary_heap_with_three_items_when_extracting_if_true_extracts_all_items_in_arbitrary_order() + { + let mut heap = BinaryHeap::new(); + heap.push(10); + heap.push(15); + heap.push(11); + let v: Vec<_> = heap.extract_if(|_| true).collect(); + + assert!(heap.is_empty()); + assert_eq!(v, vec![15, 10, 11]); +} + +#[test] +fn given_some_binary_heap_with_some_items_when_extracting_if_even_extracts_just_even_items() { + let mut heap = BinaryHeap::new(); + heap.push(10); + heap.push(15); + heap.push(11); + let v: Vec<_> = heap.extract_if(|&mut x| x % 2 == 0).collect(); + + assert_eq!(v, vec![10]); + assert_eq!(heap.pop(), Some(15)); + assert_eq!(heap.pop(), Some(11)); + assert_eq!(heap.pop(), None); +} + +#[test] +fn given_some_binary_heap_with_some_items_when_extracting_if_when_dropping_without_iterating_leaves_heap_in_valid_state() + { + let mut heap = BinaryHeap::new(); + heap.push(10); + heap.push(15); + heap.push(11); + + drop(heap.extract_if(|&mut x| x % 2 == 0)); + + assert_eq!(heap.pop(), Some(15)); + assert_eq!(heap.pop(), Some(11)); + assert_eq!(heap.pop(), Some(10)); + assert_eq!(heap.pop(), None); +} diff --git a/library/alloctests/tests/lib.rs b/library/alloctests/tests/lib.rs index 699a5010282b0..61d696b33020a 100644 --- a/library/alloctests/tests/lib.rs +++ b/library/alloctests/tests/lib.rs @@ -1,4 +1,5 @@ #![feature(allocator_api)] +#![feature(binary_heap_extract_if)] #![feature(binary_heap_pop_if)] #![feature(const_heap)] #![feature(deque_extend_front)]