From f87c46c10b50989f1b9970785287ec653fc3f5ae Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Thu, 7 May 2026 16:17:15 -0400 Subject: [PATCH 1/2] Fix agg filter NULL handling --- .../groups_accumulator/accumulate.rs | 117 +++++++++++++----- .../src/aggregate/groups_accumulator/nulls.rs | 27 ++-- .../functions-aggregate/src/array_agg.rs | 2 +- .../sqllogictest/test_files/aggregate.slt | 20 +++ 4 files changed, 127 insertions(+), 39 deletions(-) diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs index 25f52df61136f..8f0712a83f04a 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs @@ -23,6 +23,7 @@ use arrow::array::{Array, BooleanArray, BooleanBufferBuilder, PrimitiveArray}; use arrow::buffer::NullBuffer; use arrow::datatypes::ArrowPrimitiveType; +use crate::aggregate::groups_accumulator::nulls::filter_to_validity; use datafusion_expr_common::groups_accumulator::EmitTo; /// If the input has nulls, then the accumulator must potentially @@ -471,13 +472,14 @@ pub fn accumulate( /// /// This method assumes that for any input record index, if any of the value column /// is null, or it's filtered out by `opt_filter`, then the record would be ignored. -/// (won't be accumulated by `value_fn`) +/// (Won't be accumulated by `value_fn`) /// /// # Arguments /// /// * `group_indices` - To which groups do the rows in `value_columns` belong /// * `value_columns` - The input arrays to accumulate -/// * `opt_filter` - Optional filter array. If present, only rows where filter is `Some(true)` are included +/// * `opt_filter` - Optional filter array. If present, only rows where filter +/// is `Some(true)` are included /// * `value_fn` - Callback function for each valid row, with parameters: /// * `group_idx`: The group index for the current row /// * `batch_idx`: The index of the current row in the input arrays @@ -491,35 +493,28 @@ pub fn accumulate_multiple( T: ArrowPrimitiveType + Send, F: FnMut(usize, usize, &[&PrimitiveArray]) + Send, { - // Calculate `valid_indices` to accumulate, non-valid indices are ignored. - // `valid_indices` is a bit mask corresponding to the `group_indices`. An index - // is considered valid if: - // 1. All columns are non-null at this index. - // 2. Not filtered out by `opt_filter` - - // Take AND from all null buffers of `value_columns`. - let combined_nulls = value_columns - .iter() - .map(|arr| arr.logical_nulls()) - .fold(None, |acc, nulls| { - NullBuffer::union(acc.as_ref(), nulls.as_ref()) - }); - - // Take AND from previous combined nulls and `opt_filter`. - let valid_indices = match (combined_nulls, opt_filter) { - (None, None) => None, - (None, Some(filter)) => Some(filter.clone()), - (Some(nulls), None) => Some(BooleanArray::new(nulls.inner().clone(), None)), - (Some(nulls), Some(filter)) => { - let combined = nulls.inner() & filter.values(); - Some(BooleanArray::new(combined, None)) - } - }; - for col in value_columns.iter() { debug_assert_eq!(col.len(), group_indices.len()); } + // Start with rows where all value columns are non-null. + let mut valid_indices = + NullBuffer::union_many(value_columns.iter().map(|arr| arr.nulls())) + .map(NullBuffer::into_inner); + + // Restrict to rows where the optional filter is Some(true). Keep the filter + // as a raw BooleanBuffer to avoid computing a NullBuffer null_count just to + // test row validity below. + if let Some(filter) = opt_filter { + debug_assert_eq!(filter.len(), group_indices.len()); + let filter_validity = filter_to_validity(filter); + if let Some(valid_indices) = valid_indices.as_mut() { + *valid_indices &= &filter_validity; + } else { + valid_indices = Some(filter_validity); + } + } + match valid_indices { None => { for (batch_idx, &group_idx) in group_indices.iter().enumerate() { @@ -562,7 +557,8 @@ pub fn accumulate_indices( (None, Some(filter)) => { debug_assert_eq!(filter.len(), group_indices.len()); let group_indices_chunks = group_indices.chunks_exact(64); - let bit_chunks = filter.values().bit_chunks(); + let filter_validity = filter_to_validity(filter); + let bit_chunks = filter_validity.bit_chunks(); let group_indices_remainder = group_indices_chunks.remainder(); @@ -636,7 +632,8 @@ pub fn accumulate_indices( let group_indices_chunks = group_indices.chunks_exact(64); let valid_bit_chunks = valids.inner().bit_chunks(); - let filter_bit_chunks = filter.values().bit_chunks(); + let filter_validity = filter_to_validity(filter); + let filter_bit_chunks = filter_validity.bit_chunks(); let group_indices_remainder = group_indices_chunks.remainder(); @@ -1188,6 +1185,68 @@ mod test { assert_eq!(accumulated, expected); } + #[test] + fn test_accumulate_indices_with_null_filter() { + let group_indices = vec![0, 1, 0, 1]; + let filter = BooleanArray::new( + BooleanBuffer::from(vec![true, true, true, false]), + Some(NullBuffer::from(vec![true, false, true, true])), + ); + + let mut accumulated = vec![]; + accumulate_indices(&group_indices, None, Some(&filter), |group_idx| { + accumulated.push(group_idx); + }); + + // A NULL filter value should be treated the same as false, even if the + // underlying BooleanBuffer value is true. + let expected = vec![0, 0]; + assert_eq!(accumulated, expected); + + let value_validity = NullBuffer::from(vec![true, true, false, true]); + let mut accumulated = vec![]; + accumulate_indices( + &group_indices, + Some(&value_validity), + Some(&filter), + |group_idx| { + accumulated.push(group_idx); + }, + ); + + let expected = vec![0]; + assert_eq!(accumulated, expected); + } + + #[test] + fn test_accumulate_multiple_with_null_filter() { + let group_indices = vec![0, 1, 0, 1]; + let values1 = Int32Array::from(vec![1, 2, 3, 4]); + let values2 = Int32Array::from(vec![10, 20, 30, 40]); + let value_columns = [values1, values2]; + + let filter = BooleanArray::new( + BooleanBuffer::from(vec![true, true, true, false]), + Some(NullBuffer::from(vec![true, false, true, true])), + ); + + let mut accumulated = vec![]; + accumulate_multiple( + &group_indices, + &value_columns.iter().collect::>(), + Some(&filter), + |group_idx, batch_idx, columns| { + let values = columns.iter().map(|col| col.value(batch_idx)).collect(); + accumulated.push((group_idx, values)); + }, + ); + + // A NULL filter value should be treated the same as false, even if the + // underlying BooleanBuffer value is true. + let expected = vec![(0, vec![1, 10]), (0, vec![3, 30])]; + assert_eq!(accumulated, expected); + } + #[test] fn test_accumulate_multiple_with_nulls_and_filter() { let group_indices = vec![0, 1, 0, 1]; diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs index 5b56b77e11d3f..d524afe43a5a3 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/nulls.rs @@ -22,7 +22,7 @@ use arrow::array::{ BooleanArray, LargeBinaryArray, LargeStringArray, PrimitiveArray, StringArray, StringViewArray, StructArray, }; -use arrow::buffer::NullBuffer; +use arrow::buffer::{BooleanBuffer, NullBuffer}; use arrow::datatypes::DataType; use datafusion_common::{Result, not_impl_err}; use std::sync::Arc; @@ -39,15 +39,24 @@ pub fn set_nulls( PrimitiveArray::::new(values, nulls).with_data_type(dt) } -/// Converts a `BooleanBuffer` representing a filter to a `NullBuffer. +/// Converts an aggregate filter expression to a validity bitmap. +/// +/// The output is `true` for rows where the filter is `Some(true)`, and `false` +/// for rows where the filter is `Some(false)` or `None`. +pub(crate) fn filter_to_validity(filter: &BooleanArray) -> BooleanBuffer { + let Some(filter_nulls) = filter.nulls() else { + return filter.values().clone(); + }; + filter.values() & filter_nulls.inner() +} + +/// Converts an aggregate filter expression to a `NullBuffer`. /// /// The `NullBuffer` is -/// * `true` (representing valid) for values that were `true` in filter -/// * `false` (representing null) for values that were `false` or `null` in filter -pub fn filter_to_nulls(filter: &BooleanArray) -> Option { - let (filter_bools, filter_nulls) = filter.clone().into_parts(); - let filter_bools = NullBuffer::from(filter_bools); - NullBuffer::union(Some(&filter_bools), filter_nulls.as_ref()) +/// * `true` (representing valid) for filter values that were `Some(true)` +/// * `false` (representing null) for filter values that were `Some(false)` or `None` +pub fn filter_to_nulls(filter: &BooleanArray) -> NullBuffer { + NullBuffer::new(filter_to_validity(filter)) } /// Compute an output validity mask for an array that has been filtered @@ -97,7 +106,7 @@ pub fn filtered_null_mask( opt_filter: Option<&BooleanArray>, input: &dyn Array, ) -> Option { - let opt_filter = opt_filter.and_then(filter_to_nulls); + let opt_filter = opt_filter.map(filter_to_nulls); NullBuffer::union(opt_filter.as_ref(), input.nulls()) } diff --git a/datafusion/functions-aggregate/src/array_agg.rs b/datafusion/functions-aggregate/src/array_agg.rs index 861d7712ba1b0..182d77699cad3 100644 --- a/datafusion/functions-aggregate/src/array_agg.rs +++ b/datafusion/functions-aggregate/src/array_agg.rs @@ -720,7 +720,7 @@ impl GroupsAccumulator for ArrayAggGroupsAccumulator { let offsets = OffsetBuffer::from_repeated_length(1, input.len()); // Filtered rows become null list entries, which merge_batch will skip. - let filter_nulls = opt_filter.and_then(filter_to_nulls); + let filter_nulls = opt_filter.map(filter_to_nulls); // With ignore_nulls, null values also become null list entries. Without // ignore_nulls, null values stay as [NULL] so merge_batch retains them. diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 70acff3cb7b9a..b8009dfd57cec 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -693,6 +693,18 @@ from data ---- 1 +# correlation_with_group_by_and_nullable_filter +query IR rowsort +SELECT g, corr(x, y) FILTER (WHERE b < 1) AS r +FROM (VALUES + (0, 1.0, 1.0, CAST(NULL AS INT)), + (0, 2.0, 2.0, CAST(NULL AS INT)), + (0, 3.0, 4.0, 2) +) AS t(g, x, y, b) +GROUP BY g +---- +0 NULL + # group correlation_query_with_nans_f32 query IR select id, corr(f, b) @@ -6177,6 +6189,14 @@ FROM test_table ---- 2 +# count_with_group_by_and_nullable_filter +query II rowsort +SELECT g, COUNT(a) FILTER (WHERE b < 1) AS count_a +FROM (VALUES (0, 1, CAST(NULL AS INT)), (0, 2, 2)) AS t(g, a, b) +GROUP BY g +---- +0 0 + # query_with_and_without_filter query III rowsort SELECT From 1bbbf12d686aa284f22a4f6651142ecb3716a56e Mon Sep 17 00:00:00 2001 From: Neil Conway Date: Thu, 7 May 2026 16:35:57 -0400 Subject: [PATCH 2/2] Fix clippy --- .../src/aggregate/groups_accumulator/accumulate.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs index 8f0712a83f04a..09e1df4eae70c 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs @@ -478,8 +478,7 @@ pub fn accumulate( /// /// * `group_indices` - To which groups do the rows in `value_columns` belong /// * `value_columns` - The input arrays to accumulate -/// * `opt_filter` - Optional filter array. If present, only rows where filter -/// is `Some(true)` are included +/// * `opt_filter` - Optional filter array. If present, only rows where filter is `Some(true)` are included /// * `value_fn` - Callback function for each valid row, with parameters: /// * `group_idx`: The group index for the current row /// * `batch_idx`: The index of the current row in the input arrays