diff --git a/dev/diffs/3.4.3.diff b/dev/diffs/3.4.3.diff index 169df737a2..d905dcf446 100644 --- a/dev/diffs/3.4.3.diff +++ b/dev/diffs/3.4.3.diff @@ -417,7 +417,7 @@ index daef11ae4d6..9f3cc9181f2 100644 assert(exchanges.size == 2) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala -index f33432ddb6f..914afa6b01d 100644 +index f33432ddb6f..b375e285dde 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala @@ -22,6 +22,7 @@ import org.scalatest.GivenWhenThen @@ -477,27 +477,7 @@ index f33432ddb6f..914afa6b01d 100644 assert(countSubqueryBroadcasts == 1) assert(countReusedSubqueryBroadcasts == 1) -@@ -1215,7 +1231,8 @@ abstract class DynamicPartitionPruningSuiteBase - } - - test("SPARK-32509: Unused Dynamic Pruning filter shouldn't affect " + -- "canonicalization and exchange reuse") { -+ "canonicalization and exchange reuse", -+ IgnoreComet("TODO: https://github.com/apache/datafusion-comet/issues/4045")) { - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - val df = sql( -@@ -1423,7 +1440,8 @@ abstract class DynamicPartitionPruningSuiteBase - } - } - -- test("SPARK-34637: DPP side broadcast query stage is created firstly") { -+ test("SPARK-34637: DPP side broadcast query stage is created firstly", -+ IgnoreComet("TODO: https://github.com/apache/datafusion-comet/issues/4045")) { - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { - val df = sql( - """ WITH v as ( -@@ -1577,6 +1595,7 @@ abstract class DynamicPartitionPruningSuiteBase +@@ -1577,6 +1593,7 @@ abstract class DynamicPartitionPruningSuiteBase val subqueryBroadcastExecs = collectWithSubqueries(df.queryExecution.executedPlan) { case s: SubqueryBroadcastExec => s @@ -505,7 +485,7 @@ index f33432ddb6f..914afa6b01d 100644 } assert(subqueryBroadcastExecs.size === 1) subqueryBroadcastExecs.foreach { subqueryBroadcastExec => -@@ -1729,6 +1748,10 @@ abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningDat +@@ -1729,6 +1746,10 @@ abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningDat case s: BatchScanExec => // we use f1 col for v2 tables due to schema pruning s.output.exists(_.exists(_.argString(maxFields = 100).contains("f1"))) diff --git a/dev/diffs/3.5.8.diff b/dev/diffs/3.5.8.diff index cec58a78d1..d3a5c617dc 100644 --- a/dev/diffs/3.5.8.diff +++ b/dev/diffs/3.5.8.diff @@ -398,7 +398,7 @@ index c4fb4fa943c..a04b23870a8 100644 assert(exchanges.size == 2) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala -index f33432ddb6f..914afa6b01d 100644 +index f33432ddb6f..b375e285dde 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala @@ -22,6 +22,7 @@ import org.scalatest.GivenWhenThen @@ -458,27 +458,7 @@ index f33432ddb6f..914afa6b01d 100644 assert(countSubqueryBroadcasts == 1) assert(countReusedSubqueryBroadcasts == 1) -@@ -1215,7 +1231,8 @@ abstract class DynamicPartitionPruningSuiteBase - } - - test("SPARK-32509: Unused Dynamic Pruning filter shouldn't affect " + -- "canonicalization and exchange reuse") { -+ "canonicalization and exchange reuse", -+ IgnoreComet("TODO: https://github.com/apache/datafusion-comet/issues/4045")) { - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - val df = sql( -@@ -1423,7 +1440,8 @@ abstract class DynamicPartitionPruningSuiteBase - } - } - -- test("SPARK-34637: DPP side broadcast query stage is created firstly") { -+ test("SPARK-34637: DPP side broadcast query stage is created firstly", -+ IgnoreComet("TODO: https://github.com/apache/datafusion-comet/issues/4045")) { - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { - val df = sql( - """ WITH v as ( -@@ -1577,6 +1595,7 @@ abstract class DynamicPartitionPruningSuiteBase +@@ -1577,6 +1593,7 @@ abstract class DynamicPartitionPruningSuiteBase val subqueryBroadcastExecs = collectWithSubqueries(df.queryExecution.executedPlan) { case s: SubqueryBroadcastExec => s @@ -486,7 +466,7 @@ index f33432ddb6f..914afa6b01d 100644 } assert(subqueryBroadcastExecs.size === 1) subqueryBroadcastExecs.foreach { subqueryBroadcastExec => -@@ -1729,6 +1748,10 @@ abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningDat +@@ -1729,6 +1746,10 @@ abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningDat case s: BatchScanExec => // we use f1 col for v2 tables due to schema pruning s.output.exists(_.exists(_.argString(maxFields = 100).contains("f1"))) diff --git a/dev/diffs/4.0.2.diff b/dev/diffs/4.0.2.diff index d196cdbd3d..445e7e75a7 100644 --- a/dev/diffs/4.0.2.diff +++ b/dev/diffs/4.0.2.diff @@ -535,7 +535,7 @@ index 81713c777bc..b5f92ed9742 100644 assert(exchanges.size == 2) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala -index 2c24cc7d570..8c214e7d05c 100644 +index 2c24cc7d570..12d897866da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DynamicPartitionPruningSuite.scala @@ -22,6 +22,7 @@ import org.scalatest.GivenWhenThen @@ -595,27 +595,7 @@ index 2c24cc7d570..8c214e7d05c 100644 assert(countSubqueryBroadcasts == 1) assert(countReusedSubqueryBroadcasts == 1) -@@ -1215,7 +1231,8 @@ abstract class DynamicPartitionPruningSuiteBase - } - - test("SPARK-32509: Unused Dynamic Pruning filter shouldn't affect " + -- "canonicalization and exchange reuse") { -+ "canonicalization and exchange reuse", -+ IgnoreComet("TODO: https://github.com/apache/datafusion-comet/issues/4045")) { - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { - withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { - val df = sql( -@@ -1424,7 +1441,8 @@ abstract class DynamicPartitionPruningSuiteBase - } - } - -- test("SPARK-34637: DPP side broadcast query stage is created firstly") { -+ test("SPARK-34637: DPP side broadcast query stage is created firstly", -+ IgnoreComet("TODO: https://github.com/apache/datafusion-comet/issues/4045")) { - withSQLConf(SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { - val df = sql( - """ WITH v as ( -@@ -1578,6 +1596,7 @@ abstract class DynamicPartitionPruningSuiteBase +@@ -1578,6 +1594,7 @@ abstract class DynamicPartitionPruningSuiteBase val subqueryBroadcastExecs = collectWithSubqueries(df.queryExecution.executedPlan) { case s: SubqueryBroadcastExec => s @@ -623,7 +603,7 @@ index 2c24cc7d570..8c214e7d05c 100644 } assert(subqueryBroadcastExecs.size === 1) subqueryBroadcastExecs.foreach { subqueryBroadcastExec => -@@ -1730,6 +1749,10 @@ abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningDat +@@ -1730,6 +1747,10 @@ abstract class DynamicPartitionPruningV1Suite extends DynamicPartitionPruningDat case s: BatchScanExec => // we use f1 col for v2 tables due to schema pruning s.output.exists(_.exists(_.argString(maxFields = 100).contains("f1"))) diff --git a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala index eda41d02a1..679005d9b1 100644 --- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala +++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.internal.SQLConf import org.apache.comet.CometConf._ -import org.apache.comet.rules.{CometExecRule, CometReuseSubquery, CometScanRule, EliminateRedundantTransitions} +import org.apache.comet.rules.{CometExecRule, CometPlanAdaptiveDynamicPruningFilters, CometReuseSubquery, CometScanRule, CometSpark34AqeDppFallbackRule, EliminateRedundantTransitions} import org.apache.comet.shims.ShimCometSparkSessionExtensions /** @@ -43,34 +43,44 @@ import org.apache.comet.shims.ShimCometSparkSessionExtensions * * Non-AQE (QueryExecution.preparations): * {{{ - * 1. PlanDynamicPruningFilters -- Spark creates DPP filters + * 1. PlanDynamicPruningFilters -- Spark creates non-AQE DPP (SubqueryBroadcastExec) * 2. PlanSubqueries -- Spark creates SubqueryExec for scalar subqueries * 3. EnsureRequirements -- Spark inserts shuffles/sorts * 4. ApplyColumnarRulesAndInsertTransitions: - * a. preColumnarTransitions: CometScanRule, CometExecRule (replace Spark -> Comet nodes) + * a. preColumnarTransitions: CometScanRule, CometExecRule + * - CometExecRule.convertSubqueryBroadcasts converts SubqueryBroadcastExec to + * CometSubqueryBroadcastExec for exchange reuse with Comet broadcasts * b. insertTransitions: ColumnarToRow/RowToColumnar added * c. postColumnarTransitions: EliminateRedundantTransitions * 5. ReuseExchangeAndSubquery -- Spark deduplicates subqueries (sees Comet nodes) * }}} * - * AQE (AdaptiveSparkPlanExec): + * AQE (AdaptiveSparkPlanExec, Spark 3.5+): * {{{ * Initial plan: - * queryStagePreparationRules: CometScanRule, CometExecRule (replace Spark -> Comet nodes) + * PlanAdaptiveSubqueries: creates SubqueryAdaptiveBroadcastExec (SAB) for AQE DPP + * queryStagePreparationRules: CometScanRule, CometExecRule + * - CometExecRule.convertSubqueryBroadcasts wraps SABs in + * CometSubqueryAdaptiveBroadcastExec to prevent Spark's + * PlanAdaptiveDynamicPruningFilters from replacing DPP with Literal.TrueLiteral * * Per stage (optimizeQueryStage + postStageCreationRules): - * 1. queryStageOptimizerRules: ReuseAdaptiveSubquery, CometReuseSubquery + * 1. queryStageOptimizerRules: + * a. PlanAdaptiveDynamicPruningFilters (Spark) -- skips wrapped SABs + * b. ReuseAdaptiveSubquery (Spark) + * c. CometPlanAdaptiveDynamicPruningFilters -- converts wrapped SABs to + * CometSubqueryBroadcastExec with BroadcastQueryStageExec for broadcast reuse + * d. CometReuseSubquery -- deduplicates converted subqueries * 2. postStageCreationRules -> ApplyColumnarRulesAndInsertTransitions: * a. preColumnarTransitions: CometScanRule, CometExecRule (no-ops, already converted) * b. insertTransitions * c. postColumnarTransitions: EliminateRedundantTransitions * }}} * - * CometReuseSubquery is needed in AQE because Spark's ReuseAdaptiveSubquery may run before - * Comet's node replacements in the initial plan construction, and the replacements can disrupt - * subquery reuse that was already applied. The shim-based registration - * (injectQueryStageOptimizerRuleShim) handles API availability: Spark 3.5+ has - * injectQueryStageOptimizerRule, Spark 3.4 does not (no-op). + * On Spark 3.4, injectQueryStageOptimizerRule is unavailable. CometExecRule does not wrap SABs, + * and CometPlanAdaptiveDynamicPruningFilters/CometReuseSubquery are not registered. AQE DPP scans + * fall back to Spark so that Spark's PlanAdaptiveDynamicPruningFilters handles them natively + * (with DPP). */ class CometSparkSessionExtensions extends (SparkSessionExtensions => Unit) @@ -79,8 +89,13 @@ class CometSparkSessionExtensions override def apply(extensions: SparkSessionExtensions): Unit = { extensions.injectColumnar { session => CometScanColumnar(session) } extensions.injectColumnar { session => CometExecColumnar(session) } + // Pre-3.5 only: tag AQE DPP regions so the conversion rules below leave them Spark-native. + // Registered before CometScanRule/CometExecRule so tags are in place when conversion runs. + // No-op on Spark 3.5+; see CometSpark34AqeDppFallbackRule's class docstring. + injectPreSpark35QueryStagePrepRuleShim(extensions, CometSpark34AqeDppFallbackRule) extensions.injectQueryStagePrepRule { session => CometScanRule(session) } extensions.injectQueryStagePrepRule { session => CometExecRule(session) } + injectQueryStageOptimizerRuleShim(extensions, CometPlanAdaptiveDynamicPruningFilters) injectQueryStageOptimizerRuleShim(extensions, CometReuseSubquery) } diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala index 4b15c26d27..213c3d0ad7 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala @@ -96,6 +96,13 @@ object CometExecRule { val SKIP_COMET_SHUFFLE_TAG: org.apache.spark.sql.catalyst.trees.TreeNodeTag[Unit] = org.apache.spark.sql.catalyst.trees.TreeNodeTag[Unit]("comet.skipCometShuffle") + /** + * Tag set on a `BroadcastExchangeExec` that should be left as a plain Spark broadcast rather + * than converted to `CometBroadcastExchangeExec`. Written by [[CometSpark34AqeDppFallbackRule]] + * on Spark < 3.5. See that rule's class docstring for the rationale. + */ + val SKIP_COMET_BROADCAST_TAG: org.apache.spark.sql.catalyst.trees.TreeNodeTag[Unit] = + org.apache.spark.sql.catalyst.trees.TreeNodeTag[Unit]("comet.skipCometBroadcast") } /** @@ -297,6 +304,11 @@ case class CometExecRule(session: SparkSession) // broadcast exchange is forced to be enabled by Comet config. case plan if plan.children.exists(_.isInstanceOf[BroadcastExchangeExec]) => val newChildren = plan.children.map { + // Tagged by CometSpark34AqeDppFallbackRule on Spark < 3.5 to keep the build-side + // broadcast Spark-native so Spark's PlanAdaptiveDynamicPruningFilters can match it. + case b: BroadcastExchangeExec + if b.getTagValue(CometExecRule.SKIP_COMET_BROADCAST_TAG).isDefined => + b case b: BroadcastExchangeExec if b.children.forall(_.isInstanceOf[CometNativeExec]) => convertToComet(b, CometBroadcastExchangeExec).getOrElse(b) case other => other @@ -381,18 +393,22 @@ case class CometExecRule(session: SparkSession) } /** - * Replace SubqueryBroadcastExec with CometSubqueryBroadcastExec in a node's expressions. + * Replace SubqueryBroadcastExec with CometSubqueryBroadcastExec in a node's expressions + * (non-AQE DPP), and wrap SubqueryAdaptiveBroadcastExec in CometSubqueryAdaptiveBroadcastExec + * (AQE DPP) to protect it from Spark's PlanAdaptiveDynamicPruningFilters. * - * When CometExecRule converts BroadcastExchangeExec to CometBroadcastExchangeExec on the join - * side, the DPP subquery still references the original BroadcastExchangeExec. + * Non-AQE DPP: When CometExecRule converts BroadcastExchangeExec to CometBroadcastExchangeExec + * on the join side, the DPP subquery still references the original BroadcastExchangeExec. * ReuseExchangeAndSubquery (which runs after Comet rules) can't match them because they have * different types. By replacing SubqueryBroadcastExec with CometSubqueryBroadcastExec (which * wraps a CometBroadcastExchangeExec), both sides have the same exchange type and reuse works. * - * The BroadcastExchangeExec in the subquery has a CometNativeColumnarToRowExec child (inserted - * by ApplyColumnarRulesAndInsertTransitions because BroadcastExchangeExec expects row input). - * We strip this transition and create CometBroadcastExchangeExec with the underlying Comet plan - * directly. + * AQE DPP: Spark's PlanAdaptiveDynamicPruningFilters (queryStageOptimizerRule) pattern-matches + * on SubqueryAdaptiveBroadcastExec. When it can't find BroadcastHashJoinExec (Comet replaced + * it), it replaces DPP with Literal.TrueLiteral. We wrap SABs in + * CometSubqueryAdaptiveBroadcastExec to prevent this. CometPlanAdaptiveDynamicPruningFilters (a + * later queryStageOptimizerRule) unwraps and converts them with access to the materialized + * BroadcastQueryStageExec. */ private def convertSubqueryBroadcasts(plan: SparkPlan): SparkPlan = { plan.transformExpressionsUp { case inSub: InSubqueryExec => @@ -422,6 +438,32 @@ case class CometExecRule(session: SparkSession) } case _ => inSub } + case sab: SubqueryAdaptiveBroadcastExec if isSpark35Plus => + // Wrap SABs to prevent Spark's PlanAdaptiveDynamicPruningFilters from + // converting them to Literal.TrueLiteral. Spark's rule pattern-matches for + // BroadcastHashJoinExec, which Comet replaced with CometBroadcastHashJoinExec. + // Without wrapping, DPP is disabled for both Comet native scans and non-Comet + // scans (e.g., V2 BatchScan). CometPlanAdaptiveDynamicPruningFilters + // (queryStageOptimizerRule, 3.5+) unwraps and converts them later. + // + // On Spark 3.4, injectQueryStageOptimizerRule is unavailable. The isSpark35Plus + // guard leaves SABs unwrapped; CometSpark34AqeDppFallbackRule then tags the + // matching BHJ's build broadcast so Spark's rule can match it natively. + assert( + sab.buildKeys.nonEmpty, + s"SubqueryAdaptiveBroadcastExec '${sab.name}' has empty buildKeys") + logInfo( + s"Wrapping SubqueryAdaptiveBroadcastExec '${sab.name}' in " + + "CometSubqueryAdaptiveBroadcastExec to preserve AQE DPP") + val indices = getSubqueryBroadcastIndices(sab) + val wrapped = CometSubqueryAdaptiveBroadcastExec( + sab.name, + indices, + sab.onlyInBroadcast, + sab.buildPlan, + sab.buildKeys, + sab.child) + inSub.withNewPlan(wrapped) case _ => inSub } } diff --git a/spark/src/main/scala/org/apache/comet/rules/CometPlanAdaptiveDynamicPruningFilters.scala b/spark/src/main/scala/org/apache/comet/rules/CometPlanAdaptiveDynamicPruningFilters.scala new file mode 100644 index 0000000000..20207ffa5f --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/rules/CometPlanAdaptiveDynamicPruningFilters.scala @@ -0,0 +1,436 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.rules + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{Alias, BindReferences, DynamicPruningExpression, Expression, Literal} +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec, CometNativeScanExec, CometSubqueryAdaptiveBroadcastExec, CometSubqueryBroadcastExec} +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper} +import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashedRelationBroadcastMode, HashJoin} + +import org.apache.comet.shims.{ShimPrepareExecutedPlan, ShimSubqueryBroadcast} + +/** + * Converts CometSubqueryAdaptiveBroadcastExec (wrapped AQE DPP) to CometSubqueryBroadcastExec + * inside CometNativeScanExec's partitionFilters. + * + * CometExecRule wraps SubqueryAdaptiveBroadcastExec in CometSubqueryAdaptiveBroadcastExec during + * queryStagePreparationRules to prevent Spark's PlanAdaptiveDynamicPruningFilters from replacing + * DPP with Literal.TrueLiteral (which happens because Spark can't find BroadcastHashJoinExec + * after Comet replaced it with CometBroadcastHashJoinExec). + * + * This rule runs as a queryStageOptimizerRule (after Spark's built-in rules). We find the + * matching CometBroadcastHashJoinExec (or BroadcastHashJoinExec for fallback) and create + * CometSubqueryBroadcastExec for broadcast reuse. + * + * Also handles the dual-filter problem: CometNativeScanExec.partitionFilters and + * CometScanExec.partitionFilters are separate InSubqueryExec instances. Both must be converted + * because CometScanExec.dynamicallySelectedPartitions evaluates its own partitionFilters. + * + * @see + * PlanAdaptiveDynamicPruningFilters (Spark's equivalent for BroadcastHashJoinExec) + * @see + * CometExecRule.convertSubqueryBroadcasts (non-AQE DPP + SAB wrapping) + */ +case object CometPlanAdaptiveDynamicPruningFilters + extends Rule[SparkPlan] + with AdaptiveSparkPlanHelper + with ShimSubqueryBroadcast + with ShimPrepareExecutedPlan + with Logging { + + override def apply(plan: SparkPlan): SparkPlan = { + if (!conf.dynamicPartitionPruningEnabled) { + return plan + } + + // TODO(#3510): CometNativeScanExec needs special handling because its makeCopy + // loses @transient scan and expression transformations. Once makeCopy is fixed + // (or CometScanExec wrapping is removed), replace both cases with a single + // plan.transformAllExpressions call matching Spark's PlanAdaptiveDynamicPruningFilters. + plan.transformUp { + case nativeScan: CometNativeScanExec if nativeScan.partitionFilters.exists(hasCometSAB) => + logDebug("Converting AQE DPP for CometNativeScanExec") + convertNativeScanDPP(nativeScan, plan) + case p: SparkPlan if !p.isInstanceOf[CometNativeScanExec] && hasWrappedSAB(p) => + logDebug(s"Converting AQE DPP for non-Comet node: ${p.nodeName}") + convertNonCometNodeDPP(p, plan) + } + } + + private def convertNativeScanDPP( + nativeScan: CometNativeScanExec, + stagePlan: SparkPlan): CometNativeScanExec = { + val newOuterFilters = nativeScan.partitionFilters.map(f => convertFilter(f, stagePlan)) + + if (newOuterFilters == nativeScan.partitionFilters) return nativeScan + + // Dual-filter invariant: CometNativeScanExec.partitionFilters and + // CometScanExec.partitionFilters are separate InSubqueryExec instances for the + // same DPP filters. Both must be converted because + // CometScanExec.dynamicallySelectedPartitions evaluates its own filters. + assert( + nativeScan.scan != null, + "CometNativeScanExec with DPP filters must have a non-null CometScanExec") + val newInnerFilters = nativeScan.scan.partitionFilters.map(f => convertFilter(f, stagePlan)) + val newInnerScan = nativeScan.scan.copy(partitionFilters = newInnerFilters) + + nativeScan.copy(partitionFilters = newOuterFilters, scan = newInnerScan) + } + + private def convertFilter(filter: Expression, stagePlan: SparkPlan): Expression = { + filter.transformUp { case dpe @ DynamicPruningExpression(inSub: InSubqueryExec) => + extractSABData(inSub) match { + case Some(sabData) => + convertSAB(inSub, sabData, stagePlan) + case None => dpe + } + } + } + + /** + * Extracts SAB data from an InSubqueryExec's plan. Handles both: + * - CometSubqueryAdaptiveBroadcastExec (outer partitionFilters, wrapped by CometExecRule) + * - SubqueryAdaptiveBroadcastExec (inner CometScanExec.partitionFilters, never wrapped + * because CometScanExec is @transient and not part of the plan expression tree) + * + * Either form may itself be wrapped in a `ReusedSubqueryExec` when Spark's + * `ReuseAdaptiveSubquery` (which runs before our rule) dedupes identical DPP subqueries, e.g. + * TPC-DS q5/q14a/q14b/q54 where a single DPP pushes through a UNION ALL to multiple fact scans. + * The outer match below unwraps `ReusedSubqueryExec` before dispatching so the inner pattern + * match is reached regardless of reuse. If you add another wrapper type here, update + * `hasWrappedSAB` to match. The two must stay in sync or non-Comet nodes holding that wrapper + * will be skipped. + */ + private case class SABData( + name: String, + indices: Seq[Int], + onlyInBroadcast: Boolean, + buildPlan: LogicalPlan, + buildKeys: Seq[Expression], + adaptivePlan: SparkPlan) + + private def extractSABData(inSub: InSubqueryExec): Option[SABData] = { + def extract(plan: BaseSubqueryExec): Option[SABData] = { + plan match { + case csab: CometSubqueryAdaptiveBroadcastExec => + Some( + SABData( + csab.name, + csab.indices, + csab.onlyInBroadcast, + csab.buildPlan, + csab.buildKeys, + csab.child)) + case sab: SubqueryAdaptiveBroadcastExec => + Some( + SABData( + sab.name, + getSubqueryBroadcastIndices(sab), + sab.onlyInBroadcast, + sab.buildPlan, + sab.buildKeys, + sab.child)) + case _ => None + } + } + inSub.plan match { + // ReusedSubqueryExec extends BaseSubqueryExec, so unwrap it before dispatching + // to `BaseSubqueryExec`. The order is load-bearing: if the general case runs + // first it catches the wrapper and extract() returns None, leaving a wrapped + // CSAB in the plan that throws at doExecute() time. See the scaladoc above. + case ReusedSubqueryExec(sub) => extract(sub) + case sub: BaseSubqueryExec => extract(sub) + case _ => None + } + } + + /** + * Converts an SAB following the same decision tree as Spark's + * PlanAdaptiveDynamicPruningFilters: + * + * 1. exchangeReuseEnabled + matching broadcast join found: Create CometSubqueryBroadcastExec + * (or SubqueryBroadcastExec for Spark fallback) wired to the join's broadcast. DPP uses + * broadcast reuse via AQE's stageCache. + * + * 2. No reusable broadcast + onlyInBroadcast=true: Literal.TrueLiteral. DPP is disabled + * (correct results, scans all partitions). + * + * 3. No reusable broadcast + onlyInBroadcast=false: Aggregate SubqueryExec on the build side + * (DPP via separate execution, matching Spark's PlanAdaptiveDynamicPruningFilters lines 68-79). + */ + private def convertSAB( + inSub: InSubqueryExec, + sab: SABData, + stagePlan: SparkPlan): DynamicPruningExpression = { + val adaptivePlan = sab.adaptivePlan.asInstanceOf[AdaptiveSparkPlanExec] + + val sabKeyIds: Set[Any] = sab.buildKeys.flatMap(_.references.map(_.exprId)).toSet + assert( + sabKeyIds.nonEmpty, + s"DPP subquery '${sab.name}' has empty buildKeys - " + + "PlanAdaptiveSubqueries should always populate buildKeys") + + // Spark's PlanAdaptiveDynamicPruningFilters is constructed with rootPlan = the + // current AdaptiveSparkPlanExec (ASPE). Each ASPE (main query and each scalar + // subquery) gets its own rule instance pointing to itself. The rule searches + // rootPlan via find() to locate matching broadcast joins. + // + // Custom queryStageOptimizerRules (registered via injectQueryStageOptimizerRule) + // are instantiated once and shared across all ASPEs, so we don't get a per-ASPE + // rootPlan reference. We approximate Spark's behavior with two searches: + // + // 1. stagePlan: the plan arg to apply(), which is the current stage's child + // plan. Covers same-stage joins (the common case) and scalar subqueries + // where scan+join are under one exchange. + // 2. context.qe.executedPlan: the main query's ASPE, accessed via the shared + // AdaptiveExecutionContext. Covers cross-stage joins in the main query + // where a shuffle separates the scan from the broadcast join. + val rootPlan = adaptivePlan.context.qe.executedPlan + + val matchingJoin = findMatchingBroadcastJoin(sabKeyIds, stagePlan) + .orElse(findMatchingBroadcastJoin(sabKeyIds, rootPlan)) + val canReuse = conf.exchangeReuseEnabled && matchingJoin.isDefined + + if (canReuse) { + // Case 1: broadcast reuse. Matches Spark's PlanAdaptiveDynamicPruningFilters + // lines 44-64: construct a NEW exchange wrapping adaptivePlan.executedPlan, + // then wrap in a new ASPE. AQE's stageCache ensures the broadcast runs once + // via ReusedExchangeExec (same canonical form as the join's exchange). + val (broadcastChild, isComet) = matchingJoin.get + val buildSidePlan = adaptivePlan.executedPlan + logDebug( + s"Matched DPP subquery '${sab.name}' to " + + s"${if (isComet) "Comet" else "Spark"} broadcast: " + + s"${broadcastChild.getClass.getSimpleName}") + + // Construct the exchange from buildSidePlan (not from the existing exchange), + // matching Spark's PlanAdaptiveDynamicPruningFilters lines 44-48. The existing + // exchange may belong to a different plan context (e.g., the main query) with + // different attribute IDs than the current SAB's build side (e.g., a scalar + // subquery). Using the existing exchange's output/mode would cause schema + // mismatch when CometSubqueryBroadcastExec projects keys by exprId. + val packedKeys = BindReferences.bindReferences( + HashJoin.rewriteKeyExpr(sab.buildKeys), + buildSidePlan.output) + val mode = HashedRelationBroadcastMode(packedKeys) + val newExchange = if (isComet) { + CometBroadcastExchangeExec(buildSidePlan, buildSidePlan.output, mode, buildSidePlan) + } else { + BroadcastExchangeExec(mode, buildSidePlan) + } + buildSidePlan.logicalLink.foreach(newExchange.setLogicalLink) + + // supportsColumnar must match the exchange. ASPE.getFinalPhysicalPlan + // applies postStageCreationRules(supportsColumnar) to the final plan. + // With supportsColumnar=false (the SAB ASPE's default), + // ApplyColumnarRulesAndInsertTransitions wraps the BroadcastQueryStageExec + // in ColumnarToRowExec, failing the assertion at ASPE.doExecuteBroadcast + // that expects BroadcastQueryStageExec as the final plan. + val newAdaptivePlan = adaptivePlan.copy( + inputPlan = newExchange, + supportsColumnar = newExchange.supportsColumnar) + // ASPE constructor applies queryStagePreparationRules to inputPlan, + // which clears the logicalLink tag as a side effect. Re-set it so + // getFinalPhysicalPlan (line 276) can read inputPlan.logicalLink. + buildSidePlan.logicalLink.foreach(newAdaptivePlan.inputPlan.setLogicalLink) + + val subquery = if (isComet) { + CometSubqueryBroadcastExec(sab.name, sab.indices, sab.buildKeys, newAdaptivePlan) + } else { + createSubqueryBroadcastExec(sab.name, sab.indices, sab.buildKeys, newAdaptivePlan) + } + DynamicPruningExpression(inSub.withNewPlan(reuseOrRegisterSubquery(subquery, adaptivePlan))) + } else if (sab.onlyInBroadcast) { + // Case 2: no reusable broadcast, and the optimizer says DPP only makes sense + // with broadcast reuse. Disable DPP (Literal.TrueLiteral). + logInfo( + s"No reusable broadcast for DPP subquery '${sab.name}' " + + "(onlyInBroadcast=true), disabling DPP") + DynamicPruningExpression(Literal.TrueLiteral) + } else { + // Case 3: no reusable broadcast, but the optimizer says DPP is worthwhile + // even without broadcast reuse. Create an aggregate SubqueryExec on the build + // side to get distinct partition key values for pruning. + // Matches Spark's PlanAdaptiveDynamicPruningFilters lines 68-79. + val aliases = + sab.indices.map(idx => Alias(sab.buildKeys(idx), sab.buildKeys(idx).toString)()) + val aggregate = Aggregate(aliases, aliases, sab.buildPlan) + val sparkPlan = shimPrepareExecutedPlan(adaptivePlan, aggregate) + assert( + sparkPlan.isInstanceOf[AdaptiveSparkPlanExec], + "Expected AdaptiveSparkPlanExec from prepareExecutedPlan, " + + s"got ${sparkPlan.getClass.getSimpleName}") + val newAdaptivePlan = sparkPlan.asInstanceOf[AdaptiveSparkPlanExec] + val values = SubqueryExec(sab.name, newAdaptivePlan) + DynamicPruningExpression(InSubqueryExec(inSub.child, values, inSub.exprId)) + } + } + + /** + * Registers a DPP subquery in the shared AdaptiveExecutionContext.subqueryCache for cross-plan + * deduplication, matching ReuseAdaptiveSubquery's behavior. + * + * Our rule runs after Spark's ReuseAdaptiveSubquery (which can't see our subqueries because + * they don't exist yet when it runs). CometReuseSubquery uses a per-invocation local cache that + * doesn't span across the main query and scalar subquery plans. Using the shared context cache + * ensures that identical DPP subqueries across plans are deduplicated. + */ + private def reuseOrRegisterSubquery( + subquery: BaseSubqueryExec, + adaptivePlan: AdaptiveSparkPlanExec): BaseSubqueryExec = { + if (!conf.subqueryReuseEnabled) return subquery + val subqueryCache = adaptivePlan.context.subqueryCache + val cached = subqueryCache.getOrElseUpdate(subquery.canonicalized, subquery) + if (cached.ne(subquery)) { + logDebug(s"Reusing cached subquery for '${subquery.name}'") + ReusedSubqueryExec(cached) + } else { + subquery + } + } + + /** + * Finds a broadcast hash join whose build-side keys match the given exprIds. Searches for both + * CometBroadcastHashJoinExec and BroadcastHashJoinExec to handle cases where the join fell back + * to Spark. + * + * Spark's PlanAdaptiveDynamicPruningFilters uses sameResult() to match a constructed + * BroadcastExchangeExec against the join's build side. We can't do the same because + * sameResult() checks getClass equality first, so a BroadcastExchangeExec would never match a + * CometBroadcastExchangeExec. Matching on buildKey exprIds is semantically equivalent because + * SAB buildKeys originate from the same logical plan. + */ + private def findMatchingBroadcastJoin( + sabKeyIds: Set[Any], + plan: SparkPlan): Option[(SparkPlan, Boolean)] = { + var result: Option[(SparkPlan, Boolean)] = None + find(plan) { + case join: CometBroadcastHashJoinExec => + result = extractBroadcastChild( + join.buildSide, + join.left, + join.right, + join.leftKeys, + join.rightKeys, + isCometJoin = true, + sabKeyIds) + result.isDefined + case join: BroadcastHashJoinExec => + result = extractBroadcastChild( + join.buildSide, + join.left, + join.right, + join.leftKeys, + join.rightKeys, + isCometJoin = false, + sabKeyIds) + result.isDefined + case _ => false + } + result + } + + private def extractBroadcastChild( + buildSide: BuildSide, + left: SparkPlan, + right: SparkPlan, + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + isCometJoin: Boolean, + sabKeyIds: Set[Any]): Option[(SparkPlan, Boolean)] = { + val joinBuildKeys = buildSide match { + case BuildLeft => leftKeys + case BuildRight => rightKeys + } + val joinKeyIds: Set[Any] = joinBuildKeys.flatMap(_.references.map(_.exprId)).toSet + if (sabKeyIds == joinKeyIds) { + val bc = buildSide match { + case BuildLeft => left + case BuildRight => right + } + Some((bc, isCometJoin)) + } else { + // exprId mismatch between SAB and candidate BHJ. This is expected when the plan + // contains multiple broadcast joins and we're iterating through non-matches. + // It is also the silent-DPP-loss failure mode if AQE has rewritten attribute + // exprIds across stage boundaries such that no BHJ matches the SAB. Log at debug + // so it's observable without being noisy. + logDebug( + s"BHJ buildKey exprIds do not match SAB: sab=$sabKeyIds join=$joinKeyIds " + + s"(isCometJoin=$isCometJoin)") + None + } + } + + private def convertNonCometNodeDPP(node: SparkPlan, stagePlan: SparkPlan): SparkPlan = { + node.transformExpressions { + case expr if hasCometSAB(expr) => + convertFilter(expr, stagePlan) + } + } + + /** + * Checks if a SparkPlan's expressions contain a wrapped CometSubqueryAdaptiveBroadcastExec. + * Unlike hasCometSAB, this only checks for the wrapped variant. Unwrapped SABs on non-Comet + * nodes are handled by Spark's own PlanAdaptiveDynamicPruningFilters. + * + * Keep the set of accepted wrapper shapes in sync with `extractSABData`. If extractSABData + * learns to unwrap a new form, add the matching gate predicate here or non-Comet nodes holding + * that form will be silently skipped by `apply`. + */ + private def hasWrappedSAB(p: SparkPlan): Boolean = + p.expressions.exists(_.exists { + case DynamicPruningExpression( + InSubqueryExec(_, _: CometSubqueryAdaptiveBroadcastExec, _, _, _, _)) => + true + // ReuseAdaptiveSubquery wraps the CSAB when two scans share a DPP subquery + // (e.g. UNION ALL under one join). Mirrors extractSABData's unwrap. + case DynamicPruningExpression( + InSubqueryExec( + _, + ReusedSubqueryExec(_: CometSubqueryAdaptiveBroadcastExec), + _, + _, + _, + _)) => + true + case _ => false + }) + + /** + * Checks if an expression contains an SAB variant (wrapped or unwrapped). The outer + * CometNativeScanExec.partitionFilters has CometSubqueryAdaptiveBroadcastExec (wrapped by + * CometExecRule). The inner CometScanExec.partitionFilters may have the original + * SubqueryAdaptiveBroadcastExec (unwrapped, because CometScanExec is + * @transient). + */ + private def hasCometSAB(e: Expression): Boolean = + e.exists { + case DynamicPruningExpression(inSub: InSubqueryExec) => + extractSABData(inSub).isDefined + case _ => false + } +} diff --git a/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala index e6c58121b8..2cd5ba156c 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala @@ -43,7 +43,7 @@ import org.apache.spark.sql.types._ import org.apache.comet.{CometConf, CometNativeException, DataTypeSupport} import org.apache.comet.CometConf._ -import org.apache.comet.CometSparkSessionExtensions.{isCometLoaded, withInfo, withInfos} +import org.apache.comet.CometSparkSessionExtensions.{isCometLoaded, isSpark35Plus, withInfo, withInfos} import org.apache.comet.DataTypeSupport.isComplexType import org.apache.comet.iceberg.{CometIcebergNativeScanMetadata, IcebergReflection} import org.apache.comet.objectstore.NativeConfig @@ -113,6 +113,11 @@ case class CometScanRule(session: SparkSession) val fullPlan = plan def transformScan(scanNode: SparkPlan): SparkPlan = scanNode match { + // Tagged by CometSpark34AqeDppFallbackRule on Spark < 3.5 to keep a peer scan + // Spark-native for canonical symmetry in SMJ self-joins (SPARK-32509). + case scan if scan.getTagValue(CometScanRule.SKIP_COMET_SCAN_TAG).isDefined => + withInfo(scan, "AQE DPP region fallback (Spark < 3.5)") + case scan if !CometConf.COMET_NATIVE_SCAN_ENABLED.get(conf) => withInfo(scan, "Comet Scan is not enabled") @@ -139,8 +144,17 @@ case class CometScanRule(session: SparkSession) private def transformV1Scan(plan: SparkPlan, scanExec: FileSourceScanExec): SparkPlan = { - if (scanExec.partitionFilters.exists(isAqeDynamicPruningFilter)) { - return withInfo(scanExec, "AQE Dynamic Partition Pruning is not supported") + // On Spark 3.4, injectQueryStageOptimizerRule is unavailable, so + // CometPlanAdaptiveDynamicPruningFilters cannot run. Fall back this scan to Spark so that + // Spark's PlanAdaptiveDynamicPruningFilters handles the SAB natively. Comet's narrower + // CometSpark34AqeDppFallbackRule (queryStagePrepRule on 3.4) then tags any matching BHJ's + // build-side broadcast so Spark's rule can match it via sameResult. See + // CometSpark34AqeDppFallbackRule's class docstring. + // + // On 3.5+, CometPlanAdaptiveDynamicPruningFilters rewrites SABs directly and this fallback + // is not needed. + if (!isSpark35Plus && scanExec.partitionFilters.exists(isAqeDynamicPruningFilter)) { + return withInfo(scanExec, "AQE Dynamic Partition Pruning requires Spark 3.5+") } scanExec.relation match { @@ -730,6 +744,15 @@ case class CometScanTypeChecker(scanImpl: String) extends DataTypeSupport with C object CometScanRule extends Logging { + /** + * Tag set on a scan (`FileSourceScanExec` or `BatchScanExec`) that should be left as a plain + * Spark scan rather than converted to a Comet scan. Written by + * [[CometSpark34AqeDppFallbackRule]] on Spark < 3.5. See that rule's class docstring for the + * rationale. + */ + val SKIP_COMET_SCAN_TAG: org.apache.spark.sql.catalyst.trees.TreeNodeTag[Unit] = + org.apache.spark.sql.catalyst.trees.TreeNodeTag[Unit]("comet.skipCometScan") + /** * Validating object store configs can cause requests to be made to S3 APIs (such as when * resolving the region for a bucket). We use a cache to reduce the number of S3 calls. diff --git a/spark/src/main/scala/org/apache/comet/rules/CometSpark34AqeDppFallbackRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometSpark34AqeDppFallbackRule.scala new file mode 100644 index 0000000000..afedc1601c --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/rules/CometSpark34AqeDppFallbackRule.scala @@ -0,0 +1,392 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.rules + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.DynamicPruningExpression +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{FileSourceScanExec, InSubqueryExec, SparkPlan, SubqueryAdaptiveBroadcastExec, SubqueryBroadcastExec} +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.datasources.HadoopFsRelation +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec + +import org.apache.comet.CometSparkSessionExtensions.isSpark35Plus + +/** + * Preserves AQE Dynamic Partition Pruning on Spark < 3.5 by tagging specific nodes to stay + * Spark-native, so that Spark's PlanAdaptiveDynamicPruningFilters can match them natively. + * + * On Spark 3.5+, CometPlanAdaptiveDynamicPruningFilters (queryStageOptimizerRule) rewrites AQE + * DPP SABs into CometSubqueryBroadcastExec after Spark's own rule has run. Spark 3.4 does not + * expose injectQueryStageOptimizerRule (SPARK-45785 added it in 3.5), so the rewrite cannot run + * at the correct time. Rewriting the SAB at queryStagePrepRule time does not work either: AQE + * rebuilds plan nodes between prep and execution in ways that drop the `@transient` inner scan we + * would need to update. See the dual-filter handling in + * CometPlanAdaptiveDynamicPruningFilters.convertNativeScanDPP for why both filters need updating. + * + * Instead, on 3.4 we arrange for Spark's PlanAdaptiveDynamicPruningFilters to succeed on its own + * by keeping the BHJ's build-side exchange (or the peer branch of a self-join SMJ) Spark-native. + * This rule only writes skip-tags on nodes; it never rewrites expressions or plan structure. Tags + * survive AQE per-stage re-entry, matching the contract of SKIP_COMET_SHUFFLE_TAG from PR #4010. + * + * Registered via injectPreSpark35QueryStagePrepRuleShim before CometScanRule/CometExecRule in + * CometSparkSessionExtensions, so tags are in place when conversion runs. No-op on Spark 3.5+. + * + * Four cases handled: + * + * 1. SAB + matching BHJ (non-V1 fact scans: Hive / V2 / V2Filter). The cascade up from a non-V1 + * scan reaches a CometBroadcastHashJoinExec + CometBroadcastExchangeExec build side; Spark's + * class-sensitive sameResult check in PlanAdaptiveDynamicPruningFilters.scala:50-57 fails to + * match. Tag the BHJ's build-side BroadcastExchangeExec with SKIP_COMET_BROADCAST_TAG. With + * the build exchange Spark-native, Comet's BHJ conversion fails its + * forall(_.isInstanceOf[CometNativeExec]) guard and the BHJ stays Spark. Spark's rule then + * matches and creates SubqueryBroadcastExec. + * + * 2. SAB + matching BHJ on V1. CometScanRule.transformV1Scan already rejects the V1 fact scan via + * isAqeDynamicPruningFilter; the cascade keeps the BHJ and its BroadcastExchangeExec + * Spark-native. No tagging needed; this rule is a no-op for V1 BHJ. TPC-DS Q7 on 3.4 stays the + * same shape it has today, including the Comet acceleration on dim scans below the Spark + * broadcast. + * + * 3. SAB with no matching BHJ (V1 SMJ self-join, SPARK-32509 with + * AUTO_BROADCASTJOIN_THRESHOLD=-1). The logical Partition-Pruning rule attaches the SAB to only + * one branch of the self-join. transformV1Scan falls back the SAB-bearing scan; the peer scan (no + * SAB) Cometizes, producing canonical asymmetry that breaks shuffle exchange reuse. Tag the peer + * scan with SKIP_COMET_SCAN_TAG and any shuffle whose subtree contains the peer scan with + * SKIP_COMET_SHUFFLE_TAG. Both branches end up Spark-native with matching canonical forms; + * Spark's rule replaces the SAB with TrueLiteral, and FileSourceScanExec.doCanonicalize strips it + * via filterUnusedDynamicPruningExpressions (DataSourceScanExec.scala:731,736), restoring + * canonical symmetry for reuse. + * + * 4. SubqueryBroadcastExec-bearing scans (AQE re-optimize). On re-optimize cycles, + * ASPE.preprocessingRules (PlanAdaptiveSubqueries) fills the DPP slot with the already- + * materialized SubqueryBroadcastExec (produced by Spark's PlanAdaptiveDynamicPruningFilters on a + * previous pass) rather than the original SAB. The freshly-planned BroadcastExchangeExec on the + * main BHJ's build side is a new instance with no SKIP_COMET_BROADCAST_TAG carried over, so + * CometExecRule would Cometize it and lose AQE stageCache broadcast reuse with the DPP subquery's + * Spark broadcast. The rule also scans for SubqueryBroadcastExec (descending into QueryStageExec + * via AdaptiveSparkPlanHelper since the fact scan is already inside a materialized stage), + * extracts its buildKeys, finds the matching BHJ by exprId, and tags the build BE. + * + * Non-AQE DPP (#4011) is untouched: it produces SubqueryBroadcastExec, not the adaptive variant, + * and is handled by CometExecRule.convertSubqueryBroadcasts. + * + * Known limitation on 3.4: cross-plan scalar-subquery DPP. An SAB in a scalar subquery cannot see + * a matching BHJ in the main query because at prep-rule time each AdaptiveSparkPlanExec sees only + * its own plan. When the match fails, Spark's own rule falls back to TrueLiteral or aggregate + * SubqueryExec (same behavior as Spark-without-Comet on 3.4). + * + * @see + * PlanAdaptiveDynamicPruningFilters (Spark's rule this code arranges to succeed) + * @see + * CometPlanAdaptiveDynamicPruningFilters (Spark 3.5+ equivalent via queryStageOptimizerRule) + */ +case object CometSpark34AqeDppFallbackRule + extends Rule[SparkPlan] + with AdaptiveSparkPlanHelper + with Logging { + + override def apply(plan: SparkPlan): SparkPlan = { + // Registered only on Spark < 3.5 via injectPreSpark35QueryStagePrepRuleShim. If the + // 3.5+ shim mis-registers this rule, fail loud rather than silently disable Comet. + assert( + !isSpark35Plus, + "CometSpark34AqeDppFallbackRule must only be registered on Spark < 3.5; " + + "see ShimCometSparkSessionExtensions.injectPreSpark35QueryStagePrepRuleShim") + + if (!conf.dynamicPartitionPruningEnabled) return plan + + val sabScans = findSabScans(plan) + val sbScans = findSubqueryBroadcastScans(plan) + if (sabScans.isEmpty && sbScans.isEmpty) return plan + + sabScans.foreach { case (scan, sab) => + tagForSab(plan, scan, sab) + } + + // AQE re-optimization path: see case 4 in the class-level docstring. On subsequent + // re-optimize cycles Spark's PlanAdaptiveSubqueries hands us a SubqueryBroadcastExec (not + // the original SAB), and the freshly-planned main-BHJ build BE has no tag carried over. + sbScans.foreach { case (scan, sb) => + tagForSubqueryBroadcast(plan, scan, sb) + } + + // This rule only tags; it never rewrites the plan structurally. + plan + } + + /** + * Find every scan whose `partitionFilters` contain a `SubqueryAdaptiveBroadcastExec`. + * + * Mirrors the SAB pattern matched by Spark's `PlanAdaptiveDynamicPruningFilters.apply` (lines + * 41-43): + * {{{ + * case DynamicPruningExpression(InSubqueryExec( + * value, SubqueryAdaptiveBroadcastExec(...), ...)) + * }}} + * + * Returns the scan node itself and the SAB found on it. + */ + private def findSabScans(plan: SparkPlan): Seq[(SparkPlan, SubqueryAdaptiveBroadcastExec)] = { + val buf = scala.collection.mutable.ArrayBuffer[(SparkPlan, SubqueryAdaptiveBroadcastExec)]() + foreach(plan) { node => + extractFirstSab(node).foreach(sab => buf += ((node, sab))) + } + buf.toSeq + } + + private def extractFirstSab(node: SparkPlan): Option[SubqueryAdaptiveBroadcastExec] = { + node.expressions + .flatMap(_.collect { + case DynamicPruningExpression(inSub: InSubqueryExec) + if inSub.plan.isInstanceOf[SubqueryAdaptiveBroadcastExec] => + inSub.plan.asInstanceOf[SubqueryAdaptiveBroadcastExec] + }) + .headOption + } + + /** + * Find every scan whose DPP `partitionFilters` contain a `SubqueryBroadcastExec` - i.e. a DPP + * subquery that Spark's `PlanAdaptiveDynamicPruningFilters` has already materialized on a + * previous AQE pass. See the comment block in `apply` for why we need this in addition to the + * SAB path. + * + * Uses `AdaptiveSparkPlanHelper.foreach`, which descends through `QueryStageExec.plan` and + * `AdaptiveSparkPlanExec.executedPlan`. By the re-optimize pass where this matters, the fact + * scan is inside a `ShuffleQueryStageExec` whose `children` is `Seq.empty`, so a plain + * `plan.foreach` would stop there and miss the scan's `SubqueryBroadcastExec`. + */ + private def findSubqueryBroadcastScans( + plan: SparkPlan): Seq[(SparkPlan, SubqueryBroadcastExec)] = { + val buf = scala.collection.mutable.ArrayBuffer[(SparkPlan, SubqueryBroadcastExec)]() + foreach(plan) { node => + extractFirstSubqueryBroadcast(node).foreach(sb => buf += ((node, sb))) + } + buf.toSeq + } + + private def extractFirstSubqueryBroadcast(node: SparkPlan): Option[SubqueryBroadcastExec] = { + node.expressions + .flatMap(_.collect { + case DynamicPruningExpression(inSub: InSubqueryExec) + if inSub.plan.isInstanceOf[SubqueryBroadcastExec] => + inSub.plan.asInstanceOf[SubqueryBroadcastExec] + }) + .headOption + } + + /** + * Place tags for a single SAB-bearing scan. Behavior depends on whether a matching + * broadcast-hash join exists in the plan: + * - Matching BHJ found: tag its build-side `BroadcastExchangeExec` (case 1 above). + * - No matching BHJ: tag peer scans + their shuffles for canonical symmetry (case 3). + */ + private def tagForSab( + plan: SparkPlan, + scan: SparkPlan, + sab: SubqueryAdaptiveBroadcastExec): Unit = { + val sabKeyIds: Set[Any] = sab.buildKeys.flatMap(_.references.map(_.exprId)).toSet + if (sabKeyIds.isEmpty) { + logWarning(s"SAB '${sab.name}' has empty buildKeys; skipping") + return + } + + findMatchingBroadcastJoin(plan, sabKeyIds) match { + case Some(buildSide) => + tagBhjBuildBroadcast(buildSide, sab.name) + case None => + tagPeerScansAndShuffles(plan, scan, sab.name) + } + } + + /** + * Tag the matching BHJ's build-side `BroadcastExchangeExec` for a scan whose DPP filter holds a + * `SubqueryBroadcastExec` (post-PADPF form). Only the matching-BHJ case applies here: if PADPF + * already ran on a previous AQE cycle, it would have fallen back to aggregate `SubqueryExec` or + * `Literal.TrueLiteral` instead of producing a `SubqueryBroadcastExec` when no BHJ matched - so + * a `SubqueryBroadcastExec` always implies a BHJ with compatible build keys somewhere in the + * plan. No peer-scan tagging path is needed because the SMJ self-join case (SPARK-32509) never + * produces a `SubqueryBroadcastExec`. + */ + private def tagForSubqueryBroadcast( + plan: SparkPlan, + scan: SparkPlan, + sb: SubqueryBroadcastExec): Unit = { + val keyIds: Set[Any] = sb.buildKeys.flatMap(_.references.map(_.exprId)).toSet + if (keyIds.isEmpty) { + logWarning(s"SubqueryBroadcast '${sb.name}' has empty buildKeys; skipping") + return + } + + findMatchingBroadcastJoin(plan, keyIds) match { + case Some(buildSide) => + tagBhjBuildBroadcast(buildSide, sb.name) + case None => + // Nothing to tag. Either the BHJ has a different buildKey exprId set (e.g. AQE + // rewrote attributes across stages) or the matching join isn't in this plan snapshot. + logDebug( + s"SubqueryBroadcast '${sb.name}': no matching BHJ on this plan snapshot; no tag placed") + } + } + + /** + * Walk from a BHJ's build-side subtree root to the first `BroadcastExchangeExec` and tag it + * with `SKIP_COMET_BROADCAST_TAG`. If the subtree root already IS a BroadcastExchangeExec (most + * common), that's the one. Otherwise we walk until we find one or give up. + * + * Tagging the exchange is enough: when it stays Spark-native, `CometBroadcastHashJoinExec`'s + * conversion guard (`forall(_.isInstanceOf[CometNativeExec])`) fails because a Spark + * `BroadcastExchangeExec` is not a `CometNativeExec`, so the BHJ stays Spark too. Spark's + * `PlanAdaptiveDynamicPruningFilters` can then match it via `sameResult` + * (`PlanAdaptiveDynamicPruningFilters.scala:50-57`) and create a `SubqueryBroadcastExec`. + */ + private def tagBhjBuildBroadcast(buildSide: SparkPlan, dppName: String): Unit = { + val found = buildSide.find { + case _: BroadcastExchangeExec => true + case _ => false + } + found match { + case Some(be: BroadcastExchangeExec) => + be.setTagValue(CometExecRule.SKIP_COMET_BROADCAST_TAG, ()) + logDebug(s"Tagged BroadcastExchangeExec for DPP '$dppName' (BHJ build side)") + case _ => + logWarning( + s"DPP '$dppName': matched BHJ but could not locate BroadcastExchangeExec on " + + "build side; skipping") + } + } + + /** + * For SMJ-shaped DPP (SPARK-32509 with AUTO_BROADCASTJOIN_THRESHOLD=-1), there is no BHJ. The + * logical Partition-Pruning rule attaches the SAB to only one branch of a self-join; the peer + * branch has no SAB, so `transformV1Scan`'s V1 AQE DPP fallback fires for the SAB-bearing scan + * only. Peer scans Cometize normally, leading to canonical asymmetry that breaks shuffle + * exchange reuse (0 `ReusedExchangeExec` instead of 1). + * + * Tag every peer scan (same relation as the SAB-bearing scan, different instance) with + * `SKIP_COMET_SCAN_TAG`, and every `ShuffleExchangeExec` whose subtree contains a peer scan + * with `SKIP_COMET_SHUFFLE_TAG`. Both branches end up Spark-native with matching canonical + * forms. + * + * For canonical equality the SAB-bearing scan's side is already handled: + * `FileSourceScanExec.doCanonicalize` strips `DynamicPruningExpression(Literal.TrueLiteral)` + * via `filterUnusedDynamicPruningExpressions` (`DataSourceScanExec.scala:731,736`), and Spark's + * rule replaces the SAB with `TrueLiteral` when it can't match. + */ + private def tagPeerScansAndShuffles( + plan: SparkPlan, + sabScan: SparkPlan, + sabName: String): Unit = { + val sabRelation = sabScan match { + case f: FileSourceScanExec => Some(f.relation) + case _ => None + } + if (sabRelation.isEmpty) { + logDebug( + s"SAB '$sabName': non-V1 scan with no BHJ match; no peer-tagging heuristic available") + return + } + val sabRel = sabRelation.get + + var taggedScans = 0 + foreach(plan) { + case peer: FileSourceScanExec if (peer ne sabScan) && sameRelation(peer.relation, sabRel) => + peer.setTagValue(CometScanRule.SKIP_COMET_SCAN_TAG, ()) + taggedScans += 1 + case _ => + } + + var taggedShuffles = 0 + foreach(plan) { + case sh: ShuffleExchangeExec if shuffleSubtreeContainsMatchingScan(sh, sabRel, sabScan) => + sh.setTagValue(CometExecRule.SKIP_COMET_SHUFFLE_TAG, ()) + taggedShuffles += 1 + case _ => + } + + logDebug( + s"SAB '$sabName' (no BHJ match): tagged $taggedScans peer scan(s) + " + + s"$taggedShuffles shuffle(s)") + } + + private def sameRelation(a: HadoopFsRelation, b: HadoopFsRelation): Boolean = { + (a eq b) || + (a.location.rootPaths == b.location.rootPaths && + a.dataSchema == b.dataSchema && + a.partitionSchema == b.partitionSchema) + } + + private def shuffleSubtreeContainsMatchingScan( + shuffle: ShuffleExchangeExec, + sabRelation: HadoopFsRelation, + sabScan: SparkPlan): Boolean = { + find(shuffle.child) { + case scan: FileSourceScanExec => + (scan eq sabScan) || sameRelation(scan.relation, sabRelation) + case _ => false + }.isDefined + } + + /** + * Mirrors `PlanAdaptiveDynamicPruningFilters.apply` lines 50-57: + * {{{ + * find(rootPlan) { + * case BroadcastHashJoinExec(_, _, _, BuildLeft, _, left, _, _) => + * left.sameResult(exchange) + * case BroadcastHashJoinExec(_, _, _, BuildRight, _, _, right, _) => + * right.sameResult(exchange) + * } + * }}} + * + * Our rule runs BEFORE `CometScanRule`/`CometExecRule`, so the plan is entirely Spark-native at + * this point. We only match `BroadcastHashJoinExec`. Instead of `sameResult` we match on + * join-side exprId equality with the SAB's buildKeys; this is semantically equivalent because + * SAB buildKeys originate from the same logical plan as the join's build-side keys. + * + * Returns the BHJ's build-side subtree (the one we want to keep Spark-native), or None if no + * matching join is found. + */ + private def findMatchingBroadcastJoin( + plan: SparkPlan, + sabKeyIds: Set[Any]): Option[SparkPlan] = { + var result: Option[SparkPlan] = None + find(plan) { + case j: BroadcastHashJoinExec => + val joinBuildKeys = j.buildSide match { + case BuildLeft => j.leftKeys + case BuildRight => j.rightKeys + } + val joinKeyIds: Set[Any] = joinBuildKeys.flatMap(_.references.map(_.exprId)).toSet + if (sabKeyIds == joinKeyIds) { + result = Some(j.buildSide match { + case BuildLeft => j.left + case BuildRight => j.right + }) + true + } else { + false + } + case _ => false + } + result + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/operator/CometNativeScan.scala b/spark/src/main/scala/org/apache/comet/serde/operator/CometNativeScan.scala index da7f24183b..70f06f5741 100644 --- a/spark/src/main/scala/org/apache/comet/serde/operator/CometNativeScan.scala +++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometNativeScan.scala @@ -23,7 +23,7 @@ import scala.collection.mutable.ListBuffer import scala.jdk.CollectionConverters._ import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, PlanExpression} +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getExistenceDefaultValues import org.apache.spark.sql.comet.{CometNativeExec, CometNativeScanExec, CometScanExec} import org.apache.spark.sql.execution.{FileSourceScanExec, InSubqueryExec, SubqueryAdaptiveBroadcastExec} @@ -31,7 +31,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.comet.{CometConf, ConfigEntry} import org.apache.comet.CometConf.COMET_EXEC_ENABLED -import org.apache.comet.CometSparkSessionExtensions.{hasExplainInfo, withInfo} +import org.apache.comet.CometSparkSessionExtensions.{hasExplainInfo, isSpark35Plus, withInfo} import org.apache.comet.objectstore.NativeConfig import org.apache.comet.parquet.CometParquetUtils import org.apache.comet.serde.{CometOperatorSerde, Compatible, OperatorOuterClass, SupportLevel} @@ -56,11 +56,17 @@ object CometNativeScan extends CometOperatorSerde[CometScanExec] with Logging { withInfo(scanExec, s"Full native scan disabled because ${COMET_EXEC_ENABLED.key} disabled") } - // Native DataFusion doesn't support AQE DPP (SubqueryAdaptiveBroadcastExec). - // Non-AQE DPP (SubqueryBroadcastExec/SubqueryExec) is supported through the lazy + // AQE DPP (SubqueryAdaptiveBroadcastExec) is converted to CometSubqueryBroadcastExec + // by CometPlanAdaptiveDynamicPruningFilters (queryStageOptimizerRule, Spark 3.5+). + // Non-AQE DPP (SubqueryBroadcastExec/SubqueryExec) is converted by + // CometExecRule.convertSubqueryBroadcasts. Both are resolved through the lazy // partition serialization path in CometNativeScanExec. - if (scanExec.partitionFilters.exists(isAqeDynamicPruningFilter)) { - withInfo(scanExec, "Native DataFusion scan does not support AQE dynamic pruning") + // + // On Spark 3.4, injectQueryStageOptimizerRule is unavailable, so the AQE DPP conversion + // rule can't run. CometScanRule.transformV1Scan rejects AQE DPP on 3.4, so this check + // is a safety net: if the scan somehow reached here with AQE DPP on 3.4, reject it. + if (!isSpark35Plus && scanExec.partitionFilters.exists(isAqeDynamicPruningFilter)) { + withInfo(scanExec, "Native DataFusion scan does not support AQE DPP on Spark 3.4") } if (SQLConf.get.ignoreCorruptFiles || @@ -82,9 +88,6 @@ object CometNativeScan extends CometOperatorSerde[CometScanExec] with Logging { !hasExplainInfo(scanExec) } - private def isDynamicPruningFilter(e: Expression): Boolean = - e.exists(_.isInstanceOf[PlanExpression[_]]) - /** Detects AQE DPP (SubqueryAdaptiveBroadcastExec), as opposed to non-AQE DPP. */ private def isAqeDynamicPruningFilter(e: Expression): Boolean = e.exists { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index 8012b18b22..c569c34d01 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -45,7 +45,7 @@ import org.apache.spark.util.io.ChunkedByteBuffer import com.google.common.base.Objects -import org.apache.comet.{CometConf, CometRuntimeException, ConfigEntry} +import org.apache.comet.{CometConf, ConfigEntry} import org.apache.comet.serde.OperatorOuterClass import org.apache.comet.serde.operator.CometSink import org.apache.comet.shims.ShimCometBroadcastExchangeExec @@ -138,14 +138,16 @@ case class CometBroadcastExchangeExec( case ShuffleQueryStageExec(_, ReusedExchangeExec(_, plan), _) if plan.isInstanceOf[CometPlan] => CometExec.getByteArrayRdd(plan.asInstanceOf[CometPlan]).collect() - case AQEShuffleReadExec(s: ShuffleQueryStageExec, _) => - throw new CometRuntimeException( - "Child of CometBroadcastExchangeExec should be CometExec, " + - s"but got: ${s.plan.getClass}") case _ => - throw new CometRuntimeException( - "Child of CometBroadcastExchangeExec should be CometExec, " + - s"but got: ${child.getClass}") + // Non-Comet child (e.g., RowToColumnar -> LocalTableScan). Happens when + // AQE re-optimizes inside an ASPE and replaces the original Comet scan + // with a Spark-native node (e.g., empty broadcast triggers LocalTableScan). + logWarning( + "CometBroadcastExchangeExec child is not CometPlan: " + + s"${child.getClass.getSimpleName}. " + + "Wrapping in CometSparkToColumnarExec for Arrow serialization.") + val cometChild = CometSparkToColumnarExec(ColumnarToRowExec(child)) + CometExec.getByteArrayRdd(cometChild).collect() } val numRows = countsAndBytes.map(_._1).sum diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeScanExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeScanExec.scala index 050c6d431f..45d708aaef 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeScanExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeScanExec.scala @@ -76,7 +76,8 @@ case class CometNativeScanExec( with DataSourceScanExec with ShimStreamSourceAwareSparkPlan { - override lazy val metadata: Map[String, String] = originalPlan.metadata + override lazy val metadata: Map[String, String] = + if (originalPlan != null) originalPlan.metadata else Map.empty /** * Prepare subquery plans before execution. @@ -313,6 +314,18 @@ case class CometNativeScanExec( } override def doCanonicalize(): CometNativeScanExec = { + // Canonicalize originalPlan but strip its DPP partition filters. + // originalPlan carries column selection and schema info needed for + // equals/hashCode. But its partitionFilters may have stale DPP + // expressions (e.g., SABs not yet converted to TrueLiteral) that + // would prevent exchange reuse between otherwise-identical scans. + val canonOriginal = if (originalPlan != null) { + val stripped = originalPlan.copy(partitionFilters = + CometScanUtils.filterUnusedDynamicPruningExpressions(originalPlan.partitionFilters)) + stripped.doCanonicalize() + } else { + null + } CometNativeScanExec( nativeOp, relation, @@ -326,7 +339,7 @@ case class CometNativeScanExec( QueryPlan.normalizePredicates(dataFilters, output), None, disableBucketedScan, - originalPlan.doCanonicalize(), + canonOriginal, SerializedPlan(None), null, // Transient scan not needed for canonicalization "" diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometScanUtils.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometScanUtils.scala index 4cd3996669..19a2b53af1 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometScanUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometScanUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.comet import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expression, Literal} +import org.apache.spark.sql.execution.{InSubqueryExec, SubqueryAdaptiveBroadcastExec} object CometScanUtils { @@ -28,6 +29,20 @@ object CometScanUtils { * DynamicPruningExpression(Literal.TrueLiteral) during Physical Planning */ def filterUnusedDynamicPruningExpressions(predicates: Seq[Expression]): Seq[Expression] = { - predicates.filterNot(_ == DynamicPruningExpression(Literal.TrueLiteral)) + // Strip DPP expressions for canonicalization. Matches Spark's + // FileSourceScanExec.filterUnusedDynamicPruningExpressions (TrueLiteral). + // Also strips unconverted SAB wrappers because AQE stageCache canonicalizes + // before our queryStageOptimizerRule converts them, so they would prevent + // exchange reuse between otherwise-identical scans. + predicates.filterNot { + case DynamicPruningExpression(Literal.TrueLiteral) => true + case DynamicPruningExpression( + InSubqueryExec(_, _: CometSubqueryAdaptiveBroadcastExec, _, _, _, _)) => + true + case DynamicPruningExpression( + InSubqueryExec(_, _: SubqueryAdaptiveBroadcastExec, _, _, _, _)) => + true + case _ => false + } } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometSubqueryAdaptiveBroadcastExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometSubqueryAdaptiveBroadcastExec.scala new file mode 100644 index 0000000000..99063e8370 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometSubqueryAdaptiveBroadcastExec.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.{BaseSubqueryExec, SparkPlan, UnaryExecNode} + +/** + * Preserves SubqueryAdaptiveBroadcastExec data through Spark's PlanAdaptiveDynamicPruningFilters + * rule. + * + * Spark's PlanAdaptiveDynamicPruningFilters (queryStageOptimizerRule) pattern-matches on + * SubqueryAdaptiveBroadcastExec. When it can't find a matching BroadcastHashJoinExec (because + * Comet replaced it with CometBroadcastHashJoinExec), it replaces the DPP expression with + * Literal.TrueLiteral, disabling DPP entirely. + * + * CometExecRule wraps SABs in this class during queryStagePreparationRules (before + * queryStageOptimizerRules run). Spark's rule won't match this type. Then + * CometPlanAdaptiveDynamicPruningFilters (a custom queryStageOptimizerRule) unwraps the SAB data + * and converts it to CometSubqueryBroadcastExec with the join's BroadcastQueryStageExec for true + * broadcast reuse. + * + * Not executable - must be converted before execution. + */ +case class CometSubqueryAdaptiveBroadcastExec( + name: String, + indices: Seq[Int], + onlyInBroadcast: Boolean, + @transient buildPlan: LogicalPlan, + buildKeys: Seq[Expression], + child: SparkPlan) + extends BaseSubqueryExec + with UnaryExecNode { + + // This node must be converted to CometSubqueryBroadcastExec by + // CometPlanAdaptiveDynamicPruningFilters before execution. If we reach doExecute(), + // the rule didn't run (e.g., Spark 3.4 where injectQueryStageOptimizerRule is unavailable). + protected override def doExecute(): RDD[InternalRow] = { + throw QueryExecutionErrors.executeCodePathUnsupportedError( + "CometSubqueryAdaptiveBroadcastExec (should have been converted by " + + "CometPlanAdaptiveDynamicPruningFilters)") + } + + protected override def doCanonicalize(): SparkPlan = { + val keys = buildKeys.map(k => QueryPlan.normalizeExpressions(k, child.output)) + copy(name = "dpp", buildKeys = keys, child = child.canonicalized) + } + + override protected def withNewChildInternal( + newChild: SparkPlan): CometSubqueryAdaptiveBroadcastExec = + copy(child = newChild) +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index cc7a63f0cb..b2f0dde630 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -86,6 +86,18 @@ case class CometShuffleExchangeExec( "CometColumnarExchange" } + // Exclude originalPlan from canonical form. It's a reference to the + // pre-Comet Spark exchange kept for metrics, not semantic content. + // Without this, two identical CometShuffleExchangeExec nodes with + // different originalPlans (e.g., one scan has DPP filters, one doesn't) + // would fail to match in AQE's stageCache, preventing exchange reuse. + // Matches CometBroadcastExchangeExec.doCanonicalize which also nulls + // originalPlan. + override def doCanonicalize(): SparkPlan = { + val base = super.doCanonicalize().asInstanceOf[CometShuffleExchangeExec] + base.copy(originalPlan = null) + } + private lazy val serializer: Serializer = new UnsafeRowSerializer(child.output.size, longMetric("dataSize")) @@ -278,8 +290,9 @@ object CometShuffleExchangeExec // A Comet shuffle wrapped around a stage that still contains a Spark FileSourceScanExec // with DPP produces inefficient row<->columnar transitions. This only happens when the - // scan fell back (e.g., AQE DPP not supported). If the scan converted to - // CometNativeScanExec, stageContainsDPPScan won't match (it checks FileSourceScanExec). + // scan fell back to Spark (e.g., AQE DPP on Spark 3.4, or unsupported scan type). + // On 3.5+ with AQE DPP, the scan converts to CometNativeScanExec and + // stageContainsDPPScan won't match (it checks FileSourceScanExec). if (stageContainsDPPScan(s)) { withInfos(s, Set("Stage contains a scan with Dynamic Partition Pruning")) return None diff --git a/spark/src/main/spark-3.4/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala b/spark/src/main/spark-3.4/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala index 9276fe8190..55e3e1cf8a 100644 --- a/spark/src/main/spark-3.4/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala +++ b/spark/src/main/spark-3.4/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala @@ -39,8 +39,22 @@ trait ShimCometSparkSessionExtensions { true } - // injectQueryStageOptimizerRule not available on Spark 3.4 + // injectQueryStageOptimizerRule not available on Spark 3.4. + // CometPlanAdaptiveDynamicPruningFilters and CometReuseSubquery are not registered. + // On 3.4, Spark's PlanAdaptiveDynamicPruningFilters handles SABs directly (converting to + // TrueLiteral when it can't find BroadcastHashJoinExec, disabling DPP). def injectQueryStageOptimizerRuleShim( extensions: SparkSessionExtensions, rule: Rule[SparkPlan]): Unit = {} + + // Registers a queryStagePrepRule only on Spark < 3.5. The 3.5+ variants no-op this shim. + // Used by CometSpark34AqeDppFallbackRule, which is a correctness workaround for AQE DPP on + // Spark 3.4 where injectQueryStageOptimizerRule (and therefore + // CometPlanAdaptiveDynamicPruningFilters) is unavailable. See + // CometSpark34AqeDppFallbackRule's class docstring for details. + def injectPreSpark35QueryStagePrepRuleShim( + extensions: SparkSessionExtensions, + rule: Rule[SparkPlan]): Unit = { + extensions.injectQueryStagePrepRule(_ => rule) + } } diff --git a/spark/src/main/spark-3.4/org/apache/comet/shims/ShimSubqueryBroadcast.scala b/spark/src/main/spark-3.4/org/apache/comet/shims/ShimSubqueryBroadcast.scala index 292ed2cb18..20663bb05b 100644 --- a/spark/src/main/spark-3.4/org/apache/comet/shims/ShimSubqueryBroadcast.scala +++ b/spark/src/main/spark-3.4/org/apache/comet/shims/ShimSubqueryBroadcast.scala @@ -19,20 +19,25 @@ package org.apache.comet.shims -import org.apache.spark.sql.execution.{SubqueryAdaptiveBroadcastExec, SubqueryBroadcastExec} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution.{SparkPlan, SubqueryAdaptiveBroadcastExec, SubqueryBroadcastExec} trait ShimSubqueryBroadcast { - /** - * Gets the build key indices from SubqueryAdaptiveBroadcastExec. Spark 3.x has `index: Int`, - * Spark 4.x has `indices: Seq[Int]`. - */ def getSubqueryBroadcastIndices(sab: SubqueryAdaptiveBroadcastExec): Seq[Int] = { Seq(sab.index) } - /** Same version shim for SubqueryBroadcastExec. */ def getSubqueryBroadcastExecIndices(sub: SubqueryBroadcastExec): Seq[Int] = { Seq(sub.index) } + + def createSubqueryBroadcastExec( + name: String, + indices: Seq[Int], + buildKeys: Seq[Expression], + child: SparkPlan): SubqueryBroadcastExec = { + assert(indices.length == 1, s"Multi-index DPP not supported on Spark 3.4: indices=$indices") + SubqueryBroadcastExec(name, indices.head, buildKeys, child) + } } diff --git a/spark/src/main/spark-3.5/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala b/spark/src/main/spark-3.5/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala index 15bf156f89..e28bdfc55b 100644 --- a/spark/src/main/spark-3.5/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala +++ b/spark/src/main/spark-3.5/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala @@ -45,4 +45,10 @@ trait ShimCometSparkSessionExtensions { rule: Rule[SparkPlan]): Unit = { extensions.injectQueryStageOptimizerRule(_ => rule) } + + // No-op on Spark >= 3.5. See the Spark 3.4 shim and + // CometSpark34AqeDppFallbackRule's class docstring for why this shim exists. + def injectPreSpark35QueryStagePrepRuleShim( + extensions: SparkSessionExtensions, + rule: Rule[SparkPlan]): Unit = {} } diff --git a/spark/src/main/spark-3.5/org/apache/comet/shims/ShimSubqueryBroadcast.scala b/spark/src/main/spark-3.5/org/apache/comet/shims/ShimSubqueryBroadcast.scala index 292ed2cb18..309251f1b7 100644 --- a/spark/src/main/spark-3.5/org/apache/comet/shims/ShimSubqueryBroadcast.scala +++ b/spark/src/main/spark-3.5/org/apache/comet/shims/ShimSubqueryBroadcast.scala @@ -19,20 +19,25 @@ package org.apache.comet.shims -import org.apache.spark.sql.execution.{SubqueryAdaptiveBroadcastExec, SubqueryBroadcastExec} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution.{SparkPlan, SubqueryAdaptiveBroadcastExec, SubqueryBroadcastExec} trait ShimSubqueryBroadcast { - /** - * Gets the build key indices from SubqueryAdaptiveBroadcastExec. Spark 3.x has `index: Int`, - * Spark 4.x has `indices: Seq[Int]`. - */ def getSubqueryBroadcastIndices(sab: SubqueryAdaptiveBroadcastExec): Seq[Int] = { Seq(sab.index) } - /** Same version shim for SubqueryBroadcastExec. */ def getSubqueryBroadcastExecIndices(sub: SubqueryBroadcastExec): Seq[Int] = { Seq(sub.index) } + + def createSubqueryBroadcastExec( + name: String, + indices: Seq[Int], + buildKeys: Seq[Expression], + child: SparkPlan): SubqueryBroadcastExec = { + assert(indices.length == 1, s"Multi-index DPP not supported on Spark 3.5: indices=$indices") + SubqueryBroadcastExec(name, indices.head, buildKeys, child) + } } diff --git a/spark/src/main/spark-3.x/org/apache/comet/shims/ShimPrepareExecutedPlan.scala b/spark/src/main/spark-3.x/org/apache/comet/shims/ShimPrepareExecutedPlan.scala new file mode 100644 index 0000000000..cee917ec8b --- /dev/null +++ b/spark/src/main/spark-3.x/org/apache/comet/shims/ShimPrepareExecutedPlan.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.{QueryExecution, SparkPlan} +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec + +trait ShimPrepareExecutedPlan { + def shimPrepareExecutedPlan( + adaptivePlan: AdaptiveSparkPlanExec, + plan: LogicalPlan): SparkPlan = { + QueryExecution.prepareExecutedPlan(adaptivePlan.context.session, plan, adaptivePlan.context) + } +} diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/ShimPrepareExecutedPlan.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimPrepareExecutedPlan.scala new file mode 100644 index 0000000000..cee917ec8b --- /dev/null +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/ShimPrepareExecutedPlan.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.{QueryExecution, SparkPlan} +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec + +trait ShimPrepareExecutedPlan { + def shimPrepareExecutedPlan( + adaptivePlan: AdaptiveSparkPlanExec, + plan: LogicalPlan): SparkPlan = { + QueryExecution.prepareExecutedPlan(adaptivePlan.context.session, plan, adaptivePlan.context) + } +} diff --git a/spark/src/main/spark-4.1/org/apache/comet/shims/ShimPrepareExecutedPlan.scala b/spark/src/main/spark-4.1/org/apache/comet/shims/ShimPrepareExecutedPlan.scala new file mode 100644 index 0000000000..145bfd02ee --- /dev/null +++ b/spark/src/main/spark-4.1/org/apache/comet/shims/ShimPrepareExecutedPlan.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.{QueryExecution, SparkPlan} +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec + +trait ShimPrepareExecutedPlan { + def shimPrepareExecutedPlan( + adaptivePlan: AdaptiveSparkPlanExec, + plan: LogicalPlan): SparkPlan = { + QueryExecution.prepareExecutedPlan(plan, adaptivePlan.context) + } +} diff --git a/spark/src/main/spark-4.2/org/apache/comet/shims/ShimPrepareExecutedPlan.scala b/spark/src/main/spark-4.2/org/apache/comet/shims/ShimPrepareExecutedPlan.scala new file mode 100644 index 0000000000..145bfd02ee --- /dev/null +++ b/spark/src/main/spark-4.2/org/apache/comet/shims/ShimPrepareExecutedPlan.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.{QueryExecution, SparkPlan} +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec + +trait ShimPrepareExecutedPlan { + def shimPrepareExecutedPlan( + adaptivePlan: AdaptiveSparkPlanExec, + plan: LogicalPlan): SparkPlan = { + QueryExecution.prepareExecutedPlan(plan, adaptivePlan.context) + } +} diff --git a/spark/src/main/spark-4.x/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala b/spark/src/main/spark-4.x/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala index cac636c45c..c78cc6b2d1 100644 --- a/spark/src/main/spark-4.x/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala +++ b/spark/src/main/spark-4.x/org/apache/comet/shims/ShimCometSparkSessionExtensions.scala @@ -38,4 +38,10 @@ trait ShimCometSparkSessionExtensions { rule: Rule[SparkPlan]): Unit = { extensions.injectQueryStageOptimizerRule(_ => rule) } + + // No-op on Spark >= 3.5. See the Spark 3.4 shim and + // CometSpark34AqeDppFallbackRule's class docstring for why this shim exists. + def injectPreSpark35QueryStagePrepRuleShim( + extensions: SparkSessionExtensions, + rule: Rule[SparkPlan]): Unit = {} } diff --git a/spark/src/main/spark-4.x/org/apache/comet/shims/ShimSubqueryBroadcast.scala b/spark/src/main/spark-4.x/org/apache/comet/shims/ShimSubqueryBroadcast.scala index 73d9e53c4a..42c472b5bb 100644 --- a/spark/src/main/spark-4.x/org/apache/comet/shims/ShimSubqueryBroadcast.scala +++ b/spark/src/main/spark-4.x/org/apache/comet/shims/ShimSubqueryBroadcast.scala @@ -19,20 +19,24 @@ package org.apache.comet.shims -import org.apache.spark.sql.execution.{SubqueryAdaptiveBroadcastExec, SubqueryBroadcastExec} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.execution.{SparkPlan, SubqueryAdaptiveBroadcastExec, SubqueryBroadcastExec} trait ShimSubqueryBroadcast { - /** - * Gets the build key indices from SubqueryAdaptiveBroadcastExec. Spark 3.x has `index: Int`, - * Spark 4.x has `indices: Seq[Int]`. - */ def getSubqueryBroadcastIndices(sab: SubqueryAdaptiveBroadcastExec): Seq[Int] = { sab.indices } - /** Same version shim for SubqueryBroadcastExec. */ def getSubqueryBroadcastExecIndices(sub: SubqueryBroadcastExec): Seq[Int] = { sub.indices } + + def createSubqueryBroadcastExec( + name: String, + indices: Seq[Int], + buildKeys: Seq[Expression], + child: SparkPlan): SubqueryBroadcastExec = { + SubqueryBroadcastExec(name, indices, buildKeys, child) + } } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index d1cde77f78..1ac1659754 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -31,14 +31,15 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, CatalogTable} -import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expression, ExpressionInfo, Hex} +import org.apache.spark.sql.catalyst.expressions.{DynamicPruningExpression, Expression, ExpressionInfo, Hex, Literal} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateMode, BloomFilterAggregate} import org.apache.spark.sql.comet._ import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} +import org.apache.spark.sql.connector.catalog.InMemoryTableCatalog import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec} import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.execution.reuse.ReuseExchangeAndSubquery import org.apache.spark.sql.execution.window.WindowExec @@ -109,40 +110,7 @@ class CometExecSuite extends CometTestBase { } } - test("DPP fallback") { - withTempDir { path => - // create test data - val factPath = s"${path.getAbsolutePath}/fact.parquet" - val dimPath = s"${path.getAbsolutePath}/dim.parquet" - withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "false") { - val one_day = 24 * 60 * 60000 - val fact = Range(0, 100) - .map(i => (i, new java.sql.Date(System.currentTimeMillis() + i * one_day), i.toString)) - .toDF("fact_id", "fact_date", "fact_str") - fact.write.partitionBy("fact_date").parquet(factPath) - val dim = Range(0, 10) - .map(i => (i, new java.sql.Date(System.currentTimeMillis() + i * one_day), i.toString)) - .toDF("dim_id", "dim_date", "dim_str") - dim.write.parquet(dimPath) - } - - // note that this test does not trigger DPP with v2 data source - Seq("parquet").foreach { v1List => - withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> v1List) { - spark.read.parquet(factPath).createOrReplaceTempView("dpp_fact") - spark.read.parquet(dimPath).createOrReplaceTempView("dpp_dim") - val df = - spark.sql( - "select * from dpp_fact join dpp_dim on fact_date = dim_date where dim_id > 7") - val (_, cometPlan) = checkSparkAnswer(df) - val infos = new ExtendedExplainInfo().generateExtendedInfo(cometPlan) - assert(infos.contains("AQE Dynamic Partition Pruning is not supported")) - } - } - } - } - - test("DPP fallback avoids inefficient Comet shuffle (#3874)") { + test("AQE DPP: fallback avoids inefficient Comet shuffle (#3874)") { withTempDir { path => val factPath = s"${path.getAbsolutePath}/fact.parquet" val dimPath = s"${path.getAbsolutePath}/dim.parquet" @@ -179,7 +147,7 @@ class CometExecSuite extends CometTestBase { } } - test("non-AQE DPP with BHJ works with CometNativeScanExec") { + test("non-AQE DPP: BHJ works with CometNativeScanExec") { withTempDir { path => val factPath = s"${path.getAbsolutePath}/fact.parquet" val dimPath = s"${path.getAbsolutePath}/dim.parquet" @@ -218,13 +186,13 @@ class CometExecSuite extends CometTestBase { val infos = new ExtendedExplainInfo().generateExtendedInfo(cometPlan) assert( - !infos.contains("AQE Dynamic Partition Pruning is not supported"), + !infos.contains("AQE Dynamic Partition Pruning"), s"Should not fall back for non-AQE DPP:\n$infos") } } } - test("non-AQE DPP with SMJ works with CometNativeScanExec") { + test("non-AQE DPP: SMJ works with CometNativeScanExec") { withTempDir { path => val factPath = s"${path.getAbsolutePath}/fact.parquet" val dimPath = s"${path.getAbsolutePath}/dim.parquet" @@ -258,13 +226,13 @@ class CometExecSuite extends CometTestBase { val infos = new ExtendedExplainInfo().generateExtendedInfo(cometPlan) assert( - !infos.contains("AQE Dynamic Partition Pruning is not supported"), + !infos.contains("AQE Dynamic Partition Pruning"), s"Should not fall back for non-AQE DPP:\n$infos") } } } - test("non-AQE DPP with BHJ reuses broadcast exchange") { + test("non-AQE DPP: BHJ reuses broadcast exchange") { withTempDir { dir => val path = s"${dir.getAbsolutePath}/data" withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "false") { @@ -335,7 +303,7 @@ class CometExecSuite extends CometTestBase { } } - test("non-AQE DPP with non-atomic type (struct/array) join key") { + test("non-AQE DPP: non-atomic type (struct/array) join key") { withTempDir { dir => val path = s"${dir.getAbsolutePath}/data" withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "false") { @@ -373,6 +341,46 @@ class CometExecSuite extends CometTestBase { // Regression tests for DPP exchange/subquery reuse (from DynamicPartitionPruningSuite) + /** + * Asserts common AQE DPP plan-shape expectations. Pass `None` to skip a check. Counts cover the + * whole plan including subqueries (uses collectWithSubqueries). + * + * - `expectedSABs`: leftover `CometSubqueryAdaptiveBroadcastExec` nodes. After + * `CometPlanAdaptiveDynamicPruningFilters` runs, this should always be 0. + * - `expectedCometSubqueryBroadcasts`: `CometSubqueryBroadcastExec` count. Non-zero means + * broadcast reuse was wired up for native DPP. + * - `expectedReusedExchanges`: `ReusedExchangeExec` count. Confirms AQE stageCache matched + * the SAB's broadcast to an existing broadcast. + */ + private def assertAqeDppShape( + plan: SparkPlan, + expectedSABs: Int = 0, + expectedCometSubqueryBroadcasts: Option[Int] = None, + expectedReusedExchanges: Option[Int] = None): Unit = { + val remainingSABs = collectWithSubqueries(plan) { + case s: CometSubqueryAdaptiveBroadcastExec => s + } + assert( + remainingSABs.size == expectedSABs, + s"Expected $expectedSABs unconverted CometSubqueryAdaptiveBroadcastExec, " + + s"found ${remainingSABs.size}:\n${plan.treeString}") + expectedCometSubqueryBroadcasts.foreach { n => + val subqueries = collectWithSubqueries(plan) { case s: CometSubqueryBroadcastExec => + s + } + assert( + subqueries.size == n, + s"Expected $n CometSubqueryBroadcastExec, found ${subqueries.size}:" + + s"\n${plan.treeString}") + } + expectedReusedExchanges.foreach { n => + val reused = collectWithSubqueries(plan) { case e: ReusedExchangeExec => e } + assert( + reused.size == n, + s"Expected $n ReusedExchangeExec, found ${reused.size}:\n${plan.treeString}") + } + } + private def withDppTables(f: => Unit): Unit = { val factData = Seq( (1000, 1, 1, 10), @@ -463,7 +471,7 @@ class CometExecSuite extends CometTestBase { } } - test("DPP broadcast exchange reuse") { + test("non-AQE DPP: broadcast exchange reuse") { withDppTables { withSQLConf( SQLConf.USE_V1_SOURCE_LIST.key -> "parquet", @@ -486,7 +494,7 @@ class CometExecSuite extends CometTestBase { } } - test("DPP subquery reuse with uncorrelated scalar subquery") { + test("non-AQE DPP: subquery reuse with uncorrelated scalar subquery") { withDppTables { withSQLConf( SQLConf.USE_V1_SOURCE_LIST.key -> "parquet", @@ -523,7 +531,7 @@ class CometExecSuite extends CometTestBase { } } - test("DPP with non-atomic type (struct/array) join key") { + test("non-AQE DPP: non-atomic type (struct/array) join key with withDppTables") { withDppTables { withSQLConf( SQLConf.USE_V1_SOURCE_LIST.key -> "parquet", @@ -542,7 +550,7 @@ class CometExecSuite extends CometTestBase { } } - test("DPP non-atomic type uses CometSubqueryBroadcastExec") { + test("non-AQE DPP: non-atomic type uses CometSubqueryBroadcastExec") { withDppTables { withSQLConf( SQLConf.USE_V1_SOURCE_LIST.key -> "parquet", @@ -614,7 +622,7 @@ class CometExecSuite extends CometTestBase { } } - test("non-AQE DPP with two separate broadcast joins") { + test("non-AQE DPP: two separate broadcast joins") { withTempDir { dir => val path = s"${dir.getAbsolutePath}/data" withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "false") { @@ -665,7 +673,7 @@ class CometExecSuite extends CometTestBase { } } - test("non-AQE DPP fallback when broadcast exchange is not Comet") { + test("non-AQE DPP: fallback when broadcast exchange is not Comet") { withTempDir { dir => val path = s"${dir.getAbsolutePath}/data" withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "false") { @@ -702,7 +710,7 @@ class CometExecSuite extends CometTestBase { } } - test("non-AQE DPP with empty broadcast result") { + test("non-AQE DPP: empty broadcast result") { withTempDir { dir => val path = s"${dir.getAbsolutePath}/data" withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "false") { @@ -737,7 +745,7 @@ class CometExecSuite extends CometTestBase { } } - test("non-AQE DPP resolves both outer and inner partition filters") { + test("non-AQE DPP: resolves both outer and inner partition filters") { // CometNativeScanExec.partitionFilters and CometScanExec.partitionFilters contain // different InSubqueryExec instances. Both must be resolved for partition selection // to work correctly. This test verifies correct results, which requires both sets @@ -786,6 +794,1002 @@ class CometExecSuite extends CometTestBase { } } + // On 3.5+, CometPlanAdaptiveDynamicPruningFilters converts SABs to + // CometSubqueryBroadcastExec with broadcast reuse. On 3.4, AQE DPP falls back to Spark. + test("AQE DPP: BHJ works with CometNativeScanExec") { + withTempDir { path => + val factPath = s"${path.getAbsolutePath}/fact.parquet" + val dimPath = s"${path.getAbsolutePath}/dim.parquet" + withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "false") { + val one_day = 24 * 60 * 60000 + val fact = Range(0, 100) + .map(i => (i, new java.sql.Date(System.currentTimeMillis() + (i % 10) * one_day))) + .toDF("fact_id", "fact_date") + fact.write.partitionBy("fact_date").parquet(factPath) + val dim = Range(0, 10) + .map(i => (i, new java.sql.Date(System.currentTimeMillis() + i * one_day))) + .toDF("dim_id", "dim_date") + dim.write.parquet(dimPath) + } + + withSQLConf( + SQLConf.USE_V1_SOURCE_LIST.key -> "parquet", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + spark.read.parquet(factPath).createOrReplaceTempView("aqe_dpp_fact") + spark.read.parquet(dimPath).createOrReplaceTempView("aqe_dpp_dim") + val df = spark.sql( + "select * from aqe_dpp_fact join aqe_dpp_dim on fact_date = dim_date where dim_id > 7") + val (_, cometPlan) = checkSparkAnswer(df) + val infos = new ExtendedExplainInfo().generateExtendedInfo(cometPlan) + + if (isSpark35Plus) { + // Verify native scan with DPP + val nativeScans = collect(cometPlan) { case s: CometNativeScanExec => s } + assert(nativeScans.nonEmpty, "Expected CometNativeScanExec in plan") + val dppScans = nativeScans.filter( + _.partitionFilters.exists(_.isInstanceOf[DynamicPruningExpression])) + assert( + dppScans.nonEmpty, + "Expected at least one CometNativeScanExec with DynamicPruningExpression") + + // Verify CometSubqueryBroadcastExec with AdaptiveSparkPlanExec child + // (matches Spark's SubqueryBroadcastExec wrapping an ASPE that goes + // through AQE stageCache for broadcast reuse via ReusedExchangeExec) + val cometSubqueries = collectWithSubqueries(cometPlan) { + case s: CometSubqueryBroadcastExec => s + } + assert( + cometSubqueries.nonEmpty, + "Expected CometSubqueryBroadcastExec for broadcast reuse") + cometSubqueries.foreach { csb => + assert( + csb.child.isInstanceOf[AdaptiveSparkPlanExec], + "Expected AdaptiveSparkPlanExec child but got " + + s"${csb.child.getClass.getSimpleName}") + } + + // Verify broadcast reuse: the subquery's ASPE final plan should contain + // ReusedExchangeExec (AQE stageCache matched the join's broadcast) + import org.apache.spark.sql.execution.exchange.ReusedExchangeExec + cometSubqueries.foreach { csb => + val aspe = csb.child.asInstanceOf[AdaptiveSparkPlanExec] + val hasReusedExchange = collect(aspe) { case r: ReusedExchangeExec => + r + }.nonEmpty || collect(aspe) { case b: BroadcastQueryStageExec => + b + }.nonEmpty + assert( + hasReusedExchange, + "DPP subquery's ASPE should contain ReusedExchangeExec or " + + "BroadcastQueryStageExec for broadcast reuse") + } + + // Verify no unconverted SABs remain + assertAqeDppShape(cometPlan) + + // Verify no fallback + assert( + !infos.contains("AQE Dynamic Partition Pruning"), + s"Should not fall back for AQE DPP:\n$infos") + } else { + // 3.4: scan falls back to Spark so Spark handles DPP natively + assert( + infos.contains("AQE Dynamic Partition Pruning requires Spark 3.5+"), + s"Expected 3.4 AQE DPP fallback message but got:\n$infos") + } + } + } + } + + // With Comet BHJ disabled, the join stays as BroadcastHashJoinExec. Our rule finds + // it and creates SubqueryBroadcastExec (not CometSubqueryBroadcastExec). + test("AQE DPP: fallback when broadcast exchange is not Comet") { + withTempDir { dir => + val path = s"${dir.getAbsolutePath}/data" + withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "false") { + spark + .range(100) + .selectExpr("cast(id % 10 as int) as store_id", "cast(id as int) as amount") + .write + .partitionBy("store_id") + .parquet(s"$path/fact") + spark + .range(10) + .selectExpr("cast(id as int) as store_id", "cast(id as string) as country") + .write + .parquet(s"$path/dim") + } + + withSQLConf( + SQLConf.USE_V1_SOURCE_LIST.key -> "parquet", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + CometConf.COMET_EXEC_BROADCAST_EXCHANGE_ENABLED.key -> "false", + CometConf.COMET_EXEC_BROADCAST_HASH_JOIN_ENABLED.key -> "false") { + spark.read.parquet(s"$path/fact").createOrReplaceTempView("aqe_fact_fallback") + spark.read.parquet(s"$path/dim").createOrReplaceTempView("aqe_dim_fallback") + + val df = spark.sql("""SELECT f.amount, f.store_id + |FROM aqe_fact_fallback f JOIN aqe_dim_fallback d + |ON f.store_id = d.store_id + |WHERE d.country = 'DE'""".stripMargin) + val (_, cometPlan) = checkSparkAnswer(df) + + // Verify no CometSubqueryBroadcastExec — Spark handles DPP with its own + // SubqueryBroadcastExec since the join is Spark's BroadcastHashJoinExec + if (isSpark35Plus) { + val cometSubqueries = collectWithSubqueries(cometPlan) { + case s: CometSubqueryBroadcastExec => s + } + assert( + cometSubqueries.isEmpty, + "Should not have CometSubqueryBroadcastExec when Comet BHJ is disabled") + } + } + } + } + + // No broadcast to reuse, so DPP falls back to Literal.TrueLiteral. + test("AQE DPP: SMJ disables DPP gracefully") { + withTempDir { path => + val factPath = s"${path.getAbsolutePath}/fact.parquet" + val dimPath = s"${path.getAbsolutePath}/dim.parquet" + withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "false") { + val one_day = 24 * 60 * 60000 + val fact = Range(0, 100) + .map(i => (i, new java.sql.Date(System.currentTimeMillis() + (i % 10) * one_day))) + .toDF("fact_id", "fact_date") + fact.write.partitionBy("fact_date").parquet(factPath) + val dim = Range(0, 10) + .map(i => (i, new java.sql.Date(System.currentTimeMillis() + i * one_day))) + .toDF("dim_id", "dim_date") + dim.write.parquet(dimPath) + } + + withSQLConf( + SQLConf.USE_V1_SOURCE_LIST.key -> "parquet", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + spark.read.parquet(factPath).createOrReplaceTempView("aqe_dpp_fact_smj") + spark.read.parquet(dimPath).createOrReplaceTempView("aqe_dpp_dim_smj") + val df = spark.sql( + "select * from aqe_dpp_fact_smj join aqe_dpp_dim_smj " + + "on fact_date = dim_date where dim_id > 7") + val (_, cometPlan) = checkSparkAnswer(df) + + if (isSpark35Plus) { + // Verify native scan is used (DPP disabled via TrueLiteral, but scan still native) + val nativeScans = collect(cometPlan) { case s: CometNativeScanExec => s } + assert(nativeScans.nonEmpty, "Expected CometNativeScanExec in plan") + + // No CometSubqueryBroadcastExec (DPP was disabled), no unconverted SABs + assertAqeDppShape(cometPlan, expectedCometSubqueryBroadcasts = Some(0)) + + // Case 2 of CometPlanAdaptiveDynamicPruningFilters: SMJ with REUSE_BROADCAST_ONLY + // (Spark's default) sets onlyInBroadcast=true on the SAB. With no reusable + // broadcast, the rule replaces the DPP filter with DynamicPruningExpression( + // Literal.TrueLiteral). This distinguishes Case 2 from Case 3 (aggregate + // SubqueryExec), which would appear as a nested SubqueryExec in the filter. + val trueLiteralFilters = nativeScans.flatMap(_.partitionFilters).collect { + case DynamicPruningExpression(Literal.TrueLiteral) => true + } + assert( + trueLiteralFilters.nonEmpty, + "Expected DynamicPruningExpression(TrueLiteral) for onlyInBroadcast=true SMJ, " + + s"got partitionFilters: ${nativeScans.map(_.partitionFilters).mkString("; ")}") + } + } + } + } + + // Each DPP filter should match the correct broadcast join by buildKeys exprId. + test("AQE DPP: two separate broadcast joins") { + withTempDir { dir => + val path = s"${dir.getAbsolutePath}/data" + withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "false") { + spark + .range(100) + .selectExpr( + "cast(id % 5 as int) as store_id", + "cast(id % 3 as int) as region_id", + "cast(id as int) as amount") + .write + .partitionBy("store_id", "region_id") + .parquet(s"$path/fact") + spark + .range(5) + .selectExpr("cast(id as int) as store_id", "cast(id as string) as store_name") + .write + .parquet(s"$path/store_dim") + spark + .range(3) + .selectExpr("cast(id as int) as region_id", "cast(id as string) as region_name") + .write + .parquet(s"$path/region_dim") + } + + withSQLConf( + SQLConf.USE_V1_SOURCE_LIST.key -> "parquet", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + spark.read.parquet(s"$path/fact").createOrReplaceTempView("aqe_fact_two_joins") + spark.read.parquet(s"$path/store_dim").createOrReplaceTempView("aqe_store_dim") + spark.read.parquet(s"$path/region_dim").createOrReplaceTempView("aqe_region_dim") + + val df = spark.sql("""SELECT f.amount, s.store_name, r.region_name + |FROM aqe_fact_two_joins f + |JOIN aqe_store_dim s ON f.store_id = s.store_id + |JOIN aqe_region_dim r ON f.region_id = r.region_id + |WHERE s.store_name = '1' AND r.region_name = '2'""".stripMargin) + val (_, cometPlan) = checkSparkAnswer(df) + + if (isSpark35Plus) { + val nativeScans = collect(cometPlan) { case s: CometNativeScanExec => s } + assert(nativeScans.nonEmpty, "Expected CometNativeScanExec in plan") + + val dppScans = nativeScans.filter( + _.partitionFilters.exists(_.isInstanceOf[DynamicPruningExpression])) + assert(dppScans.nonEmpty, "Expected DPP filters on native scan") + + // Verify no unconverted SABs + assertAqeDppShape(cometPlan) + } + } + } + } + + test("AQE DPP: empty broadcast result") { + withTempDir { dir => + val path = s"${dir.getAbsolutePath}/data" + withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "false") { + spark + .range(100) + .selectExpr("cast(id % 10 as int) as store_id", "cast(id as int) as amount") + .write + .partitionBy("store_id") + .parquet(s"$path/fact") + spark + .range(10) + .selectExpr("cast(id as int) as store_id", "cast(id as string) as country") + .write + .parquet(s"$path/dim") + } + + withSQLConf( + SQLConf.USE_V1_SOURCE_LIST.key -> "parquet", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + spark.read.parquet(s"$path/fact").createOrReplaceTempView("aqe_fact_empty") + spark.read.parquet(s"$path/dim").createOrReplaceTempView("aqe_dim_empty") + + val df = spark.sql("""SELECT f.amount, f.store_id + |FROM aqe_fact_empty f JOIN aqe_dim_empty d + |ON f.store_id = d.store_id + |WHERE d.country = 'NONEXISTENT'""".stripMargin) + val result = df.collect() + assert(result.isEmpty, s"Expected empty result but got ${result.length} rows") + checkSparkAnswer(df) + } + } + } + + // Both outer (CometNativeScanExec) and inner (CometScanExec) partition filters must + // be resolved. Correct results prove both filter sets were converted. + test("AQE DPP: resolves both outer and inner partition filters") { + withTempDir { dir => + val path = s"${dir.getAbsolutePath}/data" + withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "false") { + spark + .range(100) + .selectExpr( + "cast(id % 10 as int) as store_id", + "cast(id as int) as date_id", + "cast(id as int) as amount") + .write + .partitionBy("store_id") + .parquet(s"$path/fact") + spark + .range(10) + .selectExpr("cast(id as int) as store_id", "cast(id as string) as country") + .write + .parquet(s"$path/dim") + } + + withSQLConf( + SQLConf.USE_V1_SOURCE_LIST.key -> "parquet", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { + spark.read.parquet(s"$path/fact").createOrReplaceTempView("aqe_fact_dual") + spark.read.parquet(s"$path/dim").createOrReplaceTempView("aqe_dim_dual") + + val df = spark.sql("""SELECT f.date_id, f.store_id + |FROM aqe_fact_dual f JOIN aqe_dim_dual d + |ON f.store_id = d.store_id + |WHERE d.country = '3'""".stripMargin) + val (_, cometPlan) = checkSparkAnswer(df) + + if (isSpark35Plus) { + val nativeScans = collect(cometPlan) { case s: CometNativeScanExec => s } + assert(nativeScans.nonEmpty, "Expected CometNativeScanExec in plan") + + val dppScans = nativeScans.filter( + _.partitionFilters.exists(_.isInstanceOf[DynamicPruningExpression])) + assert(dppScans.nonEmpty, "Expected DPP filter on native scan") + + // Verify CometSubqueryBroadcastExec is present (not TrueLiteral fallback) + val cometSubqueries = collectWithSubqueries(cometPlan) { + case s: CometSubqueryBroadcastExec => s + } + assert( + cometSubqueries.nonEmpty, + "Expected CometSubqueryBroadcastExec (DPP should be active, not TrueLiteral)") + } + } + } + } + + // DPP subquery reuses the join's broadcast via AQE stageCache. + test("AQE DPP: broadcast exchange reuse") { + withDppTables { + withSQLConf( + SQLConf.USE_V1_SOURCE_LIST.key -> "parquet", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { + + val df = sql("""SELECT /*+ BROADCAST(f)*/ + |f.date_id, f.store_id, f.product_id, f.units_sold FROM fact_np f + |JOIN code_stats s + |ON f.store_id = s.store_id WHERE f.date_id <= 1030""".stripMargin) + val (_, cometPlan) = checkSparkAnswer(df) + + if (isSpark35Plus) { + // Verify no unconverted SABs remain + assertAqeDppShape(cometPlan) + + // If DPP subqueries are present, verify they use CometSubqueryBroadcastExec + // with AdaptiveSparkPlanExec children (ASPE wrapping broadcast for stageCache reuse) + val cometSubqueries = collectWithSubqueries(cometPlan) { + case s: CometSubqueryBroadcastExec => s + } + cometSubqueries.foreach { csb => + assert( + csb.child.isInstanceOf[AdaptiveSparkPlanExec], + "CometSubqueryBroadcastExec child should be AdaptiveSparkPlanExec, " + + s"got ${csb.child.getClass.getSimpleName}") + } + } + } + } + } + + test("AQE DPP: non-atomic type (struct/array) join key") { + withTempDir { dir => + val path = s"${dir.getAbsolutePath}/data" + withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "false") { + spark + .range(100) + .selectExpr( + "cast(id % 10 as int) as store_id", + "cast(id as int) as date_id", + "cast(id * 2 as int) as units_sold") + .write + .partitionBy("store_id") + .parquet(s"$path/fact") + spark + .range(10) + .selectExpr("cast(id as int) as store_id", "cast(id as string) as country") + .write + .parquet(s"$path/dim") + } + + Seq("struct", "array").foreach { dataType => + withSQLConf( + SQLConf.USE_V1_SOURCE_LIST.key -> "parquet", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + spark.read.parquet(s"$path/fact").createOrReplaceTempView("aqe_fact_nonatomic") + spark.read.parquet(s"$path/dim").createOrReplaceTempView("aqe_dim_nonatomic") + val df = spark.sql(s"""SELECT f.date_id, f.store_id FROM aqe_fact_nonatomic f + |JOIN aqe_dim_nonatomic d + |ON $dataType(f.store_id) = $dataType(d.store_id) + |WHERE d.country = '3'""".stripMargin) + checkSparkAnswer(df) + } + } + } + } + + test("AQE DPP: non-atomic type uses CometSubqueryBroadcastExec") { + withDppTables { + withSQLConf( + SQLConf.USE_V1_SOURCE_LIST.key -> "parquet", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { + + Seq("struct", "array").foreach { dataType => + val df = + sql(s"""SELECT f.date_id, f.product_id, f.units_sold, f.store_id FROM fact_stats f + |JOIN dim_stats s + |ON $dataType(f.store_id) = $dataType(s.store_id) WHERE s.country = 'DE' + """.stripMargin) + val (_, cometPlan) = checkSparkAnswer(df) + + if (isSpark35Plus) { + val cometSubqueries = collectWithSubqueries(cometPlan) { + case s: CometSubqueryBroadcastExec => s + } + assert( + cometSubqueries.nonEmpty, + s"Expected DPP with CometSubqueryBroadcastExec for $dataType key:\n" + + cometPlan.treeString) + + assertAqeDppShape(cometPlan) + } + } + } + } + } + + // With onlyInBroadcast=false and exchange reuse disabled, our rule falls through to + // Case 3 (aggregate SubqueryExec) instead of broadcast reuse or TrueLiteral. + // Reproduces DynamicPartitionPruningSuite "simple inner join triggers DPP with mock-up tables". + test("AQE DPP: inner join with broadcast reuse disabled") { + withDppTables { + withSQLConf( + SQLConf.USE_V1_SOURCE_LIST.key -> "parquet", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", + SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") { + val df = sql("""SELECT f.date_id, f.store_id FROM fact_sk f + |JOIN dim_store s ON f.store_id = s.store_id AND s.country = 'NL'""".stripMargin) + val (_, cometPlan) = checkSparkAnswer(df) + + if (isSpark35Plus) { + val nativeScans = collect(cometPlan) { case s: CometNativeScanExec => s } + assert(nativeScans.nonEmpty, "Expected CometNativeScanExec in plan") + + assertAqeDppShape(cometPlan) + } + } + } + } + + // Scan is in a shuffle stage separated from the broadcast join. Cross-stage + // broadcast search (via context.qe.executedPlan) must find the join. + test("AQE DPP: avoid reordering broadcast join keys") { + withSQLConf( + SQLConf.USE_V1_SOURCE_LIST.key -> "parquet", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + withTable("large", "dimTwo", "dimThree") { + spark + .range(100) + .select($"id", ($"id" + 1).as("A"), ($"id" + 2).as("B")) + .write + .partitionBy("A") + .format("parquet") + .mode("overwrite") + .saveAsTable("large") + + spark + .range(10) + .select($"id", ($"id" + 1).as("C"), ($"id" + 2).as("D")) + .write + .format("parquet") + .mode("overwrite") + .saveAsTable("dimTwo") + + spark + .range(10) + .select($"id", ($"id" + 1).as("E"), ($"id" + 2).as("F"), ($"id" + 3).as("G")) + .write + .format("parquet") + .mode("overwrite") + .saveAsTable("dimThree") + + val fact = sql("SELECT * from large") + val dim = sql("SELECT * from dimTwo") + val prod = sql("SELECT * from dimThree") + + val df = fact + .join(dim, fact.col("A") === dim.col("C") && fact.col("B") === dim.col("D"), "LEFT") + .join( + broadcast(prod), + fact.col("B") === prod.col("F") && fact.col("A") === prod.col("E")) + .where(prod.col("G") > 5) + + val (_, cometPlan) = checkSparkAnswer(df) + + if (isSpark35Plus) { + val dpExprs = flatMap(cometPlan) { + case s: CometNativeScanExec => + s.partitionFilters.collect { case d: DynamicPruningExpression => d.child } + case _ => Nil + } + val hasSubquery = dpExprs.exists { + case InSubqueryExec(_, _: SubqueryExec, _, _, _, _) => true + case _ => false + } + val hasBroadcast = dpExprs.exists { + case InSubqueryExec(_, _: SubqueryBroadcastExec, _, _, _, _) => true + case InSubqueryExec(_, _: CometSubqueryBroadcastExec, _, _, _, _) => true + case _ => false + } + assert(!hasSubquery, "Should not have SubqueryExec DPP") + assert(hasBroadcast, "Should have broadcast DPP") + } + } + } + } + + // Cross-plan subquery deduplication via the shared AdaptiveExecutionContext.subqueryCache. + test("AQE DPP: uncorrelated scalar subquery with broadcast reuse") { + withDppTables { + withSQLConf( + SQLConf.USE_V1_SOURCE_LIST.key -> "parquet", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true", + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { + val df = sql(""" + |SELECT d.store_id, + | SUM(f.units_sold), + | (SELECT SUM(f.units_sold) + | FROM fact_stats f JOIN dim_stats d ON d.store_id = f.store_id + | WHERE d.country = 'US') AS total_prod + |FROM fact_stats f JOIN dim_stats d ON d.store_id = f.store_id + |WHERE d.country = 'US' + |GROUP BY 1 + """.stripMargin) + val (_, cometPlan) = checkSparkAnswer(df) + + if (isSpark35Plus) { + val countBroadcasts = collectWithSubqueries(cometPlan) { + case _: SubqueryBroadcastExec => 1 + case _: CometSubqueryBroadcastExec => 1 + }.sum + val countReused = collectWithSubqueries(cometPlan) { + case ReusedSubqueryExec(_: SubqueryBroadcastExec) => 1 + case ReusedSubqueryExec(_: CometSubqueryBroadcastExec) => 1 + }.sum + + assert(countBroadcasts == 1, s"Expected 1 SubqueryBroadcast, got $countBroadcasts") + assert(countReused == 1, s"Expected 1 ReusedSubquery, got $countReused") + } + } + } + } + + // From RemoveRedundantProjectsSuite "join with ordering requirement". + // DPP subquery uses ReusedExchangeExec so collectWithSubqueries doesn't + // double-count project nodes. + test("AQE DPP: join with ordering requirement project count") { + withSQLConf( + SQLConf.USE_V1_SOURCE_LIST.key -> "parquet", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true") { + withTable("testViewTable") { + withTempView("testView") { + spark + .range(0, 100, 1) + .selectExpr( + "id as key", + "id as a", + "id as b", + "cast(id as string) as c", + "cast(id as string) as d") + .write + .format("parquet") + .mode("overwrite") + .partitionBy("key") + .saveAsTable("testViewTable") + sql("CREATE OR REPLACE TEMP VIEW testView AS SELECT * FROM testViewTable") + + val query = "select * from (select key, a, c, b from testView) as t1 join " + + "(select key, a, b, c from testView) as t2 on t1.key = t2.key where t2.a > 50" + + val (sparkPlan, cometPlan) = checkSparkAnswer(sql(query)) + + if (isSpark35Plus) { + val sparkProjects = collectWithSubqueries(sparkPlan) { case p: ProjectExec => p } + val cometProjects = collectWithSubqueries(cometPlan) { + case p: ProjectExec => p + case p: CometProjectExec => p + } + assert( + cometProjects.size == sparkProjects.size, + s"Comet project count (${cometProjects.size}) should match " + + s"Spark (${sparkProjects.size})") + } + } + } + } + } + + // SPARK-39447: SHUFFLE_MERGE hint forces SMJ, empty CTE means the DPP subquery's + // ASPE re-optimizes to LocalTableScan. CometBroadcastExchangeExec must handle the + // resulting non-Comet child gracefully. + test("AQE DPP: SPARK-39447 avoid assertion in doExecuteBroadcast") { + withDppTables { + withSQLConf( + SQLConf.USE_V1_SOURCE_LIST.key -> "parquet", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED.key -> "true") { + val df = sql(""" + |WITH empty_result AS ( + | SELECT * FROM fact_stats WHERE product_id < 0 + |) + |SELECT * + |FROM (SELECT /*+ SHUFFLE_MERGE(fact_sk) */ empty_result.store_id + | FROM fact_sk + | JOIN empty_result + | ON fact_sk.product_id = empty_result.product_id) t2 + | JOIN empty_result + | ON t2.store_id = empty_result.store_id + """.stripMargin) + val (_, cometPlan) = checkSparkAnswer(df) + checkAnswer(df, Nil) + + if (isSpark35Plus) { + assertAqeDppShape(cometPlan) + } + } + } + } + + // SPARK-32509: previously IgnoreComet(#4045). Unused DPP filter with + // AUTO_BROADCASTJOIN_THRESHOLD=-1 (no broadcast). Should not affect exchange reuse. + test("AQE DPP: unused DPP filter and exchange reuse (SPARK-32509)") { + withDppTables { + withSQLConf( + SQLConf.USE_V1_SOURCE_LIST.key -> "parquet", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + val df = sql(""" WITH view1 as ( + | SELECT f.store_id FROM fact_stats f WHERE f.units_sold = 70 + | ) + | SELECT * FROM view1 v1 join view1 v2 WHERE v1.store_id = v2.store_id + """.stripMargin) + val (_, cometPlan) = checkSparkAnswer(df) + checkAnswer(df, Row(15, 15) :: Nil) + + import org.apache.spark.sql.execution.exchange.ReusedExchangeExec + val reusedExchanges = collect(cometPlan) { case r: ReusedExchangeExec => r } + assert( + reusedExchanges.size == 1, + s"Expected 1 ReusedExchangeExec, got ${reusedExchanges.size}.\n" + + s"Plan:\n${cometPlan.treeString}") + } + } + } + + // SPARK-34637: previously IgnoreComet(#4045). DPP side broadcast query stage + // should be created before the main join's broadcast stage. + test("AQE DPP: broadcast query stage creation order (SPARK-34637)") { + withDppTables { + withSQLConf( + SQLConf.USE_V1_SOURCE_LIST.key -> "parquet", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { + val df = sql(""" WITH v as ( + | SELECT f.store_id FROM fact_stats f WHERE f.units_sold = 70 group by f.store_id + | ) + | SELECT * FROM v v1 join v v2 WHERE v1.store_id = v2.store_id + """.stripMargin) + val (_, cometPlan) = checkSparkAnswer(df) + checkAnswer(df, Row(15, 15) :: Nil) + + if (isSpark35Plus) { + val cometSubqueries = collectWithSubqueries(cometPlan) { + case s: CometSubqueryBroadcastExec => s + } + assert( + cometSubqueries.nonEmpty, + s"Expected CometSubqueryBroadcastExec for DPP.\nPlan:\n${cometPlan.treeString}") + } else { + // On 3.4, fall back to Spark-native DPP: expect a SubqueryBroadcastExec + // (not Comet) indicating Spark's PlanAdaptiveDynamicPruningFilters ran. + val sparkSubqueries = collectWithSubqueries(cometPlan) { + case s: SubqueryBroadcastExec => s + } + assert( + sparkSubqueries.nonEmpty, + "Expected Spark SubqueryBroadcastExec for DPP on 3.4 fallback. " + + "If empty, Spark's rule killed DPP (likely because Comet BHJ was " + + s"not falling back).\nPlan:\n${cometPlan.treeString}") + } + } + } + } + + // Reproduces DynamicPartitionPruningSuiteV2: SPARK-34637 with V2 BatchScan. + // Uses InMemoryTableCatalog so Spark creates BatchScanExec (not FileSourceScanExec). + // Comet replaces BroadcastHashJoinExec with CometBroadcastHashJoinExec, which + // breaks Spark's PlanAdaptiveDynamicPruningFilters pattern match for non-Comet scans. + test("AQE DPP: V2 BatchScan broadcast query stage creation order (SPARK-34637)") { + // On Spark 4.1+, the shuffle between partial/final aggregates is elided for this + // plan, which removes the only Comet entry point (CometColumnarShuffle over a Spark + // shuffle) that would let the cascade reach CometBroadcastHashJoinExec. Without a + // Comet BHJ, CometPlanAdaptiveDynamicPruningFilters falls into its Spark-native + // branch and produces SubqueryBroadcastExec instead of CometSubqueryBroadcastExec. + // DPP is still correct and broadcast reuse still fires, so we branch the + // assertion by version rather than skipping the whole test. + // + // Enabling CometSparkToColumnar (COMET_SPARK_TO_ARROW_ENABLED + adding "BatchScan" to + // COMET_SPARK_TO_ARROW_SUPPORTED_OPERATOR_LIST) would give Comet a scan-level entry + // point, but it also exposes a separate bug in CometExecRule.transform + // (https://github.com/apache/datafusion-comet/issues/4145): SAB/SubqueryBroadcastExec + // wrapping runs only on the post-convertNode tree, so when convertNode wraps a scan in + // CometSparkToColumnarExec the wrapped scan's runtimeFilters/partitionFilters are + // hidden from the SAB-wrapping pass. Any V2 scan routed through CometSparkToColumnarExec + // with DPP filters skips SAB wrapping. That bug is independent of this test. + val factData = Seq( + (1000, 1, 1, 10), + (1010, 2, 1, 10), + (1020, 2, 1, 10), + (1030, 3, 2, 10), + (1040, 3, 2, 50), + (1050, 3, 2, 50), + (1060, 3, 2, 50), + (1070, 4, 2, 10), + (1080, 4, 3, 20), + (1090, 4, 3, 10), + (1100, 4, 3, 10), + (1110, 5, 3, 10), + (1120, 6, 4, 10), + (1130, 7, 4, 50), + (1140, 8, 4, 50), + (1150, 9, 1, 20), + (1160, 10, 1, 20), + (1170, 11, 1, 30), + (1180, 12, 2, 20), + (1190, 13, 2, 20), + (1200, 14, 3, 40), + (1200, 15, 3, 70), + (1210, 16, 4, 10), + (1220, 17, 4, 20), + (1230, 18, 4, 20), + (1240, 19, 5, 40), + (1250, 20, 5, 40), + (1260, 21, 5, 40), + (1270, 22, 5, 50), + (1280, 23, 1, 50), + (1290, 24, 1, 50), + (1300, 25, 1, 50)) + + import testImplicits._ + withSQLConf( + "spark.sql.catalog.testcat" -> classOf[InMemoryTableCatalog].getName, + "spark.sql.defaultCatalog" -> "testcat", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { + + factData + .toDF("date_id", "store_id", "product_id", "units_sold") + .write + .partitionBy("store_id") + .saveAsTable("fact_stats_v2") + + try { + val df = sql(""" WITH v as ( + | SELECT f.store_id FROM fact_stats_v2 f + | WHERE f.units_sold = 70 GROUP BY f.store_id + | ) + | SELECT * FROM v v1 JOIN v v2 WHERE v1.store_id = v2.store_id + """.stripMargin) + + val (_, cometPlan) = checkSparkAnswer(df) + checkAnswer(df, Row(15, 15) :: Nil) + + // SABs should have been unwrapped by CometPlanAdaptiveDynamicPruningFilters on 3.5+. + if (isSpark35Plus) { + assertAqeDppShape(cometPlan) + } + + if (isSpark35Plus && !isSpark41Plus) { + // 3.5 - 4.0: CometPlanAdaptiveDynamicPruningFilters rewrites the SAB into + // CometSubqueryBroadcastExec with the join's CometBroadcastExchange for native + // broadcast reuse. + val cometSubqueries = collectWithSubqueries(cometPlan) { + case s: CometSubqueryBroadcastExec => s + } + assert( + cometSubqueries.nonEmpty, + s"V2 scan should have CometSubqueryBroadcastExec for DPP:\n$cometPlan") + } else { + // 3.4 and 4.1+: DPP runs as Spark-native SubqueryBroadcastExec. + // - 3.4: CometSpark34AqeDppFallbackRule keeps the BHJ build broadcast Spark-native + // so Spark's PlanAdaptiveDynamicPruningFilters can create SubqueryBroadcastExec + // and AQE stageCache can dedupe with the DPP subquery's broadcast. + // - 4.1+: Partial/final aggregate shuffle is elided, which removes Comet's entry + // point for this query, so CometPlanAdaptiveDynamicPruningFilters falls into + // its Spark-native branch. + val sparkSubqueries = collectWithSubqueries(cometPlan) { + case s: SubqueryBroadcastExec => s + } + assert( + sparkSubqueries.nonEmpty, + s"V2 scan should have SubqueryBroadcastExec for DPP:\n$cometPlan") + + // Broadcast reuse: the DPP subquery's BroadcastExchange must be reused in the main + // plan as a ReusedExchangeExec (or appear directly). Mirrors Spark's + // DynamicPartitionPruningSuiteBase.checkPartitionPruningPredicate hasReuse check + // (DynamicPartitionPruningSuite.scala:207-231). Without this, the main BHJ's build + // side would run a second broadcast with the same data. + sparkSubqueries.foreach { s => + val dppBroadcast = s.child match { + case aspe: AdaptiveSparkPlanExec => + val bqs = collectFirst(aspe) { case b: BroadcastQueryStageExec => b } + assert( + bqs.isDefined, + s"Expected BroadcastQueryStageExec under DPP subquery's ASPE:\n$cometPlan") + bqs.get.broadcast + case other => + fail(s"Unexpected SubqueryBroadcastExec child: ${other.getClass.getSimpleName}") + } + val hasReuse = find(cometPlan) { + case ReusedExchangeExec(_, e) => e eq dppBroadcast + case b: BroadcastExchangeLike => b eq dppBroadcast + case _ => false + }.isDefined + assert(hasReuse, s"DPP broadcast should be reused in main plan:\n$cometPlan") + } + } + } finally { + sql("DROP TABLE IF EXISTS testcat.fact_stats_v2") + } + } + } + + // Regression for the TPC-DS q5/q14a/q14b/q54 failure: two fact scans inside a single + // UNION ALL that joins a dimension once produce DPP subqueries that share their + // logical build plan (since the join pushes DPP down to both scans via one subquery). + // Spark's ReuseAdaptiveSubquery (which runs before our rule) collapses them into + // ReusedSubqueryExec(CometSubqueryAdaptiveBroadcastExec). Our rule's extractSABData + // must unwrap ReusedSubqueryExec before inspecting the inner plan; otherwise the + // wrapped CSAB survives to runtime and doExecute() throws. + test("AQE DPP: ReuseAdaptiveSubquery wraps CSAB in ReusedSubqueryExec") { + withTempDir { path => + val fact1Path = s"${path.getAbsolutePath}/fact1.parquet" + val fact2Path = s"${path.getAbsolutePath}/fact2.parquet" + val dimPath = s"${path.getAbsolutePath}/dim.parquet" + withSQLConf(CometConf.COMET_EXEC_ENABLED.key -> "false") { + val one_day = 24 * 60 * 60000 + val fact1 = Range(0, 100) + .map(i => (i, new java.sql.Date(System.currentTimeMillis() + (i % 10) * one_day))) + .toDF("fact_id", "fact_date") + fact1.write.partitionBy("fact_date").parquet(fact1Path) + val fact2 = Range(100, 200) + .map(i => (i, new java.sql.Date(System.currentTimeMillis() + (i % 10) * one_day))) + .toDF("fact_id", "fact_date") + fact2.write.partitionBy("fact_date").parquet(fact2Path) + val dim = Range(0, 10) + .map(i => (i, new java.sql.Date(System.currentTimeMillis() + i * one_day))) + .toDF("dim_id", "dim_date") + dim.write.parquet(dimPath) + } + + withSQLConf( + SQLConf.USE_V1_SOURCE_LIST.key -> "parquet", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.DYNAMIC_PARTITION_PRUNING_REUSE_BROADCAST_ONLY.key -> "true") { + spark.read.parquet(fact1Path).createOrReplaceTempView("aqe_dpp_reuse_fact1") + spark.read.parquet(fact2Path).createOrReplaceTempView("aqe_dpp_reuse_fact2") + spark.read.parquet(dimPath).createOrReplaceTempView("aqe_dpp_reuse_dim") + + // Mirror TPC-DS q54: UNION ALL of two fact tables inside a single join to + // the dimension. DPP is pushed through the UNION to both fact scans from a + // single logical DynamicPruningSubquery, so both SABs share their buildPlan + // and canonicalize identically. ReuseAdaptiveSubquery then wraps one in a + // ReusedSubqueryExec, exercising the bug path. + val df = spark.sql(""" + |SELECT f.fact_id, f.fact_date + |FROM ( + | SELECT fact_id, fact_date FROM aqe_dpp_reuse_fact1 + | UNION ALL + | SELECT fact_id, fact_date FROM aqe_dpp_reuse_fact2 + |) f + |JOIN aqe_dpp_reuse_dim d ON f.fact_date = d.dim_date + |WHERE d.dim_id > 7 + """.stripMargin) + val (_, cometPlan) = checkSparkAnswer(df) + + if (isSpark35Plus) { + // Regression check: without the ReusedSubqueryExec unwrap in extractSABData, + // one CSAB survives the rule and trips CometSubqueryAdaptiveBroadcastExec.doExecute + // at runtime. assertAqeDppShape verifies no CSABs remain in the final plan. + assertAqeDppShape(cometPlan) + + // Subquery reuse: exactly one canonical CometSubqueryBroadcastExec, plus at + // least one ReusedSubqueryExec(CometSubqueryBroadcastExec) pointer for the + // second fact scan. Without this dedup, both fact scans would evaluate the + // DPP subquery independently. + val cometSubqueries = collectWithSubqueries(cometPlan) { + case s: CometSubqueryBroadcastExec => s + } + assert( + cometSubqueries.size == 1, + "Expected exactly 1 CometSubqueryBroadcastExec (shared between fact scans), " + + s"got ${cometSubqueries.size}:\n${cometPlan.treeString}") + val reusedCsbs = collectWithSubqueries(cometPlan) { + case r @ ReusedSubqueryExec(_: CometSubqueryBroadcastExec) => r + } + assert( + reusedCsbs.nonEmpty, + "Expected at least one ReusedSubqueryExec(CometSubqueryBroadcastExec) " + + s"for the second fact scan's DPP filter:\n${cometPlan.treeString}") + + // Broadcast reuse via AQE stageCache: the DPP subquery's ASPE and the main + // BHJ should share the same underlying CometBroadcastExchange. Without this, + // we'd build two identical broadcasts of the dim. + val dppBroadcast = cometSubqueries.head.child match { + case aspe: AdaptiveSparkPlanExec => + val bqs = collectFirst(aspe) { case b: BroadcastQueryStageExec => b } + assert( + bqs.isDefined, + "Expected BroadcastQueryStageExec inside DPP subquery's ASPE:\n" + + cometPlan.treeString) + bqs.get.broadcast + case other => + fail( + s"Unexpected CometSubqueryBroadcastExec child: ${other.getClass.getSimpleName}") + } + val hasReuse = find(cometPlan) { + case ReusedExchangeExec(_, e) => e eq dppBroadcast + case b: BroadcastExchangeLike => b eq dppBroadcast + case _ => false + }.isDefined + assert( + hasReuse, + "DPP subquery's broadcast should be reused by the main BHJ " + + s"(via AQE stageCache):\n${cometPlan.treeString}") + } else { + // Spark 3.4: injectQueryStageOptimizerRule is unavailable, so + // CometPlanAdaptiveDynamicPruningFilters can't run. V1 fact scans are rejected + // to Spark by CometScanRule.transformV1Scan, and CometSpark34AqeDppFallbackRule + // tags the BHJ's build-side BroadcastExchange so Spark's own + // PlanAdaptiveDynamicPruningFilters handles DPP natively. Expected shape + // mirrors the 3.5+ assertions but with Spark-native node types. + val sparkSubqueries = collectWithSubqueries(cometPlan) { + case s: SubqueryBroadcastExec => s + } + assert( + sparkSubqueries.size == 1, + "Expected exactly 1 SubqueryBroadcastExec on 3.4 (Spark-native DPP, " + + s"shared between fact scans), got ${sparkSubqueries.size}. If 0, " + + "CometSpark34AqeDppFallbackRule didn't keep the BHJ Spark-native and " + + s"Spark's rule killed DPP:\n${cometPlan.treeString}") + val reusedSparkSubqueries = collectWithSubqueries(cometPlan) { + case r @ ReusedSubqueryExec(_: SubqueryBroadcastExec) => r + } + assert( + reusedSparkSubqueries.nonEmpty, + "Expected at least one ReusedSubqueryExec(SubqueryBroadcastExec) on 3.4 " + + s"for the second fact scan's DPP filter:\n${cometPlan.treeString}") + val dppBroadcast = sparkSubqueries.head.child match { + case aspe: AdaptiveSparkPlanExec => + val bqs = collectFirst(aspe) { case b: BroadcastQueryStageExec => b } + assert( + bqs.isDefined, + "Expected BroadcastQueryStageExec inside DPP subquery's ASPE:\n" + + cometPlan.treeString) + bqs.get.broadcast + case other => + fail(s"Unexpected SubqueryBroadcastExec child: ${other.getClass.getSimpleName}") + } + val hasReuse = find(cometPlan) { + case ReusedExchangeExec(_, e) => e eq dppBroadcast + case b: BroadcastExchangeLike => b eq dppBroadcast + case _ => false + }.isDefined + assert( + hasReuse, + "DPP subquery's broadcast should be reused by the main BHJ on 3.4:\n" + + cometPlan.treeString) + } + } + } + } + test("ShuffleQueryStageExec could be direct child node of CometBroadcastExchangeExec") { withSQLConf(CometConf.COMET_SHUFFLE_MODE.key -> "jvm") { val table = "src" diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometDppFallbackRepro3949Suite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometDppFallbackRepro3949Suite.scala index 1e0dae391d..56a1b44070 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometDppFallbackRepro3949Suite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometDppFallbackRepro3949Suite.scala @@ -105,10 +105,21 @@ class CometDppFallbackRepro3949Suite extends CometTestBase { found } + // Exercises the stickiness mechanism: once shuffleSupported decides to fall back (because + // stageContainsDPPScan finds a FileSourceScanExec with DPP), the decision persists even + // when AQE wraps the child in a QueryStageExec (hiding the scan from tree walks). + // + // On 3.5+, AQE DPP scans normally convert to CometNativeScanExec (via + // CometPlanAdaptiveDynamicPruningFilters), so stageContainsDPPScan doesn't trigger. + // We disable native scan to force the FileSourceScanExec fallback and exercise the path. test("mechanism: DPP fallback decision is sticky across an AQE-style child wrap") { withTempDir { dir => buildDppTables(dir, "mech") withSQLConf( + // Disable native scan so the scan stays as FileSourceScanExec with DPP, + // producing the mixed state (Comet shuffle wrapping Spark DPP scan) that + // stageContainsDPPScan is designed to catch. + CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", SQLConf.PREFER_SORTMERGEJOIN.key -> "true", SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", @@ -132,11 +143,11 @@ class CometDppFallbackRepro3949Suite extends CometTestBase { } // Simulate AQE stage prep: wrap the shuffle's child in an opaque LeafExecNode, - // matching how `ShuffleQueryStageExec` presents to `.exists` walks (its `children` - // is `Seq.empty`). `withNewChildren` preserves tree-node tags, so if the fix is in - // place the explain-info tag on `shuffle` carries over to `postAqeShuffle`, and the - // decision short-circuits to None. Without the fix, the DPP walk re-runs, fails to - // see the scan, and flips to Some(...). + // matching how ShuffleQueryStageExec presents to .exists walks (its children + // is Seq.empty). withNewChildren preserves tree-node tags, so if the fix is in + // place the explain-info tag on shuffle carries over to postAqeShuffle, and the + // decision short-circuits to None. Without the fix, the DPP walk re-runs, fails + // to see the scan, and flips to Some(...). val hiddenChild = OpaqueStageStub(shuffle.child.output) val postAqeShuffle = shuffle.withNewChildren(Seq(hiddenChild)).asInstanceOf[ShuffleExchangeExec] diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometShuffleFallbackStickinessSuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometShuffleFallbackStickinessSuite.scala index 4302ab1391..b671e04042 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometShuffleFallbackStickinessSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometShuffleFallbackStickinessSuite.scala @@ -82,7 +82,12 @@ class CometShuffleFallbackStickinessSuite extends CometTestBase { spark.read.parquet(factPath).createOrReplaceTempView("t_sticky_fact") spark.read.parquet(dimPath).createOrReplaceTempView("t_sticky_dim") + // Disable native scan so the scan stays as FileSourceScanExec with DPP, + // producing the mixed state (Spark shuffle wrapping Spark DPP scan) that + // stageContainsDPPScan is designed to catch. With native scan enabled, AQE DPP + // scans convert to CometNativeScanExec and the shuffle goes native. withSQLConf( + CometConf.COMET_NATIVE_SCAN_ENABLED.key -> "false", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", SQLConf.PREFER_SORTMERGEJOIN.key -> "true", SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",