Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 215 additions & 5 deletions datafusion/optimizer/src/optimize_projections/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,18 @@ mod required_indices;

use crate::optimizer::ApplyOrder;
use crate::{OptimizerConfig, OptimizerRule};
use arrow::array::Array;
use std::collections::HashSet;
use std::sync::Arc;

use datafusion_common::{
Column, DFSchema, HashMap, JoinType, Result, assert_eq_or_internal_err,
Column, DFSchema, HashMap, JoinType, Result, ScalarValue, assert_eq_or_internal_err,
get_required_group_by_exprs_indices, internal_datafusion_err, internal_err,
};
use datafusion_expr::expr::Alias;
use datafusion_expr::{
Aggregate, Distinct, EmptyRelation, Expr, Projection, TableScan, Unnest, Window,
logical_plan::LogicalPlan,
logical_plan::LogicalPlan, utils::expr_to_columns,
};

use crate::optimize_projections::required_indices::RequiredIndices;
Expand Down Expand Up @@ -146,7 +148,8 @@ fn optimize_projections(
let n_group_exprs = aggregate.group_expr_len()?;
// Offset aggregate indices so that they point to valid indices at
// `aggregate.aggr_expr`:
let (group_by_reqs, aggregate_reqs) = indices.split_off(n_group_exprs);
let (group_by_reqs, aggregate_reqs) =
indices.clone().split_off(n_group_exprs);

// Get absolutely necessary GROUP BY fields.
//
Expand Down Expand Up @@ -197,6 +200,22 @@ fn optimize_projections(
)));
}

if new_aggr_expr.is_empty()
&& let Some(input) =
remove_unused_unnest_from_duplicate_insensitive_input(
aggregate.input.as_ref(),
&new_group_bys,
)?
{
let aggregate =
Aggregate::try_new(Arc::new(input), new_group_bys, new_aggr_expr)?;
return optimize_projections(
LogicalPlan::Aggregate(aggregate),
config,
indices,
);
}

let all_exprs_iter = new_group_bys.iter().chain(new_aggr_expr.iter());
let schema = aggregate.input.schema();
let necessary_indices =
Expand Down Expand Up @@ -492,6 +511,113 @@ fn optimize_projections(
}
}

fn remove_unused_unnest_from_duplicate_insensitive_input(
input: &LogicalPlan,
required_exprs: &[Expr],
) -> Result<Option<LogicalPlan>> {
match input {
LogicalPlan::Unnest(unnest)
if can_remove_unused_unnest_for_exprs(unnest, required_exprs)? =>
{
Ok(Some(Arc::unwrap_or_clone(Arc::clone(&unnest.input))))
}
LogicalPlan::Projection(projection) => {
let LogicalPlan::Unnest(unnest) = projection.input.as_ref() else {
return Ok(None);
};
let required_projection_exprs = RequiredIndices::new()
.with_exprs(&projection.schema, required_exprs.iter())
.get_at_indices(&projection.expr);

if can_remove_unused_unnest_for_exprs(unnest, &required_projection_exprs)? {
Projection::try_new(required_projection_exprs, Arc::clone(&unnest.input))
.map(LogicalPlan::Projection)
.map(Some)
} else {
Ok(None)
}
}
_ => Ok(None),
}
}

fn can_remove_unused_unnest_for_exprs(unnest: &Unnest, exprs: &[Expr]) -> Result<bool> {
if !unnest_preserves_at_least_one_row_per_input(unnest) {
return Ok(false);
}

let mut columns = HashSet::new();
for expr in exprs {
expr_to_columns(expr, &mut columns)?;
}

for column in columns {
let output_index = unnest.schema.index_of_column(&column)?;
if is_unnested_input_index(unnest, unnest.dependency_indices[output_index]) {
return Ok(false);
}
}

Ok(true)
}

fn is_unnested_input_index(unnest: &Unnest, input_index: usize) -> bool {
unnest
.list_type_columns
.iter()
.map(|(idx, _)| *idx)
.chain(unnest.struct_type_columns.iter().copied())
.any(|idx| idx == input_index)
}

fn unnest_preserves_at_least_one_row_per_input(unnest: &Unnest) -> bool {
unnest.list_type_columns.iter().all(|(input_index, _)| {
unnest_input_expr(unnest, *input_index)
.and_then(literal_non_empty_list)
.unwrap_or(false)
})
}

fn unnest_input_expr(unnest: &Unnest, input_index: usize) -> Option<&Expr> {
match unnest.input.as_ref() {
LogicalPlan::Projection(projection) => projection.expr.get(input_index),
_ => None,
}
}

fn literal_non_empty_list(expr: &Expr) -> Option<bool> {
let expr = match expr {
Expr::Alias(Alias { expr, .. }) => expr.as_ref(),
_ => expr,
};
let Expr::Literal(value, _) = expr else {
return None;
};

match value {
ScalarValue::List(array) => {
Some(has_valid_first_value(array.as_ref()) && array.value_length(0) > 0)
}
ScalarValue::LargeList(array) => {
Some(has_valid_first_value(array.as_ref()) && array.value_length(0) > 0)
}
ScalarValue::FixedSizeList(array) => {
Some(has_valid_first_value(array.as_ref()) && array.value_length() > 0)
}
ScalarValue::ListView(array) => {
Some(has_valid_first_value(array.as_ref()) && array.value_sizes()[0] > 0)
}
ScalarValue::LargeListView(array) => {
Some(has_valid_first_value(array.as_ref()) && array.value_sizes()[0] > 0)
}
_ => None,
}
}

fn has_valid_first_value(array: &impl Array) -> bool {
!array.is_empty() && array.is_valid(0)
}

/// Optimizes uncorrelated subquery plans embedded in expressions of the given
/// plan node (e.g., `Expr::ScalarSubquery`). `map_children` only visits direct
/// plan inputs, so subqueries must be handled separately.
Expand Down Expand Up @@ -947,9 +1073,10 @@ mod tests {
test_table_scan_with_name,
};
use crate::{OptimizerContext, OptimizerRule};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::array::ListArray;
use arrow::datatypes::{DataType, Field, Int64Type, Schema};
use datafusion_common::{
Column, DFSchema, DFSchemaRef, JoinType, Result, TableReference,
Column, DFSchema, DFSchemaRef, JoinType, Result, ScalarValue, TableReference,
};
use datafusion_expr::ExprFunctionExt;
use datafusion_expr::{
Expand Down Expand Up @@ -986,6 +1113,22 @@ mod tests {
}};
}

fn id_schema() -> Schema {
Schema::new(vec![Field::new("id", DataType::UInt32, false)])
}

fn list_literal_expr(values: Vec<Option<i64>>) -> Expr {
let list = ListArray::from_iter_primitive::<Int64Type, _, _>(vec![Some(values)]);
Expr::Literal(ScalarValue::List(Arc::new(list)), None)
}

fn id_elem_unnest_plan(values: Vec<Option<i64>>) -> Result<LogicalPlanBuilder> {
let schema = id_schema();
table_scan(Some("test"), &schema, None)?
.project(vec![col("id"), list_literal_expr(values).alias("elem")])?
.unnest_column(Column::from_name("elem"))
}

#[derive(Debug, Hash, PartialEq, Eq)]
struct NoOpUserDefined {
exprs: Vec<Expr>,
Expand Down Expand Up @@ -1312,6 +1455,73 @@ mod tests {
)
}

#[test]
fn remove_unused_non_empty_literal_unnest_under_group_by() -> Result<()> {
let plan = id_elem_unnest_plan(vec![Some(1), Some(2)])?
.aggregate(vec![col("id")], Vec::<Expr>::new())?
.build()?;

assert_optimized_plan_equal!(
plan,
@r"
Aggregate: groupBy=[[test.id]], aggr=[[]]
TableScan: test projection=[id]
"
)
}

#[test]
fn remove_unused_unnest_below_projection_under_group_by() -> Result<()> {
let plan = id_elem_unnest_plan(vec![Some(1), Some(2)])?
.project(vec![col("id")])?
.aggregate(vec![col("id")], Vec::<Expr>::new())?
.build()?;

assert_optimized_plan_equal!(
plan,
@r"
Aggregate: groupBy=[[test.id]], aggr=[[]]
TableScan: test projection=[id]
"
)
}

#[test]
fn keep_referenced_unnest_under_group_by() -> Result<()> {
let plan = id_elem_unnest_plan(vec![Some(1), Some(2)])?
.aggregate(vec![col("elem")], Vec::<Expr>::new())?
.build()?;

assert_optimized_plan_equal!(
plan,
@r"
Aggregate: groupBy=[[elem]], aggr=[[]]
Unnest: lists[elem|depth=1] structs[]
Projection: List([1, 2]) AS elem
TableScan: test projection=[]
"
)
}

#[test]
fn keep_unused_empty_literal_unnest_under_group_by() -> Result<()> {
let empty_list: Vec<Option<i64>> = vec![];
let plan = id_elem_unnest_plan(empty_list)?
.aggregate(vec![col("id")], Vec::<Expr>::new())?
.build()?;

assert_optimized_plan_equal!(
plan,
@r"
Aggregate: groupBy=[[test.id]], aggr=[[]]
Projection: test.id
Unnest: lists[elem|depth=1] structs[]
Projection: test.id, List([]) AS elem
TableScan: test projection=[id]
"
)
}

#[test]
fn test_neg_push_down() -> Result<()> {
let table_scan = test_table_scan()?;
Expand Down
55 changes: 51 additions & 4 deletions datafusion/sqllogictest/test_files/unnest.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1348,8 +1348,8 @@ SELECT * FROM (
(3, arrow_cast(NULL, 'List(Int64)'))
) AS t(id, arr);

# Reproducer for the optimization gap: the unused `elem` output is duplicate-insensitive
# below this GROUP BY, but the current plan still keeps Unnest/UnnestExec.
# The unused `elem` output only duplicates rows below this GROUP BY, so it can
# be pruned without changing the grouped ids.
query I
SELECT id
FROM (
Expand All @@ -1372,9 +1372,43 @@ FROM (
GROUP BY id;
----
logical_plan
<slt:ignore>Unnest:<slt:ignore>
01)Aggregate: groupBy=[[unused_unnest_pruning.id]], aggr=[[]]
02)--TableScan: unused_unnest_pruning projection=[id]
physical_plan
<slt:ignore>UnnestExec<slt:ignore>
01)AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[]
02)--RepartitionExec: partitioning=Hash([id@0], 4), input_partitions=1
03)----AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[]
04)------DataSourceExec: partitions=1, partition_sizes=[1]

# DISTINCT is implemented as a duplicate-insensitive aggregate before projection
# pruning, so the same unused non-empty literal UNNEST can be removed.
query I
SELECT DISTINCT id
FROM (
SELECT id, UNNEST(make_array(1, 2, 3)) AS elem
FROM unused_unnest_pruning
)
ORDER BY id;
----
1
2
3

query TT
EXPLAIN SELECT DISTINCT id
FROM (
SELECT id, UNNEST(make_array(1, 2, 3)) AS elem
FROM unused_unnest_pruning
);
----
logical_plan
01)Aggregate: groupBy=[[unused_unnest_pruning.id]], aggr=[[]]
02)--TableScan: unused_unnest_pruning projection=[id]
physical_plan
01)AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[]
02)--RepartitionExec: partitioning=Hash([id@0], 4), input_partitions=1
03)----AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[]
04)------DataSourceExec: partitions=1, partition_sizes=[1]

# Counterexample: removing UNNEST here would change cardinality.
query I rowsort
Expand Down Expand Up @@ -1417,5 +1451,18 @@ FROM (
----
2

# Empty and NULL input lists can remove rows before grouping, so this UNNEST
# must not be pruned even though `elem` is not projected above the GROUP BY.
query I
SELECT id
FROM (
SELECT id, UNNEST(arr) AS elem
FROM unused_unnest_pruning
)
GROUP BY id
ORDER BY id;
----
1

statement ok
DROP TABLE unused_unnest_pruning;
Loading