diff --git a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs index c523b4a752a82..1f41848d81999 100644 --- a/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs +++ b/datafusion/core/tests/physical_optimizer/limited_distinct_aggregation.rs @@ -25,8 +25,8 @@ use crate::physical_optimizer::test_utils::{ schema, }; +use arrow::compute::SortOptions; use arrow::datatypes::DataType; -use arrow::{compute::SortOptions, util::pretty::pretty_format_batches}; use datafusion::prelude::SessionContext; use datafusion_common::Result; use datafusion_execution::config::SessionConfig; @@ -40,12 +40,12 @@ use datafusion_physical_plan::{ limit::{GlobalLimitExec, LocalLimitExec}, }; -async fn run_plan_and_format(plan: Arc) -> Result { +async fn run_plan_and_count_rows(plan: Arc) -> Result { let cfg = SessionConfig::new().with_target_partitions(1); let ctx = SessionContext::new_with_config(cfg); let batches = collect(plan, ctx.task_ctx()).await?; - let actual = format!("{}", pretty_format_batches(&batches)?); - Ok(actual) + // These plans have LIMIT without ORDER BY, so the row order is not stable. + Ok(batches.iter().map(|batch| batch.num_rows()).sum()) } #[tokio::test] @@ -86,20 +86,7 @@ async fn test_partial_final() -> Result<()> { DataSourceExec: partitions=1, partition_sizes=[1] " ); - let expected = run_plan_and_format(plan).await?; - assert_snapshot!( - expected, - @r" - +---+ - | a | - +---+ - | 1 | - | 2 | - | | - | 4 | - +---+ - " - ); + assert_eq!(run_plan_and_count_rows(plan).await?, 4); Ok(()) } @@ -134,20 +121,7 @@ async fn test_single_local() -> Result<()> { DataSourceExec: partitions=1, partition_sizes=[1] " ); - let expected = run_plan_and_format(plan).await?; - assert_snapshot!( - expected, - @r" - +---+ - | a | - +---+ - | 1 | - | 2 | - | | - | 4 | - +---+ - " - ); + assert_eq!(run_plan_and_count_rows(plan).await?, 4); Ok(()) } @@ -182,19 +156,7 @@ async fn test_single_global() -> Result<()> { DataSourceExec: partitions=1, partition_sizes=[1] " ); - let expected = run_plan_and_format(plan).await?; - assert_snapshot!( - expected, - @r" - +---+ - | a | - +---+ - | 2 | - | | - | 4 | - +---+ - " - ); + assert_eq!(run_plan_and_count_rows(plan).await?, 3); Ok(()) } @@ -237,20 +199,7 @@ async fn test_distinct_cols_different_than_group_by_cols() -> Result<()> { DataSourceExec: partitions=1, partition_sizes=[1] " ); - let expected = run_plan_and_format(plan).await?; - assert_snapshot!( - expected, - @r" - +---+ - | a | - +---+ - | 1 | - | 2 | - | | - | 4 | - +---+ - " - ); + assert_eq!(run_plan_and_count_rows(plan).await?, 4); Ok(()) } diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index 2f3b1a19e7d73..9db2dd7878e21 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -134,46 +134,75 @@ pub trait GroupValues: Send { pub fn new_group_values( schema: SchemaRef, group_ordering: &GroupOrdering, +) -> Result> { + new_group_values_with_group_indices(schema, group_ordering, true) +} + +pub(crate) fn new_group_values_with_group_indices( + schema: SchemaRef, + group_ordering: &GroupOrdering, + require_group_indices: bool, ) -> Result> { if schema.fields.len() == 1 { let d = schema.fields[0].data_type(); + let track_group_ids = + require_group_indices || !matches!(group_ordering, GroupOrdering::None); macro_rules! downcast_helper { - ($t:ty, $d:ident) => { - return Ok(Box::new(GroupValuesPrimitive::<$t>::new($d.clone()))) + ($t:ty, $d:ident, $track_group_ids:expr) => { + return Ok(Box::new(GroupValuesPrimitive::<$t>::new( + $d.clone(), + $track_group_ids, + ))) }; } downcast_primitive! { - d => (downcast_helper, d), + d => (downcast_helper, d, track_group_ids), _ => {} } match d { DataType::Date32 => { - downcast_helper!(Date32Type, d); + downcast_helper!(Date32Type, d, track_group_ids); } DataType::Date64 => { - downcast_helper!(Date64Type, d); + downcast_helper!(Date64Type, d, track_group_ids); } DataType::Time32(t) => match t { - TimeUnit::Second => downcast_helper!(Time32SecondType, d), - TimeUnit::Millisecond => downcast_helper!(Time32MillisecondType, d), + TimeUnit::Second => { + downcast_helper!(Time32SecondType, d, track_group_ids) + } + TimeUnit::Millisecond => { + downcast_helper!(Time32MillisecondType, d, track_group_ids) + } _ => {} }, DataType::Time64(t) => match t { - TimeUnit::Microsecond => downcast_helper!(Time64MicrosecondType, d), - TimeUnit::Nanosecond => downcast_helper!(Time64NanosecondType, d), + TimeUnit::Microsecond => { + downcast_helper!(Time64MicrosecondType, d, track_group_ids) + } + TimeUnit::Nanosecond => { + downcast_helper!(Time64NanosecondType, d, track_group_ids) + } _ => {} }, DataType::Timestamp(t, _tz) => match t { - TimeUnit::Second => downcast_helper!(TimestampSecondType, d), - TimeUnit::Millisecond => downcast_helper!(TimestampMillisecondType, d), - TimeUnit::Microsecond => downcast_helper!(TimestampMicrosecondType, d), - TimeUnit::Nanosecond => downcast_helper!(TimestampNanosecondType, d), + TimeUnit::Second => { + downcast_helper!(TimestampSecondType, d, track_group_ids) + } + TimeUnit::Millisecond => { + downcast_helper!(TimestampMillisecondType, d, track_group_ids) + } + TimeUnit::Microsecond => { + downcast_helper!(TimestampMicrosecondType, d, track_group_ids) + } + TimeUnit::Nanosecond => { + downcast_helper!(TimestampNanosecondType, d, track_group_ids) + } }, DataType::Decimal128(_, _) => { - downcast_helper!(Decimal128Type, d); + downcast_helper!(Decimal128Type, d, track_group_ids); } DataType::Utf8 => { return Ok(Box::new(GroupValuesBytes::::new(OutputType::Utf8))); diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs index efaf7eba0f1b5..03f740e7841e3 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/primitive.rs @@ -79,147 +79,334 @@ hash_float!(f16, f32, f64); /// /// This specialization is significantly faster than using the more general /// purpose `Row`s format +enum GroupValuesPrimitiveState { + GroupIds { + /// Stores the `(group_index, hash)` based on the hash of its value + /// + /// We also store `hash` is for reducing cost of rehashing. Such cost + /// is obvious in high cardinality group by situation. + /// More details can see: + /// + map: HashTable<(usize, u64)>, + /// The group index of the null value if any + null_group: Option, + /// The values for each group index + values: Vec, + }, + DistinctOnly { + /// Stores the distinct primitive values. + map: HashTable, + has_null: bool, + }, +} + pub struct GroupValuesPrimitive { /// The data type of the output array data_type: DataType, - /// Stores the `(group_index, hash)` based on the hash of its value - /// - /// We also store `hash` is for reducing cost of rehashing. Such cost - /// is obvious in high cardinality group by situation. - /// More details can see: - /// - map: HashTable<(usize, u64)>, - /// The group index of the null value if any - null_group: Option, - /// The values for each group index - values: Vec, + state: GroupValuesPrimitiveState, /// The random state used to generate hashes random_state: RandomState, } -impl GroupValuesPrimitive { - pub fn new(data_type: DataType) -> Self { +impl GroupValuesPrimitive +where + T::Native: HashValue, +{ + pub fn new(data_type: DataType, track_group_ids: bool) -> Self { assert!(PrimitiveArray::::is_compatible(&data_type)); + let state = if track_group_ids { + GroupValuesPrimitiveState::GroupIds { + map: HashTable::with_capacity(128), + values: Vec::with_capacity(128), + null_group: None, + } + } else { + GroupValuesPrimitiveState::DistinctOnly { + map: HashTable::with_capacity(128), + has_null: false, + } + }; Self { data_type, - map: HashTable::with_capacity(128), - values: Vec::with_capacity(128), - null_group: None, + state, random_state: crate::aggregates::AGGREGATION_HASH_SEED, } } -} -impl GroupValues for GroupValuesPrimitive -where - T::Native: HashValue, -{ - fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { - assert_eq!(cols.len(), 1); - groups.clear(); + fn build_primitive( + values: Vec, + null_idx: Option, + ) -> PrimitiveArray { + let nulls = null_idx.map(|null_idx| { + let mut buffer = NullBufferBuilder::new(values.len()); + buffer.append_n_non_nulls(null_idx); + buffer.append_null(); + buffer.append_n_non_nulls(values.len() - null_idx - 1); + // NOTE: The inner builder must be constructed as there is at least one null + buffer.finish().unwrap() + }); + PrimitiveArray::::new(values.into(), nulls) + } - for v in cols[0].as_primitive::() { - let group_id = match v { - None => *self.null_group.get_or_insert_with(|| { - let group_id = self.values.len(); - self.values.push(Default::default()); - group_id - }), - Some(key) => { - let state = &self.random_state; - let hash = key.hash(state); - let insert = self.map.entry( - hash, - |&(g, h)| unsafe { - hash == h && self.values.get_unchecked(g).is_eq(key) - }, - |&(_, h)| h, - ); + fn insert_group_id( + random_state: &RandomState, + map: &mut HashTable<(usize, u64)>, + values: &mut Vec, + null_group: &mut Option, + value: Option, + ) -> usize { + match value { + None => *null_group.get_or_insert_with(|| { + let group_id = values.len(); + values.push(Default::default()); + group_id + }), + Some(key) => { + let hash = key.hash(random_state); + let insert = map.entry( + hash, + |&(g, h)| unsafe { hash == h && values.get_unchecked(g).is_eq(key) }, + |&(_, h)| h, + ); - match insert { - hashbrown::hash_table::Entry::Occupied(o) => o.get().0, - hashbrown::hash_table::Entry::Vacant(v) => { - let g = self.values.len(); - v.insert((g, hash)); - self.values.push(key); - g - } + match insert { + hashbrown::hash_table::Entry::Occupied(o) => o.get().0, + hashbrown::hash_table::Entry::Vacant(v) => { + let g = values.len(); + v.insert((g, hash)); + values.push(key); + g } } - }; - groups.push(group_id) + } } - Ok(()) - } - - fn size(&self) -> usize { - self.map.capacity() * size_of::<(usize, u64)>() + self.values.allocated_size() } - fn is_empty(&self) -> bool { - self.values.is_empty() + fn insert_distinct_only( + random_state: &RandomState, + map: &mut HashTable, + has_null: &mut bool, + value: Option, + ) { + match value { + None => { + *has_null = true; + } + Some(key) => { + Self::insert_distinct_only_hashed( + random_state, + map, + key, + key.hash(random_state), + ); + } + } } - fn len(&self) -> usize { - self.values.len() - } + fn insert_distinct_only_hashed( + random_state: &RandomState, + map: &mut HashTable, + key: T::Native, + hash: u64, + ) { + let insert = map.entry( + hash, + |stored| stored.is_eq(key), + |stored| stored.hash(random_state), + ); - fn emit(&mut self, emit_to: EmitTo) -> Result> { - fn build_primitive( - values: Vec, - null_idx: Option, - ) -> PrimitiveArray { - let nulls = null_idx.map(|null_idx| { - let mut buffer = NullBufferBuilder::new(values.len()); - buffer.append_n_non_nulls(null_idx); - buffer.append_null(); - buffer.append_n_non_nulls(values.len() - null_idx - 1); - // NOTE: The inner builder must be constructed as there is at least one null - buffer.finish().unwrap() - }); - PrimitiveArray::::new(values.into(), nulls) + if let hashbrown::hash_table::Entry::Vacant(v) = insert { + v.insert(key); } + } + fn emit_group_ids( + data_type: &DataType, + map: &mut HashTable<(usize, u64)>, + values: &mut Vec, + null_group: &mut Option, + emit_to: EmitTo, + ) -> ArrayRef { let array: PrimitiveArray = match emit_to { EmitTo::All => { - self.map.clear(); - build_primitive(std::mem::take(&mut self.values), self.null_group.take()) + map.clear(); + Self::build_primitive(std::mem::take(values), null_group.take()) } EmitTo::First(n) => { - self.map.retain(|entry| { - // Decrement group index by n + map.retain(|entry| { let group_idx = entry.0; match group_idx.checked_sub(n) { - // Group index was >= n, shift value down Some(sub) => { entry.0 = sub; true } - // Group index was < n, so remove from table None => false, } }); - let null_group = match &mut self.null_group { + let null_idx = match null_group { Some(v) if *v >= n => { *v -= n; None } - Some(_) => self.null_group.take(), + Some(_) => null_group.take(), None => None, }; - let mut split = self.values.split_off(n); - std::mem::swap(&mut self.values, &mut split); - build_primitive(split, null_group) + let mut split = values.split_off(n); + std::mem::swap(values, &mut split); + Self::build_primitive(split, null_idx) + } + }; + + Arc::new(array.with_data_type(data_type.clone())) + } + + fn emit_distinct_only( + data_type: &DataType, + random_state: &RandomState, + map: &mut HashTable, + has_null: &mut bool, + emit_to: EmitTo, + ) -> ArrayRef { + let total_len = map.len() + usize::from(*has_null); + let mut values = Vec::with_capacity(total_len); + if *has_null { + values.push(Default::default()); + } + values.extend(map.iter().copied()); + map.clear(); + + let (emitted_values, emitted_null_idx, remaining_values, remaining_has_null) = + match emit_to { + EmitTo::All => (values, (*has_null).then_some(0), Vec::new(), false), + EmitTo::First(n) if n >= total_len => { + (values, (*has_null).then_some(0), Vec::new(), false) + } + EmitTo::First(n) => { + let mut remaining_values = values.split_off(n); + let emitted_values = values; + let emitted_null_idx = (*has_null && n > 0).then_some(0); + let remaining_has_null = *has_null && n == 0; + if remaining_has_null { + remaining_values.remove(0); + } + ( + emitted_values, + emitted_null_idx, + remaining_values, + remaining_has_null, + ) + } + }; + + *has_null = remaining_has_null; + for value in remaining_values { + Self::insert_distinct_only(random_state, map, has_null, Some(value)); + } + + Arc::new( + Self::build_primitive(emitted_values, emitted_null_idx) + .with_data_type(data_type.clone()), + ) + } +} + +impl GroupValues for GroupValuesPrimitive +where + T::Native: HashValue, +{ + fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + assert_eq!(cols.len(), 1); + groups.clear(); + match &mut self.state { + GroupValuesPrimitiveState::GroupIds { + map, + null_group, + values, + } => { + for v in cols[0].as_primitive::() { + let group_id = Self::insert_group_id( + &self.random_state, + map, + values, + null_group, + v, + ); + groups.push(group_id); + } + } + GroupValuesPrimitiveState::DistinctOnly { map, has_null } => { + for v in cols[0].as_primitive::() { + Self::insert_distinct_only(&self.random_state, map, has_null, v); + } + } + } + Ok(()) + } + + fn size(&self) -> usize { + size_of::() + + match &self.state { + GroupValuesPrimitiveState::GroupIds { map, values, .. } => { + map.capacity() * size_of::<(usize, u64)>() + values.allocated_size() + } + GroupValuesPrimitiveState::DistinctOnly { map, .. } => { + map.capacity() * size_of::() + } + } + } + + fn is_empty(&self) -> bool { + self.len() == 0 + } + + fn len(&self) -> usize { + match &self.state { + GroupValuesPrimitiveState::GroupIds { values, .. } => values.len(), + GroupValuesPrimitiveState::DistinctOnly { map, has_null, .. } => { + map.len() + usize::from(*has_null) + } + } + } + + fn emit(&mut self, emit_to: EmitTo) -> Result> { + let array = match &mut self.state { + GroupValuesPrimitiveState::GroupIds { + map, + null_group, + values, + } => Self::emit_group_ids(&self.data_type, map, values, null_group, emit_to), + GroupValuesPrimitiveState::DistinctOnly { map, has_null, .. } => { + Self::emit_distinct_only( + &self.data_type, + &self.random_state, + map, + has_null, + emit_to, + ) } }; - Ok(vec![Arc::new(array.with_data_type(self.data_type.clone()))]) + Ok(vec![array]) } fn clear_shrink(&mut self, num_rows: usize) { - self.values.clear(); - self.values.shrink_to(num_rows); - self.map.clear(); - self.map.shrink_to(num_rows, |_| 0); // hasher does not matter since the map is cleared + match &mut self.state { + GroupValuesPrimitiveState::GroupIds { + map, + null_group, + values, + } => { + *null_group = None; + values.clear(); + values.shrink_to(num_rows); + map.clear(); + map.shrink_to(num_rows, |_| 0); + } + GroupValuesPrimitiveState::DistinctOnly { map, has_null } => { + *has_null = false; + map.clear(); + map.shrink_to(num_rows, |_| 0); + } + } } } diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index a55cf09c79b0a..5aafa56818ca8 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -23,7 +23,9 @@ use std::vec; use super::AggregateExec; use super::order::GroupOrdering; -use crate::aggregates::group_values::{GroupByMetrics, GroupValues, new_group_values}; +use crate::aggregates::group_values::{ + GroupByMetrics, GroupValues, new_group_values, new_group_values_with_group_indices, +}; use crate::aggregates::order::GroupOrderingFull; use crate::aggregates::{ AggregateInputMode, AggregateMode, AggregateOutputMode, PhysicalGroupBy, @@ -590,7 +592,13 @@ impl GroupedHashAggregateStream { _ => OutOfMemoryMode::ReportError, }; - let group_values = new_group_values(group_schema, &group_ordering)?; + let require_group_indices = !accumulators.is_empty() + || matches!(group_ordering, GroupOrdering::Partial(_)); + let group_values = new_group_values_with_group_indices( + group_schema, + &group_ordering, + require_group_indices, + )?; let reservation = MemoryConsumer::new(name) // We interpret 'can spill' as 'can handle memory back pressure'. // This value needs to be set to true for the default memory pool implementations @@ -955,62 +963,89 @@ impl GroupedHashAggregateStream { for group_values in &group_by_values { let groups_start_time = Instant::now(); - // calculate the group indices for each input row let starting_num_groups = self.group_values.len(); - self.group_values - .intern(group_values, &mut self.current_group_indices)?; - let group_indices = &self.current_group_indices; - - // Update ordering information if necessary - let total_num_groups = self.group_values.len(); - if total_num_groups > starting_num_groups { - self.group_ordering.new_groups( - group_values, - group_indices, - total_num_groups, - )?; - } - - // Use this instant for both measurements to save a syscall - let agg_start_time = Instant::now(); - self.group_by_metrics - .time_calculating_group_ids - .add_duration(agg_start_time - groups_start_time); - - // Gather the inputs to call the actual accumulator - let t = self - .accumulators - .iter_mut() - .zip(input_values.iter()) - .zip(filter_values.iter()); - - for ((acc, values), opt_filter) in t { - let opt_filter = opt_filter.as_ref().map(|filter| filter.as_boolean()); - - // Call the appropriate method on each aggregator with - // the entire input row and the relevant group indexes - if self.mode.input_mode() == AggregateInputMode::Raw - && !self.spill_state.is_stream_merging - { - acc.update_batch( - values, + let needs_group_indices = !self.accumulators.is_empty() + || matches!(self.group_ordering, GroupOrdering::Partial(_)); + + if needs_group_indices { + // calculate the group indices for each input row + self.group_values + .intern(group_values, &mut self.current_group_indices)?; + let group_indices = &self.current_group_indices; + + // Update ordering information if necessary + let total_num_groups = self.group_values.len(); + if total_num_groups > starting_num_groups { + self.group_ordering.new_groups( + group_values, group_indices, - opt_filter, total_num_groups, )?; - } else { - assert_or_internal_err!( - opt_filter.is_none(), - "aggregate filter should be applied in partial stage, there should be no filter in final stage" - ); - - // if aggregation is over intermediate states, - // use merge - acc.merge_batch(values, group_indices, None, total_num_groups)?; } + + // Use this instant for both measurements to save a syscall + let agg_start_time = Instant::now(); self.group_by_metrics - .aggregation_time - .add_elapsed(agg_start_time); + .time_calculating_group_ids + .add_duration(agg_start_time - groups_start_time); + + // Gather the inputs to call the actual accumulator + let t = self + .accumulators + .iter_mut() + .zip(input_values.iter()) + .zip(filter_values.iter()); + + for ((acc, values), opt_filter) in t { + let opt_filter = + opt_filter.as_ref().map(|filter| filter.as_boolean()); + + // Call the appropriate method on each aggregator with + // the entire input row and the relevant group indexes + if self.mode.input_mode() == AggregateInputMode::Raw + && !self.spill_state.is_stream_merging + { + acc.update_batch( + values, + group_indices, + opt_filter, + total_num_groups, + )?; + } else { + assert_or_internal_err!( + opt_filter.is_none(), + "aggregate filter should be applied in partial stage, there should be no filter in final stage" + ); + + // if aggregation is over intermediate states, + // use merge + acc.merge_batch(values, group_indices, None, total_num_groups)?; + } + self.group_by_metrics + .aggregation_time + .add_elapsed(agg_start_time); + } + } else { + self.current_group_indices.clear(); + self.group_values + .intern(group_values, &mut self.current_group_indices)?; + + let total_num_groups = self.group_values.len(); + if total_num_groups > starting_num_groups { + match &mut self.group_ordering { + GroupOrdering::None => {} + GroupOrdering::Full(ordering) => { + ordering.new_groups(total_num_groups); + } + GroupOrdering::Partial(_) => unreachable!( + "partial ordering requires per-row group indices" + ), + } + } + + self.group_by_metrics + .time_calculating_group_ids + .add_duration(Instant::now() - groups_start_time); } } @@ -1374,13 +1409,15 @@ impl GroupedHashAggregateStream { // Recreate `group_values` for streaming merge so group ids are assigned // in first-seen order, as required by `GroupOrderingFull`. - // The pre-spill multi-column collector may use `vectorized_intern`, which - // can assign new group ids out of input order under hash collisions. + // The pre-spill collector may not track group ids for DISTINCT-only + // aggregation. The pre-spill multi-column collector may use + // `vectorized_intern`, which can assign new group ids out of input + // order under hash collisions. let group_schema = self .spill_state .merging_group_by .group_schema(&self.spill_state.spill_schema)?; - if group_schema.fields().len() > 1 { + if self.accumulators.is_empty() || group_schema.fields().len() > 1 { self.group_values = new_group_values(group_schema, &self.group_ordering)?; } @@ -1470,9 +1507,12 @@ mod tests { use super::*; use crate::InputOrderMode; use crate::execution_plan::ExecutionPlan; + use crate::metrics::MetricValue; use crate::test::TestMemoryExec; use arrow::array::{Int32Array, Int64Array}; use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_execution::config::SessionConfig; + use datafusion_execution::memory_pool::FairSpillPool; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_functions_aggregate::count::count_udaf; use datafusion_physical_expr::aggregate::AggregateExprBuilder; @@ -1584,6 +1624,80 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_spill_distinct_single_primitive_group_by() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new( + "group_col", + DataType::Int32, + false, + )])); + + let num_distinct = 512; + let num_spills = 24; + let input_partitions = vec![ + (0..num_spills) + .map(|_| { + Ok(RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from_iter_values(0..num_distinct))], + )?) + }) + .collect::>>()?, + ]; + + let session_config = SessionConfig::new().with_batch_size(4); + let runtime = RuntimeEnvBuilder::new() + .with_memory_pool(Arc::new(FairSpillPool::new(6000))) + .build_arc()?; + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(session_config) + .with_runtime(runtime), + ); + + let group_expr = vec![(col("group_col", &schema)?, "group_col".to_string())]; + let exec = TestMemoryExec::try_new(&input_partitions, Arc::clone(&schema), None)?; + let exec = Arc::new(TestMemoryExec::update_cache(&Arc::new(exec))); + + let aggregate_exec = AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new_single(group_expr), + Vec::>::new(), + vec![], + exec, + Arc::clone(&schema), + )?; + + let mut stream = + GroupedHashAggregateStream::new(&aggregate_exec, &Arc::clone(&task_ctx), 0)?; + let mut values = Vec::new(); + + while let Some(result) = stream.next().await { + let batch = result?; + let group_col = batch + .column(0) + .as_primitive::(); + values.extend(group_col.values().iter().copied()); + } + + let spill_count = aggregate_exec + .metrics() + .unwrap() + .iter() + .find_map(|metric| match metric.value() { + MetricValue::SpillCount(count) => Some(count.value()), + _ => None, + }) + .unwrap_or_default(); + assert!(spill_count > 0, "expected test to exercise spilling"); + + values.sort_unstable(); + let expected = (0..num_distinct).collect::>(); + assert_eq!(values, expected); + + Ok(()) + } + #[tokio::test] async fn test_skip_aggregation_probe_not_locked_until_skip() -> Result<()> { // Test that the probe is not locked until we actually decide to skip.