From 1c681b1cd5745e570fcadd743a72da22ba0ae793 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Tue, 12 May 2026 14:50:14 -0400 Subject: [PATCH 1/3] Bring over TaskContext and num_rows fixes from #4267. --- .github/workflows/pr_build_linux.yml | 1 + .github/workflows/pr_build_macos.yml | 1 + .../org/apache/comet/udf/CometUdfBridge.java | 48 ++++- .../scala/org/apache/comet/udf/CometUDF.scala | 9 +- .../spark/comet/CometTaskContextShim.scala | 41 ++++ native/core/src/execution/jni_api.rs | 21 +- native/core/src/execution/planner.rs | 26 ++- native/jni-bridge/src/comet_udf_bridge.rs | 2 +- native/spark-expr/src/jvm_udf/mod.rs | 24 ++- .../org/apache/comet/CometExecIterator.scala | 5 +- .../main/scala/org/apache/comet/Native.scala | 5 +- .../spark/comet/udf/CometUdfBridgeSuite.scala | 182 ++++++++++++++++++ 12 files changed, 339 insertions(+), 26 deletions(-) create mode 100644 common/src/main/scala/org/apache/spark/comet/CometTaskContextShim.scala create mode 100644 spark/src/test/scala/org/apache/spark/comet/udf/CometUdfBridgeSuite.scala diff --git a/.github/workflows/pr_build_linux.yml b/.github/workflows/pr_build_linux.yml index 5c1ae2dc47..b8b482596d 100644 --- a/.github/workflows/pr_build_linux.yml +++ b/.github/workflows/pr_build_linux.yml @@ -365,6 +365,7 @@ jobs: org.apache.spark.sql.comet.CometShuffleFallbackStickinessSuite org.apache.spark.sql.comet.CometDecimalArithmeticViewSuite org.apache.comet.objectstore.NativeConfigSuite + org.apache.spark.comet.udf.CometUdfBridgeSuite - name: "expressions" value: | org.apache.comet.CometExpressionSuite diff --git a/.github/workflows/pr_build_macos.yml b/.github/workflows/pr_build_macos.yml index 29eca594a2..4d3e0d4611 100644 --- a/.github/workflows/pr_build_macos.yml +++ b/.github/workflows/pr_build_macos.yml @@ -213,6 +213,7 @@ jobs: org.apache.spark.sql.comet.CometShuffleFallbackStickinessSuite org.apache.spark.sql.comet.CometDecimalArithmeticViewSuite org.apache.comet.objectstore.NativeConfigSuite + org.apache.spark.comet.udf.CometUdfBridgeSuite - name: "expressions" value: | org.apache.comet.CometExpressionSuite diff --git a/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java b/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java index aed53c57df..0db4ae0c50 100644 --- a/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java +++ b/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java @@ -27,6 +27,8 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.ValueVector; +import org.apache.spark.TaskContext; +import org.apache.spark.comet.CometTaskContextShim; /** * JNI entry point for native execution to invoke a {@link CometUDF}. Matches the static-method @@ -48,13 +50,45 @@ public class CometUdfBridge { * @param inputSchemaPtrs addresses of pre-allocated FFI_ArrowSchema structs (one per input) * @param outArrayPtr address of pre-allocated FFI_ArrowArray for the result * @param outSchemaPtr address of pre-allocated FFI_ArrowSchema for the result + * @param numRows row count of the current batch. Mirrors DataFusion's {@code + * ScalarFunctionArgs.number_rows}; the only batch-size signal a zero-input UDF (e.g. a + * zero-arg non-deterministic ScalaUDF) ever sees. + * @param taskContext propagated Spark {@link TaskContext} from the driving Spark task thread, or + * {@code null} outside a Spark task. Installed as the thread-local for the duration of the + * call when the current thread has none, so partition-sensitive built-ins ({@code Rand}, + * {@code Uuid}, {@code MonotonicallyIncreasingID}) work from Tokio workers. Cleared in {@code + * finally} to avoid leaking across worker reuse. */ public static void evaluate( String udfClassName, long[] inputArrayPtrs, long[] inputSchemaPtrs, long outArrayPtr, - long outSchemaPtr) { + long outSchemaPtr, + int numRows, + TaskContext taskContext) { + boolean installedTaskContext = false; + if (taskContext != null && TaskContext.get() == null) { + CometTaskContextShim.set(taskContext); + installedTaskContext = true; + } + try { + evaluateInternal( + udfClassName, inputArrayPtrs, inputSchemaPtrs, outArrayPtr, outSchemaPtr, numRows); + } finally { + if (installedTaskContext) { + CometTaskContextShim.unset(); + } + } + } + + private static void evaluateInternal( + String udfClassName, + long[] inputArrayPtrs, + long[] inputSchemaPtrs, + long outArrayPtr, + long outSchemaPtr, + int numRows) { CometUDF udf = INSTANCES.computeIfAbsent( udfClassName, @@ -84,23 +118,17 @@ public static void evaluate( inputs[i] = Data.importVector(allocator, inArr, inSch, null); } - result = udf.evaluate(inputs); + result = udf.evaluate(inputs, numRows); if (!(result instanceof FieldVector)) { throw new RuntimeException( "CometUDF.evaluate() must return a FieldVector, got: " + result.getClass().getName()); } - // Result length must match the longest input. Scalar (length-1) inputs - // are allowed to be shorter, but a vector input bounds the output. - int expectedLen = 0; - for (ValueVector v : inputs) { - expectedLen = Math.max(expectedLen, v.getValueCount()); - } - if (result.getValueCount() != expectedLen) { + if (result.getValueCount() != numRows) { throw new RuntimeException( "CometUDF.evaluate() returned " + result.getValueCount() + " rows, expected " - + expectedLen); + + numRows); } ArrowArray outArr = ArrowArray.wrap(outArrayPtr); ArrowSchema outSch = ArrowSchema.wrap(outSchemaPtr); diff --git a/common/src/main/scala/org/apache/comet/udf/CometUDF.scala b/common/src/main/scala/org/apache/comet/udf/CometUDF.scala index 29186f0a2c..5b6652d90a 100644 --- a/common/src/main/scala/org/apache/comet/udf/CometUDF.scala +++ b/common/src/main/scala/org/apache/comet/udf/CometUDF.scala @@ -27,11 +27,16 @@ import org.apache.arrow.vector.ValueVector * * - Vector arguments arrive at the row count of the current batch. * - Scalar (literal-folded) arguments arrive as length-1 vectors and must be read at index 0. - * - The returned vector's length must match the longest input. + * - The returned vector's length must match `numRows`. + * + * `numRows` mirrors DataFusion's `ScalarFunctionArgs.number_rows` and is the batch row count. + * UDFs that always have at least one batch-length input can derive length from the inputs and + * ignore `numRows`; UDFs that may be called with zero data columns (e.g. a zero-arg ScalaUDF) + * need `numRows` to know how many rows to produce. * * Implementations must have a public no-arg constructor and must be stateless: a single instance * per class is cached and shared across native worker threads for the lifetime of the JVM. */ trait CometUDF { - def evaluate(inputs: Array[ValueVector]): ValueVector + def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector } diff --git a/common/src/main/scala/org/apache/spark/comet/CometTaskContextShim.scala b/common/src/main/scala/org/apache/spark/comet/CometTaskContextShim.scala new file mode 100644 index 0000000000..9218fc5e78 --- /dev/null +++ b/common/src/main/scala/org/apache/spark/comet/CometTaskContextShim.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.comet + +import org.apache.spark.TaskContext + +/** + * Package-private access shim for `TaskContext.setTaskContext` / `TaskContext.unset`. + * + * Both methods are declared `protected[spark]` on Spark's `TaskContext` companion, so they are + * reachable from code inside the `org.apache.spark` package tree but not from `org.apache.comet`. + * The Comet JVM UDF bridge needs to set the thread-local `TaskContext` on its caller thread (a + * Tokio worker thread with no `TaskContext`) so the user's UDF body and any partition-sensitive + * built-ins (`Rand`, `Uuid`, `MonotonicallyIncreasingID`, etc.) see the driving Spark task's + * `TaskContext`. This shim lives in `org.apache.spark.comet` so it can call through to the + * protected methods, and exposes plain public forwarders the bridge (which lives in + * `org.apache.comet.udf`) can use. + */ +object CometTaskContextShim { + + def set(taskContext: TaskContext): Unit = TaskContext.setTaskContext(taskContext) + + def unset(): Unit = TaskContext.unset() +} diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 5d3dbb8266..f5b04cc51d 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -306,6 +306,13 @@ struct ExecutionContext { pub tracing_memory_metric_name: String, /// Pre-computed tracing event name for executePlan calls pub tracing_event_name: String, + /// Spark `TaskContext` captured on the driving Spark task thread at `createPlan` time. + /// Threaded into every JVM scalar UDF the planner builds so the JNI bridge can install it + /// as the thread-local `TaskContext` for the Tokio worker running the UDF. `None` when no + /// driving Spark task is present (unit tests, direct native driver runs). The `Arc` is + /// cheap to clone; the underlying `Global` releases its JNI global ref on drop + /// via `jni`'s `Drop` impl. + pub task_context: Option>>>, } /// Accept serialized query plan and return the address of the native query plan. @@ -332,6 +339,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( task_attempt_id: jlong, task_cpus: jlong, key_unwrapper_obj: JObject, + task_context_obj: JObject, ) -> jlong { try_unwrap_or_throw(&e, |env| { // Deserialize Spark configs @@ -453,6 +461,15 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( String::new() }; + // Capture the driving Spark task's TaskContext as a JNI global reference when + // non-null. The `Arc>` releases its global ref on drop, so cleanup + // is automatic when the ExecutionContext drops. + let task_context = if !task_context_obj.is_null() { + Some(Arc::new(jni_new_global_ref!(env, task_context_obj)?)) + } else { + None + }; + let exec_context = Box::new(ExecutionContext { id, task_attempt_id, @@ -479,6 +496,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( "thread_{rust_thread_id}_comet_memory_reserved" ), tracing_event_name, + task_context, }); Ok(Box::into_raw(exec_context) as i64) @@ -703,7 +721,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( let start = Instant::now(); let planner = PhysicalPlanner::new(Arc::clone(&exec_context.session_ctx), partition) - .with_exec_id(exec_context_id); + .with_exec_id(exec_context_id) + .with_task_context(exec_context.task_context.clone()); let (scans, shuffle_scans, root_op) = planner.create_plan( &exec_context.spark_plan, &mut exec_context.input_sources.clone(), diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 478c7a8d98..b00f140026 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -183,6 +183,9 @@ pub struct PhysicalPlanner { partition: i32, session_ctx: Arc, query_context_registry: Arc, + /// Captured at `createPlan` time on `ExecutionContext`; see that struct for the + /// propagation rationale. `None` when no driving Spark task is available. + task_context: Option>>>, } impl Default for PhysicalPlanner { @@ -198,16 +201,24 @@ impl PhysicalPlanner { session_ctx, partition, query_context_registry: datafusion_comet_spark_expr::create_query_context_map(), + task_context: None, } } - pub fn with_exec_id(self, exec_context_id: i64) -> Self { - Self { - exec_context_id, - partition: self.partition, - session_ctx: Arc::clone(&self.session_ctx), - query_context_registry: Arc::clone(&self.query_context_registry), - } + pub fn with_exec_id(mut self, exec_context_id: i64) -> Self { + self.exec_context_id = exec_context_id; + self + } + + /// Attach the Spark `TaskContext` global reference captured at `createPlan` time. Cloned + /// into every `JvmScalarUdfExpr` the planner builds so the JNI bridge can install it as + /// the thread-local on the Tokio worker driving the UDF. + pub fn with_task_context( + mut self, + task_context: Option>>>, + ) -> Self { + self.task_context = task_context; + self } /// Return session context of this planner. @@ -735,6 +746,7 @@ impl PhysicalPlanner { args, return_type, udf.return_nullable, + self.task_context.clone(), ))) } expr => Err(GeneralError(format!("Not implemented: {expr:?}"))), diff --git a/native/jni-bridge/src/comet_udf_bridge.rs b/native/jni-bridge/src/comet_udf_bridge.rs index 89cd8ee514..e531d20cb1 100644 --- a/native/jni-bridge/src/comet_udf_bridge.rs +++ b/native/jni-bridge/src/comet_udf_bridge.rs @@ -41,7 +41,7 @@ impl<'a> CometUdfBridge<'a> { method_evaluate: env.get_static_method_id( JNIString::new(Self::JVM_CLASS), jni::jni_str!("evaluate"), - jni::jni_sig!("(Ljava/lang/String;[J[JJJ)V"), + jni::jni_sig!("(Ljava/lang/String;[J[JJJILorg/apache/spark/TaskContext;)V"), )?, method_evaluate_ret: ReturnType::Primitive(Primitive::Void), class, diff --git a/native/spark-expr/src/jvm_udf/mod.rs b/native/spark-expr/src/jvm_udf/mod.rs index 668a2b6727..4ed25de6ee 100644 --- a/native/spark-expr/src/jvm_udf/mod.rs +++ b/native/spark-expr/src/jvm_udf/mod.rs @@ -31,7 +31,7 @@ use datafusion::physical_expr::PhysicalExpr; use datafusion_comet_jni_bridge::errors::{CometError, ExecutionError}; use datafusion_comet_jni_bridge::JVMClasses; -use jni::objects::{JObject, JValue}; +use jni::objects::{Global, JObject, JValue}; /// A scalar expression that delegates evaluation to a JVM-side `CometUDF` via JNI. /// The JVM class named by `class_name` must implement `org.apache.comet.udf.CometUDF`. @@ -41,6 +41,14 @@ pub struct JvmScalarUdfExpr { args: Vec>, return_type: DataType, return_nullable: bool, + /// Captured at `createPlan` time and threaded here by the planner. Passed through the + /// JNI bridge so `CometUdfBridge.evaluate` can install it as the Tokio worker's + /// thread-local `TaskContext`. Without this, partition-sensitive built-ins inside a UDF + /// tree (`Rand`, `Uuid`, `MonotonicallyIncreasingID`, user code reading + /// `TaskContext.get()`) see `null` and seed / branch incorrectly. `None` when no driving + /// Spark task is available; the bridge then leaves whatever `TaskContext.get()` already + /// returns in place. + task_context: Option>>>, } impl JvmScalarUdfExpr { @@ -49,12 +57,14 @@ impl JvmScalarUdfExpr { args: Vec>, return_type: DataType, return_nullable: bool, + task_context: Option>>>, ) -> Self { Self { class_name, args, return_type, return_nullable, + task_context, } } } @@ -186,7 +196,14 @@ impl PhysicalExpr for JvmScalarUdfExpr { .set_region(env, 0, &in_sch_ptrs) .map_err(|e| CometError::JNI { source: e })?; - // Call CometUdfBridge.evaluate(String, long[], long[], long, long) + // Pass a null jobject when no TaskContext was propagated so the bridge's null-guard + // leaves the worker thread's current TaskContext.get() in place. The borrow must + // outlive `call_static_method_unchecked`. + let null_task_context = JObject::null(); + let task_context_ref: &JObject = match &self.task_context { + Some(gref) => gref.as_obj(), + None => &null_task_context, + }; let ret = unsafe { env.call_static_method_unchecked( &bridge.class, @@ -198,6 +215,8 @@ impl PhysicalExpr for JvmScalarUdfExpr { JValue::Object(JObject::from(in_sch_java).as_ref()).as_jni(), JValue::Long(out_arr_ptr).as_jni(), JValue::Long(out_sch_ptr).as_jni(), + JValue::Int(batch.num_rows() as i32).as_jni(), + JValue::Object(task_context_ref).as_jni(), ], ) }; @@ -234,6 +253,7 @@ impl PhysicalExpr for JvmScalarUdfExpr { children, self.return_type.clone(), self.return_nullable, + self.task_context.clone(), ))) } } diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index a93564811c..6140eca553 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -127,7 +127,10 @@ class CometExecIterator( memoryConfig.memoryLimitPerTask, taskAttemptId, taskCPUs, - keyUnwrapper) + keyUnwrapper, + // Propagated to Tokio workers running JVM UDFs so they see this Spark task's + // TaskContext. See CometUdfBridge.evaluate. + TaskContext.get()) } private var nextBatch: Option[ColumnarBatch] = None diff --git a/spark/src/main/scala/org/apache/comet/Native.scala b/spark/src/main/scala/org/apache/comet/Native.scala index c003bcd138..3cfa51b6e1 100644 --- a/spark/src/main/scala/org/apache/comet/Native.scala +++ b/spark/src/main/scala/org/apache/comet/Native.scala @@ -21,7 +21,7 @@ package org.apache.comet import java.nio.ByteBuffer -import org.apache.spark.CometTaskMemoryManager +import org.apache.spark.{CometTaskMemoryManager, TaskContext} import org.apache.spark.sql.comet.CometMetricNode import org.apache.comet.parquet.CometFileKeyUnwrapper @@ -69,7 +69,8 @@ class Native extends NativeBase { memoryLimitPerTask: Long, taskAttemptId: Long, taskCPUs: Long, - keyUnwrapper: CometFileKeyUnwrapper): Long + keyUnwrapper: CometFileKeyUnwrapper, + taskContext: TaskContext): Long // scalastyle:on /** diff --git a/spark/src/test/scala/org/apache/spark/comet/udf/CometUdfBridgeSuite.scala b/spark/src/test/scala/org/apache/spark/comet/udf/CometUdfBridgeSuite.scala new file mode 100644 index 0000000000..31da4d4ef8 --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/comet/udf/CometUdfBridgeSuite.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.comet.udf + +import java.io.File + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.arrow.c.{ArrowArray, ArrowSchema, Data} +import org.apache.arrow.vector.{IntVector, ValueVector} +import org.apache.spark.TaskContext + +import org.apache.comet.CometArrowAllocator +import org.apache.comet.udf.{CometUDF, CometUdfBridge} + +/** + * JVM-side unit tests for `CometUdfBridge.evaluate`. Exercises the new `numRows` length contract + * and the `TaskContext` install / unset behavior added to the bridge. The native -> JNI side of + * the change is validated by Rust compiling against the new `JvmScalarUdfExpr::new` signature and + * the updated `jni_sig!` for `CometUdfBridge.evaluate`; an end-to-end JNI round-trip will land + * with the dispatcher PR that introduces production serde routing for the bridge. + */ +class CometUdfBridgeSuite extends AnyFunSuite with BeforeAndAfterAll { + + // Surefire sets `java.io.tmpdir` to `${project.build.directory}/tmp`, which Maven does not + // auto-create. Arrow's `JniLoader` extracts `libarrow_cdata_jni.{dylib,so}` from the + // `arrow-c-data` jar via `File.createTempFile`, which fails with `No such file or directory` + // if the dir is missing. Other Comet suites avoid the issue because they enter Arrow C Data + // through native code; this suite calls `Data.exportVector` directly from the JVM. + override def beforeAll(): Unit = { + super.beforeAll() + new File(System.getProperty("java.io.tmpdir")).mkdirs() + } + + private def runEvaluate( + udfClass: String, + numRows: Int, + taskContext: TaskContext): ValueVector = { + val outArr = ArrowArray.allocateNew(CometArrowAllocator) + val outSch = ArrowSchema.allocateNew(CometArrowAllocator) + try { + CometUdfBridge.evaluate( + udfClass, + new Array[Long](0), + new Array[Long](0), + outArr.memoryAddress(), + outSch.memoryAddress(), + numRows, + taskContext) + Data.importVector(CometArrowAllocator, outArr, outSch, null) + } finally { + outArr.close() + outSch.close() + } + } + + test("evaluate uses numRows as the result length contract for zero-input UDFs") { + // Pre-numRows the bridge derived expected length from max input length, which is 0 when + // there are no inputs, so a zero-arg UDF could not produce any rows. The fix passes + // numRows through and uses it as the contract. + val out = runEvaluate(classOf[RowCountTestUDF].getName, 7, null).asInstanceOf[IntVector] + try { + assert(out.getValueCount === 7) + (0 until 7).foreach(i => assert(out.get(i) === 42)) + } finally { + out.close() + } + } + + test("evaluate installs a propagated TaskContext when the worker thread has none") { + val prior = TaskContext.get() + if (prior != null) TaskContext.unset() + try { + val propagated = TaskContext.empty() + RecordTaskContextUDF.reset() + val out = runEvaluate(classOf[RecordTaskContextUDF].getName, 1, propagated) + out.close() + assert( + RecordTaskContextUDF.observed === propagated, + "bridge should install the propagated TaskContext as the thread-local for the call") + assert( + TaskContext.get() === null, + "bridge must clear the thread-local in finally so Tokio workers do not leak it") + } finally { + if (prior != null) TaskContext.setTaskContext(prior) + } + } + + test("evaluate leaves the thread-local alone when no TaskContext is propagated") { + val prior = TaskContext.get() + if (prior != null) TaskContext.unset() + try { + RecordTaskContextUDF.reset() + val out = runEvaluate(classOf[RecordTaskContextUDF].getName, 1, null) + out.close() + assert( + RecordTaskContextUDF.observed === null, + "no TaskContext propagated and none on thread, so the UDF body must observe null") + } finally { + if (prior != null) TaskContext.setTaskContext(prior) + } + } + + test("evaluate does not overwrite an existing thread-local TaskContext") { + // Spark task threads (as opposed to Tokio workers) already have a TaskContext installed + // by Spark; the bridge must not stomp on it with the propagated reference. + val prior = TaskContext.get() + val installed = TaskContext.empty() + TaskContext.setTaskContext(installed) + try { + val propagated = TaskContext.empty() + RecordTaskContextUDF.reset() + val out = runEvaluate(classOf[RecordTaskContextUDF].getName, 1, propagated) + out.close() + assert( + RecordTaskContextUDF.observed === installed, + "current thread already has a TaskContext; bridge must leave it in place") + assert( + TaskContext.get() === installed, + "thread-local must still be the originally-installed context after evaluate returns") + } finally { + TaskContext.unset() + if (prior != null) TaskContext.setTaskContext(prior) + } + } +} + +/** Zero-input UDF: returns `numRows` rows of the constant 42. */ +class RowCountTestUDF extends CometUDF { + override def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector = { + val out = new IntVector("out", CometArrowAllocator) + out.allocateNew(numRows) + var i = 0 + while (i < numRows) { + out.set(i, 42) + i += 1 + } + out.setValueCount(numRows) + out + } +} + +object RecordTaskContextUDF { + // Volatile because the bridge is allowed to call from any thread; the assertion thread + // needs to observe whatever evaluate() wrote. + @volatile var observed: TaskContext = _ + def reset(): Unit = { observed = null } +} + +/** Records what `TaskContext.get()` returned at evaluate time, for assertion. */ +class RecordTaskContextUDF extends CometUDF { + override def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector = { + RecordTaskContextUDF.observed = TaskContext.get() + val out = new IntVector("out", CometArrowAllocator) + out.allocateNew(numRows) + var i = 0 + while (i < numRows) { + out.set(i, 0) + i += 1 + } + out.setValueCount(numRows) + out + } +} From 204aacde3629666b9e72542d940d165a4e32041d Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Tue, 12 May 2026 15:14:26 -0400 Subject: [PATCH 2/3] Address PR feedback. Always use stored TaskContext. --- .../org/apache/comet/udf/CometUdfBridge.java | 25 ++++++++++++------- .../spark/comet/udf/CometUdfBridgeSuite.scala | 19 +++++++------- 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java b/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java index 0db4ae0c50..5e76819810 100644 --- a/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java +++ b/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java @@ -54,10 +54,11 @@ public class CometUdfBridge { * ScalarFunctionArgs.number_rows}; the only batch-size signal a zero-input UDF (e.g. a * zero-arg non-deterministic ScalaUDF) ever sees. * @param taskContext propagated Spark {@link TaskContext} from the driving Spark task thread, or - * {@code null} outside a Spark task. Installed as the thread-local for the duration of the - * call when the current thread has none, so partition-sensitive built-ins ({@code Rand}, - * {@code Uuid}, {@code MonotonicallyIncreasingID}) work from Tokio workers. Cleared in {@code - * finally} to avoid leaking across worker reuse. + * {@code null} outside a Spark task. Treated as ground truth for the call: installed as the + * thread-local on entry, with the prior value (if any) saved and restored in {@code finally}. + * Lets partition-sensitive built-ins ({@code Rand}, {@code Uuid}, {@code + * MonotonicallyIncreasingID}) work from Tokio workers and avoids reusing a stale TaskContext + * left on a worker by a previous task. */ public static void evaluate( String udfClassName, @@ -67,17 +68,23 @@ public static void evaluate( long outSchemaPtr, int numRows, TaskContext taskContext) { - boolean installedTaskContext = false; - if (taskContext != null && TaskContext.get() == null) { + // Save-and-restore rather than only-install-if-null: the propagated context is the ground + // truth for this call. Any value already on the thread is either (a) the same object on a + // Spark task thread, or (b) stale from a prior task on a reused Tokio worker. + TaskContext prior = TaskContext.get(); + if (taskContext != null) { CometTaskContextShim.set(taskContext); - installedTaskContext = true; } try { evaluateInternal( udfClassName, inputArrayPtrs, inputSchemaPtrs, outArrayPtr, outSchemaPtr, numRows); } finally { - if (installedTaskContext) { - CometTaskContextShim.unset(); + if (taskContext != null) { + if (prior != null) { + CometTaskContextShim.set(prior); + } else { + CometTaskContextShim.unset(); + } } } } diff --git a/spark/src/test/scala/org/apache/spark/comet/udf/CometUdfBridgeSuite.scala b/spark/src/test/scala/org/apache/spark/comet/udf/CometUdfBridgeSuite.scala index 31da4d4ef8..370d0838ad 100644 --- a/spark/src/test/scala/org/apache/spark/comet/udf/CometUdfBridgeSuite.scala +++ b/spark/src/test/scala/org/apache/spark/comet/udf/CometUdfBridgeSuite.scala @@ -119,23 +119,22 @@ class CometUdfBridgeSuite extends AnyFunSuite with BeforeAndAfterAll { } } - test("evaluate does not overwrite an existing thread-local TaskContext") { - // Spark task threads (as opposed to Tokio workers) already have a TaskContext installed - // by Spark; the bridge must not stomp on it with the propagated reference. + test("evaluate overwrites a stale thread-local TaskContext and restores it after") { + // The thread-local on a reused Tokio worker may be stale from a previous task, so the + // bridge treats the propagated TaskContext as ground truth: install it, save the prior, + // restore the prior in finally. val prior = TaskContext.get() - val installed = TaskContext.empty() - TaskContext.setTaskContext(installed) + val stale = TaskContext.empty() + TaskContext.setTaskContext(stale) try { val propagated = TaskContext.empty() RecordTaskContextUDF.reset() val out = runEvaluate(classOf[RecordTaskContextUDF].getName, 1, propagated) out.close() assert( - RecordTaskContextUDF.observed === installed, - "current thread already has a TaskContext; bridge must leave it in place") - assert( - TaskContext.get() === installed, - "thread-local must still be the originally-installed context after evaluate returns") + RecordTaskContextUDF.observed === propagated, + "bridge must install the propagated TaskContext over whatever was on the thread") + assert(TaskContext.get() === stale, "bridge must restore the prior thread-local in finally") } finally { TaskContext.unset() if (prior != null) TaskContext.setTaskContext(prior) From 23af390d6f453106336d13ecbd0a6b8ad287a1ce Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Tue, 12 May 2026 17:01:29 -0400 Subject: [PATCH 3/3] Remove test suite that doesn't work in CI due to Arrow shading. --- .github/workflows/pr_build_linux.yml | 1 - .github/workflows/pr_build_macos.yml | 1 - .../spark/comet/udf/CometUdfBridgeSuite.scala | 181 ------------------ 3 files changed, 183 deletions(-) delete mode 100644 spark/src/test/scala/org/apache/spark/comet/udf/CometUdfBridgeSuite.scala diff --git a/.github/workflows/pr_build_linux.yml b/.github/workflows/pr_build_linux.yml index 3af57a5e28..78b9481cb1 100644 --- a/.github/workflows/pr_build_linux.yml +++ b/.github/workflows/pr_build_linux.yml @@ -366,7 +366,6 @@ jobs: org.apache.spark.sql.comet.CometShuffleFallbackStickinessSuite org.apache.spark.sql.comet.CometDecimalArithmeticViewSuite org.apache.comet.objectstore.NativeConfigSuite - org.apache.spark.comet.udf.CometUdfBridgeSuite - name: "expressions" value: | org.apache.comet.CometExpressionSuite diff --git a/.github/workflows/pr_build_macos.yml b/.github/workflows/pr_build_macos.yml index d999766f22..e77cc8f720 100644 --- a/.github/workflows/pr_build_macos.yml +++ b/.github/workflows/pr_build_macos.yml @@ -214,7 +214,6 @@ jobs: org.apache.spark.sql.comet.CometShuffleFallbackStickinessSuite org.apache.spark.sql.comet.CometDecimalArithmeticViewSuite org.apache.comet.objectstore.NativeConfigSuite - org.apache.spark.comet.udf.CometUdfBridgeSuite - name: "expressions" value: | org.apache.comet.CometExpressionSuite diff --git a/spark/src/test/scala/org/apache/spark/comet/udf/CometUdfBridgeSuite.scala b/spark/src/test/scala/org/apache/spark/comet/udf/CometUdfBridgeSuite.scala deleted file mode 100644 index 370d0838ad..0000000000 --- a/spark/src/test/scala/org/apache/spark/comet/udf/CometUdfBridgeSuite.scala +++ /dev/null @@ -1,181 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.spark.comet.udf - -import java.io.File - -import org.scalatest.BeforeAndAfterAll -import org.scalatest.funsuite.AnyFunSuite - -import org.apache.arrow.c.{ArrowArray, ArrowSchema, Data} -import org.apache.arrow.vector.{IntVector, ValueVector} -import org.apache.spark.TaskContext - -import org.apache.comet.CometArrowAllocator -import org.apache.comet.udf.{CometUDF, CometUdfBridge} - -/** - * JVM-side unit tests for `CometUdfBridge.evaluate`. Exercises the new `numRows` length contract - * and the `TaskContext` install / unset behavior added to the bridge. The native -> JNI side of - * the change is validated by Rust compiling against the new `JvmScalarUdfExpr::new` signature and - * the updated `jni_sig!` for `CometUdfBridge.evaluate`; an end-to-end JNI round-trip will land - * with the dispatcher PR that introduces production serde routing for the bridge. - */ -class CometUdfBridgeSuite extends AnyFunSuite with BeforeAndAfterAll { - - // Surefire sets `java.io.tmpdir` to `${project.build.directory}/tmp`, which Maven does not - // auto-create. Arrow's `JniLoader` extracts `libarrow_cdata_jni.{dylib,so}` from the - // `arrow-c-data` jar via `File.createTempFile`, which fails with `No such file or directory` - // if the dir is missing. Other Comet suites avoid the issue because they enter Arrow C Data - // through native code; this suite calls `Data.exportVector` directly from the JVM. - override def beforeAll(): Unit = { - super.beforeAll() - new File(System.getProperty("java.io.tmpdir")).mkdirs() - } - - private def runEvaluate( - udfClass: String, - numRows: Int, - taskContext: TaskContext): ValueVector = { - val outArr = ArrowArray.allocateNew(CometArrowAllocator) - val outSch = ArrowSchema.allocateNew(CometArrowAllocator) - try { - CometUdfBridge.evaluate( - udfClass, - new Array[Long](0), - new Array[Long](0), - outArr.memoryAddress(), - outSch.memoryAddress(), - numRows, - taskContext) - Data.importVector(CometArrowAllocator, outArr, outSch, null) - } finally { - outArr.close() - outSch.close() - } - } - - test("evaluate uses numRows as the result length contract for zero-input UDFs") { - // Pre-numRows the bridge derived expected length from max input length, which is 0 when - // there are no inputs, so a zero-arg UDF could not produce any rows. The fix passes - // numRows through and uses it as the contract. - val out = runEvaluate(classOf[RowCountTestUDF].getName, 7, null).asInstanceOf[IntVector] - try { - assert(out.getValueCount === 7) - (0 until 7).foreach(i => assert(out.get(i) === 42)) - } finally { - out.close() - } - } - - test("evaluate installs a propagated TaskContext when the worker thread has none") { - val prior = TaskContext.get() - if (prior != null) TaskContext.unset() - try { - val propagated = TaskContext.empty() - RecordTaskContextUDF.reset() - val out = runEvaluate(classOf[RecordTaskContextUDF].getName, 1, propagated) - out.close() - assert( - RecordTaskContextUDF.observed === propagated, - "bridge should install the propagated TaskContext as the thread-local for the call") - assert( - TaskContext.get() === null, - "bridge must clear the thread-local in finally so Tokio workers do not leak it") - } finally { - if (prior != null) TaskContext.setTaskContext(prior) - } - } - - test("evaluate leaves the thread-local alone when no TaskContext is propagated") { - val prior = TaskContext.get() - if (prior != null) TaskContext.unset() - try { - RecordTaskContextUDF.reset() - val out = runEvaluate(classOf[RecordTaskContextUDF].getName, 1, null) - out.close() - assert( - RecordTaskContextUDF.observed === null, - "no TaskContext propagated and none on thread, so the UDF body must observe null") - } finally { - if (prior != null) TaskContext.setTaskContext(prior) - } - } - - test("evaluate overwrites a stale thread-local TaskContext and restores it after") { - // The thread-local on a reused Tokio worker may be stale from a previous task, so the - // bridge treats the propagated TaskContext as ground truth: install it, save the prior, - // restore the prior in finally. - val prior = TaskContext.get() - val stale = TaskContext.empty() - TaskContext.setTaskContext(stale) - try { - val propagated = TaskContext.empty() - RecordTaskContextUDF.reset() - val out = runEvaluate(classOf[RecordTaskContextUDF].getName, 1, propagated) - out.close() - assert( - RecordTaskContextUDF.observed === propagated, - "bridge must install the propagated TaskContext over whatever was on the thread") - assert(TaskContext.get() === stale, "bridge must restore the prior thread-local in finally") - } finally { - TaskContext.unset() - if (prior != null) TaskContext.setTaskContext(prior) - } - } -} - -/** Zero-input UDF: returns `numRows` rows of the constant 42. */ -class RowCountTestUDF extends CometUDF { - override def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector = { - val out = new IntVector("out", CometArrowAllocator) - out.allocateNew(numRows) - var i = 0 - while (i < numRows) { - out.set(i, 42) - i += 1 - } - out.setValueCount(numRows) - out - } -} - -object RecordTaskContextUDF { - // Volatile because the bridge is allowed to call from any thread; the assertion thread - // needs to observe whatever evaluate() wrote. - @volatile var observed: TaskContext = _ - def reset(): Unit = { observed = null } -} - -/** Records what `TaskContext.get()` returned at evaluate time, for assertion. */ -class RecordTaskContextUDF extends CometUDF { - override def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector = { - RecordTaskContextUDF.observed = TaskContext.get() - val out = new IntVector("out", CometArrowAllocator) - out.allocateNew(numRows) - var i = 0 - while (i < numRows) { - out.set(i, 0) - i += 1 - } - out.setValueCount(numRows) - out - } -}