Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 45 additions & 10 deletions common/src/main/java/org/apache/comet/udf/CometUdfBridge.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.ValueVector;
import org.apache.spark.TaskContext;
import org.apache.spark.comet.CometTaskContextShim;

/**
* JNI entry point for native execution to invoke a {@link CometUDF}. Matches the static-method
Expand All @@ -48,13 +50,52 @@ public class CometUdfBridge {
* @param inputSchemaPtrs addresses of pre-allocated FFI_ArrowSchema structs (one per input)
* @param outArrayPtr address of pre-allocated FFI_ArrowArray for the result
* @param outSchemaPtr address of pre-allocated FFI_ArrowSchema for the result
* @param numRows row count of the current batch. Mirrors DataFusion's {@code
* ScalarFunctionArgs.number_rows}; the only batch-size signal a zero-input UDF (e.g. a
* zero-arg non-deterministic ScalaUDF) ever sees.
* @param taskContext propagated Spark {@link TaskContext} from the driving Spark task thread, or
* {@code null} outside a Spark task. Treated as ground truth for the call: installed as the
* thread-local on entry, with the prior value (if any) saved and restored in {@code finally}.
* Lets partition-sensitive built-ins ({@code Rand}, {@code Uuid}, {@code
* MonotonicallyIncreasingID}) work from Tokio workers and avoids reusing a stale TaskContext
* left on a worker by a previous task.
*/
public static void evaluate(
String udfClassName,
long[] inputArrayPtrs,
long[] inputSchemaPtrs,
long outArrayPtr,
long outSchemaPtr) {
long outSchemaPtr,
int numRows,
TaskContext taskContext) {
// Save-and-restore rather than only-install-if-null: the propagated context is the ground
// truth for this call. Any value already on the thread is either (a) the same object on a
// Spark task thread, or (b) stale from a prior task on a reused Tokio worker.
TaskContext prior = TaskContext.get();
if (taskContext != null) {
CometTaskContextShim.set(taskContext);
}
try {
evaluateInternal(
udfClassName, inputArrayPtrs, inputSchemaPtrs, outArrayPtr, outSchemaPtr, numRows);
} finally {
if (taskContext != null) {
if (prior != null) {
CometTaskContextShim.set(prior);
} else {
CometTaskContextShim.unset();
}
}
}
}

private static void evaluateInternal(
String udfClassName,
long[] inputArrayPtrs,
long[] inputSchemaPtrs,
long outArrayPtr,
long outSchemaPtr,
int numRows) {
CometUDF udf =
INSTANCES.computeIfAbsent(
udfClassName,
Expand Down Expand Up @@ -84,23 +125,17 @@ public static void evaluate(
inputs[i] = Data.importVector(allocator, inArr, inSch, null);
}

result = udf.evaluate(inputs);
result = udf.evaluate(inputs, numRows);
if (!(result instanceof FieldVector)) {
throw new RuntimeException(
"CometUDF.evaluate() must return a FieldVector, got: " + result.getClass().getName());
}
// Result length must match the longest input. Scalar (length-1) inputs
// are allowed to be shorter, but a vector input bounds the output.
int expectedLen = 0;
for (ValueVector v : inputs) {
expectedLen = Math.max(expectedLen, v.getValueCount());
}
if (result.getValueCount() != expectedLen) {
if (result.getValueCount() != numRows) {
throw new RuntimeException(
"CometUDF.evaluate() returned "
+ result.getValueCount()
+ " rows, expected "
+ expectedLen);
+ numRows);
}
ArrowArray outArr = ArrowArray.wrap(outArrayPtr);
ArrowSchema outSch = ArrowSchema.wrap(outSchemaPtr);
Expand Down
9 changes: 7 additions & 2 deletions common/src/main/scala/org/apache/comet/udf/CometUDF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,16 @@ import org.apache.arrow.vector.ValueVector
*
* - Vector arguments arrive at the row count of the current batch.
* - Scalar (literal-folded) arguments arrive as length-1 vectors and must be read at index 0.
* - The returned vector's length must match the longest input.
* - The returned vector's length must match `numRows`.
*
* `numRows` mirrors DataFusion's `ScalarFunctionArgs.number_rows` and is the batch row count.
* UDFs that always have at least one batch-length input can derive length from the inputs and
* ignore `numRows`; UDFs that may be called with zero data columns (e.g. a zero-arg ScalaUDF)
* need `numRows` to know how many rows to produce.
*
* Implementations must have a public no-arg constructor and must be stateless: a single instance
* per class is cached and shared across native worker threads for the lifetime of the JVM.
*/
trait CometUDF {
def evaluate(inputs: Array[ValueVector]): ValueVector
def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.spark.comet

import org.apache.spark.TaskContext

/**
* Package-private access shim for `TaskContext.setTaskContext` / `TaskContext.unset`.
*
* Both methods are declared `protected[spark]` on Spark's `TaskContext` companion, so they are
* reachable from code inside the `org.apache.spark` package tree but not from `org.apache.comet`.
* The Comet JVM UDF bridge needs to set the thread-local `TaskContext` on its caller thread (a
* Tokio worker thread with no `TaskContext`) so the user's UDF body and any partition-sensitive
* built-ins (`Rand`, `Uuid`, `MonotonicallyIncreasingID`, etc.) see the driving Spark task's
* `TaskContext`. This shim lives in `org.apache.spark.comet` so it can call through to the
* protected methods, and exposes plain public forwarders the bridge (which lives in
* `org.apache.comet.udf`) can use.
*/
object CometTaskContextShim {

def set(taskContext: TaskContext): Unit = TaskContext.setTaskContext(taskContext)

def unset(): Unit = TaskContext.unset()
}
21 changes: 20 additions & 1 deletion native/core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,13 @@ struct ExecutionContext {
pub tracing_memory_metric_name: String,
/// Pre-computed tracing event name for executePlan calls
pub tracing_event_name: String,
/// Spark `TaskContext` captured on the driving Spark task thread at `createPlan` time.
/// Threaded into every JVM scalar UDF the planner builds so the JNI bridge can install it
/// as the thread-local `TaskContext` for the Tokio worker running the UDF. `None` when no
/// driving Spark task is present (unit tests, direct native driver runs). The `Arc` is
/// cheap to clone; the underlying `Global<JObject>` releases its JNI global ref on drop
/// via `jni`'s `Drop` impl.
pub task_context: Option<Arc<Global<JObject<'static>>>>,
}

/// Accept serialized query plan and return the address of the native query plan.
Expand All @@ -332,6 +339,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
task_attempt_id: jlong,
task_cpus: jlong,
key_unwrapper_obj: JObject,
task_context_obj: JObject,
) -> jlong {
try_unwrap_or_throw(&e, |env| {
// Deserialize Spark configs
Expand Down Expand Up @@ -453,6 +461,15 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
String::new()
};

// Capture the driving Spark task's TaskContext as a JNI global reference when
// non-null. The `Arc<Global<JObject>>` releases its global ref on drop, so cleanup
// is automatic when the ExecutionContext drops.
let task_context = if !task_context_obj.is_null() {
Some(Arc::new(jni_new_global_ref!(env, task_context_obj)?))
} else {
None
};

let exec_context = Box::new(ExecutionContext {
id,
task_attempt_id,
Expand All @@ -479,6 +496,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
"thread_{rust_thread_id}_comet_memory_reserved"
),
tracing_event_name,
task_context,
});

Ok(Box::into_raw(exec_context) as i64)
Expand Down Expand Up @@ -703,7 +721,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
let start = Instant::now();
let planner =
PhysicalPlanner::new(Arc::clone(&exec_context.session_ctx), partition)
.with_exec_id(exec_context_id);
.with_exec_id(exec_context_id)
.with_task_context(exec_context.task_context.clone());
let (scans, shuffle_scans, root_op) = planner.create_plan(
&exec_context.spark_plan,
&mut exec_context.input_sources.clone(),
Expand Down
26 changes: 19 additions & 7 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ pub struct PhysicalPlanner {
partition: i32,
session_ctx: Arc<SessionContext>,
query_context_registry: Arc<datafusion_comet_spark_expr::QueryContextMap>,
/// Captured at `createPlan` time on `ExecutionContext`; see that struct for the
/// propagation rationale. `None` when no driving Spark task is available.
task_context: Option<Arc<Global<JObject<'static>>>>,
}

impl Default for PhysicalPlanner {
Expand All @@ -198,16 +201,24 @@ impl PhysicalPlanner {
session_ctx,
partition,
query_context_registry: datafusion_comet_spark_expr::create_query_context_map(),
task_context: None,
}
}

pub fn with_exec_id(self, exec_context_id: i64) -> Self {
Self {
exec_context_id,
partition: self.partition,
session_ctx: Arc::clone(&self.session_ctx),
query_context_registry: Arc::clone(&self.query_context_registry),
}
pub fn with_exec_id(mut self, exec_context_id: i64) -> Self {
self.exec_context_id = exec_context_id;
self
}

/// Attach the Spark `TaskContext` global reference captured at `createPlan` time. Cloned
/// into every `JvmScalarUdfExpr` the planner builds so the JNI bridge can install it as
/// the thread-local on the Tokio worker driving the UDF.
pub fn with_task_context(
mut self,
task_context: Option<Arc<Global<JObject<'static>>>>,
) -> Self {
self.task_context = task_context;
self
}

/// Return session context of this planner.
Expand Down Expand Up @@ -735,6 +746,7 @@ impl PhysicalPlanner {
args,
return_type,
udf.return_nullable,
self.task_context.clone(),
)))
}
expr => Err(GeneralError(format!("Not implemented: {expr:?}"))),
Expand Down
2 changes: 1 addition & 1 deletion native/jni-bridge/src/comet_udf_bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ impl<'a> CometUdfBridge<'a> {
method_evaluate: env.get_static_method_id(
JNIString::new(Self::JVM_CLASS),
jni::jni_str!("evaluate"),
jni::jni_sig!("(Ljava/lang/String;[J[JJJ)V"),
jni::jni_sig!("(Ljava/lang/String;[J[JJJILorg/apache/spark/TaskContext;)V"),
)?,
method_evaluate_ret: ReturnType::Primitive(Primitive::Void),
class,
Expand Down
24 changes: 22 additions & 2 deletions native/spark-expr/src/jvm_udf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use datafusion::physical_expr::PhysicalExpr;

use datafusion_comet_jni_bridge::errors::{CometError, ExecutionError};
use datafusion_comet_jni_bridge::JVMClasses;
use jni::objects::{JObject, JValue};
use jni::objects::{Global, JObject, JValue};

/// A scalar expression that delegates evaluation to a JVM-side `CometUDF` via JNI.
/// The JVM class named by `class_name` must implement `org.apache.comet.udf.CometUDF`.
Expand All @@ -41,6 +41,14 @@ pub struct JvmScalarUdfExpr {
args: Vec<Arc<dyn PhysicalExpr>>,
return_type: DataType,
return_nullable: bool,
/// Captured at `createPlan` time and threaded here by the planner. Passed through the
/// JNI bridge so `CometUdfBridge.evaluate` can install it as the Tokio worker's
/// thread-local `TaskContext`. Without this, partition-sensitive built-ins inside a UDF
/// tree (`Rand`, `Uuid`, `MonotonicallyIncreasingID`, user code reading
/// `TaskContext.get()`) see `null` and seed / branch incorrectly. `None` when no driving
/// Spark task is available; the bridge then leaves whatever `TaskContext.get()` already
/// returns in place.
task_context: Option<Arc<Global<JObject<'static>>>>,
}

impl JvmScalarUdfExpr {
Expand All @@ -49,12 +57,14 @@ impl JvmScalarUdfExpr {
args: Vec<Arc<dyn PhysicalExpr>>,
return_type: DataType,
return_nullable: bool,
task_context: Option<Arc<Global<JObject<'static>>>>,
) -> Self {
Self {
class_name,
args,
return_type,
return_nullable,
task_context,
}
}
}
Expand Down Expand Up @@ -186,7 +196,14 @@ impl PhysicalExpr for JvmScalarUdfExpr {
.set_region(env, 0, &in_sch_ptrs)
.map_err(|e| CometError::JNI { source: e })?;

// Call CometUdfBridge.evaluate(String, long[], long[], long, long)
// Pass a null jobject when no TaskContext was propagated so the bridge's null-guard
// leaves the worker thread's current TaskContext.get() in place. The borrow must
// outlive `call_static_method_unchecked`.
let null_task_context = JObject::null();
let task_context_ref: &JObject = match &self.task_context {
Some(gref) => gref.as_obj(),
None => &null_task_context,
};
let ret = unsafe {
env.call_static_method_unchecked(
&bridge.class,
Expand All @@ -198,6 +215,8 @@ impl PhysicalExpr for JvmScalarUdfExpr {
JValue::Object(JObject::from(in_sch_java).as_ref()).as_jni(),
JValue::Long(out_arr_ptr).as_jni(),
JValue::Long(out_sch_ptr).as_jni(),
JValue::Int(batch.num_rows() as i32).as_jni(),
JValue::Object(task_context_ref).as_jni(),
],
)
};
Expand Down Expand Up @@ -234,6 +253,7 @@ impl PhysicalExpr for JvmScalarUdfExpr {
children,
self.return_type.clone(),
self.return_nullable,
self.task_context.clone(),
)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,10 @@ class CometExecIterator(
memoryConfig.memoryLimitPerTask,
taskAttemptId,
taskCPUs,
keyUnwrapper)
keyUnwrapper,
// Propagated to Tokio workers running JVM UDFs so they see this Spark task's
// TaskContext. See CometUdfBridge.evaluate.
TaskContext.get())
}

private var nextBatch: Option[ColumnarBatch] = None
Expand Down
5 changes: 3 additions & 2 deletions spark/src/main/scala/org/apache/comet/Native.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ package org.apache.comet

import java.nio.ByteBuffer

import org.apache.spark.CometTaskMemoryManager
import org.apache.spark.{CometTaskMemoryManager, TaskContext}
import org.apache.spark.sql.comet.CometMetricNode

import org.apache.comet.parquet.CometFileKeyUnwrapper
Expand Down Expand Up @@ -69,7 +69,8 @@ class Native extends NativeBase {
memoryLimitPerTask: Long,
taskAttemptId: Long,
taskCPUs: Long,
keyUnwrapper: CometFileKeyUnwrapper): Long
keyUnwrapper: CometFileKeyUnwrapper,
taskContext: TaskContext): Long
// scalastyle:on

/**
Expand Down
Loading