From 0258299b2656c1baab487caec0a6b7f443c6b788 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Mon, 26 Jan 2026 16:01:50 -0500 Subject: [PATCH 1/5] feat: Allow struct field access projections to be pushed down into scans --- Cargo.lock | 1 + datafusion/expr-common/src/lib.rs | 3 + datafusion/expr-common/src/placement.rs | 68 +++++ datafusion/expr/src/expr.rs | 24 +- datafusion/expr/src/lib.rs | 1 + datafusion/expr/src/udf.rs | 52 ++++ datafusion/functions/src/core/getfield.rs | 111 ++++++- .../optimizer/src/optimize_projections/mod.rs | 17 +- .../physical-expr-common/src/physical_expr.rs | 19 ++ .../physical-expr/src/expressions/column.rs | 6 +- .../physical-expr/src/expressions/literal.rs | 6 +- datafusion/physical-expr/src/projection.rs | 5 + .../physical-expr/src/scalar_function.rs | 12 +- datafusion/physical-optimizer/Cargo.toml | 1 + .../src/output_requirements.rs | 9 +- .../src/projection_pushdown.rs | 275 ++++++++++++++++++ .../physical-plan/src/coalesce_partitions.rs | 6 +- datafusion/physical-plan/src/filter.rs | 98 ++++++- datafusion/physical-plan/src/projection.rs | 222 +++++++++++--- .../physical-plan/src/repartition/mod.rs | 11 +- datafusion/physical-plan/src/sorts/sort.rs | 150 +++++++++- .../src/sorts/sort_preserving_merge.rs | 5 +- .../test_files/projection_pushdown.slt | 85 +++--- datafusion/sqllogictest/test_files/unnest.slt | 4 +- 24 files changed, 1043 insertions(+), 148 deletions(-) create mode 100644 datafusion/expr-common/src/placement.rs diff --git a/Cargo.lock b/Cargo.lock index 2cf439134dda9..8539b444bdaea 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2439,6 +2439,7 @@ dependencies = [ "datafusion-physical-expr-common", "datafusion-physical-plan", "datafusion-pruning", + "indexmap 2.13.0", "insta", "itertools 0.14.0", "recursive", diff --git a/datafusion/expr-common/src/lib.rs b/datafusion/expr-common/src/lib.rs index 2be066beaad24..758f2540a22f2 100644 --- a/datafusion/expr-common/src/lib.rs +++ b/datafusion/expr-common/src/lib.rs @@ -41,7 +41,10 @@ pub mod dyn_eq; pub mod groups_accumulator; pub mod interval_arithmetic; pub mod operator; +pub mod placement; pub mod signature; pub mod sort_properties; pub mod statistics; pub mod type_coercion; + +pub use placement::ExpressionPlacement; diff --git a/datafusion/expr-common/src/placement.rs b/datafusion/expr-common/src/placement.rs new file mode 100644 index 0000000000000..0527cda11e466 --- /dev/null +++ b/datafusion/expr-common/src/placement.rs @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Expression placement classification for scalar functions. +//! +//! This module determines where in the query plan expressions should be placed +//! to optimize data flow: +//! +//! - **Leaf placement**: Cheap expressions (field accessors, column references) +//! are pushed down to leaf nodes (near data sources) to reduce data volume early. +//! - **Root placement**: Expensive expressions (computations, aggregates) are kept +//! at root nodes (after filtering) to operate on less data. + +/// Classification of expression placement for scalar functions. +/// +/// This enum is used by [`ScalarUDFImpl::placement`] to allow +/// functions to make context-dependent decisions about where they should +/// be placed in the query plan based on the nature of their arguments. +/// +/// For example, `get_field(struct_col, 'field_name')` is +/// leaf-pushable (static field lookup), but `string_col like '%foo%'` +/// performs expensive per-row computation and should be placed +/// as further up the tree so that it can be run after filtering, sorting, etc. +/// +/// # Why not pass in expressions directly to decide placement? +/// +/// There are two reasons for using this enum instead of passing in the full expressions: +/// +/// 1. **ScalarUDFImpl cannot reference PhysicalExpr**: The trait is defined in datafusion-expr, +/// which cannot reference datafusion-physical-expr since the latter depends on the former +/// (it would create a circular dependency). +/// 2. **Simplicity**: Without this enum abstracting away logical / physical distinctions, +/// we would need two distinct methods on ScalarUDFImpl: one for logical expression placement +/// and one for physical expression placement. This would require implementors to duplicate logic +/// and increases complexity for UDF authors. +/// +/// [`ScalarUDFImpl::placement`]: https://docs.rs/datafusion-expr/latest/datafusion_expr/trait.ScalarUDFImpl.html#tymethod.placement +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ExpressionPlacement { + /// Argument is a literal constant value or an expression that can be + /// evaluated to a constant at planning time. + Literal, + /// Argument is a simple column reference. + Column, + /// Argument is a complex expression that can be safely placed at leaf nodes. + /// For example, if `get_field(struct_col, 'field_name')` is implemented as a + /// leaf-pushable expression, then it would return this variant. + /// Then `other_leaf_function(get_field(...), 42)` could also be classified as + /// leaf-pushable using the knowledge that `get_field(...)` is leaf-pushable. + PlaceAtLeafs, + /// Argument is a complex expression that should be placed at root nodes. + /// For example, `min(col1 + col2)` is not leaf-pushable because it requires per-row computation. + PlaceAtRoot, +} diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 454839fdb75ac..414a4e52d0011 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -27,7 +27,7 @@ use std::sync::Arc; use crate::expr_fn::binary_expr; use crate::function::WindowFunctionSimplification; use crate::logical_plan::Subquery; -use crate::{AggregateUDF, Volatility}; +use crate::{AggregateUDF, ExpressionPlacement, Volatility}; use crate::{ExprSchemable, Operator, Signature, WindowFrame, WindowUDF}; use arrow::datatypes::{DataType, Field, FieldRef}; @@ -1933,6 +1933,28 @@ impl Expr { } } + /// Returns the placement classification of this expression. + /// + /// This tells us if optimizers should preferentially + /// move this expression towards the leafs of the execution plan + /// tree (for cheap expressions or expressions that reduce the data size) + /// or towards the root of the execution plan tree (for expensive expressions + /// that should be run after filtering or parallelization, or expressions that increase the data size). + pub fn placement(&self) -> ExpressionPlacement { + match self { + Expr::Column(_) => ExpressionPlacement::Column, + Expr::Literal(_, _) => ExpressionPlacement::Literal, + Expr::ScalarFunction(func) => { + // Classify each argument's placement for context-aware decision making + let arg_placements: Vec = + func.args.iter().map(|arg| arg.placement()).collect(); + + func.func.placement_with_args(&arg_placements) + } + _ => ExpressionPlacement::PlaceAtRoot, + } + } + /// Return all references to columns in this expression. /// /// # Example diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 978e9f627565c..d40406de0e041 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -92,6 +92,7 @@ pub use datafusion_doc::{ DocSection, Documentation, DocumentationBuilder, aggregate_doc_sections, scalar_doc_sections, window_doc_sections, }; +pub use datafusion_expr_common::ExpressionPlacement; pub use datafusion_expr_common::accumulator::Accumulator; pub use datafusion_expr_common::columnar_value::ColumnarValue; pub use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 870e318a62c3d..97f8fbba2b4d3 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -31,6 +31,7 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::{ExprSchema, Result, ScalarValue, not_impl_err}; use datafusion_expr_common::dyn_eq::{DynEq, DynHash}; use datafusion_expr_common::interval_arithmetic::Interval; +use datafusion_expr_common::placement::ExpressionPlacement; use std::any::Any; use std::cmp::Ordering; use std::fmt::Debug; @@ -123,6 +124,22 @@ impl ScalarUDF { Self { inner: fun } } + /// Returns the placement classification of this function given its arguments' placement. + /// + /// This allows functions to make context-dependent decisions about where they should + /// be placed in the query plan. For example, `get_field(struct_col, 'field_name')` is + /// leaf-pushable (static field lookup), but `string_col like '%foo%'` + /// performs expensive per-row computation and should be placed + /// as further up the tree so that it can be run after filtering, sorting, etc. + /// + /// See [`ScalarUDFImpl::placement`] for more details. + pub fn placement_with_args( + &self, + args: &[ExpressionPlacement], + ) -> ExpressionPlacement { + self.inner.placement(args) + } + /// Return the underlying [`ScalarUDFImpl`] trait object for this function pub fn inner(&self) -> &Arc { &self.inner @@ -885,6 +902,37 @@ pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync { fn documentation(&self) -> Option<&Documentation> { None } + + /// Returns the placement classification of this function given its arguments' placement. + /// + /// This method allows functions to make context-dependent decisions about + /// where they should be placed in the query plan. The default implementation + /// returns [`ExpressionPlacement::PlaceAtRoot`] (conservative default). + /// + /// Leaf-pushable functions are lightweight accessor functions like `get_field` + /// (struct field access) that simply access nested data within a column + /// without significant computation. + /// These can be pushed down to leaf nodes near data sources to reduce data volume early in the plan. + /// + /// [`ExpressionPlacement::PlaceAtRoot`] represents expressions that should be kept after filtering, + /// such as expensive computations or aggregates that benefit from operating + /// on fewer rows. + /// + /// # Example + /// + /// - `get_field(struct_col, 'field_name')` with a literal key is leaf-pushable as it + /// performs metadata only (cheap) extraction of a sub-array from a struct column. + /// Thus, it can be placed near the data source to minimize data early. + /// - `string_col like '%foo%'` performs expensive per-row computation and should be placed + /// further up the tree so that it can be run after filtering, sorting, etc. + /// + /// # Arguments + /// + /// * `args` - Classification of each argument's placement, collected from the expression tree + /// by the caller. + fn placement(&self, _args: &[ExpressionPlacement]) -> ExpressionPlacement { + ExpressionPlacement::PlaceAtRoot + } } /// ScalarUDF that adds an alias to the underlying function. It is better to @@ -1012,6 +1060,10 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { fn documentation(&self) -> Option<&Documentation> { self.inner.documentation() } + + fn placement(&self, args: &[ExpressionPlacement]) -> ExpressionPlacement { + self.inner.placement(args) + } } #[cfg(test)] diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 47a903639dde5..09bf68b44635d 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -33,8 +33,8 @@ use datafusion_common::{ use datafusion_expr::expr::ScalarFunction; use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, - ScalarUDFImpl, Signature, Volatility, + ColumnarValue, Documentation, Expr, ExpressionPlacement, ReturnFieldArgs, + ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; @@ -499,6 +499,37 @@ impl ScalarUDFImpl for GetFieldFunc { fn documentation(&self) -> Option<&Documentation> { self.doc() } + + fn placement(&self, args: &[ExpressionPlacement]) -> ExpressionPlacement { + // get_field is leaf-pushable if: + // 1. The struct/map argument is a Column or PlaceAtLeafs (not Literal or PlaceAtRoot) + // 2. All key arguments are literals (static field access, not dynamic per-row lookup) + // + // Literal base is not considered leaf-pushable because it would be constant-folded anyway. + if args.is_empty() { + return ExpressionPlacement::PlaceAtRoot; + } + + // Check if the base (struct/map) argument is Column or PlaceAtLeafs + if !matches!( + args[0], + ExpressionPlacement::Column | ExpressionPlacement::PlaceAtLeafs + ) { + return ExpressionPlacement::PlaceAtRoot; + } + + // All key arguments (after the first) must be literals for static field access + let keys_literal = args + .iter() + .skip(1) + .all(|a| *a == ExpressionPlacement::Literal); + + if keys_literal { + ExpressionPlacement::PlaceAtLeafs + } else { + ExpressionPlacement::PlaceAtRoot + } + } } #[cfg(test)] @@ -542,4 +573,80 @@ mod tests { Ok(()) } + + #[test] + fn test_placement_with_args_literal_key() { + let func = GetFieldFunc::new(); + + // get_field(col, 'literal') -> leaf-pushable (static field access) + let args = vec![ExpressionPlacement::Column, ExpressionPlacement::Literal]; + assert_eq!(func.placement(&args), ExpressionPlacement::PlaceAtLeafs); + + // get_field(col, 'a', 'b') -> leaf-pushable (nested static field access) + let args = vec![ + ExpressionPlacement::Column, + ExpressionPlacement::Literal, + ExpressionPlacement::Literal, + ]; + assert_eq!(func.placement(&args), ExpressionPlacement::PlaceAtLeafs); + + // get_field(get_field(col, 'a'), 'b') represented as PlaceAtLeafs for base + let args = vec![ + ExpressionPlacement::PlaceAtLeafs, + ExpressionPlacement::Literal, + ]; + assert_eq!(func.placement(&args), ExpressionPlacement::PlaceAtLeafs); + } + + #[test] + fn test_placement_with_args_column_key() { + let func = GetFieldFunc::new(); + + // get_field(col, other_col) -> NOT leaf-pushable (dynamic per-row lookup) + let args = vec![ExpressionPlacement::Column, ExpressionPlacement::Column]; + assert_eq!(func.placement(&args), ExpressionPlacement::PlaceAtRoot); + + // get_field(col, 'a', other_col) -> NOT leaf-pushable (dynamic nested lookup) + let args = vec![ + ExpressionPlacement::Column, + ExpressionPlacement::Literal, + ExpressionPlacement::Column, + ]; + assert_eq!(func.placement(&args), ExpressionPlacement::PlaceAtRoot); + } + + #[test] + fn test_placement_with_args_root() { + let func = GetFieldFunc::new(); + + // get_field(root_expr, 'literal') -> NOT leaf-pushable + let args = vec![ + ExpressionPlacement::PlaceAtRoot, + ExpressionPlacement::Literal, + ]; + assert_eq!(func.placement(&args), ExpressionPlacement::PlaceAtRoot); + + // get_field(col, root_expr) -> NOT leaf-pushable + let args = vec![ + ExpressionPlacement::Column, + ExpressionPlacement::PlaceAtRoot, + ]; + assert_eq!(func.placement(&args), ExpressionPlacement::PlaceAtRoot); + } + + #[test] + fn test_placement_with_args_edge_cases() { + let func = GetFieldFunc::new(); + + // Empty args -> NOT leaf-pushable + assert_eq!(func.placement(&[]), ExpressionPlacement::PlaceAtRoot); + + // Just base, no key -> PlaceAtLeafs (not a valid call but should handle gracefully) + let args = vec![ExpressionPlacement::Column]; + assert_eq!(func.placement(&args), ExpressionPlacement::PlaceAtLeafs); + + // Literal base with literal key -> NOT leaf-pushable (would be constant-folded) + let args = vec![ExpressionPlacement::Literal, ExpressionPlacement::Literal]; + assert_eq!(func.placement(&args), ExpressionPlacement::PlaceAtRoot); + } } diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index f97b05ea68fbd..764adac03faff 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -30,8 +30,8 @@ use datafusion_common::{ }; use datafusion_expr::expr::Alias; use datafusion_expr::{ - Aggregate, Distinct, EmptyRelation, Expr, Projection, TableScan, Unnest, Window, - logical_plan::LogicalPlan, + Aggregate, Distinct, EmptyRelation, Expr, ExpressionPlacement, Projection, TableScan, + Unnest, Window, logical_plan::LogicalPlan, }; use crate::optimize_projections::required_indices::RequiredIndices; @@ -530,9 +530,11 @@ fn merge_consecutive_projections(proj: Projection) -> Result 1 - && !is_expr_trivial( - &prev_projection.expr - [prev_projection.schema.index_of_column(col).unwrap()], + && matches!( + prev_projection.expr + [prev_projection.schema.index_of_column(col).unwrap()] + .placement(), + ExpressionPlacement::PlaceAtRoot ) }) { // no change @@ -586,11 +588,6 @@ fn merge_consecutive_projections(proj: Projection) -> Result bool { - matches!(expr, Expr::Column(_) | Expr::Literal(_, _)) -} - /// Rewrites a projection expression using the projection before it (i.e. its input) /// This is a subroutine to the `merge_consecutive_projections` function. /// diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index 2358a21940912..67d39356653d3 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -35,6 +35,7 @@ use datafusion_common::{ }; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; +use datafusion_expr_common::placement::ExpressionPlacement; use datafusion_expr_common::sort_properties::ExprProperties; use datafusion_expr_common::statistics::Distribution; @@ -430,6 +431,24 @@ pub trait PhysicalExpr: Any + Send + Sync + Display + Debug + DynEq + DynHash { fn is_volatile_node(&self) -> bool { false } + + /// Returns the placement classification of this expression. + /// + /// Leaf-pushable expressions include: + /// - Column references (`ExpressionPlacement::Column`) + /// - Literal values (`ExpressionPlacement::Literal`) + /// - Struct field access via `get_field` (`ExpressionPlacement::PlaceAtLeafs`) + /// - Nested combinations of field accessors (e.g., `col['a']['b']`) + /// + /// This is used to identify expressions that are cheap to duplicate or + /// don't benefit from caching/partitioning optimizations. + /// + /// **Performance note**: Expressions marked as `PlaceAtLeafs` may be pushed + /// below filters during optimization. If an expression does per-row work, + /// marking it leaf-pushable may slow things down by causing evaluation on more rows. + fn placement(&self) -> ExpressionPlacement { + ExpressionPlacement::PlaceAtRoot + } } #[deprecated( diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index 8c7e8c319fff4..d8830b98251e2 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -29,7 +29,7 @@ use arrow::{ }; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{Result, internal_err, plan_err}; -use datafusion_expr::ColumnarValue; +use datafusion_expr::{ColumnarValue, ExpressionPlacement}; /// Represents the column at a given index in a RecordBatch /// @@ -146,6 +146,10 @@ impl PhysicalExpr for Column { fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.name) } + + fn placement(&self) -> ExpressionPlacement { + ExpressionPlacement::Column + } } impl Column { diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index 1f3fefc60b7ad..f2a8abb1e5fe2 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -30,7 +30,7 @@ use arrow::{ }; use datafusion_common::metadata::FieldMetadata; use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::Expr; +use datafusion_expr::{Expr, ExpressionPlacement}; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; use datafusion_expr_common::sort_properties::{ExprProperties, SortProperties}; @@ -134,6 +134,10 @@ impl PhysicalExpr for Literal { fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { std::fmt::Display::fmt(self, f) } + + fn placement(&self) -> ExpressionPlacement { + ExpressionPlacement::Literal + } } /// Create a literal expression diff --git a/datafusion/physical-expr/src/projection.rs b/datafusion/physical-expr/src/projection.rs index 540fd620c92ce..71a104f53021b 100644 --- a/datafusion/physical-expr/src/projection.rs +++ b/datafusion/physical-expr/src/projection.rs @@ -670,6 +670,11 @@ impl ProjectionExprs { stats.column_statistics = column_statistics; Ok(stats) } + + /// Get the projection expressions as a slice. + pub fn exprs(&self) -> &[ProjectionExpr] { + &self.exprs + } } impl<'a> IntoIterator for &'a ProjectionExprs { diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index aa090743ad441..6b2c451bd8156 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -45,8 +45,8 @@ use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::functions::fields_with_udf; use datafusion_expr::{ - ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, Volatility, - expr_vec_fmt, + ColumnarValue, ExpressionPlacement, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, + Volatility, expr_vec_fmt, }; /// Physical expression of a scalar function @@ -362,6 +362,14 @@ impl PhysicalExpr for ScalarFunctionExpr { fn is_volatile_node(&self) -> bool { self.fun.signature().volatility == Volatility::Volatile } + + fn placement(&self) -> ExpressionPlacement { + // Classify each argument's placement for context-aware decision making + let arg_placements: Vec = + self.args.iter().map(|arg| arg.placement()).collect(); + + self.fun.placement_with_args(&arg_placements) + } } #[cfg(test)] diff --git a/datafusion/physical-optimizer/Cargo.toml b/datafusion/physical-optimizer/Cargo.toml index 395da10d629ba..620265dfdc0f6 100644 --- a/datafusion/physical-optimizer/Cargo.toml +++ b/datafusion/physical-optimizer/Cargo.toml @@ -50,6 +50,7 @@ datafusion-physical-expr = { workspace = true } datafusion-physical-expr-common = { workspace = true } datafusion-physical-plan = { workspace = true } datafusion-pruning = { workspace = true } +indexmap = { workspace = true } itertools = { workspace = true } recursive = { workspace = true, optional = true } diff --git a/datafusion/physical-optimizer/src/output_requirements.rs b/datafusion/physical-optimizer/src/output_requirements.rs index 0dc6a25fbc0b7..d899cbe73106a 100644 --- a/datafusion/physical-optimizer/src/output_requirements.rs +++ b/datafusion/physical-optimizer/src/output_requirements.rs @@ -256,18 +256,13 @@ impl ExecutionPlan for OutputRequirementExec { &self, projection: &ProjectionExec, ) -> Result>> { - // If the projection does not narrow the schema, we should not try to push it down: - let proj_exprs = projection.expr(); - if proj_exprs.len() >= projection.input().schema().fields().len() { - return Ok(None); - } - let mut requirements = self.required_input_ordering().swap_remove(0); if let Some(reqs) = requirements { let mut updated_reqs = vec![]; let (lexes, soft) = reqs.into_alternatives(); for lex in lexes.into_iter() { - let Some(updated_lex) = update_ordering_requirement(lex, proj_exprs)? + let Some(updated_lex) = + update_ordering_requirement(lex, projection.expr())? else { return Ok(None); }; diff --git a/datafusion/physical-optimizer/src/projection_pushdown.rs b/datafusion/physical-optimizer/src/projection_pushdown.rs index 281d61aecf538..ae64de35c1c0b 100644 --- a/datafusion/physical-optimizer/src/projection_pushdown.rs +++ b/datafusion/physical-optimizer/src/projection_pushdown.rs @@ -23,6 +23,7 @@ use crate::PhysicalOptimizerRule; use arrow::datatypes::{Fields, Schema, SchemaRef}; use datafusion_common::alias::AliasGenerator; +use indexmap::IndexMap; use std::collections::HashSet; use std::sync::Arc; @@ -31,7 +32,9 @@ use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; use datafusion_common::{JoinSide, JoinType, Result}; +use datafusion_expr::ExpressionPlacement; use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::projection::ProjectionExpr; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use datafusion_physical_plan::ExecutionPlan; use datafusion_physical_plan::joins::NestedLoopJoinExec; @@ -76,6 +79,13 @@ impl PhysicalOptimizerRule for ProjectionPushdown { }) .map(|t| t.data)?; + // First, try to split mixed projections (beneficial + non-beneficial expressions) + // This allows the beneficial parts to be pushed down while keeping non-beneficial parts above. + let plan = plan + .transform_down(|plan| try_split_projection(plan, &alias_generator)) + .map(|t| t.data)?; + + // Then apply the normal projection pushdown logic plan.transform_down(remove_unnecessary_projections).data() } @@ -88,6 +98,271 @@ impl PhysicalOptimizerRule for ProjectionPushdown { } } +/// Tries to split a projection to extract beneficial sub-expressions for pushdown. +/// +/// This function walks each expression in the projection and extracts beneficial +/// sub-expressions (like `get_field`) from within larger non-beneficial expressions. +/// For example: +/// - Input: `get_field(col, 'foo') + 1` +/// - Output: Inner projection: `get_field(col, 'foo') AS __extracted_0`, Outer: `__extracted_0 + 1` +/// +/// This enables the beneficial parts to be pushed down while keeping non-beneficial +/// expressions (like literals and computations) above. +fn try_split_projection( + plan: Arc, + alias_generator: &AliasGenerator, +) -> Result>> { + let Some(projection) = plan.as_any().downcast_ref::() else { + return Ok(Transformed::no(plan)); + }; + + let input_schema = projection.input().schema(); + let mut extractor = + LeafExpressionExtractor::new(input_schema.as_ref(), alias_generator); + + // Extract leaf-pushable sub-expressions from each projection expression + let mut outer_exprs = Vec::new(); + let mut has_extractions = false; + + for proj_expr in projection.expr() { + // If this is already an expression from an extraction don't try to re-extract it (would cause infinite recursion) + if proj_expr.alias.starts_with("__extracted") { + outer_exprs.push(proj_expr.clone()); + continue; + } + + // Only extract from root-level expressions. If the entire expression is + // already PlaceAtLeafs (like `get_field(col, 'foo')`), it can be pushed as-is. + // We only need to split when there's a root expression with leaf-pushable + // sub-expressions (like `get_field(col, 'foo') + 1`). + if matches!( + proj_expr.expr.placement(), + ExpressionPlacement::PlaceAtLeafs + ) { + outer_exprs.push(proj_expr.clone()); + continue; + } + + let rewritten = extractor.extract(Arc::clone(&proj_expr.expr))?; + if !Arc::ptr_eq(&rewritten, &proj_expr.expr) { + has_extractions = true; + } + outer_exprs.push(ProjectionExpr::new(rewritten, proj_expr.alias.clone())); + } + + if !has_extractions { + return Ok(Transformed::no(plan)); + } + + // Collect columns needed by outer expressions that aren't extracted + extractor.collect_columns_needed(&outer_exprs)?; + + // Build inner projection from extracted expressions + needed columns + let inner_exprs = extractor.build_inner_projection()?; + + if inner_exprs.is_empty() { + return Ok(Transformed::no(plan)); + } + + // Create the inner projection (to be pushed down) + let inner = ProjectionExec::try_new(inner_exprs, Arc::clone(projection.input()))?; + + // Rewrite outer expressions to reference the inner projection's output schema + let inner_schema = inner.schema(); + let final_outer_exprs = extractor.finalize_outer_exprs(outer_exprs, &inner_schema)?; + + // Create the outer projection (stays above) + let outer = ProjectionExec::try_new(final_outer_exprs, Arc::new(inner))?; + + Ok(Transformed::yes(Arc::new(outer))) +} + +/// Extracts beneficial leaf-pushable sub-expressions from larger expressions. +/// +/// Similar to `JoinFilterRewriter`, this struct walks expression trees top-down +/// and extracts sub-expressions where `placement() == ExpressionPlacement::PlaceAtLeafs` +/// (beneficial leaf-pushable expressions like field accessors). +/// +/// The extracted expressions are replaced with column references pointing to +/// an inner projection that computes these sub-expressions. +struct LeafExpressionExtractor<'a> { + /// Extracted leaf-pushable expressions: maps expression -> alias + extracted: IndexMap, String>, + /// Columns needed by outer expressions: maps input column index -> alias + columns_needed: IndexMap, + /// Input schema for the projection + input_schema: &'a Schema, + /// Alias generator for unique names + alias_generator: &'a AliasGenerator, +} + +impl<'a> LeafExpressionExtractor<'a> { + fn new(input_schema: &'a Schema, alias_generator: &'a AliasGenerator) -> Self { + Self { + extracted: IndexMap::new(), + columns_needed: IndexMap::new(), + input_schema, + alias_generator, + } + } + + /// Extracts beneficial leaf-pushable sub-expressions from the given expression. + /// + /// Walks the expression tree top-down and replaces beneficial leaf-pushable + /// sub-expressions with column references to the inner projection. + fn extract(&mut self, expr: Arc) -> Result> { + // Top-down: check self first, then recurse to children + if matches!(expr.placement(), ExpressionPlacement::PlaceAtLeafs) { + // Extract this entire sub-tree + return Ok(self.add_extracted_expr(expr)); + } + + // Not extractable at this level - recurse into children + let children = expr.children(); + if children.is_empty() { + return Ok(expr); + } + + let mut new_children = Vec::with_capacity(children.len()); + let mut any_changed = false; + + for child in children { + let new_child = self.extract(Arc::clone(child))?; + if !Arc::ptr_eq(&new_child, child) { + any_changed = true; + } + new_children.push(new_child); + } + + if any_changed { + expr.with_new_children(new_children) + } else { + Ok(expr) + } + } + + /// Adds an expression to the extracted set and returns a column reference. + /// + /// If the same expression was already extracted, reuses the existing alias. + fn add_extracted_expr( + &mut self, + expr: Arc, + ) -> Arc { + if let Some(alias) = self.extracted.get(&expr) { + // Already extracted - return a column reference + // The index will be determined later in finalize + Arc::new(Column::new(alias, 0)) as Arc + } else { + // New extraction + let alias = self.alias_generator.next("__extracted"); + self.extracted.insert(expr, alias.clone()); + Arc::new(Column::new(&alias, 0)) as Arc + } + } + + /// Collects columns from outer expressions that need to be passed through inner projection. + fn collect_columns_needed(&mut self, outer_exprs: &[ProjectionExpr]) -> Result<()> { + for proj in outer_exprs { + proj.expr.apply(|e| { + if let Some(col) = e.as_any().downcast_ref::() { + // Check if this column references an extracted expression (by alias) + let is_extracted = + self.extracted.values().any(|alias| alias == col.name()); + + if !is_extracted && !self.columns_needed.contains_key(&col.index()) { + // This is an original input column - need to pass it through + let field = self.input_schema.field(col.index()); + self.columns_needed + .insert(col.index(), field.name().clone()); + } + } + Ok(TreeNodeRecursion::Continue) + })?; + } + Ok(()) + } + + /// Builds the inner projection expressions from extracted expressions and needed columns. + fn build_inner_projection(&self) -> Result> { + let mut result: Vec = self + .extracted + .iter() + .map(|(expr, alias)| ProjectionExpr::new(Arc::clone(expr), alias.clone())) + .collect(); + + // Add columns needed by outer expressions + for (&col_idx, alias) in &self.columns_needed { + let field = self.input_schema.field(col_idx); + result.push(ProjectionExpr::new( + Arc::new(Column::new(field.name(), col_idx)), + alias.clone(), + )); + } + + Ok(result) + } + + /// Finalizes the outer expressions by fixing column indices to match the inner projection. + fn finalize_outer_exprs( + &self, + outer_exprs: Vec, + inner_schema: &Schema, + ) -> Result> { + // Build a map from alias name to index in inner projection + let mut alias_to_idx: IndexMap<&str, usize> = self + .extracted + .values() + .enumerate() + .map(|(idx, alias)| (alias.as_str(), idx)) + .collect(); + + // Add columns needed by outer expressions + let base_idx = self.extracted.len(); + for (i, (_, alias)) in self.columns_needed.iter().enumerate() { + alias_to_idx.insert(alias, base_idx + i); + } + + // Build a map from original column index to inner projection index + let mut col_idx_to_inner: IndexMap = IndexMap::new(); + for (i, (&col_idx, _)) in self.columns_needed.iter().enumerate() { + col_idx_to_inner.insert(col_idx, base_idx + i); + } + + // Rewrite column references in outer expressions + outer_exprs + .into_iter() + .map(|proj| { + let new_expr = Arc::clone(&proj.expr) + .transform(|e| { + if let Some(col) = e.as_any().downcast_ref::() { + // First check if it's a reference to an extracted expression + if let Some(&idx) = alias_to_idx.get(col.name()) { + let field = inner_schema.field(idx); + return Ok(Transformed::yes(Arc::new(Column::new( + field.name(), + idx, + )) + as Arc)); + } + // Then check if it's an original column that needs remapping + if let Some(&inner_idx) = col_idx_to_inner.get(&col.index()) { + let field = inner_schema.field(inner_idx); + return Ok(Transformed::yes(Arc::new(Column::new( + field.name(), + inner_idx, + )) + as Arc)); + } + } + Ok(Transformed::no(e)) + })? + .data; + Ok(ProjectionExpr::new(new_expr, proj.alias)) + }) + .collect() + } +} + /// Tries to push down parts of the filter. /// /// See [JoinFilterRewriter] for details. diff --git a/datafusion/physical-plan/src/coalesce_partitions.rs b/datafusion/physical-plan/src/coalesce_partitions.rs index 22dcc85d6ea3a..a58001645b87b 100644 --- a/datafusion/physical-plan/src/coalesce_partitions.rs +++ b/datafusion/physical-plan/src/coalesce_partitions.rs @@ -249,10 +249,12 @@ impl ExecutionPlan for CoalescePartitionsExec { &self, projection: &ProjectionExec, ) -> Result>> { - // If the projection does not narrow the schema, we should not try to push it down: - if projection.expr().len() >= projection.input().schema().fields().len() { + // Only push projections that are beneficial (narrow schema or have TrivialExpr). + // Pure column projections that don't narrow the schema provide no benefit. + if !projection.is_leaf_pushable_or_narrows_schema() { return Ok(None); } + // CoalescePartitionsExec always has a single child, so zero indexing is safe. make_with_child(projection, projection.input().children()[0]).map(|e| { if self.fetch.is_some() { diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 50fae84b85d0d..827ef232691b6 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -35,7 +35,7 @@ use crate::filter_pushdown::{ }; use crate::metrics::{MetricBuilder, MetricType}; use crate::projection::{ - EmbeddedProjection, ProjectionExec, ProjectionExpr, make_with_child, + EmbeddedProjection, ProjectionExec, ProjectionExpr, ProjectionWithDependencies, try_embed_projection, update_expr, }; use crate::{ @@ -567,20 +567,92 @@ impl ExecutionPlan for FilterExec { &self, projection: &ProjectionExec, ) -> Result>> { - // If the projection does not narrow the schema, we should not try to push it down: - if projection.expr().len() < projection.input().schema().fields().len() { - // Each column in the predicate expression must exist after the projection. - if let Some(new_predicate) = - update_expr(self.predicate(), projection.expr(), false)? - { - return FilterExecBuilder::from(self) - .with_input(make_with_child(projection, self.input())?) - .with_predicate(new_predicate) - .build() - .map(|e| Some(Arc::new(e) as _)); + // If we have an embedded projection already we cannot continue + // This is not a real problem: calling this method generates the embeded projection + // so we should not have one already! + if self.projection().is_some() { + return Ok(None); + } + + // Push projection through filter if: + // - It narrows the schema (drops columns), OR + // - It's trivial (columns or cheap expressions like get_field) + if !projection.is_leaf_pushable_or_narrows_schema() { + return try_embed_projection(projection, self); + } + + let pushed_down_projection = projection.projection_expr().clone(); + + // Collect columns needed by the predicate + let columns_needed_by_predicate = collect_columns(self.predicate()); + let columns_needed_by_predicate: Vec = columns_needed_by_predicate + .into_iter() + .map(|c| c.index()) + .collect(); + + // Augment projection with columns needed by predicate + let projection_with_deps = ProjectionWithDependencies::new( + &pushed_down_projection, + &columns_needed_by_predicate, + &self.input.schema(), + )?; + + // Rewrite predicate to reference augmented projection output + let new_predicate = match update_expr( + self.predicate(), + projection_with_deps.combined_projection.exprs(), + false, + )? { + Some(expr) => expr, + None => { + return internal_err!( + "Failed to rewrite predicate for projection pushdown" + ); } + }; + + // Restore projection indices (if augmentation was needed) + let restore_projection = projection_with_deps + .restore_projection + .as_ref() + .map(|_rp| (0..pushed_down_projection.exprs().len()).collect()); + + // Create the new projection that we will push down + let input = ProjectionExec::try_new( + projection_with_deps.combined_projection.exprs().to_vec(), + Arc::clone(self.input()), + )?; + + // Now that we have the new projection, we can ask the question again: + // Is it worth pushing down? + // We need to ask this question again because it's possible that taking into account the + // columns needed by the predicate, the projection is no longer narrowing + // (if it had trivial expressions and all we did was add columns needed by the predicate + // it will still be worth pushing down). + if !input.is_leaf_pushable_or_narrows_schema() { + // Effectively bail out of this whole process and just embed the projection + return try_embed_projection(projection, self); } - try_embed_projection(projection, self) + + // Try to push down further + let input: Arc = if let Some(new_input) = + input.input().try_swapping_with_projection(&input)? + { + new_input + } else { + Arc::new(input) + }; + + // Create the new FilterExec with the new predicate and projection + let new_filter = FilterExecBuilder::from(self) + .with_input(input) + .with_predicate(new_predicate) + .apply_projection(restore_projection) + .unwrap() + .build() + .unwrap(); + + Ok(Some(Arc::new(new_filter) as _)) } fn gather_filters_for_pushdown( diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index 8d4c775f87348..5d988a5fdbef3 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -20,7 +20,7 @@ //! of a projection on table `t1` where the expressions `a`, `b`, and `a+b` are the //! projection expressions. `SELECT` without `FROM` will only evaluate expressions. -use super::expressions::{Column, Literal}; +use super::expressions::Column; use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use super::{ DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream, @@ -35,7 +35,7 @@ use crate::filter_pushdown::{ use crate::joins::utils::{ColumnIndex, JoinFilter, JoinOn, JoinOnRef}; use crate::{DisplayFormatType, ExecutionPlan, PhysicalExpr}; use std::any::Any; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -48,6 +48,7 @@ use datafusion_common::tree_node::{ }; use datafusion_common::{DataFusionError, JoinSide, Result, internal_err}; use datafusion_execution::TaskContext; +use datafusion_expr::ExpressionPlacement; use datafusion_physical_expr::equivalence::ProjectionMapping; use datafusion_physical_expr::projection::Projector; use datafusion_physical_expr::utils::{collect_columns, reassign_expr_columns}; @@ -216,6 +217,72 @@ impl ProjectionExec { } Ok(alias_map) } + + /// Returns true if the projection is beneficial to push through most operators. + /// + /// Some cases to consider: + /// + /// - Sorts: sorts do expensive work (re-arranging rows) and thus benefit from + /// having as little data underneath them as possible. + /// We don't want to push down expressions such as literals that could be + /// evaluated after the sort. + /// Truthfully we probably would want to push down some complex expressions if + /// they reduce the amount of data (e.g. `a like '%foo%'` converts a large string + /// column into a boolean column) but we currently don't have any good way to estimate that. + /// - RepartitionExec / CoalesceBatchesExec: these operators change the parallelism + /// or batch sizes and are designed to optimize CPU work for operators above them. + /// Pushing down expensive expressions past them defeats their purpose so we want to avoid that. + /// - Filters: filters can reduce the amount of data processed by upstream operators, + /// so pushing down expensive computation under them would result in that computation being + /// applied to more rows. + /// Again if we knew that `a like '%foo%'` reduces the projection size significantly + /// and the filter is not selective we actually might want to push it down, but we don't have + /// a good way to estimate that currently. + /// - Joins: joins both benefit from having less data under them (they may have to select sparse rows) + /// but they also serve as filters. + /// + /// Obviously given the information we have currently, we cannot make perfect decisions here. + /// Our approach is to stick to the obvious cases: + /// + /// - If the projection narrows the schema (drops columns) and is only column references it + /// always makes sense to push it down. + /// - If the projection contains any trivial expression (which can reduce the data size + /// of the projection significantly at a very low computational cost) and does not contain + /// any computationally expensive expressions, we also consider it beneficial to push down. + /// + /// In all other cases we consider the projection not beneficial to push down. + /// + /// This is true when: + /// - The projection narrows the schema (drops columns) - saves memory, OR + /// - Any expression is PlaceAtLeafs (like get_field) - beneficial computation pushdown + /// + /// Pure Column references that don't narrow the schema are NOT beneficial to push, + /// as they just rearrange the plan without any gain. + /// + /// Note: Projections are split by `try_split_projection` before reaching this function, + /// so if any expression is PlaceAtLeafs, all expressions should be leaf-pushable. + pub fn is_leaf_pushable_or_narrows_schema(&self) -> bool { + let all_columns = self + .expr() + .iter() + .all(|proj_expr| proj_expr.expr.as_any().downcast_ref::().is_some()); + let narrows_schema = self.expr().len() < self.input().schema().fields().len(); + let all_columns_and_narrows_schema = all_columns && narrows_schema; + + let has_leaf_expressions = self + .expr() + .iter() + .any(|p| matches!(p.expr.placement(), ExpressionPlacement::PlaceAtLeafs)); + + let has_root_expressions = self.expr().iter().any(|p| { + matches!( + p.expr.placement(), + ExpressionPlacement::PlaceAtRoot | ExpressionPlacement::Literal + ) + }); + + (has_leaf_expressions && !has_root_expressions) || all_columns_and_narrows_schema + } } impl DisplayAs for ProjectionExec { @@ -279,18 +346,13 @@ impl ExecutionPlan for ProjectionExec { } fn benefits_from_input_partitioning(&self) -> Vec { - let all_simple_exprs = - self.projector - .projection() - .as_ref() - .iter() - .all(|proj_expr| { - proj_expr.expr.as_any().is::() - || proj_expr.expr.as_any().is::() - }); - // If expressions are all either column_expr or Literal, then all computations in this projection are reorder or rename, - // and projection would not benefit from the repartition, benefits_from_input_partitioning will return false. - vec![!all_simple_exprs] + // If all projection expressions are either column references or leaf-pushable expressions, + // then all operations are cheap and don't benefit from partitioning. + let has_expensive_expressions = self + .expr() + .iter() + .all(|p| !matches!(p.expr.placement(), ExpressionPlacement::PlaceAtRoot)); + vec![!has_expensive_expressions] } fn children(&self) -> Vec<&Arc> { @@ -692,7 +754,9 @@ pub fn remove_unnecessary_projections( if is_projection_removable(projection) { return Ok(Transformed::yes(Arc::clone(projection.input()))); } - // If it does, check if we can push it under its child(ren): + + // Try to push the projection under its child(ren). Each operator's + // try_swapping_with_projection handles the operator-specific logic. projection .input() .try_swapping_with_projection(projection)? @@ -758,13 +822,6 @@ pub fn make_with_child( .map(|e| Arc::new(e) as _) } -/// Returns `true` if all the expressions in the argument are `Column`s. -pub fn all_columns(exprs: &[ProjectionExpr]) -> bool { - exprs - .iter() - .all(|proj_expr| proj_expr.expr.as_any().is::()) -} - /// Updates the given lexicographic ordering according to given projected /// expressions using the [`update_expr`] function. pub fn update_ordering( @@ -801,6 +858,96 @@ pub fn update_ordering_requirement( Ok(LexRequirement::new(updated_exprs)) } +/// Augments a projection with additional columns needed by an operator. +/// +/// This helper encapsulates the pattern used by FilterExec and other operators +/// that need to rewrite expressions after projections drop columns in [`ExecutionPlan::try_swapping_with_projection`]. +/// +/// # Arguments +/// * `base_projection` - The original projection expressions +/// * `needed_columns` - Column indices that are needed by expressions in the operator +/// * `input_schema` - The input schema before projection +/// +/// # Returns +/// * `ProjectionWithDependencies` struct with: +/// - `combined_projection`: Augmented projection including base + additional columns +/// - `restore_projection`: Optional projection to restore original columns only +/// +/// # Example +/// If projection selects [A, C] but sort needs column B: +/// - combined_projection: [A, C, B] +/// - restore_projection: [0, 1] (selects A, C from the combined projection output) +#[derive(Debug, Clone)] +pub struct ProjectionWithDependencies { + /// Combined projection including base and additional needed columns. + pub combined_projection: ProjectionExprs, + /// Projection that restores original columns after operator consumes the + /// columns that were added for dependencies. + pub restore_projection: Option, +} + +impl ProjectionWithDependencies { + /// Creates a new augmented projection with dependencies. + pub fn new( + base_projection: &ProjectionExprs, + needed_columns: &[usize], + input_schema: &SchemaRef, + ) -> Result { + let base_len = base_projection.iter().count(); + + // Collect columns already in the base projection + let base_indices: HashSet = base_projection + .iter() + .filter_map(|proj_expr| { + proj_expr + .expr + .as_any() + .downcast_ref::() + .map(|col| col.index()) + }) + .collect(); + + // Find columns needed by the operator that aren't in the projection + let additional_columns: Vec = needed_columns + .iter() + .filter(|idx| !base_indices.contains(idx)) + .copied() + .collect(); + + if additional_columns.is_empty() { + // No additional columns needed + return Ok(Self { + combined_projection: base_projection.clone(), + restore_projection: None, + }); + } + + // Create projection for the additional columns + let additional_projection = + ProjectionExprs::from_indices(&additional_columns, input_schema); + + // Combine base projection with additional columns + let combined_projection = ProjectionExprs::new( + base_projection + .iter() + .cloned() + .chain(additional_projection.iter().cloned()), + ); + + // Create restore projection that selects only original columns (indices 0..base_len) + // from the combined projection's output schema + let combined_schema = combined_projection.project_schema(input_schema)?; + let restore_indices: Vec = (0..base_len).collect(); + let restore_projection = + ProjectionExprs::from_indices(&restore_indices, &combined_schema); + + Ok(Self { + combined_projection, + restore_projection: Some(restore_projection), + }) + } +} + /// Downcasts all the expressions in `exprs` to `Column`s. If any of the given /// expressions is not a `Column`, returns `None`. pub fn physical_to_column_exprs( @@ -1002,14 +1149,20 @@ fn try_unifying_projections( }) .unwrap(); }); - // Merging these projections is not beneficial, e.g - // If an expression is not trivial and it is referred more than 1, unifies projections will be - // beneficial as caching mechanism for non-trivial computations. - // See discussion in: https://github.com/apache/datafusion/issues/8296 - if column_ref_map.iter().any(|(column, count)| { - *count > 1 && !is_expr_trivial(&Arc::clone(&child.expr()[column.index()].expr)) - }) { - return Ok(None); + // Don't merge if: + // 1. A non-trivial expression is referenced more than once (caching benefit) + // See discussion in: https://github.com/apache/datafusion/issues/8296 + // 2. The child projection has PlaceAtLeafs (like get_field) that should be pushed + // down to the data source separately + for (column, count) in column_ref_map.iter() { + let placement = child.expr()[column.index()].expr.placement(); + // Don't merge if multi-referenced root level (caching) + if (*count > 1 && matches!(placement, ExpressionPlacement::PlaceAtRoot)) + // Don't merge if child has PlaceAtLeafs (should push to source) + || matches!(placement, ExpressionPlacement::PlaceAtLeafs) + { + return Ok(None); + } } for proj_expr in projection.expr() { // If there is no match in the input projection, we cannot unify these @@ -1034,7 +1187,7 @@ fn collect_column_indices(exprs: &[ProjectionExpr]) -> Vec { .iter() .flat_map(|proj_expr| collect_columns(&proj_expr.expr)) .map(|x| x.index()) - .collect::>() + .collect::>() .into_iter() .collect::>(); indices.sort(); @@ -1117,13 +1270,6 @@ fn new_columns_for_join_on( (new_columns.len() == hash_join_on.len()).then_some(new_columns) } -/// Checks if the given expression is trivial. -/// An expression is considered trivial if it is either a `Column` or a `Literal`. -fn is_expr_trivial(expr: &Arc) -> bool { - expr.as_any().downcast_ref::().is_some() - || expr.as_any().downcast_ref::().is_some() -} - #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 612c7bb27ddf4..567f4694c2b56 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -34,7 +34,7 @@ use crate::coalesce::LimitedBatchCoalescer; use crate::execution_plan::{CardinalityEffect, EvaluationType, SchedulingType}; use crate::hash_utils::create_hashes; use crate::metrics::{BaselineMetrics, SpillMetrics}; -use crate::projection::{ProjectionExec, all_columns, make_with_child, update_expr}; +use crate::projection::{ProjectionExec, make_with_child, update_expr}; use crate::sorts::streaming_merge::StreamingMergeBuilder; use crate::spill::spill_manager::SpillManager; use crate::spill::spill_pool::{self, SpillPoolWriter}; @@ -1123,15 +1123,8 @@ impl ExecutionPlan for RepartitionExec { &self, projection: &ProjectionExec, ) -> Result>> { - // If the projection does not narrow the schema, we should not try to push it down. - if projection.expr().len() >= projection.input().schema().fields().len() { - return Ok(None); - } - // If pushdown is not beneficial or applicable, break it. - if projection.benefits_from_input_partitioning()[0] - || !all_columns(projection.expr()) - { + if projection.benefits_from_input_partitioning()[0] { return Ok(None); } diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index a8361f7b2941e..4908a5dcd51c1 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -36,7 +36,7 @@ use crate::limit::LimitStream; use crate::metrics::{ BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, SpillMetrics, }; -use crate::projection::{ProjectionExec, make_with_child, update_ordering}; +use crate::projection::{ProjectionExec, ProjectionWithDependencies, update_ordering}; use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder}; use crate::spill::get_record_batch_memory_size; use crate::spill::in_progress_spill_file::InProgressSpillFile; @@ -65,6 +65,7 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_physical_expr::LexOrdering; use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_expr::expressions::{DynamicFilterPhysicalExpr, lit}; +use datafusion_physical_expr::utils::collect_columns; use futures::{StreamExt, TryStreamExt}; use log::{debug, trace}; @@ -1391,21 +1392,99 @@ impl ExecutionPlan for SortExec { &self, projection: &ProjectionExec, ) -> Result>> { - // If the projection does not narrow the schema, we should not try to push it down. - if projection.expr().len() >= projection.input().schema().fields().len() { + // Only push projections that are trivial (column refs, field accessors) or + // narrow the schema (drop columns). Non-trivial projections that don't narrow + // the schema would cause Sort to process more data than necessary. + if !projection.is_leaf_pushable_or_narrows_schema() { return Ok(None); } - let Some(updated_exprs) = update_ordering(self.expr.clone(), projection.expr())? + // 1. Collect columns needed by sort expressions + let mut columns_needed_by_sort = std::collections::HashSet::new(); + for sort_expr in &self.expr { + columns_needed_by_sort.extend(collect_columns(&sort_expr.expr)); + } + let columns_needed_by_sort: Vec = columns_needed_by_sort + .into_iter() + .map(|c| c.index()) + .collect(); + + // 2. Augment projection if needed + let pushed_down_projection = projection.projection_expr().clone(); + let projection_with_deps = ProjectionWithDependencies::new( + &pushed_down_projection, + &columns_needed_by_sort, + &self.input().schema(), + )?; + + // 3. Rewrite sort expressions for the augmented projection + let Some(updated_exprs) = update_ordering( + self.expr.clone(), + projection_with_deps.combined_projection.as_ref(), + )? else { return Ok(None); }; - Ok(Some(Arc::new( - SortExec::new(updated_exprs, make_with_child(projection, self.input())?) - .with_fetch(self.fetch()) - .with_preserve_partitioning(self.preserve_partitioning()), - ))) + // 4. Update common_sort_prefix if present + let updated_prefix = if !self.common_sort_prefix.is_empty() { + if let Some(prefix_ordering) = + LexOrdering::new(self.common_sort_prefix.iter().cloned()) + { + update_ordering( + prefix_ordering, + projection_with_deps.combined_projection.as_ref(), + )? + .map(|ordering| ordering.into_iter().collect()) + .unwrap_or_default() + } else { + vec![] + } + } else { + vec![] + }; + + // 5. Create projection to push down + let input = ProjectionExec::try_new( + projection_with_deps.combined_projection.exprs().to_vec(), + Arc::clone(self.input()), + )?; + + // 6. Re-validate if still beneficial + if !input.is_leaf_pushable_or_narrows_schema() { + // Augmentation made the projection non-beneficial, bail out + return Ok(None); + } + + // 7. Try recursive pushdown + let input: Arc = if let Some(new_input) = + input.input().try_swapping_with_projection(&input)? + { + new_input + } else { + Arc::new(input) + }; + + // 8. Build new SortExec + let new_sort = SortExec::new(updated_exprs, input) + .with_fetch(self.fetch()) + .with_preserve_partitioning(self.preserve_partitioning()); + + // Update common_sort_prefix + let mut new_sort = new_sort; + new_sort.common_sort_prefix = updated_prefix; + + let mut res: Arc = Arc::new(new_sort); + + // 9. Wrap in restore projection if needed + if let Some(restore_projection) = projection_with_deps.restore_projection { + res = Arc::new(ProjectionExec::try_new( + restore_projection.exprs().to_vec(), + res, + )?); + } + + Ok(Some(res)) } fn gather_filters_for_pushdown( @@ -1447,6 +1526,7 @@ mod tests { use crate::test::exec::{BlockingExec, assert_strong_count_converges_to_zero}; use crate::test::{assert_is_pending, make_partition}; + use crate::projection::ProjectionExprs; use arrow::array::*; use arrow::compute::SortOptions; use arrow::datatypes::*; @@ -2715,4 +2795,56 @@ mod tests { Ok(()) } + + #[test] + fn test_projection_with_dependencies_basic() -> Result<()> { + // Test the ProjectionWithDependencies helper struct + use crate::projection::ProjectionWithDependencies; + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ])); + + // Create a projection [a, c] (drops b) + let projection = ProjectionExprs::from_indices(&[0, 2], &schema); + + // Augment it with column b (index 1) + let augmented = ProjectionWithDependencies::new(&projection, &[1], &schema)?; + + // Should have a restore projection since we added a column + assert!(augmented.restore_projection.is_some()); + + // Combined projection should have all 3 columns + assert_eq!(augmented.combined_projection.exprs().len(), 3); + + Ok(()) + } + + #[test] + fn test_projection_with_dependencies_no_augmentation_needed() -> Result<()> { + // Test the case where no augmentation is needed + use crate::projection::ProjectionWithDependencies; + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + Field::new("c", DataType::Int32, false), + ])); + + // Create a projection [a, b, c] (keeps all) + let projection = ProjectionExprs::from_indices(&[0, 1, 2], &schema); + + // Augment it with column b (index 1) which is already included + let augmented = ProjectionWithDependencies::new(&projection, &[1], &schema)?; + + // Should NOT have a restore projection since no columns were added + assert!(augmented.restore_projection.is_none()); + + // Combined projection should be the same as the original + assert_eq!(augmented.combined_projection.exprs().len(), 3); + + Ok(()) + } } diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index 68c457a0d8a3c..bd81b25814ad2 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -391,8 +391,9 @@ impl ExecutionPlan for SortPreservingMergeExec { &self, projection: &ProjectionExec, ) -> Result>> { - // If the projection does not narrow the schema, we should not try to push it down. - if projection.expr().len() >= projection.input().schema().fields().len() { + // Only push projections that are trivial or narrow the schema to avoid + // evaluating expressions (like literals) on all input rows. + if !projection.is_leaf_pushable_or_narrows_schema() { return Ok(None); } diff --git a/datafusion/sqllogictest/test_files/projection_pushdown.slt b/datafusion/sqllogictest/test_files/projection_pushdown.slt index 4be83589495e7..f125ca92fb0d8 100644 --- a/datafusion/sqllogictest/test_files/projection_pushdown.slt +++ b/datafusion/sqllogictest/test_files/projection_pushdown.slt @@ -221,9 +221,8 @@ logical_plan 02)--Filter: simple_struct.id > Int64(2) 03)----TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(2)] physical_plan -01)ProjectionExec: expr=[id@0 as id, get_field(s@1, value) as simple_struct.s[value]] -02)--FilterExec: id@0 > 2 -03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, s], file_type=parquet, predicate=id@0 > 2, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 2, required_guarantees=[] +01)FilterExec: id@0 > 2 +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) as simple_struct.s[value]], file_type=parquet, predicate=id@0 > 2, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 2, required_guarantees=[] # Verify correctness query II @@ -245,9 +244,9 @@ logical_plan 02)--Filter: simple_struct.id > Int64(2) 03)----TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(2)] physical_plan -01)ProjectionExec: expr=[id@0 as id, get_field(s@1, value) + 1 as simple_struct.s[value] + Int64(1)] -02)--FilterExec: id@0 > 2 -03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, s], file_type=parquet, predicate=id@0 > 2, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 2, required_guarantees=[] +01)ProjectionExec: expr=[id@1 as id, __extracted_1@0 + 1 as simple_struct.s[value] + Int64(1)] +02)--FilterExec: id@1 > 2 +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __extracted_1, id], file_type=parquet, predicate=id@0 > 2, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 2, required_guarantees=[] # Verify correctness query II @@ -269,9 +268,8 @@ logical_plan 02)--Filter: get_field(simple_struct.s, Utf8("value")) > Int64(150) 03)----TableScan: simple_struct projection=[id, s], partial_filters=[get_field(simple_struct.s, Utf8("value")) > Int64(150)] physical_plan -01)ProjectionExec: expr=[id@0 as id, get_field(s@1, label) as simple_struct.s[label]] -02)--FilterExec: get_field(s@1, value) > 150 -03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, s], file_type=parquet +01)FilterExec: get_field(s@2, value) > 150, projection=[id@0, simple_struct.s[label]@1] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, label) as simple_struct.s[label], s], file_type=parquet # Verify correctness query IT @@ -543,9 +541,8 @@ logical_plan 04)------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(1)] physical_plan 01)SortExec: expr=[simple_struct.s[value]@1 ASC NULLS LAST], preserve_partitioning=[false] -02)--ProjectionExec: expr=[id@0 as id, get_field(s@1, value) as simple_struct.s[value]] -03)----FilterExec: id@0 > 1 -04)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, s], file_type=parquet, predicate=id@0 > 1, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1, required_guarantees=[] +02)--FilterExec: id@0 > 1 +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) as simple_struct.s[value]], file_type=parquet, predicate=id@0 > 1, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1, required_guarantees=[] # Verify correctness query II @@ -570,9 +567,8 @@ logical_plan 04)------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(1)] physical_plan 01)SortExec: TopK(fetch=2), expr=[simple_struct.s[value]@1 ASC NULLS LAST], preserve_partitioning=[false] -02)--ProjectionExec: expr=[id@0 as id, get_field(s@1, value) as simple_struct.s[value]] -03)----FilterExec: id@0 > 1 -04)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, s], file_type=parquet, predicate=id@0 > 1, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1, required_guarantees=[] +02)--FilterExec: id@0 > 1 +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) as simple_struct.s[value]], file_type=parquet, predicate=id@0 > 1, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1, required_guarantees=[] # Verify correctness query II @@ -595,9 +591,9 @@ logical_plan 04)------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(1)] physical_plan 01)SortExec: TopK(fetch=2), expr=[id@0 ASC NULLS LAST], preserve_partitioning=[false] -02)--ProjectionExec: expr=[id@0 as id, get_field(s@1, value) + 1 as simple_struct.s[value] + Int64(1)] -03)----FilterExec: id@0 > 1 -04)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, s], file_type=parquet, predicate=id@0 > 1 AND DynamicFilter [ empty ], pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1, required_guarantees=[] +02)--ProjectionExec: expr=[id@1 as id, __extracted_1@0 + 1 as simple_struct.s[value] + Int64(1)] +03)----FilterExec: id@1 > 1 +04)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __extracted_1, id], file_type=parquet, predicate=id@0 > 1 AND DynamicFilter [ empty ], pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1, required_guarantees=[] # Verify correctness query II @@ -735,10 +731,9 @@ logical_plan physical_plan 01)SortPreservingMergeExec: [id@0 ASC NULLS LAST] 02)--SortExec: expr=[id@0 ASC NULLS LAST], preserve_partitioning=[true] -03)----ProjectionExec: expr=[id@0 as id, get_field(s@1, value) as multi_struct.s[value]] -04)------FilterExec: id@0 > 2 -05)--------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=3 -06)----------DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part1.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part2.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part3.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part4.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part5.parquet]]}, projection=[id, s], file_type=parquet, predicate=id@0 > 2, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 2, required_guarantees=[] +03)----FilterExec: id@0 > 2 +04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=3 +05)--------DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part1.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part2.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part3.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part4.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part5.parquet]]}, projection=[id, get_field(s@1, value) as multi_struct.s[value]], file_type=parquet, predicate=id@0 > 2, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 2, required_guarantees=[] # Verify correctness query II @@ -817,9 +812,8 @@ logical_plan 02)--Filter: get_field(nullable_struct.s, Utf8("value")) IS NOT NULL 03)----TableScan: nullable_struct projection=[id, s], partial_filters=[get_field(nullable_struct.s, Utf8("value")) IS NOT NULL] physical_plan -01)ProjectionExec: expr=[id@0 as id, get_field(s@1, label) as nullable_struct.s[label]] -02)--FilterExec: get_field(s@1, value) IS NOT NULL -03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/nullable.parquet]]}, projection=[id, s], file_type=parquet +01)FilterExec: get_field(s@2, value) IS NOT NULL, projection=[id@0, nullable_struct.s[label]@1] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/nullable.parquet]]}, projection=[id, get_field(s@1, label) as nullable_struct.s[label], s], file_type=parquet # Verify correctness query IT @@ -942,9 +936,8 @@ logical_plan 04)------TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(2)] physical_plan 01)ProjectionExec: expr=[__common_expr_1@0 + __common_expr_1@0 as doubled] -02)--ProjectionExec: expr=[get_field(s@0, value) as __common_expr_1] -03)----FilterExec: id@0 > 2, projection=[s@1] -04)------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, s], file_type=parquet, predicate=id@0 > 2, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 2, required_guarantees=[] +02)--FilterExec: id@1 > 2, projection=[__common_expr_1@0] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __common_expr_1, id], file_type=parquet, predicate=id@0 > 2, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 2, required_guarantees=[] # Verify correctness query I @@ -966,9 +959,8 @@ logical_plan 02)--Filter: simple_struct.id > Int64(2) 03)----TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(2)] physical_plan -01)ProjectionExec: expr=[get_field(s@0, value) as simple_struct.s[value], get_field(s@0, label) as simple_struct.s[label]] -02)--FilterExec: id@0 > 2, projection=[s@1] -03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, s], file_type=parquet, predicate=id@0 > 2, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 2, required_guarantees=[] +01)FilterExec: id@2 > 2, projection=[simple_struct.s[value]@0, simple_struct.s[label]@1] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as simple_struct.s[value], get_field(s@1, label) as simple_struct.s[label], id], file_type=parquet, predicate=id@0 > 2, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 2, required_guarantees=[] # Verify correctness query IT @@ -1015,9 +1007,9 @@ logical_plan 02)--Filter: simple_struct.id > Int64(1) 03)----TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(1)] physical_plan -01)ProjectionExec: expr=[get_field(s@0, value) * 2 + CAST(character_length(get_field(s@0, label)) AS Int64) as score] -02)--FilterExec: id@0 > 1, projection=[s@1] -03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, s], file_type=parquet, predicate=id@0 > 1, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1, required_guarantees=[] +01)ProjectionExec: expr=[__extracted_1@0 * 2 + CAST(character_length(__extracted_2@1) AS Int64) as score] +02)--FilterExec: id@2 > 1, projection=[__extracted_1@0, __extracted_2@1] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __extracted_1, get_field(s@1, label) as __extracted_2, id], file_type=parquet, predicate=id@0 > 1, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1, required_guarantees=[] # Verify correctness query I @@ -1091,9 +1083,8 @@ logical_plan 02)--Filter: simple_struct.id > Int64(1) 03)----TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(1)] physical_plan -01)ProjectionExec: expr=[id@0 as id, get_field(s@1, value) as simple_struct.s[value]] -02)--FilterExec: id@0 > 1 -03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, s], file_type=parquet, predicate=id@0 > 1, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1, required_guarantees=[] +01)FilterExec: id@0 > 1 +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) as simple_struct.s[value]], file_type=parquet, predicate=id@0 > 1, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1, required_guarantees=[] # Verify correctness query II @@ -1110,9 +1101,8 @@ logical_plan 02)--Filter: simple_struct.id > Int64(1) AND (simple_struct.id < Int64(4) OR simple_struct.id = Int64(5)) 03)----TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(1), simple_struct.id < Int64(4) OR simple_struct.id = Int64(5)] physical_plan -01)ProjectionExec: expr=[get_field(s@0, value) as simple_struct.s[value]] -02)--FilterExec: id@0 > 1 AND (id@0 < 4 OR id@0 = 5), projection=[s@1] -03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, s], file_type=parquet, predicate=id@0 > 1 AND (id@0 < 4 OR id@0 = 5), pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1 AND (id_null_count@1 != row_count@2 AND id_min@3 < 4 OR id_null_count@1 != row_count@2 AND id_min@3 <= 5 AND 5 <= id_max@0), required_guarantees=[] +01)FilterExec: id@1 > 1 AND (id@1 < 4 OR id@1 = 5), projection=[simple_struct.s[value]@0] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as simple_struct.s[value], id], file_type=parquet, predicate=id@0 > 1 AND (id@0 < 4 OR id@0 = 5), pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1 AND (id_null_count@1 != row_count@2 AND id_min@3 < 4 OR id_null_count@1 != row_count@2 AND id_min@3 <= 5 AND 5 <= id_max@0), required_guarantees=[] # Verify correctness - should return rows where (id > 1) AND ((id < 4) OR (id = 5)) # That's: id=2,3 (1 Int64(1) AND simple_struct.id < Int64(5) 03)----TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(1), simple_struct.id < Int64(5)] physical_plan -01)ProjectionExec: expr=[get_field(s@0, value) as simple_struct.s[value]] -02)--FilterExec: id@0 > 1 AND id@0 < 5, projection=[s@1] -03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, s], file_type=parquet, predicate=id@0 > 1 AND id@0 < 5, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1 AND id_null_count@1 != row_count@2 AND id_min@3 < 5, required_guarantees=[] +01)FilterExec: id@1 > 1 AND id@1 < 5, projection=[simple_struct.s[value]@0] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as simple_struct.s[value], id], file_type=parquet, predicate=id@0 > 1 AND id@0 < 5, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1 AND id_null_count@1 != row_count@2 AND id_min@3 < 5, required_guarantees=[] # Verify correctness - should return rows where 1 < id < 5 (id=2,3,4) query I @@ -1151,9 +1140,8 @@ logical_plan 02)--Filter: simple_struct.id > Int64(1) 03)----TableScan: simple_struct projection=[id, s], partial_filters=[simple_struct.id > Int64(1)] physical_plan -01)ProjectionExec: expr=[get_field(s@1, value) as simple_struct.s[value], get_field(s@1, label) as simple_struct.s[label], id@0 as id] -02)--FilterExec: id@0 > 1 -03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, s], file_type=parquet, predicate=id@0 > 1, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1, required_guarantees=[] +01)FilterExec: id@2 > 1 +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as simple_struct.s[value], get_field(s@1, label) as simple_struct.s[label], id], file_type=parquet, predicate=id@0 > 1, pruning_predicate=id_null_count@1 != row_count@2 AND id_max@0 > 1, required_guarantees=[] # Verify correctness - note that id is now at index 2 in the augmented projection query ITI @@ -1171,9 +1159,8 @@ logical_plan 02)--Filter: character_length(get_field(simple_struct.s, Utf8("label"))) > Int32(4) 03)----TableScan: simple_struct projection=[s], partial_filters=[character_length(get_field(simple_struct.s, Utf8("label"))) > Int32(4)] physical_plan -01)ProjectionExec: expr=[get_field(s@0, value) as simple_struct.s[value]] -02)--FilterExec: character_length(get_field(s@0, label)) > 4 -03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[s], file_type=parquet +01)FilterExec: character_length(get_field(s@1, label)) > 4, projection=[simple_struct.s[value]@0] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as simple_struct.s[value], s], file_type=parquet # Verify correctness - filter on rows where label length > 4 (all have length 5, except 'one' has 3) # Wait, from the data: alpha(5), beta(4), gamma(5), delta(5), epsilon(7) diff --git a/datafusion/sqllogictest/test_files/unnest.slt b/datafusion/sqllogictest/test_files/unnest.slt index f939cd0154a82..1a6b82020c667 100644 --- a/datafusion/sqllogictest/test_files/unnest.slt +++ b/datafusion/sqllogictest/test_files/unnest.slt @@ -673,8 +673,8 @@ logical_plan physical_plan 01)ProjectionExec: expr=[__unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1],depth=2)@0 as UNNEST(UNNEST(UNNEST(recursive_unnest_table.column3)[c1])), column3@1 as column3] 02)--UnnestExec -03)----ProjectionExec: expr=[get_field(__unnest_placeholder(recursive_unnest_table.column3,depth=1)@0, c1) as __unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1]), column3@1 as column3] -04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 +04)------ProjectionExec: expr=[get_field(__unnest_placeholder(recursive_unnest_table.column3,depth=1)@0, c1) as __unnest_placeholder(UNNEST(recursive_unnest_table.column3)[c1]), column3@1 as column3] 05)--------UnnestExec 06)----------ProjectionExec: expr=[column3@0 as __unnest_placeholder(recursive_unnest_table.column3), column3@0 as column3] 07)------------DataSourceExec: partitions=1, partition_sizes=[1] From e8331ca0fc59c4ba77d18fad96342f06acfbc767 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Tue, 27 Jan 2026 11:22:21 -0500 Subject: [PATCH 2/5] rename placement_with_args -> placement --- datafusion/expr/src/expr.rs | 2 +- datafusion/expr/src/udf.rs | 2 +- datafusion/functions/src/core/getfield.rs | 8 ++++---- datafusion/physical-expr/src/scalar_function.rs | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 414a4e52d0011..190157db33b09 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1949,7 +1949,7 @@ impl Expr { let arg_placements: Vec = func.args.iter().map(|arg| arg.placement()).collect(); - func.func.placement_with_args(&arg_placements) + func.func.placement(&arg_placements) } _ => ExpressionPlacement::PlaceAtRoot, } diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 97f8fbba2b4d3..8ca7fded8edf5 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -133,7 +133,7 @@ impl ScalarUDF { /// as further up the tree so that it can be run after filtering, sorting, etc. /// /// See [`ScalarUDFImpl::placement`] for more details. - pub fn placement_with_args( + pub fn placement( &self, args: &[ExpressionPlacement], ) -> ExpressionPlacement { diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 09bf68b44635d..49ae52c9df1a3 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -575,7 +575,7 @@ mod tests { } #[test] - fn test_placement_with_args_literal_key() { + fn test_placement_literal_key() { let func = GetFieldFunc::new(); // get_field(col, 'literal') -> leaf-pushable (static field access) @@ -599,7 +599,7 @@ mod tests { } #[test] - fn test_placement_with_args_column_key() { + fn test_placement_column_key() { let func = GetFieldFunc::new(); // get_field(col, other_col) -> NOT leaf-pushable (dynamic per-row lookup) @@ -616,7 +616,7 @@ mod tests { } #[test] - fn test_placement_with_args_root() { + fn test_placement_root() { let func = GetFieldFunc::new(); // get_field(root_expr, 'literal') -> NOT leaf-pushable @@ -635,7 +635,7 @@ mod tests { } #[test] - fn test_placement_with_args_edge_cases() { + fn test_placement_edge_cases() { let func = GetFieldFunc::new(); // Empty args -> NOT leaf-pushable diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 6b2c451bd8156..21ad2e4fe5b6e 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -368,7 +368,7 @@ impl PhysicalExpr for ScalarFunctionExpr { let arg_placements: Vec = self.args.iter().map(|arg| arg.placement()).collect(); - self.fun.placement_with_args(&arg_placements) + self.fun.placement(&arg_placements) } } From d773c9c859587f93f9a4b7b9ea0caa1fbdf1ec1e Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Tue, 27 Jan 2026 11:47:36 -0500 Subject: [PATCH 3/5] fix typos --- datafusion/expr-common/src/placement.rs | 2 +- datafusion/expr/src/udf.rs | 5 +---- datafusion/functions/src/core/getfield.rs | 22 +++++++++---------- .../physical-expr-common/src/physical_expr.rs | 4 ++-- .../src/projection_pushdown.rs | 8 +++---- datafusion/physical-plan/src/projection.rs | 12 +++++----- 6 files changed, 25 insertions(+), 28 deletions(-) diff --git a/datafusion/expr-common/src/placement.rs b/datafusion/expr-common/src/placement.rs index 0527cda11e466..063e3b109859e 100644 --- a/datafusion/expr-common/src/placement.rs +++ b/datafusion/expr-common/src/placement.rs @@ -61,7 +61,7 @@ pub enum ExpressionPlacement { /// leaf-pushable expression, then it would return this variant. /// Then `other_leaf_function(get_field(...), 42)` could also be classified as /// leaf-pushable using the knowledge that `get_field(...)` is leaf-pushable. - PlaceAtLeafs, + PlaceAtLeaves, /// Argument is a complex expression that should be placed at root nodes. /// For example, `min(col1 + col2)` is not leaf-pushable because it requires per-row computation. PlaceAtRoot, diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 8ca7fded8edf5..78533a63c8223 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -133,10 +133,7 @@ impl ScalarUDF { /// as further up the tree so that it can be run after filtering, sorting, etc. /// /// See [`ScalarUDFImpl::placement`] for more details. - pub fn placement( - &self, - args: &[ExpressionPlacement], - ) -> ExpressionPlacement { + pub fn placement(&self, args: &[ExpressionPlacement]) -> ExpressionPlacement { self.inner.placement(args) } diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 49ae52c9df1a3..582281ca5d5f2 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -502,7 +502,7 @@ impl ScalarUDFImpl for GetFieldFunc { fn placement(&self, args: &[ExpressionPlacement]) -> ExpressionPlacement { // get_field is leaf-pushable if: - // 1. The struct/map argument is a Column or PlaceAtLeafs (not Literal or PlaceAtRoot) + // 1. The struct/map argument is a Column or PlaceAtLeaves (not Literal or PlaceAtRoot) // 2. All key arguments are literals (static field access, not dynamic per-row lookup) // // Literal base is not considered leaf-pushable because it would be constant-folded anyway. @@ -510,10 +510,10 @@ impl ScalarUDFImpl for GetFieldFunc { return ExpressionPlacement::PlaceAtRoot; } - // Check if the base (struct/map) argument is Column or PlaceAtLeafs + // Check if the base (struct/map) argument is Column or PlaceAtLeaves if !matches!( args[0], - ExpressionPlacement::Column | ExpressionPlacement::PlaceAtLeafs + ExpressionPlacement::Column | ExpressionPlacement::PlaceAtLeaves ) { return ExpressionPlacement::PlaceAtRoot; } @@ -525,7 +525,7 @@ impl ScalarUDFImpl for GetFieldFunc { .all(|a| *a == ExpressionPlacement::Literal); if keys_literal { - ExpressionPlacement::PlaceAtLeafs + ExpressionPlacement::PlaceAtLeaves } else { ExpressionPlacement::PlaceAtRoot } @@ -580,7 +580,7 @@ mod tests { // get_field(col, 'literal') -> leaf-pushable (static field access) let args = vec![ExpressionPlacement::Column, ExpressionPlacement::Literal]; - assert_eq!(func.placement(&args), ExpressionPlacement::PlaceAtLeafs); + assert_eq!(func.placement(&args), ExpressionPlacement::PlaceAtLeaves); // get_field(col, 'a', 'b') -> leaf-pushable (nested static field access) let args = vec![ @@ -588,14 +588,14 @@ mod tests { ExpressionPlacement::Literal, ExpressionPlacement::Literal, ]; - assert_eq!(func.placement(&args), ExpressionPlacement::PlaceAtLeafs); + assert_eq!(func.placement(&args), ExpressionPlacement::PlaceAtLeaves); - // get_field(get_field(col, 'a'), 'b') represented as PlaceAtLeafs for base + // get_field(get_field(col, 'a'), 'b') represented as PlaceAtLeaves for base let args = vec![ - ExpressionPlacement::PlaceAtLeafs, + ExpressionPlacement::PlaceAtLeaves, ExpressionPlacement::Literal, ]; - assert_eq!(func.placement(&args), ExpressionPlacement::PlaceAtLeafs); + assert_eq!(func.placement(&args), ExpressionPlacement::PlaceAtLeaves); } #[test] @@ -641,9 +641,9 @@ mod tests { // Empty args -> NOT leaf-pushable assert_eq!(func.placement(&[]), ExpressionPlacement::PlaceAtRoot); - // Just base, no key -> PlaceAtLeafs (not a valid call but should handle gracefully) + // Just base, no key -> PlaceAtLeaves (not a valid call but should handle gracefully) let args = vec![ExpressionPlacement::Column]; - assert_eq!(func.placement(&args), ExpressionPlacement::PlaceAtLeafs); + assert_eq!(func.placement(&args), ExpressionPlacement::PlaceAtLeaves); // Literal base with literal key -> NOT leaf-pushable (would be constant-folded) let args = vec![ExpressionPlacement::Literal, ExpressionPlacement::Literal]; diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index 67d39356653d3..79838dbce4ef9 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -437,13 +437,13 @@ pub trait PhysicalExpr: Any + Send + Sync + Display + Debug + DynEq + DynHash { /// Leaf-pushable expressions include: /// - Column references (`ExpressionPlacement::Column`) /// - Literal values (`ExpressionPlacement::Literal`) - /// - Struct field access via `get_field` (`ExpressionPlacement::PlaceAtLeafs`) + /// - Struct field access via `get_field` (`ExpressionPlacement::PlaceAtLeaves`) /// - Nested combinations of field accessors (e.g., `col['a']['b']`) /// /// This is used to identify expressions that are cheap to duplicate or /// don't benefit from caching/partitioning optimizations. /// - /// **Performance note**: Expressions marked as `PlaceAtLeafs` may be pushed + /// **Performance note**: Expressions marked as `PlaceAtLeaves` may be pushed /// below filters during optimization. If an expression does per-row work, /// marking it leaf-pushable may slow things down by causing evaluation on more rows. fn placement(&self) -> ExpressionPlacement { diff --git a/datafusion/physical-optimizer/src/projection_pushdown.rs b/datafusion/physical-optimizer/src/projection_pushdown.rs index ae64de35c1c0b..ce85e4c20bc4a 100644 --- a/datafusion/physical-optimizer/src/projection_pushdown.rs +++ b/datafusion/physical-optimizer/src/projection_pushdown.rs @@ -132,12 +132,12 @@ fn try_split_projection( } // Only extract from root-level expressions. If the entire expression is - // already PlaceAtLeafs (like `get_field(col, 'foo')`), it can be pushed as-is. + // already PlaceAtLeaves (like `get_field(col, 'foo')`), it can be pushed as-is. // We only need to split when there's a root expression with leaf-pushable // sub-expressions (like `get_field(col, 'foo') + 1`). if matches!( proj_expr.expr.placement(), - ExpressionPlacement::PlaceAtLeafs + ExpressionPlacement::PlaceAtLeaves ) { outer_exprs.push(proj_expr.clone()); continue; @@ -180,7 +180,7 @@ fn try_split_projection( /// Extracts beneficial leaf-pushable sub-expressions from larger expressions. /// /// Similar to `JoinFilterRewriter`, this struct walks expression trees top-down -/// and extracts sub-expressions where `placement() == ExpressionPlacement::PlaceAtLeafs` +/// and extracts sub-expressions where `placement() == ExpressionPlacement::PlaceAtLeaves` /// (beneficial leaf-pushable expressions like field accessors). /// /// The extracted expressions are replaced with column references pointing to @@ -212,7 +212,7 @@ impl<'a> LeafExpressionExtractor<'a> { /// sub-expressions with column references to the inner projection. fn extract(&mut self, expr: Arc) -> Result> { // Top-down: check self first, then recurse to children - if matches!(expr.placement(), ExpressionPlacement::PlaceAtLeafs) { + if matches!(expr.placement(), ExpressionPlacement::PlaceAtLeaves) { // Extract this entire sub-tree return Ok(self.add_extracted_expr(expr)); } diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index 5d988a5fdbef3..09ad1b2fcf6d7 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -254,13 +254,13 @@ impl ProjectionExec { /// /// This is true when: /// - The projection narrows the schema (drops columns) - saves memory, OR - /// - Any expression is PlaceAtLeafs (like get_field) - beneficial computation pushdown + /// - Any expression is PlaceAtLeaves (like get_field) - beneficial computation pushdown /// /// Pure Column references that don't narrow the schema are NOT beneficial to push, /// as they just rearrange the plan without any gain. /// /// Note: Projections are split by `try_split_projection` before reaching this function, - /// so if any expression is PlaceAtLeafs, all expressions should be leaf-pushable. + /// so if any expression is PlaceAtLeaves, all expressions should be leaf-pushable. pub fn is_leaf_pushable_or_narrows_schema(&self) -> bool { let all_columns = self .expr() @@ -272,7 +272,7 @@ impl ProjectionExec { let has_leaf_expressions = self .expr() .iter() - .any(|p| matches!(p.expr.placement(), ExpressionPlacement::PlaceAtLeafs)); + .any(|p| matches!(p.expr.placement(), ExpressionPlacement::PlaceAtLeaves)); let has_root_expressions = self.expr().iter().any(|p| { matches!( @@ -1152,14 +1152,14 @@ fn try_unifying_projections( // Don't merge if: // 1. A non-trivial expression is referenced more than once (caching benefit) // See discussion in: https://github.com/apache/datafusion/issues/8296 - // 2. The child projection has PlaceAtLeafs (like get_field) that should be pushed + // 2. The child projection has PlaceAtLeaves (like get_field) that should be pushed // down to the data source separately for (column, count) in column_ref_map.iter() { let placement = child.expr()[column.index()].expr.placement(); // Don't merge if multi-referenced root level (caching) if (*count > 1 && matches!(placement, ExpressionPlacement::PlaceAtRoot)) - // Don't merge if child has PlaceAtLeafs (should push to source) - || matches!(placement, ExpressionPlacement::PlaceAtLeafs) + // Don't merge if child has PlaceAtLeaves (should push to source) + || matches!(placement, ExpressionPlacement::PlaceAtLeaves) { return Ok(None); } From 1993d7f1cf86a09c76fbd4a2211f7db212262bb4 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Tue, 27 Jan 2026 12:25:45 -0500 Subject: [PATCH 4/5] fix more typo --- datafusion/expr/src/expr.rs | 2 +- datafusion/expr/src/udf.rs | 8 ++++---- datafusion/physical-plan/src/filter.rs | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 190157db33b09..3d396e16b82d3 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1936,7 +1936,7 @@ impl Expr { /// Returns the placement classification of this expression. /// /// This tells us if optimizers should preferentially - /// move this expression towards the leafs of the execution plan + /// move this expression towards the leaves of the execution plan /// tree (for cheap expressions or expressions that reduce the data size) /// or towards the root of the execution plan tree (for expensive expressions /// that should be run after filtering or parallelization, or expressions that increase the data size). diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 78533a63c8223..18bbf721b9be4 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -918,15 +918,15 @@ pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync { /// # Example /// /// - `get_field(struct_col, 'field_name')` with a literal key is leaf-pushable as it - /// performs metadata only (cheap) extraction of a sub-array from a struct column. - /// Thus, it can be placed near the data source to minimize data early. + /// performs metadata only (cheap) extraction of a sub-array from a struct column. + /// Thus, it can be placed near the data source to minimize data early. /// - `string_col like '%foo%'` performs expensive per-row computation and should be placed - /// further up the tree so that it can be run after filtering, sorting, etc. + /// further up the tree so that it can be run after filtering, sorting, etc. /// /// # Arguments /// /// * `args` - Classification of each argument's placement, collected from the expression tree - /// by the caller. + /// by the caller. fn placement(&self, _args: &[ExpressionPlacement]) -> ExpressionPlacement { ExpressionPlacement::PlaceAtRoot } diff --git a/datafusion/physical-plan/src/filter.rs b/datafusion/physical-plan/src/filter.rs index 827ef232691b6..b6624444fd880 100644 --- a/datafusion/physical-plan/src/filter.rs +++ b/datafusion/physical-plan/src/filter.rs @@ -568,7 +568,7 @@ impl ExecutionPlan for FilterExec { projection: &ProjectionExec, ) -> Result>> { // If we have an embedded projection already we cannot continue - // This is not a real problem: calling this method generates the embeded projection + // This is not a real problem: calling this method generates the embedded projection // so we should not have one already! if self.projection().is_some() { return Ok(None); From 36dda081a7df1f7d9e819f17255cfde439ea39ad Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Tue, 27 Jan 2026 14:25:12 -0500 Subject: [PATCH 5/5] Implement logical optimizer rule --- .../optimizer/src/extract_leaf_expressions.rs | 612 ++++++++++++++++++ datafusion/optimizer/src/lib.rs | 1 + datafusion/optimizer/src/optimizer.rs | 2 + datafusion/optimizer/src/test/mod.rs | 21 + .../sqllogictest/test_files/explain.slt | 4 + .../test_files/projection_pushdown.slt | 88 +-- .../test_files/push_down_filter.slt | 11 +- 7 files changed, 696 insertions(+), 43 deletions(-) create mode 100644 datafusion/optimizer/src/extract_leaf_expressions.rs diff --git a/datafusion/optimizer/src/extract_leaf_expressions.rs b/datafusion/optimizer/src/extract_leaf_expressions.rs new file mode 100644 index 0000000000000..74d9680e1593a --- /dev/null +++ b/datafusion/optimizer/src/extract_leaf_expressions.rs @@ -0,0 +1,612 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ExtractLeafExpressions`] extracts `PlaceAtLeaves` sub-expressions into projections. +//! +//! This optimizer rule normalizes the plan so that all `PlaceAtLeaves` computations +//! (like field accessors) live in Projection nodes, making them eligible for pushdown +//! by the `OptimizeProjections` rule. + +use indexmap::IndexSet; +use std::sync::Arc; + +use datafusion_common::alias::AliasGenerator; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{Column, DFSchema, Result}; +use datafusion_expr::logical_plan::LogicalPlan; +use datafusion_expr::{Expr, ExpressionPlacement, Projection}; +use indexmap::IndexMap; + +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; + +/// Extracts `PlaceAtLeaves` sub-expressions from all nodes into projections. +/// +/// This normalizes the plan so that all `PlaceAtLeaves` computations (like field +/// accessors) live in Projection nodes, making them eligible for pushdown. +/// +/// # Example +/// +/// Given a filter with a struct field access: +/// +/// ```text +/// Filter: user['status'] = 'active' +/// TableScan: t [user] +/// ``` +/// +/// This rule extracts the field access into a projection: +/// +/// ```text +/// Filter: __leaf_1 = 'active' +/// Projection: user['status'] AS __leaf_1, user +/// TableScan: t [user] +/// ``` +/// +/// The `OptimizeProjections` rule can then push this projection down to the scan. +#[derive(Default, Debug)] +pub struct ExtractLeafExpressions {} + +impl ExtractLeafExpressions { + /// Create a new [`ExtractLeafExpressions`] + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for ExtractLeafExpressions { + fn name(&self) -> &str { + "extract_leaf_expressions" + } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } + + fn rewrite( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + let alias_generator = config.alias_generator(); + extract_from_plan(plan, alias_generator) + } +} + +/// Extracts `PlaceAtLeaves` sub-expressions from a plan node. +fn extract_from_plan( + plan: LogicalPlan, + alias_generator: &Arc, +) -> Result> { + // Handle specific node types that can benefit from extraction. + // + // Schema-preserving nodes (output schema = input schema): + // - Filter: predicate doesn't affect output columns + // - Sort: ordering doesn't affect output columns + // - Limit: fetch/skip don't affect output columns + // + // Schema-transforming nodes require special handling: + // - Aggregate: handled separately to preserve output schema + // - Projection: skipped (handled by OptimizeProjections) + match &plan { + // Schema-preserving nodes + LogicalPlan::Filter(_) | LogicalPlan::Sort(_) | LogicalPlan::Limit(_) => {} + + // Aggregate needs special handling + LogicalPlan::Aggregate(_) => { + return extract_from_aggregate(plan, alias_generator); + } + + // Skip everything else + _ => { + return Ok(Transformed::no(plan)); + } + } + + // Skip nodes with no children + if plan.inputs().is_empty() { + return Ok(Transformed::no(plan)); + } + + // For nodes with multiple children (e.g., Join), we only extract from the first input + // for now to keep the logic simple. A more sophisticated implementation could handle + // multiple inputs. + let input_schema = Arc::clone(plan.inputs()[0].schema()); + let mut extractor = + LeafExpressionExtractor::new(input_schema.as_ref(), alias_generator); + + // Transform expressions using map_expressions + let transformed = plan.map_expressions(|expr| extractor.extract(expr))?; + + if !extractor.has_extractions() { + return Ok(transformed); + } + + // For non-Projection nodes (like Filter, Sort, etc.), we need to pass through + // ALL columns from the input schema, not just those referenced in expressions. + // This is because these nodes don't change the schema - they pass through all columns. + for col in input_schema.columns() { + extractor.columns_needed.insert(col); + } + + // Build projection with extracted expressions + pass-through columns + // Clone the first input to wrap in Arc + let first_input = transformed.data.inputs()[0].clone(); + let inner_projection = extractor.build_projection(Arc::new(first_input))?; + + // Update plan to use new projection as input + let new_inputs: Vec = + std::iter::once(LogicalPlan::Projection(inner_projection)) + .chain( + transformed + .data + .inputs() + .iter() + .skip(1) + .map(|p| (*p).clone()), + ) + .collect(); + + let new_plan = transformed + .data + .with_new_exprs(transformed.data.expressions(), new_inputs)?; + + // Add an outer projection to restore the original schema + // This ensures the optimized plan has the same output schema + let original_schema_exprs: Vec = input_schema + .columns() + .into_iter() + .map(Expr::Column) + .collect(); + + let outer_projection = + Projection::try_new(original_schema_exprs, Arc::new(new_plan))?; + + Ok(Transformed::yes(LogicalPlan::Projection(outer_projection))) +} + +/// Extracts `PlaceAtLeaves` sub-expressions from Aggregate nodes. +/// +/// For Aggregates, we extract from: +/// - Group-by expressions (full expressions or sub-expressions) +/// - Arguments inside aggregate functions (NOT the aggregate function itself) +fn extract_from_aggregate( + plan: LogicalPlan, + alias_generator: &Arc, +) -> Result> { + let LogicalPlan::Aggregate(agg) = plan else { + return Ok(Transformed::no(plan)); + }; + + // Capture original output schema for restoration + let original_schema = Arc::clone(&agg.schema); + + let input_schema = agg.input.schema(); + let mut extractor = + LeafExpressionExtractor::new(input_schema.as_ref(), alias_generator); + + // Extract from group-by expressions + let mut new_group_by = Vec::with_capacity(agg.group_expr.len()); + let mut has_extractions = false; + + for expr in &agg.group_expr { + let transformed = extractor.extract(expr.clone())?; + if transformed.transformed { + has_extractions = true; + } + new_group_by.push(transformed.data); + } + + // Extract from aggregate function arguments (not the function itself) + let mut new_aggr = Vec::with_capacity(agg.aggr_expr.len()); + + for expr in &agg.aggr_expr { + let transformed = extract_from_aggregate_args(expr.clone(), &mut extractor)?; + if transformed.transformed { + has_extractions = true; + } + new_aggr.push(transformed.data); + } + + if !has_extractions { + return Ok(Transformed::no(LogicalPlan::Aggregate(agg))); + } + + // Track columns needed by the aggregate (for pass-through) + for expr in new_group_by.iter().chain(new_aggr.iter()) { + for col in expr.column_refs() { + extractor.columns_needed.insert(col.clone()); + } + } + + // Build inner projection with extracted expressions + pass-through columns + let inner_projection = extractor.build_projection(Arc::clone(&agg.input))?; + + // Create new Aggregate with transformed expressions + let new_agg = datafusion_expr::logical_plan::Aggregate::try_new( + Arc::new(LogicalPlan::Projection(inner_projection)), + new_group_by, + new_aggr, + )?; + + // Create outer projection to restore original schema names + let outer_exprs: Vec = original_schema + .iter() + .zip(new_agg.schema.columns()) + .map(|((original_qual, original_field), new_col)| { + // Map from new schema column to original schema name, preserving qualifier + Expr::Column(new_col) + .alias_qualified(original_qual.cloned(), original_field.name()) + }) + .collect(); + + let outer_projection = + Projection::try_new(outer_exprs, Arc::new(LogicalPlan::Aggregate(new_agg)))?; + + Ok(Transformed::yes(LogicalPlan::Projection(outer_projection))) +} + +/// Extracts `PlaceAtLeaves` sub-expressions from aggregate function arguments. +/// +/// This extracts from inside the aggregate (e.g., from `sum(get_field(x, 'y'))` +/// we extract `get_field(x, 'y')`), but NOT the aggregate function itself. +fn extract_from_aggregate_args( + expr: Expr, + extractor: &mut LeafExpressionExtractor, +) -> Result> { + match expr { + Expr::AggregateFunction(mut agg_func) => { + // Extract from arguments, not the function itself + let mut any_changed = false; + let mut new_args = Vec::with_capacity(agg_func.params.args.len()); + + for arg in agg_func.params.args { + let transformed = extractor.extract(arg)?; + if transformed.transformed { + any_changed = true; + } + new_args.push(transformed.data); + } + + if any_changed { + agg_func.params.args = new_args; + Ok(Transformed::yes(Expr::AggregateFunction(agg_func))) + } else { + agg_func.params.args = new_args; + Ok(Transformed::no(Expr::AggregateFunction(agg_func))) + } + } + // For aliased aggregates, process the inner expression + Expr::Alias(alias) => { + let transformed = extract_from_aggregate_args(*alias.expr, extractor)?; + Ok( + transformed + .update_data(|e| e.alias_qualified(alias.relation, alias.name)), + ) + } + // For other expressions, use regular extraction + other => extractor.extract(other), + } +} + +/// Extracts `PlaceAtLeaves` sub-expressions from larger expressions. +struct LeafExpressionExtractor<'a> { + /// Extracted expressions: maps schema_name -> (original_expr, alias) + extracted: IndexMap, + /// Columns needed for pass-through + columns_needed: IndexSet, + /// Input schema + input_schema: &'a DFSchema, + /// Alias generator + alias_generator: &'a Arc, +} + +impl<'a> LeafExpressionExtractor<'a> { + fn new(input_schema: &'a DFSchema, alias_generator: &'a Arc) -> Self { + Self { + extracted: IndexMap::new(), + columns_needed: IndexSet::new(), + input_schema, + alias_generator, + } + } + + /// Extracts `PlaceAtLeaves` sub-expressions, returning rewritten expression. + fn extract(&mut self, expr: Expr) -> Result> { + // Walk top-down to find PlaceAtLeaves sub-expressions + expr.transform_down(|e| { + match e.placement() { + ExpressionPlacement::PlaceAtLeaves => { + // Extract this entire sub-tree + let col_ref = self.add_extracted(e)?; + Ok(Transformed::yes(col_ref)) + } + ExpressionPlacement::Column => { + // Track columns for pass-through + if let Expr::Column(col) = &e { + self.columns_needed.insert(col.clone()); + } + Ok(Transformed::no(e)) + } + _ => { + // Continue recursing into children + Ok(Transformed::no(e)) + } + } + }) + } + + /// Adds an expression to extracted set, returns column reference. + fn add_extracted(&mut self, expr: Expr) -> Result { + let schema_name = expr.schema_name().to_string(); + + // Deduplication: reuse existing alias if same expression + if let Some((_, alias)) = self.extracted.get(&schema_name) { + return Ok(Expr::Column(Column::new_unqualified(alias))); + } + + // Track columns referenced by this expression + for col in expr.column_refs() { + self.columns_needed.insert(col.clone()); + } + + // Generate unique alias + let alias = self.alias_generator.next("__leaf"); + self.extracted.insert(schema_name, (expr, alias.clone())); + + Ok(Expr::Column(Column::new_unqualified(&alias))) + } + + fn has_extractions(&self) -> bool { + !self.extracted.is_empty() + } + + /// Builds projection with extracted expressions + pass-through columns. + fn build_projection(&self, input: Arc) -> Result { + let mut proj_exprs = Vec::new(); + + // Add extracted expressions with their aliases + for (_, (expr, alias)) in &self.extracted { + proj_exprs.push(expr.clone().alias(alias)); + } + + // Add pass-through columns that are in the input schema + for col in &self.columns_needed { + // Only add if the column exists in the input schema + if self.input_schema.has_column(col) { + proj_exprs.push(Expr::Column(col.clone())); + } + } + + Projection::try_new(proj_exprs, input) + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::*; + use crate::test::*; + use crate::{OptimizerContext, assert_optimized_plan_eq_snapshot}; + use arrow::datatypes::DataType; + use datafusion_common::Result; + use datafusion_expr::expr::ScalarFunction; + use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, + TypeSignature, col, lit, logical_plan::builder::LogicalPlanBuilder, + }; + + /// A mock UDF that simulates a leaf-pushable function like `get_field`. + /// It returns `PlaceAtLeaves` when its first argument is Column or PlaceAtLeaves. + #[derive(Debug, PartialEq, Eq, Hash)] + struct MockLeafFunc { + signature: Signature, + } + + impl MockLeafFunc { + fn new() -> Self { + Self { + signature: Signature::new( + TypeSignature::Any(2), + datafusion_expr::Volatility::Immutable, + ), + } + } + } + + impl ScalarUDFImpl for MockLeafFunc { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "mock_leaf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _args: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + unimplemented!("This is only used for testing optimization") + } + + fn placement(&self, args: &[ExpressionPlacement]) -> ExpressionPlacement { + // Return PlaceAtLeaves if first arg is Column or PlaceAtLeaves + // (like get_field does) + match args.first() { + Some(ExpressionPlacement::Column) + | Some(ExpressionPlacement::PlaceAtLeaves) => { + ExpressionPlacement::PlaceAtLeaves + } + _ => ExpressionPlacement::PlaceAtRoot, + } + } + } + + fn mock_leaf(expr: Expr, name: &str) -> Expr { + Expr::ScalarFunction(ScalarFunction::new_udf( + Arc::new(ScalarUDF::new_from_impl(MockLeafFunc::new())), + vec![expr, lit(name)], + )) + } + + macro_rules! assert_optimized_plan_equal { + ( + $plan:expr, + @ $expected:literal $(,)? + ) => {{ + let optimizer_ctx = OptimizerContext::new().with_max_passes(1); + let rules: Vec> = + vec![Arc::new(ExtractLeafExpressions::new())]; + assert_optimized_plan_eq_snapshot!(optimizer_ctx, rules, $plan, @ $expected,) + }}; + } + + #[test] + fn test_extract_from_filter() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(mock_leaf(col("user"), "status").eq(lit("active")))? + .build()?; + + // Note: An outer projection is added to preserve the original schema + assert_optimized_plan_equal!(plan, @r#" + Projection: test.user + Filter: __leaf_1 = Utf8("active") + Projection: mock_leaf(test.user, Utf8("status")) AS __leaf_1, test.user + TableScan: test + "#) + } + + #[test] + fn test_no_extraction_for_column() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .filter(col("a").eq(lit(1)))? + .build()?; + + // No extraction should happen for simple columns + assert_optimized_plan_equal!(plan, @r" + Filter: test.a = Int32(1) + TableScan: test + ") + } + + #[test] + fn test_no_extraction_for_projection() -> Result<()> { + // Projections are skipped - they're handled by OptimizeProjections + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .project(vec![ + mock_leaf(col("user"), "name") + .is_not_null() + .alias("has_name"), + ])? + .build()?; + + // No extraction from Projections - they're schema-transforming + assert_optimized_plan_equal!(plan, @r#" + Projection: mock_leaf(test.user, Utf8("name")) IS NOT NULL AS has_name + TableScan: test + "#) + } + + #[test] + fn test_filter_with_deduplication() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + let field_access = mock_leaf(col("user"), "name"); + // Filter with the same expression used twice + let plan = LogicalPlanBuilder::from(table_scan) + .filter( + field_access + .clone() + .is_not_null() + .and(field_access.is_null()), + )? + .build()?; + + // Same expression should be extracted only once + assert_optimized_plan_equal!(plan, @r#" + Projection: test.user + Filter: __leaf_1 IS NOT NULL AND __leaf_1 IS NULL + Projection: mock_leaf(test.user, Utf8("name")) AS __leaf_1, test.user + TableScan: test + "#) + } + + #[test] + fn test_already_leaf_expression_in_filter() -> Result<()> { + let table_scan = test_table_scan_with_struct()?; + // A bare mock_leaf expression is already PlaceAtLeaves + // When compared to a literal, the comparison is PlaceAtRoot so extraction happens + let plan = LogicalPlanBuilder::from(table_scan) + .filter(mock_leaf(col("user"), "name").eq(lit("test")))? + .build()?; + + assert_optimized_plan_equal!(plan, @r#" + Projection: test.user + Filter: __leaf_1 = Utf8("test") + Projection: mock_leaf(test.user, Utf8("name")) AS __leaf_1, test.user + TableScan: test + "#) + } + + #[test] + fn test_extract_from_aggregate_group_by() -> Result<()> { + use datafusion_expr::test::function_stub::count; + + let table_scan = test_table_scan_with_struct()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![mock_leaf(col("user"), "status")], vec![count(lit(1))])? + .build()?; + + // Group-by expression is PlaceAtLeaves, so it gets extracted + assert_optimized_plan_equal!(plan, @r#" + Projection: __leaf_1 AS mock_leaf(test.user,Utf8("status")), COUNT(Int32(1)) AS COUNT(Int32(1)) + Aggregate: groupBy=[[__leaf_1]], aggr=[[COUNT(Int32(1))]] + Projection: mock_leaf(test.user, Utf8("status")) AS __leaf_1, test.user + TableScan: test + "#) + } + + #[test] + fn test_extract_from_aggregate_args() -> Result<()> { + use datafusion_expr::test::function_stub::count; + + let table_scan = test_table_scan_with_struct()?; + // Use count(mock_leaf(...)) since count works with any type + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate( + vec![col("user")], + vec![count(mock_leaf(col("user"), "value"))], + )? + .build()?; + + // Aggregate argument is PlaceAtLeaves, so it gets extracted + assert_optimized_plan_equal!(plan, @r#" + Projection: test.user AS user, COUNT(__leaf_1) AS COUNT(mock_leaf(test.user,Utf8("value"))) + Aggregate: groupBy=[[test.user]], aggr=[[COUNT(__leaf_1)]] + Projection: mock_leaf(test.user, Utf8("value")) AS __leaf_1, test.user + TableScan: test + "#) + } +} diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index f8ab453591e91..e3b38c89ef944 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -58,6 +58,7 @@ pub mod eliminate_nested_union { } pub mod eliminate_outer_join; pub mod extract_equijoin_predicate; +pub mod extract_leaf_expressions; pub mod filter_null_join_keys; pub mod optimize_projections; pub mod optimize_unions; diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 877a84fe4dc14..d7c9867a1e456 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -43,6 +43,7 @@ use crate::eliminate_join::EliminateJoin; use crate::eliminate_limit::EliminateLimit; use crate::eliminate_outer_join::EliminateOuterJoin; use crate::extract_equijoin_predicate::ExtractEquijoinPredicate; +use crate::extract_leaf_expressions::ExtractLeafExpressions; use crate::filter_null_join_keys::FilterNullJoinKeys; use crate::optimize_projections::OptimizeProjections; use crate::optimize_unions::OptimizeUnions; @@ -260,6 +261,7 @@ impl Optimizer { // that might benefit from the following rules Arc::new(EliminateGroupByConstant::new()), Arc::new(CommonSubexprEliminate::new()), + Arc::new(ExtractLeafExpressions::new()), Arc::new(OptimizeProjections::new()), ]; diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index a45983950496d..0a4daa86e9157 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -45,6 +45,27 @@ pub fn test_table_scan() -> Result { test_table_scan_with_name("test") } +/// Returns fields for a test table with a struct column +pub fn test_table_scan_with_struct_fields() -> Vec { + vec![Field::new( + "user", + DataType::Struct( + vec![ + Field::new("name", DataType::Utf8, true), + Field::new("status", DataType::Utf8, true), + ] + .into(), + ), + true, + )] +} + +/// some tests share a common table with a struct column +pub fn test_table_scan_with_struct() -> Result { + let schema = Schema::new(test_table_scan_with_struct_fields()); + table_scan(Some("test"), &schema, None)?.build() +} + /// Scan an empty data source, mainly used in tests pub fn scan_empty( name: Option<&str>, diff --git a/datafusion/sqllogictest/test_files/explain.slt b/datafusion/sqllogictest/test_files/explain.slt index 6f615ec391c9e..7a2c661ad93ce 100644 --- a/datafusion/sqllogictest/test_files/explain.slt +++ b/datafusion/sqllogictest/test_files/explain.slt @@ -197,6 +197,7 @@ logical_plan after push_down_filter SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE +logical_plan after extract_leaf_expressions SAME TEXT AS ABOVE logical_plan after optimize_projections TableScan: simple_explain_test projection=[a, b, c] logical_plan after rewrite_set_comparison SAME TEXT AS ABOVE logical_plan after optimize_unions SAME TEXT AS ABOVE @@ -219,6 +220,7 @@ logical_plan after push_down_filter SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE +logical_plan after extract_leaf_expressions SAME TEXT AS ABOVE logical_plan after optimize_projections SAME TEXT AS ABOVE logical_plan TableScan: simple_explain_test projection=[a, b, c] initial_physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], file_type=csv, has_header=true @@ -558,6 +560,7 @@ logical_plan after push_down_filter SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE +logical_plan after extract_leaf_expressions SAME TEXT AS ABOVE logical_plan after optimize_projections TableScan: simple_explain_test projection=[a, b, c] logical_plan after rewrite_set_comparison SAME TEXT AS ABOVE logical_plan after optimize_unions SAME TEXT AS ABOVE @@ -580,6 +583,7 @@ logical_plan after push_down_filter SAME TEXT AS ABOVE logical_plan after single_distinct_aggregation_to_group_by SAME TEXT AS ABOVE logical_plan after eliminate_group_by_constant SAME TEXT AS ABOVE logical_plan after common_sub_expression_eliminate SAME TEXT AS ABOVE +logical_plan after extract_leaf_expressions SAME TEXT AS ABOVE logical_plan after optimize_projections SAME TEXT AS ABOVE logical_plan TableScan: simple_explain_test projection=[a, b, c] initial_physical_plan DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/example.csv]]}, projection=[a, b, c], file_type=csv, has_header=true diff --git a/datafusion/sqllogictest/test_files/projection_pushdown.slt b/datafusion/sqllogictest/test_files/projection_pushdown.slt index f125ca92fb0d8..18a03cf348ebf 100644 --- a/datafusion/sqllogictest/test_files/projection_pushdown.slt +++ b/datafusion/sqllogictest/test_files/projection_pushdown.slt @@ -265,11 +265,13 @@ EXPLAIN SELECT id, s['label'] FROM simple_struct WHERE s['value'] > 150; ---- logical_plan 01)Projection: simple_struct.id, get_field(simple_struct.s, Utf8("label")) -02)--Filter: get_field(simple_struct.s, Utf8("value")) > Int64(150) -03)----TableScan: simple_struct projection=[id, s], partial_filters=[get_field(simple_struct.s, Utf8("value")) > Int64(150)] +02)--Projection: simple_struct.s, simple_struct.id +03)----Filter: __leaf_3 > Int64(150) +04)------Projection: get_field(simple_struct.s, Utf8("value")) AS __leaf_3, simple_struct.s, simple_struct.id +05)--------TableScan: simple_struct projection=[id, s], partial_filters=[get_field(simple_struct.s, Utf8("value")) > Int64(150)] physical_plan -01)FilterExec: get_field(s@2, value) > 150, projection=[id@0, simple_struct.s[label]@1] -02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, label) as simple_struct.s[label], s], file_type=parquet +01)FilterExec: __leaf_3@2 > 150, projection=[id@0, simple_struct.s[label]@1] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, label) as simple_struct.s[label], get_field(s@1, value) as __leaf_3], file_type=parquet # Verify correctness query IT @@ -751,13 +753,16 @@ query TT EXPLAIN SELECT s['label'], SUM(s['value']) FROM multi_struct GROUP BY s['label']; ---- logical_plan -01)Aggregate: groupBy=[[get_field(multi_struct.s, Utf8("label"))]], aggr=[[sum(get_field(multi_struct.s, Utf8("value")))]] -02)--TableScan: multi_struct projection=[s] +01)Projection: __leaf_1 AS multi_struct.s[label], sum(__leaf_2) AS sum(multi_struct.s[value]) +02)--Aggregate: groupBy=[[__leaf_1]], aggr=[[sum(__leaf_2)]] +03)----Projection: get_field(multi_struct.s, Utf8("label")) AS __leaf_1, get_field(multi_struct.s, Utf8("value")) AS __leaf_2 +04)------TableScan: multi_struct projection=[s] physical_plan -01)AggregateExec: mode=FinalPartitioned, gby=[multi_struct.s[label]@0 as multi_struct.s[label]], aggr=[sum(multi_struct.s[value])] -02)--RepartitionExec: partitioning=Hash([multi_struct.s[label]@0], 4), input_partitions=3 -03)----AggregateExec: mode=Partial, gby=[get_field(s@0, label) as multi_struct.s[label]], aggr=[sum(multi_struct.s[value])] -04)------DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part1.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part2.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part3.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part4.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part5.parquet]]}, projection=[s], file_type=parquet +01)ProjectionExec: expr=[__leaf_1@0 as multi_struct.s[label], sum(__leaf_2)@1 as sum(multi_struct.s[value])] +02)--AggregateExec: mode=FinalPartitioned, gby=[__leaf_1@0 as __leaf_1], aggr=[sum(__leaf_2)] +03)----RepartitionExec: partitioning=Hash([__leaf_1@0], 4), input_partitions=3 +04)------AggregateExec: mode=Partial, gby=[__leaf_1@0 as __leaf_1], aggr=[sum(__leaf_2)] +05)--------DataSourceExec: file_groups={3 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part1.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part2.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part3.parquet, WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part4.parquet], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/multi/part5.parquet]]}, projection=[get_field(s@1, label) as __leaf_1, get_field(s@1, value) as __leaf_2], file_type=parquet # Verify correctness query TI @@ -809,11 +814,13 @@ EXPLAIN SELECT id, s['label'] FROM nullable_struct WHERE s['value'] IS NOT NULL; ---- logical_plan 01)Projection: nullable_struct.id, get_field(nullable_struct.s, Utf8("label")) -02)--Filter: get_field(nullable_struct.s, Utf8("value")) IS NOT NULL -03)----TableScan: nullable_struct projection=[id, s], partial_filters=[get_field(nullable_struct.s, Utf8("value")) IS NOT NULL] +02)--Projection: nullable_struct.s, nullable_struct.id +03)----Filter: __leaf_3 IS NOT NULL +04)------Projection: get_field(nullable_struct.s, Utf8("value")) AS __leaf_3, nullable_struct.s, nullable_struct.id +05)--------TableScan: nullable_struct projection=[id, s], partial_filters=[get_field(nullable_struct.s, Utf8("value")) IS NOT NULL] physical_plan -01)FilterExec: get_field(s@2, value) IS NOT NULL, projection=[id@0, nullable_struct.s[label]@1] -02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/nullable.parquet]]}, projection=[id, get_field(s@1, label) as nullable_struct.s[label], s], file_type=parquet +01)FilterExec: __leaf_3@2 IS NOT NULL, projection=[id@0, nullable_struct.s[label]@1] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/nullable.parquet]]}, projection=[id, get_field(s@1, label) as nullable_struct.s[label], get_field(s@1, value) as __leaf_3], file_type=parquet # Verify correctness query IT @@ -1156,11 +1163,13 @@ EXPLAIN SELECT s['value'] FROM simple_struct WHERE length(s['label']) > 4; ---- logical_plan 01)Projection: get_field(simple_struct.s, Utf8("value")) -02)--Filter: character_length(get_field(simple_struct.s, Utf8("label"))) > Int32(4) -03)----TableScan: simple_struct projection=[s], partial_filters=[character_length(get_field(simple_struct.s, Utf8("label"))) > Int32(4)] +02)--Projection: simple_struct.s +03)----Filter: character_length(__leaf_3) > Int32(4) +04)------Projection: get_field(simple_struct.s, Utf8("label")) AS __leaf_3, simple_struct.s +05)--------TableScan: simple_struct projection=[s], partial_filters=[character_length(get_field(simple_struct.s, Utf8("label"))) > Int32(4)] physical_plan -01)FilterExec: character_length(get_field(s@1, label)) > 4, projection=[simple_struct.s[value]@0] -02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as simple_struct.s[value], s], file_type=parquet +01)FilterExec: character_length(__leaf_3@1) > 4, projection=[simple_struct.s[value]@0] +02)--DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as simple_struct.s[value], get_field(s@1, label) as __leaf_3], file_type=parquet # Verify correctness - filter on rows where label length > 4 (all have length 5, except 'one' has 3) # Wait, from the data: alpha(5), beta(4), gamma(5), delta(5), epsilon(7) @@ -1187,12 +1196,13 @@ EXPLAIN SELECT id FROM simple_struct ORDER BY s['value']; ---- logical_plan 01)Projection: simple_struct.id -02)--Sort: get_field(simple_struct.s, Utf8("value")) ASC NULLS LAST -03)----TableScan: simple_struct projection=[id, s] +02)--Sort: __leaf_1 ASC NULLS LAST +03)----Projection: get_field(simple_struct.s, Utf8("value")) AS __leaf_1, simple_struct.id +04)------TableScan: simple_struct projection=[id, s] physical_plan -01)ProjectionExec: expr=[id@0 as id] -02)--SortExec: expr=[get_field(s@1, value) ASC NULLS LAST], preserve_partitioning=[false] -03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, s], file_type=parquet +01)ProjectionExec: expr=[id@1 as id] +02)--SortExec: expr=[__leaf_1@0 ASC NULLS LAST], preserve_partitioning=[false] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __leaf_1, id], file_type=parquet # Verify correctness query I @@ -1215,13 +1225,13 @@ EXPLAIN SELECT id, s['value'] FROM simple_struct ORDER BY id, s['label']; ---- logical_plan 01)Projection: simple_struct.id, simple_struct.s[value] -02)--Sort: simple_struct.id ASC NULLS LAST, get_field(simple_struct.s, Utf8("label")) ASC NULLS LAST -03)----Projection: simple_struct.id, get_field(simple_struct.s, Utf8("value")), simple_struct.s +02)--Sort: simple_struct.id ASC NULLS LAST, __leaf_1 ASC NULLS LAST +03)----Projection: get_field(simple_struct.s, Utf8("label")) AS __leaf_1, simple_struct.id, get_field(simple_struct.s, Utf8("value")) 04)------TableScan: simple_struct projection=[id, s] physical_plan -01)ProjectionExec: expr=[id@0 as id, simple_struct.s[value]@1 as simple_struct.s[value]] -02)--SortExec: expr=[id@0 ASC NULLS LAST, get_field(s@2, label) ASC NULLS LAST], preserve_partitioning=[false] -03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, get_field(s@1, value) as simple_struct.s[value], s], file_type=parquet +01)ProjectionExec: expr=[id@1 as id, simple_struct.s[value]@2 as simple_struct.s[value]] +02)--SortExec: expr=[id@1 ASC NULLS LAST, __leaf_1@0 ASC NULLS LAST], preserve_partitioning=[false] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, label) as __leaf_1, id, get_field(s@1, value) as simple_struct.s[value]], file_type=parquet # Verify correctness query II @@ -1243,12 +1253,13 @@ EXPLAIN SELECT id FROM simple_struct ORDER BY s['value'] LIMIT 2; ---- logical_plan 01)Projection: simple_struct.id -02)--Sort: get_field(simple_struct.s, Utf8("value")) ASC NULLS LAST, fetch=2 -03)----TableScan: simple_struct projection=[id, s] +02)--Sort: __leaf_1 ASC NULLS LAST, fetch=2 +03)----Projection: get_field(simple_struct.s, Utf8("value")) AS __leaf_1, simple_struct.id +04)------TableScan: simple_struct projection=[id, s] physical_plan -01)ProjectionExec: expr=[id@0 as id] -02)--SortExec: TopK(fetch=2), expr=[get_field(s@1, value) ASC NULLS LAST], preserve_partitioning=[false] -03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, s], file_type=parquet +01)ProjectionExec: expr=[id@1 as id] +02)--SortExec: TopK(fetch=2), expr=[__leaf_1@0 ASC NULLS LAST], preserve_partitioning=[false] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __leaf_1, id], file_type=parquet # Verify correctness query I @@ -1268,12 +1279,13 @@ EXPLAIN SELECT id FROM simple_struct ORDER BY s['value'] * 2; ---- logical_plan 01)Projection: simple_struct.id -02)--Sort: get_field(simple_struct.s, Utf8("value")) * Int64(2) ASC NULLS LAST -03)----TableScan: simple_struct projection=[id, s] +02)--Sort: __leaf_1 * Int64(2) ASC NULLS LAST +03)----Projection: get_field(simple_struct.s, Utf8("value")) AS __leaf_1, simple_struct.id +04)------TableScan: simple_struct projection=[id, s] physical_plan -01)ProjectionExec: expr=[id@0 as id] -02)--SortExec: expr=[get_field(s@1, value) * 2 ASC NULLS LAST], preserve_partitioning=[false] -03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[id, s], file_type=parquet +01)ProjectionExec: expr=[id@1 as id] +02)--SortExec: expr=[__leaf_1@0 * 2 ASC NULLS LAST], preserve_partitioning=[false] +03)----DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/scratch/projection_pushdown/simple.parquet]]}, projection=[get_field(s@1, value) as __leaf_1, id], file_type=parquet # Verify correctness query I diff --git a/datafusion/sqllogictest/test_files/push_down_filter.slt b/datafusion/sqllogictest/test_files/push_down_filter.slt index 4353f805c848b..df74d1b7bdd6d 100644 --- a/datafusion/sqllogictest/test_files/push_down_filter.slt +++ b/datafusion/sqllogictest/test_files/push_down_filter.slt @@ -115,12 +115,13 @@ query TT explain select * from (select column1, unnest(column2) as o from d) where o['a'] = 1; ---- physical_plan -01)ProjectionExec: expr=[column1@0 as column1, __unnest_placeholder(d.column2,depth=1)@1 as o] -02)--FilterExec: get_field(__unnest_placeholder(d.column2,depth=1)@1, a) = 1 +01)ProjectionExec: expr=[column1@1 as column1, __unnest_placeholder(d.column2,depth=1)@0 as o] +02)--FilterExec: __leaf_3@0 = 1, projection=[__unnest_placeholder(d.column2,depth=1)@1, column1@2] 03)----RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 -04)------UnnestExec -05)--------ProjectionExec: expr=[column1@0 as column1, column2@1 as __unnest_placeholder(d.column2)] -06)----------DataSourceExec: partitions=1, partition_sizes=[1] +04)------ProjectionExec: expr=[get_field(__unnest_placeholder(d.column2,depth=1)@1, a) as __leaf_3, __unnest_placeholder(d.column2,depth=1)@1 as __unnest_placeholder(d.column2,depth=1), column1@0 as column1] +05)--------UnnestExec +06)----------ProjectionExec: expr=[column1@0 as column1, column2@1 as __unnest_placeholder(d.column2)] +07)------------DataSourceExec: partitions=1, partition_sizes=[1] statement ok drop table d;