Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions library/alloc/src/collections/binary_heap/extract_if.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
use core::fmt;
use core::iter::FusedIterator;
use core::mem::ManuallyDrop;

use super::BinaryHeap;
use crate::alloc::{Allocator, Global};

/// An iterator which uses a closure to determine if an element should be removed.
///
/// This struct is created by [`BinaryHeap::extract_if`].
/// See its documentation for more.
///
/// # Example
///
/// ```
/// #![feature(binary_heap_extract_if)]
/// use crate::alloc::collections::BinaryHeap;
///
/// let mut heap: BinaryHeap<u32> = (0..128).collect();
/// let iter: Vec<u32> = heap.extract_if(|x| *x % 2 == 0).collect();
#[unstable(feature = "binary_heap_extract_if", issue = "154721")]
#[must_use = "iterators are lazy and do nothing unless consumed; \
use `retain_mut` or `extract_if().for_each(drop)` to remove and discard elements"]
pub struct ExtractIf<
'a,
T: Ord,
F,
#[unstable(feature = "allocator_api", issue = "32838")] A: Allocator = Global,
> {
heap_ptr: *mut BinaryHeap<T, A>,
extract_if: ManuallyDrop<super::vec::ExtractIf<'a, T, F, A>>,
}

impl<T: Ord, F, A: Allocator> ExtractIf<'_, T, F, A>
where
F: FnMut(&mut T) -> bool,
{
pub(super) fn new<'a>(heap: &'a mut BinaryHeap<T, A>, predicate: F) -> ExtractIf<'a, T, F, A> {
// We need to keep a reference around to the heap so that we can
let heap_ptr: *mut BinaryHeap<T, A> = heap;
let extract_if = ManuallyDrop::new(heap.data.extract_if(.., predicate));

ExtractIf { heap_ptr, extract_if }
}
}

#[unstable(feature = "binary_heap_extract_if", issue = "154721")]
impl<T: Ord, F, A: Allocator> Iterator for ExtractIf<'_, T, F, A>
where
F: FnMut(&mut T) -> bool,
{
type Item = T;

fn next(&mut self) -> Option<Self::Item> {
self.extract_if.next()
}
}

#[unstable(feature = "binary_heap_extract_if", issue = "154721")]
impl<'a, T: Ord, F, A: Allocator> Drop for ExtractIf<'a, T, F, A> {
fn drop(&mut self) {
// SAFETY: We need to drop this before we rebuild the heap so that its descructor resets the vec info
// We also are only calling this hear during the drop of ExtractIf and then never using it again
unsafe {
ManuallyDrop::drop(&mut self.extract_if);
}

// SAFETY: We only generate this ptr from a reference so we know that it is never null
let heap = unsafe { self.heap_ptr.as_mut_unchecked() };

// Removing some items from the heap almost certainly has invalidated its invariants, we need to fix this up here
heap.rebuild();
}
}

#[unstable(feature = "binary_heap_extract_if", issue = "154721")]
impl<T: Ord, F, A: Allocator> FusedIterator for ExtractIf<'_, T, F, A> where F: FnMut(&mut T) -> bool
{}

#[unstable(feature = "binary_heap_extract_if", issue = "154721")]
impl<T: Ord, F, A> fmt::Debug for ExtractIf<'_, T, F, A>
where
T: fmt::Debug,
A: Allocator,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ExtractIf").finish_non_exhaustive()
}
}
18 changes: 18 additions & 0 deletions library/alloc/src/collections/binary_heap/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,11 @@ use core::num::NonZero;
use core::ops::{Deref, DerefMut};
use core::{fmt, ptr};

#[unstable(feature = "binary_heap_extract_if", issue = "154721")]
pub use self::extract_if::ExtractIf;

mod extract_if;

use crate::alloc::Global;
use crate::collections::TryReserveError;
use crate::slice;
Expand Down Expand Up @@ -1039,6 +1044,19 @@ impl<T: Ord, A: Allocator> BinaryHeap<T, A> {
DrainSorted { inner: self }
}

/// Creates an iterator which uses a closure to determine if an element should be removed.
/// The items are checked in sorted order
///
/// If the closure returns `true`, the element is marked to be removed and yielded
#[unstable(feature = "binary_heap_extract_if", issue = "154721")]
#[must_use]
pub fn extract_if<F>(&mut self, predicate: F) -> ExtractIf<'_, T, F, A>
Comment thread
Nokel81 marked this conversation as resolved.
where
F: FnMut(&mut T) -> bool,
{
ExtractIf::new(self, predicate)
}

/// Retains only the elements specified by the predicate.
///
/// In other words, remove all elements `e` for which `f(&e)` returns
Expand Down
62 changes: 62 additions & 0 deletions library/alloctests/tests/collections/binary_heap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -590,3 +590,65 @@ fn panic_safe() {
}
}
}

#[test]
fn given_a_binary_heap_can_create_an_extract_if_iterator() {
let mut heap: BinaryHeap<usize> = BinaryHeap::new();
let iter = heap.extract_if(|_| unreachable!("there's nothing to decide on"));

iter.for_each(drop);
assert!(heap.is_empty())
}

#[test]
fn given_some_binary_heap_with_one_item_when_extracting_if_true_extracts_all_items() {
let mut heap: BinaryHeap<usize> = BinaryHeap::new();
heap.push(10);
let v: Vec<usize> = heap.extract_if(|_| true).collect();

assert!(heap.is_empty());
assert_eq!(v, vec![10]);
}

#[test]
fn given_some_binary_heap_with_three_items_when_extracting_if_true_extracts_all_items_in_arbitrary_order()
{
let mut heap = BinaryHeap::new();
heap.push(10);
heap.push(15);
heap.push(11);
let v: Vec<_> = heap.extract_if(|_| true).collect();

assert!(heap.is_empty());
assert_eq!(v, vec![15, 10, 11]);
}

#[test]
fn given_some_binary_heap_with_some_items_when_extracting_if_even_extracts_just_even_items() {
let mut heap = BinaryHeap::new();
heap.push(10);
heap.push(15);
heap.push(11);
let v: Vec<_> = heap.extract_if(|&mut x| x % 2 == 0).collect();

assert_eq!(v, vec![10]);
assert_eq!(heap.pop(), Some(15));
assert_eq!(heap.pop(), Some(11));
assert_eq!(heap.pop(), None);
}

#[test]
fn given_some_binary_heap_with_some_items_when_extracting_if_when_dropping_without_iterating_leaves_heap_in_valid_state()
{
let mut heap = BinaryHeap::new();
heap.push(10);
heap.push(15);
heap.push(11);

drop(heap.extract_if(|&mut x| x % 2 == 0));

assert_eq!(heap.pop(), Some(15));
assert_eq!(heap.pop(), Some(11));
assert_eq!(heap.pop(), Some(10));
assert_eq!(heap.pop(), None);
}
1 change: 1 addition & 0 deletions library/alloctests/tests/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#![feature(allocator_api)]
#![feature(binary_heap_extract_if)]
#![feature(binary_heap_pop_if)]
#![feature(const_heap)]
#![feature(deque_extend_front)]
Expand Down
Loading