diff --git a/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java b/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java index aed53c57df..5e76819810 100644 --- a/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java +++ b/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java @@ -27,6 +27,8 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.ValueVector; +import org.apache.spark.TaskContext; +import org.apache.spark.comet.CometTaskContextShim; /** * JNI entry point for native execution to invoke a {@link CometUDF}. Matches the static-method @@ -48,13 +50,52 @@ public class CometUdfBridge { * @param inputSchemaPtrs addresses of pre-allocated FFI_ArrowSchema structs (one per input) * @param outArrayPtr address of pre-allocated FFI_ArrowArray for the result * @param outSchemaPtr address of pre-allocated FFI_ArrowSchema for the result + * @param numRows row count of the current batch. Mirrors DataFusion's {@code + * ScalarFunctionArgs.number_rows}; the only batch-size signal a zero-input UDF (e.g. a + * zero-arg non-deterministic ScalaUDF) ever sees. + * @param taskContext propagated Spark {@link TaskContext} from the driving Spark task thread, or + * {@code null} outside a Spark task. Treated as ground truth for the call: installed as the + * thread-local on entry, with the prior value (if any) saved and restored in {@code finally}. + * Lets partition-sensitive built-ins ({@code Rand}, {@code Uuid}, {@code + * MonotonicallyIncreasingID}) work from Tokio workers and avoids reusing a stale TaskContext + * left on a worker by a previous task. */ public static void evaluate( String udfClassName, long[] inputArrayPtrs, long[] inputSchemaPtrs, long outArrayPtr, - long outSchemaPtr) { + long outSchemaPtr, + int numRows, + TaskContext taskContext) { + // Save-and-restore rather than only-install-if-null: the propagated context is the ground + // truth for this call. Any value already on the thread is either (a) the same object on a + // Spark task thread, or (b) stale from a prior task on a reused Tokio worker. + TaskContext prior = TaskContext.get(); + if (taskContext != null) { + CometTaskContextShim.set(taskContext); + } + try { + evaluateInternal( + udfClassName, inputArrayPtrs, inputSchemaPtrs, outArrayPtr, outSchemaPtr, numRows); + } finally { + if (taskContext != null) { + if (prior != null) { + CometTaskContextShim.set(prior); + } else { + CometTaskContextShim.unset(); + } + } + } + } + + private static void evaluateInternal( + String udfClassName, + long[] inputArrayPtrs, + long[] inputSchemaPtrs, + long outArrayPtr, + long outSchemaPtr, + int numRows) { CometUDF udf = INSTANCES.computeIfAbsent( udfClassName, @@ -84,23 +125,17 @@ public static void evaluate( inputs[i] = Data.importVector(allocator, inArr, inSch, null); } - result = udf.evaluate(inputs); + result = udf.evaluate(inputs, numRows); if (!(result instanceof FieldVector)) { throw new RuntimeException( "CometUDF.evaluate() must return a FieldVector, got: " + result.getClass().getName()); } - // Result length must match the longest input. Scalar (length-1) inputs - // are allowed to be shorter, but a vector input bounds the output. - int expectedLen = 0; - for (ValueVector v : inputs) { - expectedLen = Math.max(expectedLen, v.getValueCount()); - } - if (result.getValueCount() != expectedLen) { + if (result.getValueCount() != numRows) { throw new RuntimeException( "CometUDF.evaluate() returned " + result.getValueCount() + " rows, expected " - + expectedLen); + + numRows); } ArrowArray outArr = ArrowArray.wrap(outArrayPtr); ArrowSchema outSch = ArrowSchema.wrap(outSchemaPtr); diff --git a/common/src/main/scala/org/apache/comet/udf/CometUDF.scala b/common/src/main/scala/org/apache/comet/udf/CometUDF.scala index 29186f0a2c..5b6652d90a 100644 --- a/common/src/main/scala/org/apache/comet/udf/CometUDF.scala +++ b/common/src/main/scala/org/apache/comet/udf/CometUDF.scala @@ -27,11 +27,16 @@ import org.apache.arrow.vector.ValueVector * * - Vector arguments arrive at the row count of the current batch. * - Scalar (literal-folded) arguments arrive as length-1 vectors and must be read at index 0. - * - The returned vector's length must match the longest input. + * - The returned vector's length must match `numRows`. + * + * `numRows` mirrors DataFusion's `ScalarFunctionArgs.number_rows` and is the batch row count. + * UDFs that always have at least one batch-length input can derive length from the inputs and + * ignore `numRows`; UDFs that may be called with zero data columns (e.g. a zero-arg ScalaUDF) + * need `numRows` to know how many rows to produce. * * Implementations must have a public no-arg constructor and must be stateless: a single instance * per class is cached and shared across native worker threads for the lifetime of the JVM. */ trait CometUDF { - def evaluate(inputs: Array[ValueVector]): ValueVector + def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector } diff --git a/common/src/main/scala/org/apache/spark/comet/CometTaskContextShim.scala b/common/src/main/scala/org/apache/spark/comet/CometTaskContextShim.scala new file mode 100644 index 0000000000..9218fc5e78 --- /dev/null +++ b/common/src/main/scala/org/apache/spark/comet/CometTaskContextShim.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.comet + +import org.apache.spark.TaskContext + +/** + * Package-private access shim for `TaskContext.setTaskContext` / `TaskContext.unset`. + * + * Both methods are declared `protected[spark]` on Spark's `TaskContext` companion, so they are + * reachable from code inside the `org.apache.spark` package tree but not from `org.apache.comet`. + * The Comet JVM UDF bridge needs to set the thread-local `TaskContext` on its caller thread (a + * Tokio worker thread with no `TaskContext`) so the user's UDF body and any partition-sensitive + * built-ins (`Rand`, `Uuid`, `MonotonicallyIncreasingID`, etc.) see the driving Spark task's + * `TaskContext`. This shim lives in `org.apache.spark.comet` so it can call through to the + * protected methods, and exposes plain public forwarders the bridge (which lives in + * `org.apache.comet.udf`) can use. + */ +object CometTaskContextShim { + + def set(taskContext: TaskContext): Unit = TaskContext.setTaskContext(taskContext) + + def unset(): Unit = TaskContext.unset() +} diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 5d3dbb8266..f5b04cc51d 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -306,6 +306,13 @@ struct ExecutionContext { pub tracing_memory_metric_name: String, /// Pre-computed tracing event name for executePlan calls pub tracing_event_name: String, + /// Spark `TaskContext` captured on the driving Spark task thread at `createPlan` time. + /// Threaded into every JVM scalar UDF the planner builds so the JNI bridge can install it + /// as the thread-local `TaskContext` for the Tokio worker running the UDF. `None` when no + /// driving Spark task is present (unit tests, direct native driver runs). The `Arc` is + /// cheap to clone; the underlying `Global` releases its JNI global ref on drop + /// via `jni`'s `Drop` impl. + pub task_context: Option>>>, } /// Accept serialized query plan and return the address of the native query plan. @@ -332,6 +339,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( task_attempt_id: jlong, task_cpus: jlong, key_unwrapper_obj: JObject, + task_context_obj: JObject, ) -> jlong { try_unwrap_or_throw(&e, |env| { // Deserialize Spark configs @@ -453,6 +461,15 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( String::new() }; + // Capture the driving Spark task's TaskContext as a JNI global reference when + // non-null. The `Arc>` releases its global ref on drop, so cleanup + // is automatic when the ExecutionContext drops. + let task_context = if !task_context_obj.is_null() { + Some(Arc::new(jni_new_global_ref!(env, task_context_obj)?)) + } else { + None + }; + let exec_context = Box::new(ExecutionContext { id, task_attempt_id, @@ -479,6 +496,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( "thread_{rust_thread_id}_comet_memory_reserved" ), tracing_event_name, + task_context, }); Ok(Box::into_raw(exec_context) as i64) @@ -703,7 +721,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( let start = Instant::now(); let planner = PhysicalPlanner::new(Arc::clone(&exec_context.session_ctx), partition) - .with_exec_id(exec_context_id); + .with_exec_id(exec_context_id) + .with_task_context(exec_context.task_context.clone()); let (scans, shuffle_scans, root_op) = planner.create_plan( &exec_context.spark_plan, &mut exec_context.input_sources.clone(), diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 478c7a8d98..b00f140026 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -183,6 +183,9 @@ pub struct PhysicalPlanner { partition: i32, session_ctx: Arc, query_context_registry: Arc, + /// Captured at `createPlan` time on `ExecutionContext`; see that struct for the + /// propagation rationale. `None` when no driving Spark task is available. + task_context: Option>>>, } impl Default for PhysicalPlanner { @@ -198,16 +201,24 @@ impl PhysicalPlanner { session_ctx, partition, query_context_registry: datafusion_comet_spark_expr::create_query_context_map(), + task_context: None, } } - pub fn with_exec_id(self, exec_context_id: i64) -> Self { - Self { - exec_context_id, - partition: self.partition, - session_ctx: Arc::clone(&self.session_ctx), - query_context_registry: Arc::clone(&self.query_context_registry), - } + pub fn with_exec_id(mut self, exec_context_id: i64) -> Self { + self.exec_context_id = exec_context_id; + self + } + + /// Attach the Spark `TaskContext` global reference captured at `createPlan` time. Cloned + /// into every `JvmScalarUdfExpr` the planner builds so the JNI bridge can install it as + /// the thread-local on the Tokio worker driving the UDF. + pub fn with_task_context( + mut self, + task_context: Option>>>, + ) -> Self { + self.task_context = task_context; + self } /// Return session context of this planner. @@ -735,6 +746,7 @@ impl PhysicalPlanner { args, return_type, udf.return_nullable, + self.task_context.clone(), ))) } expr => Err(GeneralError(format!("Not implemented: {expr:?}"))), diff --git a/native/jni-bridge/src/comet_udf_bridge.rs b/native/jni-bridge/src/comet_udf_bridge.rs index 89cd8ee514..e531d20cb1 100644 --- a/native/jni-bridge/src/comet_udf_bridge.rs +++ b/native/jni-bridge/src/comet_udf_bridge.rs @@ -41,7 +41,7 @@ impl<'a> CometUdfBridge<'a> { method_evaluate: env.get_static_method_id( JNIString::new(Self::JVM_CLASS), jni::jni_str!("evaluate"), - jni::jni_sig!("(Ljava/lang/String;[J[JJJ)V"), + jni::jni_sig!("(Ljava/lang/String;[J[JJJILorg/apache/spark/TaskContext;)V"), )?, method_evaluate_ret: ReturnType::Primitive(Primitive::Void), class, diff --git a/native/spark-expr/src/jvm_udf/mod.rs b/native/spark-expr/src/jvm_udf/mod.rs index 668a2b6727..4ed25de6ee 100644 --- a/native/spark-expr/src/jvm_udf/mod.rs +++ b/native/spark-expr/src/jvm_udf/mod.rs @@ -31,7 +31,7 @@ use datafusion::physical_expr::PhysicalExpr; use datafusion_comet_jni_bridge::errors::{CometError, ExecutionError}; use datafusion_comet_jni_bridge::JVMClasses; -use jni::objects::{JObject, JValue}; +use jni::objects::{Global, JObject, JValue}; /// A scalar expression that delegates evaluation to a JVM-side `CometUDF` via JNI. /// The JVM class named by `class_name` must implement `org.apache.comet.udf.CometUDF`. @@ -41,6 +41,14 @@ pub struct JvmScalarUdfExpr { args: Vec>, return_type: DataType, return_nullable: bool, + /// Captured at `createPlan` time and threaded here by the planner. Passed through the + /// JNI bridge so `CometUdfBridge.evaluate` can install it as the Tokio worker's + /// thread-local `TaskContext`. Without this, partition-sensitive built-ins inside a UDF + /// tree (`Rand`, `Uuid`, `MonotonicallyIncreasingID`, user code reading + /// `TaskContext.get()`) see `null` and seed / branch incorrectly. `None` when no driving + /// Spark task is available; the bridge then leaves whatever `TaskContext.get()` already + /// returns in place. + task_context: Option>>>, } impl JvmScalarUdfExpr { @@ -49,12 +57,14 @@ impl JvmScalarUdfExpr { args: Vec>, return_type: DataType, return_nullable: bool, + task_context: Option>>>, ) -> Self { Self { class_name, args, return_type, return_nullable, + task_context, } } } @@ -186,7 +196,14 @@ impl PhysicalExpr for JvmScalarUdfExpr { .set_region(env, 0, &in_sch_ptrs) .map_err(|e| CometError::JNI { source: e })?; - // Call CometUdfBridge.evaluate(String, long[], long[], long, long) + // Pass a null jobject when no TaskContext was propagated so the bridge's null-guard + // leaves the worker thread's current TaskContext.get() in place. The borrow must + // outlive `call_static_method_unchecked`. + let null_task_context = JObject::null(); + let task_context_ref: &JObject = match &self.task_context { + Some(gref) => gref.as_obj(), + None => &null_task_context, + }; let ret = unsafe { env.call_static_method_unchecked( &bridge.class, @@ -198,6 +215,8 @@ impl PhysicalExpr for JvmScalarUdfExpr { JValue::Object(JObject::from(in_sch_java).as_ref()).as_jni(), JValue::Long(out_arr_ptr).as_jni(), JValue::Long(out_sch_ptr).as_jni(), + JValue::Int(batch.num_rows() as i32).as_jni(), + JValue::Object(task_context_ref).as_jni(), ], ) }; @@ -234,6 +253,7 @@ impl PhysicalExpr for JvmScalarUdfExpr { children, self.return_type.clone(), self.return_nullable, + self.task_context.clone(), ))) } } diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index a93564811c..6140eca553 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -127,7 +127,10 @@ class CometExecIterator( memoryConfig.memoryLimitPerTask, taskAttemptId, taskCPUs, - keyUnwrapper) + keyUnwrapper, + // Propagated to Tokio workers running JVM UDFs so they see this Spark task's + // TaskContext. See CometUdfBridge.evaluate. + TaskContext.get()) } private var nextBatch: Option[ColumnarBatch] = None diff --git a/spark/src/main/scala/org/apache/comet/Native.scala b/spark/src/main/scala/org/apache/comet/Native.scala index c003bcd138..3cfa51b6e1 100644 --- a/spark/src/main/scala/org/apache/comet/Native.scala +++ b/spark/src/main/scala/org/apache/comet/Native.scala @@ -21,7 +21,7 @@ package org.apache.comet import java.nio.ByteBuffer -import org.apache.spark.CometTaskMemoryManager +import org.apache.spark.{CometTaskMemoryManager, TaskContext} import org.apache.spark.sql.comet.CometMetricNode import org.apache.comet.parquet.CometFileKeyUnwrapper @@ -69,7 +69,8 @@ class Native extends NativeBase { memoryLimitPerTask: Long, taskAttemptId: Long, taskCPUs: Long, - keyUnwrapper: CometFileKeyUnwrapper): Long + keyUnwrapper: CometFileKeyUnwrapper, + taskContext: TaskContext): Long // scalastyle:on /**