diff --git a/.github/workflows/pr_benchmark_check.yml b/.github/workflows/pr_benchmark_check.yml index b07cc03c34..a879493a7f 100644 --- a/.github/workflows/pr_benchmark_check.yml +++ b/.github/workflows/pr_benchmark_check.yml @@ -84,9 +84,7 @@ jobs: ${{ runner.os }}-benchmark-maven- - name: Check Scala compilation and linting - # Pin to spark-4.0 (Scala 2.13.16) because the default profile is now - # spark-4.1 / Scala 2.13.17, and semanticdb-scalac_2.13.17 is not yet - # published, which breaks `-Psemanticdb`. See pr_build_linux.yml for - # the same exclusion in the main lint matrix. + # Pinned to spark-4.0 because semanticdb-scalac_2.13.17 (spark-4.1 default) + # is not yet published, which breaks the -Psemanticdb scalafix lint. run: | - ./mvnw -B compile test-compile scalafix:scalafix -Dscalafix.mode=CHECK -Psemanticdb -Pspark-4.0 -DskipTests + ./mvnw -B compile test-compile scalafix:scalafix -Dscalafix.mode=CHECK -Pspark-4.0 -Psemanticdb -DskipTests diff --git a/.github/workflows/pr_build_linux.yml b/.github/workflows/pr_build_linux.yml index 5c1ae2dc47..6e6a526f71 100644 --- a/.github/workflows/pr_build_linux.yml +++ b/.github/workflows/pr_build_linux.yml @@ -309,6 +309,7 @@ jobs: org.apache.comet.CometFuzzAggregateSuite org.apache.comet.CometFuzzIcebergSuite org.apache.comet.CometFuzzMathSuite + org.apache.comet.CometCodegenDispatchFuzzSuite org.apache.comet.DataGeneratorSuite - name: "shuffle" value: | @@ -386,6 +387,8 @@ jobs: org.apache.comet.expressions.conditional.CometIfSuite org.apache.comet.expressions.conditional.CometCoalesceSuite org.apache.comet.expressions.conditional.CometCaseWhenSuite + org.apache.comet.CometCodegenDispatchSmokeSuite + org.apache.comet.CometCodegenSourceSuite - name: "sql" value: | org.apache.spark.sql.CometToPrettyStringSuite diff --git a/.github/workflows/pr_build_macos.yml b/.github/workflows/pr_build_macos.yml index 29eca594a2..901a3b7f5c 100644 --- a/.github/workflows/pr_build_macos.yml +++ b/.github/workflows/pr_build_macos.yml @@ -157,6 +157,7 @@ jobs: org.apache.comet.CometFuzzAggregateSuite org.apache.comet.CometFuzzIcebergSuite org.apache.comet.CometFuzzMathSuite + org.apache.comet.CometCodegenDispatchFuzzSuite org.apache.comet.DataGeneratorSuite - name: "shuffle" value: | @@ -233,6 +234,8 @@ jobs: org.apache.comet.expressions.conditional.CometIfSuite org.apache.comet.expressions.conditional.CometCoalesceSuite org.apache.comet.expressions.conditional.CometCaseWhenSuite + org.apache.comet.CometCodegenDispatchSmokeSuite + org.apache.comet.CometCodegenSourceSuite - name: "sql" value: | org.apache.spark.sql.CometToPrettyStringSuite diff --git a/common/src/main/java/org/apache/comet/udf/CometBatchKernel.java b/common/src/main/java/org/apache/comet/udf/CometBatchKernel.java new file mode 100644 index 0000000000..cfa61c9715 --- /dev/null +++ b/common/src/main/java/org/apache/comet/udf/CometBatchKernel.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf; + +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ValueVector; + +/** + * Abstract base extended by the Janino-compiled batch kernel emitted by {@code + * CometBatchKernelCodegen}. The generated subclass extends {@code CometInternalRow} (so Spark's + * {@code BoundReference.genCode} can call {@code this.getUTF8String(ord)} directly) and carries + * typed input fields baked at codegen time, one per input column. Expression evaluation plus Arrow + * read/write fuse into one method per expression tree. + * + *

Input scope: any {@code ValueVector[]}; the generated subclass casts each slot to the concrete + * Arrow type the compile-time schema specified. Output is a generic {@code FieldVector}; the + * generated subclass casts to the concrete type matching the bound expression's {@code dataType}. + * Widen input support by adding vector classes to the getter switch in {@code + * CometBatchKernelCodegen.emitTypedGetters}; widen output support by adding cases in {@code + * CometBatchKernelCodegen.allocateOutput} and {@code emitOutputWriter}. + */ +public abstract class CometBatchKernel extends CometInternalRow { + + protected final Object[] references; + + protected CometBatchKernel(Object[] references) { + this.references = references; + } + + /** + * Process one batch. + * + * @param inputs Arrow input vectors; length and concrete classes must match the schema the kernel + * was compiled against + * @param output Arrow output vector; caller allocates to the expression's {@code dataType} + * @param numRows number of rows in this batch + */ + public abstract void process(ValueVector[] inputs, FieldVector output, int numRows); + + /** + * Run partition-dependent initialization. The generated subclass overrides this to execute + * statements collected via {@code CodegenContext.addPartitionInitializationStatement}, for + * example reseeding {@code Rand}'s {@code XORShiftRandom} from {@code seed + partitionIndex}. + * Deterministic expressions leave this as a no-op. + * + *

The caller must invoke this before the first {@code process} call of each partition. The + * generated subclass is not thread-safe across concurrent {@code process} calls, so kernels are + * allocated per dispatcher invocation and init is run once on the fresh instance. + */ + public void init(int partitionIndex) {} +} diff --git a/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java b/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java index aed53c57df..4e8662829f 100644 --- a/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java +++ b/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java @@ -19,7 +19,8 @@ package org.apache.comet.udf; -import java.util.concurrent.ConcurrentHashMap; +import java.util.LinkedHashMap; +import java.util.Map; import org.apache.arrow.c.ArrowArray; import org.apache.arrow.c.ArrowSchema; @@ -27,6 +28,8 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.ValueVector; +import org.apache.spark.TaskContext; +import org.apache.spark.comet.CometTaskContextShim; /** * JNI entry point for native execution to invoke a {@link CometUDF}. Matches the static-method @@ -35,10 +38,23 @@ */ public class CometUdfBridge { - // Process-wide cache of UDF instances keyed by class name. CometUDF - // implementations are required to be stateless (see CometUDF), so a - // single shared instance per class is safe across native worker threads. - private static final ConcurrentHashMap INSTANCES = new ConcurrentHashMap<>(); + // Per-thread, bounded LRU of UDF instances keyed by class name. Comet + // native execution threads (Tokio/DataFusion worker pool) are reused + // across tasks within an executor, so the effective lifetime of cached + // entries is the worker thread (i.e. the executor JVM). Fine for + // stateless UDFs; future stateful UDFs would need explicit per-task + // isolation. + private static final int CACHE_CAPACITY = 64; + + private static final ThreadLocal> INSTANCES = + ThreadLocal.withInitial( + () -> + new LinkedHashMap(CACHE_CAPACITY, 0.75f, true) { + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return size() > CACHE_CAPACITY; + } + }); /** * Called from native via JNI. @@ -48,30 +64,67 @@ public class CometUdfBridge { * @param inputSchemaPtrs addresses of pre-allocated FFI_ArrowSchema structs (one per input) * @param outArrayPtr address of pre-allocated FFI_ArrowArray for the result * @param outSchemaPtr address of pre-allocated FFI_ArrowSchema for the result + * @param numRows number of rows in the current batch. Mirrors DataFusion's {@code + * ScalarFunctionArgs.number_rows} and gives UDFs an explicit batch-size signal for cases + * where no input arg is a batch-length array (e.g. a zero-arg non-deterministic ScalaUDF). + * UDFs that already read size from their input vectors can ignore it. + * @param taskContext Spark {@link TaskContext} captured on the driving Spark task thread and + * passed through from native. May be {@code null} when the bridge is invoked outside a Spark + * task (unit tests, direct native driver runs). When non-null and the current thread has no + * {@code TaskContext} of its own, the bridge installs it as the thread-local for the duration + * of the UDF call so the UDF body (including partition-sensitive built-ins like {@code Rand} + * / {@code Uuid} / {@code MonotonicallyIncreasingID} that read the partition index via {@code + * TaskContext.get().partitionId()}) sees the real context rather than null. The thread-local + * is cleared in a {@code finally} so Tokio workers don't leak a stale TaskContext across + * invocations. */ public static void evaluate( String udfClassName, long[] inputArrayPtrs, long[] inputSchemaPtrs, long outArrayPtr, - long outSchemaPtr) { - CometUDF udf = - INSTANCES.computeIfAbsent( - udfClassName, - name -> { - try { - // Resolve via the executor's context classloader so user-supplied UDF jars - // (added via spark.jars / --jars) are visible. - ClassLoader cl = Thread.currentThread().getContextClassLoader(); - if (cl == null) { - cl = CometUdfBridge.class.getClassLoader(); - } - return (CometUDF) - Class.forName(name, true, cl).getDeclaredConstructor().newInstance(); - } catch (ReflectiveOperationException e) { - throw new RuntimeException("Failed to instantiate CometUDF: " + name, e); - } - }); + long outSchemaPtr, + int numRows, + TaskContext taskContext) { + boolean installedTaskContext = false; + if (taskContext != null && TaskContext.get() == null) { + CometTaskContextShim.set(taskContext); + installedTaskContext = true; + } + try { + evaluateInternal( + udfClassName, inputArrayPtrs, inputSchemaPtrs, outArrayPtr, outSchemaPtr, numRows); + } finally { + if (installedTaskContext) { + CometTaskContextShim.unset(); + } + } + } + + private static void evaluateInternal( + String udfClassName, + long[] inputArrayPtrs, + long[] inputSchemaPtrs, + long outArrayPtr, + long outSchemaPtr, + int numRows) { + LinkedHashMap cache = INSTANCES.get(); + CometUDF udf = cache.get(udfClassName); + if (udf == null) { + try { + // Resolve via the executor's context classloader so user-supplied UDF jars + // (added via spark.jars / --jars) are visible. + ClassLoader cl = Thread.currentThread().getContextClassLoader(); + if (cl == null) { + cl = CometUdfBridge.class.getClassLoader(); + } + udf = + (CometUDF) Class.forName(udfClassName, true, cl).getDeclaredConstructor().newInstance(); + } catch (ReflectiveOperationException e) { + throw new RuntimeException("Failed to instantiate CometUDF: " + udfClassName, e); + } + cache.put(udfClassName, udf); + } BufferAllocator allocator = org.apache.comet.package$.MODULE$.CometArrowAllocator(); @@ -84,23 +137,17 @@ public static void evaluate( inputs[i] = Data.importVector(allocator, inArr, inSch, null); } - result = udf.evaluate(inputs); + result = udf.evaluate(inputs, numRows); if (!(result instanceof FieldVector)) { throw new RuntimeException( "CometUDF.evaluate() must return a FieldVector, got: " + result.getClass().getName()); } - // Result length must match the longest input. Scalar (length-1) inputs - // are allowed to be shorter, but a vector input bounds the output. - int expectedLen = 0; - for (ValueVector v : inputs) { - expectedLen = Math.max(expectedLen, v.getValueCount()); - } - if (result.getValueCount() != expectedLen) { + if (result.getValueCount() != numRows) { throw new RuntimeException( "CometUDF.evaluate() returned " + result.getValueCount() + " rows, expected " - + expectedLen); + + numRows); } ArrowArray outArr = ArrowArray.wrap(outArrayPtr); ArrowSchema outSch = ArrowSchema.wrap(outSchemaPtr); diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 9b376837f7..dcc6359304 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -380,6 +380,46 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(false) + val REGEXP_ENGINE_RUST = "rust" + val REGEXP_ENGINE_JAVA = "java" + + val COMET_REGEXP_ENGINE: ConfigEntry[String] = + conf("spark.comet.exec.regexp.engine") + .category(CATEGORY_EXEC) + .doc( + "Experimental. Selects the engine used to evaluate supported regular-expression " + + s"expressions. `$REGEXP_ENGINE_RUST` uses the native DataFusion regexp engine. " + + s"`$REGEXP_ENGINE_JAVA` routes through a JVM-side UDF (java.util.regex.Pattern) for " + + "Spark-compatible semantics, at the cost of JNI roundtrips per batch. Expressions " + + "routed when set to java: rlike, regexp_extract, regexp_extract_all, regexp_replace, " + + "regexp_instr, and split.") + .stringConf + .transform(_.toLowerCase(Locale.ROOT)) + .checkValues(Set(REGEXP_ENGINE_RUST, REGEXP_ENGINE_JAVA)) + .createWithDefault(REGEXP_ENGINE_JAVA) + + val CODEGEN_DISPATCH_AUTO = "auto" + val CODEGEN_DISPATCH_DISABLED = "disabled" + val CODEGEN_DISPATCH_FORCE = "force" + + val COMET_CODEGEN_DISPATCH_MODE: ConfigEntry[String] = + conf("spark.comet.exec.codegenDispatch.mode") + .category(CATEGORY_EXEC) + .doc("Controls whether Comet routes eligible scalar expressions through the Arrow-direct " + + "codegen dispatcher (`CometCodegenDispatchUDF`) rather than through a native " + + s"DataFusion implementation or falling back to Spark. `$CODEGEN_DISPATCH_AUTO` lets " + + "each expression's serde decide its preferred path based on measured evidence " + + "(e.g. for regex, codegen is preferred when " + + s"spark.comet.exec.regexp.engine=$REGEXP_ENGINE_JAVA). " + + s"`$CODEGEN_DISPATCH_DISABLED` never uses codegen dispatch. `$CODEGEN_DISPATCH_FORCE` " + + "inverts the chain: every serde tries codegen first and falls through to its next " + + "preferred path only when `canHandle` rejects the expression. Useful for debugging " + + "and benchmarking.") + .stringConf + .transform(_.toLowerCase(Locale.ROOT)) + .checkValues(Set(CODEGEN_DISPATCH_AUTO, CODEGEN_DISPATCH_DISABLED, CODEGEN_DISPATCH_FORCE)) + .createWithDefault(CODEGEN_DISPATCH_AUTO) + val COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.native.shuffle.partitioning.hash.enabled") .category(CATEGORY_SHUFFLE) diff --git a/common/src/main/scala/org/apache/comet/udf/CometArrayData.scala b/common/src/main/scala/org/apache/comet/udf/CometArrayData.scala new file mode 100644 index 0000000000..36e11546e7 --- /dev/null +++ b/common/src/main/scala/org/apache/comet/udf/CometArrayData.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, DataType, DateType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, StructType, TimestampNTZType, TimestampType} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} + +import org.apache.comet.shims.CometInternalRowShim + +/** + * Shim base for Comet-owned [[ArrayData]] views used by the Arrow-direct codegen kernel. + * + * Provides `UnsupportedOperationException` defaults for every abstract method on `ArrayData` and + * `SpecializedGetters`. Codegen emits a concrete subclass per complex-typed input column, + * overriding only the small set of getters the element type requires (e.g. `numElements`, + * `isNullAt`, and `getUTF8String` for an `ArrayType(StringType)` input). + * + * Pattern mirrors [[CometInternalRow]]: centralize the boilerplate throws so the codegen- emitted + * subclasses stay short, and absorb forward-compat breakage if Spark adds abstract methods to + * `ArrayData` in a future version. + * + * Mixes in [[CometInternalRowShim]] for the same reason `CometInternalRow` does: Spark 4.x adds + * new abstract getters (`getVariant`, `getGeography`, `getGeometry`) on `SpecializedGetters` that + * both `InternalRow` and `ArrayData` inherit. The shim is per-profile and provides throwing + * defaults only on the profiles that declare those methods abstract. + */ +abstract class CometArrayData extends ArrayData with CometInternalRowShim { + + override def numElements(): Int = unsupported("numElements") + override def isNullAt(ordinal: Int): Boolean = unsupported("isNullAt") + + override def getBoolean(ordinal: Int): Boolean = unsupported("getBoolean") + override def getByte(ordinal: Int): Byte = unsupported("getByte") + override def getShort(ordinal: Int): Short = unsupported("getShort") + override def getInt(ordinal: Int): Int = unsupported("getInt") + override def getLong(ordinal: Int): Long = unsupported("getLong") + override def getFloat(ordinal: Int): Float = unsupported("getFloat") + override def getDouble(ordinal: Int): Double = unsupported("getDouble") + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = + unsupported("getDecimal") + override def getUTF8String(ordinal: Int): UTF8String = unsupported("getUTF8String") + override def getBinary(ordinal: Int): Array[Byte] = unsupported("getBinary") + override def getInterval(ordinal: Int): CalendarInterval = unsupported("getInterval") + override def getStruct(ordinal: Int, numFields: Int): InternalRow = unsupported("getStruct") + override def getArray(ordinal: Int): ArrayData = unsupported("getArray") + override def getMap(ordinal: Int): MapData = unsupported("getMap") + + /** + * Generic `get(ordinal, dataType)` dispatcher. Spark codegen sometimes calls this rather than + * the typed getter (`SafeProjection` uses it when deserializing struct-valued ScalaUDF args, + * for example); leaving it as a throw leaks NPEs once callers catch the + * `UnsupportedOperationException` and propagate null. Dispatches to the typed getter matching + * `dataType`; a null entry returns `null` outright. + */ + override def get(ordinal: Int, dataType: DataType): AnyRef = { + if (isNullAt(ordinal)) return null + dataType match { + case BooleanType => java.lang.Boolean.valueOf(getBoolean(ordinal)) + case ByteType => java.lang.Byte.valueOf(getByte(ordinal)) + case ShortType => java.lang.Short.valueOf(getShort(ordinal)) + case IntegerType | DateType => java.lang.Integer.valueOf(getInt(ordinal)) + case LongType | TimestampType | TimestampNTZType => + java.lang.Long.valueOf(getLong(ordinal)) + case FloatType => java.lang.Float.valueOf(getFloat(ordinal)) + case DoubleType => java.lang.Double.valueOf(getDouble(ordinal)) + case _: StringType => getUTF8String(ordinal) + case BinaryType => getBinary(ordinal) + case dt: DecimalType => getDecimal(ordinal, dt.precision, dt.scale) + case st: StructType => getStruct(ordinal, st.size) + case _: ArrayType => getArray(ordinal) + case _: MapType => getMap(ordinal) + case other => unsupported(s"get for dataType $other") + } + } + + override def setNullAt(i: Int): Unit = unsupported("setNullAt") + override def update(i: Int, value: Any): Unit = unsupported("update") + + override def copy(): ArrayData = unsupported("copy") + override def array: Array[Any] = unsupported("array") + override def toString(): String = { + val n = + try numElements().toString + catch { case _: Throwable => "?" } + s"${getClass.getSimpleName}(numElements=$n)" + } + + protected def unsupported(method: String): Nothing = + throw new UnsupportedOperationException( + s"${getClass.getSimpleName}: $method not implemented for this array shape") +} diff --git a/common/src/main/scala/org/apache/comet/udf/CometBatchKernelCodegen.scala b/common/src/main/scala/org/apache/comet/udf/CometBatchKernelCodegen.scala new file mode 100644 index 0000000000..8d18f4297e --- /dev/null +++ b/common/src/main/scala/org/apache/comet/udf/CometBatchKernelCodegen.scala @@ -0,0 +1,642 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, FieldVector, Float4Vector, Float8Vector, IntVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, ValueVector, VarBinaryVector, VarCharVector} +import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, Literal, RegExpReplace, Unevaluable} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeFormatter, CodegenContext, CodeGenerator, CodegenFallback, ExprCode, GeneratedClass} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{DataType, StringType} + +import org.apache.comet.shims.CometExprTraitShim + +/** + * Compiles a bound [[Expression]] plus an input schema into a specialized [[CometBatchKernel]] + * that fuses Arrow input reads, expression evaluation, and Arrow output writes into one + * Janino-compiled method per (expression, schema) pair. + * + * Input- and output-side emission live in [[CometBatchKernelCodegenInput]] and + * [[CometBatchKernelCodegenOutput]]. This file is the orchestrator: the [[ArrowColumnSpec]] + * vocabulary, [[canHandle]] / [[allocateOutput]] / [[compile]] / [[generateSource]] entry points, + * and the cross-cutting kernel-shape decisions (null-intolerant short-circuit, CSE variant, + * per-expression specialized emitters). + * + * The generated kernel '''is''' the `InternalRow` that Spark's `BoundReference.genCode` reads + * from. `ctx.INPUT_ROW = "row"` plus `InternalRow row = this;` inside `process` routes every + * `row.getUTF8String(ord)` to the kernel's own typed getter (final method, constant ordinal; JIT + * devirtualizes and folds the switch). `row` rather than `this` because Spark's + * `splitExpressions` uses INPUT_ROW as a helper-method parameter name and `this` is a reserved + * Java keyword. + * + * For the full feature list (type surface, optimizations, cache layers, specialized emitters, + * open work items), see `docs/source/contributor-guide/jvm_udf_dispatch.md`. + */ +object CometBatchKernelCodegen extends Logging with CometExprTraitShim { + + /** + * Per-column compile-time invariants. The concrete Arrow vector class and whether the column is + * nullable are baked into the generated kernel's typed fields and branches. Part of the cache + * key: different vector classes or nullability produce different kernels. + * + * Sealed hierarchy so that complex types (array/map/struct) can carry their nested element + * shape recursively. Today scalar, array, and struct specs exist; map cases will land as an + * additional subclass when the emitter covers them. A companion `apply` / `unapply` preserves + * the original scalar-only construction and extractor shape so existing callers don't need to + * change. + */ + sealed trait ArrowColumnSpec { + def vectorClass: Class[_ <: ValueVector] + def nullable: Boolean + } + + object ArrowColumnSpec { + + /** Convenience constructor producing a [[ScalarColumnSpec]]. */ + def apply(vectorClass: Class[_ <: ValueVector], nullable: Boolean): ArrowColumnSpec = + ScalarColumnSpec(vectorClass, nullable) + + /** + * Backward-compatible extractor for the common scalar case. Callers that want array / struct + * / future map specs should pattern match on the subclass directly. + */ + def unapply(spec: ArrowColumnSpec): Option[(Class[_ <: ValueVector], Boolean)] = spec match { + case ScalarColumnSpec(c, n) => Some((c, n)) + case _ => None + } + } + + /** Scalar column: one Arrow vector class per row slot, no nested structure. */ + final case class ScalarColumnSpec(vectorClass: Class[_ <: ValueVector], nullable: Boolean) + extends ArrowColumnSpec + + /** + * Array column: an Arrow `ListVector` wrapping a child spec. `elementSparkType` is the Spark + * `DataType` of the element so the nested-class getter emitter can choose the right template + * (e.g. `getUTF8String` for `StringType`, `getInt` for `IntegerType`). The child spec carries + * the Arrow child vector class. Nested arrays (`Array>`) work by the child being + * itself an `ArrayColumnSpec`. + */ + final case class ArrayColumnSpec( + nullable: Boolean, + elementSparkType: DataType, + element: ArrowColumnSpec) + extends ArrowColumnSpec { + override def vectorClass: Class[_ <: ValueVector] = classOf[ListVector] + } + + /** + * Struct column: an Arrow `StructVector` wrapping N typed child specs. Each entry carries the + * Spark field name (for schema identification in the cache key), the Spark `DataType` of the + * field (so per-field emitters pick the right read/write template), the child `ArrowColumnSpec` + * (so nested shapes like `Struct>` compose by trait-level recursion), and the + * field's `nullable` bit (so non-nullable fields elide their per-row null check at source + * level). Nested structs (`Struct>`) work by the child being itself a + * `StructColumnSpec`. + */ + final case class StructColumnSpec(nullable: Boolean, fields: Seq[StructFieldSpec]) + extends ArrowColumnSpec { + override def vectorClass: Class[_ <: ValueVector] = classOf[StructVector] + } + + /** One field entry on a [[StructColumnSpec]]. */ + final case class StructFieldSpec( + name: String, + sparkType: DataType, + nullable: Boolean, + child: ArrowColumnSpec) + + /** + * Map column: an Arrow `MapVector` (subclass of `ListVector`) whose data vector is a + * `StructVector` with a key field at ordinal 0 and a value field at ordinal 1. `key` and + * `value` are themselves `ArrowColumnSpec` so nested shapes (`Map, Array>`, + * `Map, ...>`) compose by trait-level recursion. Nullable map entries are controlled + * per-column by the outer map's validity; nullable keys and values are carried in the child + * specs' `nullable` bit. + */ + final case class MapColumnSpec( + nullable: Boolean, + keySparkType: DataType, + valueSparkType: DataType, + key: ArrowColumnSpec, + value: ArrowColumnSpec) + extends ArrowColumnSpec { + override def vectorClass: Class[_ <: ValueVector] = classOf[MapVector] + } + + /** + * Resolve an Arrow vector class by its simple name, using the same classloader the codegen uses + * internally. Intended for tests: the `common` module shades `org.apache.arrow` to + * `org.apache.comet.shaded.arrow`, so `classOf[VarCharVector]` at a call site in an unshaded + * module refers to a different [[Class]] object than the one the codegen compares against. + * Callers pass a simple name and get back the class the production code actually uses. + */ + def vectorClassBySimpleName(name: String): Class[_ <: ValueVector] = name match { + case "BitVector" => classOf[BitVector] + case "TinyIntVector" => classOf[TinyIntVector] + case "SmallIntVector" => classOf[SmallIntVector] + case "IntVector" => classOf[IntVector] + case "BigIntVector" => classOf[BigIntVector] + case "Float4Vector" => classOf[Float4Vector] + case "Float8Vector" => classOf[Float8Vector] + case "DecimalVector" => classOf[DecimalVector] + case "DateDayVector" => classOf[DateDayVector] + case "TimeStampMicroVector" => classOf[TimeStampMicroVector] + case "TimeStampMicroTZVector" => classOf[TimeStampMicroTZVector] + case "VarCharVector" => classOf[VarCharVector] + case "VarBinaryVector" => classOf[VarBinaryVector] + case other => throw new IllegalArgumentException(s"unknown Arrow vector class: $other") + } + + /** + * Result of compiling a bound [[Expression]] into a Janino kernel. The `factory` is the Spark + * [[GeneratedClass]] produced by Janino and is safe to share across threads and partitions: it + * holds no mutable state. The `freshReferences` closure regenerates the references array each + * time a new kernel instance is allocated. + * + * Why not cache a single `references` array: some expressions (notably [[ScalaUDF]]) embed + * stateful Spark `ExpressionEncoder` serializers into `references` via `ctx.addReferenceObj`. + * Those serializers reuse an internal `UnsafeRow` / `byte[]` buffer per `.apply(...)` call and + * are not thread-safe. If two kernels on different partitions shared one serializer instance, + * they would race on that buffer and produce garbage. Re-running `genCode` per kernel + * allocation costs microseconds; Janino compile costs milliseconds. Cache the expensive piece, + * refresh the cheap one, stay correct. + * + * Mirrors Spark `WholeStageCodegenExec`: compile once per plan, instantiate per partition, call + * `init(partitionIndex)` once, iterate. + */ + final case class CompiledKernel(factory: GeneratedClass, freshReferences: () => Array[Any]) { + def newInstance(): CometBatchKernel = + factory.generate(freshReferences()).asInstanceOf[CometBatchKernel] + } + + /** + * Plan-time predicate: can the codegen dispatcher handle this bound expression end to end? If + * it returns `None`, the serde is free to emit the codegen proto. If it returns `Some(reason)`, + * the serde must fall back (usually via `withInfo(...) + None`) so Spark runs the expression + * rather than crashing in the Janino compile at execute time. + * + * Checks: + * - every `BoundReference`'s data type is in + * [[CometBatchKernelCodegenInput.isSupportedInputType]] (i.e. the kernel has a typed getter + * for it) + * - the overall `expr.dataType` is in [[CometBatchKernelCodegenOutput.isSupportedOutputType]] + * (i.e. `allocateOutput` and `emitWrite` know how to materialize it) + * - the expression is scalar (no `AggregateFunction`, no generators). These never reach a + * scalar serde, but we belt-and-suspenders anyway. + * + * Intermediate node types are '''not''' checked. Spark's `doGenCode` materializes intermediates + * in local variables; only the leaves (which read from the row) and the root (which writes to + * the output vector) touch Arrow. + */ + def canHandle(boundExpr: Expression): Option[String] = { + if (!CometBatchKernelCodegenOutput.isSupportedOutputType(boundExpr.dataType)) { + return Some(s"codegen dispatch: unsupported output type ${boundExpr.dataType}") + } + // Reject expressions that can't be safely compiled or cached: + // - AggregateFunction / Generator: non-scalar bridge shape. + // - CodegenFallback: opts out of `doGenCode`, which our compile path assumes works. + // Passing one in would emit interpreted-eval glue that our kernel can't splice cleanly. + // - Unevaluable: unresolved plan markers. Shouldn't reach a serde, but cheap to guard. + // `isCodegenInertUnevaluable` lets the shim exclude version-specific leaves that are + // `Unevaluable` but never touched by codegen (e.g. Spark 4.0's `ResolvedCollation`, which + // lives in `Collate.collation` as a type marker; `Collate.genCode` delegates to its child). + // + // Nondeterministic and stateful expressions are accepted: the dispatcher allocates one + // kernel instance per partition (per `CometCodegenDispatchUDF.ensureKernel`) and calls + // `init(partitionIndex)` once on partition entry, so per-row state on `Rand`, + // `MonotonicallyIncreasingID`, etc. advances correctly across batches in the same + // partition and resets across partitions. + // + // `ExecSubqueryExpression` (e.g. `ScalarSubquery`, `InSubqueryExec`) is also accepted, and + // works correctly via a four-link invariant: + // 1. The surrounding Comet operator inherits `SparkPlan.waitForSubqueries`, which calls + // `updateResult()` on every `ExecSubqueryExpression` in its `expressions` before the + // operator's compute path ever reaches the JVM UDF bridge. + // 2. `ScalarSubquery.result` (and equivalents on other subquery expressions) is a plain + // mutable field on the case class. `@volatile` affects cross-thread visibility but + // not serializability: Java/Kryo serializers include it. + // 3. `SparkEnv.closureSerializer` captures the populated `result` value in the bytes + // that travel through `CometCodegenDispatchUDF`'s arg-0 transport. + // 4. The dispatcher's cache key is those exact bytes (see + // `CometCodegenDispatchUDF.CacheKey`). Different `result` values produce different + // bytes, hence different cache entries, hence a fresh compile per distinct subquery + // value. No cross-query staleness. + // + // If any of those four links breaks (a different cache-key derivation that drops `result`; + // a Comet operator that bypasses `waitForSubqueries`; a transport that strips `@volatile` + // fields), subquery correctness regresses. Keep this invariant intact when refactoring the + // cache-key or transport layers. + boundExpr.find { + case _: org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction => true + case _: org.apache.spark.sql.catalyst.expressions.Generator => true + case _: CodegenFallback => true + case u: Unevaluable if isCodegenInertUnevaluable(u) => false + case _: Unevaluable => true + case _ => false + } match { + case Some(bad) => + return Some( + s"codegen dispatch: expression ${bad.getClass.getSimpleName} not supported " + + "(aggregate, generator, codegen-fallback, or unevaluable)") + case None => + } + val badRef = boundExpr.collectFirst { + case b: BoundReference if !CometBatchKernelCodegenInput.isSupportedInputType(b.dataType) => + b + } + badRef.map(b => + s"codegen dispatch: unsupported input type ${b.dataType} at ordinal ${b.ordinal}") + } + + /** + * Allocate an Arrow output vector matching the expression's `dataType`. Thin forwarder to + * [[CometBatchKernelCodegenOutput.allocateOutput]]. Kept on this object as part of the public + * API so external callers (`CometCodegenDispatchUDF`) do not have to know about the internal + * split. + */ + def allocateOutput( + dataType: DataType, + name: String, + numRows: Int, + estimatedBytes: Int = -1): FieldVector = + CometBatchKernelCodegenOutput.allocateOutput(dataType, name, numRows, estimatedBytes) + + /** + * Output of [[generateSource]]. `body` is the raw Java source Janino will compile; `code` is + * the post-`stripOverlappingComments` wrapper Janino actually takes as input; `references` are + * the runtime objects the generated constructor pulls from via `ctx.addReferenceObj` (cached + * patterns, replacement strings, etc.). Tests inspect `body` to assert the shape of the + * generated source. See `CometCodegenSourceSuite` for examples. + */ + final case class GeneratedSource(body: String, code: CodeAndComment, references: Array[Any]) + + /** + * Generate the Java source for a kernel without compiling it. Factored out of [[compile]] so + * tests can assert on the emitted source (null short-circuit present, non-nullable `isNullAt` + * returns literal `false`, specialized emitter engaged, etc.) without paying for Janino. + */ + def generateSource( + boundExpr: Expression, + inputSchema: Seq[ArrowColumnSpec]): GeneratedSource = { + val ctx = new CodegenContext + // `BoundReference.genCode` emits `${ctx.INPUT_ROW}.getUTF8String(ord)`. We alias a local + // `row` to `this` at the top of `process` so those reads resolve to the kernel's own typed + // getters (virtual dispatch on a concrete final class, JIT devirtualizes + folds the + // switch). `row` rather than `this` because Spark's `splitExpressions` uses INPUT_ROW as the + // parameter name of any helper method it emits; `this` is a reserved keyword, so using it + // as a parameter name produces `private UTF8String helper(InternalRow this)` which Janino + // rejects. + ctx.INPUT_ROW = "row" + + val baseClass = classOf[CometBatchKernel].getName + // Resolve shaded Arrow class names at compile time so generated source + // matches the abstract method signature after Maven relocation. + val valueVectorClass = classOf[ValueVector].getName + val fieldVectorClass = classOf[FieldVector].getName + + // Pick the per-row body. Specialized emitters get priority; the default reuses + // Spark's doGenCode. + // + // `outputSetup` holds once-per-batch declarations (typed child-vector casts for complex + // output) that `emitOutputWriter` factors out of the per-row body so they do not repeat on + // every row. Scalar outputs return an empty string here. Specialized emitters (like + // RegExpReplace) do not need setup because they write directly to the root `output`. + // + // TODO(method-size): perRowBody is inlined inside process's for-loop and not split. + // Sufficiently deep trees can exceed Janino's 64KB method size; wrap in + // ctx.splitExpressionsWithCurrentInputs when hit. See + // docs/source/contributor-guide/jvm_udf_dispatch.md#open-items. + val (concreteOutClass, outputSetup, perRowBody) = boundExpr match { + case rr: RegExpReplace if canSpecializeRegExpReplace(rr) => + (classOf[VarCharVector].getName, "", specializedRegExpReplaceBody(ctx, rr, inputSchema)) + case _ => + // Class-field CSE. `generateExpressions` runs `subexpressionElimination` under the + // hood, which populates `ctx.subexprFunctions` with per-row helper calls that write + // common subexpression results into `addMutableState`-allocated fields; the returned + // `ExprCode` then references those fields. `subexprFunctionsCode` is the concatenated + // helper invocation block, spliced into the per-row body by `defaultBody` (inside the + // NullIntolerant else-branch when that short-circuit fires, otherwise before + // `ev.code`). See the "Subexpression elimination" section of the object-level + // Scaladoc for why we use this variant rather than the WSCG one. + val ev = if (SQLConf.get.subexpressionEliminationEnabled) { + ctx.generateExpressions(Seq(boundExpr), doSubexpressionElimination = true).head + } else { + boundExpr.genCode(ctx) + } + val subExprsCode = ctx.subexprFunctionsCode + val (cls, setup, snippet) = + CometBatchKernelCodegenOutput.emitOutputWriter(boundExpr.dataType, ev.value, ctx) + (cls, setup, defaultBody(boundExpr, ev, snippet, subExprsCode)) + } + + val typedFieldDecls = CometBatchKernelCodegenInput.emitInputFieldDecls(inputSchema) + val typedInputCasts = CometBatchKernelCodegenInput.emitInputCasts(inputSchema) + val decimalTypeByOrdinal = CometBatchKernelCodegenInput.decimalPrecisionByOrdinal(boundExpr) + val getters = + CometBatchKernelCodegenInput.emitTypedGetters(inputSchema, decimalTypeByOrdinal) + val nested = CometBatchKernelCodegenInput.emitNestedClasses(inputSchema) + val getArrayMethod = CometBatchKernelCodegenInput.emitGetArrayMethod(inputSchema) + val getStructMethod = CometBatchKernelCodegenInput.emitGetStructMethod(inputSchema) + val getMapMethod = CometBatchKernelCodegenInput.emitGetMapMethod(inputSchema) + + val codeBody = + s""" + |public java.lang.Object generate(Object[] references) { + | return new SpecificCometBatchKernel(references); + |} + | + |final class SpecificCometBatchKernel extends $baseClass { + | + | ${ctx.declareMutableStates()} + | + | $typedFieldDecls + | private int rowIdx; + | + | public SpecificCometBatchKernel(Object[] references) { + | super(references); + | ${ctx.initMutableStates()} + | } + | + | @Override + | public void init(int partitionIndex) { + | ${ctx.initPartition()} + | } + | + | $getters + | $getArrayMethod + | $getStructMethod + | $getMapMethod + | + | @Override + | public void process( + | $valueVectorClass[] inputs, + | $fieldVectorClass outRaw, + | int numRows) { + | $concreteOutClass output = ($concreteOutClass) outRaw; + | $typedInputCasts + | $outputSetup + | // Alias the kernel as `row` so Spark-generated `${ctx.INPUT_ROW}.method()` reads + | // resolve to the kernel's own typed getters. Helper methods that Spark splits off + | // via `splitExpressions` also take `InternalRow row` as a parameter; we pass `this` + | // implicitly since callers substitute INPUT_ROW which we've set to `row`. + | org.apache.spark.sql.catalyst.InternalRow row = this; + | for (int i = 0; i < numRows; i++) { + | this.rowIdx = i; + | $perRowBody + | } + | } + | + | ${ctx.declareAddedFunctions()} + | + |$nested + |} + """.stripMargin + + val code = CodeFormatter.stripOverlappingComments( + new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) + GeneratedSource(code.body, code, ctx.references.toArray) + } + + def compile(boundExpr: Expression, inputSchema: Seq[ArrowColumnSpec]): CompiledKernel = { + val src = generateSource(boundExpr, inputSchema) + val (clazz, _) = + try { + CodeGenerator.compile(src.code) + } catch { + case t: Throwable => + logError( + s"CometBatchKernelCodegen: compile failed for ${boundExpr.getClass.getSimpleName}. " + + s"Generated source follows:\n${src.body}", + t) + throw t + } + // One log per unique (expr, schema) compile; the caller caches the result so subsequent + // batches with the same shape reuse this compile. + val specialized = boundExpr match { + case _: RegExpReplace + if canSpecializeRegExpReplace(boundExpr.asInstanceOf[RegExpReplace]) => + " [specialized]" + case _ => "" + } + logInfo( + s"CometBatchKernelCodegen: compiled ${boundExpr.getClass.getSimpleName}$specialized " + + s"-> ${boundExpr.dataType} inputs=" + + inputSchema + .map(s => s"${s.vectorClass.getSimpleName}${if (s.nullable) "?" else ""}") + .mkString(",")) + // Freshen references per kernel allocation. See the `CompiledKernel` scaladoc for why. + // `generateSource` is pure with respect to its inputs (no hidden state) and produces a + // layout-compatible references array each time because the expression and schema are + // fixed. + val freshReferences: () => Array[Any] = () => + generateSource(boundExpr, inputSchema).references + CompiledKernel(clazz, freshReferences) + } + + /** + * Can this `RegExpReplace` instance be handled by the specialized emitter? Requires a direct + * column reference as subject, non-null foldable pattern and replacement, and offset of 1. + * Other shapes fall back to the default `doGenCode` path. + */ + private def canSpecializeRegExpReplace(rr: RegExpReplace): Boolean = { + val subjectIsBound = + rr.subject.isInstanceOf[BoundReference] && rr.subject.dataType == StringType + val patternOk = + rr.regexp.foldable && rr.regexp.dataType == StringType && rr.regexp.eval() != null + val replOk = rr.rep.foldable && rr.rep.dataType == StringType && rr.rep.eval() != null + val posIsOne = rr.pos match { + case Literal(v: Int, _) => v == 1 + case _ => false + } + subjectIsBound && patternOk && replOk && posIsOne + } + + /** + * Emit the per-row body for `RegExpReplace`. Per-row shape: read Arrow subject bytes, decode to + * Java `String`, run `Matcher.replaceAll` with a cached `Pattern` and the replacement String, + * re-encode to bytes, write to Arrow. + * + * ==Why this specialization exists== + * + * The default path runs `boundExpr.genCode(ctx)` and wraps it with kernel-side getter reads and + * a `UTF8String -> bytes -> Arrow` write. For `RegExpReplace` specifically, Spark's generated + * code does not stay in `UTF8String` space: `java.util.regex.Matcher` requires a + * `CharSequence`, so the generated code materializes a Java `String` from the input + * `UTF8String` (a UTF-8 decode, allocating a `char[]`), runs the matcher, then wraps the result + * String back into a `UTF8String` (a UTF-8 encode, allocating a `byte[]`). The per-row shape + * is: + * + * {{{ + * default: Arrow bytes -> UTF8String -> String -> Matcher -> + * String -> UTF8String -> bytes -> Arrow + * }}} + * + * On a wide-match workload (every character of the row gets replaced, so the output is the full + * row length), the round trip added ~44% per-row cost versus a tight byte-oriented loop with + * shape: + * + * {{{ + * specialized: Arrow bytes -> String -> Matcher -> String -> bytes -> Arrow + * }}} + * + * This specialization emits the byte-oriented shape directly. No `UTF8String` appears in the + * generated per-row loop. The expression remains a first-class citizen of the dispatcher + * (plan-time serde, schema-keyed caching, zero-config for the caller). + * + * ==When to add a specialization== + * + * The general rule: specialize when an expression's `doGenCode` output shape forces conversions + * that an Arrow-aware byte-oriented implementation does not pay. The common case is expressions + * whose implementation requires a Java `String` (anything using `java.util.regex` and some + * `DateTimeFormatter` expressions), because Spark's `UTF8String <-> String` round-trip is not + * free for wide outputs. Keep specializations minimal so comparisons stay honest. Avoid + * layering speculative optimizations; let the default-path optimization menu handle the common + * cases. + */ + private def specializedRegExpReplaceBody( + ctx: CodegenContext, + rr: RegExpReplace, + inputSchema: Seq[ArrowColumnSpec]): String = { + val subjectOrd = rr.subject.asInstanceOf[BoundReference].ordinal + val subjectClass = inputSchema(subjectOrd).vectorClass + require( + subjectClass == classOf[VarCharVector], + "specializedRegExpReplaceBody expects VarCharVector at ordinal " + + s"$subjectOrd, got ${subjectClass.getSimpleName}") + + val patternStr = rr.regexp.eval().toString + val replStr = rr.rep.eval().toString + val compiledPattern = java.util.regex.Pattern.compile(patternStr) + + // addReferenceObj adds a class-level field initialized from references[] in the constructor, + // so the Pattern and replacement String are resolved once, not per row. + val patternRef = + ctx.addReferenceObj("pattern", compiledPattern, "java.util.regex.Pattern") + val replRef = ctx.addReferenceObj("replacement", replStr, "java.lang.String") + + val sb = ctx.freshName("sb") + val s = ctx.freshName("s") + val r = ctx.freshName("r") + val rb = ctx.freshName("rb") + + s""" + |if (this.col$subjectOrd.isNull(i)) { + | output.setNull(i); + |} else { + | byte[] $sb = this.col$subjectOrd.get(i); + | String $s = new String($sb, java.nio.charset.StandardCharsets.UTF_8); + | String $r = $patternRef.matcher($s).replaceAll($replRef); + | byte[] $rb = $r.getBytes(java.nio.charset.StandardCharsets.UTF_8); + | output.setSafe(i, $rb, 0, $rb.length); + |} + """.stripMargin + } + + /** + * Per-row body for the default (non-specialized) path. + * + * For expressions that implement the `NullIntolerant` marker trait (null in any input -> null + * output), emits a short-circuit that skips expression evaluation entirely when any input + * column is null in the current row. This saves the full `ev.code` cost for null rows, not just + * the output setNull call. Does not change behavior, only performance. + * + * For other expressions, the standard shape applies: evaluate the expression, then check + * `ev.isNull` to decide between `setNull` and a write. Null semantics are handled internally by + * Spark's generated `ev.code`. + * + * `subExprsCode` is the CSE helper-invocation block (see the "Subexpression elimination" + * section of the object-level Scaladoc). It writes common subexpression results into class + * fields that `ev.code` reads, so it must run before `ev.code`. In the NullIntolerant short- + * circuit case it is placed inside the else branch, skipping CSE evaluation for null rows as + * well as main-body evaluation. In the default case it precedes `ev.code`. Empty string when + * CSE is disabled or the tree has no common subexpressions. + */ + private def defaultBody( + boundExpr: Expression, + ev: ExprCode, + writeSnippet: String, + subExprsCode: String): String = { + boundExpr match { + case _ if isNullIntolerant(boundExpr) && allNullIntolerant(boundExpr) => + // Every node from root to leaf is either NullIntolerant or a leaf. That transitively + // guarantees "any BoundReference null at this row -> whole expression null", so we can + // short-circuit on the union of input ordinals. Breaking the chain with a non-null- + // propagating node like `coalesce` or `if` produces the wrong result (coalesce(null,x) + // is x, not null), so the check above rejects those shapes and falls through to the + // default branch which runs Spark's own null-aware ev.code. + val inputOrdinals = + boundExpr.collect { case b: BoundReference => b.ordinal }.distinct + val nullCheck = + if (inputOrdinals.isEmpty) "false" + else inputOrdinals.map(ord => s"this.col$ord.isNull(i)").mkString(" || ") + s""" + |if ($nullCheck) { + | output.setNull(i); + |} else { + | $subExprsCode + | ${ev.code} + | $writeSnippet + |} + """.stripMargin + case _ => + // Optimization: NonNullableOutputShortCircuit. + // When the bound expression declares `nullable = false`, the `if (ev.isNull)` branch is + // dead and HotSpot may or may not fold it (it depends on whether the expression's + // `doGenCode` made `ev.isNull` a `FalseLiteral` or a variable whose value is + // false-at-runtime but not a compile-time constant from Spark's side). Drop the guard + // at source level so we don't depend on JIT folding and keep the generated body + // minimal. + if (!boundExpr.nullable) { + s""" + |$subExprsCode + |${ev.code} + |$writeSnippet + """.stripMargin + } else { + s""" + |$subExprsCode + |${ev.code} + |if (${ev.isNull}) { + | output.setNull(i); + |} else { + | $writeSnippet + |} + """.stripMargin + } + } + } + + /** + * True iff every node in the expression tree is either `NullIntolerant` or a leaf we can safely + * consider null-propagating (`BoundReference` and `Literal`). Used to gate the `NullIntolerant` + * short-circuit in [[defaultBody]]: the short-circuit collects `BoundReference` ordinals from + * the whole tree and skips `ev.code` when any of them is null, which is only correct when every + * path from a leaf to the root propagates nulls. A non- propagating node (`Coalesce`, `If`, + * `CaseWhen`, `Concat`, etc.) anywhere in the tree invalidates this assumption: `coalesce(null, + * x)` is `x`, not null, so pre-nulling on any input null would produce the wrong result. + */ + private def allNullIntolerant(expr: Expression): Boolean = + !expr.exists { + case _: BoundReference | _: Literal => false + case other => !isNullIntolerant(other) + } +} diff --git a/common/src/main/scala/org/apache/comet/udf/CometBatchKernelCodegenInput.scala b/common/src/main/scala/org/apache/comet/udf/CometBatchKernelCodegenInput.scala new file mode 100644 index 0000000000..b4cdfd4595 --- /dev/null +++ b/common/src/main/scala/org/apache/comet/udf/CometBatchKernelCodegenInput.scala @@ -0,0 +1,1095 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import scala.collection.mutable + +import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, Float4Vector, Float8Vector, IntVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, VarBinaryVector, VarCharVector} +import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression} +import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, StructType, TimestampNTZType, TimestampType} + +import org.apache.comet.udf.CometBatchKernelCodegen.{ArrayColumnSpec, ArrowColumnSpec, MapColumnSpec, ScalarColumnSpec, StructColumnSpec} +import org.apache.comet.vector.CometPlainVector + +/** + * Input-side emitters for the Arrow-direct codegen kernel. Everything that generates source for + * reading Arrow input into Spark's typed getter surface lives here: kernel field declarations, + * per-batch input casts, top-level typed-getter switches, nested `InputArray_${path}` / + * `InputStruct_${path}` / `InputMap_${path}` classes at every complex level, and the input-side + * type-support gate. + * + * ==Path encoding for nested complex types== + * + * Each position in a spec tree has a unique path string, used as the suffix on typed vector + * fields and as the identifier on nested classes. Starting from the column ordinal: + * + * - root: `col${ord}` + * - array element of `P`: `${P}_e` + * - struct field `fi` of `P`: `${P}_f${fi}` + * - map key of `P`: `${P}_k` + * - map value of `P`: `${P}_v` + * + * ==Nested-class composition== + * + * A nested class at path `P` represents a Spark `ArrayData`, `InternalRow`, or `MapData` view of + * its Arrow vector. For any complex child one level down, the class holds a pre-allocated + * instance of the corresponding inner nested class and routes `getArray` / `getStruct` / `getMap` + * / `keyArray` / `valueArray` calls to that instance after resetting it. N-deep nesting falls out + * of this: each level only knows about its immediate children. + * + * ==Unified reset protocol== + * + * `InputArray_${path}` and `InputMap_${path}` classes both take `reset(int startIdx, int length)` + * and simply capture the slice. Callers (kernel top-level switches, outer complex-getter routers, + * map `keyArray` / `valueArray` returns) compute `(startIdx, length)` from the appropriate parent + * offsets before calling `reset`. This unifies the view shape across list-backed arrays and map + * key/value slices. Structs stay flat-indexed: `InputStruct_${path}` has `reset(int rowIdx)` that + * just captures the outer row index. + * + * Paired with [[CometBatchKernelCodegenOutput]], which handles the symmetric output side. + */ +private[udf] object CometBatchKernelCodegenInput { + + /** + * Input types the kernel has a typed getter for. Recursive: `ArrayType(inner)` supported when + * `inner` is supported; `StructType` when every field is; `MapType` when key and value types + * are both supported. + */ + def isSupportedInputType(dt: DataType): Boolean = dt match { + case BooleanType | ByteType | ShortType | IntegerType | LongType => true + case FloatType | DoubleType => true + case _: DecimalType => true + case _: StringType | _: BinaryType => true + case DateType | TimestampType | TimestampNTZType => true + case ArrayType(inner, _) => isSupportedInputType(inner) + case st: StructType => st.fields.forall(f => isSupportedInputType(f.dataType)) + case mt: MapType => isSupportedInputType(mt.keyType) && isSupportedInputType(mt.valueType) + case _ => false + } + + /** + * Emit the kernel's typed vector-field declarations for every level of every input column's + * spec tree. Top-level complex columns additionally get an instance-field declaration for the + * pre-allocated nested class. Instance fields for nested-class children one level down live + * inside the parent nested class. + */ + def emitInputFieldDecls(inputSchema: Seq[ArrowColumnSpec]): String = { + val lines = new mutable.ArrayBuffer[String]() + inputSchema.zipWithIndex.foreach { case (spec, ord) => + val path = s"col$ord" + collectVectorFieldDecls(path, spec, lines) + collectTopLevelInstanceDecl(path, spec, lines) + } + lines.mkString("\n ") + } + + /** + * Primitive Arrow vector classes that we wrap in [[CometPlainVector]] at the kernel's input- + * cast time. `CometPlainVector.get*` reads use `Platform.get*` against a `final long` buffer + * address, so JIT inlines them to branchless reads with no per-call `ArrowBuf` dereference. + * `CometPlainVector.getBoolean` also includes a bit-packed data-byte cache that collapses 8 + * sequential bit reads to 1 byte read. + * + * Not wrapped: `DecimalVector` (kernel emits inline unsafe reads keyed on compile-time + * precision, so the fast/slow split stays branchless in the emitted Java rather than branching + * at runtime inside `CometPlainVector.getDecimal`), `VarCharVector` / `VarBinaryVector` (kernel + * emits inline unsafe reads to avoid the redundant `isNullAt` check inside + * `CometPlainVector.getUTF8String` / `getBinary`). + */ + private val primitiveArrowClasses: Set[Class[_]] = Set( + classOf[BitVector], + classOf[TinyIntVector], + classOf[SmallIntVector], + classOf[IntVector], + classOf[BigIntVector], + classOf[Float4Vector], + classOf[Float8Vector], + classOf[DateDayVector], + classOf[TimeStampMicroVector], + classOf[TimeStampMicroTZVector]) + + private def wrapsInCometPlainVector(cls: Class[_]): Boolean = + primitiveArrowClasses.contains(cls) + + /** + * Non-wrapped scalar columns that want a cached data-buffer address for inline unsafe reads. + * `DecimalVector` uses it for the short-precision fast path (`Platform.getLong`); + * `VarCharVector` / `VarBinaryVector` use it as the base address for `UTF8String.fromAddress` / + * `Platform.copyMemory`. See the unsafe-emitter block at the bottom of this file for why we + * inline rather than reuse `CometPlainVector`. + */ + private def needsValueAddrField(cls: Class[_]): Boolean = + cls == classOf[DecimalVector] || + cls == classOf[VarCharVector] || + cls == classOf[VarBinaryVector] + + /** Variable-width columns also want the offset-buffer address cached for `Platform.getInt`. */ + private def needsOffsetAddrField(cls: Class[_]): Boolean = + cls == classOf[VarCharVector] || cls == classOf[VarBinaryVector] + + /** + * Java method name for the null check on a column's typed field. Primitive scalars wrapped in + * [[CometPlainVector]] expose `isNullAt`; Arrow typed fields (complex containers, + * `DecimalVector`, `VarCharVector`, `VarBinaryVector`) expose `isNull`. Both read the validity + * bitmap. + */ + private def nullCheckMethod(spec: ArrowColumnSpec): String = spec match { + case sc: ScalarColumnSpec if wrapsInCometPlainVector(sc.vectorClass) => "isNullAt" + case _ => "isNull" + } + + private val cometPlainVectorName: String = classOf[CometPlainVector].getName + + private def collectVectorFieldDecls( + path: String, + spec: ArrowColumnSpec, + out: mutable.ArrayBuffer[String]): Unit = spec match { + case sc: ScalarColumnSpec => + // Primitive scalar columns (at any nesting depth) are wrapped in CometPlainVector so + // per-row reads go through JIT-inlined Platform.get* against a cached buffer address. + // DecimalVector / VarCharVector / VarBinaryVector stay on the Arrow typed field but + // cache data- and (variable-width) offset-buffer addresses for inline unsafe reads. + val fieldClass = + if (wrapsInCometPlainVector(sc.vectorClass)) cometPlainVectorName + else sc.vectorClass.getName + out += s"private $fieldClass $path;" + if (needsValueAddrField(sc.vectorClass)) { + out += s"private long ${path}_valueAddr;" + } + if (needsOffsetAddrField(sc.vectorClass)) { + out += s"private long ${path}_offsetAddr;" + } + case ar: ArrayColumnSpec => + out += s"private ${classOf[ListVector].getName} $path;" + collectVectorFieldDecls(s"${path}_e", ar.element, out) + case st: StructColumnSpec => + out += s"private ${classOf[StructVector].getName} $path;" + st.fields.zipWithIndex.foreach { case (f, fi) => + collectVectorFieldDecls(s"${path}_f$fi", f.child, out) + } + case mp: MapColumnSpec => + out += s"private ${classOf[MapVector].getName} $path;" + // Key and value vectors live at `${P}_k_e` / `${P}_v_e` so the `InputArray_${P}_k` / + // `InputArray_${P}_v` synthetic classes (which follow the array-element convention of + // reading from `${path}_e`) resolve their element reads correctly. + collectVectorFieldDecls(s"${path}_k_e", mp.key, out) + collectVectorFieldDecls(s"${path}_v_e", mp.value, out) + } + + private def collectTopLevelInstanceDecl( + path: String, + spec: ArrowColumnSpec, + out: mutable.ArrayBuffer[String]): Unit = spec match { + case _: ScalarColumnSpec => () + case _: ArrayColumnSpec => + out += s"private final InputArray_$path ${path}_arrayData = new InputArray_$path();" + case _: StructColumnSpec => + out += s"private final InputStruct_$path ${path}_structData = new InputStruct_$path();" + case _: MapColumnSpec => + out += s"private final InputMap_$path ${path}_mapData = new InputMap_$path();" + } + + /** + * Emit the per-batch cast statements. For a map column, casts the outer `MapVector`, then casts + * the inner `StructVector` (via a local variable) to extract key and value children via + * `getChildByOrdinal(0)` / `(1)`. For arrays, casts the outer `ListVector` and recurses via + * `getDataVector()`. For structs, casts the outer `StructVector` and recurses via + * `getChildByOrdinal(fi)`. + */ + def emitInputCasts(inputSchema: Seq[ArrowColumnSpec]): String = { + val lines = new mutable.ArrayBuffer[String]() + inputSchema.zipWithIndex.foreach { case (spec, ord) => + val path = s"col$ord" + collectCasts(path, spec, s"inputs[$ord]", lines) + } + lines.mkString("\n ") + } + + private def collectCasts( + path: String, + spec: ArrowColumnSpec, + source: String, + out: mutable.ArrayBuffer[String]): Unit = spec match { + case sc: ScalarColumnSpec => + if (wrapsInCometPlainVector(sc.vectorClass)) { + // Wrap in CometPlainVector so per-row reads go through Platform.get* against a final + // long buffer address. JIT inlines the one-liner getters, treating the address as a + // register-cached constant across the process loop. useDecimal128 = true matches + // Spark's 128-bit decimal storage. + out += s"this.$path = new $cometPlainVectorName($source, true);" + } else { + out += s"this.$path = (${sc.vectorClass.getName}) $source;" + } + if (needsValueAddrField(sc.vectorClass)) { + out += s"this.${path}_valueAddr = this.$path.getDataBuffer().memoryAddress();" + } + if (needsOffsetAddrField(sc.vectorClass)) { + out += s"this.${path}_offsetAddr = this.$path.getOffsetBuffer().memoryAddress();" + } + case ar: ArrayColumnSpec => + out += s"this.$path = (${classOf[ListVector].getName}) $source;" + collectCasts(s"${path}_e", ar.element, s"this.$path.getDataVector()", out) + case st: StructColumnSpec => + out += s"this.$path = (${classOf[StructVector].getName}) $source;" + st.fields.zipWithIndex.foreach { case (f, fi) => + collectCasts(s"${path}_f$fi", f.child, s"this.$path.getChildByOrdinal($fi)", out) + } + case mp: MapColumnSpec => + // MapVector's data vector is a StructVector with key at child 0 and value at child 1. + // Grab the struct through a local var and pull out the typed children. The key / value + // vectors live at the `_k_e` / `_v_e` paths so the synthetic `InputArray_${P}_k` / + // `InputArray_${P}_v` classes read them via the standard array-element convention. + val structLocal = s"${path}__mapStruct" + out += s"this.$path = (${classOf[MapVector].getName}) $source;" + out += s"${classOf[StructVector].getName} $structLocal = " + + s"(${classOf[StructVector].getName}) this.$path.getDataVector();" + collectCasts(s"${path}_k_e", mp.key, s"$structLocal.getChildByOrdinal(0)", out) + collectCasts(s"${path}_v_e", mp.value, s"$structLocal.getChildByOrdinal(1)", out) + } + + /** + * Emit the kernel's typed-getter overrides. Spark's `InternalRow` provides the base virtual + * method; the `@Override` on a final class gives the JIT enough information to devirtualize. + * Each getter switches on the column ordinal so the call site (with an inlined constant ordinal + * from `BoundReference.genCode`) folds down to a single branch. + * + * `decimalTypeByOrdinal` lets the decimal getter specialize per ordinal: when a + * `BoundReference` of `DecimalType(precision <= 18)` is the only decimal read at that ordinal, + * the emitted case skips the `BigDecimal` allocation and reads the unscaled long directly. + * + * TODO(unsafe-readers): primitive `v.get(i)` performs a bounds check that is redundant given `i + * in [0, numRows)`. See `docs/source/contributor-guide/jvm_udf_dispatch.md#open-items`. + */ + def emitTypedGetters( + inputSchema: Seq[ArrowColumnSpec], + decimalTypeByOrdinal: Map[Int, Option[DecimalType]]): String = { + val withOrd = inputSchema.zipWithIndex + + val isNullCases = withOrd.map { case (spec, ord) => + if (!spec.nullable) { + s" case $ord: return false;" + } else { + // CometPlainVector exposes `isNullAt`; Arrow-typed fields expose `isNull`. Both check + // the validity bitmap with the same semantics. + val method = spec.vectorClass match { + case cls if wrapsInCometPlainVector(cls) => "isNullAt" + case _ => "isNull" + } + s" case $ord: return this.col$ord.$method(this.rowIdx);" + } + } + + val booleanCases = withOrd.collect { + case (ArrowColumnSpec(cls, _), ord) if cls == classOf[BitVector] => + s" case $ord: return this.col$ord.getBoolean(this.rowIdx);" + } + val byteCases = withOrd.collect { + case (ArrowColumnSpec(cls, _), ord) if cls == classOf[TinyIntVector] => + s" case $ord: return this.col$ord.getByte(this.rowIdx);" + } + val shortCases = withOrd.collect { + case (ArrowColumnSpec(cls, _), ord) if cls == classOf[SmallIntVector] => + s" case $ord: return this.col$ord.getShort(this.rowIdx);" + } + val intCases = withOrd.collect { + case (ArrowColumnSpec(cls, _), ord) + if cls == classOf[IntVector] || cls == classOf[DateDayVector] => + s" case $ord: return this.col$ord.getInt(this.rowIdx);" + } + val longCases = withOrd.collect { + case (ArrowColumnSpec(cls, _), ord) + if cls == classOf[BigIntVector] || + cls == classOf[TimeStampMicroVector] || + cls == classOf[TimeStampMicroTZVector] => + s" case $ord: return this.col$ord.getLong(this.rowIdx);" + } + val floatCases = withOrd.collect { + case (ArrowColumnSpec(cls, _), ord) if cls == classOf[Float4Vector] => + s" case $ord: return this.col$ord.getFloat(this.rowIdx);" + } + val doubleCases = withOrd.collect { + case (ArrowColumnSpec(cls, _), ord) if cls == classOf[Float8Vector] => + s" case $ord: return this.col$ord.getDouble(this.rowIdx);" + } + val decimalCases = withOrd.collect { + case (ArrowColumnSpec(cls, _), ord) if cls == classOf[DecimalVector] => + val known = decimalTypeByOrdinal.getOrElse(ord, None) + val valueAddr = s"this.col${ord}_valueAddr" + val slowField = s"this.col$ord" + val fastPath = emitDecimalFastBodyUnsafe(valueAddr, "this.rowIdx", " ") + val slowPath = emitDecimalSlowBody(slowField, "this.rowIdx", " ") + val body = known match { + case Some(dt) if dt.precision <= 18 => fastPath + case Some(_) => slowPath + case None => + s""" if (precision <= 18) { + |$fastPath + | } else { + |$slowPath + | }""".stripMargin + } + s""" case $ord: { + |$body + | }""".stripMargin + } + val binaryCases = withOrd.collect { + case (ArrowColumnSpec(cls, _), ord) if cls == classOf[VarBinaryVector] => + s""" case $ord: { + |${emitBinaryBodyUnsafe( + s"this.col${ord}_valueAddr", + s"this.col${ord}_offsetAddr", + "this.rowIdx", + " ")} + | }""".stripMargin + } + val utf8Cases = withOrd.collect { + case (ArrowColumnSpec(cls, _), ord) if cls == classOf[VarCharVector] => + s""" case $ord: { + |${emitUtf8BodyUnsafe( + s"this.col${ord}_valueAddr", + s"this.col${ord}_offsetAddr", + "this.rowIdx", + " ")} + | }""".stripMargin + } + + Seq( + emitOrdinalSwitch("public boolean isNullAt(int ordinal)", "isNullAt", isNullCases), + emitOrdinalSwitch("public boolean getBoolean(int ordinal)", "getBoolean", booleanCases), + emitOrdinalSwitch("public byte getByte(int ordinal)", "getByte", byteCases), + emitOrdinalSwitch("public short getShort(int ordinal)", "getShort", shortCases), + emitOrdinalSwitch("public int getInt(int ordinal)", "getInt", intCases), + emitOrdinalSwitch("public long getLong(int ordinal)", "getLong", longCases), + emitOrdinalSwitch("public float getFloat(int ordinal)", "getFloat", floatCases), + emitOrdinalSwitch("public double getDouble(int ordinal)", "getDouble", doubleCases), + emitOrdinalSwitch( + "public org.apache.spark.sql.types.Decimal getDecimal(" + + "int ordinal, int precision, int scale)", + "getDecimal", + decimalCases), + emitOrdinalSwitch("public byte[] getBinary(int ordinal)", "getBinary", binaryCases), + emitOrdinalSwitch( + "public org.apache.spark.unsafe.types.UTF8String getUTF8String(int ordinal)", + "getUTF8String", + utf8Cases)).mkString + } + + /** + * Build a per-ordinal map of the `DecimalType` observed on `BoundReference`s in the bound + * expression. Used by [[emitTypedGetters]] to emit a compile-time-specialized `getDecimal` case + * per ordinal. + */ + def decimalPrecisionByOrdinal(boundExpr: Expression): Map[Int, Option[DecimalType]] = { + boundExpr + .collect { + case b: BoundReference if b.dataType.isInstanceOf[DecimalType] => + b.ordinal -> b.dataType.asInstanceOf[DecimalType] + } + .groupBy(_._1) + .map { case (ord, pairs) => + val distinct = pairs.map(_._2).toSet + ord -> (if (distinct.size == 1) Some(distinct.head) else None) + } + } + + /** + * Emit every nested class needed for every complex level of every input column. For an + * `ArrayColumnSpec` we emit `InputArray_${path}`; for a `StructColumnSpec` + * `InputStruct_${path}`; for a `MapColumnSpec` `InputMap_${path}` plus the `InputArray` classes + * for the key and value slices (because Spark's `MapData.keyArray()` / `valueArray()` return + * `ArrayData` - same view shape as any other array). + */ + def emitNestedClasses(inputSchema: Seq[ArrowColumnSpec]): String = { + val out = new mutable.ArrayBuffer[String]() + inputSchema.zipWithIndex.foreach { case (spec, ord) => + collectNestedClasses(s"col$ord", spec, out) + } + out.mkString("\n") + } + + private def collectNestedClasses( + path: String, + spec: ArrowColumnSpec, + out: mutable.ArrayBuffer[String]): Unit = spec match { + case _: ScalarColumnSpec => () + case ar: ArrayColumnSpec => + out += emitArrayClass(path, ar) + collectNestedClasses(s"${path}_e", ar.element, out) + case st: StructColumnSpec => + out += emitStructClass(path, st) + st.fields.zipWithIndex.foreach { case (f, fi) => + collectNestedClasses(s"${path}_f$fi", f.child, out) + } + case mp: MapColumnSpec => + out += emitMapClass(path, mp) + // Emit InputArray_${path}_k and InputArray_${path}_v - the ArrayData views returned by + // `MapData.keyArray()` / `valueArray()`. They follow the standard array-element + // convention: each reads from `${classPath}_e` which maps to the key / value vector + // emitted at `${path}_k_e` / `${path}_v_e` by [[collectVectorFieldDecls]]. Instance + // fields for complex key / value elements (one level deeper) live inside these array + // classes via [[instanceDeclaration]]. + out += emitArrayClass( + s"${path}_k", + ArrayColumnSpec(nullable = true, elementSparkType = mp.keySparkType, element = mp.key)) + out += emitArrayClass( + s"${path}_v", + ArrayColumnSpec( + nullable = true, + elementSparkType = mp.valueSparkType, + element = mp.value)) + // Recurse into the key / value specs at their canonical paths (${path}_k_e / + // ${path}_v_e) so nested complex keys / values get their own nested classes. + collectNestedClasses(s"${path}_k_e", mp.key, out) + collectNestedClasses(s"${path}_v_e", mp.value, out) + } + + // --------------------------------------------------------------------------------------------- + // Shared helpers for complex-getter routing. A "list-backed child reset" computes + // `(startIdx, length)` for an inner instance from a ListVector / MapVector's offsets at a + // parent-provided index and calls `reset(startIdx, length)`. + // --------------------------------------------------------------------------------------------- + + private def emitListBackedChildReset( + parentVectorPath: String, + indexExpr: String, + innerInstanceField: String): String = + s""" int __idx = $indexExpr; + | int __s = $parentVectorPath.getElementStartIndex(__idx); + | int __e = $parentVectorPath.getElementEndIndex(__idx); + | $innerInstanceField.reset(__s, __e - __s);""".stripMargin + + /** + * Emit one `InputArray_${path}` nested class. Unified slice-based reset: callers pass + * `(startIdx, length)` directly. + * + * Key/value arrays of a map share this exact shape - the instance fields for their complex + * elements (if any) are emitted from [[emitArrayElementGetter]]; the vector fields they read + * from are at `${path}_e` (following the array-element path convention), which maps to + * `col${N}_k_e` or `col${N}_v_e` when the array represents a map key/value slice. + * + * NOTE: when this class is used for a map's key or value view and the underlying key/value is + * scalar, there is no `${path}_e` vector field - the map's key/value vector sits at `${path}` + * itself (e.g. `col0_k`). See [[emitArrayElementGetter]] for how that is handled: scalar + * element emission reads from `${path}_e`, but for map views the element vector IS the path + * itself. We rename the element path in [[emitMapClass]] below. + */ + private def emitArrayClass(path: String, spec: ArrayColumnSpec): String = { + val baseClassName = classOf[CometArrayData].getName + val elemPath = s"${path}_e" + val innerInstance = instanceDeclaration(elemPath, spec.element) + val isNullAt = + s""" @Override + | public boolean isNullAt(int i) { + | return $elemPath.${nullCheckMethod(spec.element)}(startIndex + i); + | }""".stripMargin + val elementGetter = emitArrayElementGetter(path, spec) + s""" private final class InputArray_$path extends $baseClassName { + | private int startIndex; + | private int length; + |$innerInstance + | + | void reset(int startIdx, int len) { + | this.startIndex = startIdx; + | this.length = len; + | } + | + | @Override + | public int numElements() { + | return length; + | } + | + |$isNullAt + | + |$elementGetter + | } + |""".stripMargin + } + + /** + * Emit the element getter body for a nested `InputArray_${path}`. Scalar element -> direct + * typed read. Complex element -> `getArray(i)` / `getStruct(i, n)` / `getMap(i)` that resets + * the inner instance. + */ + private def emitArrayElementGetter(path: String, spec: ArrayColumnSpec): String = { + val elemPath = s"${path}_e" + spec.element match { + case _: ScalarColumnSpec => + emitArrayElementScalarGetter(spec.elementSparkType, elemPath) + case _: ArrayColumnSpec => + val reset = emitListBackedChildReset(elemPath, "startIndex + i", s"${elemPath}_arrayData") + s""" @Override + | public org.apache.spark.sql.catalyst.util.ArrayData getArray(int i) { + |$reset + | return ${elemPath}_arrayData; + | }""".stripMargin + case _: StructColumnSpec => + s""" @Override + | public org.apache.spark.sql.catalyst.InternalRow getStruct(int i, int numFields) { + | ${elemPath}_structData.reset(startIndex + i); + | return ${elemPath}_structData; + | }""".stripMargin + case _: MapColumnSpec => + val reset = emitListBackedChildReset(elemPath, "startIndex + i", s"${elemPath}_mapData") + s""" @Override + | public org.apache.spark.sql.catalyst.util.MapData getMap(int i) { + |$reset + | return ${elemPath}_mapData; + | }""".stripMargin + } + } + + /** + * Emit the scalar-element getter override for a nested `InputArray_${path}`. Only the getter + * matching the element type is overridden; any other getter inherits the base class's + * `UnsupportedOperationException`. + */ + private def emitArrayElementScalarGetter(elemType: DataType, childField: String): String = + elemType match { + case BooleanType => + s""" @Override + | public boolean getBoolean(int i) { + | return $childField.getBoolean(startIndex + i); + | }""".stripMargin + case ByteType => + s""" @Override + | public byte getByte(int i) { + | return $childField.getByte(startIndex + i); + | }""".stripMargin + case ShortType => + s""" @Override + | public short getShort(int i) { + | return $childField.getShort(startIndex + i); + | }""".stripMargin + case IntegerType | DateType => + s""" @Override + | public int getInt(int i) { + | return $childField.getInt(startIndex + i); + | }""".stripMargin + case LongType | TimestampType | TimestampNTZType => + s""" @Override + | public long getLong(int i) { + | return $childField.getLong(startIndex + i); + | }""".stripMargin + case FloatType => + s""" @Override + | public float getFloat(int i) { + | return $childField.getFloat(startIndex + i); + | }""".stripMargin + case DoubleType => + s""" @Override + | public double getDouble(int i) { + | return $childField.getDouble(startIndex + i); + | }""".stripMargin + case dt: DecimalType => + val body = + if (dt.precision <= 18) { + emitDecimalFastBodyUnsafe(s"${childField}_valueAddr", "startIndex + i", " ") + } else { + emitDecimalSlowBody(childField, "startIndex + i", " ") + } + s""" @Override + | public org.apache.spark.sql.types.Decimal getDecimal( + | int i, int precision, int scale) { + |$body + | }""".stripMargin + case _: StringType => + s""" @Override + | public org.apache.spark.unsafe.types.UTF8String getUTF8String(int i) { + |${emitUtf8BodyUnsafe( + s"${childField}_valueAddr", + s"${childField}_offsetAddr", + "startIndex + i", + " ")} + | }""".stripMargin + case BinaryType => + s""" @Override + | public byte[] getBinary(int i) { + |${emitBinaryBodyUnsafe( + s"${childField}_valueAddr", + s"${childField}_offsetAddr", + "startIndex + i", + " ")} + | }""".stripMargin + case other => + throw new UnsupportedOperationException( + s"nested ArrayData: unsupported element type $other") + } + + /** + * Emit the kernel's `@Override public ArrayData getArray(int ordinal)` method. Each case reads + * `(startIdx, length)` from the outer `ListVector`'s offsets at the current row and calls the + * pre-allocated instance's unified `reset(startIdx, length)`. + */ + def emitGetArrayMethod(inputSchema: Seq[ArrowColumnSpec]): String = { + val cases = inputSchema.zipWithIndex.collect { case (_: ArrayColumnSpec, ord) => + val reset = + emitListBackedChildReset(s"this.col$ord", "this.rowIdx", s"this.col${ord}_arrayData") + s""" case $ord: { + |$reset + | return this.col${ord}_arrayData; + | }""".stripMargin + } + if (cases.isEmpty) { + "" + } else { + s""" + | @Override + | public org.apache.spark.sql.catalyst.util.ArrayData getArray(int ordinal) { + | switch (ordinal) { + |${cases.mkString("\n")} + | default: throw new UnsupportedOperationException( + | "getArray out of range: " + ordinal); + | } + | } + |""".stripMargin + } + } + + /** + * Emit one `InputStruct_${path}` nested class. Flat-indexed: `reset(int outerRowIdx)` just + * captures the index. Scalar getters switch on field ordinal; complex getters route to inner + * instances (offsets computed for array/map children; rowIdx passed through for struct + * children). + */ + private def emitStructClass(path: String, spec: StructColumnSpec): String = { + val baseClassName = classOf[CometInternalRow].getName + val innerInstances = spec.fields.zipWithIndex + .flatMap { case (f, fi) => + val fieldPath = s"${path}_f$fi" + Some(instanceDeclaration(fieldPath, f.child)).filter(_.nonEmpty) + } + .mkString("\n") + val isNullCases = spec.fields.zipWithIndex.map { + case (f, fi) if !f.nullable => + s" case $fi: return false;" + case (f, fi) => + s" case $fi: return ${path}_f$fi.${nullCheckMethod(f.child)}(this.rowIdx);" + } + val scalarGetters = emitStructScalarGetters(path, spec) + val complexGetters = emitStructComplexGetters(path, spec) + s""" private final class InputStruct_$path extends $baseClassName { + | private int rowIdx; + |$innerInstances + | + | void reset(int outerRowIdx) { + | this.rowIdx = outerRowIdx; + | } + | + | @Override + | public int numFields() { + | return ${spec.fields.length}; + | } + | + | @Override + | public boolean isNullAt(int ordinal) { + | switch (ordinal) { + |${isNullCases.mkString("\n")} + | default: throw new UnsupportedOperationException( + | "InputStruct_$path.isNullAt out of range: " + ordinal); + | } + | } + | + |$scalarGetters + |$complexGetters + | } + |""".stripMargin + } + + private def emitStructScalarGetters(path: String, spec: StructColumnSpec): String = { + val withOrd = spec.fields.zipWithIndex + val scalarOrd = withOrd.filter { case (f, _) => f.child.isInstanceOf[ScalarColumnSpec] } + + def fieldReadScalar(fi: Int, dt: DataType): String = dt match { + case BooleanType => + s" case $fi: return ${path}_f$fi.getBoolean(this.rowIdx);" + case ByteType => + s" case $fi: return ${path}_f$fi.getByte(this.rowIdx);" + case ShortType => + s" case $fi: return ${path}_f$fi.getShort(this.rowIdx);" + case IntegerType | DateType => + s" case $fi: return ${path}_f$fi.getInt(this.rowIdx);" + case LongType | TimestampType | TimestampNTZType => + s" case $fi: return ${path}_f$fi.getLong(this.rowIdx);" + case FloatType => + s" case $fi: return ${path}_f$fi.getFloat(this.rowIdx);" + case DoubleType => + s" case $fi: return ${path}_f$fi.getDouble(this.rowIdx);" + case BinaryType => + s""" case $fi: { + |${emitBinaryBodyUnsafe( + s"${path}_f${fi}_valueAddr", + s"${path}_f${fi}_offsetAddr", + "this.rowIdx", + " ")} + | }""".stripMargin + case _: StringType => + s""" case $fi: { + |${emitUtf8BodyUnsafe( + s"${path}_f${fi}_valueAddr", + s"${path}_f${fi}_offsetAddr", + "this.rowIdx", + " ")} + | }""".stripMargin + case _: DecimalType => + throw new IllegalStateException("decimal handled separately") + case other => + throw new UnsupportedOperationException( + s"nested InputStruct getter: unsupported field type $other") + } + + val booleanCases = + scalarOrd.collect { + case (f, fi) if f.sparkType == BooleanType => + fieldReadScalar(fi, BooleanType) + } + val byteCases = + scalarOrd.collect { + case (f, fi) if f.sparkType == ByteType => + fieldReadScalar(fi, ByteType) + } + val shortCases = + scalarOrd.collect { + case (f, fi) if f.sparkType == ShortType => + fieldReadScalar(fi, ShortType) + } + val intCases = scalarOrd.collect { + case (f, fi) if f.sparkType == IntegerType || f.sparkType == DateType => + fieldReadScalar(fi, IntegerType) + } + val longCases = scalarOrd.collect { + case (f, fi) + if f.sparkType == LongType || f.sparkType == TimestampType || + f.sparkType == TimestampNTZType => + fieldReadScalar(fi, LongType) + } + val floatCases = + scalarOrd.collect { + case (f, fi) if f.sparkType == FloatType => + fieldReadScalar(fi, FloatType) + } + val doubleCases = + scalarOrd.collect { + case (f, fi) if f.sparkType == DoubleType => + fieldReadScalar(fi, DoubleType) + } + val binaryCases = + scalarOrd.collect { + case (f, fi) if f.sparkType == BinaryType => + fieldReadScalar(fi, BinaryType) + } + val utf8Cases = scalarOrd.collect { + case (f, fi) if f.sparkType.isInstanceOf[StringType] => fieldReadScalar(fi, f.sparkType) + } + + val decimalCases = scalarOrd.collect { + case (f, fi) if f.sparkType.isInstanceOf[DecimalType] => + val dt = f.sparkType.asInstanceOf[DecimalType] + val field = s"${path}_f$fi" + val body = + if (dt.precision <= 18) { + emitDecimalFastBodyUnsafe(s"${field}_valueAddr", "this.rowIdx", " ") + } else { + emitDecimalSlowBody(field, "this.rowIdx", " ") + } + s""" case $fi: { + |$body + | }""".stripMargin + } + + Seq( + structSwitch("public boolean getBoolean(int ordinal)", "getBoolean", booleanCases), + structSwitch("public byte getByte(int ordinal)", "getByte", byteCases), + structSwitch("public short getShort(int ordinal)", "getShort", shortCases), + structSwitch("public int getInt(int ordinal)", "getInt", intCases), + structSwitch("public long getLong(int ordinal)", "getLong", longCases), + structSwitch("public float getFloat(int ordinal)", "getFloat", floatCases), + structSwitch("public double getDouble(int ordinal)", "getDouble", doubleCases), + structSwitch( + "public org.apache.spark.sql.types.Decimal getDecimal(" + + "int ordinal, int precision, int scale)", + "getDecimal", + decimalCases), + structSwitch("public byte[] getBinary(int ordinal)", "getBinary", binaryCases), + structSwitch( + "public org.apache.spark.unsafe.types.UTF8String getUTF8String(int ordinal)", + "getUTF8String", + utf8Cases)).mkString + } + + private def emitStructComplexGetters(path: String, spec: StructColumnSpec): String = { + val getArrayCases = spec.fields.zipWithIndex.collect { + case (f, fi) if f.child.isInstanceOf[ArrayColumnSpec] => + val fieldPath = s"${path}_f$fi" + val reset = emitListBackedChildReset(fieldPath, "this.rowIdx", s"${fieldPath}_arrayData") + s""" case $fi: { + |$reset + | return ${fieldPath}_arrayData; + | }""".stripMargin + } + val getStructCases = spec.fields.zipWithIndex.collect { + case (f, fi) if f.child.isInstanceOf[StructColumnSpec] => + val fieldPath = s"${path}_f$fi" + s""" case $fi: { + | ${fieldPath}_structData.reset(this.rowIdx); + | return ${fieldPath}_structData; + | }""".stripMargin + } + val getMapCases = spec.fields.zipWithIndex.collect { + case (f, fi) if f.child.isInstanceOf[MapColumnSpec] => + val fieldPath = s"${path}_f$fi" + val reset = emitListBackedChildReset(fieldPath, "this.rowIdx", s"${fieldPath}_mapData") + s""" case $fi: { + |$reset + | return ${fieldPath}_mapData; + | }""".stripMargin + } + Seq( + structSwitch( + "public org.apache.spark.sql.catalyst.util.ArrayData getArray(int ordinal)", + "getArray", + getArrayCases), + structSwitch( + "public org.apache.spark.sql.catalyst.InternalRow getStruct(int ordinal, int numFields)", + "getStruct", + getStructCases), + structSwitch( + "public org.apache.spark.sql.catalyst.util.MapData getMap(int ordinal)", + "getMap", + getMapCases)).mkString + } + + /** + * Emit one `InputMap_${path}` nested class. Holds the slice `(startIndex, length)` and routes + * `keyArray()` / `valueArray()` through pre-allocated `InputArray_${path}_k` / + * `InputArray_${path}_v` instances (emitted by [[collectNestedClasses]]). + */ + private def emitMapClass(path: String, spec: MapColumnSpec): String = { + val _ = spec // key/value arrays declared via path convention below + val baseClassName = classOf[CometMapData].getName + val keyPath = s"${path}_k" + val valPath = s"${path}_v" + s""" private final class InputMap_$path extends $baseClassName { + | private int startIndex; + | private int length; + | private final InputArray_$keyPath ${keyPath}_arrayData = new InputArray_$keyPath(); + | private final InputArray_$valPath ${valPath}_arrayData = new InputArray_$valPath(); + | + | void reset(int startIdx, int len) { + | this.startIndex = startIdx; + | this.length = len; + | } + | + | @Override + | public int numElements() { + | return length; + | } + | + | @Override + | public org.apache.spark.sql.catalyst.util.ArrayData keyArray() { + | ${keyPath}_arrayData.reset(this.startIndex, this.length); + | return ${keyPath}_arrayData; + | } + | + | @Override + | public org.apache.spark.sql.catalyst.util.ArrayData valueArray() { + | ${valPath}_arrayData.reset(this.startIndex, this.length); + | return ${valPath}_arrayData; + | } + | } + |""".stripMargin + } + + /** + * Emit the kernel's top-level `@Override public MapData getMap(int ordinal)` method when the + * input schema has at least one map-typed column at the top level; empty string otherwise. + */ + def emitGetMapMethod(inputSchema: Seq[ArrowColumnSpec]): String = { + val cases = inputSchema.zipWithIndex.collect { case (_: MapColumnSpec, ord) => + val reset = + emitListBackedChildReset(s"this.col$ord", "this.rowIdx", s"this.col${ord}_mapData") + s""" case $ord: { + |$reset + | return this.col${ord}_mapData; + | }""".stripMargin + } + if (cases.isEmpty) { + "" + } else { + s""" + | @Override + | public org.apache.spark.sql.catalyst.util.MapData getMap(int ordinal) { + | switch (ordinal) { + |${cases.mkString("\n")} + | default: throw new UnsupportedOperationException( + | "getMap out of range: " + ordinal); + | } + | } + |""".stripMargin + } + } + + /** + * Return the inner-instance field declaration for one complex spec at the given path, or an + * empty string for a scalar spec. Used inside nested-class bodies to declare pre-allocated + * child-view instances. + */ + private def instanceDeclaration(path: String, spec: ArrowColumnSpec): String = spec match { + case _: ScalarColumnSpec => "" + case _: ArrayColumnSpec => + s" private final InputArray_$path ${path}_arrayData = new InputArray_$path();" + case _: StructColumnSpec => + s" private final InputStruct_$path ${path}_structData = new InputStruct_$path();" + case _: MapColumnSpec => + s" private final InputMap_$path ${path}_mapData = new InputMap_$path();" + } + + private def structSwitch(methodSig: String, label: String, cases: Seq[String]): String = { + if (cases.isEmpty) { + "" + } else { + s""" + | @Override + | $methodSig { + | switch (ordinal) { + |${cases.mkString("\n")} + | default: throw new UnsupportedOperationException( + | "$label out of range: " + ordinal); + | } + | } + """.stripMargin + } + } + + /** + * Emit the kernel's top-level `@Override public InternalRow getStruct(int ordinal, int + * numFields)` method when the input schema has at least one struct-typed column. + */ + def emitGetStructMethod(inputSchema: Seq[ArrowColumnSpec]): String = { + val cases = inputSchema.zipWithIndex.collect { case (_: StructColumnSpec, ord) => + s""" case $ord: { + | this.col${ord}_structData.reset(this.rowIdx); + | return this.col${ord}_structData; + | }""".stripMargin + } + if (cases.isEmpty) { + "" + } else { + s""" + | @Override + | public org.apache.spark.sql.catalyst.InternalRow getStruct(int ordinal, int numFields) { + | switch (ordinal) { + |${cases.mkString("\n")} + | default: throw new UnsupportedOperationException( + | "getStruct out of range: " + ordinal); + | } + | } + |""".stripMargin + } + } + + private def emitOrdinalSwitch(methodSig: String, label: String, cases: Seq[String]): String = { + if (cases.isEmpty) { + "" + } else { + s""" + | @Override + | $methodSig { + | switch (ordinal) { + |${cases.mkString("\n")} + | default: throw new UnsupportedOperationException( + | "$label out of range: " + ordinal); + | } + | } + """.stripMargin + } + } + + // ------------------------------------------------------------------------------------------- + // Scalar-read body templates. Each helper emits the per-type read statements parameterized on + // a Java expression for the row/slot index (`idx`), the cached buffer address(es) for unsafe + // reads (`valueAddr`, `offsetAddr`), or the Arrow typed field (`field`) for the slow-path + // decimal case that still needs `getObject`. `ind` is the per-line indent prefix; + // continuation lines add four spaces. Callers wrap the output in switch cases or method + // overrides. + // + // The VarChar / VarBinary unsafe emitters below duplicate what CometPlainVector.getUTF8String + // and getBinary do today, with two differences: they skip CometPlainVector's internal + // isNullAt (redundant here because the kernel's caller already handled it) and they read the + // offset-buffer address from a kernel-cached field rather than re-dereferencing the ArrowBuf. + // Once apache/datafusion-comet#4280 (offsetBufferAddress caching) and #4279 (validity-bitmap + // byte cache) land, both differences stop mattering and `emitUtf8BodyUnsafe` / + // `emitBinaryBodyUnsafe` can be deleted in favor of `CometPlainVector` reuse for variable- + // width. The decimal-fast variant has its own motivation (compile-time precision + // specialization) unrelated to those issues. + // ------------------------------------------------------------------------------------------- + + /** Parenthesize `idx` when it contains whitespace, to keep `(long) idx * 16L` well-formed. */ + private def castableIdx(idx: String): String = if (idx.contains(' ')) s"($idx)" else idx + + private def emitDecimalSlowBody(field: String, idx: String, ind: String): String = { + val cont = ind + " " + s"""${ind}java.math.BigDecimal bd = $field.getObject($idx); + |${ind}return org.apache.spark.sql.types.Decimal$$.MODULE$$ + |$cont.apply(bd, precision, scale);""".stripMargin + } + + private def emitDecimalFastBodyUnsafe(valueAddr: String, idx: String, ind: String): String = { + val cont = ind + " " + val i = castableIdx(idx) + s"""${ind}long unscaled = org.apache.spark.unsafe.Platform.getLong(null, + |$cont$valueAddr + (long) $i * 16L); + |${ind}return org.apache.spark.sql.types.Decimal$$.MODULE$$ + |$cont.createUnsafe(unscaled, precision, scale);""".stripMargin + } + + private def emitUtf8BodyUnsafe( + valueAddr: String, + offsetAddr: String, + idx: String, + ind: String): String = { + val cont = ind + " " + val i = castableIdx(idx) + s"""${ind}int s = org.apache.spark.unsafe.Platform.getInt(null, + |$cont$offsetAddr + (long) $i * 4L); + |${ind}int e = org.apache.spark.unsafe.Platform.getInt(null, + |$cont$offsetAddr + ((long) $i + 1L) * 4L); + |${ind}return org.apache.spark.unsafe.types.UTF8String + |$cont.fromAddress(null, $valueAddr + s, e - s);""".stripMargin + } + + private def emitBinaryBodyUnsafe( + valueAddr: String, + offsetAddr: String, + idx: String, + ind: String): String = { + val cont = ind + " " + val i = castableIdx(idx) + s"""${ind}int s = org.apache.spark.unsafe.Platform.getInt(null, + |$cont$offsetAddr + (long) $i * 4L); + |${ind}int e = org.apache.spark.unsafe.Platform.getInt(null, + |$cont$offsetAddr + ((long) $i + 1L) * 4L); + |${ind}int len = e - s; + |${ind}byte[] out = new byte[len]; + |${ind}org.apache.spark.unsafe.Platform.copyMemory(null, $valueAddr + s, out, + |${cont}org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET, len); + |${ind}return out;""".stripMargin + } +} diff --git a/common/src/main/scala/org/apache/comet/udf/CometBatchKernelCodegenOutput.scala b/common/src/main/scala/org/apache/comet/udf/CometBatchKernelCodegenOutput.scala new file mode 100644 index 0000000000..4dd4d02497 --- /dev/null +++ b/common/src/main/scala/org/apache/comet/udf/CometBatchKernelCodegenOutput.scala @@ -0,0 +1,370 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, FieldVector, Float4Vector, Float8Vector, IntVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, VarBinaryVector, VarCharVector} +import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.comet.util.Utils +import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, StructType, TimestampNTZType, TimestampType} + +import org.apache.comet.CometArrowAllocator + +/** + * Output-side emitters for the Arrow-direct codegen kernel. Everything that writes a computed + * value into an Arrow output vector lives here: [[allocateOutput]], [[emitOutputWriter]] (the + * entry point for the kernel's top-level write), [[emitWrite]] (recursive per-type write), the + * output vector-class lookup, and the output-side type-support gate. + * + * Paired with [[CometBatchKernelCodegenInput]], which handles the symmetric input side. + */ +private[udf] object CometBatchKernelCodegenOutput { + + /** + * Output types [[allocateOutput]] and [[emitOutputWriter]] can materialize. Recursive: complex + * types are supported when their children are. + */ + def isSupportedOutputType(dt: DataType): Boolean = dt match { + case BooleanType | ByteType | ShortType | IntegerType | LongType => true + case FloatType | DoubleType => true + case _: DecimalType => true + case _: StringType | _: BinaryType => true + case DateType | TimestampType | TimestampNTZType => true + case ArrayType(inner, _) => isSupportedOutputType(inner) + case st: StructType => st.fields.forall(f => isSupportedOutputType(f.dataType)) + case mt: MapType => + isSupportedOutputType(mt.keyType) && isSupportedOutputType(mt.valueType) + case _ => false + } + + /** + * Allocate an Arrow output vector matching `dataType`. Delegates field and vector construction + * to [[Utils.toArrowField]] + `Field.createVector`, which is the pattern the rest of Comet uses + * to go Spark -> Arrow and handles complex-type wiring (including Arrow's non-null-key and + * non-null-entries invariants on `MapVector`). + * + * For variable-length scalar outputs (`StringType`, `BinaryType`), callers can pass + * `estimatedBytes` to pre-size the data buffer and avoid `setSafe` reallocation mid-loop. The + * hint is only applied when the root vector is `VarCharVector` or `VarBinaryVector`; inside a + * `ListVector` / `StructVector` / `MapVector`, the parent's `allocateNew` reallocates child + * buffers at default size, so a leaf hint would be lost. + * + * Closes the vector on any failure between construction and return so a partially-initialized + * tree does not leak buffers back to the allocator. + */ + def allocateOutput( + dataType: DataType, + name: String, + numRows: Int, + estimatedBytes: Int = -1): FieldVector = { + val field = Utils.toArrowField(name, dataType, nullable = true, "UTC") + val vec = field.createVector(CometArrowAllocator).asInstanceOf[FieldVector] + try { + vec.setInitialCapacity(numRows) + vec match { + case v: VarCharVector if estimatedBytes > 0 => + v.allocateNew(estimatedBytes.toLong, numRows) + case v: VarBinaryVector if estimatedBytes > 0 => + v.allocateNew(estimatedBytes.toLong, numRows) + case _ => + vec.allocateNew() + } + vec + } catch { + case t: Throwable => + try vec.close() + catch { case _: Throwable => () } + throw t + } + } + + /** + * Split output for a complex-type write: `setup` holds once-per-batch declarations (typed + * child-vector casts) and lives outside the per-row for-loop; `perRow` holds the statements + * executed for each row. Scalar writes have empty setup. + */ + private case class OutputEmit(setup: String, perRow: String) + + /** + * Returns `(concreteVectorClassName, batchSetup, perRowSnippet)` for the expression's output + * type at the root of the generated kernel. `output` is already cast to + * `concreteVectorClassName` in `process`'s prelude, so `emitWrite`'s complex-type branches can + * hoist child casts straight off `output` without re-casting it per row. + */ + def emitOutputWriter( + dataType: DataType, + valueTerm: String, + ctx: CodegenContext): (String, String, String) = { + val cls = outputVectorClass(dataType) + val emit = emitWrite("output", "i", valueTerm, dataType, ctx) + (cls, emit.setup, emit.perRow) + } + + /** + * Concrete Arrow vector class name for the given output type. The name is used to cast `outRaw` + * to the right type at the top of the generated `process` method, so that subsequent writes + * through `emitWrite` can call vector-specific methods without further casts. + */ + private def outputVectorClass(dataType: DataType): String = dataType match { + case BooleanType => classOf[BitVector].getName + case ByteType => classOf[TinyIntVector].getName + case ShortType => classOf[SmallIntVector].getName + case IntegerType => classOf[IntVector].getName + case LongType => classOf[BigIntVector].getName + case FloatType => classOf[Float4Vector].getName + case DoubleType => classOf[Float8Vector].getName + case _: DecimalType => classOf[DecimalVector].getName + case _: StringType => classOf[VarCharVector].getName + case BinaryType => classOf[VarBinaryVector].getName + case DateType => classOf[DateDayVector].getName + case TimestampType => classOf[TimeStampMicroTZVector].getName + case TimestampNTZType => classOf[TimeStampMicroVector].getName + case _: ArrayType => classOf[ListVector].getName + case _: StructType => classOf[StructVector].getName + case _: MapType => classOf[MapVector].getName + case other => + throw new UnsupportedOperationException( + s"CometBatchKernelCodegen.outputVectorClass: unsupported output type $other") + } + + /** + * Composable write emitter. Returns an [[OutputEmit]] whose `setup` declares once-per-batch + * typed child-vector casts (hoisted above the `process` for-loop) and whose `perRow` writes the + * value produced by `source` into `targetVec` at index `idx`. `targetVec` is assumed to be + * already typed to the concrete Arrow vector class for `dataType` at the call site (via the + * prelude cast in `process` for the root, or via a setup cast declared by the caller for nested + * children). + * + * Scalars emit `perRow` only; complex types (`ArrayType` / `StructType` / `MapType`) emit both + * setup (child-vector casts) and perRow (loops, null guards, recursive writes). Inner + * `emitWrite` calls return their own setup, which the outer caller concatenates so child-of- + * child casts bubble up to the batch prelude. + */ + private def emitWrite( + targetVec: String, + idx: String, + source: String, + dataType: DataType, + ctx: CodegenContext): OutputEmit = dataType match { + case BooleanType => + OutputEmit("", s"$targetVec.set($idx, $source ? 1 : 0);") + case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | DateType | + TimestampType | TimestampNTZType => + // All scalar primitives and date/time types share the direct `set(idx, value)` shape. + // Spark's codegen already emits the correct primitive Java type for each; Arrow's + // typed vectors accept the matching primitive in their `set` overloads. + OutputEmit("", s"$targetVec.set($idx, $source);") + case dt: DecimalType => + // Optimization: DecimalOutputShortFastPath. + // For precision <= 18 the unscaled value fits in a signed long; pass it straight to + // `DecimalVector.setSafe(int, long)` and skip the `java.math.BigDecimal` allocation + // `setSafe(int, BigDecimal)` requires. For p > 18 the BigDecimal path is unavoidable. + val write = + if (dt.precision <= 18) s"$targetVec.setSafe($idx, $source.toUnscaledLong());" + else s"$targetVec.setSafe($idx, $source.toJavaBigDecimal());" + OutputEmit("", write) + case _: StringType => + // Optimization: Utf8OutputOnHeapShortcut. + // `UTF8String` is internally a `(base, offset, numBytes)` view. When the base is a + // `byte[]` (common case: Spark string functions allocate results on-heap), pass the + // existing byte[] directly to `VarCharVector.setSafe(int, byte[], int, int)` via the + // encoded offset and skip the redundant `getBytes()` allocation. Off-heap passthrough + // (rare on output side) falls back to `getBytes()`. + val bBase = ctx.freshName("utfBase") + val bLen = ctx.freshName("utfLen") + val bArr = ctx.freshName("utfArr") + OutputEmit( + "", + s"""Object $bBase = $source.getBaseObject(); + |int $bLen = $source.numBytes(); + |if ($bBase instanceof byte[]) { + | $targetVec.setSafe($idx, (byte[]) $bBase, + | (int) ($source.getBaseOffset() + | - org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET), + | $bLen); + |} else { + | byte[] $bArr = $source.getBytes(); + | $targetVec.setSafe($idx, $bArr, 0, $bArr.length); + |}""".stripMargin) + case BinaryType => + // Spark's BinaryType value is already a `byte[]`. + OutputEmit("", s"$targetVec.setSafe($idx, $source, 0, $source.length);") + case ArrayType(elementType, _) => + // Complex-type output: recursive per-row write. + // Spark's `doGenCode` for ArrayType-returning expressions produces an `ArrayData` value + // (usually `GenericArrayData` / `UnsafeArrayData`). We iterate its elements, write each + // one into the Arrow `ListVector`'s child, and bracket with `startNewValue` / + // `endValue`. The element write recurses through `emitWrite` on the list's child vector, + // so any scalar we support becomes a valid array element. Nested complex types (Array of + // Array, Array of Struct) work by the same recursion. `targetVec` is a `ListVector` at + // the call site (either `output` at root or a hoisted child cast); we only need to cast + // its data vector, and that cast goes into setup. + val childVar = ctx.freshName("outListChild") + val childClass = outputVectorClass(elementType) + val arrVar = ctx.freshName("arr") + val nVar = ctx.freshName("n") + val childIdx = ctx.freshName("cidx") + val jVar = ctx.freshName("j") + val elemSource = emitSpecializedGetterExpr(arrVar, jVar, elementType) + val inner = emitWrite(childVar, s"$childIdx + $jVar", elemSource, elementType, ctx) + val setup = + (s"$childClass $childVar = ($childClass) $targetVec.getDataVector();" +: + Seq(inner.setup).filter(_.nonEmpty)).mkString("\n") + val perRow = + s"""org.apache.spark.sql.catalyst.util.ArrayData $arrVar = $source; + |int $nVar = $arrVar.numElements(); + |int $childIdx = $targetVec.startNewValue($idx); + |for (int $jVar = 0; $jVar < $nVar; $jVar++) { + | if ($arrVar.isNullAt($jVar)) { + | $childVar.setNull($childIdx + $jVar); + | } else { + | ${inner.perRow} + | } + |} + |$targetVec.endValue($idx, $nVar);""".stripMargin + OutputEmit(setup, perRow) + case st: StructType => + // Complex-type output: recursive per-row write to a StructVector. + // Spark's `doGenCode` for StructType-returning expressions produces an `InternalRow` + // value (`GenericInternalRow` / `UnsafeRow` / ScalaUDF encoder output). Typed child-vector + // casts are hoisted to setup (once per batch); the per-row body references the hoisted + // names. `StructVector` writes are flat-indexed (same `$idx` as the struct's outer slot). + // + // Branchless optimization: for each field whose `nullable == false` on the + // [[StructType]], we skip the `row.isNullAt($fi)` guard at source level. Non-nullable + // fields in Spark are a contract that the producer does not emit nulls for that field, + // and matching that contract here lets HotSpot emit a straight write path per field + // rather than a branch. + val rowVar = ctx.freshName("row") + val perField = st.fields.zipWithIndex.map { case (field, fi) => + val childVar = ctx.freshName("outStructChild") + val childClass = outputVectorClass(field.dataType) + val childDecl = + s"$childClass $childVar = ($childClass) $targetVec.getChildByOrdinal($fi);" + val fieldSource = emitSpecializedGetterExpr(rowVar, fi.toString, field.dataType) + val inner = emitWrite(childVar, idx, fieldSource, field.dataType, ctx) + val write = + if (!field.nullable) { + inner.perRow + } else { + s"""if ($rowVar.isNullAt($fi)) { + | $childVar.setNull($idx); + |} else { + | ${inner.perRow} + |}""".stripMargin + } + val perFieldSetup = (Seq(childDecl) ++ Seq(inner.setup).filter(_.nonEmpty)).mkString("\n") + (perFieldSetup, write) + } + val setup = perField.map(_._1).mkString("\n") + val perFieldWrites = perField.map(_._2).mkString("\n") + val perRow = + s"""org.apache.spark.sql.catalyst.InternalRow $rowVar = $source; + |$targetVec.setIndexDefined($idx); + |$perFieldWrites""".stripMargin + OutputEmit(setup, perRow) + case mt: MapType => + // Complex-type output: recursive per-row write to a MapVector. + // Spark's `doGenCode` for MapType-returning expressions produces a `MapData` value + // (`ArrayBasedMapData` / `UnsafeMapData` / ScalaUDF encoder output). Typed child-vector + // casts for the entries struct and the key/value children are hoisted to setup (once per + // batch); the per-row body references them. + // + // Per-row shape: + // 1. Read keyArray / valueArray from the MapData source. + // 2. Open a new map entry via `startNewValue(idx)`; returns the base index into the + // entries StructVector for this row's key/value pairs. + // 3. For each key/value pair: set the entries struct slot defined (map values can be + // null, but the struct slot itself is defined), write the key (always non-null by + // Spark/Arrow invariant), then write the value with a null-guard on + // `vals.isNullAt(j)`. Both writes recurse through `emitWrite`. + // 4. Close the map entry with `endValue(idx, n)`. + val entriesVar = ctx.freshName("outMapEntries") + val keyVar = ctx.freshName("outMapKey") + val valVar = ctx.freshName("outMapVal") + val mapSrc = ctx.freshName("mapSrc") + val keyArr = ctx.freshName("keyArr") + val valArr = ctx.freshName("valArr") + val nVar = ctx.freshName("n") + val childIdx = ctx.freshName("cidx") + val jVar = ctx.freshName("j") + val structClass = classOf[StructVector].getName + val keyClass = outputVectorClass(mt.keyType) + val valClass = outputVectorClass(mt.valueType) + val keySrcExpr = emitSpecializedGetterExpr(keyArr, jVar, mt.keyType) + val valSrcExpr = emitSpecializedGetterExpr(valArr, jVar, mt.valueType) + val keyEmit = emitWrite(keyVar, s"$childIdx + $jVar", keySrcExpr, mt.keyType, ctx) + val valEmit = emitWrite(valVar, s"$childIdx + $jVar", valSrcExpr, mt.valueType, ctx) + val setup = + (Seq( + s"$structClass $entriesVar = ($structClass) $targetVec.getDataVector();", + s"$keyClass $keyVar = ($keyClass) $entriesVar.getChildByOrdinal(0);", + s"$valClass $valVar = ($valClass) $entriesVar.getChildByOrdinal(1);") ++ + Seq(keyEmit.setup, valEmit.setup).filter(_.nonEmpty)).mkString("\n") + val perRow = + s"""org.apache.spark.sql.catalyst.util.MapData $mapSrc = $source; + |org.apache.spark.sql.catalyst.util.ArrayData $keyArr = $mapSrc.keyArray(); + |org.apache.spark.sql.catalyst.util.ArrayData $valArr = $mapSrc.valueArray(); + |int $nVar = $mapSrc.numElements(); + |int $childIdx = $targetVec.startNewValue($idx); + |for (int $jVar = 0; $jVar < $nVar; $jVar++) { + | $entriesVar.setIndexDefined($childIdx + $jVar); + | ${keyEmit.perRow} + | if ($valArr.isNullAt($jVar)) { + | $valVar.setNull($childIdx + $jVar); + | } else { + | ${valEmit.perRow} + | } + |} + |$targetVec.endValue($idx, $nVar);""".stripMargin + OutputEmit(setup, perRow) + case other => + throw new UnsupportedOperationException( + s"CometBatchKernelCodegen.emitWrite: unsupported output type $other") + } + + /** + * Java expression that reads a typed value out of a Spark `SpecializedGetters` reference (which + * both `ArrayData` and `InternalRow` implement) at a given ordinal/index. Used by the + * `ArrayType` and `StructType` branches of [[emitWrite]] to source each element / field for its + * recursive inner write. + */ + private def emitSpecializedGetterExpr(target: String, idx: String, elemType: DataType): String = + elemType match { + case BooleanType => s"$target.getBoolean($idx)" + case ByteType => s"$target.getByte($idx)" + case ShortType => s"$target.getShort($idx)" + case IntegerType | DateType => s"$target.getInt($idx)" + case LongType | TimestampType | TimestampNTZType => s"$target.getLong($idx)" + case FloatType => s"$target.getFloat($idx)" + case DoubleType => s"$target.getDouble($idx)" + case dt: DecimalType => s"$target.getDecimal($idx, ${dt.precision}, ${dt.scale})" + case _: StringType => s"$target.getUTF8String($idx)" + case BinaryType => s"$target.getBinary($idx)" + case ArrayType(_, _) => s"$target.getArray($idx)" + case _: MapType => s"$target.getMap($idx)" + case _: StructType => + val numFields = elemType.asInstanceOf[StructType].fields.length + s"$target.getStruct($idx, $numFields)" + case other => + throw new UnsupportedOperationException( + s"CometBatchKernelCodegen.emitSpecializedGetterExpr: unsupported type $other") + } +} diff --git a/common/src/main/scala/org/apache/comet/udf/CometCodegenDispatchUDF.scala b/common/src/main/scala/org/apache/comet/udf/CometCodegenDispatchUDF.scala new file mode 100644 index 0000000000..5be5dc25d5 --- /dev/null +++ b/common/src/main/scala/org/apache/comet/udf/CometCodegenDispatchUDF.scala @@ -0,0 +1,358 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import java.nio.ByteBuffer +import java.util.{Collections, LinkedHashMap} +import java.util.concurrent.atomic.AtomicLong + +import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, DecimalVector, Float4Vector, Float8Vector, IntVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, ValueVector, VarBinaryVector, VarCharVector} +import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector} +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression} +import org.apache.spark.sql.comet.util.Utils +import org.apache.spark.sql.types.{BinaryType, DataType, StringType} + +import org.apache.comet.udf.CometBatchKernelCodegen.{ArrayColumnSpec, ArrowColumnSpec, MapColumnSpec, ScalarColumnSpec, StructColumnSpec, StructFieldSpec} + +/** + * Arrow-direct codegen dispatcher. For each (bound Spark `Expression`, input Arrow schema) pair, + * compiles a specialized [[CometBatchKernel]] on first encounter and caches the compile. + * Subsequent batches with the same expression and schema reuse the cached compile. + * + * Arg 0 is a `VarBinaryVector` scalar carrying the serialized Expression bytes (produced on the + * driver by Spark's closure serializer). Args 1..N are the data columns the `BoundReference`s + * refer to, in ordinal order. The bytes self-describe the expression so the path works in cluster + * mode without executor-side state. + * + * Three caches compose at different scopes: the JVM-wide compile cache on the companion + * (`kernelCache`); a per-thread UDF instance map in `CometUdfBridge.INSTANCES`; and per-partition + * kernel instance state on this object (`activeKernel`, `activeKey`, `activePartition`) managed + * by [[ensureKernel]]. See `docs/source/contributor-guide/jvm_udf_dispatch.md` for the rationale + * and why none of the layers can be collapsed. + */ +class CometCodegenDispatchUDF extends CometUDF { + + override def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector = { + require( + inputs.length >= 1, + "CometCodegenDispatchUDF requires at least 1 input (serialized expression), " + + s"got ${inputs.length}") + val exprVec = inputs(0).asInstanceOf[VarBinaryVector] + require( + exprVec.getValueCount >= 1 && !exprVec.isNull(0), + "CometCodegenDispatchUDF requires non-null serialized expression bytes at arg 0") + val bytes = exprVec.get(0) + + // TODO(dict-encoded): kernels assume materialized inputs; dict-encoded vectors would fail the + // cast in `specFor` below. See docs/source/contributor-guide/jvm_udf_dispatch.md#open-items. + + val numDataCols = inputs.length - 1 + val dataCols = new Array[ValueVector](numDataCols) + val specs = new Array[ArrowColumnSpec](numDataCols) + var di = 0 + while (di < numDataCols) { + val v = inputs(di + 1) + dataCols(di) = v + specs(di) = specFor(v) + di += 1 + } + val n = numRows + val specsSeq = specs.toIndexedSeq + + val key = CometCodegenDispatchUDF.CacheKey(ByteBuffer.wrap(bytes), specsSeq) + val entry = CometCodegenDispatchUDF.lookupOrCompile(key, bytes, specsSeq) + + val partitionId = CometCodegenDispatchUDF.currentPartitionIndex() + val kernel = ensureKernel(entry.compiled, key, partitionId) + + val out = CometBatchKernelCodegen.allocateOutput( + entry.outputType, + "codegen_result", + n, + estimatedOutputBytes(entry.outputType, dataCols)) + try { + kernel.process(dataCols, out, n) + out.setValueCount(n) + out + } catch { + case t: Throwable => + try out.close() + catch { case _: Throwable => () } + throw t + } + } + + /** + * Per-partition kernel instance cache. The dispatcher's compile cache (on the companion object) + * is JVM-wide and stores the compiled `GeneratedClass`. The kernel '''instance''', however, + * holds per-row mutable state for non-deterministic and stateful expressions (`Rand`'s + * `XORShiftRandom`, `MonotonicallyIncreasingID`'s counter, etc.). That state must advance + * across batches in one partition and reset across partitions. Allocating per batch (the prior + * model) reset state every batch and was wrong; allocating per partition is right. + * + * `CometCodegenDispatchUDF` is per-thread via `CometUdfBridge.INSTANCES`, and Spark tasks are + * single-threaded on a partition, so plain instance fields are safe without synchronisation. A + * different partition or a different cached expression flowing through the same thread triggers + * a fresh allocation; same partition + same expression reuses the kernel. + */ + private var activeKernel: CometBatchKernel = _ + private var activeKey: CometCodegenDispatchUDF.CacheKey = _ + private var activePartition: Int = -1 + + private def ensureKernel( + compiled: CometBatchKernelCodegen.CompiledKernel, + key: CometCodegenDispatchUDF.CacheKey, + partitionId: Int): CometBatchKernel = { + if (activeKernel == null || activePartition != partitionId || activeKey != key) { + activeKernel = compiled.newInstance() + activeKernel.init(partitionId) + activeKey = key + activePartition = partitionId + } + activeKernel + } + + /** + * Did any row in this Arrow vector set the null bit? The cache key carries this per column, so + * a batch with no nulls and a later batch with nulls map to different keys and different + * compiles, no correctness risk from flipping this. The tighter `nullable=false` compile lets + * the kernel emit `return false` from its `isNullAt` switch and, once paired with the + * BoundReference tree rewrite in `lookupOrCompile`, lets Spark's `BoundReference.genCode` skip + * the null branch at source level rather than relying on JIT constant-folding. + * + * Trade-off: if real workloads flip a column's nullability frequently across batches, each + * expression caches up to `2^numCols` variants and the bounded LRU churns. The common case is + * stable per-column nullability per query, which keeps variance at one kernel per expression. + */ + private def nullable(v: ValueVector): Boolean = v.getNullCount != 0 + + /** + * Build the compile-time spec for one input Arrow vector. Recurses on complex types; scalars + * produce a [[ScalarColumnSpec]] carrying the concrete Arrow vector class and nullability. + * Spark `DataType`s on complex children come from [[Utils.fromArrowField]] so the Arrow -> + * Spark mapping stays in one place. + */ + private def specFor(v: ValueVector): ArrowColumnSpec = v match { + case map: MapVector => + // MapVector extends ListVector; match it first. Its data vector is a StructVector with + // child 0 = key and child 1 = value. + val struct = map.getDataVector.asInstanceOf[StructVector] + val keyVec = struct.getChildByOrdinal(0).asInstanceOf[ValueVector] + val valueVec = struct.getChildByOrdinal(1).asInstanceOf[ValueVector] + MapColumnSpec( + nullable = nullable(map), + keySparkType = Utils.fromArrowField(keyVec.getField), + valueSparkType = Utils.fromArrowField(valueVec.getField), + key = specFor(keyVec), + value = specFor(valueVec)) + case list: ListVector => + val child = list.getDataVector + ArrayColumnSpec(nullable(list), Utils.fromArrowField(child.getField), specFor(child)) + case struct: StructVector => + val fieldSpecs = (0 until struct.size()).map { fi => + val childVec = struct.getChildByOrdinal(fi).asInstanceOf[ValueVector] + val field = struct.getField.getChildren.get(fi) + StructFieldSpec( + name = field.getName, + sparkType = Utils.fromArrowField(field), + nullable = field.isNullable, + child = specFor(childVec)) + } + StructColumnSpec(nullable(struct), fieldSpecs) + case _: BitVector | _: TinyIntVector | _: SmallIntVector | _: IntVector | _: BigIntVector | + _: Float4Vector | _: Float8Vector | _: DecimalVector | _: VarCharVector | + _: VarBinaryVector | _: DateDayVector | _: TimeStampMicroVector | + _: TimeStampMicroTZVector => + ScalarColumnSpec(v.getClass.asInstanceOf[Class[_ <: ValueVector]], nullable(v)) + case other => + throw new UnsupportedOperationException( + s"CometCodegenDispatchUDF: unsupported Arrow vector ${other.getClass.getSimpleName}") + } + + /** + * Estimate output byte capacity for variable-length output types. Sums the data-buffer sizes of + * variable-length input vectors as an upper bound for typical transform expressions (replace, + * upper, lower, substring, concat on the same inputs). Underestimates are still corrected by + * `setSafe`; this just reduces the odds of mid-loop reallocation. + */ + private def estimatedOutputBytes(outputType: DataType, dataCols: Array[ValueVector]): Int = { + outputType match { + case _: StringType | _: BinaryType => + var sum = 0 + var i = 0 + while (i < dataCols.length) { + dataCols(i) match { + case v: VarCharVector => sum += v.getDataBuffer.writerIndex().toInt + case v: VarBinaryVector => sum += v.getDataBuffer.writerIndex().toInt + case _ => // no size hint for fixed-width vector types + } + i += 1 + } + sum + case _ => -1 + } + } +} + +object CometCodegenDispatchUDF { + + private val CacheCapacity: Int = 128 + + /** + * Cache key: serialized expression bytes plus per-column compile-time invariants. + * + * `hashCode` walks `bytesKey` per lookup, so for large ScalaUDF closures it scales with closure + * size. TODO(perf-cache-key): see + * `docs/source/contributor-guide/jvm_udf_dispatch.md#open-items` for possible optimizations if + * a workload makes this hot. + */ + final case class CacheKey(bytesKey: ByteBuffer, specs: IndexedSeq[ArrowColumnSpec]) + + private case class CacheEntry( + compiled: CometBatchKernelCodegen.CompiledKernel, + outputType: DataType) + + private val kernelCache: java.util.Map[CacheKey, CacheEntry] = + Collections.synchronizedMap( + new LinkedHashMap[CacheKey, CacheEntry](CacheCapacity, 0.75f, true) { + override def removeEldestEntry( + eldest: java.util.Map.Entry[CacheKey, CacheEntry]): Boolean = + size() > CacheCapacity + }) + + // Observability counters. Incremented under the `kernelCache.synchronized` block in + // `lookupOrCompile` so counter increments and cache mutations cannot interleave. Read via + // [[stats]]; reset via [[resetStats]] for tests. + private val compileCount = new AtomicLong(0) + private val cacheHitCount = new AtomicLong(0) + + /** + * Snapshot of dispatcher cache counters and current size. Intended for tests, logging, and + * future integration with Spark SQL metrics. Not thread-synchronized across the three fields + * (each read is atomic, but they are not read atomically together); snapshots taken during + * concurrent activity may show a consistent individual-field view but a slightly inconsistent + * combined view. Fine for reporting, not for assertions that require cross-field invariants. + */ + final case class DispatcherStats(compileCount: Long, cacheHitCount: Long, cacheSize: Int) { + def totalLookups: Long = compileCount + cacheHitCount + def hitRate: Double = + if (totalLookups == 0) 0.0 else cacheHitCount.toDouble / totalLookups.toDouble + } + + /** Returns a snapshot of cache counters and current size. Cheap; safe to call anytime. */ + def stats(): DispatcherStats = + DispatcherStats(compileCount.get(), cacheHitCount.get(), kernelCache.size()) + + /** Reset counters to zero. Leaves the compile cache intact. Intended for tests. */ + def resetStats(): Unit = { + compileCount.set(0) + cacheHitCount.set(0) + } + + /** + * Test-facing snapshot of compiled kernel signatures currently in the cache. Each entry is the + * pair `(input Arrow vector classes in ordinal order, output Spark DataType)` the kernel + * compiled against. Lets tests assert that the dispatcher actually specialized on the types it + * was expected to, not just that the query returned a correct result (which Spark would do + * regardless of how the kernel was shaped). + * + * Drops the `ArrowColumnSpec.nullable` bit to keep assertions robust to per-batch nullability + * variance: test data with no nulls compiles with `nullable=false` and the same expression run + * against data with nulls would cache a second variant. Tests assert on vector class and output + * type; both variants satisfy the same assertion. + */ + def snapshotCompiledSignatures(): Set[(IndexedSeq[Class[_ <: ValueVector]], DataType)] = { + kernelCache.synchronized { + import scala.jdk.CollectionConverters._ + kernelCache + .entrySet() + .asScala + .iterator + .map { e => + (e.getKey.specs.map(_.vectorClass), e.getValue.outputType) + } + .toSet + } + } + + private def lookupOrCompile( + key: CacheKey, + bytes: Array[Byte], + specs: IndexedSeq[ArrowColumnSpec]): CacheEntry = { + kernelCache.synchronized { + val existing = kernelCache.get(key) + if (existing != null) { + cacheHitCount.incrementAndGet() + existing + } else { + // Use a classloader that can see Spark classes. The Comet native runtime calls us on a + // Tokio worker thread where the context classloader may not be set to Spark's task + // loader, so fall back to the loader that loaded `Expression` itself if needed. + val loader = Option(Thread.currentThread().getContextClassLoader) + .getOrElse(classOf[Expression].getClassLoader) + val rawExpr = SparkEnv.get.closureSerializer + .newInstance() + .deserialize[Expression](ByteBuffer.wrap(bytes), loader) + // Tighten BoundReference.nullable based on the observed batch. The plan-time value is + // conservative (the column may be null somewhere in the query's execution), but for + // this specific batch we know. Rewriting lets Spark's `BoundReference.genCode` skip the + // `isNull` branch at source level rather than leaving it to JIT constant-folding. + // Correctness is preserved by the cache key: a later batch with nulls on this column has + // a different `specs`, so it hits a different kernel compiled with nullable=true. + val boundExpr = rewriteBoundReferences(rawExpr, specs) + val compiled = CometBatchKernelCodegen.compile(boundExpr, specs) + val entry = CacheEntry(compiled, boundExpr.dataType) + kernelCache.put(key, entry) + compileCount.incrementAndGet() + entry + } + } + } + + /** + * Walk the bound expression tree and rewrite any `BoundReference(ord, dt, nullable=true)` to + * `nullable=false` when the corresponding input column in `specs` is non-nullable for this + * batch. Only tightens; never relaxes. Expressions outside the `BoundReference` leaves are + * unchanged. + */ + private def rewriteBoundReferences( + expr: Expression, + specs: IndexedSeq[ArrowColumnSpec]): Expression = { + expr.transform { + case BoundReference(ord, dt, true) + if ord >= 0 && ord < specs.length && !specs(ord).nullable => + BoundReference(ord, dt, nullable = false) + // Fall through unchanged: non-BoundReference nodes and BoundReferences that are already + // non-nullable or point at a nullable column in this batch. + case other => other + } + } + + /** + * Partition index for the generated kernel's `init`. Expressions whose `doGenCode` calls + * `addPartitionInitializationStatement` (e.g. `Rand`, `Randn`, `Uuid`) reseed mutable state + * from this. Falls back to 0 when the dispatcher is exercised outside a Spark task (unit tests) + * so an absent `TaskContext` does not fail the call; the result is still deterministic for that + * fallback. + */ + private def currentPartitionIndex(): Int = + Option(TaskContext.get()).map(_.partitionId()).getOrElse(0) +} diff --git a/common/src/main/scala/org/apache/comet/udf/CometInternalRow.scala b/common/src/main/scala/org/apache/comet/udf/CometInternalRow.scala new file mode 100644 index 0000000000..0007499ea1 --- /dev/null +++ b/common/src/main/scala/org/apache/comet/udf/CometInternalRow.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, DataType, DateType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, StructType, TimestampNTZType, TimestampType} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} + +import org.apache.comet.shims.CometInternalRowShim + +/** + * Shim base for Comet-owned [[InternalRow]] getters used by the Arrow-direct codegen kernel. + * + * Provides `throw new UnsupportedOperationException` defaults for every abstract method declared + * by `InternalRow` and `SpecializedGetters`. Concrete subclasses (`CometBatchKernel` and its + * generated subclasses) override only the getters they actually support for their input shape. + * + * Purpose: keep subclasses free of boilerplate throws, and absorb forward-compat breakage if + * Spark adds abstract methods to `InternalRow` in a future version. Add the defaulted override + * here once, all subclasses recompile. + */ +abstract class CometInternalRow extends InternalRow with CometInternalRowShim { + + override def numFields: Int = unsupported("numFields") + override def isNullAt(ordinal: Int): Boolean = unsupported("isNullAt") + + override def getBoolean(ordinal: Int): Boolean = unsupported("getBoolean") + override def getByte(ordinal: Int): Byte = unsupported("getByte") + override def getShort(ordinal: Int): Short = unsupported("getShort") + override def getInt(ordinal: Int): Int = unsupported("getInt") + override def getLong(ordinal: Int): Long = unsupported("getLong") + override def getFloat(ordinal: Int): Float = unsupported("getFloat") + override def getDouble(ordinal: Int): Double = unsupported("getDouble") + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = + unsupported("getDecimal") + override def getUTF8String(ordinal: Int): UTF8String = unsupported("getUTF8String") + override def getBinary(ordinal: Int): Array[Byte] = unsupported("getBinary") + override def getInterval(ordinal: Int): CalendarInterval = unsupported("getInterval") + override def getStruct(ordinal: Int, numFields: Int): InternalRow = unsupported("getStruct") + override def getArray(ordinal: Int): ArrayData = unsupported("getArray") + override def getMap(ordinal: Int): MapData = unsupported("getMap") + + /** + * Generic `get(ordinal, dataType)` dispatcher. Required because `SpecializedGetters` declares + * it abstract and some Spark codegen paths (notably `SafeProjection` for deserializing + * `ScalaUDF` struct arguments) call it instead of the typed getter. Dispatches to the typed + * getter matching `dataType`; a null entry returns `null` outright. Unsupported types fall + * through to the shared throw. + */ + override def get(ordinal: Int, dataType: DataType): AnyRef = { + if (isNullAt(ordinal)) return null + dataType match { + case BooleanType => java.lang.Boolean.valueOf(getBoolean(ordinal)) + case ByteType => java.lang.Byte.valueOf(getByte(ordinal)) + case ShortType => java.lang.Short.valueOf(getShort(ordinal)) + case IntegerType | DateType => java.lang.Integer.valueOf(getInt(ordinal)) + case LongType | TimestampType | TimestampNTZType => + java.lang.Long.valueOf(getLong(ordinal)) + case FloatType => java.lang.Float.valueOf(getFloat(ordinal)) + case DoubleType => java.lang.Double.valueOf(getDouble(ordinal)) + case _: StringType => getUTF8String(ordinal) + case BinaryType => getBinary(ordinal) + case dt: DecimalType => getDecimal(ordinal, dt.precision, dt.scale) + case st: StructType => getStruct(ordinal, st.size) + case _: ArrayType => getArray(ordinal) + case _: MapType => getMap(ordinal) + case other => unsupported(s"get for dataType $other") + } + } + + override def setNullAt(i: Int): Unit = unsupported("setNullAt") + override def update(i: Int, value: Any): Unit = unsupported("update") + override def copy(): InternalRow = unsupported("copy") + + protected def unsupported(method: String): Nothing = + throw new UnsupportedOperationException( + s"${getClass.getSimpleName}: $method not implemented for this row shape") +} diff --git a/common/src/main/scala/org/apache/comet/udf/CometLambdaRegistry.scala b/common/src/main/scala/org/apache/comet/udf/CometLambdaRegistry.scala deleted file mode 100644 index 5e020ae74a..0000000000 --- a/common/src/main/scala/org/apache/comet/udf/CometLambdaRegistry.scala +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.comet.udf - -import java.util.UUID -import java.util.concurrent.ConcurrentHashMap - -import org.apache.spark.sql.catalyst.expressions.Expression - -/** - * Thread-safe registry bridging plan-time Spark expressions to execution-time UDF lookup. At plan - * time the serde layer registers a lambda expression under a unique key; at execution time the - * UDF retrieves it by that key (passed as a scalar argument). - */ -object CometLambdaRegistry { - - private val registry = new ConcurrentHashMap[String, Expression]() - - def register(expression: Expression): String = { - val key = UUID.randomUUID().toString - registry.put(key, expression) - key - } - - def get(key: String): Expression = { - val expr = registry.get(key) - if (expr == null) { - throw new IllegalStateException( - s"Lambda expression not found in registry for key: $key. " + - "This indicates a lifecycle issue between plan creation and execution.") - } - expr - } - - def remove(key: String): Unit = { - registry.remove(key) - } - - // Visible for testing - def size(): Int = registry.size() -} diff --git a/common/src/main/scala/org/apache/comet/udf/CometMapData.scala b/common/src/main/scala/org/apache/comet/udf/CometMapData.scala new file mode 100644 index 0000000000..fc99844110 --- /dev/null +++ b/common/src/main/scala/org/apache/comet/udf/CometMapData.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} + +/** + * Shim base for Comet-owned [[MapData]] views used by the Arrow-direct codegen kernel. Provides + * `UnsupportedOperationException` defaults for every abstract method on `MapData`; the codegen- + * emitted `InputMap_${path}` subclass overrides `numElements`, `keyArray`, and `valueArray`. + * + * Pairs with [[CometArrayData]] and [[CometInternalRow]]. `MapData` does not extend + * `SpecializedGetters` (unlike `ArrayData` / `InternalRow`), so no version-specific shim is + * needed here. + */ +abstract class CometMapData extends MapData { + + override def numElements(): Int = unsupported("numElements") + override def keyArray(): ArrayData = unsupported("keyArray") + override def valueArray(): ArrayData = unsupported("valueArray") + override def copy(): MapData = unsupported("copy") + + override def toString(): String = { + val n = + try numElements().toString + catch { case _: Throwable => "?" } + s"${getClass.getSimpleName}(numElements=$n)" + } + + protected def unsupported(method: String): Nothing = + throw new UnsupportedOperationException( + s"${getClass.getSimpleName}: $method not implemented for this map shape") +} diff --git a/common/src/main/scala/org/apache/comet/udf/CometUDF.scala b/common/src/main/scala/org/apache/comet/udf/CometUDF.scala index 29186f0a2c..98cb519c1b 100644 --- a/common/src/main/scala/org/apache/comet/udf/CometUDF.scala +++ b/common/src/main/scala/org/apache/comet/udf/CometUDF.scala @@ -27,11 +27,16 @@ import org.apache.arrow.vector.ValueVector * * - Vector arguments arrive at the row count of the current batch. * - Scalar (literal-folded) arguments arrive as length-1 vectors and must be read at index 0. - * - The returned vector's length must match the longest input. + * - The returned vector's length must match `numRows`. * - * Implementations must have a public no-arg constructor and must be stateless: a single instance - * per class is cached and shared across native worker threads for the lifetime of the JVM. + * `numRows` mirrors DataFusion's `ScalarFunctionArgs.number_rows` and is the batch row count. + * UDFs that always have at least one batch-length input can read length from it and ignore + * `numRows`; UDFs that may be called with zero data columns (e.g. a zero-arg ScalaUDF through the + * codegen dispatcher) need `numRows` to know how many rows to produce. + * + * Implementations must have a public no-arg constructor and should be stateless: instances are + * cached per executor thread for the lifetime of the JVM. */ trait CometUDF { - def evaluate(inputs: Array[ValueVector]): ValueVector + def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector } diff --git a/common/src/main/scala/org/apache/spark/comet/CometTaskContextShim.scala b/common/src/main/scala/org/apache/spark/comet/CometTaskContextShim.scala new file mode 100644 index 0000000000..9218fc5e78 --- /dev/null +++ b/common/src/main/scala/org/apache/spark/comet/CometTaskContextShim.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.comet + +import org.apache.spark.TaskContext + +/** + * Package-private access shim for `TaskContext.setTaskContext` / `TaskContext.unset`. + * + * Both methods are declared `protected[spark]` on Spark's `TaskContext` companion, so they are + * reachable from code inside the `org.apache.spark` package tree but not from `org.apache.comet`. + * The Comet JVM UDF bridge needs to set the thread-local `TaskContext` on its caller thread (a + * Tokio worker thread with no `TaskContext`) so the user's UDF body and any partition-sensitive + * built-ins (`Rand`, `Uuid`, `MonotonicallyIncreasingID`, etc.) see the driving Spark task's + * `TaskContext`. This shim lives in `org.apache.spark.comet` so it can call through to the + * protected methods, and exposes plain public forwarders the bridge (which lives in + * `org.apache.comet.udf`) can use. + */ +object CometTaskContextShim { + + def set(taskContext: TaskContext): Unit = TaskContext.setTaskContext(taskContext) + + def unset(): Unit = TaskContext.unset() +} diff --git a/common/src/main/spark-3.x/org/apache/comet/shims/CometExprTraitShim.scala b/common/src/main/spark-3.x/org/apache/comet/shims/CometExprTraitShim.scala new file mode 100644 index 0000000000..3d039879d5 --- /dev/null +++ b/common/src/main/spark-3.x/org/apache/comet/shims/CometExprTraitShim.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} + +/** + * Per-profile view of expression traits that shifted shape across Spark versions. Spark 3.x has a + * `NullIntolerant` marker trait and no scalar-expression `Stateful` concept at all (the notion + * was added in 4.x as a boolean method on `Expression`). Routing checks through one shim lets the + * dispatcher ask "is this expression null-intolerant / stateful" without sprinkling version + * pattern matches through the codebase. + */ +trait CometExprTraitShim { + def isNullIntolerant(expr: Expression): Boolean = expr.isInstanceOf[NullIntolerant] + + // No scalar `Stateful` trait in 3.x. Aggregate/window/generator stateful cases are rejected + // elsewhere in `canHandle`, so treating all scalar expressions as non-stateful here is + // conservative-correct on this profile. + def isStateful(expr: Expression): Boolean = false + + // No collation / `ResolvedCollation` concept in 3.x, so no `Unevaluable` leaf slips past the + // dispatcher's guard here. + def isCodegenInertUnevaluable(expr: Expression): Boolean = false +} diff --git a/common/src/main/spark-3.x/org/apache/comet/shims/CometInternalRowShim.scala b/common/src/main/spark-3.x/org/apache/comet/shims/CometInternalRowShim.scala new file mode 100644 index 0000000000..e71d301d48 --- /dev/null +++ b/common/src/main/spark-3.x/org/apache/comet/shims/CometInternalRowShim.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +/** + * Per-profile extension point mixed into `CometInternalRow` and `CometArrayData`. Spark 4.x added + * new abstract getters on `SpecializedGetters` (`getVariant` in 4.0, `getGeography` and + * `getGeometry` in 4.1) that both `InternalRow` and `ArrayData` concrete subclasses must + * implement. Spark 3.x has none of these; this trait is empty so the shared classes compile + * unchanged on that profile. + */ +trait CometInternalRowShim diff --git a/common/src/main/spark-4.0/org/apache/comet/shims/CometInternalRowShim.scala b/common/src/main/spark-4.0/org/apache/comet/shims/CometInternalRowShim.scala new file mode 100644 index 0000000000..20c6d47816 --- /dev/null +++ b/common/src/main/spark-4.0/org/apache/comet/shims/CometInternalRowShim.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.unsafe.types.VariantVal + +/** + * Throwing-default implementations for `SpecializedGetters` methods that were added in Spark 4.0: + * `getVariant`. The Janino-generated kernel subclasses `CometInternalRow` (rows) and + * `CometArrayData` (array inputs), and each must satisfy every abstract method on the interface; + * without these defaults the compiled class fails its abstract-method check at class-load time. + * `GeographyVal` and `GeometryVal` were added in 4.1, so this profile's shim does not override + * those getters. + */ +trait CometInternalRowShim { + def getVariant(ordinal: Int): VariantVal = + throw new UnsupportedOperationException( + s"${getClass.getSimpleName}: getVariant not supported") +} diff --git a/common/src/main/spark-4.1/org/apache/comet/shims/CometInternalRowShim.scala b/common/src/main/spark-4.1/org/apache/comet/shims/CometInternalRowShim.scala new file mode 100644 index 0000000000..3d277e7505 --- /dev/null +++ b/common/src/main/spark-4.1/org/apache/comet/shims/CometInternalRowShim.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.unsafe.types.{GeographyVal, GeometryVal, VariantVal} + +/** + * Throwing-default implementations for `SpecializedGetters` methods added in Spark 4.x: + * `getVariant` (4.0), `getGeography` and `getGeometry` (4.1). The Janino-generated kernel + * subclasses `CometInternalRow` (rows) and `CometArrayData` (array inputs), and each must satisfy + * every abstract method on the interface; without these defaults the compiled class fails its + * abstract-method check at class-load time. + */ +trait CometInternalRowShim { + def getVariant(ordinal: Int): VariantVal = + throw new UnsupportedOperationException( + s"${getClass.getSimpleName}: getVariant not supported") + + def getGeography(ordinal: Int): GeographyVal = + throw new UnsupportedOperationException( + s"${getClass.getSimpleName}: getGeography not supported") + + def getGeometry(ordinal: Int): GeometryVal = + throw new UnsupportedOperationException( + s"${getClass.getSimpleName}: getGeometry not supported") +} diff --git a/common/src/main/spark-4.2/org/apache/comet/shims/CometInternalRowShim.scala b/common/src/main/spark-4.2/org/apache/comet/shims/CometInternalRowShim.scala new file mode 100644 index 0000000000..3d277e7505 --- /dev/null +++ b/common/src/main/spark-4.2/org/apache/comet/shims/CometInternalRowShim.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.unsafe.types.{GeographyVal, GeometryVal, VariantVal} + +/** + * Throwing-default implementations for `SpecializedGetters` methods added in Spark 4.x: + * `getVariant` (4.0), `getGeography` and `getGeometry` (4.1). The Janino-generated kernel + * subclasses `CometInternalRow` (rows) and `CometArrayData` (array inputs), and each must satisfy + * every abstract method on the interface; without these defaults the compiled class fails its + * abstract-method check at class-load time. + */ +trait CometInternalRowShim { + def getVariant(ordinal: Int): VariantVal = + throw new UnsupportedOperationException( + s"${getClass.getSimpleName}: getVariant not supported") + + def getGeography(ordinal: Int): GeographyVal = + throw new UnsupportedOperationException( + s"${getClass.getSimpleName}: getGeography not supported") + + def getGeometry(ordinal: Int): GeometryVal = + throw new UnsupportedOperationException( + s"${getClass.getSimpleName}: getGeometry not supported") +} diff --git a/common/src/main/spark-4.x/org/apache/comet/shims/CometExprTraitShim.scala b/common/src/main/spark-4.x/org/apache/comet/shims/CometExprTraitShim.scala new file mode 100644 index 0000000000..2d86258014 --- /dev/null +++ b/common/src/main/spark-4.x/org/apache/comet/shims/CometExprTraitShim.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.sql.catalyst.expressions.{Expression, ResolvedCollation} + +/** + * Spark 4.x replaced the `NullIntolerant` marker trait with a boolean method on `Expression`, and + * introduced a `stateful` boolean method covering scalar expressions that carry per-row state + * (e.g. `Rand`, `Uuid`). Neither concept exists as a trait in 4.x, so pattern matches against + * them would fail to compile. This shim routes the checks through the method form. + */ +trait CometExprTraitShim { + def isNullIntolerant(expr: Expression): Boolean = expr.nullIntolerant + def isStateful(expr: Expression): Boolean = expr.stateful + + // `ResolvedCollation` is an `Unevaluable` leaf that only lives in `Collate.collation` as a + // type-level marker. `Collate.genCode` passes through to its child and never touches the + // collation slot, so the leaf is never invoked in generated code. Spark 4.1 analyzes it away, + // but 4.0 leaves it in the tree, so the dispatcher's `Unevaluable` guard trips on 4.0 without + // this exemption. + def isCodegenInertUnevaluable(expr: Expression): Boolean = expr match { + case _: ResolvedCollation => true + case _ => false + } +} diff --git a/docs/source/contributor-guide/index.md b/docs/source/contributor-guide/index.md index 77c73d68da..c2fbec4f54 100644 --- a/docs/source/contributor-guide/index.md +++ b/docs/source/contributor-guide/index.md @@ -47,6 +47,7 @@ Benchmarking Guide Adding a New Operator Adding a New Expression Supported Spark Expressions +JVM UDF Dispatch Tracing Profiling Comet SQL Tests diff --git a/docs/source/contributor-guide/jvm_udf_dispatch.md b/docs/source/contributor-guide/jvm_udf_dispatch.md new file mode 100644 index 0000000000..f68753b9f4 --- /dev/null +++ b/docs/source/contributor-guide/jvm_udf_dispatch.md @@ -0,0 +1,336 @@ + + +# JVM UDF dispatch + +Comet offloads expressions that lack a native DataFusion implementation, or whose native implementation diverges from Spark's semantics, to JVM-side code that operates on Arrow batches passed through the C Data Interface. This preserves Spark compatibility on expressions that would otherwise force a whole-plan fallback to Spark. The tradeoff is a JNI roundtrip and per-batch JVM execution. + +The dispatch path is **Arrow-direct codegen via `CometCodegenDispatchUDF`** - one generic dispatcher that compiles a specialized kernel per bound Spark `Expression` plus input schema. Per-expression specialized emitters inside the dispatcher cover the cases where the default `doGenCode` output pays avoidable conversions; see [Specialized emitters](#specialized-emitters) below. + +The JNI bridge (`CometUdfBridge`) and proto schema (`JvmScalarUdf`) are generic enough to carry any `CometUDF` implementation, but the codebase today contains one: `CometCodegenDispatchUDF`. + +## Arrow-direct codegen via `CometCodegenDispatchUDF` + +One UDF class handles any scalar Spark `Expression` in the supported type surface. For each `(boundExpr, inputSchema)` pair, it compiles a specialized `CometBatchKernel` subclass via Janino that fuses Arrow input reads, expression evaluation, and Arrow output writes into one method. The kernel is cached in a JVM-wide LRU. + +### Transport + +At plan time the serde binds the expression tree to its leaf `AttributeReference`s, serializes the bound `Expression` via Spark's closure serializer, and emits a `JvmScalarUdf` proto whose argument 0 is a `Literal(bytes, BinaryType)` holding the serialized Expression. Arguments 1..N are the raw data columns the `BoundReference`s refer to, in ordinal order. + +At execute time, `CometCodegenDispatchUDF.evaluate` reads the bytes from the `VarBinaryVector` at arg 0, computes a cache key from (bytes, per-column Arrow vector class, per-column nullability), and either reuses a cached `CompiledKernel` or compiles one on the miss path. + +The self-describing proto removes the driver-side state the original prototype relied on. Cluster-mode executors deserialize and compile locally. + +**Classloader caveat.** The Comet native runtime calls the UDF on a Tokio worker thread whose context classloader may not be Spark's task loader. `SparkEnv.get.closureSerializer.newInstance().deserialize[Expression](bytes)` without an explicit loader fails with `ClassNotFoundException` on Spark's expression classes. The dispatcher passes an explicit loader, falling back to the loader that loaded `Expression` if the thread context is null. + +### Compilation + +`CometBatchKernelCodegen.compile(boundExpr, inputSchema)` generates a Java source for a `SpecificCometBatchKernel` that: + +- Extends `CometBatchKernel`, which extends `CometInternalRow`, which extends Spark's `InternalRow`. The kernel **is** the `InternalRow` that Spark's `BoundReference.genCode` reads from. +- Sets `ctx.INPUT_ROW = "row"` at compile time and aliases `InternalRow row = this;` inside `process`, so Spark's generated body calls `row.getUTF8String(ord)` which resolves to the kernel's own typed getter. The getter is final, the ordinal is constant at the call site, and JIT devirtualizes and folds the switch. `row` rather than `this` because Spark's `splitExpressions` uses INPUT_ROW as a helper-method parameter name and `this` is a reserved Java keyword. +- Carries typed input fields `col0 .. colN`, one per bound column, cast at the top of `process` from the generic `ValueVector[]` to the concrete Arrow class baked in at compile time. +- Emits `isNullAt(ordinal)` and `getUTF8String(ordinal)` overrides whose switch cases are specialized per column. A column marked non-nullable compiles to `return false;`; a `VarCharVector` compiles to a zero-copy `UTF8String.fromAddress` read against the Arrow data buffer. +- Overrides `init(int partitionIndex)` with the statements collected by `ctx.addPartitionInitializationStatement`. Non-deterministic expressions (`Rand`, `Randn`, `Uuid`) register statements that reseed mutable state from `partitionIndex`; deterministic expressions leave `init` empty. +- Processes the batch in a tight loop that sets `this.rowIdx = i`, runs the expression body (either `boundExpr.genCode` for the default path or a specialized emitter), and writes to the typed output vector. + +### Specialized emitters + +For expressions whose `doGenCode` forces conversions that a tighter byte-oriented loop could skip, the dispatcher has per-expression overrides that emit custom Java while staying inside the framework (same cache, same bridge, same serde entry). Today that is `RegExpReplace`: the default path goes `Arrow bytes -> UTF8String -> String -> Matcher -> String -> UTF8String -> bytes -> Arrow` because `java.util.regex.Matcher` requires a `CharSequence`. The specialized emitter writes the byte-oriented shape directly (`Arrow bytes -> String -> Matcher -> String -> bytes -> Arrow`). The `UTF8String` round-trip costs measurable time on wide-match workloads; see `specializedRegExpReplaceBody` for the benchmark rationale. + +Precedent for adding new specializations: match when an expression's `doGenCode` pays conversions an Arrow-aware byte-oriented loop would avoid. Keep the specialization minimal (no speculative layering beyond the conversions it exists to skip) so its value over the default path stays legible. + +### Caching + +Three cache layers compose at three different scopes. None is redundant: collapsing any pair would either lose correctness or pay an avoidable cost. + +1. **JVM-wide compile cache.** Value is `CompiledKernel(factory: GeneratedClass, freshReferences: () => Array[Any])`, keyed by `(ByteBuffer.wrap(bytes), IndexedSeq[ArrowColumnSpec])`. Bounded LRU via `Collections.synchronizedMap(LinkedHashMap(accessOrder=true))` with `removeEldestEntry`, capacity 128. Same shape as `IcebergPlanDataInjector.commonCache` in `spark/src/main/scala/org/apache/spark/sql/comet/operators.scala`. Amortizes the Janino compile cost across every thread and every query in the JVM. + +2. **Per-thread UDF instance cache.** `CometUdfBridge.INSTANCES` is a `ThreadLocal>` that hands each task thread its own `CometCodegenDispatchUDF`. Keeps cache layer 3's instance fields safe without synchronization. + +3. **Per-partition kernel instance cache.** Plain mutable fields (`activeKernel`, `activeKey`, `activePartition`) on each UDF instance, managed by `ensureKernel`. The compiled `GeneratedClass` produces a kernel instance, and the kernel carries per-row mutable state (`Rand`'s `XORShiftRandom`, `MonotonicallyIncreasingID`'s counter, `addMutableState` fields) that must advance across batches within a partition and reset across partitions. `ensureKernel` allocates a fresh kernel and calls `init(partitionIndex)` only when the partition or cache key changes; otherwise the same kernel handles every batch in the partition. + +Matches Spark `WholeStageCodegenExec`: compile once per plan, instantiate per partition, init, iterate. + +#### Why `freshReferences` is a closure, not a cached array + +`CompiledKernel` holds a closure that regenerates `references: Array[Any]` each time a new kernel is allocated, rather than caching a single shared array. Reason: some expressions (notably `ScalaUDF`) embed stateful Spark `ExpressionEncoder` serializers into `references` via `ctx.addReferenceObj`. Those serializers reuse an internal `UnsafeRow` / `byte[]` buffer per `.apply(...)` call and are not thread-safe. If two kernels on different partitions shared one serializer instance, they would race on that buffer and return garbage. + +Re-running `genCode(ctx)` per kernel allocation costs microseconds; Janino compile costs milliseconds. Caching only the expensive piece preserves correctness cheaply. A future optimization would be to distinguish expressions whose references are all immutable (most non-UDF expressions) from those that embed stateful converters, and cache the array in the immutable case; not worth the complexity today. + +### Plan-time dispatchability + +`CometBatchKernelCodegen.canHandle(boundExpr)` runs at serde time. It returns `None` when the dispatcher can compile the expression, `Some(reason)` when it cannot. Checks: + +- Output `dataType` is in the scalar set `allocateOutput` and `emitOutputWriter` cover. +- No `AggregateFunction` or `Generator` anywhere in the tree (scalar-only bridge). +- Every `BoundReference`'s data type is in the input set `emitTypedGetters` has a getter for. + +The serde calls `withInfo(original, reason) + None` on a `Some` result, so Spark falls back rather than the kernel compiler crashing at execute time. Intermediate node types are not checked - `doGenCode` materializes them in local variables; only leaves (row reads) and the root (output write) touch Arrow. + +### Observability + +`CometCodegenDispatchUDF.stats()` returns `DispatcherStats(compileCount, cacheHitCount, cacheSize)`. `hitRate` is derived. `resetStats()` clears the counters (not the cache) for test isolation. + +Counters are not yet surfaced anywhere user-visible. Candidates for future wiring: Spark SQL metrics on the hosting operator, a JMX MBean, a Spark accumulator, or a periodic log line. + +## User-defined scalar functions (ScalaUDF) + +The codegen dispatcher routes scalar `org.apache.spark.sql.catalyst.expressions.ScalaUDF` expressions through the same compile + per-partition-kernel pipeline as the regex serdes. The serde is `CometScalaUDF` in `spark/src/main/scala/org/apache/comet/serde/scalaUdf.scala`, registered in `QueryPlanSerde.miscExpressions`. + +Why it works without per-UDF handling: Spark's `ScalaUDF.doGenCode` already emits compilable Java that calls the user function via `ctx.addReferenceObj`. The compile path runs `boundExpr.genCode(ctx)` and picks this up unchanged. The serialized-bytes transport carries the function reference through Spark's closure serializer, the same machinery Spark uses to ship UDFs to executors. Per-partition kernel caching handles `ScalaUDF`'s `stateful=true`. + +Without this serde, any `ScalaUDF` in a plan forces Comet to fall back to Spark for the whole plan, losing acceleration on the surrounding operators. With it, scalar UDFs whose types fit the supported surface stay on the Comet path behind one JNI hop. + +### What's covered + +| What users write | Spark expression class | Route through codegen | +| --------------------------------------------------------------- | ------------------------------------------------------ | ------------------------------------------------------------- | +| `udf((x: T) => ...)` or `spark.udf.register` (Scala) | `ScalaUDF` | yes | +| `spark.udf.register("f", new UDF1[...]{...})` (Java) | `ScalaUDF` (Spark wraps the Java functional interface) | yes, transparently | +| `CREATE FUNCTION foo AS 'com.example.MyUDF'` (SQL registration) | `ScalaUDF` | yes, if the user class is reachable on the executor classpath | + +### What's not covered + +| What users write | Spark expression class | Why not | +| ------------------------------- | --------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------ | +| Aggregate UDF | `ScalaAggregator`, `TypedImperativeAggregate`, old `UserDefinedAggregateFunction` | accumulator-based; needs a different bridge contract (accumulate + merge + finalize) | +| Table UDF / generator | `UserDefinedTableFunction` | 1 row -> N rows; `canHandle` rejects `Generator` | +| Python `@udf` | `PythonUDF` | subprocess runtime, not JVM | +| Pandas `@pandas_udf` | `PandasUDF` | Arrow-via-subprocess runtime | +| Hive `GenericUDF` / `SimpleUDF` | `HiveGenericUDF` / `HiveSimpleUDF` | separate expression classes; would need their own serde | + +### Constraints within the ScalaUDF path + +- Input and output types must be in the supported surface (see [Type surface](#type-surface)). Nested types (`Struct`, `Array`, `Map`) are supported when their element types are supported. +- The user function must be closure-serializable. This is Spark's own requirement; a function that works with Spark's executor execution works here. +- User functions that touch `TaskContext` internals, accumulators, or broadcast variables in unusual ways may misbehave; the common case works. +- Stateful behavior: per-partition kernel caching resets kernel instance state on partition boundary, matching the contract most user UDFs assume (and matching Spark's own re-instantiation on some paths). UDFs that rely on long-lived JVM-wide state across partitions in the same executor see that state reset more often than before, which is rare and usually a latent bug in the UDF. + +### Mode knob interaction + +`spark.comet.exec.codegenDispatch.mode` controls routing: + +- `auto` (default) and `force`: ScalaUDFs go through the codegen dispatcher. +- `disabled`: `CometScalaUDF.convert` returns `None` and the plan falls back to Spark. + +There is no non-codegen Comet path for arbitrary user functions. + +## Type surface + +### Input (kernel getters) + +All scalar Spark types that map to a single Arrow vector: + +| Spark type | Arrow vector class | `InternalRow` getter | +| ----------------------------------------- | ---------------------------------------------------------- | -------------------------------------------------------- | +| BooleanType | BitVector | `getBoolean` | +| ByteType | TinyIntVector | `getByte` | +| ShortType | SmallIntVector | `getShort` | +| IntegerType, DateType | IntVector, DateDayVector | `getInt` | +| LongType, TimestampType, TimestampNTZType | BigIntVector, TimeStampMicroVector, TimeStampMicroTZVector | `getLong` | +| FloatType | Float4Vector | `getFloat` | +| DoubleType | Float8Vector | `getDouble` | +| DecimalType | DecimalVector | `getDecimal(ord, precision, scale)` | +| StringType | VarCharVector | `getUTF8String` (zero-copy via `UTF8String.fromAddress`) | +| BinaryType | VarBinaryVector | `getBinary` (allocates `byte[]`) | + +Widening: add cases to `CometBatchKernelCodegen.emitTypedGetters` and accept the new vector classes in `CometCodegenDispatchUDF.evaluate`'s input pattern match. + +### Output (writers + allocators) + +All scalar Spark types that map to a single Arrow vector: `Boolean`, `Byte`, `Short`, `Int`, `Long`, `Float`, `Double`, `Decimal`, `String`, `Binary`, `Date`, `Timestamp`, `TimestampNTZ`. Mirrors `ArrowWriters.createFieldWriter` so producer and consumer sides stay aligned. Widen by adding cases to `CometBatchKernelCodegen.allocateOutput` and `emitOutputWriter`. + +### Complex types + +`ArrayType`, `StructType`, and `MapType` are supported as both input and output, including arbitrary nesting (`Array>`, `Array`, `Struct`, `Map`, and so on). Each side of the pipeline handles them through recursion over the `ArrowColumnSpec` tree, with a path-suffix naming convention for the emitted fields and nested classes: `_e` for array element, `_f${fi}` for struct field, `_k` / `_v` for map key / value. N-deep nesting falls out of this because every level only knows about its immediate children. + +Output side (`CometBatchKernelCodegenOutput.emitWrite`): + +- `ArrayType` emits a `ListVector.startNewValue` / per-element loop / `endValue` triple; the per-element write recurses through `emitWrite` on the list's child vector. +- `StructType` casts each typed child vector once per row, writes each field via one recursive `emitWrite` call per field, and skips the `isNullAt` guard on non-nullable fields. +- `MapType` casts the entries `StructVector` once per row, writes each key / value pair with a per-value null guard (keys are non-nullable per Arrow invariant), and brackets with `startNewValue` / `endValue`. +- `allocateOutput` builds the complex `FieldVector` tree and recursively allocates child buffers, pre-sized from the input data-buffer estimate where applicable. + +Input side (`CometBatchKernelCodegenInput`): + +- Each complex input column produces a final nested class at every level: `InputArray_${path}` extends `CometArrayData`, `InputStruct_${path}` extends `CometInternalRow`, `InputMap_${path}` extends `CometMapData`. The class holds slice state (arrays / maps: `(startIndex, length)`; structs: `rowIdx`) and pre-allocated child-view instances for any complex child. Spark's generated `row.getArray(ord)` / `row.getStruct(ord, n)` / `row.getMap(ord)` resolves to the kernel's switch which resets and returns the pre-allocated instance. +- Scalar element reads go through the typed child-vector field with zero allocation: `UTF8String.fromAddress` for strings, the decimal128 short-precision fast path for `DecimalType(p <= 18)`, primitive direct reads for everything else. + +### Out of scope + +- Calendar interval types. +- Aggregates, window functions, generators - these need a different bridge signature than `CometUDF.evaluate`. + +## Regex family routing + +Regex serdes (`rlike`, `regexp_replace`, `regexp_extract`, `regexp_extract_all`, `regexp_instr`, `split` via `StringSplit`) route to codegen dispatch in the default `auto` mode when `spark.comet.exec.regexp.engine=java` (itself the default). Set `spark.comet.exec.codegenDispatch.mode=disabled` to fall back to Spark; set `mode=force` to prefer codegen regardless of the regex engine. + +#### Routing matrix + +Rows are the six regex-family expressions; columns are `(spark.comet.exec.regexp.engine, spark.comet.exec.codegenDispatch.mode)`. Cells name the path the serde takes. `Spark` means `convert` returns `None` and Spark executes the expression; `codegen` means the generated Janino kernel via `CometCodegenDispatchUDF`; `native Rust` means the DataFusion scalar function. + +| Expression | java, auto | java, force | java, disabled | rust, auto | rust, force | rust, disabled | +| ----------------------- | ---------- | ----------- | -------------- | ----------- | ----------- | -------------- | +| `rlike` | codegen | codegen | Spark | native Rust | codegen | native Rust | +| `regexp_replace` | codegen | codegen | Spark | native Rust | codegen | native Rust | +| `regexp_extract` | codegen | codegen | Spark | Spark | Spark | Spark | +| `regexp_extract_all` | codegen | codegen | Spark | Spark | Spark | Spark | +| `regexp_instr` | codegen | codegen | Spark | Spark | Spark | Spark | +| `split` (`StringSplit`) | codegen | codegen | Spark | native Rust | codegen | native Rust | + +Notes: + +- `force` always tries codegen first and only falls back to the non-codegen path if `canHandle` rejects the bound expression. For `rlike` / `regexp_replace` / `StringSplit` with `rust` engine, that fallback is native Rust. The matrix collapses to the common outcome. +- `auto` with the rust engine does not prefer codegen (it would bypass the native Rust path the user explicitly selected), so the `rust, auto` column matches `rust, disabled`. +- `regexp_extract` / `regexp_extract_all` / `regexp_instr` have no native Rust path; `getSupportLevel` declares them unsupported when engine is rust, so the cells read `Spark` regardless of dispatch mode. +- The rust-engine cells also depend on `spark.comet.expr.allow.incompat`: when `false` (default), the incompatibility listed in `getIncompatibleReasons` vetoes the cell and Spark executes the expression. The matrix describes what happens once the expression reaches `convert`. + +## Opting a new expression into codegen dispatch + +Adding a new Spark expression to the codegen dispatch path is a serde-only change when its input and output types are already in [Type surface](#type-surface). The pattern mirrors the regex-family serdes in `strings.scala` and the `ScalaUDF` serde in `scalaUdf.scala`. + +Steps: + +1. **Verify type coverage.** `CometBatchKernelCodegen.canHandle(boundExpr)` returns `None` iff every `BoundReference`'s data type is in `isSupportedInputType` and the root data type is in `isSupportedOutputType`. No extra work needed if the expression uses supported types; if not, widen the relevant case in `emitTypedGetters` / `emitWrite` / `allocateOutput` first. + +2. **Wrap `convert` in `pickWithMode`.** The serde's `override def convert(...)` routes through `CodegenDispatchSerdeHelpers.pickWithMode(viaCodegen, viaNonCodegen, preferCodegenInAuto)`. `viaCodegen` is the new helper (step 3). `viaNonCodegen` is either an existing native-DataFusion converter or `() => None` when the only Comet-side path is codegen. `preferCodegenInAuto` decides whether `auto` mode tries codegen first; set `true` when codegen is the intended primary path, `false` when the native path takes priority and codegen is a fallback. + +3. **Add the codegen helper.** `private def convertViaJvmUdfGenericCodegen(expr, inputs, binding): Option[Expr]`. Structure (same for every adoption): + - Any per-expression preconditions (literal-pattern check, offset check, etc.) that `canHandle` does not express. Return `None` with `withInfo` on failure so planning falls back cleanly. + - `val attrs = expr.collect { case a: AttributeReference => a }.distinct` - the bound tree's input columns in ordinal order. + - `val boundExpr = BindReferences.bindReference(expr, AttributeSeq(attrs))` - binds `AttributeReference` leaves to `BoundReference(ord, dt, nullable)`. + - `CodegenDispatchSerdeHelpers.serializedExpressionArg(expr, boundExpr, inputs, binding)` - gates on `canHandle`, serializes via Spark's closure serializer, wraps as a `Literal(bytes, BinaryType)` proto arg. Returns `None` and emits `withInfo` when `canHandle` rejects, so callers just `.getOrElse(return None)`. + - `val dataArgs = attrs.map(a => exprToProtoInternal(a, inputs, binding).getOrElse(return None))` - the raw data columns. + - `val returnType = serializeDataType(expr.dataType).getOrElse(return None)` - the expression's Spark output type. + - Build a `JvmScalarUdf` proto with `setClassName(classOf[CometCodegenDispatchUDF].getName)`, `addArgs(exprArg)` followed by `dataArgs.foreach(addArgs)`, `setReturnType`, `setReturnNullable(expr.nullable)`. Wrap in `ExprOuterClass.Expr` and return `Some(...)`. + +4. **Decide non-codegen routing.** Three cases in practice: + - Native DataFusion path exists (e.g. `regexp_replace` with `engine=rust`): keep the existing `convertViaNativeRegex`/equivalent and have `viaNonCodegen` call it. + - No native path, but there's a meaningful non-codegen alternative: write that converter (rare; only `RLike` was this case historically, now removed). + - No alternative: `viaNonCodegen = () => None`, and `mode=disabled` falls through to Spark. + +5. **Tests.** Add a smoke test in `CometCodegenDispatchSmokeSuite` using `assertCodegenDidWork` around a `checkSparkAnswerAndOperator`, plus `assertKernelSignaturePresent(Seq(classOf[...Vector]), OutputType)` to prove specialization reached the cache. If the expression has a new code path in `emitWrite` or `emitTypedGetters`, also add a source-level marker assertion in `CometCodegenSourceSuite` so future regressions don't silently lose the optimization. + +Once wired, the `auto | force | disabled` mode knob applies automatically and users can disable codegen per-session via `spark.comet.exec.codegenDispatch.mode`. + +## Optimizations + +Every optimization is compile-time specialized on `(bound expression, input schema)`; the emitted Java carries only the selected path at each site. Source-level tests in `CometCodegenSourceSuite` assert that each of these activates where expected. + +### Input readers (`CometBatchKernelCodegenInput.emitTypedGetters` and the nested-class emitters) + +- **ZeroCopyUtf8Read** for `VarCharVector`. `UTF8String.fromAddress` wraps Arrow's data-buffer address with no `byte[]` allocation. `ViewVarCharVector` is not supported today; the dispatcher's `specFor` rejects it with a clear exception if a future upstream change produces one. +- **NonNullableIsNullAtElision** for non-nullable columns. `isNullAt(ord)` returns literal `false`, and `CometCodegenDispatchUDF.rewriteBoundReferences` tightens the `BoundReference.nullable` flag so Spark's `doGenCode` stops probing at source level too (not just at JIT time). +- **DecimalInputShortFastPath** for `DecimalType(p, _)` with `p <= 18`. Reads the low 8 bytes of the decimal128 slot as a signed long and wraps with `Decimal.createUnsafe`. The slow path (`getObject` + `Decimal.apply`) is emitted only for `p > 18`. + +### Output writers (`CometBatchKernelCodegenOutput`) + +- **DecimalOutputShortFastPath** for `DecimalType(p, _)` with `p <= 18`. Passes `Decimal.toUnscaledLong` to `DecimalVector.setSafe(int, long)`. Slow path via `toJavaBigDecimal()` is emitted only for `p > 18`. +- **Utf8OutputOnHeapShortcut** for `StringType`. When the `UTF8String` base is a `byte[]`, passes it directly to `VarCharVector.setSafe(int, byte[], int, int)` and skips the redundant `getBytes()` allocation. Off-heap fallback retains `getBytes()`. +- **PreSizedOutputBuffer** for variable-length output types. The caller passes an input-size-derived byte estimate to avoid `setSafe` reallocations mid-loop. + +### Kernel shape (`defaultBody` / `generateSource`) + +- **NullIntolerantShortCircuit**. Expression trees where every node is `NullIntolerant` or a leaf get a pre-body null check over the union of input ordinals; null rows skip both CSE evaluation and the main expression body. Correct only when every path from a leaf to the root propagates nulls; breaking the chain with `Coalesce` / `If` / `CaseWhen` / `Concat` falls through to the default branch which runs Spark's own null-aware `ev.code`. +- **NonNullableOutputShortCircuit**. Bound expressions with `nullable == false` drop the `if (ev.isNull) setNull` guard and write unconditionally at source level rather than depending on JIT constant-folding. +- **SubexpressionElimination** (when `spark.sql.subexpressionEliminationEnabled`). Common subtrees become helper methods writing into `addMutableState` fields. Class-field variant for the reason given in [Subexpression elimination (CSE)](#subexpression-elimination-cse) below. + +### Per-expression specializers + +- **RegExpReplaceSpecialized** for `RegExpReplace` with a direct `BoundReference` subject, foldable non-null pattern and replacement, and `pos == 1`. Emits `byte[] -> String -> Matcher -> String -> byte[]` directly, bypassing the `UTF8String` round-trip that default `doGenCode` forces. `java.util.regex.Matcher` requires a `CharSequence`, so the default path materializes a Java `String` from the input `UTF8String`, runs the matcher, then encodes back to `UTF8String`. The round-trip cost is measurable on wide-match workloads; see `specializedRegExpReplaceBody` for the benchmark rationale. + +The general rule for adding a new specialization: specialize when an expression's `doGenCode` pays conversions that an Arrow-aware byte-oriented implementation can skip. The common case is expressions that require a Java `String` (`java.util.regex`, some `DateTimeFormatter` expressions). Keep specializations minimal so comparisons stay honest. + +## Subexpression elimination (CSE) + +CSE hoists repeated subtrees into a single evaluation per row. Spark exposes two entry points: + +- `subexpressionElimination` (via `ctx.generateExpressions(..., doSubexpressionElimination = true)` + `ctx.subexprFunctionsCode`). Each common subexpression becomes a helper method that writes its result into class-level mutable state allocated via `addMutableState`. The main expression's `genCode` references those class fields. This is what `GeneratePredicate`, `GenerateMutableProjection`, and `GenerateUnsafeProjection` use. +- `subexpressionEliminationForWholeStageCodegen`. CSE results live in local variables declared in the caller's scope, and the main expression's `genCode` references those locals. Only safe when no helper method gets extracted between the locals' declaration site and their use. + +We use the **class-field** variant. The WSCG variant does not work in our shape without additional setup: Spark's arithmetic, string, and decimal expressions internally call `splitExpressionsWithCurrentInputs`, which splits into helper methods unless `currentVars` is non-null. In our kernel `currentVars` is null (we read from a row, not from materialized locals), so those splits fire and the helper bodies cannot see CSE-declared locals in the outer scope. The class-field variant sidesteps this because helper methods can read class fields freely. + +### Future WSCG-variant exploration + +Making the WSCG variant usable would require: + +- Setting `ctx.currentVars = Seq.fill(numInputs)(null)` before CSE. `BoundReference.genCode` checks `currentVars != null && currentVars(ord) != null`, so an all-null `currentVars` lets reads fall through to the `INPUT_ROW` path (what we want) while `splitExpressionsWithCurrentInputs` sees `currentVars != null` and declines to split. +- Verifying that direct `ctx.splitExpressions` calls (not the `-WithCurrentInputs` wrapper) in a handful of expressions (`hash`, `Cast`, `collectionOperations`, `ToStringBase`) remain self-contained. They pass explicit args to their split helpers, so they should be fine, but that is a per-expression audit. +- Benchmarking. The potential win is that CSE state lives in local variables rather than class fields, so HotSpot has more freedom to keep values in registers. Whether that wins over the class-field variant is unclear; CSE state is written once and read two or more times per row, and the expression work usually dominates. Not worth doing until a profile shows class-field access on the hot path. +- If the kernel ever gets integrated into Spark's `WholeStageCodegenExec` pipeline (rather than standing alone), the WSCG variant becomes the natural fit and this revisit is forced. Until then, the standalone-kernel shape matches Predicate/Projection/UnsafeRow generators, which use class-field CSE. + +## Open items + +Each item below has a `TODO` in the code at the referenced location. The code-side comment is a short pointer; this section carries the rationale. + +### Dictionary-encoded inputs + +`CometCodegenDispatchUDF.evaluate` (near the top). Comet's native scan and shuffle paths currently materialize dictionaries before the UDF bridge, so `v.getField.getDictionary != null` is not observed here today. If that invariant is ever relaxed upstream, the cast in `specFor` throws. Two ways to fix it at that point: + +- Materialize at the dispatcher via `CDataDictionaryProvider` (see `NativeUtil.importVector`). Simpler. +- Widen `emitTypedGetters` with a dict-index read plus a lookup into the dictionary vector. Faster on high-cardinality dictionaries but adds a cache-key dimension. + +### Cache-key hash cost + +`CometCodegenDispatchUDF.CacheKey`. `hashCode` walks `bytesKey` once per batch (`equals` again on hash collision). For small expressions (a few KB) this is single-digit microseconds and invisible; for large `ScalaUDF` closures with heavy encoders (tens to hundreds of KB) it could climb to tens of microseconds per batch. If a workload shows this on a profile, three alternatives worth exploring: + +1. Driver-side precomputed hash piggybacked through the Arrow transport as a small tag (e.g. 8 bytes). Executor uses the tag directly as the key. O(1) per batch, and the tag is tiny versus the full byte array. +2. Per-UDF-instance byte-identity fast path. `CometCodegenDispatchUDF` is per-thread; the expression is invariant for the life of one task. Memoize the last-seen `(Arrow data buffer address, offset, length)` tuple and skip the HashMap entirely when it matches. +3. Two-level cache with source-string outer tier. Keep bytes-based L1 as today; add an L2 keyed on `generateSource(expr).code.body` that stores only the Janino-compiled class. Captures the "same lambda, different closure identity" cross-query reuse case (e.g. the same `udf((i: Int) => i + 1)` registered across sessions produces identical source but different serialized bytes). + +None of these are worth doing until a profile shows lookup in the hot path. + +### Unsafe readers skipping Arrow bounds checks + +`CometBatchKernelCodegenInput.emitTypedGetters`. Primitive getters go through Arrow's typed `v.get(i)` which performs bounds checks. Inside the kernel's `process` loop `i` is always in `[0, numRows)`, so the check is redundant. Mirror `CometPlainVector`'s pattern (cache validity/value/offset buffer addresses, use direct `Platform.getInt` reads) behind a benchmark. + +### Per-row-body method-size splitting + +`CometBatchKernelCodegen.generateSource`. The per-row body lives inline inside `process`'s for-loop and is not split. Individual `doGenCode` implementations (`Concat`, `Cast`, `CaseWhen`) call `ctx.splitExpressionsWithCurrentInputs` internally, but the outer per-row body itself is never split. A sufficiently deep composed expression (multi-level ScalaUDF with heavy encoder converters per level) can push `process` past Janino's 64 KB method size limit, at which point compile fails. Mitigation when that ceiling is hit: wrap `perRowBody` in `ctx.splitExpressionsWithCurrentInputs(Seq(perRowBody), funcName = "evalRow", arguments = Seq(...))`. The `row`-as-`this` alias we install in `process` already covers that path. Skipped speculatively because today's workloads sit comfortably below the threshold and splitting unconditionally adds a function-call frame per row for the common case. + +## Known behavioral limitations + +- **`regexp_replace` on a collated subject** rejects at plan time: Spark wraps the pattern in `Collate(Literal, ...)` and the current `RegExpReplace` serde requires a bare `Literal`. Serde-side unwrap would unblock this. +- **`rlike` on ICU collations** (`UNICODE_CI` etc.) is a type mismatch in Spark itself (RLike contracts on `UTF8_BINARY`), not a Comet limitation. Binary collations like `UTF8_LCASE` work. +- **Observability sink**. `CometCodegenDispatchUDF.stats()` and `snapshotCompiledSignatures()` are test-facing; not yet wired to Spark SQL metrics, JMX, or periodic logging. +- **DataFusion alignment gaps in the bridge contract**: + - `arg_fields` - already covered by `ValueVector.getField()` on the JVM side. + - `return_field` - dispatcher derives it via `boundExpr.dataType`. + - `config_options` - session-level state like timezone / locale. Not plumbed across JNI. Would matter for TZ-aware or locale-sensitive UDFs. + - `ColumnarValue::Scalar` return - DataFusion lets a scalar function return one value broadcast to batch length. Arrow Java has no `ScalarValue` equivalent; adding it would need a new JVM wrapper type plus an FFI protocol extension. Small practical payoff (most UDFs produce row-varying output; true constants are folded at plan time), large surface change. +- **Benchmark observation** (`CometScalaUDFCompositionBenchmark`). On plans of shape `Scan -> Project[UDF] -> noop` or `Scan -> Project[UDF] -> SUM`, the dispatcher runs a few percent slower than "dispatcher disabled" (Spark row-based fallback) at 1M rows. Both paths do the same per-row work in the JVM and our path pays an extra JNI hop. The benefit is keeping the surrounding plan columnar when downstream operators would otherwise fall back, a shape the current benchmark does not exercise. A follow-up benchmark with expensive columnar operators around the UDF (filter + hash join + aggregate) would measure the plan-preservation effect. +- **Candidates for specialized emitters beyond `RegExpReplace`**. Other regex-family expressions (`regexp_extract`, `regexp_extract_all`, `regexp_instr`) pay the same `UTF8String <-> String` conversion chain Spark's `doGenCode` forces. `str_to_map` is another candidate. Audit pending. +- **Longer-term: full `WholeStageCodegenExec` integration**. Build a Spark plan tree (`ArrowOutputExec(ProjectExec(ColumnarToRowExec(BatchInputExec)))`) and let Spark's WSCG fuse everything through its own codegen machinery, reusing `CometVector` on the input side. Larger engineering footprint (custom `CodegenSupport` sink, plan construction inside JNI callbacks) but unlocks nested types and every Arrow input type without Comet-side getter maintenance. + +## File map + +- `common/src/main/scala/org/apache/comet/udf/CometCodegenDispatchUDF.scala` - dispatcher `CometUDF`, shared LRU, counters, `snapshotCompiledSignatures()`. +- `common/src/main/scala/org/apache/comet/udf/CometBatchKernelCodegen.scala` - Janino-based kernel compiler, `canHandle`, `allocateOutput`, `emitOutputWriter`, `emitTypedGetters`, `CompiledKernel` with `freshReferences` closure. +- `common/src/main/scala/org/apache/comet/udf/CometInternalRow.scala` - abstract `InternalRow` base with throwing defaults for unimplemented getters. +- `common/src/main/scala/org/apache/comet/udf/CometUDF.scala` - `CometUDF.evaluate(inputs, numRows)` contract. +- `common/src/main/java/org/apache/comet/udf/CometBatchKernel.java` - Java abstract base the generated subclass extends. +- `common/src/main/java/org/apache/comet/udf/CometUdfBridge.java` - JNI entry point; plumbs `numRows` through. +- `native/jni-bridge/src/comet_udf_bridge.rs` - JNI method ID lookup for `CometUdfBridge.evaluate`. +- `native/spark-expr/src/jvm_udf/mod.rs` - Rust-side `JvmScalarUdfExpr` calling the JVM bridge. +- `spark/src/main/scala/org/apache/comet/serde/strings.scala` - rlike / regexp_replace / regexp_extract / regexp_extract_all / regexp_instr / string_split serdes, `CodegenDispatchSerdeHelpers` (`canHandle` + serialization). +- `spark/src/main/scala/org/apache/comet/serde/scalaUdf.scala` - `ScalaUDF` serde routing user UDFs through the dispatcher. +- `spark/src/test/scala/org/apache/comet/CometCodegenDispatchSmokeSuite.scala` - smoke tests: mode knob, composition, `ScalaUDF`, type-surface, zero-column, signature assertions. +- `spark/src/test/scala/org/apache/comet/CometCodegenDispatchFuzzSuite.scala` - randomized string fuzz across null densities and a fixed regex pattern set. +- `spark/src/test/scala/org/apache/spark/sql/benchmark/CometScalaUDFCompositionBenchmark.scala` - benchmark comparing Spark, Comet native built-ins, dispatcher-disabled fallback, and codegen dispatch for composed `ScalaUDF` trees. diff --git a/docs/source/user-guide/latest/index.rst b/docs/source/user-guide/latest/index.rst index 314a0a51bd..237d45b858 100644 --- a/docs/source/user-guide/latest/index.rst +++ b/docs/source/user-guide/latest/index.rst @@ -43,6 +43,7 @@ to read more. Supported Data Types Supported Operators Supported Expressions + JVM UDF Dispatch Configuration Settings Compatibility Guide Understanding Comet Plans diff --git a/docs/source/user-guide/latest/jvm_udf_dispatch.md b/docs/source/user-guide/latest/jvm_udf_dispatch.md new file mode 100644 index 0000000000..65edff4a30 --- /dev/null +++ b/docs/source/user-guide/latest/jvm_udf_dispatch.md @@ -0,0 +1,75 @@ + + +# JVM UDF dispatch + +Comet can route scalar expressions that lack a native DataFusion implementation, or whose native implementation diverges from Spark, through a JVM-side kernel that processes Arrow batches directly. Surrounding native operators stay on the Comet path instead of forcing a whole-plan fallback to Spark. The tradeoff is a JNI roundtrip and per-batch JVM execution. + +## Supported expressions + +- User-defined scalar functions registered via `spark.udf.register` (Scala `UDF1`/`UDF2`/... or Java functional interfaces), `udf(...)`, or SQL `CREATE FUNCTION ... AS 'com.example.MyUDF'`. +- Regex family: `rlike`, `regexp_replace`, `regexp_extract`, `regexp_extract_all`, `regexp_instr`, and `split` with a literal regex pattern. + +Not supported: + +- Aggregate UDFs, table UDFs, generators. +- Python `@udf` and Pandas `@pandas_udf`. +- Hive `GenericUDF` / `SimpleUDF`. + +## Supported types + +Scalar: `Boolean`, `Byte`, `Short`, `Int`, `Long`, `Float`, `Double`, `Decimal`, `String`, `Binary`, `Date`, `Timestamp`, `TimestampNTZ`. + +Complex (as both input and output, including arbitrary nesting): `ArrayType`, `StructType`, `MapType`. + +## Configuration + +| Key | Default | Description | +| --------------------------------------- | ------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `spark.comet.exec.codegenDispatch.mode` | `auto` | `auto` routes through JVM codegen when it is the serde's primary path (regex with java engine, ScalaUDF). `force` routes through codegen whenever accepted. `disabled` never routes through codegen. | +| `spark.comet.exec.regexp.engine` | `java` | `java` uses the JVM codegen path for the regex family. `rust` prefers the native DataFusion engine where one exists and falls back to Spark otherwise. | + +## Regex routing + +Cells name the path the expression takes. `Spark` means the plan falls back to Spark. `codegen` means the JVM codegen dispatcher. `native` means the DataFusion scalar function. + +| Expression | `java, auto` | `java, force` | `java, disabled` | `rust, auto` | `rust, force` | `rust, disabled` | +| -------------------- | ------------ | ------------- | ---------------- | ------------ | ------------- | ---------------- | +| `rlike` | codegen | codegen | Spark | native | codegen | native | +| `regexp_replace` | codegen | codegen | Spark | native | codegen | native | +| `regexp_extract` | codegen | codegen | Spark | Spark | Spark | Spark | +| `regexp_extract_all` | codegen | codegen | Spark | Spark | Spark | Spark | +| `regexp_instr` | codegen | codegen | Spark | Spark | Spark | Spark | +| `split` | codegen | codegen | Spark | native | codegen | native | + +`regexp_extract`, `regexp_extract_all`, and `regexp_instr` have no native DataFusion path, so rust-engine cells read `Spark` regardless of dispatch mode. Rust-engine cells also require `spark.comet.expr.allow.incompat=true` for patterns the rust engine evaluates incompatibly with Spark; otherwise the plan falls back to Spark. + +## Behavior notes + +- Non-deterministic expressions (`rand`, `uuid`, `monotonically_increasing_id`) produce per-partition sequences consistent with Spark. One kernel instance lives per partition; state is reset on partition boundaries. +- ScalaUDF bodies that read `TaskContext.get()` see the correct partition context even when executed on a Tokio worker thread. +- The user function must be closure-serializable. The same function that works with Spark's executor execution works here. + +## Known limitations + +- Dictionary-encoded inputs are not handled. Comet's native scan and shuffle paths materialize dictionaries before the dispatcher, so this is not a current failure mode. If you observe it, file an issue. +- `regexp_replace` on a collated subject rejects at plan time; Spark wraps the pattern in `Collate(Literal, ...)` and the serde requires a bare `Literal`. +- `rlike` on ICU collations (e.g. `UNICODE_CI`) is a type mismatch in Spark itself, not a Comet-specific limitation. Binary collations like `UTF8_LCASE` work. + +For internals (architecture, caching, compile-time specializations, open work items), see the contributor guide [JVM UDF Dispatch](../../contributor-guide/jvm_udf_dispatch.md) page. diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 5d3dbb8266..ecb05eb91f 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -306,6 +306,13 @@ struct ExecutionContext { pub tracing_memory_metric_name: String, /// Pre-computed tracing event name for executePlan calls pub tracing_event_name: String, + /// Spark `TaskContext` captured on the driving Spark task thread at `createPlan` time. + /// Threaded into every JVM scalar UDF the planner builds so the JNI bridge can install it + /// as the thread-local `TaskContext` for the Tokio worker running the UDF. `None` when no + /// driving Spark task is present (unit tests, direct native driver runs). The `Arc` is + /// cheap to clone; the underlying `Global` releases its JNI global ref on drop + /// via `jni`'s `Drop` impl. + pub task_context: Option>>>, } /// Accept serialized query plan and return the address of the native query plan. @@ -332,6 +339,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( task_attempt_id: jlong, task_cpus: jlong, key_unwrapper_obj: JObject, + task_context_obj: JObject, ) -> jlong { try_unwrap_or_throw(&e, |env| { // Deserialize Spark configs @@ -453,6 +461,15 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( String::new() }; + // Capture the driving Spark task's TaskContext as a JNI global reference when + // non-null. The `Arc>` releases its global ref on drop, so + // cleanup is automatic when the ExecutionContext drops. + let task_context = if !task_context_obj.is_null() { + Some(Arc::new(jni_new_global_ref!(env, task_context_obj)?)) + } else { + None + }; + let exec_context = Box::new(ExecutionContext { id, task_attempt_id, @@ -479,6 +496,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( "thread_{rust_thread_id}_comet_memory_reserved" ), tracing_event_name, + task_context, }); Ok(Box::into_raw(exec_context) as i64) @@ -703,7 +721,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( let start = Instant::now(); let planner = PhysicalPlanner::new(Arc::clone(&exec_context.session_ctx), partition) - .with_exec_id(exec_context_id); + .with_exec_id(exec_context_id) + .with_task_context(exec_context.task_context.clone()); let (scans, shuffle_scans, root_op) = planner.create_plan( &exec_context.spark_plan, &mut exec_context.input_sources.clone(), diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 478c7a8d98..6b40ea435f 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -183,6 +183,12 @@ pub struct PhysicalPlanner { partition: i32, session_ctx: Arc, query_context_registry: Arc, + /// Spark `TaskContext` captured on the driving Spark task thread and stashed on the + /// [`ExecutionContext`] at `createPlan` time. Threaded into every [`JvmScalarUdfExpr`] the + /// planner builds so the JNI bridge can install it as the thread-local `TaskContext` on + /// the Tokio worker that drives the UDF. `None` when no driving Spark task is available + /// (unit tests, direct native driver runs). + task_context: Option>>>, } impl Default for PhysicalPlanner { @@ -198,6 +204,7 @@ impl PhysicalPlanner { session_ctx, partition, query_context_registry: datafusion_comet_spark_expr::create_query_context_map(), + task_context: None, } } @@ -207,6 +214,20 @@ impl PhysicalPlanner { partition: self.partition, session_ctx: Arc::clone(&self.session_ctx), query_context_registry: Arc::clone(&self.query_context_registry), + task_context: self.task_context, + } + } + + /// Attach a propagated Spark `TaskContext` global reference. Called by the JNI `executePlan` + /// entry with whatever was captured at `createPlan` time. The planner clones this `Option` + /// into every `JvmScalarUdfExpr` it builds. + pub fn with_task_context(self, task_context: Option>>>) -> Self { + Self { + exec_context_id: self.exec_context_id, + partition: self.partition, + session_ctx: self.session_ctx, + query_context_registry: self.query_context_registry, + task_context, } } @@ -735,6 +756,7 @@ impl PhysicalPlanner { args, return_type, udf.return_nullable, + self.task_context.clone(), ))) } expr => Err(GeneralError(format!("Not implemented: {expr:?}"))), diff --git a/native/jni-bridge/src/comet_udf_bridge.rs b/native/jni-bridge/src/comet_udf_bridge.rs index 89cd8ee514..e531d20cb1 100644 --- a/native/jni-bridge/src/comet_udf_bridge.rs +++ b/native/jni-bridge/src/comet_udf_bridge.rs @@ -41,7 +41,7 @@ impl<'a> CometUdfBridge<'a> { method_evaluate: env.get_static_method_id( JNIString::new(Self::JVM_CLASS), jni::jni_str!("evaluate"), - jni::jni_sig!("(Ljava/lang/String;[J[JJJ)V"), + jni::jni_sig!("(Ljava/lang/String;[J[JJJILorg/apache/spark/TaskContext;)V"), )?, method_evaluate_ret: ReturnType::Primitive(Primitive::Void), class, diff --git a/native/jni-bridge/src/lib.rs b/native/jni-bridge/src/lib.rs index d72323c961..f95d3cc174 100644 --- a/native/jni-bridge/src/lib.rs +++ b/native/jni-bridge/src/lib.rs @@ -231,7 +231,8 @@ pub struct JVMClasses<'a> { /// acquire & release native memory. pub comet_task_memory_manager: CometTaskMemoryManager<'a>, /// The CometUdfBridge class used to dispatch JVM scalar UDFs. - /// `None` if the class is not on the classpath. + /// `None` if the class is not on the classpath; the JVM-UDF dispatch path + /// reports a clear error rather than crashing executor init. pub comet_udf_bridge: Option>, } @@ -304,6 +305,9 @@ impl JVMClasses<'_> { comet_shuffle_block_iterator: CometShuffleBlockIterator::new(env).unwrap(), comet_task_memory_manager: CometTaskMemoryManager::new(env).unwrap(), comet_udf_bridge: { + // Optional: if the bridge class is absent (e.g. comet shading + // dropped org.apache.comet.udf.*), record None and clear the + // pending JVM exception so other JNI calls keep working. let bridge = CometUdfBridge::new(env).ok(); if env.exception_check() { env.exception_clear(); diff --git a/native/spark-expr/src/jvm_udf/mod.rs b/native/spark-expr/src/jvm_udf/mod.rs index 668a2b6727..ddfad18a1a 100644 --- a/native/spark-expr/src/jvm_udf/mod.rs +++ b/native/spark-expr/src/jvm_udf/mod.rs @@ -31,7 +31,7 @@ use datafusion::physical_expr::PhysicalExpr; use datafusion_comet_jni_bridge::errors::{CometError, ExecutionError}; use datafusion_comet_jni_bridge::JVMClasses; -use jni::objects::{JObject, JValue}; +use jni::objects::{Global, JObject, JValue}; /// A scalar expression that delegates evaluation to a JVM-side `CometUDF` via JNI. /// The JVM class named by `class_name` must implement `org.apache.comet.udf.CometUDF`. @@ -41,6 +41,17 @@ pub struct JvmScalarUdfExpr { args: Vec>, return_type: DataType, return_nullable: bool, + /// Spark `TaskContext` captured on the driving Spark task thread, stashed in the + /// [`ExecutionContext`] at `createPlan` time, and threaded here by the planner. Passed + /// through the JNI bridge so [`CometUdfBridge.evaluate`] can install it as the + /// thread-local `TaskContext` on the Tokio worker that drives the UDF call. Without this, + /// partition-sensitive built-ins inside a user UDF tree (`Rand`, `Uuid`, + /// `MonotonicallyIncreasingID`, custom UDF code that reads + /// `TaskContext.get().partitionId()`) see a null `TaskContext` and seed / branch + /// incorrectly. `None` means the surrounding driver had no `TaskContext` to propagate + /// (unit tests, direct native driver runs); the bridge then leaves whatever + /// `TaskContext.get()` already returns in place. + task_context: Option>>>, } impl JvmScalarUdfExpr { @@ -49,12 +60,14 @@ impl JvmScalarUdfExpr { args: Vec>, return_type: DataType, return_nullable: bool, + task_context: Option>>>, ) -> Self { Self { class_name, args, return_type, return_nullable, + task_context, } } } @@ -110,10 +123,10 @@ impl PhysicalExpr for JvmScalarUdfExpr { } fn evaluate(&self, batch: &RecordBatch) -> DFResult { - // Step 1: evaluate child expressions to get Arrow arrays. Scalar children - // (e.g. literal patterns) are sent as length-1 vectors rather than expanded - // to batch-row count, so the JVM bridge does not pay an O(rows) copy for - // values that never vary across the batch. + // Scalar children (e.g. literal patterns) are sent as length-1 vectors rather than + // expanded to batch-row count, so the JVM bridge does not pay an O(rows) copy for + // values that never vary across the batch. The JVM side gets `numRows` directly via + // the bridge so it doesn't need the scalar to carry batch length. let arrays: Vec = self .args .iter() @@ -123,7 +136,6 @@ impl PhysicalExpr for JvmScalarUdfExpr { }) .collect::>()?; - // Step 2: allocate FFI structs on the Rust heap and collect their raw pointers. // The JVM writes into the out_array/out_schema slots and reads from the in_ slots. let in_ffi_arrays: Vec> = arrays .iter() @@ -147,7 +159,6 @@ impl PhysicalExpr for JvmScalarUdfExpr { .map(|b| b.as_ref() as *const FFI_ArrowSchema as i64) .collect(); - // Allocate output FFI slots. let mut out_array = Box::new(FFI_ArrowArray::empty()); let mut out_schema = Box::new(FFI_ArrowSchema::empty()); let out_arr_ptr = out_array.as_mut() as *mut FFI_ArrowArray as i64; @@ -156,22 +167,20 @@ impl PhysicalExpr for JvmScalarUdfExpr { let class_name = self.class_name.clone(); let n_args = arrays.len(); - // Step 3: attach a JNI env for this thread and call the static bridge method. JVMClasses::with_env(|env| { let bridge = JVMClasses::get().comet_udf_bridge.as_ref().ok_or_else(|| { CometError::from(ExecutionError::GeneralError( "JVM UDF bridge unavailable: org.apache.comet.udf.CometUdfBridge \ - class was not found on the JVM classpath." + class was not found on the JVM classpath. Set \ + spark.comet.exec.regexp.engine=rust to disable this path." .to_string(), )) })?; - // Build the JVM String for the class name. let jclass_name = env .new_string(&class_name) .map_err(|e| CometError::JNI { source: e })?; - // Build the long[] arrays for input pointers. let in_arr_java = env .new_long_array(n_args) .map_err(|e| CometError::JNI { source: e })?; @@ -186,7 +195,15 @@ impl PhysicalExpr for JvmScalarUdfExpr { .set_region(env, 0, &in_sch_ptrs) .map_err(|e| CometError::JNI { source: e })?; - // Call CometUdfBridge.evaluate(String, long[], long[], long, long) + // Resolve the TaskContext reference once before building the arg array so the + // borrow lives until `call_static_method_unchecked` returns. When no TaskContext + // was propagated, pass a null object so the bridge's null-guard leaves the thread- + // local alone. + let null_task_context = JObject::null(); + let task_context_ref: &JObject = match &self.task_context { + Some(gref) => gref.as_obj(), + None => &null_task_context, + }; let ret = unsafe { env.call_static_method_unchecked( &bridge.class, @@ -198,6 +215,8 @@ impl PhysicalExpr for JvmScalarUdfExpr { JValue::Object(JObject::from(in_sch_java).as_ref()).as_jni(), JValue::Long(out_arr_ptr).as_jni(), JValue::Long(out_sch_ptr).as_jni(), + JValue::Int(batch.num_rows() as i32).as_jni(), + JValue::Object(task_context_ref).as_jni(), ], ) }; @@ -210,7 +229,6 @@ impl PhysicalExpr for JvmScalarUdfExpr { Ok(()) })?; - // Step 4: import the result from the FFI slots filled by the JVM. // SAFETY: `*out_array` moves the FFI_ArrowArray out of the Box (the heap // allocation is freed by the move), and `from_ffi` wraps it in an Arc that // keeps the JVM-installed release callback alive until the resulting @@ -218,7 +236,19 @@ impl PhysicalExpr for JvmScalarUdfExpr { // exactly once when the Box drops at end of scope. let result_data = unsafe { from_ffi(*out_array, &out_schema) } .map_err(|e| CometError::Arrow { source: e })?; - Ok(ColumnarValue::Array(make_array(result_data))) + let result_array = make_array(result_data); + + // The JVM may produce arrays with different field names (e.g. Arrow Java's + // ListVector uses "$data$" for child fields) than what DataFusion expects + // (e.g. "item"). Cast to the declared return_type to normalize schema. + let result_array = if result_array.data_type() != &self.return_type { + arrow::compute::cast(&result_array, &self.return_type) + .map_err(|e| CometError::Arrow { source: e })? + } else { + result_array + }; + + Ok(ColumnarValue::Array(result_array)) } fn children(&self) -> Vec<&Arc> { @@ -234,6 +264,7 @@ impl PhysicalExpr for JvmScalarUdfExpr { children, self.return_type.clone(), self.return_nullable, + self.task_context.clone(), ))) } } diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index a93564811c..f385d22700 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -127,7 +127,12 @@ class CometExecIterator( memoryConfig.memoryLimitPerTask, taskAttemptId, taskCPUs, - keyUnwrapper) + keyUnwrapper, + // Capture the Spark task thread's TaskContext at `createPlan` time. Stashed native-side + // in the ExecutionContext and passed through the JVM UDF bridge so that Tokio workers + // running JVM UDFs see the real `TaskContext` via their thread-local. See + // `CometUdfBridge.evaluate` and `CometTaskContextShim` for the receive side. + TaskContext.get()) } private var nextBatch: Option[ColumnarBatch] = None diff --git a/spark/src/main/scala/org/apache/comet/Native.scala b/spark/src/main/scala/org/apache/comet/Native.scala index c003bcd138..3cfa51b6e1 100644 --- a/spark/src/main/scala/org/apache/comet/Native.scala +++ b/spark/src/main/scala/org/apache/comet/Native.scala @@ -21,7 +21,7 @@ package org.apache.comet import java.nio.ByteBuffer -import org.apache.spark.CometTaskMemoryManager +import org.apache.spark.{CometTaskMemoryManager, TaskContext} import org.apache.spark.sql.comet.CometMetricNode import org.apache.comet.parquet.CometFileKeyUnwrapper @@ -69,7 +69,8 @@ class Native extends NativeBase { memoryLimitPerTask: Long, taskAttemptId: Long, taskCPUs: Long, - keyUnwrapper: CometFileKeyUnwrapper): Long + keyUnwrapper: CometFileKeyUnwrapper, + taskContext: TaskContext): Long // scalastyle:on /** diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 2d138450e9..b3da6afd08 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -179,6 +179,9 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[Like] -> CometLike, classOf[Lower] -> CometLower, classOf[OctetLength] -> CometScalarFunction("octet_length"), + classOf[RegExpExtract] -> CometRegExpExtract, + classOf[RegExpExtractAll] -> CometRegExpExtractAll, + classOf[RegExpInStr] -> CometRegExpInStr, classOf[RegExpReplace] -> CometRegExpReplace, classOf[Reverse] -> CometReverse, classOf[RLike] -> CometRLike, @@ -255,6 +258,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[MakeDecimal] -> CometMakeDecimal, classOf[MonotonicallyIncreasingID] -> CometMonotonicallyIncreasingId, classOf[ScalarSubquery] -> CometScalarSubquery, + classOf[ScalaUDF] -> CometScalaUDF, classOf[SparkPartitionID] -> CometSparkPartitionId, classOf[SortOrder] -> CometSortOrder, classOf[StaticInvoke] -> CometStaticInvoke, diff --git a/spark/src/main/scala/org/apache/comet/serde/scalaUdf.scala b/spark/src/main/scala/org/apache/comet/serde/scalaUdf.scala new file mode 100644 index 0000000000..de9e2148a6 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/scalaUdf.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.catalyst.expressions.{Attribute, ScalaUDF} + +import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.serde.ExprOuterClass.Expr + +/** + * Routes scalar `ScalaUDF` expressions (user-registered Scala and Java UDFs) through the + * Arrow-direct codegen dispatcher. `ScalaUDF.doGenCode` emits compilable Java that invokes the + * user function via `ctx.addReferenceObj`, so the codegen path picks it up unchanged: we + * serialize the bound tree, the closure serializer carries the function reference across the + * wire, and the Janino-compiled kernel loads the function and invokes it in a tight batch loop. + * + * Not covered here: + * - Aggregate UDFs (`ScalaAggregator`, `TypedImperativeAggregate`, old UDAF API) - different + * bridge contract. + * - Table UDFs (`UserDefinedTableFunction`) - generator shape; `canHandle` rejects. + * - Python / Pandas UDFs - different runtime. + * - Hive UDFs (`HiveGenericUDF` / `HiveSimpleUDF`) - separate expression classes; would need + * their own serde. + * + * Mode knob: `auto` prefers codegen because `ScalaUDF` has no native fallback; `disabled` returns + * `None` and the plan falls back to Spark. + */ +object CometScalaUDF extends CometExpressionSerde[ScalaUDF] { + + override def convert(expr: ScalaUDF, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { + CodegenDispatchSerdeHelpers.pickWithMode( + viaCodegen = + () => CodegenDispatchSerdeHelpers.buildJvmUdfExpr(expr, inputs, binding, expr.dataType), + viaNonCodegen = () => { + withInfo( + expr, + "codegen dispatch disabled; ScalaUDF has no native path so the plan falls back to Spark") + None + }, + preferCodegenInAuto = true) + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/strings.scala b/spark/src/main/scala/org/apache/comet/serde/strings.scala index 968fe8cd69..a14abcb89d 100644 --- a/spark/src/main/scala/org/apache/comet/serde/strings.scala +++ b/spark/src/main/scala/org/apache/comet/serde/strings.scala @@ -21,15 +21,134 @@ package org.apache.comet.serde import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Concat, ConcatWs, Expression, GetJsonObject, If, InitCap, IsNull, Left, Length, Like, Literal, Lower, RegExpReplace, Right, RLike, StringLPad, StringRepeat, StringRPad, StringSplit, Substring, Upper} -import org.apache.spark.sql.types.{BinaryType, DataTypes, LongType, StringType} +import org.apache.spark.SparkEnv +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSeq, BindReferences, Cast, Concat, ConcatWs, Expression, GetJsonObject, If, InitCap, IsNull, Left, Length, Like, Literal, Lower, RegExpExtract, RegExpExtractAll, RegExpInStr, RegExpReplace, Right, RLike, StringLPad, StringRepeat, StringRPad, StringSplit, Substring, Upper} +import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType, DataTypes, IntegerType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.expressions.{CometCast, CometEvalMode, RegExp} import org.apache.comet.serde.ExprOuterClass.Expr -import org.apache.comet.serde.QueryPlanSerde.{createBinaryExpr, exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType} +import org.apache.comet.serde.QueryPlanSerde.{createBinaryExpr, exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType, serializeDataType} +import org.apache.comet.udf.{CometBatchKernelCodegen, CometCodegenDispatchUDF} + +/** + * Helpers for wiring expressions through the [[CometCodegenDispatchUDF]] proto. The codegen + * dispatcher identifies the expression to evaluate by carrying serialized `Expression` bytes as + * its first argument, replacing the earlier driver-side-registry + UUID approach so the path + * works in cluster mode without executor-side state. + */ +private[serde] object CodegenDispatchSerdeHelpers { + + /** + * Serialize a bound `Expression` via Spark's closure serializer and wrap the bytes as a + * `Literal(bytes, BinaryType)` proto arg. The closure serializer respects the task context + * classloader (so user UDF jars are visible) and matches the machinery Spark uses to ship + * closures across the wire. + * + * Gated by [[CometBatchKernelCodegen.canHandle]]: if the bound expression has an unsupported + * input or output type, we log via `withInfo` and return `None` so the caller falls back. + * Prevents unsupported shapes from reaching the Janino compiler at execute time. + */ + def serializedExpressionArg( + original: Expression, + boundExpr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + CometBatchKernelCodegen.canHandle(boundExpr) match { + case Some(reason) => + withInfo(original, reason) + return None + case None => + } + val serializer = SparkEnv.get.closureSerializer.newInstance() + val buffer = serializer.serialize(boundExpr) + val bytes = new Array[Byte](buffer.remaining()) + buffer.get(bytes) + exprToProtoInternal(Literal(bytes, BinaryType), inputs, binding) + } + + /** + * Build the [[ExprOuterClass.Expr]] proto routing `expr` through [[CometCodegenDispatchUDF]]. + * Shared scaffold: collect the bound tree's `AttributeReference`s, bind, serialize the bound + * tree as arg 0, emit each attribute as a data arg, set the declared return type, wrap. All + * regex-family serdes and [[CometScalaUDF]] land here. + */ + def buildJvmUdfExpr( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean, + returnType: DataType): Option[Expr] = { + val attrs = expr.collect { case a: AttributeReference => a }.distinct + val boundExpr = BindReferences.bindReference(expr, AttributeSeq(attrs)) + + val exprArg = serializedExpressionArg(expr, boundExpr, inputs, binding) + .getOrElse(return None) + val dataArgs = + attrs.map(a => exprToProtoInternal(a, inputs, binding).getOrElse(return None)) + + val returnTypeProto = serializeDataType(returnType).getOrElse(return None) + val udfBuilder = ExprOuterClass.JvmScalarUdf + .newBuilder() + .setClassName(classOf[CometCodegenDispatchUDF].getName) + .addArgs(exprArg) + dataArgs.foreach(udfBuilder.addArgs) + udfBuilder + .setReturnType(returnTypeProto) + .setReturnNullable(expr.nullable) + Some( + ExprOuterClass.Expr + .newBuilder() + .setJvmScalarUdf(udfBuilder.build()) + .build()) + } + + /** + * Validate a regex-literal value: non-null and syntactically compilable by + * `java.util.regex.Pattern`. Returns `Some(reason)` for the caller to pass to `withInfo` when + * the literal forces a Spark fallback, `None` when it is usable. + */ + def validateRegexLiteral(value: Any): Option[String] = { + if (value == null) { + return Some("Null literal pattern is handled by Spark fallback") + } + try { + java.util.regex.Pattern.compile(value.toString) + None + } catch { + case e: java.util.regex.PatternSyntaxException => + Some(s"Invalid regex pattern: ${e.getDescription}") + } + } + + /** + * Chain-of-responsibility picker for expressions that have a codegen dispatcher path plus an + * optional non-codegen fallback (native DataFusion, Spark, etc.). Mode semantics: + * + * - `force`: try codegen first, fall back to `viaNonCodegen` if codegen rejects the + * expression. + * - `disabled`: never try codegen. + * - `auto`: try codegen first when `preferCodegenInAuto` is true, otherwise skip it. + * + * The picker returns `None` if every attempted path returns `None` (the serde should then emit + * `withInfo` + fallback higher up). `viaCodegen` already bakes in the `canHandle` check. + */ + def pickWithMode( + viaCodegen: () => Option[Expr], + viaNonCodegen: () => Option[Expr], + preferCodegenInAuto: Boolean): Option[Expr] = { + CometConf.COMET_CODEGEN_DISPATCH_MODE.get() match { + case CometConf.CODEGEN_DISPATCH_FORCE => + viaCodegen().orElse(viaNonCodegen()) + case CometConf.CODEGEN_DISPATCH_DISABLED => + viaNonCodegen() + case _ => + // auto: serde-declared preference within this mode. + if (preferCodegenInAuto) viaCodegen().orElse(viaNonCodegen()) else viaNonCodegen() + } + } +} object CometStringRepeat extends CometExpressionSerde[StringRepeat] { @@ -264,9 +383,36 @@ object CometLike extends CometExpressionSerde[Like] { object CometRLike extends CometExpressionSerde[RLike] { override def getIncompatibleReasons(): Seq[String] = Seq( - "Uses Rust regexp engine, which has different behavior to Java regexp engine") + s"When ${CometConf.COMET_REGEXP_ENGINE.key}=${CometConf.REGEXP_ENGINE_RUST}: " + + "Uses Rust regexp engine, which has different behavior to Java regexp engine") + + override def getSupportLevel(expr: RLike): SupportLevel = { + if (CometConf.COMET_REGEXP_ENGINE.get() == CometConf.REGEXP_ENGINE_JAVA) { + expr.right match { + case _: Literal => Compatible(None) + case _ => Unsupported(Some("Only scalar regexp patterns are supported")) + } + } else { + super.getSupportLevel(expr) + } + } override def convert(expr: RLike, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { + val javaEngine = CometConf.COMET_REGEXP_ENGINE.get() == CometConf.REGEXP_ENGINE_JAVA + // Rust engine always uses the native DataFusion path regardless of codegen mode. Java + // engine uses the codegen dispatcher; `disabled` falls through to Spark by returning None. + CodegenDispatchSerdeHelpers.pickWithMode( + viaCodegen = () => convertViaJvmUdfGenericCodegen(expr, inputs, binding), + viaNonCodegen = () => + if (javaEngine) None + else convertViaNativeRegex(expr, inputs, binding), + preferCodegenInAuto = javaEngine) + } + + private def convertViaNativeRegex( + expr: RLike, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { expr.right match { case Literal(pattern, DataTypes.StringType) => if (!RegExp.isSupportedPattern(pattern.toString) && @@ -291,6 +437,204 @@ object CometRLike extends CometExpressionSerde[RLike] { None } } + + private def convertViaJvmUdfGenericCodegen( + expr: RLike, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + expr.right match { + case Literal(value, DataTypes.StringType) => + CodegenDispatchSerdeHelpers.validateRegexLiteral(value) match { + case Some(reason) => withInfo(expr, reason); None + case None => + CodegenDispatchSerdeHelpers.buildJvmUdfExpr( + expr, + inputs, + binding, + DataTypes.BooleanType) + } + case _ => + withInfo(expr, "Only scalar regexp patterns are supported") + None + } + } +} + +object CometRegExpExtract extends CometExpressionSerde[RegExpExtract] { + + override def getSupportLevel(expr: RegExpExtract): SupportLevel = { + if (CometConf.COMET_REGEXP_ENGINE.get() == CometConf.REGEXP_ENGINE_JAVA) { + (expr.regexp, expr.idx) match { + case (_: Literal, _: Literal) => Compatible(None) + case (_: Literal, _) => + Unsupported(Some("Only scalar group index is supported")) + case _ => Unsupported(Some("Only scalar regexp patterns are supported")) + } + } else { + Unsupported( + Some( + s"regexp_extract requires ${CometConf.COMET_REGEXP_ENGINE.key}=" + + s"${CometConf.REGEXP_ENGINE_JAVA}")) + } + } + + override def convert( + expr: RegExpExtract, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + if (CometConf.COMET_REGEXP_ENGINE.get() != CometConf.REGEXP_ENGINE_JAVA) { + withInfo( + expr, + s"regexp_extract requires ${CometConf.COMET_REGEXP_ENGINE.key}=" + + s"${CometConf.REGEXP_ENGINE_JAVA}") + return None + } + // No native path exists for regexp_extract; the codegen dispatcher is the only Comet path. + // `disabled` mode falls through to Spark by returning None. + CodegenDispatchSerdeHelpers.pickWithMode( + viaCodegen = () => convertViaJvmUdfGenericCodegen(expr, inputs, binding), + viaNonCodegen = () => None, + preferCodegenInAuto = true) + } + + private def convertViaJvmUdfGenericCodegen( + expr: RegExpExtract, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + (expr.regexp, expr.idx) match { + case (Literal(value, DataTypes.StringType), Literal(_, _: IntegerType)) => + CodegenDispatchSerdeHelpers.validateRegexLiteral(value) match { + case Some(reason) => withInfo(expr, reason); None + case None => + CodegenDispatchSerdeHelpers.buildJvmUdfExpr( + expr, + inputs, + binding, + DataTypes.StringType) + } + case _ => + withInfo(expr, "Only scalar regexp patterns and group index are supported") + None + } + } +} + +object CometRegExpExtractAll extends CometExpressionSerde[RegExpExtractAll] { + + override def getSupportLevel(expr: RegExpExtractAll): SupportLevel = { + if (CometConf.COMET_REGEXP_ENGINE.get() == CometConf.REGEXP_ENGINE_JAVA) { + (expr.regexp, expr.idx) match { + case (_: Literal, _: Literal) => Compatible(None) + case (_: Literal, _) => + Unsupported(Some("Only scalar group index is supported")) + case _ => Unsupported(Some("Only scalar regexp patterns are supported")) + } + } else { + Unsupported( + Some( + s"regexp_extract_all requires ${CometConf.COMET_REGEXP_ENGINE.key}=" + + s"${CometConf.REGEXP_ENGINE_JAVA}")) + } + } + + override def convert( + expr: RegExpExtractAll, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + if (CometConf.COMET_REGEXP_ENGINE.get() != CometConf.REGEXP_ENGINE_JAVA) { + withInfo( + expr, + s"regexp_extract_all requires ${CometConf.COMET_REGEXP_ENGINE.key}=" + + s"${CometConf.REGEXP_ENGINE_JAVA}") + return None + } + // No native path exists for regexp_extract_all; the codegen dispatcher is the only Comet + // path. `disabled` mode falls through to Spark by returning None. + CodegenDispatchSerdeHelpers.pickWithMode( + viaCodegen = () => convertViaJvmUdfGenericCodegen(expr, inputs, binding), + viaNonCodegen = () => None, + preferCodegenInAuto = true) + } + + private def convertViaJvmUdfGenericCodegen( + expr: RegExpExtractAll, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + (expr.regexp, expr.idx) match { + case (Literal(value, DataTypes.StringType), Literal(_, _: IntegerType)) => + CodegenDispatchSerdeHelpers.validateRegexLiteral(value) match { + case Some(reason) => withInfo(expr, reason); None + case None => + CodegenDispatchSerdeHelpers.buildJvmUdfExpr( + expr, + inputs, + binding, + ArrayType(StringType, containsNull = true)) + } + case _ => + withInfo(expr, "Only scalar regexp patterns and group index are supported") + None + } + } +} + +object CometRegExpInStr extends CometExpressionSerde[RegExpInStr] { + + override def getSupportLevel(expr: RegExpInStr): SupportLevel = { + if (CometConf.COMET_REGEXP_ENGINE.get() == CometConf.REGEXP_ENGINE_JAVA) { + (expr.regexp, expr.idx) match { + case (_: Literal, _: Literal) => Compatible(None) + case (_: Literal, _) => + Unsupported(Some("Only scalar group index is supported")) + case _ => Unsupported(Some("Only scalar regexp patterns are supported")) + } + } else { + Unsupported( + Some( + s"regexp_instr requires ${CometConf.COMET_REGEXP_ENGINE.key}=" + + s"${CometConf.REGEXP_ENGINE_JAVA}")) + } + } + + override def convert( + expr: RegExpInStr, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + if (CometConf.COMET_REGEXP_ENGINE.get() != CometConf.REGEXP_ENGINE_JAVA) { + withInfo( + expr, + s"regexp_instr requires ${CometConf.COMET_REGEXP_ENGINE.key}=" + + s"${CometConf.REGEXP_ENGINE_JAVA}") + return None + } + // No native path exists for regexp_instr; the codegen dispatcher is the only Comet path. + // `disabled` mode falls through to Spark by returning None. + CodegenDispatchSerdeHelpers.pickWithMode( + viaCodegen = () => convertViaJvmUdfGenericCodegen(expr, inputs, binding), + viaNonCodegen = () => None, + preferCodegenInAuto = true) + } + + private def convertViaJvmUdfGenericCodegen( + expr: RegExpInStr, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + (expr.regexp, expr.idx) match { + case (Literal(value, DataTypes.StringType), Literal(_, _: IntegerType)) => + CodegenDispatchSerdeHelpers.validateRegexLiteral(value) match { + case Some(reason) => withInfo(expr, reason); None + case None => + CodegenDispatchSerdeHelpers.buildJvmUdfExpr( + expr, + inputs, + binding, + DataTypes.IntegerType) + } + case _ => + withInfo(expr, "Only scalar regexp patterns and group index are supported") + None + } + } } object CometStringRPad extends CometExpressionSerde[StringRPad] { @@ -352,23 +696,28 @@ object CometStringLPad extends CometExpressionSerde[StringLPad] { object CometRegExpReplace extends CometExpressionSerde[RegExpReplace] { override def getIncompatibleReasons(): Seq[String] = Seq( - "Regexp pattern may not be compatible with Spark") + s"When ${CometConf.COMET_REGEXP_ENGINE.key}=${CometConf.REGEXP_ENGINE_RUST}: " + + "Regexp pattern may not be compatible with Spark") override def getUnsupportedReasons(): Seq[String] = Seq( "Only supports `regexp_replace` with an offset of 1 (no offset)") override def getSupportLevel(expr: RegExpReplace): SupportLevel = { - if (!RegExp.isSupportedPattern(expr.regexp.toString) && - !CometConf.isExprAllowIncompat("regexp")) { - withInfo( - expr, - s"Regexp pattern ${expr.regexp} is not compatible with Spark. " + - s"Set ${CometConf.getExprAllowIncompatConfigKey("regexp")}=true " + - "to allow it anyway.") - return Incompatible() - } expr.pos match { - case Literal(value, DataTypes.IntegerType) if value == 1 => Compatible() + case Literal(value, DataTypes.IntegerType) if value == 1 => + if (CometConf.COMET_REGEXP_ENGINE.get() == CometConf.REGEXP_ENGINE_JAVA) { + expr.regexp match { + case _: Literal => Compatible(None) + case _ => Unsupported(Some("Only scalar regexp patterns are supported")) + } + } else { + if (!RegExp.isSupportedPattern(expr.regexp.toString) && + !CometConf.isExprAllowIncompat("regexp")) { + Incompatible() + } else { + Compatible() + } + } case _ => Unsupported(Some("Comet only supports regexp_replace with an offset of 1 (no offset).")) } @@ -378,6 +727,30 @@ object CometRegExpReplace extends CometExpressionSerde[RegExpReplace] { expr: RegExpReplace, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { + val javaEngine = CometConf.COMET_REGEXP_ENGINE.get() == CometConf.REGEXP_ENGINE_JAVA + // Rust engine always uses the native DataFusion path regardless of codegen mode. Java + // engine uses the codegen dispatcher; `disabled` falls through to Spark by returning None. + CodegenDispatchSerdeHelpers.pickWithMode( + viaCodegen = () => convertViaJvmUdfGenericCodegen(expr, inputs, binding), + viaNonCodegen = () => + if (javaEngine) None + else convertViaNativeRegex(expr, inputs, binding), + preferCodegenInAuto = javaEngine) + } + + private def convertViaNativeRegex( + expr: RegExpReplace, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + if (!RegExp.isSupportedPattern(expr.regexp.toString) && + !CometConf.isExprAllowIncompat("regexp")) { + withInfo( + expr, + s"Regexp pattern ${expr.regexp} is not compatible with Spark. " + + s"Set ${CometConf.getExprAllowIncompatConfigKey("regexp")}=true " + + "to allow it anyway.") + return None + } val subjectExpr = exprToProtoInternal(expr.subject, inputs, binding) val patternExpr = exprToProtoInternal(expr.regexp, inputs, binding) val replacementExpr = exprToProtoInternal(expr.rep, inputs, binding) @@ -392,6 +765,27 @@ object CometRegExpReplace extends CometExpressionSerde[RegExpReplace] { flagsExpr) optExprWithInfo(optExpr, expr, expr.subject, expr.regexp, expr.rep, expr.pos) } + + private def convertViaJvmUdfGenericCodegen( + expr: RegExpReplace, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + expr.regexp match { + case Literal(value, DataTypes.StringType) => + CodegenDispatchSerdeHelpers.validateRegexLiteral(value) match { + case Some(reason) => withInfo(expr, reason); None + case None => + CodegenDispatchSerdeHelpers.buildJvmUdfExpr( + expr, + inputs, + binding, + DataTypes.StringType) + } + case _ => + withInfo(expr, "Only scalar regexp patterns are supported") + None + } + } } /** @@ -402,15 +796,39 @@ object CometRegExpReplace extends CometExpressionSerde[RegExpReplace] { object CometStringSplit extends CometExpressionSerde[StringSplit] { override def getIncompatibleReasons(): Seq[String] = Seq( - "Regex engine differences between Java and Rust") - - override def getSupportLevel(expr: StringSplit): SupportLevel = - Incompatible(Some("Regex engine differences between Java and Rust")) + s"When ${CometConf.COMET_REGEXP_ENGINE.key}=${CometConf.REGEXP_ENGINE_RUST}: " + + "Regex engine differences between Java and Rust") + + override def getSupportLevel(expr: StringSplit): SupportLevel = { + if (CometConf.COMET_REGEXP_ENGINE.get() == CometConf.REGEXP_ENGINE_JAVA) { + expr.regex match { + case _: Literal => Compatible(None) + case _ => Unsupported(Some("Only scalar regex patterns are supported")) + } + } else { + Incompatible(Some("Regex engine differences between Java and Rust")) + } + } override def convert( expr: StringSplit, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { + val javaEngine = CometConf.COMET_REGEXP_ENGINE.get() == CometConf.REGEXP_ENGINE_JAVA + // Rust engine always uses the native DataFusion path regardless of codegen mode. Java + // engine uses the codegen dispatcher; `disabled` falls through to Spark by returning None. + CodegenDispatchSerdeHelpers.pickWithMode( + viaCodegen = () => convertViaJvmUdfGenericCodegen(expr, inputs, binding), + viaNonCodegen = () => + if (javaEngine) None + else convertViaNativeRegex(expr, inputs, binding), + preferCodegenInAuto = javaEngine) + } + + private def convertViaNativeRegex( + expr: StringSplit, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { val strExpr = exprToProtoInternal(expr.str, inputs, binding) val regexExpr = exprToProtoInternal(expr.regex, inputs, binding) val limitExpr = exprToProtoInternal(expr.limit, inputs, binding) @@ -423,6 +841,27 @@ object CometStringSplit extends CometExpressionSerde[StringSplit] { limitExpr) optExprWithInfo(optExpr, expr, expr.str, expr.regex, expr.limit) } + + private def convertViaJvmUdfGenericCodegen( + expr: StringSplit, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + expr.regex match { + case Literal(value, DataTypes.StringType) => + CodegenDispatchSerdeHelpers.validateRegexLiteral(value) match { + case Some(reason) => withInfo(expr, reason); None + case None => + CodegenDispatchSerdeHelpers.buildJvmUdfExpr( + expr, + inputs, + binding, + ArrayType(StringType, containsNull = false)) + } + case _ => + withInfo(expr, "Only scalar regex patterns are supported") + None + } + } } object CometGetJsonObject extends CometExpressionSerde[GetJsonObject] { diff --git a/spark/src/test/resources/sql-tests/expressions/string/regexp_replace.sql b/spark/src/test/resources/sql-tests/expressions/string/regexp_replace.sql deleted file mode 100644 index 967674a894..0000000000 --- a/spark/src/test/resources/sql-tests/expressions/string/regexp_replace.sql +++ /dev/null @@ -1,28 +0,0 @@ --- Licensed to the Apache Software Foundation (ASF) under one --- or more contributor license agreements. See the NOTICE file --- distributed with this work for additional information --- regarding copyright ownership. The ASF licenses this file --- to you under the Apache License, Version 2.0 (the --- "License"); you may not use this file except in compliance --- with the License. You may obtain a copy of the License at --- --- http://www.apache.org/licenses/LICENSE-2.0 --- --- Unless required by applicable law or agreed to in writing, --- software distributed under the License is distributed on an --- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY --- KIND, either express or implied. See the License for the --- specific language governing permissions and limitations --- under the License. - -statement -CREATE TABLE test_regexp_replace(s string) USING parquet - -statement -INSERT INTO test_regexp_replace VALUES ('100-200'), ('abc'), (''), (NULL), ('phone 123-456-7890') - -query expect_fallback(Regexp pattern) -SELECT regexp_replace(s, '(\\d+)', 'X') FROM test_regexp_replace - -query expect_fallback(Regexp pattern) -SELECT regexp_replace(s, '(\\d+)', 'X', 1) FROM test_regexp_replace diff --git a/spark/src/test/resources/sql-tests/expressions/string/regexp_replace_enabled.sql b/spark/src/test/resources/sql-tests/expressions/string/regexp_replace_enabled.sql deleted file mode 100644 index 97b4917c33..0000000000 --- a/spark/src/test/resources/sql-tests/expressions/string/regexp_replace_enabled.sql +++ /dev/null @@ -1,35 +0,0 @@ --- Licensed to the Apache Software Foundation (ASF) under one --- or more contributor license agreements. See the NOTICE file --- distributed with this work for additional information --- regarding copyright ownership. The ASF licenses this file --- to you under the Apache License, Version 2.0 (the --- "License"); you may not use this file except in compliance --- with the License. You may obtain a copy of the License at --- --- http://www.apache.org/licenses/LICENSE-2.0 --- --- Unless required by applicable law or agreed to in writing, --- software distributed under the License is distributed on an --- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY --- KIND, either express or implied. See the License for the --- specific language governing permissions and limitations --- under the License. - --- Test regexp_replace() with regexp allowIncompatible enabled (happy path) --- Config: spark.comet.expression.regexp.allowIncompatible=true - -statement -CREATE TABLE test_regexp_replace_enabled(s string) USING parquet - -statement -INSERT INTO test_regexp_replace_enabled VALUES ('100-200'), ('abc'), (''), (NULL), ('phone 123-456-7890') - -query -SELECT regexp_replace(s, '(\d+)', 'X') FROM test_regexp_replace_enabled - -query -SELECT regexp_replace(s, '(\d+)', 'X', 1) FROM test_regexp_replace_enabled - --- literal + literal + literal -query -SELECT regexp_replace('100-200', '(\d+)', 'X'), regexp_replace('abc', '(\d+)', 'X'), regexp_replace(NULL, '(\d+)', 'X') diff --git a/spark/src/test/resources/sql-tests/expressions/string/rlike.sql b/spark/src/test/resources/sql-tests/expressions/string/rlike.sql deleted file mode 100644 index 97350918ba..0000000000 --- a/spark/src/test/resources/sql-tests/expressions/string/rlike.sql +++ /dev/null @@ -1,31 +0,0 @@ --- Licensed to the Apache Software Foundation (ASF) under one --- or more contributor license agreements. See the NOTICE file --- distributed with this work for additional information --- regarding copyright ownership. The ASF licenses this file --- to you under the Apache License, Version 2.0 (the --- "License"); you may not use this file except in compliance --- with the License. You may obtain a copy of the License at --- --- http://www.apache.org/licenses/LICENSE-2.0 --- --- Unless required by applicable law or agreed to in writing, --- software distributed under the License is distributed on an --- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY --- KIND, either express or implied. See the License for the --- specific language governing permissions and limitations --- under the License. - -statement -CREATE TABLE test_rlike(s string) USING parquet - -statement -INSERT INTO test_rlike VALUES ('hello'), ('12345'), (''), (NULL), ('Hello World'), ('abc123') - -query expect_fallback(Regexp pattern) -SELECT s RLIKE '^[0-9]+$' FROM test_rlike - -query expect_fallback(Regexp pattern) -SELECT s RLIKE '^[a-z]+$' FROM test_rlike - -query spark_answer_only -SELECT s RLIKE '' FROM test_rlike diff --git a/spark/src/test/resources/sql-tests/expressions/string/rlike_enabled.sql b/spark/src/test/resources/sql-tests/expressions/string/rlike_enabled.sql deleted file mode 100644 index 5b2bd05fb3..0000000000 --- a/spark/src/test/resources/sql-tests/expressions/string/rlike_enabled.sql +++ /dev/null @@ -1,38 +0,0 @@ --- Licensed to the Apache Software Foundation (ASF) under one --- or more contributor license agreements. See the NOTICE file --- distributed with this work for additional information --- regarding copyright ownership. The ASF licenses this file --- to you under the Apache License, Version 2.0 (the --- "License"); you may not use this file except in compliance --- with the License. You may obtain a copy of the License at --- --- http://www.apache.org/licenses/LICENSE-2.0 --- --- Unless required by applicable law or agreed to in writing, --- software distributed under the License is distributed on an --- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY --- KIND, either express or implied. See the License for the --- specific language governing permissions and limitations --- under the License. - --- Test RLIKE with regexp allowIncompatible enabled (happy path) --- Config: spark.comet.expression.regexp.allowIncompatible=true - -statement -CREATE TABLE test_rlike_enabled(s string) USING parquet - -statement -INSERT INTO test_rlike_enabled VALUES ('hello'), ('12345'), (''), (NULL), ('Hello World'), ('abc123') - -query -SELECT s RLIKE '^[0-9]+$' FROM test_rlike_enabled - -query -SELECT s RLIKE '^[a-z]+$' FROM test_rlike_enabled - -query -SELECT s RLIKE '' FROM test_rlike_enabled - --- literal arguments -query -SELECT 'hello' RLIKE '^[a-z]+$', '12345' RLIKE '^[a-z]+$', '' RLIKE '', NULL RLIKE 'a' diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index 63936a94b7..0ab429a383 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -236,7 +236,12 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp test("ArrayInsertUnsupportedArgs") { // This test checks that the else branch in ArrayInsert // mapping to the comet is valid and fallback to spark is working fine. - withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[ArrayInsert]) -> "true") { + // Disable the codegen dispatcher so the `idx` ScalaUDF child returns None from its serde, + // which is what drives ArrayInsert's "unsupported arguments" branch. With the dispatcher + // enabled, ScalaUDF routes through codegen and the whole plan runs native. + withSQLConf( + CometConf.COMET_CODEGEN_DISPATCH_MODE.key -> CometConf.CODEGEN_DISPATCH_DISABLED, + CometConf.getExprAllowIncompatConfigKey(classOf[ArrayInsert]) -> "true") { withTempDir { dir => val path = new Path(dir.toURI.toString, "test.parquet") makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = false, 10000) @@ -247,7 +252,7 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp .withColumn("arrUnsupportedArgs", expr("array_insert(arr, idx, 1)")) checkSparkAnswerAndFallbackReasons( df.select("arrUnsupportedArgs"), - Set("scalaudf is not supported", "unsupported arguments for ArrayInsert")) + Set("codegen dispatch disabled", "unsupported arguments for ArrayInsert")) } } } diff --git a/spark/src/test/scala/org/apache/comet/CometCodegenDispatchFuzzSuite.scala b/spark/src/test/scala/org/apache/comet/CometCodegenDispatchFuzzSuite.scala new file mode 100644 index 0000000000..03d19d0bb2 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometCodegenDispatchFuzzSuite.scala @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import scala.util.Random + +import org.apache.spark.SparkConf +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper + +import org.apache.comet.udf.CometCodegenDispatchUDF + +/** + * Randomized tests for the Arrow-direct codegen dispatcher. Generates string inputs at varying + * null densities and runs a fixed set of regex patterns through both Spark and the codegen + * dispatcher, asserting results agree. Fixes a seed per test for reproducibility. + * + * Scope of this pass: the string surface the dispatcher currently exercises end to end (rlike and + * regexp_replace). Broader cross-type fuzz, including primitive inputs, multi-column expressions, + * and view-type variants, lands once more serdes route through codegen dispatch. + * + * Pinned to `mode=force` so every eligible query is guaranteed to route through the dispatcher + * rather than the hand-coded regex UDF, keeping the fuzz focused on the codegen path. + */ +class CometCodegenDispatchFuzzSuite extends CometTestBase with AdaptiveSparkPlanHelper { + + override protected def sparkConf: SparkConf = + super.sparkConf + .set(CometConf.COMET_REGEXP_ENGINE.key, CometConf.REGEXP_ENGINE_JAVA) + .set(CometConf.COMET_CODEGEN_DISPATCH_MODE.key, CometConf.CODEGEN_DISPATCH_FORCE) + + private val RowCount: Int = 512 + private val MaxStringLen: Int = 32 + + /** + * Characters the generator picks from. Mix of digits, lowercase, uppercase, and a couple of + * non-alphanumerics to exercise classes, anchors, and alternations. + */ + private val charPalette: Array[Char] = + ("0123456789" + + "abcdefghijklmnopqrstuvwxyz" + + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + + "_-. ").toCharArray + + private def randomString(rng: Random): String = { + val len = rng.nextInt(MaxStringLen + 1) + val sb = new StringBuilder(len) + var i = 0 + while (i < len) { + sb.append(charPalette(rng.nextInt(charPalette.length))) + i += 1 + } + sb.toString + } + + /** + * Generate `RowCount` strings with the requested null density. Seeded for determinism. Empty + * strings and nulls are both part of the distribution when density > 0. + */ + private def generateSubjects(seed: Long, nullDensity: Double): Seq[String] = { + val rng = new Random(seed) + (0 until RowCount).map { _ => + if (rng.nextDouble() < nullDensity) null + else randomString(rng) + } + } + + /** + * Resets dispatcher stats, runs `f`, then asserts the codegen path actually ran for at least + * one batch. Without this, a silent serde fallback would let the fuzz pass trivially because + * both Spark and whatever-Comet-ran-instead agree with Spark. + */ + private def assertCodegenRan(f: => Unit): Unit = { + CometCodegenDispatchUDF.resetStats() + f + val after = CometCodegenDispatchUDF.stats() + assert( + after.compileCount + after.cacheHitCount >= 1, + s"expected at least one codegen dispatcher invocation during this query, got $after") + } + + /** Create a temp table `t(s STRING)` populated with the given subjects, run `f`, then drop. */ + private def withSubjectTable(subjects: Seq[String])(f: => Unit): Unit = { + withTable("t") { + sql("CREATE TABLE t (s STRING) USING parquet") + if (subjects.nonEmpty) { + val escaped = subjects.map { v => + if (v == null) "(NULL)" else s"('${v.replace("'", "''")}')" + } + // Insert in chunks so the generated VALUES list doesn't blow the SQL parser. + escaped.grouped(64).foreach { batch => + sql(s"INSERT INTO t VALUES ${batch.mkString(", ")}") + } + } + f + } + } + + // Regex patterns chosen to span common rlike shapes and the Java-only backreference feature. + // All are Spark-compatible under the java regex engine the codegen path uses. + private val rlikePatterns: Seq[String] = + Seq("\\\\d+", "^[a-z]", "[A-Z][0-9]+", "(ab){2,}", "^(\\\\w)\\\\1", "_.*\\\\.", "^$") + + // regexp_replace (pattern, replacement) pairs. Mix of no-match, narrow match, wide match. + private val regexpReplacePatterns: Seq[(String, String)] = Seq( + "\\\\d+" -> "N", + "[a-z]+" -> "L", + "[aeiouAEIOU]" -> "*", + "xyzzy" -> "", + "\\\\s+" -> "_") + + private val nullDensities: Seq[Double] = Seq(0.0, 0.1, 0.5, 1.0) + + for { + density <- nullDensities + pattern <- rlikePatterns + } { + test(s"rlike pattern='$pattern' nullDensity=$density") { + val subjects = generateSubjects(seed = pattern.hashCode.toLong ^ density.hashCode, density) + withSubjectTable(subjects) { + assertCodegenRan { + checkSparkAnswerAndOperator(sql(s"SELECT s rlike '$pattern' FROM t")) + } + } + } + } + + for { + density <- nullDensities + (pattern, replacement) <- regexpReplacePatterns + } { + test(s"regexp_replace pattern='$pattern' replacement='$replacement' nullDensity=$density") { + val seed = (pattern + replacement).hashCode.toLong ^ density.hashCode + val subjects = generateSubjects(seed = seed, density) + withSubjectTable(subjects) { + assertCodegenRan { + checkSparkAnswerAndOperator( + sql(s"SELECT regexp_replace(s, '$pattern', '$replacement') FROM t")) + } + } + } + } + + /** + * Multi-column fuzz via expression composition. The rlike serde is single-input from its own + * point of view, but its subject can be an arbitrary sub-expression that references multiple + * columns. `concat(c1, c2) rlike 'pat'` is the simplest such shape, and it exercises the + * kernel's two-column `inputSchema` path plus the NullIntolerant short-circuit gating (Concat + * is not NullIntolerant, so the whole-tree guard in `defaultBody` must skip the short-circuit + * for this shape; Spark's own Concat codegen handles nulls correctly). + */ + private def withTwoColumnTable(c1Values: Seq[String], c2Values: Seq[String])( + f: => Unit): Unit = { + require( + c1Values.length == c2Values.length, + s"columns must be same length: c1=${c1Values.length}, c2=${c2Values.length}") + withTable("t") { + sql("CREATE TABLE t (c1 STRING, c2 STRING) USING parquet") + if (c1Values.nonEmpty) { + val rows = c1Values.zip(c2Values).map { case (a, b) => + val av = if (a == null) "NULL" else s"'${a.replace("'", "''")}'" + val bv = if (b == null) "NULL" else s"'${b.replace("'", "''")}'" + s"($av, $bv)" + } + rows.grouped(64).foreach { batch => + sql(s"INSERT INTO t VALUES ${batch.mkString(", ")}") + } + } + f + } + } + + private val twoColumnPatterns: Seq[String] = Seq("[0-9]+", "^[a-z]", "[A-Z][0-9]+") + private val perColumnNullDensities: Seq[Double] = Seq(0.0, 0.25, 1.0) + + for { + d1 <- perColumnNullDensities + d2 <- perColumnNullDensities + pattern <- twoColumnPatterns + } { + test(s"concat(c1,c2) rlike '$pattern' nullDensity=($d1,$d2)") { + val seed = (pattern.hashCode.toLong ^ d1.hashCode) * 31 + d2.hashCode + val c1 = generateSubjects(seed, d1) + val c2 = generateSubjects(seed ^ 0x5f3759df, d2) + withTwoColumnTable(c1, c2) { + assertCodegenRan { + checkSparkAnswerAndOperator(sql(s"SELECT concat(c1, c2) rlike '$pattern' FROM t")) + } + } + } + } + + /** + * Randomized decimal identity UDF. Spans both sides of the `Decimal.MAX_LONG_DIGITS` (18) + * boundary so each test hits one of the two specialized branches in the generated `getDecimal` + * getter. Precisions are chosen to exercise: small short-precision, boundary short-precision + * with varying scale, just-past-boundary long precision, and the max decimal128 precision. + */ + private def generateDecimals( + seed: Long, + precision: Int, + scale: Int, + nullDensity: Double): Seq[java.math.BigDecimal] = { + val rng = new Random(seed) + val intDigits = precision - scale + // `BigInt.apply(bits, rng)` samples uniformly on `[0, 2^bits - 1]`; bound to the decimal's + // integer-part range (10^intDigits - 1) so the result fits the schema. `BigInteger.bitLength` + // would overshoot slightly; min with the exact max is cheap insurance. + val intMax = BigInt(10).pow(intDigits) - 1 + val bits = math.max(intMax.bitLength, 1) + (0 until RowCount).map { _ => + if (rng.nextDouble() < nullDensity) null + else { + val mag = BigInt(bits, rng).min(intMax) + val signed = if (rng.nextBoolean()) -mag else mag + new java.math.BigDecimal(signed.bigInteger, scale) + } + } + } + + private def withDecimalTable(decimalType: String, values: Seq[java.math.BigDecimal])( + f: => Unit): Unit = { + withTable("t") { + sql(s"CREATE TABLE t (d $decimalType) USING parquet") + if (values.nonEmpty) { + val rows = values.map { v => + if (v == null) "(NULL)" else s"(${v.toPlainString})" + } + rows.grouped(64).foreach { batch => + sql(s"INSERT INTO t VALUES ${batch.mkString(", ")}") + } + } + f + } + } + + // (precision, scale) pairs spanning both sides of the MAX_LONG_DIGITS=18 boundary. + private val decimalShapes: Seq[(Int, Int)] = Seq((9, 2), (18, 0), (18, 9), (19, 0), (38, 10)) + + for { + density <- nullDensities + (precision, scale) <- decimalShapes + } { + test(s"decimal identity precision=$precision scale=$scale nullDensity=$density") { + // Reuse one registered UDF name across iterations; Spark replaces by name. The Scala-side + // signature uses `BigDecimal`, which Spark encodes as DecimalType(38, 18); an implicit Cast + // from the column's DecimalType to the UDF's parameter type runs inside Spark's generated + // code, but the column read still goes through our kernel's `getDecimal` which is the path + // we're fuzzing. + spark.udf.register("dec_id_fuzz", (d: java.math.BigDecimal) => d) + val seed = ((precision * 31L) + scale) * 31L + density.hashCode + val values = generateDecimals(seed, precision, scale, density) + withDecimalTable(s"DECIMAL($precision, $scale)", values) { + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT dec_id_fuzz(d) FROM t")) + } + } + } + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometCodegenDispatchSmokeSuite.scala b/spark/src/test/scala/org/apache/comet/CometCodegenDispatchSmokeSuite.scala new file mode 100644 index 0000000000..faac3643ea --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometCodegenDispatchSmokeSuite.scala @@ -0,0 +1,1234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.apache.arrow.vector.{BigIntVector, BitVector, DateDayVector, Float4Vector, Float8Vector, IntVector, SmallIntVector, TimeStampMicroTZVector, TimeStampMicroVector, TinyIntVector, ValueVector, VarCharVector} +import org.apache.spark.{SparkConf, TaskContext} +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.types.{BinaryType, BooleanType, ByteType, DataType, DateType, DoubleType, FloatType, IntegerType, LongType, ShortType, StringType, TimestampNTZType, TimestampType} + +import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus +import org.apache.comet.udf.CometCodegenDispatchUDF + +/** + * Smoke tests for the Arrow-direct codegen dispatcher. Runs rlike and regexp_replace queries and + * asserts results match Spark. Widens to more expression shapes as the productionization plan + * lands supporting types and plan-time dispatchability. + */ +class CometCodegenDispatchSmokeSuite extends CometTestBase with AdaptiveSparkPlanHelper { + + override protected def sparkConf: SparkConf = + super.sparkConf + .set(CometConf.COMET_REGEXP_ENGINE.key, CometConf.REGEXP_ENGINE_JAVA) + // `auto` would also route rlike/regexp_replace to codegen when engine=java, but `force` + // guarantees it and exercises the codegen path regardless of future auto-mode tuning. + .set(CometConf.COMET_CODEGEN_DISPATCH_MODE.key, CometConf.CODEGEN_DISPATCH_FORCE) + + private def withSubjects(values: String*)(f: => Unit): Unit = { + withTable("t") { + sql("CREATE TABLE t (s STRING) USING parquet") + val rows = values + .map(v => if (v == null) "(NULL)" else s"('${v.replace("'", "''")}')") + .mkString(", ") + sql(s"INSERT INTO t VALUES $rows") + f + } + } + + test("rlike projection with null handling") { + withSubjects("abc123", "no digits", null, "mixed_42_data") { + checkSparkAnswerAndOperator(sql("SELECT s, s rlike '\\\\d+' AS m FROM t")) + } + } + + test("rlike predicate") { + withSubjects("abc123", "no digits", null, "mixed_42_data") { + checkSparkAnswerAndOperator(sql("SELECT s FROM t WHERE s rlike '\\\\d+'")) + } + } + + test("rlike with backreference (Java-only)") { + withSubjects("aa", "ab", "xyzzy", null) { + checkSparkAnswerAndOperator(sql("SELECT s, s rlike '^(\\\\w)\\\\1$' FROM t")) + } + } + + test("rlike on all-null column") { + withSubjects(null, null, null) { + checkSparkAnswerAndOperator(sql("SELECT s rlike '\\\\d+' FROM t")) + } + } + + test("rlike empty pattern matches every non-null row") { + withSubjects("a", "", null, "bc") { + checkSparkAnswerAndOperator(sql("SELECT s, s rlike '' FROM t")) + } + } + + test("regexp_replace digits with a token") { + withSubjects("abc123", "no digits", null, "mixed_42_data") { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_replace(s, '\\\\d+', 'N') FROM t")) + } + } + + test("regexp_replace with empty replacement") { + withSubjects("abc123def", "no digits", null, "") { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_replace(s, '\\\\d+', '') FROM t")) + } + } + + test("regexp_replace no-match preserves input") { + withSubjects("abc", "xyz", null) { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_replace(s, '\\\\d+', 'N') FROM t")) + } + } + + /** + * Composition smoke tests. Demonstrate that the codegen dispatcher handles nested expression + * trees in one compile per (tree, schema) pair, not one JNI hop per sub-expression. Each test + * wraps the query in `assertCodegenDidWork` to prove the codegen path ran rather than silently + * falling back to Spark. + */ + private def assertCodegenDidWork(f: => Unit): Unit = { + CometCodegenDispatchUDF.resetStats() + f + val after = CometCodegenDispatchUDF.stats() + assert( + after.compileCount + after.cacheHitCount >= 1, + s"expected codegen dispatcher activity, got $after") + } + + /** + * Stronger form of [[assertCodegenDidWork]] for composition tests. Asserts that the full + * expression subtree compiled into at most one kernel. The "one JNI crossing per nesting level" + * alternative (the PR description's foil) would produce one `(bytes, specs)` cache entry per + * nested sub-expression, so `compileCount` would be N and the cache would grow by N after the + * first batch. Asserting `compileCount <= 1` and `cacheSize` growth `<= 1` directly falsifies + * that shape. + * + * Uses `<=` rather than `==` because the compile cache is JVM-wide and shared across tests; a + * prior test that already compiled the same `(expression bytes, input schema)` pair will make + * this run a cache hit (`compileCount == 0`). The dispatcher-activity check guards against a + * silent fallback where the query runs through Spark and the first two assertions pass + * vacuously. + */ + private def assertOneKernelForSubtree(f: => Unit): Unit = { + CometCodegenDispatchUDF.resetStats() + val sizeBefore = CometCodegenDispatchUDF.stats().cacheSize + f + val after = CometCodegenDispatchUDF.stats() + assert(after.compileCount <= 1, s"expected <= 1 compile for the composed subtree, got $after") + val grew = after.cacheSize - sizeBefore + assert(grew <= 1, s"expected cache to grow by <= 1 entry, grew by $grew; stats=$after") + assert( + after.compileCount + after.cacheHitCount >= 1, + s"expected codegen dispatcher activity, got $after") + } + + /** + * Assert that the dispatcher's compile cache contains a kernel compiled for the given input + * Arrow vector classes (in ordinal order) and output Spark `DataType`. This is a specialization + * check: the dispatcher is supposed to bake the concrete Arrow vector class into the generated + * kernel, and the cache key reflects that. If a future change accidentally loses that + * discrimination, `checkSparkAnswerAndOperator` would still pass (Spark computes the right + * answer) but this assertion would fail. + * + * Asserts presence in the cache, not newness. The cache is JVM-wide and shared across tests; if + * a prior test already compiled the same signature, that still counts. Combined with + * `assertCodegenDidWork` (which proves the dispatcher ran in this test), the pair gives both + * "this test exercised the dispatcher" and "the dispatcher's cache has a kernel of the expected + * shape". + * + * Compares by simple name because the `common` module shades `org.apache.arrow`, so a direct + * class-identity check against `classOf[VarCharVector]` at this call site (unshaded) misses the + * shaded classes the dispatcher actually uses internally. + */ + private def assertKernelSignaturePresent( + inputs: Seq[Class[_ <: ValueVector]], + output: DataType): Unit = { + val sigs = CometCodegenDispatchUDF.snapshotCompiledSignatures() + val expectedNames = inputs.map(_.getSimpleName).toIndexedSeq + val present = sigs.exists { case (cached, dt) => + dt == output && cached.map(_.getSimpleName) == expectedNames + } + assert( + present, + s"expected kernel signature $expectedNames -> $output; " + + s"cache had ${sigs.map { case (c, d) => (c.map(_.getSimpleName), d) }}") + } + + test("compose upper(s) rlike pattern") { + // The serde binds the whole tree, including the Upper, and ships it to the codegen + // dispatcher. Inside the kernel, Upper.doGenCode emits `this.getUTF8String(0).toUpperCase()` + // which feeds directly into the Matcher check. No second JNI hop for Upper. + withSubjects("Abc123", "NO DIGITS", null, "mixed_42") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT s, upper(s) rlike '[A-Z0-9]+' FROM t")) + } + } + } + + test("compose regexp_replace(upper(s), pattern, replacement)") { + // Upper as the subject of RegExpReplace defeats the specialized emitter (its fast path + // requires a direct BoundReference subject). Falls to the default path, which still compiles + // cleanly as one fused method because Spark's doGenCode for Upper -> RegExpReplace is + // self-contained. + withSubjects("Abc123", "no digits", null, "Mix42") { + assertCodegenDidWork { + checkSparkAnswerAndOperator( + sql("SELECT s, regexp_replace(upper(s), '[0-9]+', '#') FROM t")) + } + } + } + + test("compose upper(regexp_replace(s, pattern, replacement))") { + // Flip the nesting: RegExpReplace is inside, Upper is outside. Still one compile per + // (tree, schema) pair; the outer Upper's doGenCode consumes the RegExpReplace result as a + // UTF8String in the same generated method. Case conversion is enabled because the inputs + // are ASCII-only (the conf guards against locale-specific divergence, which does not apply + // here). + withSQLConf(CometConf.COMET_CASE_CONVERSION_ENABLED.key -> "true") { + withSubjects("Abc123", "no digits", null, "Mix42") { + assertCodegenDidWork { + checkSparkAnswerAndOperator( + sql("SELECT s, upper(regexp_replace(s, '[0-9]+', 'n')) FROM t")) + } + } + } + } + + test("compose substring(upper(s), 1, 3)") { + // Three levels: BoundReference, Upper, Substring. Substring takes two literal ints; its + // subject is the Upper result. Exercises multiple intermediate UTF8String operations in the + // generated fused method. + withSQLConf(CometConf.COMET_CASE_CONVERSION_ENABLED.key -> "true") { + withSubjects("abcdef", null, "X", "hello world") { + assertCodegenDidWork { + checkSparkAnswerAndOperator( + sql("SELECT s, substring(upper(s), 1, 3) rlike '^[A-Z]+$' FROM t")) + } + } + } + } + + test("regexp_extract (StringType output) routes through dispatcher") { + // regexp_extract has no native path in Comet, so the mode knob decides codegen vs + // hand-coded. Under the suite's `force` default, codegen runs. + withSubjects("abc123", "no digits", null, "mix42data") { + assertCodegenDidWork { + checkSparkAnswerAndOperator( + sql("SELECT s, regexp_extract(s, '([a-z]+)([0-9]+)', 2) FROM t")) + } + } + } + + test("regexp_instr (IntegerType output) routes through dispatcher") { + // regexp_instr exercises the IntegerType output writer end to end for the first time since + // Phase 2b added the allocator/writer; no prior end-to-end serde produced int output. + withSubjects("abc123", "no digits", null, "mix42data") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT s, regexp_instr(s, '[0-9]+', 0) FROM t")) + } + } + } + + /** + * Multi-column smoke tests. The dispatcher compiles the whole bound expression tree, including + * composed sub-expressions that reference multiple columns. Verify end-to-end correctness + * against Spark for a handful of representative shapes. + */ + private def withTwoStringCols(rows: (String, String)*)(f: => Unit): Unit = { + withTable("t") { + sql("CREATE TABLE t (c1 STRING, c2 STRING) USING parquet") + if (rows.nonEmpty) { + val tuples = rows.map { case (a, b) => + val av = if (a == null) "NULL" else s"'${a.replace("'", "''")}'" + val bv = if (b == null) "NULL" else s"'${b.replace("'", "''")}'" + s"($av, $bv)" + } + sql(s"INSERT INTO t VALUES ${tuples.mkString(", ")}") + } + f + } + } + + test("concat(c1, c2) rlike 'pat' compiles over two columns") { + // Concat is not NullIntolerant. The dispatcher's short-circuit guard should skip the + // whole-tree short-circuit and let Spark's Concat codegen handle nulls correctly. + withTwoStringCols(("abc", "123"), ("abc", null), (null, "123"), (null, null), ("zz", "zz")) { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT concat(c1, c2) rlike '[a-z]+[0-9]+' FROM t")) + } + } + } + + test("concat(upper(c1), c2) rlike 'pat' nests Upper inside Concat") { + // Upper is NullIntolerant; Concat is not. The tree still has a non-NullIntolerant node so + // the short-circuit must not apply. Exercises mixed-trait composition. + withSQLConf(CometConf.COMET_CASE_CONVERSION_ENABLED.key -> "true") { + withTwoStringCols(("abc", "123"), ("abc", null), (null, "zz"), (null, null)) { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT concat(upper(c1), c2) rlike '[A-Z]+' FROM t")) + } + } + } + } + + test("regexp_replace(c1, literal, c2-ignored-literal) two columns in tree") { + // Verifies that a second column reference outside the subject (here as a literal + // replacement) still routes through. Note: regexp_replace requires literal regex and + // replacement, so this is the only realistic two-column shape for that serde. + withTwoStringCols(("abc123", "Z"), ("xyz", null), (null, "foo")) { + assertCodegenDidWork { + checkSparkAnswerAndOperator( + sql("SELECT regexp_replace(concat(c1, c2), '[0-9]+', 'N') FROM t")) + } + } + } + + test("disabled mode bypasses the dispatcher") { + // In `disabled`, the rlike serde returns None and the expression falls back to Spark. The + // dispatcher's counters should not move. We check the result against Spark's answer but do + // not assert the operator is Comet for this query, because rlike itself runs on the JVM + // Spark path when the java-engine dispatcher is disabled. + val pattern = "disabled_mode_marker_[0-9]+" + CometCodegenDispatchUDF.resetStats() + withSQLConf( + CometConf.COMET_CODEGEN_DISPATCH_MODE.key -> CometConf.CODEGEN_DISPATCH_DISABLED) { + withSubjects("disabled_mode_marker_1", null) { + checkSparkAnswer(sql(s"SELECT s rlike '$pattern' FROM t")) + } + } + val after = CometCodegenDispatchUDF.stats() + assert( + after.compileCount == 0 && after.cacheHitCount == 0, + s"expected no dispatcher activity under disabled mode, got $after") + } + + test("auto mode prefers dispatcher when regex engine is java") { + // `auto` with engine=java should resolve to codegen (the serde's documented preference). Use + // a pattern unique to this test to guarantee a fresh compile. + val pattern = "auto_mode_marker_[0-9]+" + CometCodegenDispatchUDF.resetStats() + withSQLConf( + CometConf.COMET_CODEGEN_DISPATCH_MODE.key -> CometConf.CODEGEN_DISPATCH_AUTO, + CometConf.COMET_REGEXP_ENGINE.key -> CometConf.REGEXP_ENGINE_JAVA) { + withSubjects("auto_mode_marker_7", null) { + checkSparkAnswerAndOperator(sql(s"SELECT s rlike '$pattern' FROM t")) + } + } + val after = CometCodegenDispatchUDF.stats() + assert( + after.compileCount + after.cacheHitCount >= 1, + s"expected dispatcher activity under auto mode with java engine, got $after") + } + + test("per-batch nullability produces distinct compiles for null-present vs null-absent") { + // Same expression + same Arrow vector class + different observed nullability should hit + // different cache keys, because `ArrowColumnSpec.nullable` flips when the batch has no + // nulls. We don't assert on per-run deltas because Spark's partitioning can split the + // subject table so the first query alone sees both nullability variants across different + // partitions. Instead, assert the total invariant: across both queries we see at least two + // compiles, proving the cache key discriminated on nullability. + val pattern = "nullability_marker_[0-9]+" + CometCodegenDispatchUDF.resetStats() + + withSubjects("nullability_marker_1", null, "nullability_marker_2") { + checkSparkAnswerAndOperator(sql(s"SELECT s rlike '$pattern' FROM t")) + } + withSubjects("nullability_marker_3", "nullability_marker_4") { + checkSparkAnswerAndOperator(sql(s"SELECT s rlike '$pattern' FROM t")) + } + val after = CometCodegenDispatchUDF.stats() + + assert( + after.compileCount >= 2, + "expected at least two compiles across both nullability distributions (one per " + + s"nullable=true/false variant); got $after") + } + + test("dispatcher stats increment on compile and hit") { + // Use a pattern no other test in this suite compiles, so the first run is guaranteed to be a + // cache miss regardless of test order. + val pattern = "stats_only_marker_[0-9]+" + CometCodegenDispatchUDF.resetStats() + withSubjects("stats_only_marker_42", "nope", null) { + checkSparkAnswerAndOperator(sql(s"SELECT s rlike '$pattern' FROM t")) + } + val firstRun = CometCodegenDispatchUDF.stats() + assert( + firstRun.compileCount >= 1, + s"expected compile count >= 1 after first query, got $firstRun") + assert(firstRun.cacheSize >= 1, s"expected cache size >= 1 after first query, got $firstRun") + + // Re-run the same expression against the same schema; should reuse the compiled kernel. + val compileBefore = firstRun.compileCount + withSubjects("stats_only_marker_9", null) { + checkSparkAnswerAndOperator(sql(s"SELECT s rlike '$pattern' FROM t")) + } + val secondRun = CometCodegenDispatchUDF.stats() + assert( + secondRun.cacheHitCount >= 1, + s"expected cache hits >= 1 after second query, got $secondRun") + assert( + secondRun.compileCount == compileBefore, + s"expected no additional compile on second query, got $secondRun vs $firstRun") + } + + /** + * Collation smoke test. Spark 4.x associates a collation id with each `StringType` instance. + * The codegen dispatcher's argument for handling collation is "Spark's own `doGenCode` for + * regex-on-string uses `CollationFactory` / `CollationSupport`, so we inherit the right + * semantics by reusing it". This test proves that end to end for the most common shape: `rlike` + * on a UTF8_LCASE-cast subject. The collation lives on the expression (`COLLATE` cast in SQL) + * rather than the column, so the parquet scan reads a default-collation column and stays + * native; only the Project carries the collated regex evaluation. + * + * Limits worth knowing about (separate work, not codegen-dispatch issues): + * - `regexp_replace` with a collated subject: Spark's analyzer wraps the regex literal in + * `Collate(Literal, ...)`. Our `RegExpReplace` serde's `getSupportLevel` requires a bare + * `Literal` for the pattern, so it rejects before the dispatcher is invoked. Widening the + * serde to unwrap `Collate(Literal, ...)` would unblock this; it's a serde-side change, not + * a codegen-side gap. + * - `rlike` on an ICU collation (UNICODE_CI etc.): Spark itself rejects with a type mismatch + * ("requires STRING, got STRING COLLATE UNICODE_CI"). RLike contracts on UTF8_BINARY + * semantics; binary collations like UTF8_LCASE work, ICU ones don't. + */ + test("rlike on UTF8_LCASE-cast column matches case-insensitively") { + assume(isSpark40Plus, "non-default collations require Spark 4.0+") + withSubjects("Abc", "abc", "ABC", "xyz", null) { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT s, (s COLLATE UTF8_LCASE) rlike 'abc' FROM t")) + } + } + } + + test("per-partition kernel preserves Nondeterministic state across batches") { + // Compose `monotonically_increasing_id()` with rlike so the dispatcher routes the + // composed tree (the inner expression by itself wouldn't have a serde). The expression + // also references `s` so the proto carries at least one data column, giving the bridge a + // row count signal. Per-partition kernel caching means the id counter advances across + // batches in one partition; without it, every batch would restart at 0 and we'd disagree + // with Spark on the right side of the rlike. The rlike pattern is permissive on purpose; + // we're testing state correctness, not regex matching. + val rows = (0 until 4096).map(i => s"row_$i") + withSubjects(rows: _*) { + assertCodegenDidWork { + checkSparkAnswerAndOperator( + sql("SELECT concat(s, cast(monotonically_increasing_id() as string)) rlike " + + "'^row_[0-9]+[0-9]+$' FROM t")) + } + } + } + + /** + * Scalar ScalaUDF smoke tests. These prove that user-registered UDFs route through the codegen + * dispatcher rather than forcing a whole-plan Spark fallback. Spark's `ScalaUDF.doGenCode` + * already emits compilable Java that calls the user function via `ctx.addReferenceObj`, so the + * dispatcher's compile path picks it up for free. Tests that user-registered UDFs route through + * the dispatcher rather than forcing whole-plan Spark fallback. + */ + + test("registered string ScalaUDF routes through dispatcher") { + spark.udf.register("shout", (s: String) => if (s == null) null else s.toUpperCase + "!") + withSubjects("Abc", "xyz", null, "mixed") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT shout(s) FROM t")) + } + } + } + + test("multi-arg ScalaUDF over string + literal routes through dispatcher") { + spark.udf.register( + "prepend", + (prefix: String, s: String) => if (s == null) null else prefix + s) + withSubjects("one", "two", null) { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT prepend('[', s) FROM t")) + } + } + } + + test("ScalaUDF composed with an rlike subject") { + // Outer rlike binds the whole tree, including the ScalaUDF inside its subject. One + // compiled kernel handles rlike + user-code + Arrow reads in a single fused method. + spark.udf.register("wrap", (s: String) => if (s == null) null else s"|$s|") + withSubjects("abc", "def", null) { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT wrap(s) rlike '^\\\\|[a-z]+\\\\|$' FROM t")) + } + } + } + + test("composed ScalaUDFs outer(inner(s)) fuse into one kernel") { + // Two user UDFs stacked, both operating on String. The dispatcher binds the whole tree and + // Spark's codegen emits two `ctx.addReferenceObj` calls inside one generated method. Races + // on the `ExpressionEncoder` serializers in `references` would show up here since each UDF + // contributes its own stateful serializer; the `freshReferences` closure in `CompiledKernel` + // is what keeps this correct across partitions. + spark.udf.register("inner", (s: String) => if (s == null) null else s.toUpperCase) + spark.udf.register("outer", (s: String) => if (s == null) null else s"<$s>") + withSubjects("abc", null, "xyz", "MiXeD") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT outer(inner(s)) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[VarCharVector]), StringType) + } + } + + test("ScalaUDFs of different types compose: isShort(len(s))") { + // Exercises an input type transition: String -> Int -> Boolean. Two user UDFs with + // different I/O type shapes in one tree, one Janino compile. + spark.udf.register("len", (s: String) => if (s == null) -1 else s.length) + spark.udf.register("isShort", (i: Int) => i < 5) + withSubjects("ab", "abcdef", null, "hi") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT isShort(len(s)) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[VarCharVector]), BooleanType) + } + } + + test("three-deep ScalaUDF composition lvl3(lvl2(lvl1(s)))") { + // Three user UDFs stacked in one tree: String -> String -> String -> Int. The fused kernel + // carries three `ctx.addReferenceObj` calls. `assertOneKernelForSubtree` asserts that the + // whole chain collapses into a single compile rather than one per nesting level. + // Input rows intentionally exclude nulls: per-batch nullability is a cache-key dimension + // (`nullable()` reads `getNullCount != 0`), so a null-present batch compiles a second kernel + // specialized for `nullable=true`. Null handling through composed UDFs is covered by the + // other composition tests above. + spark.udf.register("lvl1", (s: String) => if (s == null) null else s.toUpperCase) + spark.udf.register("lvl2", (s: String) => if (s == null) null else s.reverse) + spark.udf.register("lvl3", (s: String) => if (s == null) -1 else s.length) + withSubjects("abc", "hello world", "x") { + assertOneKernelForSubtree { + checkSparkAnswerAndOperator(sql("SELECT lvl3(lvl2(lvl1(s))) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[VarCharVector]), IntegerType) + } + } + + test("multi-column ScalaUDF composition join(upperU(c1), lowerU(c2))") { + // One multi-arg user UDF consuming two other user UDFs, each on a different input column. + // The bound tree has two BoundReferences, and the kernel is specialized on two VarCharVector + // columns. `assertOneKernelForSubtree` asserts that the two-branch composition fuses into a + // single kernel rather than one per branch or one per UDF. + // Input rows intentionally exclude nulls (see note on the three-deep test above). + spark.udf.register("upperU", (s: String) => if (s == null) null else s.toUpperCase) + spark.udf.register("lowerU", (s: String) => if (s == null) null else s.toLowerCase) + spark.udf.register( + "joinU", + (a: String, b: String) => if (a == null || b == null) null else s"$a-$b") + withTwoStringCols(("Abc", "XYZ"), ("Foo", "bar"), ("baz", "Bar"), ("Hi", "Lo")) { + assertOneKernelForSubtree { + checkSparkAnswerAndOperator(sql("SELECT joinU(upperU(c1), lowerU(c2)) FROM t")) + } + assertKernelSignaturePresent( + Seq(classOf[VarCharVector], classOf[VarCharVector]), + StringType) + } + } + + /** + * Type-surface ScalaUDF tests. Each exercises a distinct Arrow input vector class plus the + * matching output writer through the full SQL -> serde -> dispatcher -> Janino -> kernel + * pipeline. Before ScalaUDF routing, non-string types were covered only by the direct-compile + * suite (since the regex serdes all produce string or boolean output). + * + * Backed by parquet tables with declared column types rather than derived-from-range views: + * when the source column is a derived projection (e.g. `cast(id as int)` from `spark.range`), + * the optimizer folds the cast into the outer plan and the ScalaUDF's `BoundReference` ends up + * on the underlying long, not the projected int. A declared parquet column type keeps the + * `AttributeReference` on the expected type and the Arrow vector the dispatcher sees matches + * the UDF's signature. + */ + private def withTypedCol(sqlType: String, valueLiterals: String*)(f: => Unit): Unit = { + withTable("t") { + sql(s"CREATE TABLE t (c $sqlType) USING parquet") + if (valueLiterals.nonEmpty) { + val rows = valueLiterals.map(v => s"($v)").mkString(", ") + sql(s"INSERT INTO t VALUES $rows") + } + f + } + } + + test("ScalaUDF on IntegerType (IntVector, getInt)") { + spark.udf.register("doubleIt", (i: Int) => i * 2) + withTypedCol("INT", "1", "2", "100") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT doubleIt(c) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[IntVector]), IntegerType) + } + } + + test("ScalaUDF on LongType (BigIntVector, getLong)") { + spark.udf.register("inc", (l: Long) => l + 1L) + withTypedCol("BIGINT", "1", "2", "100") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT inc(c) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[BigIntVector]), LongType) + } + } + + test("ScalaUDF on DoubleType (Float8Vector, getDouble)") { + spark.udf.register("halve", (d: Double) => d / 2.0) + withTypedCol("DOUBLE", "1.5", "2.5", "100.0") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT halve(c) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[Float8Vector]), DoubleType) + } + } + + test("ScalaUDF on FloatType (Float4Vector, getFloat)") { + spark.udf.register("scaleF", (f: Float) => f * 1.5f) + withTypedCol("FLOAT", "CAST(1.5 AS FLOAT)", "CAST(2.5 AS FLOAT)") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT scaleF(c) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[Float4Vector]), FloatType) + } + } + + test("ScalaUDF on BooleanType (BitVector, getBoolean)") { + spark.udf.register("neg", (b: Boolean) => !b) + withTypedCol("BOOLEAN", "TRUE", "FALSE", "TRUE") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT neg(c) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[BitVector]), BooleanType) + } + } + + test("ScalaUDF on ShortType (SmallIntVector, getShort)") { + spark.udf.register("incS", (s: Short) => (s + 1).toShort) + withTypedCol( + "SMALLINT", + "CAST(1 AS SMALLINT)", + "CAST(2 AS SMALLINT)", + "CAST(30000 AS SMALLINT)") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT incS(c) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[SmallIntVector]), ShortType) + } + } + + test("ScalaUDF on ByteType (TinyIntVector, getByte)") { + spark.udf.register("incB", (b: Byte) => (b + 1).toByte) + withTypedCol("TINYINT", "CAST(1 AS TINYINT)", "CAST(2 AS TINYINT)", "CAST(100 AS TINYINT)") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT incB(c) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[TinyIntVector]), ByteType) + } + } + + test("ScalaUDF on DateType (DateDayVector, getInt)") { + // Date input flows through the Int getter because DateType is physically int. The UDF takes + // java.sql.Date and Spark's encoder handles the int -> Date materialization. + spark.udf.register( + "nextDay", + (d: java.sql.Date) => if (d == null) null else new java.sql.Date(d.getTime + 86400000L)) + withTypedCol("DATE", "DATE'2024-01-01'", "DATE'2024-06-15'", "DATE'1970-01-01'") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT nextDay(c) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[DateDayVector]), DateType) + } + } + + test("ScalaUDF on TimestampType (TimeStampMicroTZVector, getLong)") { + spark.udf.register( + "plusSecond", + (t: java.sql.Timestamp) => + if (t == null) null else new java.sql.Timestamp(t.getTime + 1000L)) + withTypedCol( + "TIMESTAMP", + "TIMESTAMP'2024-01-01 12:00:00'", + "TIMESTAMP'2024-06-15 23:59:59'") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT plusSecond(c) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[TimeStampMicroTZVector]), TimestampType) + } + } + + test("ScalaUDF on TimestampNTZType (TimeStampMicroVector, getLong)") { + spark.udf.register( + "plusDayNtz", + (ldt: java.time.LocalDateTime) => if (ldt == null) null else ldt.plusDays(1)) + withTypedCol( + "TIMESTAMP_NTZ", + "TIMESTAMP_NTZ'2024-01-01 12:00:00'", + "TIMESTAMP_NTZ'2024-06-15 23:59:59'") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT plusDayNtz(c) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[TimeStampMicroVector]), TimestampNTZType) + } + } + + test("ScalaUDF returning DateType") { + spark.udf.register("epochDay", (_: Int) => java.sql.Date.valueOf("1970-01-01")) + withTypedCol("INT", "1", "2", "3") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT epochDay(c) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[IntVector]), DateType) + } + } + + test("ScalaUDF returning TimestampType") { + spark.udf.register("mkTs", (s: Long) => new java.sql.Timestamp(s * 1000L)) + withTypedCol("BIGINT", "0", "1700000000", "1750000000") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT mkTs(c) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[BigIntVector]), TimestampType) + } + } + + test("ScalaUDF returning TimestampNTZType") { + spark.udf.register( + "mkTsNtz", + (s: Long) => java.time.LocalDateTime.ofEpochSecond(s, 0, java.time.ZoneOffset.UTC)) + withTypedCol("BIGINT", "0", "1700000000", "1750000000") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT mkTsNtz(c) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[BigIntVector]), TimestampNTZType) + } + } + + test("ScalaUDF returning a different type than its input") { + // String -> Int transition forces the output writer to switch from VarChar to Int. Exercises + // the `IntegerType` output path end to end from a user UDF (previously only regexp_instr + // covered it). + spark.udf.register("codePoint", (s: String) => if (s == null) 0 else s.codePointAt(0)) + withSubjects("abc", "A", null, "!") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT codePoint(s) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[VarCharVector]), IntegerType) + } + } + + test("ScalaUDF returning BinaryType (VarBinaryVector output writer)") { + // Binary output writer path, exercised here by a user UDF for the first time. Before this + // the writer only had direct-compile unit tests. + spark.udf.register("bytes", (s: String) => if (s == null) null else s.getBytes("UTF-8")) + withSubjects("abc", null, "hello") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT bytes(s) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[VarCharVector]), BinaryType) + } + } + + test("ScalaUDF returning ArrayType(StringType) (ListVector output writer)") { + // First use of the ArrayType output path end-to-end. The UDF returns a `Seq[String]`, + // which Spark encodes as `ArrayType(StringType, containsNull = true)`. The dispatcher's + // canHandle accepts it (ArrayType is supported when its element type is supported), + // allocateOutput builds a ListVector with an inner VarCharVector, and emitWrite recurses + // into the StringType case for the per-element UTF8 on-heap shortcut. End-to-end answer + // matches Spark. + spark.udf.register( + "splitComma", + (s: String) => if (s == null) null else s.split(",", -1).toSeq) + withSubjects("a,b,c", "x", null, "", "one,,three") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT splitComma(s) FROM t")) + } + } + } + + test("ScalaUDF returning ArrayType(IntegerType)") { + // Exercises ArrayType output with a primitive element. emitWrite's ArrayType case + // recurses into the IntegerType case for the inner write; no byte[] allocation involved. + spark.udf.register( + "asLengths", + (s: String) => if (s == null) null else s.split(",").map(_.length).toSeq) + withSubjects("a,bb,ccc", null, "xyzzy") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT asLengths(s) FROM t")) + } + } + } + + test("zero-column ScalaUDF produces one row per input row") { + // Non-deterministic (so Spark doesn't constant-fold) with a deterministic body (so + // Spark-vs-Comet comparison stays honest). The expression has no `AttributeReference`, + // so the serde produces an empty data-arg list and the dispatcher has no data column to + // read the batch size from. Guards the `numRows` path through the JNI bridge. + import org.apache.spark.sql.functions.udf + val alwaysHello = udf(() => "hello").asNondeterministic() + spark.udf.register("helloU", alwaysHello) + withSubjects("a", "b", null, "c") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT helloU() FROM t")) + } + } + } + + /** + * Decimal tests. The dispatcher's `getDecimal` getter specializes on the `BoundReference`'s + * `DecimalType.precision` at source-generation time: precision <= 18 emits an unscaled-long + * fast path via `Decimal.createUnsafe`, precision > 18 emits a `BigDecimal + Decimal.apply` + * slow path. These smoke tests exercise both sides of the split end to end and verify Spark and + * Comet agree on correctness across typical decimal workloads. + */ + private def withDecimalTable(decimalType: String, values: Seq[String])(f: => Unit): Unit = { + withTable("t") { + sql(s"CREATE TABLE t (d $decimalType) USING parquet") + val rows = values.map(v => if (v == null) "(NULL)" else s"($v)").mkString(", ") + if (values.nonEmpty) sql(s"INSERT INTO t VALUES $rows") + f + } + } + + test("ScalaUDF over Decimal(9, 2) (short precision, fast path)") { + // Short-precision identity UDF. The column's DecimalType has precision 9, so the generated + // getter for ordinal 0 emits only the unscaled-long fast path. The UDF's Scala-side signature + // uses `java.math.BigDecimal`, which Spark's encoder pins at DecimalType(38, 18); the implicit + // Cast from DECIMAL(9, 2) -> DECIMAL(38, 18) runs inside Spark's generated code, not via our + // kernel's getter, so the fast path still fires on the column read. + spark.udf.register("decId9_2", (d: java.math.BigDecimal) => d) + withDecimalTable("DECIMAL(9, 2)", Seq("0.00", "1.50", "-1.50", "9999.99", "-9999.99", null)) { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT decId9_2(d) FROM t")) + } + } + } + + test("ScalaUDF over Decimal(18, 0) (max short precision, fast path)") { + // Boundary precision: 18 is the last value for which the unscaled representation fits in a + // signed 64-bit long. The fast path must still be selected. + spark.udf.register("decId18_0", (d: java.math.BigDecimal) => d) + withDecimalTable( + "DECIMAL(18, 0)", + Seq("0", "1", "-1", "999999999999999999", "-999999999999999999", null)) { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT decId18_0(d) FROM t")) + } + } + } + + test("ScalaUDF over Decimal(18, 9) (max short precision with scale, fast path)") { + // Same precision as above but with scale 9 to exercise the fractional side of the long + // decimal. Spark `Decimal` stores both as the same unscaled long; only the `scale` parameter + // differs. + spark.udf.register("decId18_9", (d: java.math.BigDecimal) => d) + withDecimalTable( + "DECIMAL(18, 9)", + Seq("0.000000000", "1.123456789", "-1.123456789", "999999999.999999999", null)) { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT decId18_9(d) FROM t")) + } + } + } + + test("ScalaUDF over Decimal(19, 0) (just past short precision, slow path)") { + // First precision where the unscaled value can exceed `Long.MAX_VALUE`. The generated getter + // must emit only the slow path; the fast-path marker must be absent in the compiled kernel. + spark.udf.register("decId19_0", (d: java.math.BigDecimal) => d) + withDecimalTable( + "DECIMAL(19, 0)", + Seq("0", "1", "-1", "9999999999999999999", "-9999999999999999999", null)) { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT decId19_0(d) FROM t")) + } + } + } + + test("ScalaUDF over Decimal(38, 10) (max precision, slow path)") { + // Max decimal128 precision. Exercises the `getObject + Decimal.apply` branch and the + // end-to-end BigDecimal conversion path with a non-trivial scale. + spark.udf.register("decId38_10", (d: java.math.BigDecimal) => d) + withDecimalTable( + "DECIMAL(38, 10)", + Seq( + "0.0000000000", + "1.1234567890", + "-1.1234567890", + "9999999999999999999999999999.0000000000", + null)) { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT decId38_10(d) FROM t")) + } + } + } + + test("ScalaUDF sees TaskContext.partitionId() per partition") { + // Direct probe: register a ScalaUDF that reads TaskContext.partitionId() and returns it. + // Spark's own task thread has TaskContext set, so each partition's rows carry that + // partition's index. For the dispatcher to match Spark, the invocation thread must see a + // live TaskContext. With the `createPlan`-time TaskContext capture + bridge-side + // `TaskContext.setTaskContext` install (see `CometUdfBridge.evaluate` and + // `CometTaskContextShim`), Tokio workers see the propagated TaskContext and the UDF + // returns the real partitionId. Without that propagation, `TaskContext.get()` returns null + // on the Tokio thread and the sentinel (-1) leaks through, diverging from Spark. + spark.udf.register( + "pid", + (_: Long) => { + val tc = TaskContext.get() + if (tc != null) tc.partitionId() else -1 + }) + val df = spark + .range(0, 1024, 1, numPartitions = 4) + .selectExpr("id", "pid(id) as p") + checkSparkAnswerAndOperator(df) + } + + test("ScalaUDF sees TaskContext from fully-native parquet plan") { + // The `spark.range`-based test above runs through `CometSparkRowToColumnar`, which executes + // on a Spark task thread where TaskContext is live even without explicit propagation. The + // fully-native path through `CometNativeScan` runs the JVM UDF bridge on a Tokio worker + // thread where TaskContext.get() would otherwise be null. This test forces that path by + // sourcing from a Parquet table written as multiple files (so the native read produces + // multiple partitions) and asserting the UDF still sees the per-partition TaskContext via + // the `createPlan`-time capture + bridge-side install. + spark.udf.register( + "pidP", + (_: Int) => { + val tc = TaskContext.get() + if (tc != null) tc.partitionId() else -1 + }) + withTable("t") { + sql("CREATE TABLE t (x INT) USING parquet") + // Multiple INSERT statements -> multiple parquet files -> multiple read splits -> + // multiple partitions. + sql("INSERT INTO t VALUES (1), (2), (3), (4)") + sql("INSERT INTO t VALUES (5), (6), (7), (8)") + sql("INSERT INTO t VALUES (9), (10), (11), (12)") + sql("INSERT INTO t VALUES (13), (14), (15), (16)") + checkSparkAnswerAndOperator(sql("SELECT x, pidP(x) AS p FROM t")) + } + } + + test("Rand seeded per partition across a multi-partition table") { + // Rand.doGenCode registers an XORShiftRandom via ctx.addMutableState and seeds it via + // ctx.addPartitionInitializationStatement. That init statement runs inside our kernel's + // `init(int partitionIndex)`, called once per kernel allocation. Spark seeds + // `XORShiftRandom(seed + partitionIndex)` per partition, so different partitions produce + // different sequences for the same seed. Matching Spark across partitions requires the + // kernel to see the real partition index, which the dispatcher derives from + // `TaskContext.get().partitionId()` — live on this path thanks to the bridge-level + // TaskContext propagation. Composing with a ScalaUDF (identity on Double here) forces the + // tree through codegen dispatch so the Rand evaluation runs inside our kernel's init + // rather than via Spark's normal codegen. + spark.udf.register("dblId", (d: Double) => d) + val df = spark + .range(0, 1024, 1, numPartitions = 4) + .selectExpr("id", "dblId(rand(42)) as r") + checkSparkAnswerAndOperator(df) + } + + test("ScalaUDF composed with reused scalar subquery across projection and filter") { + // The same scalar subquery appears in two sites: the projection (which the dispatcher + // compiles into a fused kernel) and the filter (a separate operator). Each site holds its + // own `ScalarSubquery` expression instance with its own `@volatile result` field. Each + // surrounding operator's inherited `SparkPlan.waitForSubqueries` populates its instance's + // `result` before the dispatcher's bridge serializes the expression. The populated value + // travels through closure serialization into the cache key's bytes, so different subquery + // values compile distinct kernels. Exercises the full subquery-correctness invariant + // documented on `CometBatchKernelCodegen.canHandle`. + spark.udf.register("addOne", (i: Int) => i + 1) + withTable("t", "t2") { + sql("CREATE TABLE t (x INT) USING parquet") + sql("INSERT INTO t VALUES (1), (2), (3), (4), (5)") + sql("CREATE TABLE t2 (v INT) USING parquet") + sql("INSERT INTO t2 VALUES (2), (4)") + checkSparkAnswerAndOperator( + sql("SELECT addOne(x) + (SELECT max(v) FROM t2) AS r " + + "FROM t WHERE addOne(x) < (SELECT max(v) FROM t2) * 2")) + } + } + + /** + * ArrayType input. The dispatcher emits a nested `InputArray_col0` final class per array-typed + * input column; Spark's generated `getArray(ord)` resolves to our kernel's switch which returns + * the pre-allocated instance after resetting its start/length against the list's offsets. + * Element reads go through the typed child-vector field with no `ArrayData` copy or boxing. + * + * Each smoke test exercises the same serde/transport path at a different element type so the + * nested getter emitter's scalar-element cases are each covered: `StringType` (zero-copy + * `UTF8String.fromAddress`), `IntegerType` (primitive direct), and `DecimalType(p <= 18)` + * (decimal128 fast path). + */ + private def withArrayTable(colType: String, insertRows: String)(f: => Unit): Unit = { + withTable("t") { + sql(s"CREATE TABLE t (a $colType) USING parquet") + sql(s"INSERT INTO t VALUES $insertRows") + f + } + } + + test("ScalaUDF taking Seq[String] reads through nested ArrayData class") { + spark.udf.register( + "headOrNull", + (arr: Seq[String]) => if (arr == null || arr.isEmpty) null else arr.head) + withArrayTable( + "ARRAY", + "(array('a', 'b', 'c')), (array('x')), (null), (array()), (array('alone'))") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT headOrNull(a) FROM t")) + } + } + } + + test("ScalaUDF taking Seq[String] iterating all elements") { + spark.udf.register( + "concatArr", + (arr: Seq[String]) => if (arr == null) null else arr.mkString("|")) + withArrayTable( + "ARRAY", + "(array('one', 'two', 'three')), (array('solo')), (null), (array())") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT concatArr(a) FROM t")) + } + } + } + + test("ScalaUDF taking Seq[Int] hits primitive element getter") { + spark.udf.register("sumArr", (arr: Seq[Int]) => if (arr == null) -1 else arr.sum) + withArrayTable( + "ARRAY", + "(array(1, 2, 3)), (array(-5, 5)), (array()), (null), (array(42))") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT sumArr(a) FROM t")) + } + } + } + + test("ScalaUDF taking Seq[BigDecimal] hits short-precision decimal fast path") { + // DecimalType(10, 2) is well inside p <= 18, so the nested-array `getDecimal` emits the + // unscaled-long fast path (see `emitNestedArrayElementGetter`). A `BigDecimal` UDF argument + // forces Spark's encoder to call `getDecimal(i, 10, 2)` on our nested ArrayData for each + // element, which exercises that code path end to end. + spark.udf.register( + "sumDecArr", + (arr: Seq[java.math.BigDecimal]) => + if (arr == null) null + else { + var acc = java.math.BigDecimal.ZERO + arr.foreach(v => if (v != null) acc = acc.add(v)) + acc + }) + withArrayTable( + "ARRAY", + "(array(1.23, 4.56)), (array(-9.99)), (null), (array())") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT sumDecArr(a) FROM t")) + } + } + } + + // ============================================================================================= + // StructType + MapType + nested-composition smoke tests. Source tests prove the emitted Java + // is well-shaped; these tests prove Janino compiles it and the runtime roundtrip matches + // Spark. + // ============================================================================================= + + test("ScalaUDF composes with struct-field access reading Struct.age") { + // Keeps the UDF arg scalar (Int) but puts a `GetStructField` under it so the codegen + // dispatcher compiles the struct-input read path (`row.getStruct(0, 2).getInt(1)`). + spark.udf.register("doubleInt", (i: Int) => i * 2) + withTable("t") { + sql("CREATE TABLE t (s STRUCT) USING parquet") + sql( + "INSERT INTO t VALUES " + + "(named_struct('name', 'alice', 'age', 30)), " + + "(named_struct('name', 'bob', 'age', 42)), " + + "(null)") + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT doubleInt(s.age) FROM t")) + } + } + } + + test("ScalaUDF taking full Struct value (case class arg)") { + // Case-class UDF arguments: test data must not include null top-level rows. + // `ScalaUDF.scalaConverter` applies Spark's `ExpressionEncoder.Deserializer` on every row + // to materialize the case-class instance. The generated deserializer has a + // `newInstance(NameAgePair)` step that throws `EXPRESSION_DECODING_FAILED` on a null input, + // independent of the dispatcher. Case-class UDF tests omit null top-level rows; other + // tests with plain `Seq` / `Map` args can include nulls because the deserializer hands null + // to the UDF body which handles it. + spark.udf.register("fmtPair", (r: NameAgePair) => s"${r.name}:${r.age}") + withTable("t") { + sql("CREATE TABLE t (s STRUCT) USING parquet") + sql( + "INSERT INTO t VALUES " + + "(named_struct('name', 'alice', 'age', 30)), " + + "(named_struct('name', 'bob', 'age', 42))") + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT fmtPair(s) FROM t")) + } + } + } + + test("ScalaUDF returning Struct (case class output)") { + spark.udf.register("makePair", (i: Int) => NameAgePair(s"n$i", i)) + withTypedCol("INT", "1", "2", "3") { + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT makePair(c) FROM t")) + } + } + } + + test("ScalaUDF taking Map") { + spark.udf.register("sumMap", (m: Map[String, Int]) => if (m == null) -1 else m.values.sum) + withTable("t") { + sql("CREATE TABLE t (m MAP) USING parquet") + sql("INSERT INTO t VALUES (map('a', 1, 'b', 2)), (map()), (null)") + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT sumMap(m) FROM t")) + } + } + } + + test("ScalaUDF returning Map") { + spark.udf.register( + "singletonMap", + (s: String, i: Int) => if (s == null) null else Map(s -> i)) + withTable("t") { + sql("CREATE TABLE t (s STRING, i INT) USING parquet") + sql("INSERT INTO t VALUES ('a', 1), ('b', 2), (null, 3)") + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT singletonMap(s, i) FROM t")) + } + } + } + + test("ScalaUDF taking Map> exercises nested composition") { + spark.udf.register( + "totalLens", + (m: Map[String, Seq[Int]]) => if (m == null) -1 else m.values.flatten.sum) + withTable("t") { + sql("CREATE TABLE t (m MAP>) USING parquet") + sql( + "INSERT INTO t VALUES " + + "(map('a', array(1, 2, 3), 'b', array(10))), " + + "(map()), " + + "(null)") + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT totalLens(m) FROM t")) + } + } + } + + test("ScalaUDF round-trips Array> (nested array input + output)") { + // Exercises nested-array input reads and nested-list output writes in one call: the inner + // `InputArray_col0_e` class on the input side and the recursive emitWrite on the output. + spark.udf.register( + "reverseRows", + (arr: Seq[Seq[Int]]) => if (arr == null) null else arr.map(_.reverse)) + withTable("t") { + sql("CREATE TABLE t (a ARRAY>) USING parquet") + sql( + "INSERT INTO t VALUES " + + "(array(array(1, 2, 3), array(4, 5))), " + + "(array(array())), " + + "(null)") + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT reverseRows(a) FROM t")) + } + } + } + + test("ScalaUDF round-trips Struct>") { + // Struct with a complex field on both sides: input reads go through InputStruct_col0 + + // InputArray_col0_f1, output writes through StructVector + ListVector. + // Null top-level rows omitted - case-class arg; see the note on `fmtPair` above. + spark.udf.register( + "growItems", + (r: NameItems) => + if (r == null) null else NameItems(r.name, if (r.items == null) null else r.items :+ 0)) + withTable("t") { + sql("CREATE TABLE t (s STRUCT>) USING parquet") + sql( + "INSERT INTO t VALUES " + + "(named_struct('name', 'a', 'items', array(1, 2))), " + + "(named_struct('name', 'b', 'items', array()))") + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT growItems(s) FROM t")) + } + } + } + + test("ScalaUDF round-trips Map> (nested value both sides)") { + // Map input read goes through InputMap_col0 + InputArray_col0_v (the complex-value side); + // output write emits MapVector + entries Struct + per-value ListVector inside the map's + // entries struct. + spark.udf.register( + "sortValues", + (m: Map[String, Seq[Int]]) => + if (m == null) null + else m.map { case (k, v) => k -> (if (v == null) null else v.sorted) }) + withTable("t") { + sql("CREATE TABLE t (m MAP>) USING parquet") + sql( + "INSERT INTO t VALUES " + + "(map('a', array(3, 1, 2), 'b', array(10))), " + + "(map()), " + + "(null)") + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT sortValues(m) FROM t")) + } + } + } + + test("ScalaUDF round-trips Map>") { + // Struct value inside a map, both sides. Null top-level rows omitted - the map value is a + // case class; see the note on `fmtPair` above. + spark.udf.register( + "tagValues", + (m: Map[String, XyPair]) => + if (m == null) null + else + m.map { case (k, v) => k -> (if (v == null) null else XyPair(v.x + 1, s"<${v.y}>")) }) + withTable("t") { + sql("CREATE TABLE t (m MAP>) USING parquet") + sql( + "INSERT INTO t VALUES " + + "(map('a', named_struct('x', 1, 'y', 'one'))), " + + "(map())") + assertCodegenDidWork { + checkSparkAnswerAndOperator(sql("SELECT tagValues(m) FROM t")) + } + } + } +} + +/** + * Case class used by the struct-input / struct-output smoke tests. Must be declared at file scope + * (not inside the test class) so Spark's TypeTag-based UDF encoder can resolve the Spark + * `StructType` schema from the Scala class. + */ +private case class NameAgePair(name: String, age: Int) + +private case class NameItems(name: String, items: Seq[Int]) + +private case class XyPair(x: Int, y: String) diff --git a/spark/src/test/scala/org/apache/comet/CometCodegenSourceSuite.scala b/spark/src/test/scala/org/apache/comet/CometCodegenSourceSuite.scala new file mode 100644 index 0000000000..2b8ca796b6 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometCodegenSourceSuite.scala @@ -0,0 +1,782 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Add, BoundReference, Coalesce, Concat, CreateMap, ElementAt, Expression, GetStructField, LeafExpression, Length, Literal, Nondeterministic, Rand, RegExpReplace, RLike, Size, StringSplit, Unevaluable, Upper} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.types.{ArrayType, DataType, DecimalType, IntegerType, LongType, MapType, StringType, StructField, StructType} + +import org.apache.comet.udf.CometBatchKernelCodegen +import org.apache.comet.udf.CometBatchKernelCodegen.{ArrayColumnSpec, ArrowColumnSpec, MapColumnSpec, ScalarColumnSpec, StructColumnSpec, StructFieldSpec} + +// Resolve Arrow vector classes through the codegen object so tests see the same `Class` objects +// the shaded `common` module sees. A direct `classOf[org.apache.arrow.vector.VarCharVector]` here +// would be the unshaded class from the test classpath, which is not `==` to the shaded class the +// production pattern-matches against. + +/** + * Generated-source inspection tests. These exercise `CometBatchKernelCodegen.generateSource` and + * assert on the emitted Java directly, without invoking Janino. The goal is to catch regressions + * in the optimizations we claim the dispatcher applies: + * + * - `NullIntolerant` short-circuit wraps `ev.code` in `if (any-input-null) { setNull; } else { + * ev.code; write; }`. + * - Non-nullable column declaration emits `return false;` from `isNullAt(ord)` and, when the + * dispatcher rewrites the `BoundReference`, Spark's `doGenCode` stops emitting its own + * `row.isNullAt(ord)` probe. + * - Zero-copy string reads route through `UTF8String.fromAddress`. + * - The specialized `RegExpReplace` emitter engages for the shape its guard accepts. + * + * These are the smallest durable tests that the claimed optimizations actually reach the + * generated Java, and they document the shapes future contributors should preserve. + */ +class CometCodegenSourceSuite extends AnyFunSuite { + + private val varCharVectorClass = + CometBatchKernelCodegen.vectorClassBySimpleName("VarCharVector") + + private val nullableString = ArrowColumnSpec(varCharVectorClass, nullable = true) + private val nonNullableString = ArrowColumnSpec(varCharVectorClass, nullable = false) + + private def gen( + expr: org.apache.spark.sql.catalyst.expressions.Expression, + specs: ArrowColumnSpec*): String = + CometBatchKernelCodegen.generateSource(expr, specs.toIndexedSeq).body + + test("non-nullable column emits literal-false isNullAt case") { + val expr = Length(BoundReference(0, StringType, nullable = false)) + val src = gen(expr, nonNullableString) + assert( + src.contains("case 0: return false;"), + s"expected non-nullable isNullAt to return literal false; got:\n$src") + } + + test("non-nullable BoundReference elides Spark's own isNullAt probe in the expression body") { + // When the BoundReference carries `nullable=false`, Spark's `doGenCode` skips the + // `row.isNullAt(ord)` branch at source level. This is the payoff of the tree-rewrite in + // `CometCodegenDispatchUDF.lookupOrCompile`: subsequent expressions over the same column + // compile to tighter source rather than relying on JIT to constant-fold `isNullAt`. + val expr = Length(BoundReference(0, StringType, nullable = false)) + val src = gen(expr, nonNullableString) + assert( + !src.contains("row.isNullAt(0)"), + s"expected Spark's BoundReference null probe to be elided; got:\n$src") + } + + test("nullable column emits delegated isNullAt case") { + val expr = Length(BoundReference(0, StringType, nullable = true)) + val src = gen(expr, nullableString) + assert( + src.contains("case 0: return this.col0.isNull(this.rowIdx);"), + s"expected nullable isNullAt to delegate to the Arrow vector; got:\n$src") + } + + test("VarCharVector getUTF8String uses zero-copy fromAddress") { + val expr = Length(BoundReference(0, StringType, nullable = true)) + val src = gen(expr, nullableString) + assert( + src.contains("org.apache.spark.unsafe.types.UTF8String"), + s"expected UTF8String reference; got:\n$src") + assert(src.contains(".fromAddress("), s"expected zero-copy fromAddress read; got:\n$src") + } + + test("NullIntolerant expression emits input-null short-circuit before ev.code") { + // RLike is NullIntolerant (a null subject returns null, not "did not match"). Expect the + // default body to prepend `if (this.col0.isNull(i)) { setNull; } else { ... }` so null rows + // skip the whole regex eval, not just the setNull write. + val expr = + RLike(BoundReference(0, StringType, nullable = true), Literal.create("\\d+", StringType)) + val src = gen(expr, nullableString) + assert( + src.contains("this.col0.isNull(i)"), + s"expected NullIntolerant short-circuit on input ordinal 0; got:\n$src") + assert( + src.contains("output.setNull(i);"), + s"expected setNull emission for short-circuited null rows; got:\n$src") + } + + test("specialized RegExpReplace emitter engages for BoundReference subject") { + val expr = RegExpReplace( + subject = BoundReference(0, StringType, nullable = true), + regexp = Literal.create("\\d+", StringType), + rep = Literal.create("N", StringType), + pos = Literal(1, IntegerType)) + val src = gen(expr, nullableString) + // The specialized path reads bytes directly and runs `Pattern.matcher(...).replaceAll(...)` + // without detouring through `UTF8String`. Key marker: no `UTF8String` on the subject read + // inside the loop; instead `inputs` or the typed column field with `.get(i)`. + assert( + src.contains(".matcher(") && src.contains(".replaceAll("), + s"expected specialized Matcher.replaceAll shape; got:\n$src") + assert( + src.contains("this.col0.get(i)"), + s"expected specialized path to read bytes directly from the typed column; got:\n$src") + } + + test("specialized RegExpReplace declines when subject is not a BoundReference") { + // Upper breaks the specialization guard; fall through to the default `doGenCode` path. + val expr = RegExpReplace( + subject = Upper(BoundReference(0, StringType, nullable = true)), + regexp = Literal.create("\\d+", StringType), + rep = Literal.create("N", StringType), + pos = Literal(1, IntegerType)) + val src = gen(expr, nullableString) + // The default path routes the subject read through the kernel's getters. Marker of the + // default path: the Upper child emits `row.getUTF8String(0)` / `row.isNullAt(0)` because + // `ctx.INPUT_ROW = "row"`. + assert( + src.contains("row.getUTF8String(0)") || src.contains("this.getUTF8String(0)"), + s"expected default path with row/kernel getter invocation; got:\n$src") + } + + test("NullIntolerant short-circuit emitted when every node is NullIntolerant") { + // RLike(Upper(BoundReference), Literal): RLike is NullIntolerant, Upper is NullIntolerant, + // BoundReference and Literal are leaves. Every path from a leaf to the root propagates + // nulls, so the short-circuit heuristic ("any input null -> output null") holds. + val expr = + RLike( + Upper(BoundReference(0, StringType, nullable = true)), + Literal.create("x", StringType)) + val src = gen(expr, nullableString) + assert( + src.contains("if (this.col0.isNull(i))"), + s"expected short-circuit on col0 when every node is NullIntolerant; got:\n$src") + } + + test("NullIntolerant short-circuit skipped when a non-NullIntolerant node breaks the chain") { + // Concat is not NullIntolerant; null in some args doesn't necessarily produce a null + // result. The short-circuit heuristic would be incorrect here (short-circuiting on c0 or c1 + // being null would skip evaluation, but Concat's null handling differs). Expect the + // default path without the `if (colX.isNull(i) || colY.isNull(i))` wrapper, letting Spark's + // own `ev.code` handle nulls correctly. + val nullable1 = ArrowColumnSpec(varCharVectorClass, nullable = true) + val nullable2 = ArrowColumnSpec(varCharVectorClass, nullable = true) + val expr = RLike( + Concat( + Seq( + BoundReference(0, StringType, nullable = true), + BoundReference(1, StringType, nullable = true))), + Literal.create("x", StringType)) + val src = gen(expr, nullable1, nullable2) + assert( + !src.contains("this.col0.isNull(i) || this.col1.isNull(i)"), + "expected no pre-null short-circuit when Concat breaks the NullIntolerant chain; " + + s"got:\n$src") + } + + test("canHandle rejects CodegenFallback expressions") { + val expr = FakeCodegenFallback(BoundReference(0, StringType, nullable = true)) + val reason = CometBatchKernelCodegen.canHandle(expr) + assert(reason.isDefined, "expected canHandle to reject CodegenFallback") + assert( + reason.get.contains("FakeCodegenFallback"), + s"expected reason to name the rejected expression class; got: ${reason.get}") + } + + test("canHandle accepts Nondeterministic expressions (per-partition kernel handles state)") { + // Per-partition kernel instance caching in `CometCodegenDispatchUDF.ensureKernel` advances + // mutable state across batches in one partition, so Rand/Uuid/etc. produce the expected + // sequences. The previous canHandle rejection was conservative; with that caching in + // place, accepting Nondeterministic is correct. + val expr = FakeNondeterministic() + val reason = CometBatchKernelCodegen.canHandle(expr) + assert(reason.isEmpty, s"expected canHandle to accept Nondeterministic; got $reason") + } + + test("canHandle rejects Unevaluable expressions") { + val expr = FakeUnevaluable() + val reason = CometBatchKernelCodegen.canHandle(expr) + assert(reason.isDefined, "expected canHandle to reject Unevaluable") + assert( + reason.get.contains("FakeUnevaluable"), + s"expected reason to name the rejected expression class; got: ${reason.get}") + } + + test("CSE collapses a repeated subtree to one evaluation in the generated body") { + // `Add(Length(Upper(c0)), Length(Upper(c0)))` has `Length(Upper(c0))` as a common subtree. + // Length.doGenCode emits `$value.numChars()` on every Spark version the project targets, + // which makes it a stable activation marker. Upper's own doGenCode text drifts across + // versions (Spark 3.5 emits `UTF8String.toUpperCase()`, Spark 4 emits + // `CollationSupport.Upper.exec*` via collation-aware codegen), so we avoid it as a marker. + // When CSE fires, `Length(Upper(c0))` compiles into one `subExpr_*` helper whose body calls + // `numChars()` once; both uses in the `Add` read the cached result from mutable state. + // Without CSE, each Add child would emit its own `numChars()` call. + val upperOrd0 = Upper(BoundReference(0, StringType, nullable = true)) + val lenUpper = Length(upperOrd0) + val expr = Add(lenUpper, lenUpper) + val result = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(nullableString)) + val occurrences = "\\.numChars\\(\\)".r.findAllIn(result.body).size + assert( + occurrences == 1, + "expected CSE to collapse repeated Length evaluation to 1 numChars() call, " + + s"got $occurrences; src=\n${CodeFormatter.format(result.code)}") + // Additional proof: CSE emitted a `subExpr_` helper method. Without CSE the generator would + // have inlined the repeated subtree into the main body with no helper at all. + assert( + result.body.contains("subExpr_0(row)"), + s"expected CSE helper invocation; got:\n${CodeFormatter.format(result.code)}") + } + + test("CSE does not fire on non-deterministic expressions (regression guard)") { + // `Add(Rand(0), Rand(0))` is two structurally identical non-deterministic subtrees. CSE must + // not collapse them: each Rand call must produce an independent draw. Spark's CSE + // (`EquivalentExpressions.updateExprInMap`) filters non-deterministic expressions via + // `expr.deterministic`, so the two Rands stay separate. This test is a regression guard + // against Spark ever relaxing that check and against us accidentally applying CSE outside + // the `generateExpressions` path (which respects the filter). `Rand.doGenCode` emits one + // `$rng.nextDouble()` call per evaluation, so two Rands produce two `.nextDouble()` calls + // in the body; one-call output would indicate incorrect CSE. + val expr = Add(Rand(Literal(0L, LongType)), Rand(Literal(0L, LongType))) + val result = CometBatchKernelCodegen.generateSource(expr, IndexedSeq.empty) + val occurrences = "\\.nextDouble\\(\\)".r.findAllIn(result.body).size + assert( + occurrences == 2, + "expected two independent Rand evaluations (no CSE on nondeterministic), " + + s"got $occurrences; src=\n${CodeFormatter.format(result.code)}") + } + + test("DecimalVector getDecimal specializes to unscaled-long fast path for short precision") { + // Mirrors Spark's `UnsafeRow.getDecimal` split at `Decimal.MAX_LONG_DIGITS` (18), done at + // codegen time rather than at runtime. The dispatcher reads the `BoundReference`'s + // `DecimalType` at source-generation time and emits only the fast-path branch when + // `precision <= 18`. The fast path reads the low 8 bytes of the 16-byte Arrow decimal128 + // slot directly as a signed long via `ArrowBuf.getLong` and wraps with + // `Decimal.createUnsafe`, avoiding the `BigDecimal` allocation `DecimalVector.getObject` + // would perform. For precision > 18 the generator emits only the slow-path branch + // (`getObject + Decimal.apply`); see the companion test below. + val decimalVectorClass = CometBatchKernelCodegen.vectorClassBySimpleName("DecimalVector") + val spec = ArrowColumnSpec(decimalVectorClass, nullable = true) + val expr = BoundReference(0, DecimalType(18, 2), nullable = true) + val result = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(spec)) + assert( + result.body.contains(".createUnsafe("), + "expected Decimal.createUnsafe call on fast path; got:\n" + + CodeFormatter.format(result.code)) + assert( + result.body.contains("Platform.getLong(") && + result.body.contains("this.col0_valueAddr"), + "expected unsafe Platform.getLong against cached valueAddr; got:\n" + + CodeFormatter.format(result.code)) + assert( + !result.body.contains(".getObject("), + "expected specialized fast path (no BigDecimal fallback branch in source); got:\n" + + CodeFormatter.format(result.code)) + assert( + !result.body.contains("if (precision <= 18)"), + "expected no runtime precision branch for known short-precision column; got:\n" + + CodeFormatter.format(result.code)) + } + + test("DecimalVector getDecimal specializes to BigDecimal slow path for long precision") { + // Companion to the fast-path test. For `DecimalType(p, s)` with `p > 18`, the unscaled value + // can exceed 64 bits, so the generator emits only the `getObject + Decimal.apply` branch. + // The fast path markers must be absent so the generated source is minimal for this column. + val decimalVectorClass = CometBatchKernelCodegen.vectorClassBySimpleName("DecimalVector") + val spec = ArrowColumnSpec(decimalVectorClass, nullable = true) + val expr = BoundReference(0, DecimalType(38, 10), nullable = true) + val result = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(spec)) + assert( + result.body.contains(".getObject(") && result.body.contains(".apply("), + s"expected BigDecimal slow path; got:\n${CodeFormatter.format(result.code)}") + assert( + !result.body.contains(".createUnsafe("), + "expected no fast-path emission for long-precision column; got:\n" + + CodeFormatter.format(result.code)) + assert( + !result.body.contains("if (precision <= 18)"), + "expected no runtime precision branch for known long-precision column; got:\n" + + CodeFormatter.format(result.code)) + } + + test("DecimalVector setSafe uses unscaled-long fast path for short-precision output") { + // The output writer specializes on the root expression's DecimalType precision. For + // precision <= 18 the Decimal's unscaled long is passed directly to + // `DecimalVector.setSafe(int, long)`, avoiding the BigDecimal allocation that + // `toJavaBigDecimal()` performs. Use a simple expression that produces a DecimalType output: + // `BoundReference(0, DecimalType(18, 2))` has output type DecimalType(18, 2), which is what + // the generator specializes on. + val decimalVectorClass = CometBatchKernelCodegen.vectorClassBySimpleName("DecimalVector") + val spec = ArrowColumnSpec(decimalVectorClass, nullable = true) + val expr = BoundReference(0, DecimalType(18, 2), nullable = true) + val result = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(spec)) + assert( + result.body.contains(".toUnscaledLong()"), + s"expected toUnscaledLong call on fast path; got:\n${CodeFormatter.format(result.code)}") + assert( + !result.body.contains(".toJavaBigDecimal("), + "expected no BigDecimal allocation for short-precision output; got:\n" + + CodeFormatter.format(result.code)) + } + + test("DecimalVector setSafe uses BigDecimal slow path for long-precision output") { + // Companion to the fast-path output test. Precision > 18 can have unscaled values exceeding + // 64 bits, so the writer must fall back to the BigDecimal path. + val decimalVectorClass = CometBatchKernelCodegen.vectorClassBySimpleName("DecimalVector") + val spec = ArrowColumnSpec(decimalVectorClass, nullable = true) + val expr = BoundReference(0, DecimalType(38, 10), nullable = true) + val result = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(spec)) + assert( + result.body.contains(".toJavaBigDecimal("), + s"expected BigDecimal slow path; got:\n${CodeFormatter.format(result.code)}") + assert( + !result.body.contains(".toUnscaledLong()"), + "expected no unscaled-long write for long-precision output; got:\n" + + CodeFormatter.format(result.code)) + } + + test("VarCharVector setSafe uses on-heap UTF8String shortcut") { + // The UTF8String output writer avoids the `byte[] b = $value.getBytes()` allocation when + // the UTF8String is on-heap by passing its backing byte[] directly to + // `VarCharVector.setSafe(int, byte[], int, int)`. Spark's string functions allocate their + // result on-heap, so this path hits for typical string expressions. Off-heap fallback + // (for passthrough of zero-copy input reads) stays as the else branch. + // + // Markers: `getBaseObject()` (inspecting the backing), `instanceof byte[]` (the branch), + // and `Platform.BYTE_ARRAY_OFFSET` (the on-heap offset math). + val expr = Upper(BoundReference(0, StringType, nullable = true)) + val result = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(nullableString)) + assert( + result.body.contains(".getBaseObject()"), + s"expected UTF8String.getBaseObject call; got:\n${CodeFormatter.format(result.code)}") + assert( + result.body.contains("instanceof byte[]"), + s"expected on-heap instanceof branch; got:\n${CodeFormatter.format(result.code)}") + assert( + result.body.contains("Platform.BYTE_ARRAY_OFFSET"), + "expected on-heap offset math via Platform.BYTE_ARRAY_OFFSET; got:\n" + + CodeFormatter.format(result.code)) + assert( + result.body.contains(".getBytes()"), + s"expected off-heap getBytes fallback; got:\n${CodeFormatter.format(result.code)}") + } + + test("non-nullable root expression omits the `if (isNull)` branch in default body") { + // When the bound expression claims `nullable = false`, the default body drops the + // `if (ev.isNull) output.setNull(i);` guard entirely. `Length` on a non-nullable column is + // itself non-nullable (Length.nullable = child.nullable = false), so the writer goes + // straight to the setSafe/set call. This test uses a non-NullIntolerant-short-circuit + // shape by wrapping Length in Coalesce, so we exercise the default branch of defaultBody + // rather than the NullIntolerant one. Actually, Length is NullIntolerant, so the NI branch + // fires; use an expression that's non-nullable but whose tree is not fully NullIntolerant + // to hit the default branch. `Coalesce(Seq(Length(col_non_null), Literal(0)))` has + // nullable=false (Coalesce is non-null when any child is) and Coalesce itself is not + // NullIntolerant, so the default branch runs. Assert `setNull` is absent. + val expr = Coalesce( + Seq(Length(BoundReference(0, StringType, nullable = false)), Literal(0, IntegerType))) + val result = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(nonNullableString)) + assert( + !result.body.contains("output.setNull(i);"), + "expected no setNull for a non-nullable root expression; got:\n" + + CodeFormatter.format(result.code)) + } + + test("nullable root expression keeps the `if (isNull)` branch in default body") { + // Baseline: when the root expression is nullable, the setNull branch must still be emitted. + // Uses Coalesce with a nullable child so the Coalesce itself remains nullable. Guards the + // NonNullableOutputShortCircuit optimization against over-firing. + val expr = Coalesce( + Seq( + Length(BoundReference(0, StringType, nullable = true)), + BoundReference(1, IntegerType, nullable = true))) + val result = CometBatchKernelCodegen.generateSource( + expr, + IndexedSeq( + nullableString, + ArrowColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = true))) + assert( + result.body.contains("output.setNull(i);"), + "expected setNull branch for a nullable root expression; got:\n" + + CodeFormatter.format(result.code)) + } + + test("ArrayType(StringType) output emits ListVector startNewValue/endValue recursion") { + // StringSplit produces ArrayType(StringType). emitWrite's ArrayType case should emit: + // - ListVector cast of output + // - child VarCharVector extraction via getDataVector + // - startNewValue + per-element loop + endValue + // - the per-element write recursing into the StringType case (which uses the UTF8 on-heap + // shortcut marker `instanceof byte[]`) + // Not asserting exact expression-specific text since Spark's StringSplit.doGenCode may drift + // across versions. Focus markers: ListVector cast, VarCharVector child cast, startNewValue, + // endValue, and the inner UTF8 shortcut branch. + val expr = + StringSplit( + BoundReference(0, StringType, nullable = true), + Literal.create(",", StringType), + Literal(-1, IntegerType)) + val result = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(nullableString)) + val src = result.body + val formatted = CodeFormatter.format(result.code) + assert(src.contains("ListVector"), s"expected ListVector in emitted body; got:\n$formatted") + assert(src.contains(".startNewValue("), s"expected startNewValue call; got:\n$formatted") + assert(src.contains(".endValue("), s"expected endValue call; got:\n$formatted") + assert( + src.contains(".getDataVector()"), + s"expected child vector extraction; got:\n$formatted") + assert( + src.contains("instanceof byte[]"), + s"expected inner UTF8 on-heap shortcut for string elements; got:\n$formatted") + } + + test("MapType output emits MapVector startNewValue/endValue + per-pair writes") { + // CreateMap produces MapType(k, v). emitWrite's MapType case should emit: + // - MapVector cast of output + // - entries StructVector extraction + // - typed key / value child casts via getChildByOrdinal(0) / (1) + // - startNewValue / endValue bracketing + // - setIndexDefined on each struct entry + // - keyArray() / valueArray() retrieval from the MapData source + // - null-guard on the value write (key is always non-null per Arrow invariant) + val expr = CreateMap( + Seq( + Literal.create("a", StringType), + Literal(1, IntegerType), + Literal.create("b", StringType), + Literal(2, IntegerType))) + val src = CometBatchKernelCodegen.generateSource(expr, IndexedSeq.empty).body + Seq( + "MapVector", + "StructVector", + ".startNewValue(", + ".endValue(", + ".setIndexDefined(", + ".keyArray()", + ".valueArray()", + ".isNullAt(").foreach { marker => + assert(src.contains(marker), s"expected $marker in MapType output emission; got:\n$src") + } + } + + test("ArrayType(StringType) input emits InputArray_col0 nested class with UTF8 child getter") { + // Array input with string elements: the kernel must expose a `getArray(0)` that hands Spark's + // `doGenCode` a zero-allocation `ArrayData` view onto the Arrow `ListVector`'s child + // `VarCharVector`. Markers: the nested class declaration, a `reset(int)` bracketing the + // per-row slice, the typed child getter using `fromAddress`, and a `getArray` switch on the + // ordinal returning the pre-allocated instance. + val varCharChildSpec = ScalarColumnSpec(varCharVectorClass, nullable = true) + val arraySpec = + ArrayColumnSpec(nullable = true, elementSparkType = StringType, element = varCharChildSpec) + val expr = Size(BoundReference(0, ArrayType(StringType), nullable = true)) + val src = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(arraySpec)).body + + assert( + src.contains("class InputArray_col0"), + s"expected nested ArrayData class for array col0; got:\n$src") + assert( + src.contains("col0_e") && src.contains("col0_arrayData"), + s"expected typed child-vector field and pre-allocated ArrayData instance; got:\n$src") + assert( + src.contains("getElementStartIndex(") && src.contains("getElementEndIndex("), + s"expected list-offset reads inside `reset`; got:\n$src") + assert( + src.contains("public org.apache.spark.unsafe.types.UTF8String getUTF8String(int i)"), + s"expected element-type-specific UTF8String getter; got:\n$src") + assert( + src.contains(".fromAddress("), + s"expected zero-copy UTF8 read inside the nested ArrayData; got:\n$src") + assert( + src.contains("public org.apache.spark.sql.catalyst.util.ArrayData getArray(int ordinal)"), + s"expected kernel-level getArray switch; got:\n$src") + assert( + src.contains("col0_arrayData.reset("), + s"expected getArray to reset the pre-allocated instance; got:\n$src") + } + + test("ArrayType(IntegerType) input emits primitive int getter in nested class") { + val intChildSpec = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = true) + val arraySpec = + ArrayColumnSpec(nullable = true, elementSparkType = IntegerType, element = intChildSpec) + val expr = Size(BoundReference(0, ArrayType(IntegerType), nullable = true)) + val src = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(arraySpec)).body + + assert( + src.contains("public int getInt(int i)"), + s"expected primitive int getter on nested array class; got:\n$src") + // Scalar-element fast path reads directly off the typed child vector; no BigDecimal / + // fromAddress scaffolding should leak in. + assert( + !src.contains(".fromAddress("), + s"int element getter should not wrap with UTF8 fromAddress; got:\n$src") + } + + test( + "ArrayType(DecimalType) short-precision input emits decimal128 fast-path via getLong in " + + "nested class") { + val decimalChildSpec = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("DecimalVector"), + nullable = true) + val arraySpec = ArrayColumnSpec( + nullable = true, + elementSparkType = DecimalType(10, 2), + element = decimalChildSpec) + val expr = + ElementAt( + BoundReference(0, ArrayType(DecimalType(10, 2)), nullable = true), + Literal(1, IntegerType)) + val src = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(arraySpec)).body + + // Fast path markers: reads the low 8 bytes of the decimal128 slot via getLong + createUnsafe. + // The slow path would go through getObject + Decimal.apply. + assert( + src.contains(".getLong(") && src.contains(".createUnsafe("), + s"expected decimal-input short-precision fast path in nested class; got:\n$src") + assert( + !src.contains(".getObject("), + s"short-precision decimal element should not use BigDecimal slow path; got:\n$src") + } + + test("ArrayType(DecimalType) long-precision input emits BigDecimal slow path in nested class") { + val decimalChildSpec = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("DecimalVector"), + nullable = true) + val arraySpec = ArrayColumnSpec( + nullable = true, + elementSparkType = DecimalType(30, 2), + element = decimalChildSpec) + val expr = + ElementAt( + BoundReference(0, ArrayType(DecimalType(30, 2)), nullable = true), + Literal(1, IntegerType)) + val src = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(arraySpec)).body + + assert( + src.contains(".getObject(") && src.contains("Decimal$.MODULE$"), + s"expected BigDecimal slow path for p>18 element; got:\n$src") + } + + // ============================================================================================ + // Nested-type tests. Each case verifies that a complex-within-complex shape emits a full + // nested-class tree (outer + inner), wired together through the path-suffix naming + // convention: `_e` for array element, `_f${fi}` for struct field fi. Scalar-element / scalar- + // field leaves reuse the typed-getter templates already covered by the single-depth tests. + // ============================================================================================ + + private def generate(expr: Expression, specs: IndexedSeq[ArrowColumnSpec]): String = + CometBatchKernelCodegen.generateSource(expr, specs).body + + test("Array> emits outer + inner array classes with _e_arrayData router") { + val innerArray = ArrayColumnSpec( + nullable = true, + elementSparkType = IntegerType, + element = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = true)) + val outerArray = ArrayColumnSpec( + nullable = true, + elementSparkType = ArrayType(IntegerType), + element = innerArray) + val expr = Size(BoundReference(0, ArrayType(ArrayType(IntegerType)), nullable = true)) + val src = generate(expr, IndexedSeq(outerArray)) + assert( + src.contains("class InputArray_col0 ") && src.contains("class InputArray_col0_e "), + s"expected both outer and inner array classes; got:\n$src") + assert( + src.contains("col0_e_arrayData.reset("), + s"expected outer class to route getArray via inner instance reset; got:\n$src") + assert( + src.contains("public int getInt(int i)"), + s"expected innermost scalar getter for IntegerType element; got:\n$src") + } + + test("Array> emits array class routing getStruct via _e_structData") { + val innerStruct = StructColumnSpec( + nullable = true, + fields = Seq( + StructFieldSpec( + "a", + IntegerType, + nullable = true, + ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = true)))) + val outerArray = ArrayColumnSpec( + nullable = true, + elementSparkType = StructType(Seq(StructField("a", IntegerType, nullable = true)).toArray), + element = innerStruct) + val elemType = StructType(Seq(StructField("a", IntegerType, nullable = true)).toArray) + val expr = Size(BoundReference(0, ArrayType(elemType), nullable = true)) + val src = generate(expr, IndexedSeq(outerArray)) + assert( + src.contains("class InputArray_col0 ") && src.contains("class InputStruct_col0_e "), + s"expected array-of-struct nested classes; got:\n$src") + assert( + src.contains("col0_e_structData.reset(startIndex + i)"), + s"expected array getStruct to route to inner struct instance; got:\n$src") + } + + test("Struct> emits outer + inner struct classes") { + val innerStruct = StructColumnSpec( + nullable = true, + fields = Seq( + StructFieldSpec( + "a", + IntegerType, + nullable = true, + ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = true)))) + val outerStruct = StructColumnSpec( + nullable = true, + fields = Seq( + StructFieldSpec( + "s", + StructType(Seq(StructField("a", IntegerType, nullable = true)).toArray), + nullable = true, + innerStruct))) + val innerType = StructType(Seq(StructField("a", IntegerType, nullable = true)).toArray) + val outerType = StructType(Seq(StructField("s", innerType, nullable = true)).toArray) + val expr = GetStructField( + GetStructField(BoundReference(0, outerType, nullable = true), 0, Some("s")), + 0, + Some("a")) + val src = generate(expr, IndexedSeq(outerStruct)) + assert( + src.contains("class InputStruct_col0 ") && src.contains("class InputStruct_col0_f0 "), + s"expected outer + inner struct classes; got:\n$src") + assert( + src.contains("col0_f0_structData.reset(this.rowIdx)"), + s"expected outer struct getStruct to route to inner instance; got:\n$src") + assert( + src.contains("public int getInt(int ordinal)"), + s"expected innermost getInt on InputStruct_col0_f0; got:\n$src") + } + + test("Struct> emits struct class routing getArray via _f0_arrayData") { + val innerArray = ArrayColumnSpec( + nullable = true, + elementSparkType = IntegerType, + element = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = true)) + val outerStruct = StructColumnSpec( + nullable = true, + fields = Seq(StructFieldSpec("a", ArrayType(IntegerType), nullable = true, innerArray))) + val structType = + StructType(Seq(StructField("a", ArrayType(IntegerType), nullable = true)).toArray) + val expr = Size(GetStructField(BoundReference(0, structType, nullable = true), 0, Some("a"))) + val src = generate(expr, IndexedSeq(outerStruct)) + assert( + src.contains("class InputStruct_col0 ") && src.contains("class InputArray_col0_f0 "), + s"expected struct-of-array nested classes; got:\n$src") + assert( + src.contains("col0_f0_arrayData.reset("), + s"expected struct getArray to route to inner array instance; got:\n$src") + } + + test("Map emits InputMap_col0 + keyArray / valueArray views") { + val keySpec = ScalarColumnSpec(varCharVectorClass, nullable = true) + val valueSpec = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = true) + val mapSpec = MapColumnSpec( + nullable = true, + keySparkType = StringType, + valueSparkType = IntegerType, + key = keySpec, + value = valueSpec) + val expr = Size(BoundReference(0, MapType(StringType, IntegerType), nullable = true)) + val src = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(mapSpec)).body + assert( + src.contains("class InputMap_col0 "), + s"expected InputMap_col0 nested class; got:\n$src") + assert( + src.contains("class InputArray_col0_k ") && src.contains("class InputArray_col0_v "), + s"expected key/value array view classes; got:\n$src") + assert( + src.contains("col0_k_arrayData.reset(this.startIndex, this.length)"), + s"expected keyArray to reset with slice; got:\n$src") + assert( + src.contains("col0_v_arrayData.reset(this.startIndex, this.length)"), + s"expected valueArray to reset with slice; got:\n$src") + assert( + src.contains("public org.apache.spark.sql.catalyst.util.MapData getMap(int ordinal)"), + s"expected kernel-level getMap switch; got:\n$src") + assert( + src.contains("col0_mapData.reset("), + s"expected getMap to reset the pre-allocated map instance; got:\n$src") + } + + test("Map, Array> emits complex key and complex value views") { + val keyElem = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = true) + val keyArraySpec = + ArrayColumnSpec(nullable = true, elementSparkType = IntegerType, element = keyElem) + val valueElem = ScalarColumnSpec(varCharVectorClass, nullable = true) + val valueArraySpec = + ArrayColumnSpec(nullable = true, elementSparkType = StringType, element = valueElem) + val mapSpec = MapColumnSpec( + nullable = true, + keySparkType = ArrayType(IntegerType), + valueSparkType = ArrayType(StringType), + key = keyArraySpec, + value = valueArraySpec) + val expr = Size( + BoundReference(0, MapType(ArrayType(IntegerType), ArrayType(StringType)), nullable = true)) + val src = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(mapSpec)).body + // Full chain of nested classes should appear: top-level map view, the key/value array + // views, and the inner array classes for each complex key/value element. + Seq( + "class InputMap_col0 ", + "class InputArray_col0_k ", + "class InputArray_col0_v ", + "class InputArray_col0_k_e ", + "class InputArray_col0_v_e ").foreach { marker => + assert(src.contains(marker), s"expected $marker in emission; got:\n$src") + } + } +} + +/** + * Minimal fake expressions for the `canHandle` rejection tests. Each opts into one of the marker + * traits whose presence forces a serde-level fallback. Bodies are unreachable; `canHandle` walks + * the tree structurally. + */ +private case class FakeCodegenFallback(child: Expression) + extends Expression + with CodegenFallback { + override def children: Seq[Expression] = Seq(child) + override def nullable: Boolean = true + override def dataType: DataType = StringType + override def eval(input: InternalRow): Any = null + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): Expression = copy(child = newChildren.head) +} + +private case class FakeNondeterministic() extends LeafExpression with Nondeterministic { + override def nullable: Boolean = true + override def dataType: DataType = IntegerType + override protected def initializeInternal(partitionIndex: Int): Unit = {} + override protected def evalInternal(input: InternalRow): Any = 0 + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + throw new UnsupportedOperationException("test fake; never reaches codegen") +} + +private case class FakeUnevaluable() extends LeafExpression with Unevaluable { + override def nullable: Boolean = true + override def dataType: DataType = IntegerType +} diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometScalaUDFCompositionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometScalaUDFCompositionBenchmark.scala new file mode 100644 index 0000000000..9485cb39e1 --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometScalaUDFCompositionBenchmark.scala @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.benchmark + +import org.apache.spark.benchmark.Benchmark + +import org.apache.comet.CometConf + +/** + * Benchmark user-registered ScalaUDFs composed in trees, comparing the codegen dispatcher to the + * "feature off" baseline (where a user UDF forces the containing operator to Spark) and to + * Comet's native built-ins that are functionally equivalent. + * + * Four modes per composition: + * + * - '''Spark''': all Comet disabled. + * - '''Comet (native built-ins)''': the composition rewritten using Comet-native Spark + * built-ins (`upper`, `lower`, `reverse`, `concat`, `length`). Ceiling for what pure native + * can do. + * - '''Comet (user UDFs, dispatcher disabled)''': user UDFs with + * `codegenDispatch.mode=disabled`. `CometScalaUDF.convert` returns `None`, the ScalaUDF's + * Project falls back to Spark. This is the state before the dispatcher landed: any user UDF + * loses Comet acceleration on the whole hosting operator. + * - '''Comet (user UDFs, codegen dispatch)''': user UDFs with the dispatcher forced on. One + * Janino-compiled kernel per (tree, input schema) handles the whole composition in one JNI + * hop. + * + * Story the numbers should tell: dispatcher (mode 4) tracks native (mode 2) and beats + * dispatcher-disabled (mode 3) by the cost of the Spark fallback / ColumnarToRow hand-off. + * + * To run: + * {{{ + * SPARK_GENERATE_BENCHMARK_FILES=1 \ + * make benchmark-org.apache.spark.sql.benchmark.CometScalaUDFCompositionBenchmark + * }}} + */ +object CometScalaUDFCompositionBenchmark extends CometBenchmarkBase { + + private def registerThreeLevelUdfs(): Unit = { + spark.udf.register("lvl1_upper", (s: String) => if (s == null) null else s.toUpperCase) + spark.udf.register("lvl2_reverse", (s: String) => if (s == null) null else s.reverse) + spark.udf.register("lvl3_length", (s: String) => if (s == null) -1 else s.length) + } + + private def registerMultiColUdfs(): Unit = { + spark.udf.register("upperU", (s: String) => if (s == null) null else s.toUpperCase) + spark.udf.register("lowerU", (s: String) => if (s == null) null else s.toLowerCase) + spark.udf.register( + "joinU", + (a: String, b: String) => if (a == null || b == null) null else s"$a-$b") + } + + override def runCometBenchmark(mainArgs: Array[String]): Unit = { + runBenchmarkWithTable("scalaudf composition", 1024 * 1024) { v => + withTempPath { dir => + withTempTable("parquetV1Table") { + prepareTable( + dir, + spark.sql(s"SELECT REPEAT(CAST(value AS STRING), 10) AS c1 FROM $tbl")) + + registerThreeLevelUdfs() + runBenchmark("three-level composition: length(reverse(upper(c1)))") { + runModes( + name = "three-level", + cardinality = v, + nativeQuery = "SELECT length(reverse(upper(c1))) FROM parquetV1Table", + udfQuery = "SELECT lvl3_length(lvl2_reverse(lvl1_upper(c1))) FROM parquetV1Table") + } + } + } + + withTempPath { dir => + withTempTable("parquetV1Table") { + prepareTable( + dir, + spark.sql( + "SELECT REPEAT(CAST(value AS STRING), 10) AS c1, " + + s"CAST(value AS STRING) AS c2 FROM $tbl")) + + registerMultiColUdfs() + runBenchmark("multi-col composition: concat(upper(c1), '-', lower(c2))") { + runModes( + name = "multi-col", + cardinality = v, + nativeQuery = "SELECT concat(upper(c1), '-', lower(c2)) FROM parquetV1Table", + udfQuery = "SELECT joinU(upperU(c1), lowerU(c2)) FROM parquetV1Table") + } + } + } + + // Aggregate shape: SUM over the composition output. Picks up the cost of "dispatcher + // disabled" breaking the columnar pipeline around an aggregate, not just the Project + // itself. When the dispatcher is off, the Project falls back to Spark, which typically + // drags the surrounding HashAggregate off Comet's columnar path too (ColumnarToRow hand-off + // plus Spark's row-based aggregate). When the dispatcher is on, scan -> project -> agg + // stays columnar end to end. + withTempPath { dir => + withTempTable("parquetV1Table") { + prepareTable( + dir, + spark.sql(s"SELECT REPEAT(CAST(value AS STRING), 10) AS c1 FROM $tbl")) + + registerThreeLevelUdfs() + runBenchmark("agg over composition: SUM(length(reverse(upper(c1))))") { + runModes( + name = "agg-over-composition", + cardinality = v, + nativeQuery = "SELECT SUM(length(reverse(upper(c1)))) FROM parquetV1Table", + udfQuery = + "SELECT SUM(lvl3_length(lvl2_reverse(lvl1_upper(c1)))) FROM parquetV1Table") + } + } + } + } + } + + private def runModes( + name: String, + cardinality: Long, + nativeQuery: String, + udfQuery: String): Unit = { + val benchmark = new Benchmark(name, cardinality, output = output) + + benchmark.addCase("Spark") { _ => + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + spark.sql(udfQuery).noop() + } + } + + // Pure Comet-native rewrite of the composition using built-ins. Ceiling for native perf. + // Case conversion is enabled because upper/lower are in the tree. + benchmark.addCase("Comet (native built-ins)") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_CASE_CONVERSION_ENABLED.key -> "true") { + spark.sql(nativeQuery).noop() + } + } + + // User UDFs with dispatcher disabled. The ScalaUDF serde returns None, the hosting Project + // falls back to Spark. State of the world before the dispatcher landed: any ScalaUDF in a + // query sinks the containing operator. + benchmark.addCase("Comet (user UDFs, dispatcher disabled)") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_CODEGEN_DISPATCH_MODE.key -> CometConf.CODEGEN_DISPATCH_DISABLED) { + spark.sql(udfQuery).noop() + } + } + + // User UDFs through the codegen dispatcher. One Janino-compiled kernel for the whole tree, + // one JNI hop per batch. + benchmark.addCase("Comet (user UDFs, codegen dispatch)") { _ => + withSQLConf( + CometConf.COMET_ENABLED.key -> "true", + CometConf.COMET_EXEC_ENABLED.key -> "true", + CometConf.COMET_CODEGEN_DISPATCH_MODE.key -> CometConf.CODEGEN_DISPATCH_FORCE) { + spark.sql(udfQuery).noop() + } + } + + benchmark.run() + } +}