Skip to content

Commit 579d562

Browse files
committed
fix: UNIQUE constraint with NULLs incorrectly collapses GROUP BY groups
1 parent 32f51ec commit 579d562

6 files changed

Lines changed: 193 additions & 29 deletions

File tree

datafusion/common/src/functional_dependencies.rs

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -196,15 +196,25 @@ impl FunctionalDependencies {
196196
}
197197

198198
/// Creates a new `FunctionalDependencies` object from the given constraints.
199+
///
200+
/// `nullable_flags` must contain one entry per field in the relation,
201+
/// indicating whether that field is nullable. A `UNIQUE` constraint whose
202+
/// source columns include any nullable field is **not** a functional
203+
/// dependency — because SQL treats `NULL` values as distinct, multiple rows
204+
/// may carry `NULL` in a unique-key column without violating the constraint.
205+
/// Such constraints are therefore omitted entirely. When all source columns
206+
/// are non-nullable a `UNIQUE` constraint is equivalent to a primary key and
207+
/// is recorded with `nullable = false`.
199208
pub fn new_from_constraints(
200209
constraints: Option<&Constraints>,
201-
n_field: usize,
210+
nullable_flags: &[bool],
202211
) -> Self {
212+
let n_field = nullable_flags.len();
203213
if let Some(Constraints { inner: constraints }) = constraints {
204214
// Construct dependency objects based on each individual constraint:
205215
let dependencies = constraints
206216
.iter()
207-
.map(|constraint| {
217+
.filter_map(|constraint| {
208218
// All the field indices are associated with the whole table
209219
// since we are dealing with table level constraints:
210220
let dependency = match constraint {
@@ -213,15 +223,27 @@ impl FunctionalDependencies {
213223
(0..n_field).collect::<Vec<_>>(),
214224
false,
215225
),
216-
Constraint::Unique(indices) => FunctionalDependence::new(
217-
indices.to_vec(),
218-
(0..n_field).collect::<Vec<_>>(),
219-
true,
220-
),
226+
Constraint::Unique(indices) => {
227+
// A UNIQUE constraint where any source column is
228+
// nullable is not a functional dependency: SQL does
229+
// not consider NULLs equal, so two rows may both
230+
// have NULL in the key and still satisfy the
231+
// constraint. Only emit an FD when all source
232+
// columns are non-nullable, in which case it is
233+
// equivalent to a primary key.
234+
if indices.iter().any(|&i| nullable_flags[i]) {
235+
return None;
236+
}
237+
FunctionalDependence::new(
238+
indices.to_vec(),
239+
(0..n_field).collect::<Vec<_>>(),
240+
false,
241+
)
242+
}
221243
};
222244
// As primary keys are guaranteed to be unique, set the
223245
// functional dependency mode to `Dependency::Single`:
224-
dependency.with_mode(Dependency::Single)
246+
Some(dependency.with_mode(Dependency::Single))
225247
})
226248
.collect::<Vec<_>>();
227249
Self::new(dependencies)
@@ -422,7 +444,6 @@ pub fn aggregate_functional_dependencies(
422444
) -> FunctionalDependencies {
423445
let mut aggregate_func_dependencies = vec![];
424446
let aggr_input_fields = aggr_input_schema.field_names();
425-
let aggr_fields = aggr_schema.fields();
426447
// Association covers the whole table:
427448
let target_indices = (0..aggr_schema.fields().len()).collect::<Vec<_>>();
428449
// Get functional dependencies of the schema:
@@ -484,9 +505,12 @@ pub fn aggregate_functional_dependencies(
484505
if !group_by_expr_names.is_empty() {
485506
let count = group_by_expr_names.len();
486507
let source_indices = (0..count).collect::<Vec<_>>();
487-
let nullable = source_indices
488-
.iter()
489-
.any(|idx| aggr_fields[*idx].is_nullable());
508+
// Aggregation with GROUP BY always produces unique output rows for
509+
// each distinct combination of GROUP BY keys. The nullable flag is
510+
// set to false here so that subsequent expansion (e.g. a second
511+
// GROUP BY on the aggregate output) is never blocked by source
512+
// field nullability.
513+
let nullable = false;
490514
// If GROUP BY expressions do not already act as a determinant:
491515
if !aggregate_func_dependencies.iter().any(|item| {
492516
// If `item.source_indices` is a subset of GROUP BY expressions, we shouldn't add

datafusion/expr/src/logical_plan/builder.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2887,6 +2887,32 @@ mod tests {
28872887
Ok(())
28882888
}
28892889

2890+
#[test]
2891+
fn plan_builder_aggregate_does_not_expand_nullable_unique_group_by_exprs()
2892+
-> Result<()> {
2893+
let schema = Schema::new(vec![
2894+
Field::new("id", DataType::Int32, true),
2895+
Field::new("state", DataType::Utf8, false),
2896+
Field::new("salary", DataType::Int32, false),
2897+
]);
2898+
let constraints = Constraints::new_unverified(vec![Constraint::Unique(vec![0])]);
2899+
let table_source = table_source_with_constraints(&schema, constraints);
2900+
2901+
let options =
2902+
LogicalPlanBuilderOptions::new().with_add_implicit_group_by_exprs(true);
2903+
let plan = LogicalPlanBuilder::scan("employee_csv", table_source, None)?
2904+
.with_options(options)
2905+
.aggregate(vec![col("id")], vec![sum(col("salary"))])?
2906+
.build()?;
2907+
2908+
assert_snapshot!(plan, @r"
2909+
Aggregate: groupBy=[[employee_csv.id]], aggr=[[sum(employee_csv.salary)]]
2910+
TableScan: employee_csv
2911+
");
2912+
2913+
Ok(())
2914+
}
2915+
28902916
#[test]
28912917
fn test_join_metadata() -> Result<()> {
28922918
let left_schema = DFSchema::new_with_metadata(

datafusion/expr/src/logical_plan/plan.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,6 @@ impl LogicalPlan {
354354
LogicalPlan::Ddl(ddl) => ddl.schema(),
355355
LogicalPlan::Unnest(Unnest { schema, .. }) => schema,
356356
LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => {
357-
// we take the schema of the static term as the schema of the entire recursive query
358357
static_term.schema()
359358
}
360359
}
@@ -2405,6 +2404,10 @@ impl SubqueryAlias {
24052404
// no field must share the same column name as this would lead to ambiguity when referencing
24062405
// columns in parent logical nodes.
24072406

2407+
// Capture whether the input is a RecursiveQuery before `plan` may be
2408+
// rebound to a wrapping Projection below.
2409+
let is_recursive_query = matches!(plan.as_ref(), LogicalPlan::RecursiveQuery(_));
2410+
24082411
// Compute unique aliases, if any, for each column of the input's schema.
24092412
let aliases = unique_field_aliases(plan.schema().fields());
24102413
let is_projection_needed = aliases.iter().any(Option::is_some);
@@ -2434,7 +2437,14 @@ impl SubqueryAlias {
24342437
// Requalify fields with the new `alias`.
24352438
let fields = plan.schema().fields().clone();
24362439
let meta_data = plan.schema().metadata().clone();
2437-
let func_dependencies = plan.schema().functional_dependencies().clone();
2440+
// Recursive queries do not expose the anchor's functional dependencies to
2441+
// the outer schema — the recursive term can produce rows that violate
2442+
// those dependencies, so they are intentionally dropped here.
2443+
let func_dependencies = if is_recursive_query {
2444+
FunctionalDependencies::empty()
2445+
} else {
2446+
plan.schema().functional_dependencies().clone()
2447+
};
24382448

24392449
let schema = DFSchema::from_unqualified_fields(fields, meta_data)?;
24402450
let schema = schema.as_arrow();
@@ -2856,9 +2866,11 @@ impl TableScan {
28562866
return plan_err!("table_name cannot be empty");
28572867
}
28582868
let schema = table_source.schema();
2869+
let nullable_flags: Vec<bool> =
2870+
schema.fields().iter().map(|f| f.is_nullable()).collect();
28592871
let func_dependencies = FunctionalDependencies::new_from_constraints(
28602872
table_source.constraints(),
2861-
schema.fields.len(),
2873+
&nullable_flags,
28622874
);
28632875
let projected_schema = projection
28642876
.as_ref()
@@ -5146,7 +5158,7 @@ mod tests {
51465158
Some(&Constraints::new_unverified(vec![Constraint::Unique(
51475159
vec![0],
51485160
)])),
5149-
1,
5161+
&[false],
51505162
),
51515163
)
51525164
.unwrap(),

datafusion/optimizer/src/eliminate_duplicated_expr.rs

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,12 @@ mod tests {
145145
use crate::OptimizerContext;
146146
use crate::assert_optimized_plan_eq_snapshot;
147147
use crate::test::*;
148-
use datafusion_expr::{col, logical_plan::builder::LogicalPlanBuilder};
148+
use arrow::datatypes::{DataType, Field, Schema};
149+
use datafusion_common::{Constraint, Constraints};
150+
use datafusion_expr::{
151+
col, logical_plan::builder::LogicalPlanBuilder,
152+
logical_plan::builder::table_source_with_constraints,
153+
};
149154
use std::sync::Arc;
150155

151156
macro_rules! assert_optimized_plan_equal {
@@ -200,4 +205,47 @@ mod tests {
200205
TableScan: test
201206
")
202207
}
208+
209+
#[test]
210+
fn eliminate_sort_exprs_pk_removes_dependent_key() -> Result<()> {
211+
// When `id` is a PRIMARY KEY (non-nullable), it uniquely determines
212+
// `val`, so `ORDER BY id, val` can safely be reduced to `ORDER BY id`.
213+
let schema = Schema::new(vec![
214+
Field::new("id", DataType::Int32, false),
215+
Field::new("val", DataType::Int32, false),
216+
]);
217+
let constraints = Constraints::new_unverified(vec![Constraint::PrimaryKey(vec![0])]);
218+
let source = table_source_with_constraints(&schema, constraints);
219+
let plan = LogicalPlanBuilder::scan("t", source, None)?
220+
.sort_by(vec![col("t.id"), col("t.val")])?
221+
.build()?;
222+
223+
assert_optimized_plan_equal!(plan, @r"
224+
Sort: t.id ASC NULLS LAST
225+
TableScan: t
226+
")
227+
}
228+
229+
#[test]
230+
fn eliminate_sort_exprs_nullable_unique_keeps_dependent_key() -> Result<()> {
231+
// When `id` is a nullable UNIQUE column, SQL allows multiple NULL
232+
// values in `id`. Because NULLs are not considered equal, multiple
233+
// rows may share `id = NULL` with different `val` values, so `id`
234+
// does NOT functionally determine `val`. `ORDER BY id, val` must
235+
// therefore keep both keys.
236+
let schema = Schema::new(vec![
237+
Field::new("id", DataType::Int32, true), // nullable
238+
Field::new("val", DataType::Int32, false),
239+
]);
240+
let constraints = Constraints::new_unverified(vec![Constraint::Unique(vec![0])]);
241+
let source = table_source_with_constraints(&schema, constraints);
242+
let plan = LogicalPlanBuilder::scan("t", source, None)?
243+
.sort_by(vec![col("t.id"), col("t.val")])?
244+
.build()?;
245+
246+
assert_optimized_plan_equal!(plan, @r"
247+
Sort: t.id ASC NULLS LAST, t.val ASC NULLS LAST
248+
TableScan: t
249+
")
250+
}
203251
}

datafusion/sqllogictest/test_files/cte.slt

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1319,3 +1319,34 @@ RESET datafusion.execution.enable_recursive_ctes;
13191319

13201320
statement ok
13211321
RESET datafusion.sql_parser.enable_ident_normalization;
1322+
1323+
# Regression test: functional dependencies from the static (anchor) term of a
1324+
# recursive CTE must NOT be propagated to the outer SubqueryAlias. The
1325+
# recursive term can produce rows that violate any uniqueness constraint that
1326+
# holds for the anchor alone. Without this guard, Filter(pk = const) on the
1327+
# CTE result would be mis-identified as scalar (at most 1 row) and return only
1328+
# one row instead of all matching rows.
1329+
statement ok
1330+
CREATE TABLE pk_table(id INT NOT NULL, val INT NOT NULL, PRIMARY KEY(id));
1331+
1332+
statement ok
1333+
INSERT INTO pk_table VALUES (1, 100), (2, 200);
1334+
1335+
# The recursive term produces a second row with id=1 (val=300). Without the
1336+
# FD fix, Filter(nodes.id = 1) would be deemed scalar and return only the
1337+
# first matching row.
1338+
query II rowsort
1339+
WITH RECURSIVE nodes AS (
1340+
SELECT id, val FROM pk_table
1341+
UNION ALL
1342+
SELECT 1 AS id, 300 AS val
1343+
FROM nodes
1344+
WHERE nodes.id = 2
1345+
)
1346+
SELECT id, val FROM nodes WHERE id = 1
1347+
----
1348+
1 100
1349+
1 300
1350+
1351+
statement ok
1352+
DROP TABLE pk_table;

datafusion/sqllogictest/test_files/group_by.slt

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3565,33 +3565,27 @@ SELECT r.sn, r.amount, SUM(r.amount)
35653565
GROUP BY r.sn
35663566
ORDER BY r.sn
35673567

3568-
# left semi join should propagate constraint of left side as is.
3569-
query IRR
3568+
# left semi join with a nullable UNIQUE key cannot safely propagate the
3569+
# constraint for expansion, because UNIQUE allows multiple NULLs.
3570+
statement error DataFusion error: Error during planning: Column in SELECT must be in GROUP BY or an aggregate function: While expanding wildcard, column "l\.amount" must appear in the GROUP BY clause or must be part of an aggregate function, currently only "l\.sn, sum\(l\.amount\)" appears in the SELECT clause satisfies this requirement
35703571
SELECT l.sn, l.amount, SUM(l.amount)
35713572
FROM (SELECT *
35723573
FROM sales_global_with_unique as l
35733574
LEFT SEMI JOIN sales_global_with_unique as r
35743575
ON l.amount >= r.amount + 10)
35753576
GROUP BY l.sn
35763577
ORDER BY l.sn
3577-
----
3578-
1 50 50
3579-
2 75 75
3580-
3 200 200
3581-
4 100 100
3582-
NULL 100 100
35833578

3584-
# Similarly, left anti join should propagate constraint of left side as is.
3585-
query IRR
3579+
# Similarly, left anti join with a nullable UNIQUE key cannot safely propagate
3580+
# the constraint for expansion.
3581+
statement error DataFusion error: Error during planning: Column in SELECT must be in GROUP BY or an aggregate function: While expanding wildcard, column "l\.amount" must appear in the GROUP BY clause or must be part of an aggregate function, currently only "l\.sn, sum\(l\.amount\)" appears in the SELECT clause satisfies this requirement
35863582
SELECT l.sn, l.amount, SUM(l.amount)
35873583
FROM (SELECT *
35883584
FROM sales_global_with_unique as l
35893585
LEFT ANTI JOIN sales_global_with_unique as r
35903586
ON l.amount >= r.amount + 10)
35913587
GROUP BY l.sn
35923588
ORDER BY l.sn
3593-
----
3594-
0 30 30
35953589

35963590
# Should support grouping by list column
35973591
query ?I
@@ -5641,3 +5635,32 @@ set datafusion.execution.target_partitions = 4;
56415635

56425636
statement count 0
56435637
drop table t;
5638+
5639+
# Test that GROUP BY with a UNIQUE constraint does not incorrectly collapse
5640+
# NULL rows. UNIQUE allows multiple NULLs (NULLs are not equal in SQL), so
5641+
# a UNIQUE column cannot be used to eliminate other GROUP BY columns.
5642+
# Regression test for https://github.com/apache/datafusion/issues/21507
5643+
5644+
statement ok
5645+
CREATE TABLE t_unique_null(a INT, b INT, c INT, UNIQUE(a));
5646+
5647+
statement ok
5648+
INSERT INTO t_unique_null VALUES (1, 10, 100), (NULL, 20, 200), (NULL, 30, 300);
5649+
5650+
# The two NULL rows must stay in separate groups (grouped by b as well).
5651+
query II rowsort
5652+
SELECT a, SUM(c) AS total FROM t_unique_null GROUP BY a, b;
5653+
----
5654+
1 100
5655+
NULL 200
5656+
NULL 300
5657+
5658+
# GROUP BY on the UNIQUE column alone must still merge the NULL rows into one group.
5659+
query II rowsort
5660+
SELECT a, SUM(c) AS total FROM t_unique_null GROUP BY a;
5661+
----
5662+
1 100
5663+
NULL 500
5664+
5665+
statement ok
5666+
DROP TABLE t_unique_null;

0 commit comments

Comments
 (0)