diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala index c569c34d01..9e2296067b 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometBroadcastExchangeExec.scala @@ -35,8 +35,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPar import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, SQLExecution} -import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, ShuffleQueryStageExec} -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ReusedExchangeExec} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -110,6 +109,12 @@ case class CometBroadcastExchangeExec( @transient private lazy val maxBroadcastRows = 512000000 + private def getByteArrayRdd(plan: SparkPlan): RDD[(Long, ChunkedByteBuffer)] = { + plan.executeColumnar().mapPartitionsInternal { iter => + Utils.serializeBatches(iter) + } + } + def getNumPartitions(): Int = { child.executeColumnar().getNumPartitions } @@ -123,32 +128,7 @@ case class CometBroadcastExchangeExec( setJobGroupOrTag(sparkContext, this) val beforeCollect = System.nanoTime() - val countsAndBytes = child match { - case c: CometPlan => CometExec.getByteArrayRdd(c).collect() - case AQEShuffleReadExec(s: ShuffleQueryStageExec, _) - if s.plan.isInstanceOf[CometPlan] => - CometExec.getByteArrayRdd(s.plan.asInstanceOf[CometPlan]).collect() - case s: ShuffleQueryStageExec if s.plan.isInstanceOf[CometPlan] => - CometExec.getByteArrayRdd(s.plan.asInstanceOf[CometPlan]).collect() - case ReusedExchangeExec(_, plan) if plan.isInstanceOf[CometPlan] => - CometExec.getByteArrayRdd(plan.asInstanceOf[CometPlan]).collect() - case AQEShuffleReadExec(ShuffleQueryStageExec(_, ReusedExchangeExec(_, plan), _), _) - if plan.isInstanceOf[CometPlan] => - CometExec.getByteArrayRdd(plan.asInstanceOf[CometPlan]).collect() - case ShuffleQueryStageExec(_, ReusedExchangeExec(_, plan), _) - if plan.isInstanceOf[CometPlan] => - CometExec.getByteArrayRdd(plan.asInstanceOf[CometPlan]).collect() - case _ => - // Non-Comet child (e.g., RowToColumnar -> LocalTableScan). Happens when - // AQE re-optimizes inside an ASPE and replaces the original Comet scan - // with a Spark-native node (e.g., empty broadcast triggers LocalTableScan). - logWarning( - "CometBroadcastExchangeExec child is not CometPlan: " + - s"${child.getClass.getSimpleName}. " + - "Wrapping in CometSparkToColumnarExec for Arrow serialization.") - val cometChild = CometSparkToColumnarExec(ColumnarToRowExec(child)) - CometExec.getByteArrayRdd(cometChild).collect() - } + val countsAndBytes = getByteArrayRdd(child).collect() val numRows = countsAndBytes.map(_._1).sum val input = countsAndBytes.iterator.map(countAndBytes => countAndBytes._2) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala index 6ccc2a5778..ccc8faa830 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec} +import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec import org.apache.spark.sql.internal.SQLConf import org.apache.comet.CometConf @@ -517,4 +518,50 @@ class CometJoinSuite extends CometTestBase { } } } + + test("Broadcast exchange respects AQE shuffle partition coalescing") { + // When a shuffle feeds into a broadcast exchange, AQE may coalesce the shuffle + // partitions. The broadcast collect should execute through the AQEShuffleReadExec + // to use coalesced partitions rather than bypassing it. + val numPartitions = 200 + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.SHUFFLE_PARTITIONS.key -> numPartitions.toString, + SQLConf.PREFER_SORTMERGEJOIN.key -> "false", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB", + SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true") { + withParquetTable((0 until 100).map(i => (i, i % 5)), "small_tbl") { + withParquetTable((0 until 10000).map(i => (i, i + 2)), "large_tbl") { + val query = + """SELECT /*+ BROADCAST(a) */ * + |FROM (SELECT /*+ REBALANCE(_1) */ * FROM small_tbl) a + |JOIN large_tbl b ON a._1 = b._1""".stripMargin + + val (_, cometPlan) = checkSparkAnswerAndOperator( + sql(query), + Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastHashJoinExec])) + + // The shuffle partitions feeding the broadcast should be coalesced by + // AQE. AQEShuffleReadExec.executeColumnar() lazily builds its shuffleRDD + // and, as a side effect, sets the "numPartitions" driver metric to + // partitionSpecs.length. If the broadcast collect bypasses the wrapper + // (the bug this test guards against), executeColumnar is never called + // and the metric stays at its initial 0. + val readExecs = collect(cometPlan) { case r: AQEShuffleReadExec => r } + assert(readExecs.nonEmpty, "Expected AQEShuffleReadExec in plan") + readExecs.foreach { r => + val coalesced = r.metrics("numPartitions").value + assert( + coalesced > 0, + "AQEShuffleReadExec.numPartitions metric was never updated; the " + + "broadcast collect likely bypassed AQEShuffleReadExec") + assert( + coalesced < numPartitions, + s"Expected AQE to coalesce shuffle partitions below $numPartitions, " + + s"got $coalesced") + } + } + } + } + } }