From 1d08bb70eb67ff0b02430c953f47e8d2a5d2471c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 30 Apr 2026 13:10:53 -0600 Subject: [PATCH 1/4] fix: broadcast exchange bypasses AQE partition coalescing When a shuffle feeds into a broadcast exchange, AQE may coalesce the shuffle partitions (e.g. from 200 to 7). CometBroadcastExchangeExec was extracting the underlying plan from AQEShuffleReadExec and executing it directly, bypassing the coalescing. This caused the broadcast collect to run against all original shuffle partitions instead of the coalesced count, inflating task overhead and shuffle data. Execute through the AQEShuffleReadExec to respect partition coalescing. --- .../comet/CometBroadcastExchangeExec.scala | 18 +++++++++---- .../apache/comet/exec/CometJoinSuite.scala | 26 +++++++++++++++++++ 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index 8012b18b22..26eafe51ad 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -110,6 +110,12 @@ case class CometBroadcastExchangeExec( @transient private lazy val maxBroadcastRows = 512000000 + private def getByteArrayRdd(plan: SparkPlan): RDD[(Long, ChunkedByteBuffer)] = { + plan.executeColumnar().mapPartitionsInternal { iter => + Utils.serializeBatches(iter) + } + } + def getNumPartitions(): Int = { child.executeColumnar().getNumPartitions } @@ -125,16 +131,18 @@ case class CometBroadcastExchangeExec( val countsAndBytes = child match { case c: CometPlan => CometExec.getByteArrayRdd(c).collect() - case AQEShuffleReadExec(s: ShuffleQueryStageExec, _) + // Execute through AQEShuffleReadExec to respect AQE partition coalescing + case aqe @ AQEShuffleReadExec(s: ShuffleQueryStageExec, _) if s.plan.isInstanceOf[CometPlan] => - CometExec.getByteArrayRdd(s.plan.asInstanceOf[CometPlan]).collect() + getByteArrayRdd(aqe).collect() case s: ShuffleQueryStageExec if s.plan.isInstanceOf[CometPlan] => CometExec.getByteArrayRdd(s.plan.asInstanceOf[CometPlan]).collect() case ReusedExchangeExec(_, plan) if plan.isInstanceOf[CometPlan] => CometExec.getByteArrayRdd(plan.asInstanceOf[CometPlan]).collect() - case AQEShuffleReadExec(ShuffleQueryStageExec(_, ReusedExchangeExec(_, plan), _), _) - if plan.isInstanceOf[CometPlan] => - CometExec.getByteArrayRdd(plan.asInstanceOf[CometPlan]).collect() + case aqe @ AQEShuffleReadExec( + ShuffleQueryStageExec(_, ReusedExchangeExec(_, plan), _), + _) if plan.isInstanceOf[CometPlan] => + getByteArrayRdd(aqe).collect() case ShuffleQueryStageExec(_, ReusedExchangeExec(_, plan), _) if plan.isInstanceOf[CometPlan] => CometExec.getByteArrayRdd(plan.asInstanceOf[CometPlan]).collect() diff --git a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala index 49fbe10c30..f7956edac3 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala @@ -443,4 +443,30 @@ class CometJoinSuite extends CometTestBase { } } } + + test("Broadcast exchange respects AQE shuffle partition coalescing") { + // When a shuffle feeds into a broadcast exchange, AQE may coalesce the shuffle + // partitions. The broadcast collect should execute through the AQEShuffleReadExec + // to use coalesced partitions rather than bypassing it. + val numPartitions = 200 + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.SHUFFLE_PARTITIONS.key -> numPartitions.toString, + SQLConf.PREFER_SORTMERGEJOIN.key -> "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB", + SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true") { + withParquetTable((0 until 100).map(i => (i, i % 5)), "small_tbl") { + withParquetTable((0 until 10000).map(i => (i, i + 2)), "large_tbl") { + val query = + s"""SELECT /*+ BROADCAST(a) */ * + |FROM (SELECT /*+ REPARTITION($numPartitions) */ * FROM small_tbl) a + |JOIN large_tbl b ON a._1 = b._1""".stripMargin + + checkSparkAnswerAndOperator( + sql(query), + Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastHashJoinExec])) + } + } + } + } } From 069faa05a24aea093bd9e40217fed4a7ac120e9d Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 1 May 2026 13:55:42 -0600 Subject: [PATCH 2/4] test: assert AQE coalesce metric to catch broadcast bypass regressions Use REBALANCE so AQE actually inserts AQEShuffleReadExec, then verify the read exec's `numPartitions` driver metric reflects coalescing. The metric is only set when AQEShuffleReadExec.executeColumnar runs, so a broadcast that bypasses the wrapper leaves it at 0 and fails the test. --- .../apache/comet/exec/CometJoinSuite.scala | 25 +++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala index f7956edac3..f63d7be96c 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec} +import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec import org.apache.spark.sql.internal.SQLConf import org.apache.comet.CometConf @@ -459,12 +460,32 @@ class CometJoinSuite extends CometTestBase { withParquetTable((0 until 10000).map(i => (i, i + 2)), "large_tbl") { val query = s"""SELECT /*+ BROADCAST(a) */ * - |FROM (SELECT /*+ REPARTITION($numPartitions) */ * FROM small_tbl) a + |FROM (SELECT /*+ REBALANCE(_1) */ * FROM small_tbl) a |JOIN large_tbl b ON a._1 = b._1""".stripMargin - checkSparkAnswerAndOperator( + val (_, cometPlan) = checkSparkAnswerAndOperator( sql(query), Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastHashJoinExec])) + + // The shuffle partitions feeding the broadcast should be coalesced by + // AQE. AQEShuffleReadExec.executeColumnar() lazily builds its shuffleRDD + // and, as a side effect, sets the "numPartitions" driver metric to + // partitionSpecs.length. If the broadcast collect bypasses the wrapper + // (the bug this test guards against), executeColumnar is never called + // and the metric stays at its initial 0. + val readExecs = collect(cometPlan) { case r: AQEShuffleReadExec => r } + assert(readExecs.nonEmpty, "Expected AQEShuffleReadExec in plan") + readExecs.foreach { r => + val coalesced = r.metrics("numPartitions").value + assert( + coalesced > 0, + "AQEShuffleReadExec.numPartitions metric was never updated; the " + + "broadcast collect likely bypassed AQEShuffleReadExec") + assert( + coalesced < numPartitions, + s"Expected AQE to coalesce shuffle partitions below $numPartitions, " + + s"got $coalesced") + } } } } From d51c4e2b6e13189d74627ae447e91d83d29171b1 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 3 May 2026 11:28:05 -0600 Subject: [PATCH 3/4] style: drop unused string interpolator in broadcast AQE test --- .../test/scala/org/apache/comet/exec/CometJoinSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala index f63d7be96c..b2b0a7ea18 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala @@ -459,9 +459,9 @@ class CometJoinSuite extends CometTestBase { withParquetTable((0 until 100).map(i => (i, i % 5)), "small_tbl") { withParquetTable((0 until 10000).map(i => (i, i + 2)), "large_tbl") { val query = - s"""SELECT /*+ BROADCAST(a) */ * - |FROM (SELECT /*+ REBALANCE(_1) */ * FROM small_tbl) a - |JOIN large_tbl b ON a._1 = b._1""".stripMargin + """SELECT /*+ BROADCAST(a) */ * + |FROM (SELECT /*+ REBALANCE(_1) */ * FROM small_tbl) a + |JOIN large_tbl b ON a._1 = b._1""".stripMargin val (_, cometPlan) = checkSparkAnswerAndOperator( sql(query), From 8e2213d0b8ff4ab8d067e0c10a4d3a7a9f6479b9 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 4 May 2026 10:35:40 -0600 Subject: [PATCH 4/4] simplify broadcast collect to dispatch by execution instead of type peeking Replace the per-wrapper-type match block with a single getByteArrayRdd(child) call. This executes through AQE wrappers (coalescing, skew splits, local reads) instead of pattern-matching past them, preventing future regressions when AQE adds new wrappers. --- .../comet/CometBroadcastExchangeExec.scala | 32 ++----------------- 1 file changed, 2 insertions(+), 30 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index 6d0098d00d..9e2296067b 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -35,8 +35,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPar import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, SQLExecution} -import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, ShuffleQueryStageExec} -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ReusedExchangeExec} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -129,34 +128,7 @@ case class CometBroadcastExchangeExec( setJobGroupOrTag(sparkContext, this) val beforeCollect = System.nanoTime() - val countsAndBytes = child match { - case c: CometPlan => CometExec.getByteArrayRdd(c).collect() - // Execute through AQEShuffleReadExec to respect AQE partition coalescing - case aqe @ AQEShuffleReadExec(s: ShuffleQueryStageExec, _) - if s.plan.isInstanceOf[CometPlan] => - getByteArrayRdd(aqe).collect() - case s: ShuffleQueryStageExec if s.plan.isInstanceOf[CometPlan] => - CometExec.getByteArrayRdd(s.plan.asInstanceOf[CometPlan]).collect() - case ReusedExchangeExec(_, plan) if plan.isInstanceOf[CometPlan] => - CometExec.getByteArrayRdd(plan.asInstanceOf[CometPlan]).collect() - case aqe @ AQEShuffleReadExec( - ShuffleQueryStageExec(_, ReusedExchangeExec(_, plan), _), - _) if plan.isInstanceOf[CometPlan] => - getByteArrayRdd(aqe).collect() - case ShuffleQueryStageExec(_, ReusedExchangeExec(_, plan), _) - if plan.isInstanceOf[CometPlan] => - CometExec.getByteArrayRdd(plan.asInstanceOf[CometPlan]).collect() - case _ => - // Non-Comet child (e.g., RowToColumnar -> LocalTableScan). Happens when - // AQE re-optimizes inside an ASPE and replaces the original Comet scan - // with a Spark-native node (e.g., empty broadcast triggers LocalTableScan). - logWarning( - "CometBroadcastExchangeExec child is not CometPlan: " + - s"${child.getClass.getSimpleName}. " + - "Wrapping in CometSparkToColumnarExec for Arrow serialization.") - val cometChild = CometSparkToColumnarExec(ColumnarToRowExec(child)) - CometExec.getByteArrayRdd(cometChild).collect() - } + val countsAndBytes = getByteArrayRdd(child).collect() val numRows = countsAndBytes.map(_._1).sum val input = countsAndBytes.iterator.map(countAndBytes => countAndBytes._2)