From 29fe1c0f1e2a8055039ed09f42c234927987fa29 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 20 Jan 2026 19:45:44 -0700 Subject: [PATCH 01/12] feat: [EXPERIMENTAL] direct native shuffle execution optimization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR introduces an experimental optimization that allows the native shuffle writer to directly execute the child native plan instead of reading intermediate batches via JNI. This avoids the JNI round-trip for single-source native plans. Current flow: Native Plan → ColumnarBatch → JNI → ScanExec → ShuffleWriterExec Optimized flow: Native Plan → (directly in native) → ShuffleWriterExec The optimization is: - Disabled by default (spark.comet.exec.shuffle.directNative.enabled=false) - Only applies to CometNativeShuffle (not columnar JVM shuffle) - Only applies to single-source native scans (CometNativeScanExec) - Does not apply to RangePartitioning (requires sampling) Changes: - CometShuffleDependency: Added childNativePlan field to pass native plan - CometShuffleExchangeExec: Added detection logic for single-source native plans - CometShuffleManager: Pass native plan to shuffle writer - CometNativeShuffleWriter: Use child native plan directly when available - CometConf: Added COMET_SHUFFLE_DIRECT_NATIVE_ENABLED config option - CometDirectNativeShuffleSuite: Comprehensive test suite with 15 tests Co-Authored-By: Claude Opus 4.5 --- .../scala/org/apache/comet/CometConf.scala | 12 + .../shuffle/CometNativeShuffleWriter.scala | 260 ++++++++-------- .../shuffle/CometShuffleDependency.scala | 6 +- .../shuffle/CometShuffleExchangeExec.scala | 105 ++++++- .../shuffle/CometShuffleManager.scala | 3 +- .../exec/CometDirectNativeShuffleSuite.scala | 286 ++++++++++++++++++ 6 files changed, 537 insertions(+), 135 deletions(-) create mode 100644 spark/src/test/scala/org/apache/comet/exec/CometDirectNativeShuffleSuite.scala diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 89dbb6468d..65061282c9 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -319,6 +319,18 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(true) + val COMET_SHUFFLE_DIRECT_NATIVE_ENABLED: ConfigEntry[Boolean] = + conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.directNative.enabled") + .category(CATEGORY_SHUFFLE) + .doc( + "When enabled, the native shuffle writer will directly execute the child native plan " + + "instead of reading intermediate batches via JNI. This optimization avoids the " + + "JNI round-trip for single-source native plans (e.g., Scan -> Filter -> Project). " + + "This is an experimental feature and is disabled by default.") + .internal() + .booleanConf + .createWithDefault(false) + val COMET_SHUFFLE_MODE: ConfigEntry[String] = conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.mode") .category(CATEGORY_SHUFFLE) .doc( diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala index b5d15b41f4..5655c7e492 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala @@ -43,6 +43,11 @@ import org.apache.comet.serde.QueryPlanSerde.serializeDataType /** * A [[ShuffleWriter]] that will delegate shuffle write to native shuffle. + * + * @param childNativePlan + * When provided, the shuffle writer will execute this native plan directly and pipe its output + * to the ShuffleWriter, avoiding the JNI round-trip for intermediate batches. This is used for + * direct native execution optimization when the shuffle's child is a single-source native plan. */ class CometNativeShuffleWriter[K, V]( outputPartitioning: Partitioning, @@ -53,7 +58,8 @@ class CometNativeShuffleWriter[K, V]( mapId: Long, context: TaskContext, metricsReporter: ShuffleWriteMetricsReporter, - rangePartitionBounds: Option[Seq[InternalRow]] = None) + rangePartitionBounds: Option[Seq[InternalRow]] = None, + childNativePlan: Option[Operator] = None) extends ShuffleWriter[K, V] with Logging { @@ -163,150 +169,150 @@ class CometNativeShuffleWriter[K, V]( } private def getNativePlan(dataFile: String, indexFile: String): Operator = { - val scanBuilder = OperatorOuterClass.Scan.newBuilder().setSource("ShuffleWriterInput") - val opBuilder = OperatorOuterClass.Operator.newBuilder() - - val scanTypes = outputAttributes.flatten { attr => - serializeDataType(attr.dataType) - } - - if (scanTypes.length == outputAttributes.length) { + // When childNativePlan is provided, we use it directly as the input to ShuffleWriter. + // Otherwise, we create a Scan operator that reads from JNI input ("ShuffleWriterInput"). + val inputOperator: Operator = childNativePlan.getOrElse { + val scanBuilder = OperatorOuterClass.Scan.newBuilder().setSource("ShuffleWriterInput") + val scanTypes = outputAttributes.flatten { attr => + serializeDataType(attr.dataType) + } + if (scanTypes.length != outputAttributes.length) { + throw new UnsupportedOperationException( + s"$outputAttributes contains unsupported data types for CometShuffleExchangeExec.") + } scanBuilder.addAllFields(scanTypes.asJava) + OperatorOuterClass.Operator.newBuilder().setScan(scanBuilder).build() + } - val shuffleWriterBuilder = OperatorOuterClass.ShuffleWriter.newBuilder() - shuffleWriterBuilder.setOutputDataFile(dataFile) - shuffleWriterBuilder.setOutputIndexFile(indexFile) + val shuffleWriterBuilder = OperatorOuterClass.ShuffleWriter.newBuilder() + shuffleWriterBuilder.setOutputDataFile(dataFile) + shuffleWriterBuilder.setOutputIndexFile(indexFile) - if (SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true)) { - val codec = CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.get() match { - case "zstd" => CompressionCodec.Zstd - case "lz4" => CompressionCodec.Lz4 - case "snappy" => CompressionCodec.Snappy - case other => throw new UnsupportedOperationException(s"invalid codec: $other") - } - shuffleWriterBuilder.setCodec(codec) - } else { - shuffleWriterBuilder.setCodec(CompressionCodec.None) + if (SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true)) { + val codec = CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.get() match { + case "zstd" => CompressionCodec.Zstd + case "lz4" => CompressionCodec.Lz4 + case "snappy" => CompressionCodec.Snappy + case other => throw new UnsupportedOperationException(s"invalid codec: $other") } - shuffleWriterBuilder.setCompressionLevel( - CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL.get) - shuffleWriterBuilder.setWriteBufferSize( - CometConf.COMET_SHUFFLE_WRITE_BUFFER_SIZE.get().max(Int.MaxValue).toInt) + shuffleWriterBuilder.setCodec(codec) + } else { + shuffleWriterBuilder.setCodec(CompressionCodec.None) + } + shuffleWriterBuilder.setCompressionLevel( + CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL.get) + shuffleWriterBuilder.setWriteBufferSize( + CometConf.COMET_SHUFFLE_WRITE_BUFFER_SIZE.get().max(Int.MaxValue).toInt) - outputPartitioning match { - case p if isSinglePartitioning(p) => - val partitioning = PartitioningOuterClass.SinglePartition.newBuilder() + outputPartitioning match { + case p if isSinglePartitioning(p) => + val partitioning = PartitioningOuterClass.SinglePartition.newBuilder() - val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() - shuffleWriterBuilder.setPartitioning( - partitioningBuilder.setSinglePartition(partitioning).build()) - case _: HashPartitioning => - val hashPartitioning = outputPartitioning.asInstanceOf[HashPartitioning] + val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() + shuffleWriterBuilder.setPartitioning( + partitioningBuilder.setSinglePartition(partitioning).build()) + case _: HashPartitioning => + val hashPartitioning = outputPartitioning.asInstanceOf[HashPartitioning] - val partitioning = PartitioningOuterClass.HashPartition.newBuilder() - partitioning.setNumPartitions(outputPartitioning.numPartitions) + val partitioning = PartitioningOuterClass.HashPartition.newBuilder() + partitioning.setNumPartitions(outputPartitioning.numPartitions) - val partitionExprs = hashPartitioning.expressions - .flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes)) + val partitionExprs = hashPartitioning.expressions + .flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes)) - if (partitionExprs.length != hashPartitioning.expressions.length) { - throw new UnsupportedOperationException( - s"Partitioning $hashPartitioning is not supported.") - } + if (partitionExprs.length != hashPartitioning.expressions.length) { + throw new UnsupportedOperationException( + s"Partitioning $hashPartitioning is not supported.") + } - partitioning.addAllHashExpression(partitionExprs.asJava) - - val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() - shuffleWriterBuilder.setPartitioning( - partitioningBuilder.setHashPartition(partitioning).build()) - case _: RangePartitioning => - val rangePartitioning = outputPartitioning.asInstanceOf[RangePartitioning] - - val partitioning = PartitioningOuterClass.RangePartition.newBuilder() - partitioning.setNumPartitions(outputPartitioning.numPartitions) - - // Detect duplicates by tracking expressions directly, similar to DataFusion's LexOrdering - // DataFusion will deduplicate identical sort expressions in LexOrdering, - // so we need to transform boundary rows to match the deduplicated structure - val seenExprs = mutable.HashSet[Expression]() - val deduplicationMap = mutable.ArrayBuffer[(Int, Boolean)]() // (originalIndex, isKept) - - rangePartitioning.ordering.zipWithIndex.foreach { case (sortOrder, idx) => - if (seenExprs.contains(sortOrder.child)) { - deduplicationMap += (idx -> false) // Will be deduplicated by DataFusion - } else { - seenExprs += sortOrder.child - deduplicationMap += (idx -> true) // Will be kept by DataFusion - } + partitioning.addAllHashExpression(partitionExprs.asJava) + + val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() + shuffleWriterBuilder.setPartitioning( + partitioningBuilder.setHashPartition(partitioning).build()) + case _: RangePartitioning => + val rangePartitioning = outputPartitioning.asInstanceOf[RangePartitioning] + + val partitioning = PartitioningOuterClass.RangePartition.newBuilder() + partitioning.setNumPartitions(outputPartitioning.numPartitions) + + // Detect duplicates by tracking expressions directly, similar to DataFusion's LexOrdering + // DataFusion will deduplicate identical sort expressions in LexOrdering, + // so we need to transform boundary rows to match the deduplicated structure + val seenExprs = mutable.HashSet[Expression]() + val deduplicationMap = mutable.ArrayBuffer[(Int, Boolean)]() // (originalIndex, isKept) + + rangePartitioning.ordering.zipWithIndex.foreach { case (sortOrder, idx) => + if (seenExprs.contains(sortOrder.child)) { + deduplicationMap += (idx -> false) // Will be deduplicated by DataFusion + } else { + seenExprs += sortOrder.child + deduplicationMap += (idx -> true) // Will be kept by DataFusion } + } - { - // Serialize the ordering expressions for comparisons - val orderingExprs = rangePartitioning.ordering - .flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes)) - if (orderingExprs.length != rangePartitioning.ordering.length) { - throw new UnsupportedOperationException( - s"Partitioning $rangePartitioning is not supported.") - } - partitioning.addAllSortOrders(orderingExprs.asJava) + { + // Serialize the ordering expressions for comparisons + val orderingExprs = rangePartitioning.ordering + .flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes)) + if (orderingExprs.length != rangePartitioning.ordering.length) { + throw new UnsupportedOperationException( + s"Partitioning $rangePartitioning is not supported.") } + partitioning.addAllSortOrders(orderingExprs.asJava) + } - // Convert Spark's sequence of InternalRows that represent partitioning boundaries to - // sequences of Literals, where each outer entry represents a boundary row, and each - // internal entry is a value in that row. In other words, these are stored in row major - // order, not column major - val boundarySchema = rangePartitioning.ordering.flatMap(e => Some(e.dataType)) - - // Transform boundary rows to match DataFusion's deduplicated structure - val transformedBoundaryExprs: Seq[Seq[Literal]] = - rangePartitionBounds.get.map((row: InternalRow) => { - // For every InternalRow, map its values to Literals - val allLiterals = - row.toSeq(boundarySchema).zip(boundarySchema).map { case (value, valueType) => - Literal(value, valueType) - } - - // Keep only the literals that correspond to non-deduplicated expressions - allLiterals - .zip(deduplicationMap) - .filter(_._2._2) // Keep only where isKept = true - .map(_._1) // Extract the literal + // Convert Spark's sequence of InternalRows that represent partitioning boundaries to + // sequences of Literals, where each outer entry represents a boundary row, and each + // internal entry is a value in that row. In other words, these are stored in row major + // order, not column major + val boundarySchema = rangePartitioning.ordering.flatMap(e => Some(e.dataType)) + + // Transform boundary rows to match DataFusion's deduplicated structure + val transformedBoundaryExprs: Seq[Seq[Literal]] = + rangePartitionBounds.get.map((row: InternalRow) => { + // For every InternalRow, map its values to Literals + val allLiterals = + row.toSeq(boundarySchema).zip(boundarySchema).map { case (value, valueType) => + Literal(value, valueType) + } + + // Keep only the literals that correspond to non-deduplicated expressions + allLiterals + .zip(deduplicationMap) + .filter(_._2._2) // Keep only where isKept = true + .map(_._1) // Extract the literal + }) + + { + // Convert the sequences of Literals to a collection of serialized BoundaryRows + val boundaryRows: Seq[PartitioningOuterClass.BoundaryRow] = transformedBoundaryExprs + .map((rowLiterals: Seq[Literal]) => { + // Serialize each sequence of Literals as a BoundaryRow + val rowBuilder = PartitioningOuterClass.BoundaryRow.newBuilder(); + val serializedExprs = + rowLiterals.map(lit_value => + QueryPlanSerde.exprToProto(lit_value, outputAttributes).get) + rowBuilder.addAllPartitionBounds(serializedExprs.asJava) + rowBuilder.build() }) + partitioning.addAllBoundaryRows(boundaryRows.asJava) + } - { - // Convert the sequences of Literals to a collection of serialized BoundaryRows - val boundaryRows: Seq[PartitioningOuterClass.BoundaryRow] = transformedBoundaryExprs - .map((rowLiterals: Seq[Literal]) => { - // Serialize each sequence of Literals as a BoundaryRow - val rowBuilder = PartitioningOuterClass.BoundaryRow.newBuilder(); - val serializedExprs = - rowLiterals.map(lit_value => - QueryPlanSerde.exprToProto(lit_value, outputAttributes).get) - rowBuilder.addAllPartitionBounds(serializedExprs.asJava) - rowBuilder.build() - }) - partitioning.addAllBoundaryRows(boundaryRows.asJava) - } - - val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() - shuffleWriterBuilder.setPartitioning( - partitioningBuilder.setRangePartition(partitioning).build()) - - case _ => - throw new UnsupportedOperationException( - s"Partitioning $outputPartitioning is not supported.") - } + val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() + shuffleWriterBuilder.setPartitioning( + partitioningBuilder.setRangePartition(partitioning).build()) - val shuffleWriterOpBuilder = OperatorOuterClass.Operator.newBuilder() - shuffleWriterOpBuilder - .setShuffleWriter(shuffleWriterBuilder) - .addChildren(opBuilder.setScan(scanBuilder).build()) - .build() - } else { - // There are unsupported scan type - throw new UnsupportedOperationException( - s"$outputAttributes contains unsupported data types for CometShuffleExchangeExec.") + case _ => + throw new UnsupportedOperationException( + s"Partitioning $outputPartitioning is not supported.") } + + val shuffleWriterOpBuilder = OperatorOuterClass.Operator.newBuilder() + shuffleWriterOpBuilder + .setShuffleWriter(shuffleWriterBuilder) + .addChildren(inputOperator) + .build() } override def stop(success: Boolean): Option[MapStatus] = { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala index 2b74e5a168..3528b6d2c9 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala @@ -31,6 +31,8 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.StructType +import org.apache.comet.serde.OperatorOuterClass.Operator + /** * A [[ShuffleDependency]] that allows us to identify the shuffle dependency as a Comet shuffle. */ @@ -49,7 +51,9 @@ class CometShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( val outputAttributes: Seq[Attribute] = Seq.empty, val shuffleWriteMetrics: Map[String, SQLMetric] = Map.empty, val numParts: Int = 0, - val rangePartitionBounds: Option[Seq[InternalRow]] = None) + val rangePartitionBounds: Option[Seq[InternalRow]] = None, + // For direct native execution: the child's native plan to compose with ShuffleWriter + val childNativePlan: Option[Operator] = None) extends ShuffleDependency[K, V, C]( _rdd, partitioner, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 1805711d01..0b829aa6ac 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Uns import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.comet.{CometMetricNode, CometNativeExec, CometPlan, CometSinkPlaceHolder} +import org.apache.spark.sql.comet.{CometBatchScanExec, CometMetricNode, CometNativeExec, CometNativeScanExec, CometPlan, CometScanExec, CometSinkPlaceHolder} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} @@ -52,6 +52,7 @@ import org.apache.comet.CometConf import org.apache.comet.CometConf.{COMET_EXEC_SHUFFLE_ENABLED, COMET_SHUFFLE_MODE} import org.apache.comet.CometSparkSessionExtensions.{isCometShuffleManagerEnabled, withInfo} import org.apache.comet.serde.{Compatible, OperatorOuterClass, QueryPlanSerde, SupportLevel, Unsupported} +import org.apache.comet.serde.OperatorOuterClass.Operator import org.apache.comet.serde.operator.CometSink import org.apache.comet.shims.ShimCometShuffleExchangeExec @@ -89,9 +90,85 @@ case class CometShuffleExchangeExec( private lazy val serializer: Serializer = new UnsafeRowSerializer(child.output.size, longMetric("dataSize")) + /** + * Information about direct native execution optimization. When the child is a single-source + * native plan with a fully native scan (CometNativeScanExec), we can pass the child's native + * plan to the shuffle writer and execute: Scan -> Filter -> Project -> ShuffleWriter all in + * native code, avoiding the JNI round-trip for intermediate batches. + * + * Currently only supports CometNativeScanExec (fully native scans that read files directly via + * DataFusion). JVM scan wrappers (CometScanExec, CometBatchScanExec) still require JNI input + * and are not optimized. + */ + @transient private lazy val directNativeExecutionInfo: Option[DirectNativeExecutionInfo] = { + if (!CometConf.COMET_SHUFFLE_DIRECT_NATIVE_ENABLED.get()) { + None + } else if (shuffleType != CometNativeShuffle) { + None + } else { + // Check if direct native execution is possible + outputPartitioning match { + case _: RangePartitioning => + // RangePartitioning requires sampling the data to compute bounds, + // which requires executing the child plan. Fall back to current behavior. + None + case _ => + child match { + case nativeChild: CometNativeExec => + // Find input sources using foreachUntilCometInput + val inputSources = scala.collection.mutable.ArrayBuffer.empty[SparkPlan] + nativeChild.foreachUntilCometInput(nativeChild)(inputSources += _) + + // Only optimize single-source native scan case for now + // JVM scan wrappers (CometScanExec, CometBatchScanExec) still need JNI input, + // so we don't optimize those yet + if (inputSources.size == 1) { + inputSources.head match { + case scan: CometNativeScanExec => + // Fully native scan - no JNI input needed, native code reads files directly + // Get the partition count from the underlying scan + val numPartitions = scan.originalPlan.inputRDD.getNumPartitions + Some(DirectNativeExecutionInfo(nativeChild.nativeOp, numPartitions)) + case _ => + // Other input sources (JVM scans, shuffle, broadcast, etc.) - fall back + None + } + } else { + // Multiple input sources (joins, unions) - fall back for now + None + } + case _ => + None + } + } + } + } + + /** + * Returns true if direct native execution optimization is being used for this shuffle. This is + * primarily intended for testing to verify the optimization is applied correctly. + */ + def isDirectNativeExecution: Boolean = directNativeExecutionInfo.isDefined + + /** + * Creates an RDD that provides empty iterators for each partition. Used when direct native + * execution is enabled - the shuffle writer will execute the full native plan which reads data + * directly (no JNI input needed). + */ + private def createEmptyPartitionRDD(numPartitions: Int): RDD[ColumnarBatch] = { + sparkContext.parallelize(Seq.empty[ColumnarBatch], numPartitions) + } + @transient lazy val inputRDD: RDD[_] = if (shuffleType == CometNativeShuffle) { - // CometNativeShuffle assumes that the input plan is Comet plan. - child.executeColumnar() + directNativeExecutionInfo match { + case Some(info) => + // Direct native execution: create an RDD with empty partitions. + // The shuffle writer will execute the full native plan which reads data directly. + createEmptyPartitionRDD(info.numPartitions) + case None => + // Fall back to current behavior: execute child and pass intermediate batches + child.executeColumnar() + } } else if (shuffleType == CometColumnarShuffle) { // CometColumnarShuffle uses Spark's row-based execute() API. For Spark row-based plans, // rows flow directly. For Comet native plans, their doExecute() wraps with ColumnarToRowExec @@ -142,7 +219,8 @@ case class CometShuffleExchangeExec( child.output, outputPartitioning, serializer, - metrics) + metrics, + directNativeExecutionInfo.map(_.childNativePlan)) metrics("numPartitions").set(dep.partitioner.numPartitions) val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) SQLMetrics.postDriverMetricUpdates( @@ -538,7 +616,9 @@ object CometShuffleExchangeExec outputAttributes: Seq[Attribute], outputPartitioning: Partitioning, serializer: Serializer, - metrics: Map[String, SQLMetric]): ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { + metrics: Map[String, SQLMetric], + childNativePlan: Option[Operator] = None) + : ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { val numParts = rdd.getNumPartitions // The code block below is mostly brought over from @@ -605,7 +685,8 @@ object CometShuffleExchangeExec outputAttributes = outputAttributes, shuffleWriteMetrics = metrics, numParts = numParts, - rangePartitionBounds = rangePartitionBounds) + rangePartitionBounds = rangePartitionBounds, + childNativePlan = childNativePlan) dependency } @@ -810,3 +891,15 @@ object CometShuffleExchangeExec dependency } } + +/** + * Information needed for direct native execution optimization. + * + * @param childNativePlan + * The child's native operator plan to compose with ShuffleWriter + * @param numPartitions + * The number of partitions (from the underlying scan) + */ +private[shuffle] case class DirectNativeExecutionInfo( + childNativePlan: Operator, + numPartitions: Int) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala index aa47dfa166..367ec4a90e 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala @@ -238,7 +238,8 @@ class CometShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { mapId, context, metrics, - dep.rangePartitionBounds) + dep.rangePartitionBounds, + dep.childNativePlan) case bypassMergeSortHandle: CometBypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => new CometBypassMergeSortShuffleWriter( env.blockManager, diff --git a/spark/src/test/scala/org/apache/comet/exec/CometDirectNativeShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometDirectNativeShuffleSuite.scala new file mode 100644 index 0000000000..6ab30b94a3 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/exec/CometDirectNativeShuffleSuite.scala @@ -0,0 +1,286 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.exec + +import org.scalactic.source.Position +import org.scalatest.Tag + +import org.apache.spark.sql.{CometTestBase, DataFrame} +import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.functions.col + +import org.apache.comet.CometConf + +/** + * Test suite for the direct native shuffle execution optimization. + * + * This optimization allows the native shuffle writer to directly execute the child native plan + * instead of reading intermediate batches via JNI. This avoids the JNI round-trip for + * single-source native plans (e.g., Scan -> Filter -> Project -> Shuffle). + */ +class CometDirectNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper { + + override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit + pos: Position): Unit = { + super.test(testName, testTags: _*) { + withSQLConf( + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_MODE.key -> "native", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_NATIVE_SCAN_IMPL.key -> "native_datafusion", + CometConf.COMET_SHUFFLE_DIRECT_NATIVE_ENABLED.key -> "true") { + testFun + } + } + } + + import testImplicits._ + + test("direct native execution: simple scan with hash partitioning") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl").repartition(10, $"_1") + + // Verify the optimization is applied + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1, "Expected exactly one shuffle") + assert( + shuffles.head.isDirectNativeExecution, + "Direct native execution should be enabled for single-source native scan") + + // Verify correctness + checkSparkAnswer(df) + } + } + + test("direct native execution: scan with filter and project") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong, i.toString)), "tbl") { + val df = sql("SELECT _1, _2 * 2 as doubled FROM tbl WHERE _1 > 10") + .repartition(10, $"_1") + + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1) + assert( + shuffles.head.isDirectNativeExecution, + "Direct native execution should work with filter and project") + + checkSparkAnswer(df) + } + } + + test("direct native execution: single partition") { + withParquetTable((0 until 50).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl").repartition(1) + + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1) + assert( + shuffles.head.isDirectNativeExecution, + "Direct native execution should work with single partition") + + checkSparkAnswer(df) + } + } + + test("direct native execution: multiple hash columns") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong, i.toString)), "tbl") { + val df = sql("SELECT * FROM tbl").repartition(10, $"_1", $"_2") + + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1) + assert( + shuffles.head.isDirectNativeExecution, + "Direct native execution should work with multiple hash columns") + + checkSparkAnswer(df) + } + } + + test("direct native execution: aggregation before shuffle") { + withParquetTable((0 until 100).map(i => (i % 10, (i + 1).toLong)), "tbl") { + val df = sql("SELECT _1, SUM(_2) as total FROM tbl GROUP BY _1") + .repartition(5, col("_1")) + + // This involves partial aggregation -> shuffle -> final aggregation + // The direct native execution applies to the shuffle that reads from the partial agg + checkSparkAnswer(df) + } + } + + test("direct native execution disabled: config is false") { + withSQLConf(CometConf.COMET_SHUFFLE_DIRECT_NATIVE_ENABLED.key -> "false") { + withParquetTable((0 until 50).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl").repartition(10, $"_1") + + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1) + assert( + !shuffles.head.isDirectNativeExecution, + "Direct native execution should be disabled when config is false") + + checkSparkAnswer(df) + } + } + } + + test("direct native execution disabled: range partitioning") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl").repartitionByRange(10, $"_1") + + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1) + assert( + !shuffles.head.isDirectNativeExecution, + "Direct native execution should not be used for range partitioning") + + checkSparkAnswer(df) + } + } + + test("direct native execution disabled: JVM columnar shuffle mode") { + withSQLConf(CometConf.COMET_SHUFFLE_MODE.key -> "jvm") { + withParquetTable((0 until 50).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl").repartition(10, $"_1") + + // JVM shuffle mode uses CometColumnarShuffle, not CometNativeShuffle + val shuffles = findShuffleExchanges(df) + shuffles.foreach { shuffle => + assert( + !shuffle.isDirectNativeExecution, + "Direct native execution should not be used with JVM shuffle mode") + } + + checkSparkAnswer(df) + } + } + } + + test("direct native execution: multiple shuffles in same query") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl") + .repartition(10, $"_1") + .select($"_1", $"_2" + 1 as "_2_plus") + .repartition(5, $"_2_plus") + + // First shuffle reads from scan, second reads from previous shuffle output + // Only the first shuffle should use direct native execution + val shuffles = findShuffleExchanges(df) + // AQE might combine some shuffles, so just verify results are correct + checkSparkAnswer(df) + } + } + + test("direct native execution: various data types") { + withParquetTable( + (0 until 50).map(i => + (i, i.toLong, i.toFloat, i.toDouble, i.toString, i % 2 == 0, BigDecimal(i))), + "tbl") { + val df = sql("SELECT * FROM tbl").repartition(10, $"_1") + + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1) + assert(shuffles.head.isDirectNativeExecution) + + checkSparkAnswer(df) + } + } + + test("direct native execution: complex filter and multiple projections") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong, i % 5)), "tbl") { + val df = sql(""" + |SELECT _1 * 2 as doubled, + | _2 + _3 as sum_col, + | _1 + _2 as combined + |FROM tbl + |WHERE _1 > 20 AND _3 < 3 + |""".stripMargin) + .repartition(10, col("doubled")) + + val shuffles = findShuffleExchanges(df) + // Note: Native shuffle might fall back depending on expression support + // Just verify correctness - the optimization is best-effort + checkSparkAnswer(df) + } + } + + test("direct native execution: results match non-optimized path") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong, i.toString)), "tbl") { + // Run with optimization enabled + val dfOptimized = sql("SELECT _1, _2 FROM tbl WHERE _1 > 50").repartition(10, $"_1") + val optimizedResult = dfOptimized.collect().sortBy(_.getInt(0)) + + // Run with optimization disabled and collect results + var nonOptimizedResult: Array[org.apache.spark.sql.Row] = Array.empty + withSQLConf(CometConf.COMET_SHUFFLE_DIRECT_NATIVE_ENABLED.key -> "false") { + val dfNonOptimized = sql("SELECT _1, _2 FROM tbl WHERE _1 > 50").repartition(10, $"_1") + nonOptimizedResult = dfNonOptimized.collect().sortBy(_.getInt(0)) + } + + // Results should match + assert(optimizedResult.length == nonOptimizedResult.length, "Row counts should match") + optimizedResult.zip(nonOptimizedResult).foreach { case (opt, nonOpt) => + assert(opt == nonOpt, s"Rows should match: $opt vs $nonOpt") + } + } + } + + test("direct native execution: large number of partitions") { + withParquetTable((0 until 1000).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl").repartition(201, $"_1") + + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1) + assert(shuffles.head.isDirectNativeExecution) + + checkSparkAnswer(df) + } + } + + test("direct native execution: empty table") { + withParquetTable(Seq.empty[(Int, Long)], "tbl") { + val df = sql("SELECT * FROM tbl").repartition(10, $"_1") + + // Should handle empty tables gracefully + val result = df.collect() + assert(result.isEmpty) + } + } + + test("direct native execution: all rows filtered out") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong)), "tbl") { + val df = sql("SELECT * FROM tbl WHERE _1 > 1000").repartition(10, $"_1") + + val shuffles = findShuffleExchanges(df) + assert(shuffles.length == 1) + assert(shuffles.head.isDirectNativeExecution) + + val result = df.collect() + assert(result.isEmpty, "Result should be empty when all rows are filtered") + } + } + + /** + * Helper method to find CometShuffleExchangeExec nodes in a DataFrame's execution plan. + */ + private def findShuffleExchanges(df: DataFrame): Seq[CometShuffleExchangeExec] = { + val plan = stripAQEPlan(df.queryExecution.executedPlan) + plan.collect { case s: CometShuffleExchangeExec => s } + } +} From 15c88da4289e9f21ad2f740bee07f1de712df0c2 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 20 Jan 2026 21:14:58 -0700 Subject: [PATCH 02/12] format --- .../sql/comet/execution/shuffle/CometShuffleExchangeExec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 0b829aa6ac..6f2bf7d91a 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Uns import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.comet.{CometBatchScanExec, CometMetricNode, CometNativeExec, CometNativeScanExec, CometPlan, CometScanExec, CometSinkPlaceHolder} +import org.apache.spark.sql.comet.{CometMetricNode, CometNativeExec, CometNativeScanExec, CometPlan, CometSinkPlaceHolder} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} From ce361f843dc534754c292937ac9382858b123609 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 21 Jan 2026 08:26:09 -0700 Subject: [PATCH 03/12] fi --- .../org/apache/comet/exec/CometDirectNativeShuffleSuite.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometDirectNativeShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometDirectNativeShuffleSuite.scala index 6ab30b94a3..05c608310c 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometDirectNativeShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometDirectNativeShuffleSuite.scala @@ -181,7 +181,6 @@ class CometDirectNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlan // First shuffle reads from scan, second reads from previous shuffle output // Only the first shuffle should use direct native execution - val shuffles = findShuffleExchanges(df) // AQE might combine some shuffles, so just verify results are correct checkSparkAnswer(df) } @@ -213,7 +212,6 @@ class CometDirectNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlan |""".stripMargin) .repartition(10, col("doubled")) - val shuffles = findShuffleExchanges(df) // Note: Native shuffle might fall back depending on expression support // Just verify correctness - the optimization is best-effort checkSparkAnswer(df) From 6132098982a107b916472b3a9d51e89b437238ee Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 21 Jan 2026 08:48:36 -0700 Subject: [PATCH 04/12] fix: disable direct native shuffle when plan contains subqueries Subqueries (e.g., bloom filters with might_contain) are registered with the parent execution context ID. Direct native shuffle creates a new execution context with a different ID, causing subquery lookup to fail with "Subquery X not found for plan Y" errors. This change detects ScalarSubquery expressions in the child plan and falls back to the standard execution path when present. Co-Authored-By: Claude Opus 4.5 --- .../execution/shuffle/CometShuffleExchangeExec.scala | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 6f2bf7d91a..96b8bcbacc 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.comet.{CometMetricNode, CometNativeExec, CometNativeScanExec, CometPlan, CometSinkPlaceHolder} import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.ScalarSubquery import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} @@ -122,7 +123,16 @@ case class CometShuffleExchangeExec( // Only optimize single-source native scan case for now // JVM scan wrappers (CometScanExec, CometBatchScanExec) still need JNI input, // so we don't optimize those yet - if (inputSources.size == 1) { + // Check if the plan contains subqueries (e.g., bloom filters with might_contain). + // Subqueries are registered with the parent execution context ID, but direct + // native shuffle creates a new execution context, so subquery lookup would fail. + val containsSubquery = nativeChild.exists { p => + p.expressions.exists(_.exists(_.isInstanceOf[ScalarSubquery])) + } + if (containsSubquery) { + // Fall back to avoid subquery lookup failures + None + } else if (inputSources.size == 1) { inputSources.head match { case scan: CometNativeScanExec => // Fully native scan - no JNI input needed, native code reads files directly From 5d32f0632a28a76a1ece1688dd699a465bf15658 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 10 Apr 2026 07:17:39 -0600 Subject: [PATCH 05/12] fix: add per-partition data injection to direct native shuffle and broaden to support any native scan inputs --- .../shuffle/CometNativeShuffleWriter.scala | 29 +++++++--- .../shuffle/CometShuffleDependency.scala | 4 +- .../shuffle/CometShuffleExchangeExec.scala | 57 +++++++++++++------ .../shuffle/CometShuffleManager.scala | 4 +- .../apache/spark/sql/comet/operators.scala | 2 +- 5 files changed, 67 insertions(+), 29 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala index 797b242944..1c1c20fd43 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala @@ -32,11 +32,11 @@ import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriteMetricsR import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Literal} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition} -import org.apache.spark.sql.comet.{CometExec, CometMetricNode} +import org.apache.spark.sql.comet.{CometExec, CometMetricNode, PlanDataInjector} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.comet.CometConf +import org.apache.comet.{CometConf, CometExecIterator} import org.apache.comet.serde.{OperatorOuterClass, PartitioningOuterClass, QueryPlanSerde} import org.apache.comet.serde.OperatorOuterClass.{CompressionCodec, Operator} import org.apache.comet.serde.QueryPlanSerde.serializeDataType @@ -59,7 +59,9 @@ class CometNativeShuffleWriter[K, V]( context: TaskContext, metricsReporter: ShuffleWriteMetricsReporter, rangePartitionBounds: Option[Seq[InternalRow]] = None, - childNativePlan: Option[Operator] = None) + childNativePlan: Option[Operator] = None, + commonByKey: Map[String, Array[Byte]] = Map.empty, + perPartitionByKey: Map[String, Array[Array[Byte]]] = Map.empty) extends ShuffleWriter[K, V] with Logging { @@ -81,6 +83,18 @@ class CometNativeShuffleWriter[K, V]( // Call native shuffle write val nativePlan = getNativePlan(tempDataFilename, tempIndexFilename) + // Inject per-partition file data if this is a direct native execution plan + val actualPlan = if (commonByKey.nonEmpty && perPartitionByKey.nonEmpty) { + val partitionIdx = context.partitionId() + val partitionByKey = perPartitionByKey.map { case (key, arr) => + key -> arr(partitionIdx) + } + val injected = PlanDataInjector.injectPlanData(nativePlan, commonByKey, partitionByKey) + CometExec.serializeNativePlan(injected) + } else { + CometExec.serializeNativePlan(nativePlan) + } + val detailedMetrics = Seq( "elapsed_compute", "encode_time", @@ -102,15 +116,14 @@ class CometNativeShuffleWriter[K, V]( // Getting rid of the fake partitionId val newInputs = inputs.asInstanceOf[Iterator[_ <: Product2[Any, Any]]].map(_._2) - val cometIter = CometExec.getCometIterator( + val cometIter = new CometExecIterator( + CometExec.newIterId, Seq(newInputs.asInstanceOf[Iterator[ColumnarBatch]]), outputAttributes.length, - nativePlan, + actualPlan, nativeMetrics, numParts, - context.partitionId(), - broadcastedHadoopConfForEncryption = None, - encryptedFilePaths = Seq.empty) + context.partitionId()) while (cometIter.hasNext) { cometIter.next() diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala index 3528b6d2c9..d0a42d48ce 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala @@ -53,7 +53,9 @@ class CometShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( val numParts: Int = 0, val rangePartitionBounds: Option[Seq[InternalRow]] = None, // For direct native execution: the child's native plan to compose with ShuffleWriter - val childNativePlan: Option[Operator] = None) + val childNativePlan: Option[Operator] = None, + val commonByKey: Map[String, Array[Byte]] = Map.empty, + val perPartitionByKey: Map[String, Array[Array[Byte]]] = Map.empty) extends ShuffleDependency[K, V, C]( _rdd, partitioner, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index a9a91df8a7..21bbfa4753 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Exp import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.comet.{CometMetricNode, CometNativeExec, CometNativeScanExec, CometPlan, CometSinkPlaceHolder} +import org.apache.spark.sql.comet.{CometIcebergNativeScanExec, CometMetricNode, CometNativeExec, CometNativeScanExec, CometPlan, CometSinkPlaceHolder} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.ScalarSubquery import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec @@ -132,20 +132,33 @@ case class CometShuffleExchangeExec( if (containsSubquery) { // Fall back to avoid subquery lookup failures None - } else if (inputSources.size == 1) { - inputSources.head match { - case scan: CometNativeScanExec => - // Fully native scan - no JNI input needed, native code reads files directly - // Get the partition count from the underlying scan - val numPartitions = scan.originalPlan.inputRDD.getNumPartitions - Some(DirectNativeExecutionInfo(nativeChild.nativeOp, numPartitions)) - case _ => - // Other input sources (JVM scans, shuffle, broadcast, etc.) - fall back - None - } } else { - // Multiple input sources (joins, unions) - fall back for now - None + // Check that ALL input sources are native scans (file-reading, no JNI) + val allNativeScans = inputSources.nonEmpty && inputSources.forall { + case _: CometNativeScanExec => true + case _: CometIcebergNativeScanExec => true + case _ => false + } + if (allNativeScans) { + // Collect per-partition plan data from all native scans + val (commonByKey, perPartitionByKey) = + nativeChild.findAllPlanData(nativeChild) + // All scans must have the same partition count + val partitionCounts = perPartitionByKey.values.map(_.length).toSet + if (partitionCounts.size <= 1) { + val numPartitions = partitionCounts.headOption.getOrElse(0) + Some( + DirectNativeExecutionInfo( + nativeChild.nativeOp, + numPartitions, + commonByKey, + perPartitionByKey)) + } else { + None // Partition count mismatch across scans + } + } else { + None + } } case _ => None @@ -230,7 +243,9 @@ case class CometShuffleExchangeExec( outputPartitioning, serializer, metrics, - directNativeExecutionInfo.map(_.childNativePlan)) + directNativeExecutionInfo.map(_.childNativePlan), + directNativeExecutionInfo.map(_.commonByKey).getOrElse(Map.empty), + directNativeExecutionInfo.map(_.perPartitionByKey).getOrElse(Map.empty)) metrics("numPartitions").set(dep.partitioner.numPartitions) val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) SQLMetrics.postDriverMetricUpdates( @@ -675,7 +690,9 @@ object CometShuffleExchangeExec outputPartitioning: Partitioning, serializer: Serializer, metrics: Map[String, SQLMetric], - childNativePlan: Option[Operator] = None) + childNativePlan: Option[Operator] = None, + commonByKey: Map[String, Array[Byte]] = Map.empty, + perPartitionByKey: Map[String, Array[Array[Byte]]] = Map.empty) : ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { val numParts = rdd.getNumPartitions @@ -744,7 +761,9 @@ object CometShuffleExchangeExec shuffleWriteMetrics = metrics, numParts = numParts, rangePartitionBounds = rangePartitionBounds, - childNativePlan = childNativePlan) + childNativePlan = childNativePlan, + commonByKey = commonByKey, + perPartitionByKey = perPartitionByKey) dependency } @@ -960,4 +979,6 @@ object CometShuffleExchangeExec */ private[shuffle] case class DirectNativeExecutionInfo( childNativePlan: Operator, - numPartitions: Int) + numPartitions: Int, + commonByKey: Map[String, Array[Byte]], + perPartitionByKey: Map[String, Array[Array[Byte]]]) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala index 367ec4a90e..1c27e2d8de 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala @@ -239,7 +239,9 @@ class CometShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { context, metrics, dep.rangePartitionBounds, - dep.childNativePlan) + dep.childNativePlan, + dep.commonByKey, + dep.perPartitionByKey) case bypassMergeSortHandle: CometBypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => new CometBypassMergeSortShuffleWriter( env.blockManager, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 21cbdab974..ec793801d6 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -661,7 +661,7 @@ abstract class CometNativeExec extends CometExec { * @return * (commonByKey, perPartitionByKey) - common data is shared, per-partition varies */ - private def findAllPlanData( + private[comet] def findAllPlanData( plan: SparkPlan): (Map[String, Array[Byte]], Map[String, Array[Array[Byte]]]) = { plan match { // Found an Iceberg scan with planning data From 78ea924f634a9a201f0dc8aecdf3fb71b505a31c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 10 Apr 2026 07:22:05 -0600 Subject: [PATCH 06/12] fix: update stale comment in directNativeExecutionInfo --- .../comet/execution/shuffle/CometShuffleExchangeExec.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 21bbfa4753..5a2e8437d1 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -120,9 +120,10 @@ case class CometShuffleExchangeExec( val inputSources = scala.collection.mutable.ArrayBuffer.empty[SparkPlan] nativeChild.foreachUntilCometInput(nativeChild)(inputSources += _) - // Only optimize single-source native scan case for now + // Optimize when all input sources are native scans + // (CometNativeScanExec, CometIcebergNativeScanExec). // JVM scan wrappers (CometScanExec, CometBatchScanExec) still need JNI input, - // so we don't optimize those yet + // so we don't optimize those. // Check if the plan contains subqueries (e.g., bloom filters with might_contain). // Subqueries are registered with the parent execution context ID, but direct // native shuffle creates a new execution context, so subquery lookup would fail. From 26b8bf62e1efc1a6aaab339110da72ea8b67645d Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 10 Apr 2026 07:23:03 -0600 Subject: [PATCH 07/12] docs: update config and class docs for broadened direct native shuffle --- common/src/main/scala/org/apache/comet/CometConf.scala | 4 +++- .../execution/shuffle/CometNativeShuffleWriter.scala | 8 ++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index dea1d83e76..b8b2c829df 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -349,7 +349,9 @@ object CometConf extends ShimCometConf { .doc( "When enabled, the native shuffle writer will directly execute the child native plan " + "instead of reading intermediate batches via JNI. This optimization avoids the " + - "JNI round-trip for single-source native plans (e.g., Scan -> Filter -> Project). " + + "JNI round-trip for native plans whose inputs are all native scans " + + "(CometNativeScanExec, CometIcebergNativeScanExec). Supports single and multi-source " + + "plans (e.g., joins over native scans). " + "This is an experimental feature and is disabled by default.") .internal() .booleanConf diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala index 1c1c20fd43..4c585c01fd 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala @@ -46,8 +46,12 @@ import org.apache.comet.serde.QueryPlanSerde.serializeDataType * * @param childNativePlan * When provided, the shuffle writer will execute this native plan directly and pipe its output - * to the ShuffleWriter, avoiding the JNI round-trip for intermediate batches. This is used for - * direct native execution optimization when the shuffle's child is a single-source native plan. + * to the ShuffleWriter, avoiding the JNI round-trip for intermediate batches. Used when all + * input sources are native scans (CometNativeScanExec, CometIcebergNativeScanExec). + * @param commonByKey + * Common planning data (schemas, filters) keyed by source identifier, for PlanDataInjector. + * @param perPartitionByKey + * Per-partition planning data (file lists) keyed by source identifier, for PlanDataInjector. */ class CometNativeShuffleWriter[K, V]( outputPartitioning: Partitioning, From cb3446aaa30d111ae380e95022ce0ce60d03f786 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 10 Apr 2026 07:26:22 -0600 Subject: [PATCH 08/12] test: add tests for broadened direct native shuffle (multi-source, fallback) --- .../exec/CometDirectNativeShuffleSuite.scala | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometDirectNativeShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometDirectNativeShuffleSuite.scala index 05c608310c..c08ada38a9 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometDirectNativeShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometDirectNativeShuffleSuite.scala @@ -274,6 +274,44 @@ class CometDirectNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlan } } + // TODO: Add Iceberg native scan test when Iceberg test infrastructure is available + // in this suite. CometIcebergNativeScanExec is supported by the optimization but + // requires SparkCatalog setup. See CometIcebergSuite for patterns. + + test("direct native execution: join of two native scans") { + withParquetTable((0 until 100).map(i => (i, s"left_$i")), "left_tbl") { + withParquetTable((0 until 100).map(i => (i, s"right_$i")), "right_tbl") { + // Broadcast join with two native scans + // The join itself may or may not use direct native execution depending on + // whether broadcast creates a non-native-scan input, but the query should + // execute correctly regardless + val df = sql(""" + |SELECT l._1, l._2, r._2 + |FROM left_tbl l JOIN right_tbl r ON l._1 = r._1 + |WHERE l._1 > 50 + |""".stripMargin) + .repartition(10, col("_1")) + + checkSparkAnswer(df) + } + } + } + + test("direct native execution disabled: shuffle input source") { + withParquetTable((0 until 100).map(i => (i, (i + 1).toLong)), "tbl") { + // Force a shuffle before the final shuffle by using a repartition + aggregation + // The second shuffle reads from the first shuffle's output (not a native scan) + val df = sql("SELECT _1, SUM(_2) as s FROM tbl GROUP BY _1") + .repartition(5, col("_1")) + .filter(col("s") > 10) + .repartition(3, col("_1")) + + // The final shuffle should NOT use direct native execution because its input + // comes from a shuffle read, not a native scan + checkSparkAnswer(df) + } + } + /** * Helper method to find CometShuffleExchangeExec nodes in a DataFrame's execution plan. */ From d370a91f64161a9a6f4f710f7eb1de96974ac0ca Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 10 Apr 2026 07:29:16 -0600 Subject: [PATCH 09/12] chore: update Cargo.lock after format --- native/Cargo.lock | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/native/Cargo.lock b/native/Cargo.lock index 480f7ad06d..b5c7f2b0c7 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -2070,9 +2070,7 @@ dependencies = [ "num", "rand 0.10.0", "regex", - "serde", "serde_json", - "thiserror 2.0.18", "tokio", "twox-hash", ] @@ -5885,7 +5883,7 @@ version = "1.0.149" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" dependencies = [ - "indexmap 2.13.0", + "indexmap 2.13.1", "itoa", "memchr", "serde", From 9040f8d3ab362e97c70da69fd766a8361e4a01b1 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 10 Apr 2026 08:27:39 -0600 Subject: [PATCH 10/12] default true --- common/src/main/scala/org/apache/comet/CometConf.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index b8b2c829df..a32a290c02 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -351,11 +351,10 @@ object CometConf extends ShimCometConf { "instead of reading intermediate batches via JNI. This optimization avoids the " + "JNI round-trip for native plans whose inputs are all native scans " + "(CometNativeScanExec, CometIcebergNativeScanExec). Supports single and multi-source " + - "plans (e.g., joins over native scans). " + - "This is an experimental feature and is disabled by default.") + "plans (such as joins over native scans).") .internal() .booleanConf - .createWithDefault(false) + .createWithDefault(true) val COMET_SHUFFLE_DIRECT_READ_ENABLED: ConfigEntry[Boolean] = conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.directRead.enabled") From e9b0b03b6fc1a6240c537956a22d1f97697492e6 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 10 Apr 2026 09:08:47 -0600 Subject: [PATCH 11/12] fix: handle empty table in direct native shuffle execution Fall back to normal execution when numPartitions is 0 (empty table with no data files) to avoid IllegalArgumentException from sparkContext.parallelize requiring positive partition count. --- .../shuffle/CometShuffleExchangeExec.scala | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 5a2e8437d1..5a590c85aa 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -148,12 +148,17 @@ case class CometShuffleExchangeExec( val partitionCounts = perPartitionByKey.values.map(_.length).toSet if (partitionCounts.size <= 1) { val numPartitions = partitionCounts.headOption.getOrElse(0) - Some( - DirectNativeExecutionInfo( - nativeChild.nativeOp, - numPartitions, - commonByKey, - perPartitionByKey)) + if (numPartitions == 0) { + // Empty table (no data files) - fall back to normal execution + None + } else { + Some( + DirectNativeExecutionInfo( + nativeChild.nativeOp, + numPartitions, + commonByKey, + perPartitionByKey)) + } } else { None // Partition count mismatch across scans } From aed8fc1af536834c7ce42ac4861724261c731bd6 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 10 Apr 2026 10:36:38 -0600 Subject: [PATCH 12/12] Revert "default true" This reverts commit 9040f8d3ab362e97c70da69fd766a8361e4a01b1. --- common/src/main/scala/org/apache/comet/CometConf.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index a32a290c02..b8b2c829df 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -351,10 +351,11 @@ object CometConf extends ShimCometConf { "instead of reading intermediate batches via JNI. This optimization avoids the " + "JNI round-trip for native plans whose inputs are all native scans " + "(CometNativeScanExec, CometIcebergNativeScanExec). Supports single and multi-source " + - "plans (such as joins over native scans).") + "plans (e.g., joins over native scans). " + + "This is an experimental feature and is disabled by default.") .internal() .booleanConf - .createWithDefault(true) + .createWithDefault(false) val COMET_SHUFFLE_DIRECT_READ_ENABLED: ConfigEntry[Boolean] = conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.directRead.enabled")