Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPar
import org.apache.spark.sql.comet.util.Utils
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.{ColumnarToRowExec, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike, ReusedExchangeExec}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand Down Expand Up @@ -110,6 +109,12 @@ case class CometBroadcastExchangeExec(
@transient
private lazy val maxBroadcastRows = 512000000

private def getByteArrayRdd(plan: SparkPlan): RDD[(Long, ChunkedByteBuffer)] = {
plan.executeColumnar().mapPartitionsInternal { iter =>
Utils.serializeBatches(iter)
}
}

def getNumPartitions(): Int = {
child.executeColumnar().getNumPartitions
}
Expand All @@ -123,32 +128,7 @@ case class CometBroadcastExchangeExec(
setJobGroupOrTag(sparkContext, this)
val beforeCollect = System.nanoTime()

val countsAndBytes = child match {
case c: CometPlan => CometExec.getByteArrayRdd(c).collect()
case AQEShuffleReadExec(s: ShuffleQueryStageExec, _)
if s.plan.isInstanceOf[CometPlan] =>
CometExec.getByteArrayRdd(s.plan.asInstanceOf[CometPlan]).collect()
case s: ShuffleQueryStageExec if s.plan.isInstanceOf[CometPlan] =>
CometExec.getByteArrayRdd(s.plan.asInstanceOf[CometPlan]).collect()
case ReusedExchangeExec(_, plan) if plan.isInstanceOf[CometPlan] =>
CometExec.getByteArrayRdd(plan.asInstanceOf[CometPlan]).collect()
case AQEShuffleReadExec(ShuffleQueryStageExec(_, ReusedExchangeExec(_, plan), _), _)
if plan.isInstanceOf[CometPlan] =>
CometExec.getByteArrayRdd(plan.asInstanceOf[CometPlan]).collect()
case ShuffleQueryStageExec(_, ReusedExchangeExec(_, plan), _)
if plan.isInstanceOf[CometPlan] =>
CometExec.getByteArrayRdd(plan.asInstanceOf[CometPlan]).collect()
case _ =>
// Non-Comet child (e.g., RowToColumnar -> LocalTableScan). Happens when
// AQE re-optimizes inside an ASPE and replaces the original Comet scan
// with a Spark-native node (e.g., empty broadcast triggers LocalTableScan).
logWarning(
"CometBroadcastExchangeExec child is not CometPlan: " +
s"${child.getClass.getSimpleName}. " +
"Wrapping in CometSparkToColumnarExec for Arrow serialization.")
val cometChild = CometSparkToColumnarExec(ColumnarToRowExec(child))
CometExec.getByteArrayRdd(cometChild).collect()
}
val countsAndBytes = getByteArrayRdd(child).collect()

val numRows = countsAndBytes.map(_._1).sum
val input = countsAndBytes.iterator.map(countAndBytes => countAndBytes._2)
Expand Down
47 changes: 47 additions & 0 deletions spark/src/test/scala/org/apache/comet/exec/CometJoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.CometTestBase
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec}
import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec
import org.apache.spark.sql.internal.SQLConf

import org.apache.comet.CometConf
Expand Down Expand Up @@ -517,4 +518,50 @@ class CometJoinSuite extends CometTestBase {
}
}
}

test("Broadcast exchange respects AQE shuffle partition coalescing") {
// When a shuffle feeds into a broadcast exchange, AQE may coalesce the shuffle
// partitions. The broadcast collect should execute through the AQEShuffleReadExec
// to use coalesced partitions rather than bypassing it.
val numPartitions = 200
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.SHUFFLE_PARTITIONS.key -> numPartitions.toString,
SQLConf.PREFER_SORTMERGEJOIN.key -> "false",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB",
SQLConf.COALESCE_PARTITIONS_ENABLED.key -> "true") {
withParquetTable((0 until 100).map(i => (i, i % 5)), "small_tbl") {
withParquetTable((0 until 10000).map(i => (i, i + 2)), "large_tbl") {
val query =
"""SELECT /*+ BROADCAST(a) */ *
|FROM (SELECT /*+ REBALANCE(_1) */ * FROM small_tbl) a
|JOIN large_tbl b ON a._1 = b._1""".stripMargin

val (_, cometPlan) = checkSparkAnswerAndOperator(
sql(query),
Seq(classOf[CometBroadcastExchangeExec], classOf[CometBroadcastHashJoinExec]))

// The shuffle partitions feeding the broadcast should be coalesced by
// AQE. AQEShuffleReadExec.executeColumnar() lazily builds its shuffleRDD
// and, as a side effect, sets the "numPartitions" driver metric to
// partitionSpecs.length. If the broadcast collect bypasses the wrapper
// (the bug this test guards against), executeColumnar is never called
// and the metric stays at its initial 0.
val readExecs = collect(cometPlan) { case r: AQEShuffleReadExec => r }
assert(readExecs.nonEmpty, "Expected AQEShuffleReadExec in plan")
readExecs.foreach { r =>
val coalesced = r.metrics("numPartitions").value
assert(
coalesced > 0,
"AQEShuffleReadExec.numPartitions metric was never updated; the " +
"broadcast collect likely bypassed AQEShuffleReadExec")
assert(
coalesced < numPartitions,
s"Expected AQE to coalesce shuffle partitions below $numPartitions, " +
s"got $coalesced")
}
}
}
}
}
}
Loading