diff --git a/src/query/sql/src/planner/binder/bind_table_reference/bind_join.rs b/src/query/sql/src/planner/binder/bind_table_reference/bind_join.rs index 2a8a172429e0e..56bdcb1980c5c 100644 --- a/src/query/sql/src/planner/binder/bind_table_reference/bind_join.rs +++ b/src/query/sql/src/planner/binder/bind_table_reference/bind_join.rs @@ -394,7 +394,7 @@ impl Binder { // If there are outer columns in right child, then the join is a correlated lateral join let opt_ctx = OptimizerContext::new(self.ctx.clone(), self.metadata.clone()); let mut decorrelator = SubqueryDecorrelatorOptimizer::new(opt_ctx, Some(self.clone())); - right_child = decorrelator.flatten_plan( + let (flatten_plan, derived_columns) = decorrelator.flatten_plan( &left_child, &right_child, &right_prop.outer_columns, @@ -403,10 +403,12 @@ impl Binder { }, false, )?; + right_child = flatten_plan; let original_num_conditions = left_conditions.len(); decorrelator.add_equi_conditions( None, &right_prop.outer_columns, + &derived_columns, &mut right_conditions, &mut left_conditions, )?; diff --git a/src/query/sql/src/planner/optimizer/optimizers/operator/decorrelate/decorrelate.rs b/src/query/sql/src/planner/optimizer/optimizers/operator/decorrelate/decorrelate.rs index 8ba8d32fa5728..d9c669a4cd63b 100644 --- a/src/query/sql/src/planner/optimizer/optimizers/operator/decorrelate/decorrelate.rs +++ b/src/query/sql/src/planner/optimizer/optimizers/operator/decorrelate/decorrelate.rs @@ -28,6 +28,7 @@ use databend_common_expression::types::DataType; use databend_common_expression::types::NumberScalar; use databend_common_functions::BUILTIN_FUNCTIONS; +use super::DerivedColumnScope; use crate::ColumnSet; use crate::binder::ColumnBindingBuilder; use crate::binder::JoinPredicate; @@ -244,17 +245,17 @@ impl SubqueryDecorrelatorOptimizer { Ok(Some(result)) } - pub fn try_decorrelate_subquery( + pub(crate) fn try_decorrelate_subquery( &mut self, outer: &SExpr, subquery: &SubqueryExpr, flatten_info: &mut FlattenInfo, is_conjunctive_predicate: bool, - ) -> Result<(SExpr, UnnestResult)> { + ) -> Result<(SExpr, UnnestResult, DerivedColumnScope)> { match subquery.typ { SubqueryType::Scalar => { let correlated_columns = &subquery.outer_columns; - let flatten_plan = self.flatten_plan( + let (flatten_plan, derived_columns) = self.flatten_plan( outer, &subquery.subquery, correlated_columns, @@ -267,6 +268,7 @@ impl SubqueryDecorrelatorOptimizer { self.add_equi_conditions( subquery.span, correlated_columns, + &derived_columns, &mut right_conditions, &mut left_conditions, )?; @@ -304,16 +306,20 @@ impl SubqueryDecorrelatorOptimizer { Arc::new(outer.clone()), Arc::new(flatten_plan), ); - Ok((s_expr, UnnestResult::SingleJoin)) + Ok((s_expr, UnnestResult::SingleJoin, derived_columns)) } SubqueryType::Exists | SubqueryType::NotExists => { - if is_conjunctive_predicate { - if let Some(result) = self.try_decorrelate_simple_subquery(outer, subquery)? { - return Ok((result, UnnestResult::SimpleJoin { output_index: None })); - } + if is_conjunctive_predicate + && let Some(result) = self.try_decorrelate_simple_subquery(outer, subquery)? + { + return Ok(( + result, + UnnestResult::SimpleJoin { output_index: None }, + Default::default(), + )); } let correlated_columns = &subquery.outer_columns; - let flatten_plan = self.flatten_plan( + let (flatten_plan, derived_columns) = self.flatten_plan( outer, &subquery.subquery, correlated_columns, @@ -326,6 +332,7 @@ impl SubqueryDecorrelatorOptimizer { self.add_equi_conditions( subquery.span, correlated_columns, + &derived_columns, &mut left_conditions, &mut right_conditions, )?; @@ -368,11 +375,15 @@ impl SubqueryDecorrelatorOptimizer { Arc::new(outer.clone()), Arc::new(flatten_plan), ); - Ok((s_expr, UnnestResult::MarkJoin { marker_index })) + Ok(( + s_expr, + UnnestResult::MarkJoin { marker_index }, + derived_columns, + )) } SubqueryType::Any => { let correlated_columns = &subquery.outer_columns; - let flatten_plan = self.flatten_plan( + let (flatten_plan, derived_columns) = self.flatten_plan( outer, &subquery.subquery, correlated_columns, @@ -384,6 +395,7 @@ impl SubqueryDecorrelatorOptimizer { self.add_equi_conditions( subquery.span, correlated_columns, + &derived_columns, &mut left_conditions, &mut right_conditions, )?; @@ -453,16 +465,18 @@ impl SubqueryDecorrelatorOptimizer { Arc::new(flatten_plan), ), UnnestResult::MarkJoin { marker_index }, + derived_columns, )) } _ => unreachable!(), } } - pub fn add_equi_conditions( + pub(crate) fn add_equi_conditions( &self, span: Span, correlated_columns: &ColumnSet, + derived_columns: &DerivedColumnScope, left_conditions: &mut Vec, right_conditions: &mut Vec, ) -> Result<()> { @@ -482,15 +496,15 @@ impl SubqueryDecorrelatorOptimizer { .table_index(column_entry.table_index()) .build(), }); - let Some(derive_column) = self.derived_columns.get(&correlated_column) else { + let Some(derive_column) = derived_columns.resolve(correlated_column) else { continue; }; - let column_entry = metadata.column(*derive_column); + let column_entry = metadata.column(derive_column); let left_column = ScalarExpr::BoundColumnRef(BoundColumnRef { span, column: ColumnBindingBuilder::new( column_entry.name(), - *derive_column, + derive_column, Box::from(column_entry.data_type()), Visibility::Visible, ) @@ -507,43 +521,40 @@ impl SubqueryDecorrelatorOptimizer { // If correlated_columns only occur in equi-conditions, such as `where t1.a = t.a and t1.b = t.b`(t1 is outer table) // Then we won't join outer and inner table. pub(crate) fn join_outer_inner_table( - &mut self, + &self, filter: &Filter, correlated_columns: &ColumnSet, - ) -> Result { - Ok(!filter.predicates.iter().all(|predicate| { + ) -> Result<(bool, DerivedColumnScope)> { + let mut derived_columns = DerivedColumnScope::default(); + let can_reuse_inner_columns = filter.predicates.iter().all(|predicate| { if predicate .used_columns() .iter() - .any(|column| correlated_columns.contains(column)) + .all(|column| !correlated_columns.contains(column)) { - if let ScalarExpr::FunctionCall(func) = predicate { - if func.func_name == "eq" { - if let ( - ScalarExpr::BoundColumnRef(left), - ScalarExpr::BoundColumnRef(right), - ) = (&func.arguments[0], &func.arguments[1]) - { - if correlated_columns.contains(&left.column.index) - && !correlated_columns.contains(&right.column.index) - { - self.derived_columns - .insert(left.column.index, right.column.index); - } - if !correlated_columns.contains(&left.column.index) - && correlated_columns.contains(&right.column.index) - { - self.derived_columns - .insert(right.column.index, left.column.index); - } - return true; - } - } + return true; + } + if let ScalarExpr::FunctionCall(func) = predicate + && func.func_name == "eq" + && let (ScalarExpr::BoundColumnRef(left), ScalarExpr::BoundColumnRef(right)) = + (&func.arguments[0], &func.arguments[1]) + { + if correlated_columns.contains(&left.column.index) + && !correlated_columns.contains(&right.column.index) + { + derived_columns.record(left.column.index, right.column.index); + } + if !correlated_columns.contains(&left.column.index) + && correlated_columns.contains(&right.column.index) + { + derived_columns.record(right.column.index, left.column.index); } - return false; + true + } else { + false } - true - })) + }); + Ok((!can_reuse_inner_columns, derived_columns)) } // Try folding the subquery into a constant value expression, diff --git a/src/query/sql/src/planner/optimizer/optimizers/operator/decorrelate/flatten_plan.rs b/src/query/sql/src/planner/optimizer/optimizers/operator/decorrelate/flatten_plan.rs index 2627ff5659738..eececf84a96d5 100644 --- a/src/query/sql/src/planner/optimizer/optimizers/operator/decorrelate/flatten_plan.rs +++ b/src/query/sql/src/planner/optimizer/optimizers/operator/decorrelate/flatten_plan.rs @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashMap; use std::sync::Arc; use databend_common_exception::ErrorCode; @@ -24,6 +23,8 @@ use databend_common_expression::types::DataType; use databend_common_expression::types::NumberDataType; use databend_common_expression::types::NumberScalar; +use super::DerivedColumnScope; +use super::FlattenPlanResult; use crate::ColumnEntry; use crate::ColumnSet; use crate::Metadata; @@ -67,21 +68,41 @@ use crate::plans::WindowFuncType; use crate::plans::WindowPartition; impl SubqueryDecorrelatorOptimizer { + pub(crate) fn flatten_plan( + &mut self, + outer: &SExpr, + subquery: &SExpr, + correlated_columns: &ColumnSet, + flatten_info: &mut FlattenInfo, + need_cross_join: bool, + ) -> Result { + self.flatten_plan_with_scope( + outer, + subquery, + correlated_columns, + flatten_info, + need_cross_join, + &Default::default(), + ) + } + #[recursive::recursive] - pub fn flatten_plan( + fn flatten_plan_with_scope( &mut self, outer: &SExpr, subquery: &SExpr, correlated_columns: &ColumnSet, flatten_info: &mut FlattenInfo, need_cross_join: bool, - ) -> Result { + derived_columns: &DerivedColumnScope, + ) -> Result { let prop = subquery.derive_relational_prop()?; if prop.outer_columns.is_empty() { - if !need_cross_join { - return Ok(subquery.clone()); - } - return self.rewrite_to_join_then_aggr(outer, subquery, correlated_columns); + return if need_cross_join { + self.rewrite_to_join_then_aggr(outer, subquery, correlated_columns, derived_columns) + } else { + Ok((subquery.clone(), derived_columns.clone())) + }; } match subquery.plan() { @@ -92,6 +113,7 @@ impl SubqueryDecorrelatorOptimizer { correlated_columns, flatten_info, need_cross_join, + derived_columns, ), RelOperator::ProjectSet(project_set) => self.flatten_sub_project_set( outer, @@ -100,6 +122,7 @@ impl SubqueryDecorrelatorOptimizer { correlated_columns, flatten_info, need_cross_join, + derived_columns, ), RelOperator::Filter(filter) => self.flatten_sub_filter( outer, @@ -108,10 +131,16 @@ impl SubqueryDecorrelatorOptimizer { correlated_columns, flatten_info, need_cross_join, + derived_columns, + ), + RelOperator::Join(join) => self.flatten_sub_join( + outer, + subquery, + join, + correlated_columns, + flatten_info, + derived_columns, ), - RelOperator::Join(join) => { - self.flatten_sub_join(outer, subquery, join, correlated_columns, flatten_info) - } RelOperator::Aggregate(aggregate) => self.flatten_sub_aggregate( outer, subquery, @@ -119,6 +148,7 @@ impl SubqueryDecorrelatorOptimizer { correlated_columns, flatten_info, need_cross_join, + derived_columns, ), RelOperator::Sort(sort) => self.flatten_sub_sort( outer, @@ -127,6 +157,7 @@ impl SubqueryDecorrelatorOptimizer { correlated_columns, flatten_info, need_cross_join, + derived_columns, ), RelOperator::Limit(limit) => self.flatten_sub_limit( outer, @@ -135,6 +166,7 @@ impl SubqueryDecorrelatorOptimizer { correlated_columns, flatten_info, need_cross_join, + derived_columns, ), RelOperator::UnionAll(op) => self.flatten_sub_union_all( outer, @@ -143,13 +175,22 @@ impl SubqueryDecorrelatorOptimizer { correlated_columns, flatten_info, need_cross_join, + derived_columns, + ), + RelOperator::Window(op) => self.flatten_sub_window( + outer, + subquery, + op, + correlated_columns, + flatten_info, + derived_columns, + ), + RelOperator::ExpressionScan(scan) => self.flatten_sub_expression_scan( + subquery, + scan, + correlated_columns, + derived_columns, ), - RelOperator::Window(op) => { - self.flatten_sub_window(outer, subquery, op, correlated_columns, flatten_info) - } - RelOperator::ExpressionScan(scan) => { - self.flatten_sub_expression_scan(subquery, scan, correlated_columns) - } _ => Err(ErrorCode::SemanticError( "Invalid plan type for flattening subquery", )), @@ -164,17 +205,19 @@ impl SubqueryDecorrelatorOptimizer { correlated_columns: &ColumnSet, flatten_info: &mut FlattenInfo, mut need_cross_join: bool, - ) -> Result { + derived_columns: &DerivedColumnScope, + ) -> Result { if !eval_scalar.used_columns()?.is_disjoint(correlated_columns) { need_cross_join = true; } - let flatten_plan = self.flatten_plan( + let (flatten_plan, derived_columns) = self.flatten_plan_with_scope( outer, subquery.unary_child(), correlated_columns, flatten_info, need_cross_join, + derived_columns, )?; let metadata = self.metadata.clone(); @@ -187,11 +230,15 @@ impl SubqueryDecorrelatorOptimizer { .chain(correlated_columns.iter().copied().map(Item::Index)) .map(|item| match item { Item::Scalar(item) => Ok(ScalarItem { - scalar: self.flatten_scalar(&item.scalar, correlated_columns)?, + scalar: self.flatten_scalar( + &item.scalar, + correlated_columns, + &derived_columns, + )?, index: item.index, }), Item::Index(old) => Ok(Self::scalar_item_from_index( - self.get_derived(old)?, + derived_columns.must_resolve(old)?, "outer.", &metadata, )), @@ -233,9 +280,9 @@ impl SubqueryDecorrelatorOptimizer { }); } - Ok(SExpr::create_unary( - Arc::new(EvalScalar { items }.into()), - Arc::new(flatten_plan), + Ok(( + flatten_plan.build_unary(EvalScalar { items }), + derived_columns, )) } @@ -247,7 +294,8 @@ impl SubqueryDecorrelatorOptimizer { correlated_columns: &ColumnSet, flatten_info: &mut FlattenInfo, mut need_cross_join: bool, - ) -> Result { + derived_columns: &DerivedColumnScope, + ) -> Result { if !project_set .srfs .iter() @@ -260,38 +308,35 @@ impl SubqueryDecorrelatorOptimizer { { need_cross_join = true; } - let flatten_plan = self.flatten_plan( + let (flatten_plan, derived_columns) = self.flatten_plan_with_scope( outer, subquery.unary_child(), correlated_columns, flatten_info, need_cross_join, + derived_columns, )?; let mut srfs = Vec::with_capacity(project_set.srfs.len()); for item in project_set.srfs.iter() { let new_item = ScalarItem { - scalar: self.flatten_scalar(&item.scalar, correlated_columns)?, + scalar: self.flatten_scalar(&item.scalar, correlated_columns, &derived_columns)?, index: item.index, }; srfs.push(new_item); } let metadata = self.metadata.read(); - let scalar_items = self - .derived_columns - .values() - .map(|index| Self::scalar_item_from_index(*index, "outer.", &metadata)) + let scalar_items = derived_columns + .visible_symbols() + .into_iter() + .map(|index| Self::scalar_item_from_index(index, "outer.", &metadata)) .collect(); - Ok(SExpr::create_unary( - Arc::new(ProjectSet { srfs }.into()), - Arc::new(SExpr::create_unary( - Arc::new( - EvalScalar { - items: scalar_items, - } - .into(), - ), - Arc::new(flatten_plan), - )), + Ok(( + flatten_plan + .build_unary(EvalScalar { + items: scalar_items, + }) + .build_unary(ProjectSet { srfs }), + derived_columns, )) } @@ -303,30 +348,34 @@ impl SubqueryDecorrelatorOptimizer { correlated_columns: &ColumnSet, flatten_info: &mut FlattenInfo, need_cross_join: bool, - ) -> Result { - let mut predicates = Vec::with_capacity(filter.predicates.len()); - let need_cross_join = need_cross_join - || if self.join_outer_inner_table(filter, correlated_columns)? { - self.derived_columns.clear(); - true - } else { - false - }; - let flatten_plan = self.flatten_plan( + derived_columns: &DerivedColumnScope, + ) -> Result { + let (join_outer_inner, sub_derived_columns) = + self.join_outer_inner_table(filter, correlated_columns)?; + let need_cross_join = need_cross_join || join_outer_inner; + let child_scope = if join_outer_inner { + DerivedColumnScope::default() + } else { + derived_columns.child_scope_for_filter(sub_derived_columns) + }; + let (flatten_plan, derived_columns) = self.flatten_plan_with_scope( outer, subquery.unary_child(), correlated_columns, flatten_info, need_cross_join, + &child_scope, )?; - for predicate in filter.predicates.iter() { - predicates.push(self.flatten_scalar(predicate, correlated_columns)?); - } - let filter_plan = Filter { predicates }.into(); - Ok(SExpr::create_unary( - Arc::new(filter_plan), - Arc::new(flatten_plan), + let predicates = filter + .predicates + .iter() + .map(|predicate| self.flatten_scalar(predicate, correlated_columns, &derived_columns)) + .collect::>()?; + + Ok(( + flatten_plan.build_unary(Filter { predicates }), + derived_columns, )) } @@ -337,7 +386,8 @@ impl SubqueryDecorrelatorOptimizer { join: &Join, correlated_columns: &ColumnSet, flatten_info: &mut FlattenInfo, - ) -> Result { + derived_columns: &DerivedColumnScope, + ) -> Result { // Helper function to check if conditions need a cross join fn needs_cross_join( conditions: &[JoinEquiCondition], @@ -358,7 +408,7 @@ impl SubqueryDecorrelatorOptimizer { fn process_conditions( conditions: &[ScalarExpr], correlated_columns: &ColumnSet, - derived_columns: &HashMap, + derived_columns: &DerivedColumnScope, need_cross_join: bool, ) -> Result> { if need_cross_join { @@ -368,10 +418,8 @@ impl SubqueryDecorrelatorOptimizer { let mut new_condition = condition.clone(); for col in condition.used_columns() { if correlated_columns.contains(&col) { - let new_col = derived_columns.get(&col).ok_or_else(|| { - ErrorCode::Internal(format!("Missing derived column {col}")) - })?; - new_condition.replace_column(col, *new_col)?; + let new_col = derived_columns.must_resolve(col)?; + new_condition.replace_column(col, new_col)?; } } Ok(new_condition) @@ -403,19 +451,21 @@ impl SubqueryDecorrelatorOptimizer { } } - let left_flatten_plan = self.flatten_plan( + let (left_flatten_plan, left_derived_columns) = self.flatten_plan_with_scope( outer, subquery.left_child(), correlated_columns, flatten_info, left_need_cross_join, + derived_columns, )?; - let right_flatten_plan = self.flatten_plan( + let (right_flatten_plan, right_derived_columns) = self.flatten_plan_with_scope( outer, subquery.right_child(), correlated_columns, flatten_info, right_need_cross_join, + derived_columns, )?; let left_conditions = join @@ -426,7 +476,7 @@ impl SubqueryDecorrelatorOptimizer { let left_conditions = process_conditions( &left_conditions, correlated_columns, - &self.derived_columns, + &left_derived_columns, left_need_cross_join, )?; let right_conditions = join @@ -437,18 +487,21 @@ impl SubqueryDecorrelatorOptimizer { let right_conditions = process_conditions( &right_conditions, correlated_columns, - &self.derived_columns, + &right_derived_columns, right_need_cross_join, )?; + let mut derived_columns = derived_columns.clone(); + derived_columns.absorb_child_scope(&left_derived_columns); + derived_columns.absorb_child_scope(&right_derived_columns); let non_equi_conditions = process_conditions( &join.non_equi_conditions, correlated_columns, - &self.derived_columns, + &derived_columns, true, )?; - Ok(SExpr::create_binary( - Arc::new( + Ok(( + SExpr::create_binary( Join { equi_conditions: JoinEquiCondition::new_conditions( left_conditions, @@ -463,11 +516,11 @@ impl SubqueryDecorrelatorOptimizer { is_lateral: false, single_to_inner: None, build_side_cache_info: None, - } - .into(), + }, + left_flatten_plan, + right_flatten_plan, ), - Arc::new(left_flatten_plan), - Arc::new(right_flatten_plan), + derived_columns, )) } @@ -479,16 +532,18 @@ impl SubqueryDecorrelatorOptimizer { correlated_columns: &ColumnSet, flatten_info: &mut FlattenInfo, mut need_cross_join: bool, - ) -> Result { + derived_columns: &DerivedColumnScope, + ) -> Result { if !aggregate.used_columns()?.is_disjoint(correlated_columns) { need_cross_join = true; } - let flatten_plan = self.flatten_plan( + let (flatten_plan, derived_columns) = self.flatten_plan_with_scope( outer, subquery.unary_child(), correlated_columns, flatten_info, need_cross_join, + derived_columns, )?; let metadata = self.metadata.clone(); @@ -500,14 +555,15 @@ impl SubqueryDecorrelatorOptimizer { .chain(correlated_columns.iter().copied().map(Item::Index)) .map(|item| match item { Item::Scalar(item) => { - let scalar = self.flatten_scalar(&item.scalar, correlated_columns)?; + let scalar = + self.flatten_scalar(&item.scalar, correlated_columns, &derived_columns)?; Ok(ScalarItem { scalar, index: item.index, }) } Item::Index(old) => Ok(Self::scalar_item_from_index( - self.get_derived(old)?, + derived_columns.must_resolve(old)?, "outer.", &metadata, )), @@ -517,7 +573,7 @@ impl SubqueryDecorrelatorOptimizer { let mut agg_items = Vec::with_capacity(aggregate.aggregate_functions.len()); for item in aggregate.aggregate_functions.iter() { - let scalar = self.flatten_scalar(&item.scalar, correlated_columns)?; + let scalar = self.flatten_scalar(&item.scalar, correlated_columns, &derived_columns)?; if let ScalarExpr::AggregateFunction(AggregateFunction { func_name, .. }) = &scalar { // For scalar subquery, we'll convert it to single join. // Single join is similar to left outer join, if there isn't matched row in the right side, we'll add NULL value for the right side. @@ -534,19 +590,16 @@ impl SubqueryDecorrelatorOptimizer { index: item.index, }) } - Ok(SExpr::create_unary( - Arc::new( - Aggregate { - mode: AggregateMode::Initial, - group_items, - aggregate_functions: agg_items, - from_distinct: aggregate.from_distinct, - rank_limit: aggregate.rank_limit.clone(), - grouping_sets: aggregate.grouping_sets.clone(), - } - .into(), - ), - Arc::new(flatten_plan), + Ok(( + flatten_plan.build_unary(Aggregate { + mode: AggregateMode::Initial, + group_items, + aggregate_functions: agg_items, + from_distinct: aggregate.from_distinct, + rank_limit: aggregate.rank_limit.clone(), + grouping_sets: aggregate.grouping_sets.clone(), + }), + derived_columns, )) } @@ -558,14 +611,16 @@ impl SubqueryDecorrelatorOptimizer { correlated_columns: &ColumnSet, flatten_info: &mut FlattenInfo, need_cross_join: bool, - ) -> Result { + derived_columns: &DerivedColumnScope, + ) -> Result { // Currently, we don't support sort contain subquery. - let flatten_plan = self.flatten_plan( + let (flatten_plan, derived_columns) = self.flatten_plan_with_scope( outer, subquery.unary_child(), correlated_columns, flatten_info, need_cross_join, + derived_columns, )?; // Check if sort contains `count() or distinct count()`. if sort.items.iter().any(|item| { @@ -580,9 +635,9 @@ impl SubqueryDecorrelatorOptimizer { }) { flatten_info.from_count_func = false; } - Ok(SExpr::create_unary( - subquery.plan.clone(), - Arc::new(flatten_plan), + Ok(( + flatten_plan.build_unary(subquery.plan.clone()), + derived_columns, )) } @@ -594,15 +649,17 @@ impl SubqueryDecorrelatorOptimizer { correlated_columns: &ColumnSet, flatten_info: &mut FlattenInfo, need_cross_join: bool, - ) -> Result { - let (flatten_plan, order_by) = match subquery.unary_child().plan() { + derived_columns: &DerivedColumnScope, + ) -> Result { + let (flatten_plan, derived_columns, order_by) = match subquery.unary_child().plan() { RelOperator::Sort(sort) => { - let flatten_plan = self.flatten_plan( + let (flatten_plan, derived_columns) = self.flatten_plan_with_scope( outer, subquery.unary_child().unary_child(), correlated_columns, flatten_info, need_cross_join, + derived_columns, )?; if sort.items.iter().any(|item| { @@ -622,11 +679,7 @@ impl SubqueryDecorrelatorOptimizer { .items .iter() .map(|item| { - let index = self - .derived_columns - .get(&item.index) - .copied() - .unwrap_or(item.index); + let index = derived_columns.resolve_or_self(item.index); Ok(WindowOrderByInfo { order_by_item: Self::scalar_item_from_index(index, "", &metadata), asc: Some(item.asc), @@ -634,22 +687,23 @@ impl SubqueryDecorrelatorOptimizer { }) }) .collect::>>()?; - (flatten_plan, order_by) + (flatten_plan, derived_columns, order_by) } - _ => ( - self.flatten_plan( + _ => { + let (flatten_plan, derived_columns) = self.flatten_plan_with_scope( outer, subquery.unary_child(), correlated_columns, flatten_info, need_cross_join, - )?, - vec![], - ), + derived_columns, + )?; + (flatten_plan, derived_columns, vec![]) + } }; if limit.limit.is_none() && limit.offset == 0 { - return Ok(flatten_plan); + return Ok((flatten_plan, derived_columns)); } let metadata = self.metadata.read(); @@ -658,7 +712,7 @@ impl SubqueryDecorrelatorOptimizer { .copied() .map(|old| { Ok(Self::scalar_item_from_index( - self.get_derived(old)?, + derived_columns.must_resolve(old)?, "outer.", &metadata, )) @@ -712,27 +766,21 @@ impl SubqueryDecorrelatorOptimizer { let window_child = if sort_items.is_empty() { flatten_plan } else { - SExpr::create_unary( - Arc::new( - Sort { - items: sort_items, - limit: None, - after_exchange: None, - pre_projection: None, - window_partition: if partition_by.is_empty() { - None - } else { - Some(WindowPartition { - partition_by: partition_by.clone(), - top: None, - func: WindowFuncType::RowNumber, - }) - }, - } - .into(), - ), - Arc::new(flatten_plan), - ) + flatten_plan.build_unary(Sort { + items: sort_items, + limit: None, + after_exchange: None, + pre_projection: None, + window_partition: if partition_by.is_empty() { + None + } else { + Some(WindowPartition { + partition_by: partition_by.clone(), + top: None, + func: WindowFuncType::RowNumber, + }) + }, + }) }; let row_number_column = ScalarExpr::BoundColumnRef(BoundColumnRef { @@ -779,12 +827,11 @@ impl SubqueryDecorrelatorOptimizer { })); } - Ok(SExpr::create_unary( - Arc::new(Filter { predicates }.into()), - Arc::new(SExpr::create_unary( - Arc::new(window_plan.into()), - Arc::new(window_child), - )), + Ok(( + window_child + .build_unary(window_plan) + .build_unary(Filter { predicates }), + derived_columns, )) } @@ -795,18 +842,20 @@ impl SubqueryDecorrelatorOptimizer { window: &Window, correlated_columns: &ColumnSet, flatten_info: &mut FlattenInfo, - ) -> Result { + derived_columns: &DerivedColumnScope, + ) -> Result { if !window.used_columns()?.is_disjoint(correlated_columns) { return Err(ErrorCode::SemanticError( "correlated columns in window functions not supported", )); } - let flatten_plan = self.flatten_plan( + let (flatten_plan, derived_columns) = self.flatten_plan_with_scope( outer, subquery.unary_child(), correlated_columns, flatten_info, true, + derived_columns, )?; let metadata = self.metadata.read(); let partition_by = window @@ -816,7 +865,7 @@ impl SubqueryDecorrelatorOptimizer { .map(Ok) .chain(correlated_columns.iter().copied().map(|old| { Ok(Self::scalar_item_from_index( - self.get_derived(old)?, + derived_columns.must_resolve(old)?, "outer.", &metadata, )) @@ -824,21 +873,18 @@ impl SubqueryDecorrelatorOptimizer { .collect::>()?; drop(metadata); - Ok(SExpr::create_unary( - Arc::new( - Window { - span: window.span, - index: window.index, - function: window.function.clone(), - arguments: window.arguments.clone(), - partition_by, - order_by: window.order_by.clone(), - frame: window.frame.clone(), - limit: window.limit, - } - .into(), - ), - Arc::new(flatten_plan), + Ok(( + flatten_plan.build_unary(Window { + span: window.span, + index: window.index, + function: window.function.clone(), + arguments: window.arguments.clone(), + partition_by, + order_by: window.order_by.clone(), + frame: window.frame.clone(), + limit: window.limit, + }), + derived_columns, )) } @@ -850,7 +896,8 @@ impl SubqueryDecorrelatorOptimizer { correlated_columns: &ColumnSet, flatten_info: &mut FlattenInfo, mut need_cross_join: bool, - ) -> Result { + derived_columns: &DerivedColumnScope, + ) -> Result { if !union_all.used_columns()?.is_disjoint(correlated_columns) { need_cross_join = true; } @@ -864,28 +911,28 @@ impl SubqueryDecorrelatorOptimizer { need_cross_join || !correlated_columns.is_subset(&right_prop.outer_columns); let mut union_all = union_all.clone(); - let left_flatten_plan = self.flatten_plan( + let (left_flatten_plan, left_derived) = self.flatten_plan_with_scope( outer, subquery.left_child(), correlated_columns, flatten_info, left_need_cross_join, + derived_columns, )?; - let left_derived = std::mem::take(&mut self.derived_columns); Self::rewrite_union_branch_outputs( &mut union_all.left_outputs, correlated_columns, &left_derived, )?; - let right_flatten_plan = self.flatten_plan( + let (right_flatten_plan, right_derived) = self.flatten_plan_with_scope( outer, subquery.right_child(), correlated_columns, flatten_info, right_need_cross_join, + derived_columns, )?; - let right_derived = std::mem::take(&mut self.derived_columns); Self::rewrite_union_branch_outputs( &mut union_all.right_outputs, correlated_columns, @@ -893,6 +940,7 @@ impl SubqueryDecorrelatorOptimizer { )?; let mut metadata = self.metadata.write(); + let mut derived_columns = DerivedColumnScope::default(); union_all .output_indexes .extend(correlated_columns.iter().copied().map(|old| { @@ -900,26 +948,25 @@ impl SubqueryDecorrelatorOptimizer { let name = column_entry.name(); let data_type = column_entry.data_type(); let new = metadata.add_derived_column(name, data_type); - self.derived_columns.insert(old, new); + derived_columns.record(old, new); new })); - Ok(SExpr::create_binary( - Arc::new(union_all.clone().into()), - Arc::new(left_flatten_plan), - Arc::new(right_flatten_plan), + Ok(( + SExpr::create_binary(union_all, left_flatten_plan, right_flatten_plan), + derived_columns, )) } fn rewrite_union_branch_outputs( branch_outputs: &mut Vec<(Symbol, Option)>, correlated_columns: &ColumnSet, - derived: &HashMap, + derived: &DerivedColumnScope, ) -> Result<()> { *branch_outputs = branch_outputs .drain(..) .map(|(old, mut expr)| { - let Some(&new) = derived.get(&old) else { + let Some(new) = derived.resolve(old) else { return Ok((old, expr)); }; if let Some(expr) = &mut expr { @@ -928,10 +975,7 @@ impl SubqueryDecorrelatorOptimizer { Ok((new, expr)) }) .chain(correlated_columns.iter().copied().map(|old| { - let new = derived - .get(&old) - .copied() - .ok_or_else(|| ErrorCode::Internal(format!("Missing derived column {old}")))?; + let new = derived.must_resolve(old)?; Ok((new, None)) })) .collect::>()?; @@ -943,16 +987,17 @@ impl SubqueryDecorrelatorOptimizer { subquery: &SExpr, scan: &ExpressionScan, correlated_columns: &ColumnSet, - ) -> Result { + derived_columns: &DerivedColumnScope, + ) -> Result { let binder = self.binder.as_ref().unwrap(); - for correlated_column in correlated_columns.iter() { + let mut derived_columns = derived_columns.clone(); + for correlated_column in correlated_columns.iter().copied() { let derived_column_index = binder .expression_scan_context - .get_derived_column(scan.expression_scan_index, *correlated_column); - self.derived_columns - .insert(*correlated_column, derived_column_index); + .get_derived_column(scan.expression_scan_index, correlated_column); + derived_columns.record(correlated_column, derived_column_index); } - Ok(subquery.clone()) + Ok((subquery.clone(), derived_columns)) } fn rewrite_to_join_then_aggr( @@ -960,8 +1005,10 @@ impl SubqueryDecorrelatorOptimizer { outer: &SExpr, subquery: &SExpr, correlated_columns: &ColumnSet, - ) -> Result { - let outer = self.clone_outer_recursive(outer)?; + derived_columns: &DerivedColumnScope, + ) -> Result { + let mut derived_columns = derived_columns.clone(); + let outer = self.clone_outer_recursive(outer, &mut derived_columns)?; // Wrap logical get with distinct to eliminate duplicates rows. let metadata = self.metadata.read(); @@ -970,42 +1017,39 @@ impl SubqueryDecorrelatorOptimizer { .copied() .map(|old| { Ok(Self::scalar_item_from_index( - self.get_derived(old)?, + derived_columns.must_resolve(old)?, "", &metadata, )) }) .collect::>()?; - let aggr = SExpr::create_unary( - Arc::new( - Aggregate { - mode: AggregateMode::Initial, - group_items, - ..Default::default() - } - .into(), - ), - Arc::new(outer), - ); + let aggr = outer.build_unary(Aggregate { + mode: AggregateMode::Initial, + group_items, + ..Default::default() + }); - Ok(SExpr::create_binary( - Arc::new(Join::default().into()), - Arc::new(aggr), - Arc::new(subquery.clone()), + Ok(( + SExpr::create_binary(Join::default(), aggr, subquery.clone()), + derived_columns, )) } #[recursive::recursive] - fn clone_outer_recursive(&mut self, outer: &SExpr) -> Result { + fn clone_outer_recursive( + &mut self, + outer: &SExpr, + derived_columns: &mut DerivedColumnScope, + ) -> Result { let children = outer .children .iter() - .map(|child| Ok(self.clone_outer_recursive(child)?.into())) + .map(|child| Ok(self.clone_outer_recursive(child, derived_columns)?.into())) .collect::>()?; Ok(SExpr::create( - self.clone_outer_plan(outer.plan())?, + self.clone_outer_plan(outer.plan(), derived_columns)?, children, None, None, @@ -1013,20 +1057,30 @@ impl SubqueryDecorrelatorOptimizer { )) } - fn clone_outer_plan(&mut self, plan: &RelOperator) -> Result { + fn clone_outer_plan( + &mut self, + plan: &RelOperator, + derived_columns: &mut DerivedColumnScope, + ) -> Result { let op = match plan { RelOperator::DummyTableScan(scan) => scan.clone().into(), - RelOperator::ConstantTableScan(scan) => self.clone_outer_constant_table_scan(scan)?, - RelOperator::Scan(scan) => self.clone_outer_scan(scan), - RelOperator::RecursiveCteScan(scan) => self.clone_outer_recursive_cte_scan(scan)?, - RelOperator::UnionAll(union_all) => self.clone_outer_union_all(union_all)?, + RelOperator::ConstantTableScan(scan) => { + self.clone_outer_constant_table_scan(scan, derived_columns)? + } + RelOperator::Scan(scan) => self.clone_outer_scan(scan, derived_columns), + RelOperator::RecursiveCteScan(scan) => { + self.clone_outer_recursive_cte_scan(scan, derived_columns)? + } + RelOperator::UnionAll(union_all) => { + self.clone_outer_union_all(union_all, derived_columns)? + } RelOperator::Sequence(sequence) => self.clone_outer_sequence(sequence), - RelOperator::EvalScalar(eval) => self.clone_outer_eval_scalar(eval)?, + RelOperator::EvalScalar(eval) => self.clone_outer_eval_scalar(eval, derived_columns)?, RelOperator::Limit(limit) => limit.clone().into(), RelOperator::Sort(sort) => { let mut sort = sort.clone(); for old in sort.used_columns() { - sort.replace_column(old, self.get_derived(old)?); + sort.replace_column(old, derived_columns.must_resolve(old)?); } sort.into() } @@ -1034,7 +1088,7 @@ impl SubqueryDecorrelatorOptimizer { let mut filter = filter.clone(); for predicate in &mut filter.predicates { for old in predicate.used_columns() { - predicate.replace_column(old, self.get_derived(old)?)?; + predicate.replace_column(old, derived_columns.must_resolve(old)?)?; } } filter.into() @@ -1042,7 +1096,7 @@ impl SubqueryDecorrelatorOptimizer { RelOperator::Join(join) => { let mut join = join.clone(); for old in join.used_columns()? { - join.replace_column(old, self.get_derived(old)?)?; + join.replace_column(old, derived_columns.must_resolve(old)?)?; } if let Some(mark) = &mut join.marker_index { let mut metadata = self.metadata.write(); @@ -1050,7 +1104,7 @@ impl SubqueryDecorrelatorOptimizer { let name = column_entry.name(); let data_type = column_entry.data_type(); let new_mark = metadata.add_derived_column(name, data_type); - self.derived_columns.insert(*mark, new_mark); + derived_columns.record(*mark, new_mark); *mark = new_mark; } join.into() @@ -1060,10 +1114,10 @@ impl SubqueryDecorrelatorOptimizer { let metadata = self.metadata.clone(); let mut metadata = metadata.write(); for item in &mut aggregate.group_items { - *item = self.clone_outer_scalar_item(item, &mut metadata)?; + *item = self.clone_outer_scalar_item(item, &mut metadata, derived_columns)?; } for func in &mut aggregate.aggregate_functions { - *func = self.clone_outer_scalar_item(func, &mut metadata)?; + *func = self.clone_outer_scalar_item(func, &mut metadata, derived_columns)?; } aggregate.rank_limit = None; if aggregate.grouping_sets.is_some() { @@ -1083,7 +1137,11 @@ impl SubqueryDecorrelatorOptimizer { Ok(op) } - fn clone_outer_constant_table_scan(&mut self, scan: &ConstantTableScan) -> Result { + fn clone_outer_constant_table_scan( + &mut self, + scan: &ConstantTableScan, + derived_columns: &mut DerivedColumnScope, + ) -> Result { let mut metadata = self.metadata.write(); let ((values, fields), columns) = scan .columns @@ -1095,7 +1153,7 @@ impl SubqueryDecorrelatorOptimizer { let derived_index = metadata.add_derived_column(name, field.data_type().clone()); let field = DataField::new(&derived_index.to_string(), field.data_type().clone()); - self.derived_columns.insert(index, derived_index); + derived_columns.record(index, derived_index); Ok(((value, field), derived_index)) }) .collect::, Vec<_>), ColumnSet)>>()?; @@ -1109,7 +1167,11 @@ impl SubqueryDecorrelatorOptimizer { .into()) } - fn clone_outer_scan(&mut self, scan: &Scan) -> RelOperator { + fn clone_outer_scan( + &mut self, + scan: &Scan, + derived_columns: &mut DerivedColumnScope, + ) -> RelOperator { let mut metadata = self.metadata.write(); let columns = scan .columns @@ -1119,7 +1181,7 @@ impl SubqueryDecorrelatorOptimizer { let column_entry = metadata.column(col).clone(); let derived_index = metadata.add_derived_column(column_entry.name(), column_entry.data_type()); - self.derived_columns.insert(col, derived_index); + derived_columns.record(col, derived_index); derived_index }) .collect(); @@ -1127,7 +1189,11 @@ impl SubqueryDecorrelatorOptimizer { scan.derive_decorrelated_scan(columns, scan_id).into() } - fn clone_outer_recursive_cte_scan(&mut self, scan: &RecursiveCteScan) -> Result { + fn clone_outer_recursive_cte_scan( + &mut self, + scan: &RecursiveCteScan, + derived_columns: &mut DerivedColumnScope, + ) -> Result { let mut metadata = self.metadata.write(); let fields = scan .fields @@ -1137,7 +1203,7 @@ impl SubqueryDecorrelatorOptimizer { let column_entry = metadata.column(index).clone(); let derived_index = metadata.add_derived_column(column_entry.name(), column_entry.data_type()); - self.derived_columns.insert(index, derived_index); + derived_columns.record(index, derived_index); Ok(DataField::new( &derived_index.to_string(), field.data_type().clone(), @@ -1153,18 +1219,25 @@ impl SubqueryDecorrelatorOptimizer { .into()) } - fn clone_outer_union_all(&mut self, union_all: &UnionAll) -> Result { + fn clone_outer_union_all( + &mut self, + union_all: &UnionAll, + derived_columns: &mut DerivedColumnScope, + ) -> Result { let mut union_all = union_all.clone(); union_all.left_outputs = union_all .left_outputs .drain(..) .map(|(old, mut expr)| { - let Some(&new) = self.derived_columns.get(&old) else { + let Some(new) = derived_columns.resolve(old) else { return Ok((old, expr)); }; if let Some(expr) = &mut expr { for used_column in expr.used_columns() { - expr.replace_column(used_column, self.get_derived(used_column)?)?; + expr.replace_column( + used_column, + derived_columns.must_resolve(used_column)?, + )?; } } Ok((new, expr)) @@ -1174,12 +1247,15 @@ impl SubqueryDecorrelatorOptimizer { .right_outputs .drain(..) .map(|(old, mut expr)| { - let Some(&new) = self.derived_columns.get(&old) else { + let Some(new) = derived_columns.resolve(old) else { return Ok((old, expr)); }; if let Some(expr) = &mut expr { for used_column in expr.used_columns() { - expr.replace_column(used_column, self.get_derived(used_column)?)?; + expr.replace_column( + used_column, + derived_columns.must_resolve(used_column)?, + )?; } } Ok((new, expr)) @@ -1196,7 +1272,7 @@ impl SubqueryDecorrelatorOptimizer { let name = column_entry.name().to_string(); let data_type = column_entry.data_type(); let new = metadata.add_derived_column(name, data_type); - self.derived_columns.insert(old, new); + derived_columns.record(old, new); new }) .collect(); @@ -1208,13 +1284,17 @@ impl SubqueryDecorrelatorOptimizer { sequence.clone().into() } - fn clone_outer_eval_scalar(&mut self, eval: &EvalScalar) -> Result { + fn clone_outer_eval_scalar( + &mut self, + eval: &EvalScalar, + derived_columns: &mut DerivedColumnScope, + ) -> Result { let metadata = self.metadata.clone(); let mut metadata = metadata.write(); let items = eval .items .iter() - .map(|item| self.clone_outer_scalar_item(item, &mut metadata)) + .map(|item| self.clone_outer_scalar_item(item, &mut metadata, derived_columns)) .collect::>()?; Ok(EvalScalar { items }.into()) } @@ -1223,12 +1303,13 @@ impl SubqueryDecorrelatorOptimizer { &mut self, ScalarItem { scalar, index }: &ScalarItem, metadata: &mut Metadata, + derived_columns: &mut DerivedColumnScope, ) -> Result { let mut scalar = scalar.clone(); let index = *index; match scalar { ScalarExpr::BoundColumnRef(ref mut column_ref) if column_ref.column.index == index => { - let new_index = self.get_derived(index)?; + let new_index = derived_columns.must_resolve(index)?; column_ref.column.index = new_index; Ok(ScalarItem { scalar, @@ -1237,14 +1318,14 @@ impl SubqueryDecorrelatorOptimizer { } _ => { for old in scalar.used_columns() { - scalar.replace_column(old, self.get_derived(old)?)?; + scalar.replace_column(old, derived_columns.must_resolve(old)?)?; } let column_entry = metadata.column(index); let name = column_entry.name(); let data_type = column_entry.data_type(); let old = index; let index = metadata.add_derived_column(name, data_type); - self.derived_columns.insert(old, index); + derived_columns.record(old, index); Ok(ScalarItem { scalar, index }) } } @@ -1264,13 +1345,6 @@ impl SubqueryDecorrelatorOptimizer { index, } } - - pub fn get_derived(&self, old: Symbol) -> Result { - self.derived_columns - .get(&old) - .copied() - .ok_or_else(|| ErrorCode::Internal(format!("Missing derived column {old}"))) - } } enum Item<'a> { diff --git a/src/query/sql/src/planner/optimizer/optimizers/operator/decorrelate/flatten_scalar.rs b/src/query/sql/src/planner/optimizer/optimizers/operator/decorrelate/flatten_scalar.rs index be1f09fae49c4..aafeb995b6072 100644 --- a/src/query/sql/src/planner/optimizer/optimizers/operator/decorrelate/flatten_scalar.rs +++ b/src/query/sql/src/planner/optimizer/optimizers/operator/decorrelate/flatten_scalar.rs @@ -15,6 +15,7 @@ use databend_common_exception::ErrorCode; use databend_common_exception::Result; +use super::DerivedColumnScope; use crate::ColumnSet; use crate::binder::ColumnBindingBuilder; use crate::optimizer::optimizers::operator::SubqueryDecorrelatorOptimizer; @@ -29,22 +30,23 @@ use crate::plans::UDFCall; impl SubqueryDecorrelatorOptimizer { #[recursive::recursive] pub(crate) fn flatten_scalar( - &mut self, + &self, scalar: &ScalarExpr, correlated_columns: &ColumnSet, + derived_columns: &DerivedColumnScope, ) -> Result { match scalar { ScalarExpr::BoundColumnRef(bound_column) => { let column_binding = bound_column.column.clone(); if correlated_columns.contains(&column_binding.index) { - let index = self.derived_columns.get(&column_binding.index).unwrap(); + let index = derived_columns.must_resolve(column_binding.index)?; let metadata = self.metadata.read(); - let column_entry = metadata.column(*index); + let column_entry = metadata.column(index); return Ok(ScalarExpr::BoundColumnRef(BoundColumnRef { span: scalar.span(), column: ColumnBindingBuilder::new( column_entry.name(), - *index, + index, Box::new(column_entry.data_type()), column_binding.visibility, ) @@ -57,12 +59,16 @@ impl SubqueryDecorrelatorOptimizer { ScalarExpr::AggregateFunction(agg) => { let mut args = Vec::with_capacity(agg.args.len()); for arg in &agg.args { - args.push(self.flatten_scalar(arg, correlated_columns)?); + args.push(self.flatten_scalar(arg, correlated_columns, derived_columns)?); } let mut sort_descs = Vec::with_capacity(agg.sort_descs.len()); for desc in &agg.sort_descs { sort_descs.push(AggregateFunctionScalarSortDesc { - expr: self.flatten_scalar(&desc.expr, correlated_columns)?, + expr: self.flatten_scalar( + &desc.expr, + correlated_columns, + derived_columns, + )?, is_reuse_index: desc.is_reuse_index, nulls_first: desc.nulls_first, asc: desc.asc, @@ -83,7 +89,7 @@ impl SubqueryDecorrelatorOptimizer { let arguments = func .arguments .iter() - .map(|arg| self.flatten_scalar(arg, correlated_columns)) + .map(|arg| self.flatten_scalar(arg, correlated_columns, derived_columns)) .collect::>>()?; Ok(ScalarExpr::FunctionCall(FunctionCall { span: func.span, @@ -93,7 +99,8 @@ impl SubqueryDecorrelatorOptimizer { })) } ScalarExpr::CastExpr(cast_expr) => { - let scalar = self.flatten_scalar(&cast_expr.argument, correlated_columns)?; + let scalar = + self.flatten_scalar(&cast_expr.argument, correlated_columns, derived_columns)?; Ok(ScalarExpr::CastExpr(CastExpr { span: cast_expr.span, is_try: cast_expr.is_try, @@ -105,7 +112,7 @@ impl SubqueryDecorrelatorOptimizer { let arguments = udf .arguments .iter() - .map(|arg| self.flatten_scalar(arg, correlated_columns)) + .map(|arg| self.flatten_scalar(arg, correlated_columns, derived_columns)) .collect::>>()?; Ok(ScalarExpr::UDFCall(UDFCall { span: udf.span, diff --git a/src/query/sql/src/planner/optimizer/optimizers/operator/decorrelate/mod.rs b/src/query/sql/src/planner/optimizer/optimizers/operator/decorrelate/mod.rs index 791a749965951..5639cfc1aefeb 100644 --- a/src/query/sql/src/planner/optimizer/optimizers/operator/decorrelate/mod.rs +++ b/src/query/sql/src/planner/optimizer/optimizers/operator/decorrelate/mod.rs @@ -18,6 +18,78 @@ mod flatten_plan; mod flatten_scalar; mod subquery_decorrelator; +use std::collections::HashMap; + +use databend_common_exception::ErrorCode; +use databend_common_exception::Result; +use databend_common_expression::Symbol; pub use subquery_decorrelator::FlattenInfo; pub use subquery_decorrelator::SubqueryDecorrelatorOptimizer; pub use subquery_decorrelator::UnnestResult; + +use crate::optimizer::ir::SExpr; + +type FlattenPlanResult = (SExpr, DerivedColumnScope); + +#[derive(Clone, Default)] +pub(crate) struct DerivedColumnScope { + inherited: HashMap, + local: HashMap, +} + +impl DerivedColumnScope { + fn child_scope_for_filter(&self, filter_derived_columns: Self) -> Self { + let mut inherited = self.snapshot(); + inherited.extend( + filter_derived_columns + .inherited + .into_iter() + .chain(filter_derived_columns.local), + ); + Self { + inherited, + local: HashMap::new(), + } + } + + fn absorb_child_scope(&mut self, child: &Self) { + for (old, new) in child.snapshot() { + if self.resolve(old) == Some(new) { + continue; + } + if !self.local.contains_key(&old) { + self.local.insert(old, new); + } + } + } + + fn record(&mut self, old: Symbol, new: Symbol) { + self.local.insert(old, new); + } + + pub(crate) fn resolve(&self, old: Symbol) -> Option { + self.local + .get(&old) + .copied() + .or_else(|| self.inherited.get(&old).copied()) + } + + pub(crate) fn resolve_or_self(&self, old: Symbol) -> Symbol { + self.resolve(old).unwrap_or(old) + } + + pub(crate) fn must_resolve(&self, old: Symbol) -> Result { + self.resolve(old) + .ok_or_else(|| ErrorCode::Internal(format!("Missing derived column {old}"))) + } + + pub(crate) fn visible_symbols(&self) -> Vec { + self.snapshot().into_values().collect() + } + + fn snapshot(&self) -> HashMap { + let mut visible = self.inherited.clone(); + visible.extend(self.local.iter().map(|(&old, &new)| (old, new))); + visible + } +} diff --git a/src/query/sql/src/planner/optimizer/optimizers/operator/decorrelate/subquery_decorrelator.rs b/src/query/sql/src/planner/optimizer/optimizers/operator/decorrelate/subquery_decorrelator.rs index bf9582b6c2eb3..571cbf2a2b4ef 100644 --- a/src/query/sql/src/planner/optimizer/optimizers/operator/decorrelate/subquery_decorrelator.rs +++ b/src/query/sql/src/planner/optimizer/optimizers/operator/decorrelate/subquery_decorrelator.rs @@ -12,9 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashMap; use std::sync::Arc; -use std::vec; use databend_common_catalog::table_context::TableContext; use databend_common_exception::ErrorCode; @@ -144,7 +142,6 @@ pub struct FlattenInfo { pub struct SubqueryDecorrelatorOptimizer { pub(crate) ctx: Arc, pub(crate) metadata: MetadataRef, - pub(crate) derived_columns: HashMap, pub(crate) binder: Option, } @@ -153,7 +150,6 @@ impl SubqueryDecorrelatorOptimizer { Self { ctx: opt_ctx.get_table_ctx(), metadata: opt_ctx.get_metadata(), - derived_columns: Default::default(), binder, } } @@ -183,7 +179,7 @@ impl SubqueryDecorrelatorOptimizer { for item in eval.items.iter_mut() { (item.scalar, outer) = self.try_rewrite_subquery(&item.scalar, outer, false)?; } - Ok(SExpr::create_unary(Arc::new(eval.into()), Arc::new(outer))) + Ok(outer.build_unary(eval)) } RelOperator::Filter(plan) => { @@ -192,7 +188,7 @@ impl SubqueryDecorrelatorOptimizer { for pred in plan.predicates.iter_mut() { (*pred, outer) = self.try_rewrite_subquery(pred, outer, true)?; } - Ok(SExpr::create_unary(Arc::new(plan.into()), Arc::new(outer))) + Ok(outer.build_unary(plan)) } RelOperator::ProjectSet(plan) => { @@ -201,7 +197,7 @@ impl SubqueryDecorrelatorOptimizer { for item in plan.srfs.iter_mut() { (item.scalar, outer) = self.try_rewrite_subquery(&item.scalar, outer, false)?; } - Ok(SExpr::create_unary(Arc::new(plan.into()), Arc::new(outer))) + Ok(outer.build_unary(plan)) } RelOperator::Aggregate(plan) => { @@ -213,7 +209,7 @@ impl SubqueryDecorrelatorOptimizer { for item in plan.aggregate_functions.iter_mut() { (item.scalar, outer) = self.try_rewrite_subquery(&item.scalar, outer, false)?; } - Ok(SExpr::create_unary(Arc::new(plan.into()), Arc::new(outer))) + Ok(outer.build_unary(plan)) } RelOperator::Window(plan) => { @@ -235,14 +231,14 @@ impl SubqueryDecorrelatorOptimizer { } } - Ok(SExpr::create_unary(Arc::new(plan.into()), Arc::new(outer))) + Ok(outer.build_unary(plan)) } RelOperator::Sort(sort) => { let mut outer = self.optimize_sync(s_expr.unary_child())?; let Some(mut window) = sort.window_partition.clone() else { - return Ok(SExpr::create_unary(s_expr.plan.clone(), Arc::new(outer))); + return Ok(outer.build_unary(s_expr.plan.clone())); }; for item in window.partition_by.iter_mut() { @@ -253,7 +249,7 @@ impl SubqueryDecorrelatorOptimizer { ..sort.clone() }; - Ok(SExpr::create_unary(Arc::new(sort.into()), Arc::new(outer))) + Ok(outer.build_unary(sort)) } RelOperator::Join(join) => { @@ -306,10 +302,7 @@ impl SubqueryDecorrelatorOptimizer { (*pred, outer) = self.try_rewrite_subquery(pred, outer, true)?; } let filter = Filter { predicates }; - return Ok(SExpr::create_unary( - Arc::new(filter.into()), - Arc::new(outer), - )); + return Ok(outer.build_unary(filter)); } RelOperator::UnionAll(_) | RelOperator::Sequence(_) => Ok(SExpr::create_binary( @@ -321,10 +314,9 @@ impl SubqueryDecorrelatorOptimizer { RelOperator::Limit(_) | RelOperator::Udf(_) | RelOperator::AsyncFunction(_) - | RelOperator::MaterializedCTE(_) => Ok(SExpr::create_unary( - s_expr.plan.clone(), - Arc::new(self.optimize_sync(s_expr.unary_child())?), - )), + | RelOperator::MaterializedCTE(_) => Ok(self + .optimize_sync(s_expr.unary_child())? + .build_unary(s_expr.plan.clone())), RelOperator::DummyTableScan(_) | RelOperator::Scan(_) @@ -427,12 +419,13 @@ impl SubqueryDecorrelatorOptimizer { let mut flatten_info = FlattenInfo { from_count_func: false, }; - let (outer, result) = if prop.outer_columns.is_empty() { - self.try_rewrite_uncorrelated_subquery( + let (outer, result, derived_columns) = if prop.outer_columns.is_empty() { + let (outer, result) = self.try_rewrite_uncorrelated_subquery( outer, &subquery, is_conjunctive_predicate, - )? + )?; + (outer, result, Default::default()) } else { // todo: optimize outer before decorrelate subquery self.try_decorrelate_subquery( @@ -478,8 +471,8 @@ impl SubqueryDecorrelatorOptimizer { (marker_index, marker_index.to_string()) } else if let UnnestResult::SingleJoin = result { let mut output_column = subquery.output_column; - if let Some(index) = self.derived_columns.get(&output_column.index) { - output_column.index = *index; + if let Some(index) = derived_columns.resolve(output_column.index) { + output_column.index = index; } ( output_column.index, @@ -555,8 +548,6 @@ impl SubqueryDecorrelatorOptimizer { } else { column_ref }; - // After finishing rewriting subquery, we should clear the derived columns. - self.derived_columns.clear(); Ok((scalar, outer)) } } @@ -590,8 +581,7 @@ impl SubqueryDecorrelatorOptimizer { before_exchange: false, lazy_columns: Default::default(), }; - subquery_expr = - SExpr::create_unary(Arc::new(limit.into()), Arc::new(subquery_expr)); + subquery_expr = subquery_expr.build_unary(limit); // We will rewrite EXISTS subquery into the form `COUNT(*) = 1`. // For example, `EXISTS(SELECT a FROM t WHERE a > 1)` will be rewritten into @@ -648,10 +638,7 @@ impl SubqueryDecorrelatorOptimizer { ], }; - let agg_s_expr = Arc::new(SExpr::create_unary( - Arc::new(agg.into()), - Arc::new(subquery_expr), - )); + let agg_s_expr = Arc::new(subquery_expr.build_unary(agg)); let mut output_index = None; let rewritten_subquery = if is_conjunctive_predicate { @@ -660,7 +647,7 @@ impl SubqueryDecorrelatorOptimizer { }; // Filter: COUNT(*) = 1 or COUNT(*) != 1 // └── Aggregate: COUNT(*) - SExpr::create_unary(Arc::new(filter.into()), agg_s_expr) + agg_s_expr.ref_build_unary(filter) } else { let column_index = self.metadata.write().add_derived_column( "_exists_scalar_subquery".to_string(), @@ -673,7 +660,7 @@ impl SubqueryDecorrelatorOptimizer { index: column_index, }], }; - SExpr::create_unary(Arc::new(eval_scalar.into()), agg_s_expr) + agg_s_expr.ref_build_unary(eval_scalar) }; let cross_join = Join { diff --git a/src/query/sql/tests/it/optimizer/decorrelate_correlated_aliases.rs b/src/query/sql/tests/it/optimizer/decorrelate_correlated_aliases.rs new file mode 100644 index 0000000000000..73dd5e68bf45b --- /dev/null +++ b/src/query/sql/tests/it/optimizer/decorrelate_correlated_aliases.rs @@ -0,0 +1,94 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use databend_common_exception::Result; + +use crate::framework::golden::SqlTestCase; +use crate::framework::golden::open_golden_file; +use crate::framework::golden::setup_context; +use crate::framework::golden::write_case_header; + +async fn write_optimized_case(file: &mut impl std::io::Write, case: &SqlTestCase) -> Result<()> { + let ctx = setup_context(case).await?; + ctx.set_cluster_node_num(1); + + let raw_plan = ctx.bind_sql(case.sql).await?; + let optimized_plan = ctx.optimize_plan(raw_plan.clone()).await?; + + write_case_header(file, case)?; + writeln!(file, "raw_plan:")?; + writeln!(file, "{}", raw_plan.format_indent(Default::default())?)?; + writeln!(file, "optimized_plan:")?; + writeln!( + file, + "{}", + optimized_plan.format_indent(Default::default())? + )?; + writeln!(file)?; + + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 1)] +async fn test_decorrelate_correlated_alias_regressions() -> Result<()> { + let mut file = open_golden_file("optimizer", "decorrelate_correlated_aliases.txt")?; + + let cases = [ + SqlTestCase { + name: "nested_filter_alias_reaches_limit_rewrite", + description: "Filter-derived correlated aliases must remain visible while rewriting a deeper correlated LIMIT subtree.", + setup_sqls: &[], + sql: r#" + SELECT * + FROM (VALUES (1, 1)) AS t1(a, b) + WHERE EXISTS ( + SELECT 1 + FROM ( + SELECT t2.a + FROM (VALUES (1, 1)) AS t2(a, b) + WHERE t2.b = t1.b + LIMIT 1 + ) AS s + WHERE s.a = t1.a + ) + "#, + }, + SqlTestCase { + name: "nested_filter_alias_survives_deeper_join_rewrite", + description: "A deeper join remap must override a stale filter-local alias instead of collapsing the correlated predicate into a self-equality.", + setup_sqls: &[], + sql: r#" + SELECT * + FROM (VALUES (1, 1)) AS t1(a, b) + WHERE EXISTS ( + SELECT 1 + FROM ( + SELECT t2.a + FROM (VALUES (1, 1)) AS t2(a, b) + JOIN (VALUES (1)) AS t3(c) + ON t2.a = t1.a + WHERE t2.b = t1.b + ) AS s + WHERE s.a = t1.a + ) + "#, + }, + ]; + + for case in &cases { + write_optimized_case(&mut file, case).await?; + } + + Ok(()) +} diff --git a/src/query/sql/tests/it/optimizer/decorrelate_correlated_aliases.txt b/src/query/sql/tests/it/optimizer/decorrelate_correlated_aliases.txt new file mode 100644 index 0000000000000..9fda60990915f --- /dev/null +++ b/src/query/sql/tests/it/optimizer/decorrelate_correlated_aliases.txt @@ -0,0 +1,157 @@ +=== nested_filter_alias_reaches_limit_rewrite === +description: Filter-derived correlated aliases must remain visible while rewriting a deeper correlated LIMIT subtree. +sql: + SELECT * + FROM (VALUES (1, 1)) AS t1(a, b) + WHERE EXISTS ( + SELECT 1 + FROM ( + SELECT t2.a + FROM (VALUES (1, 1)) AS t2(a, b) + WHERE t2.b = t1.b + LIMIT 1 + ) AS s + WHERE s.a = t1.a + ) + +raw_plan: +EvalScalar +├── scalars: [a (#0) AS (#0), b (#1) AS (#1)] +└── Filter + ├── filters: [SUBQUERY AS (#4)] + ├── subquerys + │ └── Subquery (Exists) + │ ├── output_column: 1 (#4) + │ └── EvalScalar + │ ├── scalars: [1 AS (#4)] + │ └── Filter + │ ├── filters: [eq(a (#2), a (#0))] + │ └── Limit + │ ├── limit: [1] + │ ├── offset: [0] + │ └── EvalScalar + │ ├── scalars: [a (#2) AS (#2)] + │ └── Filter + │ ├── filters: [eq(b (#3), b (#1))] + │ └── ConstantTableScan + │ ├── columns: [a (#2), b (#3)] + │ └── num_rows: [1] + └── ConstantTableScan + ├── columns: [a (#0), b (#1)] + └── num_rows: [1] + +optimized_plan: +Exchange(Merge) +└── EvalScalar + ├── scalars: [a (#0) AS (#0), b (#1) AS (#1), marker (#6) AS (#8)] + └── Filter + ├── filters: [is_true(marker (#6))] + └── Join(RightMark) + ├── build keys: [a (#2), b (#3)] + ├── probe keys: [a (#0), b (#1)] + ├── other filters: [] + ├── Exchange(Broadcast) + │ └── EvalScalar + │ ├── scalars: [a (#2) AS (#2), b (#3) AS (#3), 1 AS (#4), correlated_limit_row_number (#5) AS (#7)] + │ └── Filter + │ ├── filters: [lte(correlated_limit_row_number (#5), 1)] + │ └── Window + │ ├── aggregate function: row_number + │ ├── partition items: [a (#2) AS (#2), b (#3) AS (#3)] + │ ├── order by items: [] + │ ├── frame: [Rows: Preceding(None) ~ CurrentRow] + │ └── Sort + │ ├── sort keys: [a (#2) ASC NULLS LAST, b (#3) ASC NULLS LAST] + │ ├── limit: [NONE] + │ ├── window top: 1 + │ ├── window function: RowNumber + │ └── Exchange(Hash) + │ ├── Exchange(Hash): keys: [a (#2), b (#3)] + │ └── ConstantTableScan + │ ├── columns: [a (#2), b (#3)] + │ └── num_rows: [1] + └── ConstantTableScan + ├── columns: [a (#0), b (#1)] + └── num_rows: [1] + + +=== nested_filter_alias_survives_deeper_join_rewrite === +description: A deeper join remap must override a stale filter-local alias instead of collapsing the correlated predicate into a self-equality. +sql: + SELECT * + FROM (VALUES (1, 1)) AS t1(a, b) + WHERE EXISTS ( + SELECT 1 + FROM ( + SELECT t2.a + FROM (VALUES (1, 1)) AS t2(a, b) + JOIN (VALUES (1)) AS t3(c) + ON t2.a = t1.a + WHERE t2.b = t1.b + ) AS s + WHERE s.a = t1.a + ) + +raw_plan: +EvalScalar +├── scalars: [a (#0) AS (#0), b (#1) AS (#1)] +└── Filter + ├── filters: [SUBQUERY AS (#5)] + ├── subquerys + │ └── Subquery (Exists) + │ ├── output_column: 1 (#5) + │ └── EvalScalar + │ ├── scalars: [1 AS (#5)] + │ └── Filter + │ ├── filters: [eq(a (#2), a (#0))] + │ └── EvalScalar + │ ├── scalars: [a (#2) AS (#2)] + │ └── Filter + │ ├── filters: [eq(b (#3), b (#1))] + │ └── Join(Inner) + │ ├── build keys: [] + │ ├── probe keys: [] + │ ├── other filters: [eq(a (#2), a (#0))] + │ ├── ConstantTableScan + │ │ ├── columns: [c (#4)] + │ │ └── num_rows: [1] + │ └── ConstantTableScan + │ ├── columns: [a (#2), b (#3)] + │ └── num_rows: [1] + └── ConstantTableScan + ├── columns: [a (#0), b (#1)] + └── num_rows: [1] + +optimized_plan: +Exchange(Merge) +└── EvalScalar + ├── scalars: [a (#0) AS (#0), b (#1) AS (#1), marker (#6) AS (#9)] + └── Filter + ├── filters: [is_true(marker (#6))] + └── Join(RightMark) + ├── build keys: [a (#2), b (#3)] + ├── probe keys: [a (#0), b (#1)] + ├── other filters: [] + ├── Exchange(Broadcast) + │ └── EvalScalar + │ ├── scalars: [a (#2) AS (#2), b (#3) AS (#3), 1 AS (#5), a (#7) AS (#8)] + │ └── EvalScalar + │ ├── scalars: [a (#2) AS (#2), a (#2) AS (#2), b (#3) AS (#3), a (#0) AS (#7)] + │ └── Filter + │ ├── filters: [eq(a (#2), a (#0))] + │ └── Join(Cross) + │ ├── build keys: [] + │ ├── probe keys: [] + │ ├── other filters: [] + │ ├── Exchange(Broadcast) + │ │ └── ConstantTableScan + │ │ ├── columns: [c (#4)] + │ │ └── num_rows: [1] + │ └── ConstantTableScan + │ ├── columns: [a (#2), b (#3)] + │ └── num_rows: [1] + └── ConstantTableScan + ├── columns: [a (#0), b (#1)] + └── num_rows: [1] + + diff --git a/src/query/sql/tests/it/optimizer/mod.rs b/src/query/sql/tests/it/optimizer/mod.rs index 9cb7859cc30c8..92cf3e052994f 100644 --- a/src/query/sql/tests/it/optimizer/mod.rs +++ b/src/query/sql/tests/it/optimizer/mod.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +mod decorrelate_correlated_aliases; mod eager_aggregation; mod normalize_scalar; mod push_down_filter_project_set;