From 71e2b1c64d7ebea5b7680df01ce6d4837c4c0f37 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Thu, 16 Apr 2026 11:07:19 +0800 Subject: [PATCH 1/5] Optimize distinct hash aggregation for clickbench q4/q5 --- .../src/binary_view_map.rs | 359 ++++++++++----- .../src/aggregates/group_values/mod.rs | 90 +++- .../group_values/single_group_by/boolean.rs | 26 ++ .../group_values/single_group_by/bytes.rs | 205 ++++++--- .../single_group_by/bytes_view.rs | 209 ++++++--- .../group_values/single_group_by/primitive.rs | 425 ++++++++++++++---- .../physical-plan/src/aggregates/row_hash.rs | 149 +++--- .../physical-plan/src/recursive_query.rs | 2 +- 8 files changed, 1079 insertions(+), 386 deletions(-) diff --git a/datafusion/physical-expr-common/src/binary_view_map.rs b/datafusion/physical-expr-common/src/binary_view_map.rs index abc3e28f82627..b106bcbbf81c3 100644 --- a/datafusion/physical-expr-common/src/binary_view_map.rs +++ b/datafusion/physical-expr-common/src/binary_view_map.rs @@ -21,12 +21,12 @@ use crate::binary_map::OutputType; use arrow::array::NullBufferBuilder; use arrow::array::cast::AsArray; use arrow::array::{Array, ArrayRef, BinaryViewArray, ByteView, make_view}; -use arrow::buffer::{Buffer, ScalarBuffer}; +use arrow::buffer::{Buffer, NullBuffer, ScalarBuffer}; use arrow::datatypes::{BinaryViewType, ByteViewType, DataType, StringViewType}; use datafusion_common::hash_utils::RandomState; -use datafusion_common::hash_utils::create_hashes; -use datafusion_common::utils::proxy::{HashTableAllocExt, VecAllocExt}; +use datafusion_common::utils::proxy::HashTableAllocExt; use std::fmt::Debug; +use std::hash::BuildHasher; use std::mem::size_of; use std::sync::Arc; @@ -139,8 +139,8 @@ where /// random state used to generate hashes random_state: RandomState, - /// buffer that stores hash values (reused across batches to save allocations) - hashes_buffer: Vec, + /// Number of new non-null entries inserted by the previous batch. + last_batch_new_entries: usize, /// `(payload, null_index)` for the 'null' value, if any /// NOTE null_index is the logical index in the final array, not the index /// in the buffer @@ -164,7 +164,7 @@ where completed: Vec::new(), nulls: NullBufferBuilder::new(0), random_state: RandomState::default(), - hashes_buffer: vec![], + last_batch_new_entries: 0, null: None, } } @@ -252,27 +252,15 @@ where OP: FnMut(V), B: ByteViewType, { - // step 1: compute hashes - let batch_hashes = &mut self.hashes_buffer; - batch_hashes.clear(); - batch_hashes.resize(values.len(), 0); - create_hashes([values], &self.random_state, batch_hashes) - // hash is supported for all types and create_hashes only - // returns errors for unsupported types - .unwrap(); - - // step 2: insert each value into the set, if not already present let values = values.as_byte_view::(); - - // Get raw views buffer for direct comparison let input_views = values.views(); - - // Ensure lengths are equivalent - assert_eq!(values.len(), self.hashes_buffer.len()); + let input_buffers = values.data_buffers(); + let starting_len = self.map.len(); + self.reserve_for_batch(values.len() - values.null_count()); + let mut adopted_buffer_start = None; for i in 0..values.len() { let view_u128 = input_views[i]; - let hash = self.hashes_buffer[i]; // handle null value via validity bitmap check if values.is_null(i) { @@ -290,47 +278,29 @@ where continue; } - // Extract length from the view (first 4 bytes of u128 in little-endian) let len = view_u128 as u32; + let input_value = (len > 12).then(|| { + let value = values.value(i); + let bytes: &[u8] = value.as_ref(); + bytes + }); + let hash = hash_input_view(&self.random_state, view_u128, input_value); - // Check if value already exists let maybe_payload = { - // Borrow completed and in_progress for comparison let completed = &self.completed; let in_progress = &self.in_progress; self.map .find(hash, |header| { - if header.hash != hash { - return false; - } - - // Fast path: inline strings can be compared directly - if len <= 12 { - return header.view == view_u128; - } - - // For larger strings: first compare the 4-byte prefix - let stored_prefix = (header.view >> 32) as u32; - let input_prefix = (view_u128 >> 32) as u32; - if stored_prefix != input_prefix { - return false; - } - - // Prefix matched - compare full bytes - let byte_view = ByteView::from(header.view); - let stored_len = byte_view.length as usize; - let buffer_index = byte_view.buffer_index as usize; - let offset = byte_view.offset as usize; - - let stored_value = if buffer_index < completed.len() { - &completed[buffer_index].as_slice() - [offset..offset + stored_len] - } else { - &in_progress[offset..offset + stored_len] - }; - let input_value: &[u8] = values.value(i).as_ref(); - stored_value == input_value + view_matches( + hash, + view_u128, + input_value, + header.hash, + header.view, + completed, + in_progress, + ) }) .map(|entry| entry.payload) }; @@ -338,12 +308,40 @@ where let payload = if let Some(payload) = maybe_payload { payload } else { - // no existing value, make a new one - let value: &[u8] = values.value(i).as_ref(); + let value = input_value.unwrap_or_else(|| { + let value = values.value(i); + let bytes: &[u8] = value.as_ref(); + bytes + }); let payload = make_payload_fn(Some(value)); - // Create view pointing to our buffers - let new_view = self.append_value(value); + let new_view = if len <= 12 { + append_view(view_u128, Some(&mut self.views), Some(&mut self.nulls)) + } else if let Some(starting_buffer) = adopted_buffer_start.or_else(|| { + (!input_buffers.is_empty()).then(|| { + adopt_input_buffers( + input_buffers, + &mut self.in_progress, + &mut self.completed, + ) + }) + }) { + adopted_buffer_start = Some(starting_buffer); + append_input_view( + view_u128, + starting_buffer, + Some(&mut self.views), + Some(&mut self.nulls), + ) + } else { + append_value( + value, + &mut self.in_progress, + &mut self.completed, + Some(&mut self.views), + Some(&mut self.nulls), + ) + }; let new_header = Entry { view: new_view, hash, @@ -356,64 +354,62 @@ where }; observe_payload_fn(payload); } + + self.last_batch_new_entries = self.map.len() - starting_len; } - /// Converts this set into a `StringViewArray`, or `BinaryViewArray`, - /// containing each distinct value - /// that was inserted. This is done without copying the values. - /// - /// The values are guaranteed to be returned in the same order in which - /// they were first seen. - pub fn into_state(mut self) -> ArrayRef { - // Flush any remaining in-progress buffer - if !self.in_progress.is_empty() { - let flushed = std::mem::take(&mut self.in_progress); - self.completed.push(Buffer::from_vec(flushed)); + fn flush_in_progress(&mut self) { + if self.in_progress.is_empty() { + return; } - // Build null buffer if we have any nulls - let null_buffer = self.nulls.finish(); - - let views = ScalarBuffer::from(self.views); - let array = - unsafe { BinaryViewArray::new_unchecked(views, self.completed, null_buffer) }; + let flushed = std::mem::replace( + &mut self.in_progress, + Vec::with_capacity(BYTE_VIEW_MAX_BLOCK_SIZE), + ); + self.completed.push(Buffer::from_vec(flushed)); + } - match self.output_type { - OutputType::BinaryView => Arc::new(array), - OutputType::Utf8View => { - // SAFETY: all input was valid utf8 - let array = unsafe { array.to_string_view_unchecked() }; - Arc::new(array) - } - _ => unreachable!("Utf8/Binary should use `ArrowBytesMap`"), + fn reserve_for_batch(&mut self, non_null_rows: usize) { + if non_null_rows == 0 { + return; } - } - /// Append a value to our buffers and return the view pointing to it - fn append_value(&mut self, value: &[u8]) -> u128 { - let len = value.len(); - let view = if len <= 12 { - make_view(value, 0, 0) + let expected_new_entries = if self.last_batch_new_entries == 0 { + non_null_rows } else { - // Ensure buffer is big enough - if self.in_progress.len() + len > BYTE_VIEW_MAX_BLOCK_SIZE { - let flushed = std::mem::replace( - &mut self.in_progress, - Vec::with_capacity(BYTE_VIEW_MAX_BLOCK_SIZE), - ); - self.completed.push(Buffer::from_vec(flushed)); - } + self.last_batch_new_entries + .saturating_mul(2) + .min(non_null_rows) + }; - let buffer_index = self.completed.len() as u32; - let offset = self.in_progress.len() as u32; - self.in_progress.extend_from_slice(value); + let remaining_capacity = self.map.capacity().saturating_sub(self.map.len()); + let additional = expected_new_entries.saturating_sub(remaining_capacity); + if additional == 0 { + return; + } - make_view(value, buffer_index, offset) - }; + let previous_capacity = self.map.capacity(); + self.map.reserve(additional, |h| h.hash); + self.map_size += + (self.map.capacity() - previous_capacity) * size_of::>(); + } - self.views.push(view); - self.nulls.append_non_null(); - view + /// Converts this set into a `StringViewArray`, or `BinaryViewArray`, + /// containing each distinct value + /// that was inserted. This is done without copying the values. + /// + /// The values are guaranteed to be returned in the same order in which + /// they were first seen. + pub fn into_state(mut self) -> ArrayRef { + self.flush_in_progress(); + + build_view_array( + self.output_type, + self.views, + self.completed, + self.nulls.finish(), + ) } /// Total number of entries (including null, if present) @@ -439,12 +435,7 @@ where let completed_size: usize = self.completed.iter().map(|b| b.len()).sum(); let nulls_size = self.nulls.allocated_size(); - self.map_size - + views_size - + in_progress_size - + completed_size - + nulls_size - + self.hashes_buffer.allocated_size() + self.map_size + views_size + in_progress_size + completed_size + nulls_size } } @@ -459,7 +450,6 @@ where .field("views_len", &self.views.len()) .field("completed_buffers", &self.completed.len()) .field("random_state", &self.random_state) - .field("hashes_buffer", &self.hashes_buffer) .finish() } } @@ -486,6 +476,149 @@ where payload: V, } +fn hash_input_view( + random_state: &RandomState, + view_u128: u128, + input_value: Option<&[u8]>, +) -> u64 { + input_value.map_or_else( + || random_state.hash_one(view_u128), + |value| random_state.hash_one(value), + ) +} + +fn view_matches( + input_hash: u64, + input_view: u128, + input_value: Option<&[u8]>, + stored_hash: u64, + stored_view: u128, + completed: &[Buffer], + in_progress: &[u8], +) -> bool { + if stored_hash != input_hash { + return false; + } + + let len = input_view as u32; + if len <= 12 { + return stored_view == input_view; + } + + if stored_view as u32 != len { + return false; + } + + let stored_prefix = (stored_view >> 32) as u32; + let input_prefix = (input_view >> 32) as u32; + if stored_prefix != input_prefix { + return false; + } + + let byte_view = ByteView::from(stored_view); + let stored_len = byte_view.length as usize; + let buffer_index = byte_view.buffer_index as usize; + let offset = byte_view.offset as usize; + + let stored_value = if buffer_index < completed.len() { + &completed[buffer_index].as_slice()[offset..offset + stored_len] + } else { + &in_progress[offset..offset + stored_len] + }; + stored_value == input_value.expect("non-inline value") +} + +fn append_value( + value: &[u8], + in_progress: &mut Vec, + completed: &mut Vec, + views: Option<&mut Vec>, + nulls: Option<&mut NullBufferBuilder>, +) -> u128 { + let len = value.len(); + let view = if len <= 12 { + make_view(value, 0, 0) + } else { + if in_progress.len() + len > BYTE_VIEW_MAX_BLOCK_SIZE { + let flushed = std::mem::replace( + in_progress, + Vec::with_capacity(BYTE_VIEW_MAX_BLOCK_SIZE), + ); + completed.push(Buffer::from_vec(flushed)); + } + + let buffer_index = completed.len() as u32; + let offset = in_progress.len() as u32; + in_progress.extend_from_slice(value); + + make_view(value, buffer_index, offset) + }; + + append_view(view, views, nulls) +} + +fn adopt_input_buffers( + input_buffers: &[Buffer], + in_progress: &mut Vec, + completed: &mut Vec, +) -> u32 { + if !in_progress.is_empty() { + let flushed = + std::mem::replace(in_progress, Vec::with_capacity(BYTE_VIEW_MAX_BLOCK_SIZE)); + completed.push(Buffer::from_vec(flushed)); + } + + let starting_buffer = completed.len().try_into().expect("too many buffers"); + completed.extend(input_buffers.iter().cloned()); + starting_buffer +} + +fn append_input_view( + input_view: u128, + starting_buffer: u32, + views: Option<&mut Vec>, + nulls: Option<&mut NullBufferBuilder>, +) -> u128 { + let byte_view = ByteView::from(input_view); + let view = byte_view + .with_buffer_index(byte_view.buffer_index + starting_buffer) + .as_u128(); + append_view(view, views, nulls) +} + +fn append_view( + view: u128, + views: Option<&mut Vec>, + nulls: Option<&mut NullBufferBuilder>, +) -> u128 { + if let Some(views) = views { + views.push(view); + } + if let Some(nulls) = nulls { + nulls.append_non_null(); + } + view +} + +fn build_view_array( + output_type: OutputType, + views: Vec, + completed: Vec, + null_buffer: Option, +) -> ArrayRef { + let views = ScalarBuffer::from(views); + let array = unsafe { BinaryViewArray::new_unchecked(views, completed, null_buffer) }; + + match output_type { + OutputType::BinaryView => Arc::new(array), + OutputType::Utf8View => { + let array = unsafe { array.to_string_view_unchecked() }; + Arc::new(array) + } + _ => unreachable!("Utf8/Binary should use `ArrowBytesMap`"), + } +} + #[cfg(test)] mod tests { use arrow::array::{GenericByteViewArray, StringViewArray}; diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index 2f3b1a19e7d73..1839192c05782 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -99,6 +99,16 @@ pub trait GroupValues: Send { /// assigned. fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()>; + /// Interns the rows from `cols` without materializing a group id per input + /// row. + /// + /// This is useful for hash aggregate operators that only need the set of + /// distinct group keys and have no per-row accumulator updates to perform. + fn intern_no_group_ids(&mut self, cols: &[ArrayRef]) -> Result<()> { + let mut groups = Vec::new(); + self.intern(cols, &mut groups) + } + /// Returns the number of bytes of memory used by this [`GroupValues`] fn size(&self) -> usize; @@ -134,64 +144,104 @@ pub trait GroupValues: Send { pub fn new_group_values( schema: SchemaRef, group_ordering: &GroupOrdering, + require_group_indices: bool, ) -> Result> { if schema.fields.len() == 1 { let d = schema.fields[0].data_type(); + let track_group_ids = + require_group_indices || !matches!(group_ordering, GroupOrdering::None); macro_rules! downcast_helper { - ($t:ty, $d:ident) => { - return Ok(Box::new(GroupValuesPrimitive::<$t>::new($d.clone()))) + ($t:ty, $d:ident, $track_group_ids:expr) => { + return Ok(Box::new(GroupValuesPrimitive::<$t>::new( + $d.clone(), + $track_group_ids, + ))) }; } downcast_primitive! { - d => (downcast_helper, d), + d => (downcast_helper, d, track_group_ids), _ => {} } match d { DataType::Date32 => { - downcast_helper!(Date32Type, d); + downcast_helper!(Date32Type, d, track_group_ids); } DataType::Date64 => { - downcast_helper!(Date64Type, d); + downcast_helper!(Date64Type, d, track_group_ids); } DataType::Time32(t) => match t { - TimeUnit::Second => downcast_helper!(Time32SecondType, d), - TimeUnit::Millisecond => downcast_helper!(Time32MillisecondType, d), + TimeUnit::Second => { + downcast_helper!(Time32SecondType, d, track_group_ids) + } + TimeUnit::Millisecond => { + downcast_helper!(Time32MillisecondType, d, track_group_ids) + } _ => {} }, DataType::Time64(t) => match t { - TimeUnit::Microsecond => downcast_helper!(Time64MicrosecondType, d), - TimeUnit::Nanosecond => downcast_helper!(Time64NanosecondType, d), + TimeUnit::Microsecond => { + downcast_helper!(Time64MicrosecondType, d, track_group_ids) + } + TimeUnit::Nanosecond => { + downcast_helper!(Time64NanosecondType, d, track_group_ids) + } _ => {} }, DataType::Timestamp(t, _tz) => match t { - TimeUnit::Second => downcast_helper!(TimestampSecondType, d), - TimeUnit::Millisecond => downcast_helper!(TimestampMillisecondType, d), - TimeUnit::Microsecond => downcast_helper!(TimestampMicrosecondType, d), - TimeUnit::Nanosecond => downcast_helper!(TimestampNanosecondType, d), + TimeUnit::Second => { + downcast_helper!(TimestampSecondType, d, track_group_ids) + } + TimeUnit::Millisecond => { + downcast_helper!(TimestampMillisecondType, d, track_group_ids) + } + TimeUnit::Microsecond => { + downcast_helper!(TimestampMicrosecondType, d, track_group_ids) + } + TimeUnit::Nanosecond => { + downcast_helper!(TimestampNanosecondType, d, track_group_ids) + } }, DataType::Decimal128(_, _) => { - downcast_helper!(Decimal128Type, d); + downcast_helper!(Decimal128Type, d, track_group_ids); } DataType::Utf8 => { - return Ok(Box::new(GroupValuesBytes::::new(OutputType::Utf8))); + return Ok(Box::new(GroupValuesBytes::::new( + OutputType::Utf8, + track_group_ids, + ))); } DataType::LargeUtf8 => { - return Ok(Box::new(GroupValuesBytes::::new(OutputType::Utf8))); + return Ok(Box::new(GroupValuesBytes::::new( + OutputType::Utf8, + track_group_ids, + ))); } DataType::Utf8View => { - return Ok(Box::new(GroupValuesBytesView::new(OutputType::Utf8View))); + return Ok(Box::new(GroupValuesBytesView::new( + OutputType::Utf8View, + track_group_ids, + ))); } DataType::Binary => { - return Ok(Box::new(GroupValuesBytes::::new(OutputType::Binary))); + return Ok(Box::new(GroupValuesBytes::::new( + OutputType::Binary, + track_group_ids, + ))); } DataType::LargeBinary => { - return Ok(Box::new(GroupValuesBytes::::new(OutputType::Binary))); + return Ok(Box::new(GroupValuesBytes::::new( + OutputType::Binary, + track_group_ids, + ))); } DataType::BinaryView => { - return Ok(Box::new(GroupValuesBytesView::new(OutputType::BinaryView))); + return Ok(Box::new(GroupValuesBytesView::new( + OutputType::BinaryView, + track_group_ids, + ))); } DataType::Boolean => { return Ok(Box::new(GroupValuesBoolean::new())); diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs index e993c0c53d199..7e9e965e9826b 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs @@ -83,6 +83,32 @@ impl GroupValues for GroupValuesBoolean { Ok(()) } + fn intern_no_group_ids(&mut self, cols: &[ArrayRef]) -> Result<()> { + let array = cols[0].as_boolean(); + + for value in array.iter() { + match value { + Some(false) => { + if self.false_group.is_none() { + self.false_group = Some(self.len()); + } + } + Some(true) => { + if self.true_group.is_none() { + self.true_group = Some(self.len()); + } + } + None => { + if self.null_group.is_none() { + self.null_group = Some(self.len()); + } + } + } + } + + Ok(()) + } + fn size(&self) -> usize { size_of::() } diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs index b881a51b25474..40f64a336e3e2 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs @@ -22,49 +22,137 @@ use crate::aggregates::group_values::GroupValues; use arrow::array::{Array, ArrayRef, OffsetSizeTrait}; use datafusion_common::Result; use datafusion_expr::EmitTo; -use datafusion_physical_expr_common::binary_map::{ArrowBytesMap, OutputType}; +use datafusion_physical_expr_common::binary_map::{ + ArrowBytesMap, ArrowBytesSet, OutputType, +}; + +enum GroupValuesBytesState { + GroupIds { + map: ArrowBytesMap, + num_groups: usize, + }, + DistinctOnly(ArrowBytesSet), +} /// A [`GroupValues`] storing single column of Utf8/LargeUtf8/Binary/LargeBinary values /// /// This specialization is significantly faster than using the more general /// purpose `Row`s format pub struct GroupValuesBytes { - /// Map string/binary values to group index - map: ArrowBytesMap, - /// The total number of groups so far (used to assign group_index) - num_groups: usize, + output_type: OutputType, + state: GroupValuesBytesState, } impl GroupValuesBytes { - pub fn new(output_type: OutputType) -> Self { - Self { - map: ArrowBytesMap::new(output_type), - num_groups: 0, + pub fn new(output_type: OutputType, track_group_ids: bool) -> Self { + let state = if track_group_ids { + GroupValuesBytesState::GroupIds { + map: ArrowBytesMap::new(output_type), + num_groups: 0, + } + } else { + GroupValuesBytesState::DistinctOnly(ArrowBytesSet::new(output_type)) + }; + + Self { output_type, state } + } + + fn ensure_group_id_tracking(&mut self) { + if matches!(self.state, GroupValuesBytesState::GroupIds { .. }) { + return; + } + + let GroupValuesBytesState::DistinctOnly(set) = &mut self.state else { + unreachable!(); + }; + let contents = set.take().into_state(); + let mut map = ArrowBytesMap::new(self.output_type); + let mut num_groups = 0; + map.insert_if_new( + &contents, + |_value| { + let group_idx = num_groups; + num_groups += 1; + group_idx + }, + |_group_idx| {}, + ); + self.state = GroupValuesBytesState::GroupIds { map, num_groups }; + } + + fn emit_group_ids( + map: &mut ArrowBytesMap, + num_groups: &mut usize, + emit_to: EmitTo, + ) -> ArrayRef { + let map_contents = map.take().into_state(); + + match emit_to { + EmitTo::All => { + *num_groups -= map_contents.len(); + map_contents + } + EmitTo::First(n) if n == *num_groups => { + *num_groups -= map_contents.len(); + map_contents + } + EmitTo::First(n) => { + let emit_group_values = map_contents.slice(0, n); + let remaining_group_values = + map_contents.slice(n, map_contents.len() - n); + + *num_groups = 0; + map.insert_if_new( + &remaining_group_values, + |_value| { + let group_idx = *num_groups; + *num_groups += 1; + group_idx + }, + |_group_idx| {}, + ); + + emit_group_values + } + } + } + + fn emit_distinct_only(set: &mut ArrowBytesSet, emit_to: EmitTo) -> ArrayRef { + let set_contents = set.take().into_state(); + match emit_to { + EmitTo::All => set_contents, + EmitTo::First(n) if n == set_contents.len() => set_contents, + EmitTo::First(n) => { + let emit_group_values = set_contents.slice(0, n); + let remaining_group_values = + set_contents.slice(n, set_contents.len() - n); + set.insert(&remaining_group_values); + emit_group_values + } } } } impl GroupValues for GroupValuesBytes { fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + self.ensure_group_id_tracking(); assert_eq!(cols.len(), 1); // look up / add entries in the table let arr = &cols[0]; + let GroupValuesBytesState::GroupIds { map, num_groups } = &mut self.state else { + unreachable!(); + }; groups.clear(); - self.map.insert_if_new( + map.insert_if_new( arr, - // called for each new group |_value| { - // assign new group index on each insert - let group_idx = self.num_groups; - self.num_groups += 1; + let group_idx = *num_groups; + *num_groups += 1; group_idx }, - // called for each group - |group_idx| { - groups.push(group_idx); - }, + |group_idx| groups.push(group_idx), ); // ensure we assigned a group to for each row @@ -72,48 +160,55 @@ impl GroupValues for GroupValuesBytes { Ok(()) } + fn intern_no_group_ids(&mut self, cols: &[ArrayRef]) -> Result<()> { + assert_eq!(cols.len(), 1); + + let arr = &cols[0]; + match &mut self.state { + GroupValuesBytesState::GroupIds { map, num_groups } => map.insert_if_new( + arr, + |_value| { + let group_idx = *num_groups; + *num_groups += 1; + group_idx + }, + |_group_idx| {}, + ), + GroupValuesBytesState::DistinctOnly(set) => set.insert(arr), + } + + Ok(()) + } + fn size(&self) -> usize { - self.map.size() + size_of::() + size_of::() + + match &self.state { + GroupValuesBytesState::GroupIds { map, .. } => map.size(), + GroupValuesBytesState::DistinctOnly(set) => set.size(), + } } fn is_empty(&self) -> bool { - self.num_groups == 0 + match &self.state { + GroupValuesBytesState::GroupIds { num_groups, .. } => *num_groups == 0, + GroupValuesBytesState::DistinctOnly(set) => set.is_empty(), + } } fn len(&self) -> usize { - self.num_groups + match &self.state { + GroupValuesBytesState::GroupIds { num_groups, .. } => *num_groups, + GroupValuesBytesState::DistinctOnly(set) => set.len(), + } } fn emit(&mut self, emit_to: EmitTo) -> Result> { - // Reset the map to default, and convert it into a single array - let map_contents = self.map.take().into_state(); - - let group_values = match emit_to { - EmitTo::All => { - self.num_groups -= map_contents.len(); - map_contents - } - EmitTo::First(n) if n == self.len() => { - self.num_groups -= map_contents.len(); - map_contents + let group_values = match &mut self.state { + GroupValuesBytesState::GroupIds { map, num_groups } => { + Self::emit_group_ids(map, num_groups, emit_to) } - EmitTo::First(n) => { - // if we only wanted to take the first n, insert the rest back - // into the map we could potentially avoid this reallocation, at - // the expense of much more complex code. - // see https://github.com/apache/datafusion/issues/9195 - let emit_group_values = map_contents.slice(0, n); - let remaining_group_values = - map_contents.slice(n, map_contents.len() - n); - - self.num_groups = 0; - let mut group_indexes = vec![]; - self.intern(&[remaining_group_values], &mut group_indexes)?; - - // Verify that the group indexes were assigned in the correct order - assert_eq!(0, group_indexes[0]); - - emit_group_values + GroupValuesBytesState::DistinctOnly(set) => { + Self::emit_distinct_only(set, emit_to) } }; @@ -121,8 +216,14 @@ impl GroupValues for GroupValuesBytes { } fn clear_shrink(&mut self, _num_rows: usize) { - // in theory we could potentially avoid this reallocation and clear the - // contents of the maps, but for now we just reset the map from the beginning - self.map.take(); + match &mut self.state { + GroupValuesBytesState::GroupIds { map, num_groups } => { + *num_groups = 0; + map.take(); + } + GroupValuesBytesState::DistinctOnly(set) => { + set.take(); + } + } } } diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs index 7a56f7c52c11a..47dceddb47c58 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs @@ -19,25 +19,114 @@ use crate::aggregates::group_values::GroupValues; use arrow::array::{Array, ArrayRef}; use datafusion_expr::EmitTo; use datafusion_physical_expr::binary_map::OutputType; -use datafusion_physical_expr_common::binary_view_map::ArrowBytesViewMap; +use datafusion_physical_expr_common::binary_view_map::{ + ArrowBytesViewMap, ArrowBytesViewSet, +}; use std::mem::size_of; +enum GroupValuesBytesViewState { + GroupIds { + map: ArrowBytesViewMap, + num_groups: usize, + }, + DistinctOnly(ArrowBytesViewSet), +} + /// A [`GroupValues`] storing single column of Utf8View/BinaryView values /// /// This specialization is significantly faster than using the more general /// purpose `Row`s format pub struct GroupValuesBytesView { - /// Map string/binary values to group index - map: ArrowBytesViewMap, - /// The total number of groups so far (used to assign group_index) - num_groups: usize, + output_type: OutputType, + state: GroupValuesBytesViewState, } impl GroupValuesBytesView { - pub fn new(output_type: OutputType) -> Self { - Self { - map: ArrowBytesViewMap::new(output_type), - num_groups: 0, + pub fn new(output_type: OutputType, track_group_ids: bool) -> Self { + let state = if track_group_ids { + GroupValuesBytesViewState::GroupIds { + map: ArrowBytesViewMap::new(output_type), + num_groups: 0, + } + } else { + GroupValuesBytesViewState::DistinctOnly(ArrowBytesViewSet::new(output_type)) + }; + + Self { output_type, state } + } + + fn ensure_group_id_tracking(&mut self) { + if matches!(self.state, GroupValuesBytesViewState::GroupIds { .. }) { + return; + } + + let GroupValuesBytesViewState::DistinctOnly(set) = &mut self.state else { + unreachable!(); + }; + let contents = set.take().into_state(); + let mut map = ArrowBytesViewMap::new(self.output_type); + let mut num_groups = 0; + map.insert_if_new( + &contents, + |_value| { + let group_idx = num_groups; + num_groups += 1; + group_idx + }, + |_group_idx| {}, + ); + self.state = GroupValuesBytesViewState::GroupIds { map, num_groups }; + } + + fn emit_group_ids( + map: &mut ArrowBytesViewMap, + num_groups: &mut usize, + emit_to: EmitTo, + ) -> ArrayRef { + let map_contents = map.take().into_state(); + + match emit_to { + EmitTo::All => { + *num_groups -= map_contents.len(); + map_contents + } + EmitTo::First(n) if n == *num_groups => { + *num_groups -= map_contents.len(); + map_contents + } + EmitTo::First(n) => { + let emit_group_values = map_contents.slice(0, n); + let remaining_group_values = + map_contents.slice(n, map_contents.len() - n); + + *num_groups = 0; + map.insert_if_new( + &remaining_group_values, + |_value| { + let group_idx = *num_groups; + *num_groups += 1; + group_idx + }, + |_group_idx| {}, + ); + + emit_group_values + } + } + } + + fn emit_distinct_only(set: &mut ArrowBytesViewSet, emit_to: EmitTo) -> ArrayRef { + let set_contents = set.take().into_state(); + match emit_to { + EmitTo::All => set_contents, + EmitTo::First(n) if n == set_contents.len() => set_contents, + EmitTo::First(n) => { + let emit_group_values = set_contents.slice(0, n); + let remaining_group_values = + set_contents.slice(n, set_contents.len() - n); + set.insert(&remaining_group_values); + emit_group_values + } } } } @@ -48,25 +137,25 @@ impl GroupValues for GroupValuesBytesView { cols: &[ArrayRef], groups: &mut Vec, ) -> datafusion_common::Result<()> { + self.ensure_group_id_tracking(); assert_eq!(cols.len(), 1); // look up / add entries in the table let arr = &cols[0]; + let GroupValuesBytesViewState::GroupIds { map, num_groups } = &mut self.state + else { + unreachable!(); + }; groups.clear(); - self.map.insert_if_new( + map.insert_if_new( arr, - // called for each new group |_value| { - // assign new group index on each insert - let group_idx = self.num_groups; - self.num_groups += 1; + let group_idx = *num_groups; + *num_groups += 1; group_idx }, - // called for each group - |group_idx| { - groups.push(group_idx); - }, + |group_idx| groups.push(group_idx), ); // ensure we assigned a group to for each row @@ -74,48 +163,58 @@ impl GroupValues for GroupValuesBytesView { Ok(()) } + fn intern_no_group_ids( + &mut self, + cols: &[ArrayRef], + ) -> datafusion_common::Result<()> { + assert_eq!(cols.len(), 1); + + let arr = &cols[0]; + match &mut self.state { + GroupValuesBytesViewState::GroupIds { map, num_groups } => map.insert_if_new( + arr, + |_value| { + let group_idx = *num_groups; + *num_groups += 1; + group_idx + }, + |_group_idx| {}, + ), + GroupValuesBytesViewState::DistinctOnly(set) => set.insert(arr), + } + + Ok(()) + } + fn size(&self) -> usize { - self.map.size() + size_of::() + size_of::() + + match &self.state { + GroupValuesBytesViewState::GroupIds { map, .. } => map.size(), + GroupValuesBytesViewState::DistinctOnly(set) => set.size(), + } } fn is_empty(&self) -> bool { - self.num_groups == 0 + match &self.state { + GroupValuesBytesViewState::GroupIds { num_groups, .. } => *num_groups == 0, + GroupValuesBytesViewState::DistinctOnly(set) => set.is_empty(), + } } fn len(&self) -> usize { - self.num_groups + match &self.state { + GroupValuesBytesViewState::GroupIds { num_groups, .. } => *num_groups, + GroupValuesBytesViewState::DistinctOnly(set) => set.len(), + } } fn emit(&mut self, emit_to: EmitTo) -> datafusion_common::Result> { - // Reset the map to default, and convert it into a single array - let map_contents = self.map.take().into_state(); - - let group_values = match emit_to { - EmitTo::All => { - self.num_groups -= map_contents.len(); - map_contents + let group_values = match &mut self.state { + GroupValuesBytesViewState::GroupIds { map, num_groups } => { + Self::emit_group_ids(map, num_groups, emit_to) } - EmitTo::First(n) if n == self.len() => { - self.num_groups -= map_contents.len(); - map_contents - } - EmitTo::First(n) => { - // if we only wanted to take the first n, insert the rest back - // into the map we could potentially avoid this reallocation, at - // the expense of much more complex code. - // see https://github.com/apache/datafusion/issues/9195 - let emit_group_values = map_contents.slice(0, n); - let remaining_group_values = - map_contents.slice(n, map_contents.len() - n); - - self.num_groups = 0; - let mut group_indexes = vec![]; - self.intern(&[remaining_group_values], &mut group_indexes)?; - - // Verify that the group indexes were assigned in the correct order - assert_eq!(0, group_indexes[0]); - - emit_group_values + GroupValuesBytesViewState::DistinctOnly(set) => { + Self::emit_distinct_only(set, emit_to) } }; @@ -123,8 +222,14 @@ impl GroupValues for GroupValuesBytesView { } fn clear_shrink(&mut self, _num_rows: usize) { - // in theory we could potentially avoid this reallocation and clear the - // contents of the maps, but for now we just reset the map from the beginning - self.map.take(); + match &mut self.state { + GroupValuesBytesViewState::GroupIds { map, num_groups } => { + *num_groups = 0; + map.take(); + } + GroupValuesBytesViewState::DistinctOnly(set) => { + set.take(); + } + } } } diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs index efaf7eba0f1b5..4e6d50d6a3662 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs @@ -79,147 +79,390 @@ hash_float!(f16, f32, f64); /// /// This specialization is significantly faster than using the more general /// purpose `Row`s format +enum GroupValuesPrimitiveState { + GroupIds { + /// Stores the `(group_index, hash)` based on the hash of its value + /// + /// We also store `hash` is for reducing cost of rehashing. Such cost + /// is obvious in high cardinality group by situation. + /// More details can see: + /// + map: HashTable<(usize, u64)>, + /// The group index of the null value if any + null_group: Option, + /// The values for each group index + values: Vec, + }, + DistinctOnly { + /// Stores the distinct primitive values. + map: HashTable, + has_null: bool, + }, +} + pub struct GroupValuesPrimitive { /// The data type of the output array data_type: DataType, - /// Stores the `(group_index, hash)` based on the hash of its value - /// - /// We also store `hash` is for reducing cost of rehashing. Such cost - /// is obvious in high cardinality group by situation. - /// More details can see: - /// - map: HashTable<(usize, u64)>, - /// The group index of the null value if any - null_group: Option, - /// The values for each group index - values: Vec, + state: GroupValuesPrimitiveState, /// The random state used to generate hashes random_state: RandomState, } -impl GroupValuesPrimitive { - pub fn new(data_type: DataType) -> Self { +impl GroupValuesPrimitive +where + T::Native: HashValue, +{ + pub fn new(data_type: DataType, track_group_ids: bool) -> Self { assert!(PrimitiveArray::::is_compatible(&data_type)); + let state = if track_group_ids { + GroupValuesPrimitiveState::GroupIds { + map: HashTable::with_capacity(128), + values: Vec::with_capacity(128), + null_group: None, + } + } else { + GroupValuesPrimitiveState::DistinctOnly { + map: HashTable::with_capacity(128), + has_null: false, + } + }; Self { data_type, - map: HashTable::with_capacity(128), - values: Vec::with_capacity(128), - null_group: None, + state, random_state: crate::aggregates::AGGREGATION_HASH_SEED, } } -} -impl GroupValues for GroupValuesPrimitive -where - T::Native: HashValue, -{ - fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { - assert_eq!(cols.len(), 1); - groups.clear(); + fn build_primitive( + values: Vec, + null_idx: Option, + ) -> PrimitiveArray { + let nulls = null_idx.map(|null_idx| { + let mut buffer = NullBufferBuilder::new(values.len()); + buffer.append_n_non_nulls(null_idx); + buffer.append_null(); + buffer.append_n_non_nulls(values.len() - null_idx - 1); + // NOTE: The inner builder must be constructed as there is at least one null + buffer.finish().unwrap() + }); + PrimitiveArray::::new(values.into(), nulls) + } - for v in cols[0].as_primitive::() { - let group_id = match v { - None => *self.null_group.get_or_insert_with(|| { - let group_id = self.values.len(); - self.values.push(Default::default()); - group_id - }), - Some(key) => { - let state = &self.random_state; - let hash = key.hash(state); - let insert = self.map.entry( - hash, - |&(g, h)| unsafe { - hash == h && self.values.get_unchecked(g).is_eq(key) - }, - |&(_, h)| h, - ); + fn ensure_group_id_tracking(&mut self) { + if matches!(self.state, GroupValuesPrimitiveState::GroupIds { .. }) { + return; + } - match insert { - hashbrown::hash_table::Entry::Occupied(o) => o.get().0, - hashbrown::hash_table::Entry::Vacant(v) => { - let g = self.values.len(); - v.insert((g, hash)); - self.values.push(key); - g - } - } - } - }; - groups.push(group_id) + let GroupValuesPrimitiveState::DistinctOnly { map, has_null } = std::mem::replace( + &mut self.state, + GroupValuesPrimitiveState::GroupIds { + map: HashTable::with_capacity(128), + null_group: None, + values: Vec::with_capacity(128), + }, + ) else { + unreachable!(); + }; + + let mut values = Vec::with_capacity(map.len() + usize::from(has_null)); + let null_group = has_null.then(|| { + values.push(Default::default()); + 0 + }); + let mut group_map = HashTable::with_capacity(map.len()); + for value in map { + let group_idx = values.len(); + values.push(value); + let hash = value.hash(&self.random_state); + group_map + .insert_unique(hash, (group_idx, hash), |&(_, stored_hash)| stored_hash); } - Ok(()) + self.state = GroupValuesPrimitiveState::GroupIds { + map: group_map, + null_group, + values, + }; } - fn size(&self) -> usize { - self.map.capacity() * size_of::<(usize, u64)>() + self.values.allocated_size() - } + fn insert_group_id( + random_state: &RandomState, + map: &mut HashTable<(usize, u64)>, + values: &mut Vec, + null_group: &mut Option, + value: Option, + ) -> usize { + match value { + None => *null_group.get_or_insert_with(|| { + let group_id = values.len(); + values.push(Default::default()); + group_id + }), + Some(key) => { + let hash = key.hash(random_state); + let insert = map.entry( + hash, + |&(g, h)| unsafe { hash == h && values.get_unchecked(g).is_eq(key) }, + |&(_, h)| h, + ); - fn is_empty(&self) -> bool { - self.values.is_empty() + match insert { + hashbrown::hash_table::Entry::Occupied(o) => o.get().0, + hashbrown::hash_table::Entry::Vacant(v) => { + let g = values.len(); + v.insert((g, hash)); + values.push(key); + g + } + } + } + } } - fn len(&self) -> usize { - self.values.len() + fn insert_distinct_only( + random_state: &RandomState, + map: &mut HashTable, + has_null: &mut bool, + value: Option, + ) { + match value { + None => { + *has_null = true; + } + Some(key) => { + Self::insert_distinct_only_hashed( + random_state, + map, + key, + key.hash(random_state), + ); + } + } } - fn emit(&mut self, emit_to: EmitTo) -> Result> { - fn build_primitive( - values: Vec, - null_idx: Option, - ) -> PrimitiveArray { - let nulls = null_idx.map(|null_idx| { - let mut buffer = NullBufferBuilder::new(values.len()); - buffer.append_n_non_nulls(null_idx); - buffer.append_null(); - buffer.append_n_non_nulls(values.len() - null_idx - 1); - // NOTE: The inner builder must be constructed as there is at least one null - buffer.finish().unwrap() - }); - PrimitiveArray::::new(values.into(), nulls) + fn insert_distinct_only_hashed( + random_state: &RandomState, + map: &mut HashTable, + key: T::Native, + hash: u64, + ) { + let insert = map.entry( + hash, + |stored| stored.is_eq(key), + |stored| stored.hash(random_state), + ); + + if let hashbrown::hash_table::Entry::Vacant(v) = insert { + v.insert(key); } + } + fn emit_group_ids( + data_type: &DataType, + map: &mut HashTable<(usize, u64)>, + values: &mut Vec, + null_group: &mut Option, + emit_to: EmitTo, + ) -> ArrayRef { let array: PrimitiveArray = match emit_to { EmitTo::All => { - self.map.clear(); - build_primitive(std::mem::take(&mut self.values), self.null_group.take()) + map.clear(); + Self::build_primitive(std::mem::take(values), null_group.take()) } EmitTo::First(n) => { - self.map.retain(|entry| { - // Decrement group index by n + map.retain(|entry| { let group_idx = entry.0; match group_idx.checked_sub(n) { - // Group index was >= n, shift value down Some(sub) => { entry.0 = sub; true } - // Group index was < n, so remove from table None => false, } }); - let null_group = match &mut self.null_group { + let null_idx = match null_group { Some(v) if *v >= n => { *v -= n; None } - Some(_) => self.null_group.take(), + Some(_) => null_group.take(), None => None, }; - let mut split = self.values.split_off(n); - std::mem::swap(&mut self.values, &mut split); - build_primitive(split, null_group) + let mut split = values.split_off(n); + std::mem::swap(values, &mut split); + Self::build_primitive(split, null_idx) + } + }; + + Arc::new(array.with_data_type(data_type.clone())) + } + + fn emit_distinct_only( + data_type: &DataType, + random_state: &RandomState, + map: &mut HashTable, + has_null: &mut bool, + emit_to: EmitTo, + ) -> ArrayRef { + let total_len = map.len() + usize::from(*has_null); + let mut values = Vec::with_capacity(total_len); + if *has_null { + values.push(Default::default()); + } + values.extend(map.iter().copied()); + map.clear(); + + let (emitted_values, emitted_null_idx, remaining_values, remaining_has_null) = + match emit_to { + EmitTo::All => (values, (*has_null).then_some(0), Vec::new(), false), + EmitTo::First(n) if n >= total_len => { + (values, (*has_null).then_some(0), Vec::new(), false) + } + EmitTo::First(n) => { + let mut remaining_values = values.split_off(n); + let emitted_values = values; + let emitted_null_idx = (*has_null && n > 0).then_some(0); + let remaining_has_null = *has_null && n == 0; + if remaining_has_null { + remaining_values.remove(0); + } + ( + emitted_values, + emitted_null_idx, + remaining_values, + remaining_has_null, + ) + } + }; + + *has_null = remaining_has_null; + for value in remaining_values { + Self::insert_distinct_only(random_state, map, has_null, Some(value)); + } + + Arc::new( + Self::build_primitive(emitted_values, emitted_null_idx) + .with_data_type(data_type.clone()), + ) + } +} + +impl GroupValues for GroupValuesPrimitive +where + T::Native: HashValue, +{ + fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + self.ensure_group_id_tracking(); + assert_eq!(cols.len(), 1); + groups.clear(); + let GroupValuesPrimitiveState::GroupIds { + map, + null_group, + values, + } = &mut self.state + else { + unreachable!(); + }; + + for v in cols[0].as_primitive::() { + let group_id = + Self::insert_group_id(&self.random_state, map, values, null_group, v); + groups.push(group_id) + } + Ok(()) + } + + fn intern_no_group_ids(&mut self, cols: &[ArrayRef]) -> Result<()> { + assert_eq!(cols.len(), 1); + + match &mut self.state { + GroupValuesPrimitiveState::GroupIds { + map, + null_group, + values, + } => { + for v in cols[0].as_primitive::() { + let _ = Self::insert_group_id( + &self.random_state, + map, + values, + null_group, + v, + ); + } + } + GroupValuesPrimitiveState::DistinctOnly { map, has_null } => { + for v in cols[0].as_primitive::() { + Self::insert_distinct_only(&self.random_state, map, has_null, v); + } + } + } + Ok(()) + } + + fn size(&self) -> usize { + size_of::() + + match &self.state { + GroupValuesPrimitiveState::GroupIds { map, values, .. } => { + map.capacity() * size_of::<(usize, u64)>() + values.allocated_size() + } + GroupValuesPrimitiveState::DistinctOnly { map, .. } => { + map.capacity() * size_of::() + } + } + } + + fn is_empty(&self) -> bool { + self.len() == 0 + } + + fn len(&self) -> usize { + match &self.state { + GroupValuesPrimitiveState::GroupIds { values, .. } => values.len(), + GroupValuesPrimitiveState::DistinctOnly { map, has_null, .. } => { + map.len() + usize::from(*has_null) + } + } + } + + fn emit(&mut self, emit_to: EmitTo) -> Result> { + let array = match &mut self.state { + GroupValuesPrimitiveState::GroupIds { + map, + null_group, + values, + } => Self::emit_group_ids(&self.data_type, map, values, null_group, emit_to), + GroupValuesPrimitiveState::DistinctOnly { map, has_null, .. } => { + Self::emit_distinct_only( + &self.data_type, + &self.random_state, + map, + has_null, + emit_to, + ) } }; - Ok(vec![Arc::new(array.with_data_type(self.data_type.clone()))]) + Ok(vec![array]) } fn clear_shrink(&mut self, num_rows: usize) { - self.values.clear(); - self.values.shrink_to(num_rows); - self.map.clear(); - self.map.shrink_to(num_rows, |_| 0); // hasher does not matter since the map is cleared + match &mut self.state { + GroupValuesPrimitiveState::GroupIds { + map, + null_group, + values, + } => { + *null_group = None; + values.clear(); + values.shrink_to(num_rows); + map.clear(); + map.shrink_to(num_rows, |_| 0); + } + GroupValuesPrimitiveState::DistinctOnly { map, has_null } => { + *has_null = false; + map.clear(); + map.shrink_to(num_rows, |_| 0); + } + } } } diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 056a7f171a516..6b2a9bc174176 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -587,7 +587,10 @@ impl GroupedHashAggregateStream { _ => OutOfMemoryMode::ReportError, }; - let group_values = new_group_values(group_schema, &group_ordering)?; + let require_group_indices = !accumulators.is_empty() + || matches!(group_ordering, GroupOrdering::Partial(_)); + let group_values = + new_group_values(group_schema, &group_ordering, require_group_indices)?; let reservation = MemoryConsumer::new(name) // We interpret 'can spill' as 'can handle memory back pressure'. // This value needs to be set to true for the default memory pool implementations @@ -947,62 +950,87 @@ impl GroupedHashAggregateStream { for group_values in &group_by_values { let groups_start_time = Instant::now(); - // calculate the group indices for each input row let starting_num_groups = self.group_values.len(); - self.group_values - .intern(group_values, &mut self.current_group_indices)?; - let group_indices = &self.current_group_indices; - - // Update ordering information if necessary - let total_num_groups = self.group_values.len(); - if total_num_groups > starting_num_groups { - self.group_ordering.new_groups( - group_values, - group_indices, - total_num_groups, - )?; - } - - // Use this instant for both measurements to save a syscall - let agg_start_time = Instant::now(); - self.group_by_metrics - .time_calculating_group_ids - .add_duration(agg_start_time - groups_start_time); - - // Gather the inputs to call the actual accumulator - let t = self - .accumulators - .iter_mut() - .zip(input_values.iter()) - .zip(filter_values.iter()); - - for ((acc, values), opt_filter) in t { - let opt_filter = opt_filter.as_ref().map(|filter| filter.as_boolean()); - - // Call the appropriate method on each aggregator with - // the entire input row and the relevant group indexes - if self.mode.input_mode() == AggregateInputMode::Raw - && !self.spill_state.is_stream_merging - { - acc.update_batch( - values, + let needs_group_indices = !self.accumulators.is_empty() + || matches!(self.group_ordering, GroupOrdering::Partial(_)); + + if needs_group_indices { + // calculate the group indices for each input row + self.group_values + .intern(group_values, &mut self.current_group_indices)?; + let group_indices = &self.current_group_indices; + + // Update ordering information if necessary + let total_num_groups = self.group_values.len(); + if total_num_groups > starting_num_groups { + self.group_ordering.new_groups( + group_values, group_indices, - opt_filter, total_num_groups, )?; - } else { - assert_or_internal_err!( - opt_filter.is_none(), - "aggregate filter should be applied in partial stage, there should be no filter in final stage" - ); - - // if aggregation is over intermediate states, - // use merge - acc.merge_batch(values, group_indices, None, total_num_groups)?; } + + // Use this instant for both measurements to save a syscall + let agg_start_time = Instant::now(); + self.group_by_metrics + .time_calculating_group_ids + .add_duration(agg_start_time - groups_start_time); + + // Gather the inputs to call the actual accumulator + let t = self + .accumulators + .iter_mut() + .zip(input_values.iter()) + .zip(filter_values.iter()); + + for ((acc, values), opt_filter) in t { + let opt_filter = + opt_filter.as_ref().map(|filter| filter.as_boolean()); + + // Call the appropriate method on each aggregator with + // the entire input row and the relevant group indexes + if self.mode.input_mode() == AggregateInputMode::Raw + && !self.spill_state.is_stream_merging + { + acc.update_batch( + values, + group_indices, + opt_filter, + total_num_groups, + )?; + } else { + assert_or_internal_err!( + opt_filter.is_none(), + "aggregate filter should be applied in partial stage, there should be no filter in final stage" + ); + + // if aggregation is over intermediate states, + // use merge + acc.merge_batch(values, group_indices, None, total_num_groups)?; + } + self.group_by_metrics + .aggregation_time + .add_elapsed(agg_start_time); + } + } else { + self.group_values.intern_no_group_ids(group_values)?; + + let total_num_groups = self.group_values.len(); + if total_num_groups > starting_num_groups { + match &mut self.group_ordering { + GroupOrdering::None => {} + GroupOrdering::Full(ordering) => { + ordering.new_groups(total_num_groups); + } + GroupOrdering::Partial(_) => unreachable!( + "partial ordering requires per-row group indices" + ), + } + } + self.group_by_metrics - .aggregation_time - .add_elapsed(agg_start_time); + .time_calculating_group_ids + .add_duration(Instant::now() - groups_start_time); } } @@ -1260,16 +1288,23 @@ impl GroupedHashAggregateStream { // on the grouping columns. self.group_ordering = GroupOrdering::Full(GroupOrderingFull::new()); - // Recreate `group_values` for streaming merge so group ids are assigned - // in first-seen order, as required by `GroupOrderingFull`. - // The pre-spill multi-column collector may use `vectorized_intern`, which - // can assign new group ids out of input order under hash collisions. + // Recreate `group_values` for streaming merge when the previous collector + // could emit groups out of first-seen order. This is required for: + // - the multi-column collector, which may use `vectorized_intern` + // - the unordered distinct-only collectors, which deliberately do not + // preserve first-seen order while building the hash table let group_schema = self .spill_state .merging_group_by .group_schema(&self.spill_state.spill_schema)?; - if group_schema.fields().len() > 1 { - self.group_values = new_group_values(group_schema, &self.group_ordering)?; + let require_group_indices = !self.accumulators.is_empty() + || matches!(self.group_ordering, GroupOrdering::Partial(_)); + if group_schema.fields().len() > 1 || !require_group_indices { + self.group_values = new_group_values( + group_schema, + &self.group_ordering, + require_group_indices, + )?; } // Use `OutOfMemoryMode::ReportError` from this point on diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index 35b787759441c..49b4c42b7bff6 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -441,7 +441,7 @@ struct DistinctDeduplicator { impl DistinctDeduplicator { fn new(schema: SchemaRef, task_context: &TaskContext) -> Result { - let group_values = new_group_values(schema, &GroupOrdering::None)?; + let group_values = new_group_values(schema, &GroupOrdering::None, true)?; let reservation = MemoryConsumer::new("RecursiveQueryHashTable") .register(task_context.memory_pool()); Ok(Self { From d5b6f1e89604a4211708852348a39c838f7e61d4 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Sat, 18 Apr 2026 11:48:22 +0800 Subject: [PATCH 2/5] Keep only primitive distinct group-id fast path for q4 --- datafusion/expr/src/logical_plan/plan.rs | 1 + .../src/binary_view_map.rs | 359 ++++++------------ .../src/aggregates/group_values/mod.rs | 30 +- .../group_values/single_group_by/boolean.rs | 26 -- .../group_values/single_group_by/bytes.rs | 205 +++------- .../single_group_by/bytes_view.rs | 209 +++------- .../physical-plan/src/aggregates/row_hash.rs | 20 +- 7 files changed, 231 insertions(+), 619 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 4f73169ad2827..45df05a69972b 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -3768,6 +3768,7 @@ impl PartialOrd for Aggregate { /// index among identical entries. For example, if the same set appears three /// times, the ordinals are 0, 1, 2 and this function returns 2. /// Returns 0 when no grouping set is duplicated. +#[allow(clippy::allow_attributes, clippy::mutable_key_type)] // Expr contains Arc with interior mutability but is intentionally used as hash key fn max_grouping_set_duplicate_ordinal(group_expr: &[Expr]) -> usize { if let Some(Expr::GroupingSet(GroupingSet::GroupingSets(sets))) = group_expr.first() { let mut counts: HashMap<&[Expr], usize> = HashMap::new(); diff --git a/datafusion/physical-expr-common/src/binary_view_map.rs b/datafusion/physical-expr-common/src/binary_view_map.rs index b106bcbbf81c3..abc3e28f82627 100644 --- a/datafusion/physical-expr-common/src/binary_view_map.rs +++ b/datafusion/physical-expr-common/src/binary_view_map.rs @@ -21,12 +21,12 @@ use crate::binary_map::OutputType; use arrow::array::NullBufferBuilder; use arrow::array::cast::AsArray; use arrow::array::{Array, ArrayRef, BinaryViewArray, ByteView, make_view}; -use arrow::buffer::{Buffer, NullBuffer, ScalarBuffer}; +use arrow::buffer::{Buffer, ScalarBuffer}; use arrow::datatypes::{BinaryViewType, ByteViewType, DataType, StringViewType}; use datafusion_common::hash_utils::RandomState; -use datafusion_common::utils::proxy::HashTableAllocExt; +use datafusion_common::hash_utils::create_hashes; +use datafusion_common::utils::proxy::{HashTableAllocExt, VecAllocExt}; use std::fmt::Debug; -use std::hash::BuildHasher; use std::mem::size_of; use std::sync::Arc; @@ -139,8 +139,8 @@ where /// random state used to generate hashes random_state: RandomState, - /// Number of new non-null entries inserted by the previous batch. - last_batch_new_entries: usize, + /// buffer that stores hash values (reused across batches to save allocations) + hashes_buffer: Vec, /// `(payload, null_index)` for the 'null' value, if any /// NOTE null_index is the logical index in the final array, not the index /// in the buffer @@ -164,7 +164,7 @@ where completed: Vec::new(), nulls: NullBufferBuilder::new(0), random_state: RandomState::default(), - last_batch_new_entries: 0, + hashes_buffer: vec![], null: None, } } @@ -252,15 +252,27 @@ where OP: FnMut(V), B: ByteViewType, { + // step 1: compute hashes + let batch_hashes = &mut self.hashes_buffer; + batch_hashes.clear(); + batch_hashes.resize(values.len(), 0); + create_hashes([values], &self.random_state, batch_hashes) + // hash is supported for all types and create_hashes only + // returns errors for unsupported types + .unwrap(); + + // step 2: insert each value into the set, if not already present let values = values.as_byte_view::(); + + // Get raw views buffer for direct comparison let input_views = values.views(); - let input_buffers = values.data_buffers(); - let starting_len = self.map.len(); - self.reserve_for_batch(values.len() - values.null_count()); - let mut adopted_buffer_start = None; + + // Ensure lengths are equivalent + assert_eq!(values.len(), self.hashes_buffer.len()); for i in 0..values.len() { let view_u128 = input_views[i]; + let hash = self.hashes_buffer[i]; // handle null value via validity bitmap check if values.is_null(i) { @@ -278,29 +290,47 @@ where continue; } + // Extract length from the view (first 4 bytes of u128 in little-endian) let len = view_u128 as u32; - let input_value = (len > 12).then(|| { - let value = values.value(i); - let bytes: &[u8] = value.as_ref(); - bytes - }); - let hash = hash_input_view(&self.random_state, view_u128, input_value); + // Check if value already exists let maybe_payload = { + // Borrow completed and in_progress for comparison let completed = &self.completed; let in_progress = &self.in_progress; self.map .find(hash, |header| { - view_matches( - hash, - view_u128, - input_value, - header.hash, - header.view, - completed, - in_progress, - ) + if header.hash != hash { + return false; + } + + // Fast path: inline strings can be compared directly + if len <= 12 { + return header.view == view_u128; + } + + // For larger strings: first compare the 4-byte prefix + let stored_prefix = (header.view >> 32) as u32; + let input_prefix = (view_u128 >> 32) as u32; + if stored_prefix != input_prefix { + return false; + } + + // Prefix matched - compare full bytes + let byte_view = ByteView::from(header.view); + let stored_len = byte_view.length as usize; + let buffer_index = byte_view.buffer_index as usize; + let offset = byte_view.offset as usize; + + let stored_value = if buffer_index < completed.len() { + &completed[buffer_index].as_slice() + [offset..offset + stored_len] + } else { + &in_progress[offset..offset + stored_len] + }; + let input_value: &[u8] = values.value(i).as_ref(); + stored_value == input_value }) .map(|entry| entry.payload) }; @@ -308,40 +338,12 @@ where let payload = if let Some(payload) = maybe_payload { payload } else { - let value = input_value.unwrap_or_else(|| { - let value = values.value(i); - let bytes: &[u8] = value.as_ref(); - bytes - }); + // no existing value, make a new one + let value: &[u8] = values.value(i).as_ref(); let payload = make_payload_fn(Some(value)); - let new_view = if len <= 12 { - append_view(view_u128, Some(&mut self.views), Some(&mut self.nulls)) - } else if let Some(starting_buffer) = adopted_buffer_start.or_else(|| { - (!input_buffers.is_empty()).then(|| { - adopt_input_buffers( - input_buffers, - &mut self.in_progress, - &mut self.completed, - ) - }) - }) { - adopted_buffer_start = Some(starting_buffer); - append_input_view( - view_u128, - starting_buffer, - Some(&mut self.views), - Some(&mut self.nulls), - ) - } else { - append_value( - value, - &mut self.in_progress, - &mut self.completed, - Some(&mut self.views), - Some(&mut self.nulls), - ) - }; + // Create view pointing to our buffers + let new_view = self.append_value(value); let new_header = Entry { view: new_view, hash, @@ -354,62 +356,64 @@ where }; observe_payload_fn(payload); } - - self.last_batch_new_entries = self.map.len() - starting_len; } - fn flush_in_progress(&mut self) { - if self.in_progress.is_empty() { - return; + /// Converts this set into a `StringViewArray`, or `BinaryViewArray`, + /// containing each distinct value + /// that was inserted. This is done without copying the values. + /// + /// The values are guaranteed to be returned in the same order in which + /// they were first seen. + pub fn into_state(mut self) -> ArrayRef { + // Flush any remaining in-progress buffer + if !self.in_progress.is_empty() { + let flushed = std::mem::take(&mut self.in_progress); + self.completed.push(Buffer::from_vec(flushed)); } - let flushed = std::mem::replace( - &mut self.in_progress, - Vec::with_capacity(BYTE_VIEW_MAX_BLOCK_SIZE), - ); - self.completed.push(Buffer::from_vec(flushed)); - } + // Build null buffer if we have any nulls + let null_buffer = self.nulls.finish(); - fn reserve_for_batch(&mut self, non_null_rows: usize) { - if non_null_rows == 0 { - return; + let views = ScalarBuffer::from(self.views); + let array = + unsafe { BinaryViewArray::new_unchecked(views, self.completed, null_buffer) }; + + match self.output_type { + OutputType::BinaryView => Arc::new(array), + OutputType::Utf8View => { + // SAFETY: all input was valid utf8 + let array = unsafe { array.to_string_view_unchecked() }; + Arc::new(array) + } + _ => unreachable!("Utf8/Binary should use `ArrowBytesMap`"), } + } - let expected_new_entries = if self.last_batch_new_entries == 0 { - non_null_rows + /// Append a value to our buffers and return the view pointing to it + fn append_value(&mut self, value: &[u8]) -> u128 { + let len = value.len(); + let view = if len <= 12 { + make_view(value, 0, 0) } else { - self.last_batch_new_entries - .saturating_mul(2) - .min(non_null_rows) - }; + // Ensure buffer is big enough + if self.in_progress.len() + len > BYTE_VIEW_MAX_BLOCK_SIZE { + let flushed = std::mem::replace( + &mut self.in_progress, + Vec::with_capacity(BYTE_VIEW_MAX_BLOCK_SIZE), + ); + self.completed.push(Buffer::from_vec(flushed)); + } - let remaining_capacity = self.map.capacity().saturating_sub(self.map.len()); - let additional = expected_new_entries.saturating_sub(remaining_capacity); - if additional == 0 { - return; - } + let buffer_index = self.completed.len() as u32; + let offset = self.in_progress.len() as u32; + self.in_progress.extend_from_slice(value); - let previous_capacity = self.map.capacity(); - self.map.reserve(additional, |h| h.hash); - self.map_size += - (self.map.capacity() - previous_capacity) * size_of::>(); - } + make_view(value, buffer_index, offset) + }; - /// Converts this set into a `StringViewArray`, or `BinaryViewArray`, - /// containing each distinct value - /// that was inserted. This is done without copying the values. - /// - /// The values are guaranteed to be returned in the same order in which - /// they were first seen. - pub fn into_state(mut self) -> ArrayRef { - self.flush_in_progress(); - - build_view_array( - self.output_type, - self.views, - self.completed, - self.nulls.finish(), - ) + self.views.push(view); + self.nulls.append_non_null(); + view } /// Total number of entries (including null, if present) @@ -435,7 +439,12 @@ where let completed_size: usize = self.completed.iter().map(|b| b.len()).sum(); let nulls_size = self.nulls.allocated_size(); - self.map_size + views_size + in_progress_size + completed_size + nulls_size + self.map_size + + views_size + + in_progress_size + + completed_size + + nulls_size + + self.hashes_buffer.allocated_size() } } @@ -450,6 +459,7 @@ where .field("views_len", &self.views.len()) .field("completed_buffers", &self.completed.len()) .field("random_state", &self.random_state) + .field("hashes_buffer", &self.hashes_buffer) .finish() } } @@ -476,149 +486,6 @@ where payload: V, } -fn hash_input_view( - random_state: &RandomState, - view_u128: u128, - input_value: Option<&[u8]>, -) -> u64 { - input_value.map_or_else( - || random_state.hash_one(view_u128), - |value| random_state.hash_one(value), - ) -} - -fn view_matches( - input_hash: u64, - input_view: u128, - input_value: Option<&[u8]>, - stored_hash: u64, - stored_view: u128, - completed: &[Buffer], - in_progress: &[u8], -) -> bool { - if stored_hash != input_hash { - return false; - } - - let len = input_view as u32; - if len <= 12 { - return stored_view == input_view; - } - - if stored_view as u32 != len { - return false; - } - - let stored_prefix = (stored_view >> 32) as u32; - let input_prefix = (input_view >> 32) as u32; - if stored_prefix != input_prefix { - return false; - } - - let byte_view = ByteView::from(stored_view); - let stored_len = byte_view.length as usize; - let buffer_index = byte_view.buffer_index as usize; - let offset = byte_view.offset as usize; - - let stored_value = if buffer_index < completed.len() { - &completed[buffer_index].as_slice()[offset..offset + stored_len] - } else { - &in_progress[offset..offset + stored_len] - }; - stored_value == input_value.expect("non-inline value") -} - -fn append_value( - value: &[u8], - in_progress: &mut Vec, - completed: &mut Vec, - views: Option<&mut Vec>, - nulls: Option<&mut NullBufferBuilder>, -) -> u128 { - let len = value.len(); - let view = if len <= 12 { - make_view(value, 0, 0) - } else { - if in_progress.len() + len > BYTE_VIEW_MAX_BLOCK_SIZE { - let flushed = std::mem::replace( - in_progress, - Vec::with_capacity(BYTE_VIEW_MAX_BLOCK_SIZE), - ); - completed.push(Buffer::from_vec(flushed)); - } - - let buffer_index = completed.len() as u32; - let offset = in_progress.len() as u32; - in_progress.extend_from_slice(value); - - make_view(value, buffer_index, offset) - }; - - append_view(view, views, nulls) -} - -fn adopt_input_buffers( - input_buffers: &[Buffer], - in_progress: &mut Vec, - completed: &mut Vec, -) -> u32 { - if !in_progress.is_empty() { - let flushed = - std::mem::replace(in_progress, Vec::with_capacity(BYTE_VIEW_MAX_BLOCK_SIZE)); - completed.push(Buffer::from_vec(flushed)); - } - - let starting_buffer = completed.len().try_into().expect("too many buffers"); - completed.extend(input_buffers.iter().cloned()); - starting_buffer -} - -fn append_input_view( - input_view: u128, - starting_buffer: u32, - views: Option<&mut Vec>, - nulls: Option<&mut NullBufferBuilder>, -) -> u128 { - let byte_view = ByteView::from(input_view); - let view = byte_view - .with_buffer_index(byte_view.buffer_index + starting_buffer) - .as_u128(); - append_view(view, views, nulls) -} - -fn append_view( - view: u128, - views: Option<&mut Vec>, - nulls: Option<&mut NullBufferBuilder>, -) -> u128 { - if let Some(views) = views { - views.push(view); - } - if let Some(nulls) = nulls { - nulls.append_non_null(); - } - view -} - -fn build_view_array( - output_type: OutputType, - views: Vec, - completed: Vec, - null_buffer: Option, -) -> ArrayRef { - let views = ScalarBuffer::from(views); - let array = unsafe { BinaryViewArray::new_unchecked(views, completed, null_buffer) }; - - match output_type { - OutputType::BinaryView => Arc::new(array), - OutputType::Utf8View => { - let array = unsafe { array.to_string_view_unchecked() }; - Arc::new(array) - } - _ => unreachable!("Utf8/Binary should use `ArrowBytesMap`"), - } -} - #[cfg(test)] mod tests { use arrow::array::{GenericByteViewArray, StringViewArray}; diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index 1839192c05782..ad27d30cd6ee6 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -208,40 +208,22 @@ pub fn new_group_values( downcast_helper!(Decimal128Type, d, track_group_ids); } DataType::Utf8 => { - return Ok(Box::new(GroupValuesBytes::::new( - OutputType::Utf8, - track_group_ids, - ))); + return Ok(Box::new(GroupValuesBytes::::new(OutputType::Utf8))); } DataType::LargeUtf8 => { - return Ok(Box::new(GroupValuesBytes::::new( - OutputType::Utf8, - track_group_ids, - ))); + return Ok(Box::new(GroupValuesBytes::::new(OutputType::Utf8))); } DataType::Utf8View => { - return Ok(Box::new(GroupValuesBytesView::new( - OutputType::Utf8View, - track_group_ids, - ))); + return Ok(Box::new(GroupValuesBytesView::new(OutputType::Utf8View))); } DataType::Binary => { - return Ok(Box::new(GroupValuesBytes::::new( - OutputType::Binary, - track_group_ids, - ))); + return Ok(Box::new(GroupValuesBytes::::new(OutputType::Binary))); } DataType::LargeBinary => { - return Ok(Box::new(GroupValuesBytes::::new( - OutputType::Binary, - track_group_ids, - ))); + return Ok(Box::new(GroupValuesBytes::::new(OutputType::Binary))); } DataType::BinaryView => { - return Ok(Box::new(GroupValuesBytesView::new( - OutputType::BinaryView, - track_group_ids, - ))); + return Ok(Box::new(GroupValuesBytesView::new(OutputType::BinaryView))); } DataType::Boolean => { return Ok(Box::new(GroupValuesBoolean::new())); diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs index 7e9e965e9826b..e993c0c53d199 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/boolean.rs @@ -83,32 +83,6 @@ impl GroupValues for GroupValuesBoolean { Ok(()) } - fn intern_no_group_ids(&mut self, cols: &[ArrayRef]) -> Result<()> { - let array = cols[0].as_boolean(); - - for value in array.iter() { - match value { - Some(false) => { - if self.false_group.is_none() { - self.false_group = Some(self.len()); - } - } - Some(true) => { - if self.true_group.is_none() { - self.true_group = Some(self.len()); - } - } - None => { - if self.null_group.is_none() { - self.null_group = Some(self.len()); - } - } - } - } - - Ok(()) - } - fn size(&self) -> usize { size_of::() } diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs index 40f64a336e3e2..b881a51b25474 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes.rs @@ -22,137 +22,49 @@ use crate::aggregates::group_values::GroupValues; use arrow::array::{Array, ArrayRef, OffsetSizeTrait}; use datafusion_common::Result; use datafusion_expr::EmitTo; -use datafusion_physical_expr_common::binary_map::{ - ArrowBytesMap, ArrowBytesSet, OutputType, -}; - -enum GroupValuesBytesState { - GroupIds { - map: ArrowBytesMap, - num_groups: usize, - }, - DistinctOnly(ArrowBytesSet), -} +use datafusion_physical_expr_common::binary_map::{ArrowBytesMap, OutputType}; /// A [`GroupValues`] storing single column of Utf8/LargeUtf8/Binary/LargeBinary values /// /// This specialization is significantly faster than using the more general /// purpose `Row`s format pub struct GroupValuesBytes { - output_type: OutputType, - state: GroupValuesBytesState, + /// Map string/binary values to group index + map: ArrowBytesMap, + /// The total number of groups so far (used to assign group_index) + num_groups: usize, } impl GroupValuesBytes { - pub fn new(output_type: OutputType, track_group_ids: bool) -> Self { - let state = if track_group_ids { - GroupValuesBytesState::GroupIds { - map: ArrowBytesMap::new(output_type), - num_groups: 0, - } - } else { - GroupValuesBytesState::DistinctOnly(ArrowBytesSet::new(output_type)) - }; - - Self { output_type, state } - } - - fn ensure_group_id_tracking(&mut self) { - if matches!(self.state, GroupValuesBytesState::GroupIds { .. }) { - return; - } - - let GroupValuesBytesState::DistinctOnly(set) = &mut self.state else { - unreachable!(); - }; - let contents = set.take().into_state(); - let mut map = ArrowBytesMap::new(self.output_type); - let mut num_groups = 0; - map.insert_if_new( - &contents, - |_value| { - let group_idx = num_groups; - num_groups += 1; - group_idx - }, - |_group_idx| {}, - ); - self.state = GroupValuesBytesState::GroupIds { map, num_groups }; - } - - fn emit_group_ids( - map: &mut ArrowBytesMap, - num_groups: &mut usize, - emit_to: EmitTo, - ) -> ArrayRef { - let map_contents = map.take().into_state(); - - match emit_to { - EmitTo::All => { - *num_groups -= map_contents.len(); - map_contents - } - EmitTo::First(n) if n == *num_groups => { - *num_groups -= map_contents.len(); - map_contents - } - EmitTo::First(n) => { - let emit_group_values = map_contents.slice(0, n); - let remaining_group_values = - map_contents.slice(n, map_contents.len() - n); - - *num_groups = 0; - map.insert_if_new( - &remaining_group_values, - |_value| { - let group_idx = *num_groups; - *num_groups += 1; - group_idx - }, - |_group_idx| {}, - ); - - emit_group_values - } - } - } - - fn emit_distinct_only(set: &mut ArrowBytesSet, emit_to: EmitTo) -> ArrayRef { - let set_contents = set.take().into_state(); - match emit_to { - EmitTo::All => set_contents, - EmitTo::First(n) if n == set_contents.len() => set_contents, - EmitTo::First(n) => { - let emit_group_values = set_contents.slice(0, n); - let remaining_group_values = - set_contents.slice(n, set_contents.len() - n); - set.insert(&remaining_group_values); - emit_group_values - } + pub fn new(output_type: OutputType) -> Self { + Self { + map: ArrowBytesMap::new(output_type), + num_groups: 0, } } } impl GroupValues for GroupValuesBytes { fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { - self.ensure_group_id_tracking(); assert_eq!(cols.len(), 1); // look up / add entries in the table let arr = &cols[0]; - let GroupValuesBytesState::GroupIds { map, num_groups } = &mut self.state else { - unreachable!(); - }; groups.clear(); - map.insert_if_new( + self.map.insert_if_new( arr, + // called for each new group |_value| { - let group_idx = *num_groups; - *num_groups += 1; + // assign new group index on each insert + let group_idx = self.num_groups; + self.num_groups += 1; group_idx }, - |group_idx| groups.push(group_idx), + // called for each group + |group_idx| { + groups.push(group_idx); + }, ); // ensure we assigned a group to for each row @@ -160,55 +72,48 @@ impl GroupValues for GroupValuesBytes { Ok(()) } - fn intern_no_group_ids(&mut self, cols: &[ArrayRef]) -> Result<()> { - assert_eq!(cols.len(), 1); - - let arr = &cols[0]; - match &mut self.state { - GroupValuesBytesState::GroupIds { map, num_groups } => map.insert_if_new( - arr, - |_value| { - let group_idx = *num_groups; - *num_groups += 1; - group_idx - }, - |_group_idx| {}, - ), - GroupValuesBytesState::DistinctOnly(set) => set.insert(arr), - } - - Ok(()) - } - fn size(&self) -> usize { - size_of::() - + match &self.state { - GroupValuesBytesState::GroupIds { map, .. } => map.size(), - GroupValuesBytesState::DistinctOnly(set) => set.size(), - } + self.map.size() + size_of::() } fn is_empty(&self) -> bool { - match &self.state { - GroupValuesBytesState::GroupIds { num_groups, .. } => *num_groups == 0, - GroupValuesBytesState::DistinctOnly(set) => set.is_empty(), - } + self.num_groups == 0 } fn len(&self) -> usize { - match &self.state { - GroupValuesBytesState::GroupIds { num_groups, .. } => *num_groups, - GroupValuesBytesState::DistinctOnly(set) => set.len(), - } + self.num_groups } fn emit(&mut self, emit_to: EmitTo) -> Result> { - let group_values = match &mut self.state { - GroupValuesBytesState::GroupIds { map, num_groups } => { - Self::emit_group_ids(map, num_groups, emit_to) + // Reset the map to default, and convert it into a single array + let map_contents = self.map.take().into_state(); + + let group_values = match emit_to { + EmitTo::All => { + self.num_groups -= map_contents.len(); + map_contents + } + EmitTo::First(n) if n == self.len() => { + self.num_groups -= map_contents.len(); + map_contents } - GroupValuesBytesState::DistinctOnly(set) => { - Self::emit_distinct_only(set, emit_to) + EmitTo::First(n) => { + // if we only wanted to take the first n, insert the rest back + // into the map we could potentially avoid this reallocation, at + // the expense of much more complex code. + // see https://github.com/apache/datafusion/issues/9195 + let emit_group_values = map_contents.slice(0, n); + let remaining_group_values = + map_contents.slice(n, map_contents.len() - n); + + self.num_groups = 0; + let mut group_indexes = vec![]; + self.intern(&[remaining_group_values], &mut group_indexes)?; + + // Verify that the group indexes were assigned in the correct order + assert_eq!(0, group_indexes[0]); + + emit_group_values } }; @@ -216,14 +121,8 @@ impl GroupValues for GroupValuesBytes { } fn clear_shrink(&mut self, _num_rows: usize) { - match &mut self.state { - GroupValuesBytesState::GroupIds { map, num_groups } => { - *num_groups = 0; - map.take(); - } - GroupValuesBytesState::DistinctOnly(set) => { - set.take(); - } - } + // in theory we could potentially avoid this reallocation and clear the + // contents of the maps, but for now we just reset the map from the beginning + self.map.take(); } } diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs index 47dceddb47c58..7a56f7c52c11a 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/bytes_view.rs @@ -19,114 +19,25 @@ use crate::aggregates::group_values::GroupValues; use arrow::array::{Array, ArrayRef}; use datafusion_expr::EmitTo; use datafusion_physical_expr::binary_map::OutputType; -use datafusion_physical_expr_common::binary_view_map::{ - ArrowBytesViewMap, ArrowBytesViewSet, -}; +use datafusion_physical_expr_common::binary_view_map::ArrowBytesViewMap; use std::mem::size_of; -enum GroupValuesBytesViewState { - GroupIds { - map: ArrowBytesViewMap, - num_groups: usize, - }, - DistinctOnly(ArrowBytesViewSet), -} - /// A [`GroupValues`] storing single column of Utf8View/BinaryView values /// /// This specialization is significantly faster than using the more general /// purpose `Row`s format pub struct GroupValuesBytesView { - output_type: OutputType, - state: GroupValuesBytesViewState, + /// Map string/binary values to group index + map: ArrowBytesViewMap, + /// The total number of groups so far (used to assign group_index) + num_groups: usize, } impl GroupValuesBytesView { - pub fn new(output_type: OutputType, track_group_ids: bool) -> Self { - let state = if track_group_ids { - GroupValuesBytesViewState::GroupIds { - map: ArrowBytesViewMap::new(output_type), - num_groups: 0, - } - } else { - GroupValuesBytesViewState::DistinctOnly(ArrowBytesViewSet::new(output_type)) - }; - - Self { output_type, state } - } - - fn ensure_group_id_tracking(&mut self) { - if matches!(self.state, GroupValuesBytesViewState::GroupIds { .. }) { - return; - } - - let GroupValuesBytesViewState::DistinctOnly(set) = &mut self.state else { - unreachable!(); - }; - let contents = set.take().into_state(); - let mut map = ArrowBytesViewMap::new(self.output_type); - let mut num_groups = 0; - map.insert_if_new( - &contents, - |_value| { - let group_idx = num_groups; - num_groups += 1; - group_idx - }, - |_group_idx| {}, - ); - self.state = GroupValuesBytesViewState::GroupIds { map, num_groups }; - } - - fn emit_group_ids( - map: &mut ArrowBytesViewMap, - num_groups: &mut usize, - emit_to: EmitTo, - ) -> ArrayRef { - let map_contents = map.take().into_state(); - - match emit_to { - EmitTo::All => { - *num_groups -= map_contents.len(); - map_contents - } - EmitTo::First(n) if n == *num_groups => { - *num_groups -= map_contents.len(); - map_contents - } - EmitTo::First(n) => { - let emit_group_values = map_contents.slice(0, n); - let remaining_group_values = - map_contents.slice(n, map_contents.len() - n); - - *num_groups = 0; - map.insert_if_new( - &remaining_group_values, - |_value| { - let group_idx = *num_groups; - *num_groups += 1; - group_idx - }, - |_group_idx| {}, - ); - - emit_group_values - } - } - } - - fn emit_distinct_only(set: &mut ArrowBytesViewSet, emit_to: EmitTo) -> ArrayRef { - let set_contents = set.take().into_state(); - match emit_to { - EmitTo::All => set_contents, - EmitTo::First(n) if n == set_contents.len() => set_contents, - EmitTo::First(n) => { - let emit_group_values = set_contents.slice(0, n); - let remaining_group_values = - set_contents.slice(n, set_contents.len() - n); - set.insert(&remaining_group_values); - emit_group_values - } + pub fn new(output_type: OutputType) -> Self { + Self { + map: ArrowBytesViewMap::new(output_type), + num_groups: 0, } } } @@ -137,25 +48,25 @@ impl GroupValues for GroupValuesBytesView { cols: &[ArrayRef], groups: &mut Vec, ) -> datafusion_common::Result<()> { - self.ensure_group_id_tracking(); assert_eq!(cols.len(), 1); // look up / add entries in the table let arr = &cols[0]; - let GroupValuesBytesViewState::GroupIds { map, num_groups } = &mut self.state - else { - unreachable!(); - }; groups.clear(); - map.insert_if_new( + self.map.insert_if_new( arr, + // called for each new group |_value| { - let group_idx = *num_groups; - *num_groups += 1; + // assign new group index on each insert + let group_idx = self.num_groups; + self.num_groups += 1; group_idx }, - |group_idx| groups.push(group_idx), + // called for each group + |group_idx| { + groups.push(group_idx); + }, ); // ensure we assigned a group to for each row @@ -163,58 +74,48 @@ impl GroupValues for GroupValuesBytesView { Ok(()) } - fn intern_no_group_ids( - &mut self, - cols: &[ArrayRef], - ) -> datafusion_common::Result<()> { - assert_eq!(cols.len(), 1); - - let arr = &cols[0]; - match &mut self.state { - GroupValuesBytesViewState::GroupIds { map, num_groups } => map.insert_if_new( - arr, - |_value| { - let group_idx = *num_groups; - *num_groups += 1; - group_idx - }, - |_group_idx| {}, - ), - GroupValuesBytesViewState::DistinctOnly(set) => set.insert(arr), - } - - Ok(()) - } - fn size(&self) -> usize { - size_of::() - + match &self.state { - GroupValuesBytesViewState::GroupIds { map, .. } => map.size(), - GroupValuesBytesViewState::DistinctOnly(set) => set.size(), - } + self.map.size() + size_of::() } fn is_empty(&self) -> bool { - match &self.state { - GroupValuesBytesViewState::GroupIds { num_groups, .. } => *num_groups == 0, - GroupValuesBytesViewState::DistinctOnly(set) => set.is_empty(), - } + self.num_groups == 0 } fn len(&self) -> usize { - match &self.state { - GroupValuesBytesViewState::GroupIds { num_groups, .. } => *num_groups, - GroupValuesBytesViewState::DistinctOnly(set) => set.len(), - } + self.num_groups } fn emit(&mut self, emit_to: EmitTo) -> datafusion_common::Result> { - let group_values = match &mut self.state { - GroupValuesBytesViewState::GroupIds { map, num_groups } => { - Self::emit_group_ids(map, num_groups, emit_to) + // Reset the map to default, and convert it into a single array + let map_contents = self.map.take().into_state(); + + let group_values = match emit_to { + EmitTo::All => { + self.num_groups -= map_contents.len(); + map_contents } - GroupValuesBytesViewState::DistinctOnly(set) => { - Self::emit_distinct_only(set, emit_to) + EmitTo::First(n) if n == self.len() => { + self.num_groups -= map_contents.len(); + map_contents + } + EmitTo::First(n) => { + // if we only wanted to take the first n, insert the rest back + // into the map we could potentially avoid this reallocation, at + // the expense of much more complex code. + // see https://github.com/apache/datafusion/issues/9195 + let emit_group_values = map_contents.slice(0, n); + let remaining_group_values = + map_contents.slice(n, map_contents.len() - n); + + self.num_groups = 0; + let mut group_indexes = vec![]; + self.intern(&[remaining_group_values], &mut group_indexes)?; + + // Verify that the group indexes were assigned in the correct order + assert_eq!(0, group_indexes[0]); + + emit_group_values } }; @@ -222,14 +123,8 @@ impl GroupValues for GroupValuesBytesView { } fn clear_shrink(&mut self, _num_rows: usize) { - match &mut self.state { - GroupValuesBytesViewState::GroupIds { map, num_groups } => { - *num_groups = 0; - map.take(); - } - GroupValuesBytesViewState::DistinctOnly(set) => { - set.take(); - } - } + // in theory we could potentially avoid this reallocation and clear the + // contents of the maps, but for now we just reset the map from the beginning + self.map.take(); } } diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 6b2a9bc174176..bc8e532d3090a 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -1288,23 +1288,17 @@ impl GroupedHashAggregateStream { // on the grouping columns. self.group_ordering = GroupOrdering::Full(GroupOrderingFull::new()); - // Recreate `group_values` for streaming merge when the previous collector - // could emit groups out of first-seen order. This is required for: - // - the multi-column collector, which may use `vectorized_intern` - // - the unordered distinct-only collectors, which deliberately do not - // preserve first-seen order while building the hash table + // Recreate `group_values` for streaming merge so group ids are assigned + // in first-seen order, as required by `GroupOrderingFull`. + // The pre-spill multi-column collector may use `vectorized_intern`, which + // can assign new group ids out of input order under hash collisions. let group_schema = self .spill_state .merging_group_by .group_schema(&self.spill_state.spill_schema)?; - let require_group_indices = !self.accumulators.is_empty() - || matches!(self.group_ordering, GroupOrdering::Partial(_)); - if group_schema.fields().len() > 1 || !require_group_indices { - self.group_values = new_group_values( - group_schema, - &self.group_ordering, - require_group_indices, - )?; + if group_schema.fields().len() > 1 { + self.group_values = + new_group_values(group_schema, &self.group_ordering, true)?; } // Use `OutOfMemoryMode::ReportError` from this point on From 4b7264c749c7306557fc33fa9177868b960a8b37 Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Sat, 18 Apr 2026 12:19:55 +0800 Subject: [PATCH 3/5] Trim distinct-only API surface for q4 fast path --- .../src/aggregates/group_values/mod.rs | 17 +++--- .../group_values/single_group_by/primitive.rs | 60 +------------------ .../physical-plan/src/aggregates/row_hash.rs | 18 ++++-- .../physical-plan/src/recursive_query.rs | 2 +- 4 files changed, 22 insertions(+), 75 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index ad27d30cd6ee6..9db2dd7878e21 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -99,16 +99,6 @@ pub trait GroupValues: Send { /// assigned. fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()>; - /// Interns the rows from `cols` without materializing a group id per input - /// row. - /// - /// This is useful for hash aggregate operators that only need the set of - /// distinct group keys and have no per-row accumulator updates to perform. - fn intern_no_group_ids(&mut self, cols: &[ArrayRef]) -> Result<()> { - let mut groups = Vec::new(); - self.intern(cols, &mut groups) - } - /// Returns the number of bytes of memory used by this [`GroupValues`] fn size(&self) -> usize; @@ -144,6 +134,13 @@ pub trait GroupValues: Send { pub fn new_group_values( schema: SchemaRef, group_ordering: &GroupOrdering, +) -> Result> { + new_group_values_with_group_indices(schema, group_ordering, true) +} + +pub(crate) fn new_group_values_with_group_indices( + schema: SchemaRef, + group_ordering: &GroupOrdering, require_group_indices: bool, ) -> Result> { if schema.fields.len() == 1 { diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs index 4e6d50d6a3662..03f740e7841e3 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs @@ -148,42 +148,6 @@ where PrimitiveArray::::new(values.into(), nulls) } - fn ensure_group_id_tracking(&mut self) { - if matches!(self.state, GroupValuesPrimitiveState::GroupIds { .. }) { - return; - } - - let GroupValuesPrimitiveState::DistinctOnly { map, has_null } = std::mem::replace( - &mut self.state, - GroupValuesPrimitiveState::GroupIds { - map: HashTable::with_capacity(128), - null_group: None, - values: Vec::with_capacity(128), - }, - ) else { - unreachable!(); - }; - - let mut values = Vec::with_capacity(map.len() + usize::from(has_null)); - let null_group = has_null.then(|| { - values.push(Default::default()); - 0 - }); - let mut group_map = HashTable::with_capacity(map.len()); - for value in map { - let group_idx = values.len(); - values.push(value); - let hash = value.hash(&self.random_state); - group_map - .insert_unique(hash, (group_idx, hash), |&(_, stored_hash)| stored_hash); - } - self.state = GroupValuesPrimitiveState::GroupIds { - map: group_map, - null_group, - values, - }; - } - fn insert_group_id( random_state: &RandomState, map: &mut HashTable<(usize, u64)>, @@ -351,29 +315,8 @@ where T::Native: HashValue, { fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { - self.ensure_group_id_tracking(); assert_eq!(cols.len(), 1); groups.clear(); - let GroupValuesPrimitiveState::GroupIds { - map, - null_group, - values, - } = &mut self.state - else { - unreachable!(); - }; - - for v in cols[0].as_primitive::() { - let group_id = - Self::insert_group_id(&self.random_state, map, values, null_group, v); - groups.push(group_id) - } - Ok(()) - } - - fn intern_no_group_ids(&mut self, cols: &[ArrayRef]) -> Result<()> { - assert_eq!(cols.len(), 1); - match &mut self.state { GroupValuesPrimitiveState::GroupIds { map, @@ -381,13 +324,14 @@ where values, } => { for v in cols[0].as_primitive::() { - let _ = Self::insert_group_id( + let group_id = Self::insert_group_id( &self.random_state, map, values, null_group, v, ); + groups.push(group_id); } } GroupValuesPrimitiveState::DistinctOnly { map, has_null } => { diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index bc8e532d3090a..cd93b019ca674 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -23,7 +23,9 @@ use std::vec; use super::AggregateExec; use super::order::GroupOrdering; -use crate::aggregates::group_values::{GroupByMetrics, GroupValues, new_group_values}; +use crate::aggregates::group_values::{ + GroupByMetrics, GroupValues, new_group_values, new_group_values_with_group_indices, +}; use crate::aggregates::order::GroupOrderingFull; use crate::aggregates::{ AggregateInputMode, AggregateMode, AggregateOutputMode, PhysicalGroupBy, @@ -589,8 +591,11 @@ impl GroupedHashAggregateStream { let require_group_indices = !accumulators.is_empty() || matches!(group_ordering, GroupOrdering::Partial(_)); - let group_values = - new_group_values(group_schema, &group_ordering, require_group_indices)?; + let group_values = new_group_values_with_group_indices( + group_schema, + &group_ordering, + require_group_indices, + )?; let reservation = MemoryConsumer::new(name) // We interpret 'can spill' as 'can handle memory back pressure'. // This value needs to be set to true for the default memory pool implementations @@ -1013,7 +1018,9 @@ impl GroupedHashAggregateStream { .add_elapsed(agg_start_time); } } else { - self.group_values.intern_no_group_ids(group_values)?; + self.current_group_indices.clear(); + self.group_values + .intern(group_values, &mut self.current_group_indices)?; let total_num_groups = self.group_values.len(); if total_num_groups > starting_num_groups { @@ -1297,8 +1304,7 @@ impl GroupedHashAggregateStream { .merging_group_by .group_schema(&self.spill_state.spill_schema)?; if group_schema.fields().len() > 1 { - self.group_values = - new_group_values(group_schema, &self.group_ordering, true)?; + self.group_values = new_group_values(group_schema, &self.group_ordering)?; } // Use `OutOfMemoryMode::ReportError` from this point on diff --git a/datafusion/physical-plan/src/recursive_query.rs b/datafusion/physical-plan/src/recursive_query.rs index 49b4c42b7bff6..35b787759441c 100644 --- a/datafusion/physical-plan/src/recursive_query.rs +++ b/datafusion/physical-plan/src/recursive_query.rs @@ -441,7 +441,7 @@ struct DistinctDeduplicator { impl DistinctDeduplicator { fn new(schema: SchemaRef, task_context: &TaskContext) -> Result { - let group_values = new_group_values(schema, &GroupOrdering::None, true)?; + let group_values = new_group_values(schema, &GroupOrdering::None)?; let reservation = MemoryConsumer::new("RecursiveQueryHashTable") .register(task_context.memory_pool()); Ok(Self { From 2aa55a4a20e982cca833bdb23bae190766495bfb Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Tue, 12 May 2026 16:36:24 +0800 Subject: [PATCH 4/5] handle spill case Signed-off-by: Ruihang Xia --- .../physical-plan/src/aggregates/row_hash.rs | 85 ++++++++++++++++++- 1 file changed, 82 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index cd93b019ca674..4ea841d6142c6 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -1297,13 +1297,15 @@ impl GroupedHashAggregateStream { // Recreate `group_values` for streaming merge so group ids are assigned // in first-seen order, as required by `GroupOrderingFull`. - // The pre-spill multi-column collector may use `vectorized_intern`, which - // can assign new group ids out of input order under hash collisions. + // The pre-spill collector may not track group ids for DISTINCT-only + // aggregation. The pre-spill multi-column collector may use + // `vectorized_intern`, which can assign new group ids out of input + // order under hash collisions. let group_schema = self .spill_state .merging_group_by .group_schema(&self.spill_state.spill_schema)?; - if group_schema.fields().len() > 1 { + if self.accumulators.is_empty() || group_schema.fields().len() > 1 { self.group_values = new_group_values(group_schema, &self.group_ordering)?; } @@ -1393,9 +1395,12 @@ mod tests { use super::*; use crate::InputOrderMode; use crate::execution_plan::ExecutionPlan; + use crate::metrics::MetricValue; use crate::test::TestMemoryExec; use arrow::array::{Int32Array, Int64Array}; use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_execution::config::SessionConfig; + use datafusion_execution::memory_pool::FairSpillPool; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::aggregate::AggregateExprBuilder; @@ -1507,6 +1512,80 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_spill_distinct_single_primitive_group_by() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new( + "group_col", + DataType::Int32, + false, + )])); + + let num_distinct = 512; + let num_spills = 24; + let input_partitions = vec![ + (0..num_spills) + .map(|_| { + Ok(RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from_iter_values(0..num_distinct))], + )?) + }) + .collect::>>()?, + ]; + + let session_config = SessionConfig::new().with_batch_size(4); + let runtime = RuntimeEnvBuilder::new() + .with_memory_pool(Arc::new(FairSpillPool::new(6000))) + .build_arc()?; + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(session_config) + .with_runtime(runtime), + ); + + let group_expr = vec![(col("group_col", &schema)?, "group_col".to_string())]; + let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?; + let exec = Arc::new(TestMemoryExec::update_cache(&Arc::new(exec))); + + let aggregate_exec = AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new_single(group_expr), + Vec::>::new(), + vec![], + exec, + Arc::clone(&schema), + )?; + + let mut stream = + GroupedHashAggregateStream::new(&aggregate_exec, &Arc::clone(&task_ctx), 0)?; + let mut values = Vec::new(); + + while let Some(result) = stream.next().await { + let batch = result?; + let group_col = batch + .column(0) + .as_primitive::(); + values.extend(group_col.values().iter().copied()); + } + + let spill_count = aggregate_exec + .metrics() + .unwrap() + .iter() + .find_map(|metric| match metric.value() { + MetricValue::SpillCount(count) => Some(count.value()), + _ => None, + }) + .unwrap_or_default(); + assert!(spill_count > 0, "expected test to exercise spilling"); + + values.sort_unstable(); + let expected = (0..num_distinct).collect::>(); + assert_eq!(values, expected); + + Ok(()) + } + #[tokio::test] async fn test_skip_aggregation_probe_not_locked_until_skip() -> Result<()> { // Test that the probe is not locked until we actually decide to skip. From 688efd869504ff779a1545f09878ddfad156860d Mon Sep 17 00:00:00 2001 From: Ruihang Xia Date: Thu, 14 May 2026 11:39:42 +0800 Subject: [PATCH 5/5] sort nondeterministic test Signed-off-by: Ruihang Xia --- .../limited_distinct_aggregation.rs | 67 +++---------------- 1 file changed, 8 insertions(+), 59 deletions(-) diff --git a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs index c523b4a752a82..1f41848d81999 100644 --- a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs @@ -25,8 +25,8 @@ use crate::physical_optimizer::test_utils::{ schema, }; +use arrow::compute::SortOptions; use arrow::datatypes::DataType; -use arrow::{compute::SortOptions, util::pretty::pretty_format_batches}; use datafusion::prelude::SessionContext; use datafusion_common::Result; use datafusion_execution::config::SessionConfig; @@ -40,12 +40,12 @@ use datafusion_physical_plan::{ limit::{GlobalLimitExec, LocalLimitExec}, }; -async fn run_plan_and_format(plan: Arc) -> Result { +async fn run_plan_and_count_rows(plan: Arc) -> Result { let cfg = SessionConfig::new().with_target_partitions(1); let ctx = SessionContext::new_with_config(cfg); let batches = collect(plan, ctx.task_ctx()).await?; - let actual = format!("{}", pretty_format_batches(&batches)?); - Ok(actual) + // These plans have LIMIT without ORDER BY, so the row order is not stable. + Ok(batches.iter().map(|batch| batch.num_rows()).sum()) } #[tokio::test] @@ -86,20 +86,7 @@ async fn test_partial_final() -> Result<()> { DataSourceExec: partitions=1, partition_sizes=[1] " ); - let expected = run_plan_and_format(plan).await?; - assert_snapshot!( - expected, - @r" - +---+ - | a | - +---+ - | 1 | - | 2 | - | | - | 4 | - +---+ - " - ); + assert_eq!(run_plan_and_count_rows(plan).await?, 4); Ok(()) } @@ -134,20 +121,7 @@ async fn test_single_local() -> Result<()> { DataSourceExec: partitions=1, partition_sizes=[1] " ); - let expected = run_plan_and_format(plan).await?; - assert_snapshot!( - expected, - @r" - +---+ - | a | - +---+ - | 1 | - | 2 | - | | - | 4 | - +---+ - " - ); + assert_eq!(run_plan_and_count_rows(plan).await?, 4); Ok(()) } @@ -182,19 +156,7 @@ async fn test_single_global() -> Result<()> { DataSourceExec: partitions=1, partition_sizes=[1] " ); - let expected = run_plan_and_format(plan).await?; - assert_snapshot!( - expected, - @r" - +---+ - | a | - +---+ - | 2 | - | | - | 4 | - +---+ - " - ); + assert_eq!(run_plan_and_count_rows(plan).await?, 3); Ok(()) } @@ -237,20 +199,7 @@ async fn test_distinct_cols_different_than_group_by_cols() -> Result<()> { DataSourceExec: partitions=1, partition_sizes=[1] " ); - let expected = run_plan_and_format(plan).await?; - assert_snapshot!( - expected, - @r" - +---+ - | a | - +---+ - | 1 | - | 2 | - | | - | 4 | - +---+ - " - ); + assert_eq!(run_plan_and_count_rows(plan).await?, 4); Ok(()) }