From 8f2b6eb4be5d9945e5d1708b9b02e37f996504d2 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 5 May 2026 15:26:38 -0600 Subject: [PATCH 1/9] feat: add JVM UDF framework for native execution Add a framework that allows Comet to invoke JVM-side UDF implementations operating on Arrow data via JNI, avoiding expensive fallback to Spark while maintaining 100% Spark compatibility for expressions not yet implemented natively in Rust. Co-Authored-By: Claude Opus 4.6 --- .../org/apache/comet/udf/CometUdfBridge.java | 141 +++++++++++ .../comet/udf/CometLambdaRegistry.scala | 58 +++++ .../scala/org/apache/comet/udf/CometUDF.scala | 37 +++ native/Cargo.lock | 2 + native/core/src/execution/planner.rs | 25 +- native/jni-bridge/src/comet_udf_bridge.rs | 50 ++++ native/jni-bridge/src/lib.rs | 12 + native/proto/src/proto/expr.proto | 16 ++ native/spark-expr/Cargo.toml | 2 + native/spark-expr/src/jvm_udf/mod.rs | 239 ++++++++++++++++++ native/spark-expr/src/lib.rs | 2 + 11 files changed, 580 insertions(+), 4 deletions(-) create mode 100644 common/src/main/java/org/apache/comet/udf/CometUdfBridge.java create mode 100644 common/src/main/scala/org/apache/comet/udf/CometLambdaRegistry.scala create mode 100644 common/src/main/scala/org/apache/comet/udf/CometUDF.scala create mode 100644 native/jni-bridge/src/comet_udf_bridge.rs create mode 100644 native/spark-expr/src/jvm_udf/mod.rs diff --git a/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java b/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java new file mode 100644 index 0000000000..0e01c12d81 --- /dev/null +++ b/common/src/main/java/org/apache/comet/udf/CometUdfBridge.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf; + +import java.util.LinkedHashMap; +import java.util.Map; + +import org.apache.arrow.c.ArrowArray; +import org.apache.arrow.c.ArrowSchema; +import org.apache.arrow.c.Data; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ValueVector; + +/** + * JNI entry point for native execution to invoke a {@link CometUDF}. Matches the static-method + * pattern used by CometScalarSubquery so the native side can dispatch via + * call_static_method_unchecked. + */ +public class CometUdfBridge { + + // Per-thread, bounded LRU of UDF instances keyed by class name. Comet + // native execution threads (Tokio/DataFusion worker pool) are reused + // across tasks within an executor, so the effective lifetime of cached + // entries is the worker thread (i.e. the executor JVM). This is fine for + // stateless UDFs like ArrayExistsUDF; future stateful UDFs would need + // explicit per-task isolation. + private static final int CACHE_CAPACITY = 64; + + private static final ThreadLocal> INSTANCES = + ThreadLocal.withInitial( + () -> + new LinkedHashMap(CACHE_CAPACITY, 0.75f, true) { + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return size() > CACHE_CAPACITY; + } + }); + + /** + * Called from native via JNI. + * + * @param udfClassName fully-qualified class name implementing CometUDF + * @param inputArrayPtrs addresses of pre-allocated FFI_ArrowArray structs (one per input) + * @param inputSchemaPtrs addresses of pre-allocated FFI_ArrowSchema structs (one per input) + * @param outArrayPtr address of pre-allocated FFI_ArrowArray for the result + * @param outSchemaPtr address of pre-allocated FFI_ArrowSchema for the result + */ + public static void evaluate( + String udfClassName, + long[] inputArrayPtrs, + long[] inputSchemaPtrs, + long outArrayPtr, + long outSchemaPtr) { + LinkedHashMap cache = INSTANCES.get(); + CometUDF udf = cache.get(udfClassName); + if (udf == null) { + try { + // Resolve via the executor's context classloader so user-supplied UDF jars + // (added via spark.jars / --jars) are visible. + ClassLoader cl = Thread.currentThread().getContextClassLoader(); + if (cl == null) { + cl = CometUdfBridge.class.getClassLoader(); + } + udf = + (CometUDF) Class.forName(udfClassName, true, cl).getDeclaredConstructor().newInstance(); + } catch (ReflectiveOperationException e) { + throw new RuntimeException("Failed to instantiate CometUDF: " + udfClassName, e); + } + cache.put(udfClassName, udf); + } + + BufferAllocator allocator = org.apache.comet.package$.MODULE$.CometArrowAllocator(); + + ValueVector[] inputs = new ValueVector[inputArrayPtrs.length]; + ValueVector result = null; + try { + for (int i = 0; i < inputArrayPtrs.length; i++) { + ArrowArray inArr = ArrowArray.wrap(inputArrayPtrs[i]); + ArrowSchema inSch = ArrowSchema.wrap(inputSchemaPtrs[i]); + inputs[i] = Data.importVector(allocator, inArr, inSch, null); + } + + result = udf.evaluate(inputs); + if (!(result instanceof FieldVector)) { + throw new RuntimeException( + "CometUDF.evaluate() must return a FieldVector, got: " + result.getClass().getName()); + } + // Result length must match the longest input. Scalar (length-1) inputs + // are allowed to be shorter, but a vector input bounds the output. + int expectedLen = 0; + for (ValueVector v : inputs) { + expectedLen = Math.max(expectedLen, v.getValueCount()); + } + if (result.getValueCount() != expectedLen) { + throw new RuntimeException( + "CometUDF.evaluate() returned " + + result.getValueCount() + + " rows, expected " + + expectedLen); + } + ArrowArray outArr = ArrowArray.wrap(outArrayPtr); + ArrowSchema outSch = ArrowSchema.wrap(outSchemaPtr); + Data.exportVector(allocator, (FieldVector) result, null, outArr, outSch); + } finally { + for (ValueVector v : inputs) { + if (v != null) { + try { + v.close(); + } catch (RuntimeException ignored) { + // do not mask the original throwable + } + } + } + if (result != null) { + try { + result.close(); + } catch (RuntimeException ignored) { + // do not mask the original throwable + } + } + } + } +} diff --git a/common/src/main/scala/org/apache/comet/udf/CometLambdaRegistry.scala b/common/src/main/scala/org/apache/comet/udf/CometLambdaRegistry.scala new file mode 100644 index 0000000000..5e020ae74a --- /dev/null +++ b/common/src/main/scala/org/apache/comet/udf/CometLambdaRegistry.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import java.util.UUID +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark.sql.catalyst.expressions.Expression + +/** + * Thread-safe registry bridging plan-time Spark expressions to execution-time UDF lookup. At plan + * time the serde layer registers a lambda expression under a unique key; at execution time the + * UDF retrieves it by that key (passed as a scalar argument). + */ +object CometLambdaRegistry { + + private val registry = new ConcurrentHashMap[String, Expression]() + + def register(expression: Expression): String = { + val key = UUID.randomUUID().toString + registry.put(key, expression) + key + } + + def get(key: String): Expression = { + val expr = registry.get(key) + if (expr == null) { + throw new IllegalStateException( + s"Lambda expression not found in registry for key: $key. " + + "This indicates a lifecycle issue between plan creation and execution.") + } + expr + } + + def remove(key: String): Unit = { + registry.remove(key) + } + + // Visible for testing + def size(): Int = registry.size() +} diff --git a/common/src/main/scala/org/apache/comet/udf/CometUDF.scala b/common/src/main/scala/org/apache/comet/udf/CometUDF.scala new file mode 100644 index 0000000000..ac7b72a883 --- /dev/null +++ b/common/src/main/scala/org/apache/comet/udf/CometUDF.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import org.apache.arrow.vector.ValueVector + +/** + * Scalar UDF invoked from native execution via JNI. Receives Arrow vectors as input and returns + * an Arrow vector. + * + * - Vector arguments arrive at the row count of the current batch. + * - Scalar (literal-folded) arguments arrive as length-1 vectors and must be read at index 0. + * - The returned vector's length must match the longest input. + * + * Implementations must have a public no-arg constructor and should be stateless: instances are + * cached per executor thread for the lifetime of the JVM. + */ +trait CometUDF { + def evaluate(inputs: Array[ValueVector]): ValueVector +} diff --git a/native/Cargo.lock b/native/Cargo.lock index ae2d6b074c..75e84d851d 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -2116,8 +2116,10 @@ dependencies = [ "criterion", "datafusion", "datafusion-comet-common", + "datafusion-comet-jni-bridge", "futures", "hex", + "jni 0.22.4", "num", "rand 0.10.1", "regex", diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 844cc07c69..6019f168cc 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -122,10 +122,10 @@ use datafusion_comet_proto::{ spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning}, }; use datafusion_comet_spark_expr::{ - ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct, - DecimalRescaleCheckOverflow, GetArrayStructFields, GetStructField, IfExpr, ListExtract, - NormalizeNaNAndZero, SparkCastOptions, Stddev, SumDecimal, ToJson, UnboundColumn, Variance, - WideDecimalBinaryExpr, WideDecimalOp, + jvm_udf::JvmScalarUdfExpr, ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, + Covariance, CreateNamedStruct, DecimalRescaleCheckOverflow, GetArrayStructFields, + GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, SparkCastOptions, Stddev, SumDecimal, + ToJson, UnboundColumn, Variance, WideDecimalBinaryExpr, WideDecimalOp, }; use itertools::Itertools; use jni::objects::{Global, JObject}; @@ -701,6 +701,23 @@ impl PhysicalPlanner { expr.names.clone(), ))) } + ExprStruct::JvmScalarUdf(udf) => { + let args = udf + .args + .iter() + .map(|e| self.create_expr(e, Arc::clone(&input_schema))) + .collect::, _>>()?; + let return_type = + to_arrow_datatype(udf.return_type.as_ref().ok_or_else(|| { + GeneralError("JvmScalarUdf missing return_type".to_string()) + })?); + Ok(Arc::new(JvmScalarUdfExpr::new( + udf.class_name.clone(), + args, + return_type, + udf.return_nullable, + ))) + } expr => Err(GeneralError(format!("Not implemented: {expr:?}"))), } } diff --git a/native/jni-bridge/src/comet_udf_bridge.rs b/native/jni-bridge/src/comet_udf_bridge.rs new file mode 100644 index 0000000000..89cd8ee514 --- /dev/null +++ b/native/jni-bridge/src/comet_udf_bridge.rs @@ -0,0 +1,50 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use jni::{ + errors::Result as JniResult, + objects::{JClass, JStaticMethodID}, + signature::{Primitive, ReturnType}, + strings::JNIString, + Env, +}; + +/// JNI handle for the JVM `org.apache.comet.udf.CometUdfBridge` class. +/// Mirrors the static-method pattern in `comet_exec.rs` (`CometScalarSubquery`). +#[allow(dead_code)] // class field is held to keep JStaticMethodID alive +pub struct CometUdfBridge<'a> { + pub class: JClass<'a>, + pub method_evaluate: JStaticMethodID, + pub method_evaluate_ret: ReturnType, +} + +impl<'a> CometUdfBridge<'a> { + pub const JVM_CLASS: &'static str = "org/apache/comet/udf/CometUdfBridge"; + + pub fn new(env: &mut Env<'a>) -> JniResult> { + let class = env.find_class(JNIString::new(Self::JVM_CLASS))?; + Ok(CometUdfBridge { + method_evaluate: env.get_static_method_id( + JNIString::new(Self::JVM_CLASS), + jni::jni_str!("evaluate"), + jni::jni_sig!("(Ljava/lang/String;[J[JJJ)V"), + )?, + method_evaluate_ret: ReturnType::Primitive(Primitive::Void), + class, + }) + } +} diff --git a/native/jni-bridge/src/lib.rs b/native/jni-bridge/src/lib.rs index 21c647135b..d72323c961 100644 --- a/native/jni-bridge/src/lib.rs +++ b/native/jni-bridge/src/lib.rs @@ -192,11 +192,13 @@ pub use comet_exec::*; mod batch_iterator; mod comet_metric_node; mod comet_task_memory_manager; +mod comet_udf_bridge; mod shuffle_block_iterator; use batch_iterator::CometBatchIterator; pub use comet_metric_node::*; pub use comet_task_memory_manager::*; +use comet_udf_bridge::CometUdfBridge; use shuffle_block_iterator::CometShuffleBlockIterator; /// The JVM classes that are used in the JNI calls. @@ -228,6 +230,9 @@ pub struct JVMClasses<'a> { /// The CometTaskMemoryManager used for interacting with JVM side to /// acquire & release native memory. pub comet_task_memory_manager: CometTaskMemoryManager<'a>, + /// The CometUdfBridge class used to dispatch JVM scalar UDFs. + /// `None` if the class is not on the classpath. + pub comet_udf_bridge: Option>, } unsafe impl Send for JVMClasses<'_> {} @@ -298,6 +303,13 @@ impl JVMClasses<'_> { comet_batch_iterator: CometBatchIterator::new(env).unwrap(), comet_shuffle_block_iterator: CometShuffleBlockIterator::new(env).unwrap(), comet_task_memory_manager: CometTaskMemoryManager::new(env).unwrap(), + comet_udf_bridge: { + let bridge = CometUdfBridge::new(env).ok(); + if env.exception_check() { + env.exception_clear(); + } + bridge + }, } }); } diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index c7a305285d..90e3d87032 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -90,6 +90,7 @@ message Expr { ToCsv to_csv = 67; HoursTransform hours_transform = 68; ArraysZip arrays_zip = 69; + JvmScalarUdf jvm_scalar_udf = 70; } // Optional QueryContext for error reporting (contains SQL text and position) @@ -514,3 +515,18 @@ message ArraysZip { repeated Expr values = 1; repeated string names = 2; } + +// Scalar UDF dispatched to the JVM via JNI. Native side exports input arrays +// through Arrow C Data Interface, calls CometUdfBridge.evaluate, and imports +// the result. +message JvmScalarUdf { + // Fully-qualified Java/Scala class name implementing + // org.apache.comet.udf.CometUDF (must have a public no-arg constructor). + string class_name = 1; + // Argument expressions, evaluated by the native side before invocation. + repeated Expr args = 2; + // Expected return type. Used to import the result FFI_ArrowArray. + DataType return_type = 3; + // Whether the result column may contain nulls. + bool return_nullable = 4; +} diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index e9a4a546c1..33ffc1c886 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -36,6 +36,8 @@ regex = { workspace = true } # preserve_order: needed for get_json_object to match Spark's JSON key ordering serde_json = { version = "1.0", features = ["preserve_order"] } datafusion-comet-common = { workspace = true } +datafusion-comet-jni-bridge = { workspace = true } +jni = "0.22.4" futures = { workspace = true } twox-hash = "2.1.2" rand = { workspace = true } diff --git a/native/spark-expr/src/jvm_udf/mod.rs b/native/spark-expr/src/jvm_udf/mod.rs new file mode 100644 index 0000000000..668a2b6727 --- /dev/null +++ b/native/spark-expr/src/jvm_udf/mod.rs @@ -0,0 +1,239 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::fmt::{Display, Formatter}; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +use arrow::array::{make_array, ArrayRef}; +use arrow::datatypes::{DataType, Schema}; +use arrow::ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}; +use arrow::record_batch::RecordBatch; + +use datafusion::common::Result as DFResult; +use datafusion::logical_expr::ColumnarValue; +use datafusion::physical_expr::PhysicalExpr; + +use datafusion_comet_jni_bridge::errors::{CometError, ExecutionError}; +use datafusion_comet_jni_bridge::JVMClasses; +use jni::objects::{JObject, JValue}; + +/// A scalar expression that delegates evaluation to a JVM-side `CometUDF` via JNI. +/// The JVM class named by `class_name` must implement `org.apache.comet.udf.CometUDF`. +#[derive(Debug)] +pub struct JvmScalarUdfExpr { + class_name: String, + args: Vec>, + return_type: DataType, + return_nullable: bool, +} + +impl JvmScalarUdfExpr { + pub fn new( + class_name: String, + args: Vec>, + return_type: DataType, + return_nullable: bool, + ) -> Self { + Self { + class_name, + args, + return_type, + return_nullable, + } + } +} + +impl Display for JvmScalarUdfExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "JvmScalarUdf({}", self.class_name)?; + for a in &self.args { + write!(f, ", {a}")?; + } + write!(f, ")") + } +} + +impl Hash for JvmScalarUdfExpr { + fn hash(&self, state: &mut H) { + self.class_name.hash(state); + for a in &self.args { + a.hash(state); + } + self.return_type.hash(state); + self.return_nullable.hash(state); + } +} + +impl PartialEq for JvmScalarUdfExpr { + fn eq(&self, other: &Self) -> bool { + self.class_name == other.class_name + && self.return_type == other.return_type + && self.return_nullable == other.return_nullable + && self.args.len() == other.args.len() + && self.args.iter().zip(&other.args).all(|(a, b)| a.eq(b)) + } +} + +impl Eq for JvmScalarUdfExpr {} + +impl PhysicalExpr for JvmScalarUdfExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Display::fmt(self, f) + } + + fn data_type(&self, _input_schema: &Schema) -> DFResult { + Ok(self.return_type.clone()) + } + + fn nullable(&self, _input_schema: &Schema) -> DFResult { + Ok(self.return_nullable) + } + + fn evaluate(&self, batch: &RecordBatch) -> DFResult { + // Step 1: evaluate child expressions to get Arrow arrays. Scalar children + // (e.g. literal patterns) are sent as length-1 vectors rather than expanded + // to batch-row count, so the JVM bridge does not pay an O(rows) copy for + // values that never vary across the batch. + let arrays: Vec = self + .args + .iter() + .map(|e| match e.evaluate(batch)? { + ColumnarValue::Array(a) => Ok(a), + ColumnarValue::Scalar(s) => s.to_array_of_size(1), + }) + .collect::>()?; + + // Step 2: allocate FFI structs on the Rust heap and collect their raw pointers. + // The JVM writes into the out_array/out_schema slots and reads from the in_ slots. + let in_ffi_arrays: Vec> = arrays + .iter() + .map(|arr| Box::new(FFI_ArrowArray::new(&arr.to_data()))) + .collect(); + let in_ffi_schemas: Vec> = arrays + .iter() + .map(|arr| { + FFI_ArrowSchema::try_from(arr.data_type()) + .map(Box::new) + .map_err(|e| CometError::Arrow { source: e }) + }) + .collect::>()?; + + let in_arr_ptrs: Vec = in_ffi_arrays + .iter() + .map(|b| b.as_ref() as *const FFI_ArrowArray as i64) + .collect(); + let in_sch_ptrs: Vec = in_ffi_schemas + .iter() + .map(|b| b.as_ref() as *const FFI_ArrowSchema as i64) + .collect(); + + // Allocate output FFI slots. + let mut out_array = Box::new(FFI_ArrowArray::empty()); + let mut out_schema = Box::new(FFI_ArrowSchema::empty()); + let out_arr_ptr = out_array.as_mut() as *mut FFI_ArrowArray as i64; + let out_sch_ptr = out_schema.as_mut() as *mut FFI_ArrowSchema as i64; + + let class_name = self.class_name.clone(); + let n_args = arrays.len(); + + // Step 3: attach a JNI env for this thread and call the static bridge method. + JVMClasses::with_env(|env| { + let bridge = JVMClasses::get().comet_udf_bridge.as_ref().ok_or_else(|| { + CometError::from(ExecutionError::GeneralError( + "JVM UDF bridge unavailable: org.apache.comet.udf.CometUdfBridge \ + class was not found on the JVM classpath." + .to_string(), + )) + })?; + + // Build the JVM String for the class name. + let jclass_name = env + .new_string(&class_name) + .map_err(|e| CometError::JNI { source: e })?; + + // Build the long[] arrays for input pointers. + let in_arr_java = env + .new_long_array(n_args) + .map_err(|e| CometError::JNI { source: e })?; + in_arr_java + .set_region(env, 0, &in_arr_ptrs) + .map_err(|e| CometError::JNI { source: e })?; + + let in_sch_java = env + .new_long_array(n_args) + .map_err(|e| CometError::JNI { source: e })?; + in_sch_java + .set_region(env, 0, &in_sch_ptrs) + .map_err(|e| CometError::JNI { source: e })?; + + // Call CometUdfBridge.evaluate(String, long[], long[], long, long) + let ret = unsafe { + env.call_static_method_unchecked( + &bridge.class, + bridge.method_evaluate, + bridge.method_evaluate_ret, + &[ + JValue::from(&jclass_name).as_jni(), + JValue::Object(JObject::from(in_arr_java).as_ref()).as_jni(), + JValue::Object(JObject::from(in_sch_java).as_ref()).as_jni(), + JValue::Long(out_arr_ptr).as_jni(), + JValue::Long(out_sch_ptr).as_jni(), + ], + ) + }; + + if let Some(exception) = datafusion_comet_jni_bridge::check_exception(env)? { + return Err(exception); + } + + ret.map_err(|e| CometError::JNI { source: e })?; + Ok(()) + })?; + + // Step 4: import the result from the FFI slots filled by the JVM. + // SAFETY: `*out_array` moves the FFI_ArrowArray out of the Box (the heap + // allocation is freed by the move), and `from_ffi` wraps it in an Arc that + // keeps the JVM-installed release callback alive until the resulting + // ArrayData drops. `out_schema` is borrowed; its release callback runs + // exactly once when the Box drops at end of scope. + let result_data = unsafe { from_ffi(*out_array, &out_schema) } + .map_err(|e| CometError::Arrow { source: e })?; + Ok(ColumnarValue::Array(make_array(result_data))) + } + + fn children(&self) -> Vec<&Arc> { + self.args.iter().collect() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> DFResult> { + Ok(Arc::new(JvmScalarUdfExpr::new( + self.class_name.clone(), + children, + self.return_type.clone(), + self.return_nullable, + ))) + } +} diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index eddf2ff460..d5297f27fd 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -55,6 +55,8 @@ pub use cast::{spark_cast, Cast, SparkCastOptions}; mod bloom_filter; pub use bloom_filter::{BloomFilterAgg, BloomFilterMightContain, SparkBloomFilterVersion}; +pub mod jvm_udf; + mod conditional_funcs; mod conversion_funcs; mod map_funcs; From 8bf539c7ae360e60a77f7d47abf43d20975c099a Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 5 May 2026 15:43:21 -0600 Subject: [PATCH 2/9] feat: add user-facing CometUDF registration for custom JVM UDFs Add CometUdfRegistry that allows end users to register their own CometUDF implementations to be accelerated by Comet's native execution. When a ScalaUDF is encountered during planning whose name matches a registry entry, Comet emits a JvmScalarUdf proto instead of falling back to Spark's row-at-a-time execution. Also adds user guide documentation explaining how to write, register, and deploy custom JVM UDFs. Co-Authored-By: Claude Opus 4.6 --- .../apache/comet/udf/CometUdfRegistry.scala | 127 +++++++++++++++ .../user-guide/latest/custom-jvm-udfs.md | 148 ++++++++++++++++++ docs/source/user-guide/latest/index.rst | 1 + .../apache/comet/serde/CometScalaUdf.scala | 73 +++++++++ .../apache/comet/serde/QueryPlanSerde.scala | 3 + 5 files changed, 352 insertions(+) create mode 100644 common/src/main/scala/org/apache/comet/udf/CometUdfRegistry.scala create mode 100644 docs/source/user-guide/latest/custom-jvm-udfs.md create mode 100644 spark/src/main/scala/org/apache/comet/serde/CometScalaUdf.scala diff --git a/common/src/main/scala/org/apache/comet/udf/CometUdfRegistry.scala b/common/src/main/scala/org/apache/comet/udf/CometUdfRegistry.scala new file mode 100644 index 0000000000..0b17ac66b2 --- /dev/null +++ b/common/src/main/scala/org/apache/comet/udf/CometUdfRegistry.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.types.DataType + +/** + * Registry for user-defined CometUDF implementations. Users register their UDF class names here + * so that the Comet serde layer can intercept matching Spark UDFs and route them to native + * execution via the JVM UDF bridge. + * + * Usage: + * {{{ + * // Register a CometUDF implementation for a Spark UDF + * CometUdfRegistry.register( + * "my_func", // Spark UDF name (as used in spark.udf.register) + * "com.example.MyUdf", // CometUDF implementation class + * BooleanType, // return type + * nullable = true // whether the result may contain nulls + * ) + * + * // Or use the convenience method that also registers the Spark UDF: + * CometUdfRegistry.register( + * spark, + * "my_func", + * "com.example.MyUdf", + * sparkUdf, // the Spark UserDefinedFunction + * BooleanType, + * nullable = true + * ) + * }}} + */ +object CometUdfRegistry { + + case class UdfEntry(className: String, returnType: DataType, nullable: Boolean) + + private val registry = new ConcurrentHashMap[String, UdfEntry]() + + /** + * Register a CometUDF implementation for a named Spark UDF. + * + * @param name + * The UDF name as registered with Spark (via spark.udf.register) + * @param className + * Fully-qualified class name implementing CometUDF + * @param returnType + * The return DataType of the UDF + * @param nullable + * Whether the result column may contain nulls + */ + def register(name: String, className: String, returnType: DataType, nullable: Boolean): Unit = { + registry.put(name, UdfEntry(className, returnType, nullable)) + } + + /** + * Convenience method that registers both with Spark and with Comet in one call. + * + * @param spark + * The SparkSession + * @param name + * The UDF name + * @param className + * Fully-qualified CometUDF class name + * @param sparkUdf + * The Spark UserDefinedFunction (for row-at-a-time fallback) + * @param returnType + * The return DataType + * @param nullable + * Whether the result may contain nulls + */ + def register( + spark: SparkSession, + name: String, + className: String, + sparkUdf: org.apache.spark.sql.expressions.UserDefinedFunction, + returnType: DataType, + nullable: Boolean): Unit = { + spark.udf.register(name, sparkUdf) + registry.put(name, UdfEntry(className, returnType, nullable)) + } + + /** + * Look up a registered CometUDF by its Spark UDF name. + * + * @return + * Some(UdfEntry) if registered, None otherwise + */ + def get(name: String): Option[UdfEntry] = Option(registry.get(name)) + + /** + * Remove a previously registered UDF. + */ + def remove(name: String): Unit = { + registry.remove(name) + } + + /** + * Check whether a UDF name is registered. + */ + def isRegistered(name: String): Boolean = registry.containsKey(name) + + // Visible for testing + def size(): Int = registry.size() + + // Visible for testing + def clear(): Unit = registry.clear() +} diff --git a/docs/source/user-guide/latest/custom-jvm-udfs.md b/docs/source/user-guide/latest/custom-jvm-udfs.md new file mode 100644 index 0000000000..7ed151ab14 --- /dev/null +++ b/docs/source/user-guide/latest/custom-jvm-udfs.md @@ -0,0 +1,148 @@ + + +# Custom JVM UDFs + +Comet supports user-defined functions (UDFs) that operate on Arrow columnar data via the JVM UDF framework. This +allows UDFs to process entire batches of data at once rather than row-at-a-time, providing significant performance +improvements while maintaining full Spark compatibility. + +## Overview + +When Comet encounters a registered Spark UDF during query planning, it can route the UDF to a vectorized +JVM implementation that operates on Arrow vectors. This avoids the overhead of falling back to Spark's +row-at-a-time execution while keeping the implementation in Java/Scala. + +The framework consists of: + +- **`CometUDF`** — a trait your UDF class must implement, receiving and returning Arrow `ValueVector` instances +- **`CometUdfRegistry`** — a registry that maps Spark UDF names to CometUDF implementation classes +- **`CometUdfBridge`** — the JNI bridge that native execution uses to invoke your UDF (no user interaction needed) + +## Writing a CometUDF + +Implement the `org.apache.comet.udf.CometUDF` trait: + +```java +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.comet.udf.CometUDF; + +public class IsPositiveUdf implements CometUDF { + + @Override + public ValueVector evaluate(ValueVector[] inputs) { + IntVector input = (IntVector) inputs[0]; + int rowCount = input.getValueCount(); + + BitVector result = new BitVector("result", + org.apache.comet.package$.MODULE$.CometArrowAllocator()); + result.allocateNew(rowCount); + + for (int i = 0; i < rowCount; i++) { + if (input.isNull(i)) { + result.setNull(i); + } else { + result.set(i, input.get(i) > 0 ? 1 : 0); + } + } + result.setValueCount(rowCount); + return result; + } +} +``` + +Key requirements: + +- The class must have a **public no-arg constructor** +- Input vectors arrive at the row count of the current batch +- Scalar (literal) arguments arrive as length-1 vectors — read at index 0 +- The returned vector's length **must match** the longest input vector +- Instances are cached per executor thread, so implementations should be **stateless** + +## Registering a CometUDF + +### Option 1: Register Comet UDF separately from Spark UDF + +If you already have a Spark UDF registered, just tell Comet about the accelerated implementation: + +```scala +import org.apache.comet.udf.CometUdfRegistry +import org.apache.spark.sql.types.BooleanType + +// Register the Spark UDF (row-at-a-time fallback) +spark.udf.register("is_positive", (x: Int) => x > 0) + +// Register the CometUDF (vectorized Arrow implementation) +CometUdfRegistry.register( + "is_positive", // must match the Spark UDF name + "com.example.IsPositiveUdf", // CometUDF implementation class + BooleanType, // return type + nullable = true // whether results may contain nulls +) +``` + +### Option 2: Register both in one call + +```scala +import org.apache.comet.udf.CometUdfRegistry +import org.apache.spark.sql.functions.udf +import org.apache.spark.sql.types.BooleanType + +val sparkUdf = udf((x: Int) => x > 0) + +CometUdfRegistry.register( + spark, + "is_positive", + "com.example.IsPositiveUdf", + sparkUdf, + BooleanType, + nullable = true +) +``` + +## How It Works + +1. **Query planning**: When Comet's serde layer encounters a `ScalaUDF` expression with a name registered in + `CometUdfRegistry`, it emits a `JvmScalarUdf` protobuf message instead of falling back to Spark. + +2. **Native execution**: The Rust execution engine evaluates the UDF's input expressions to Arrow arrays, then + calls back into the JVM via the Arrow C Data Interface (zero-copy FFI). + +3. **JVM execution**: `CometUdfBridge` instantiates your `CometUDF` class (cached per thread), passes the input + Arrow vectors, and exports the result vector back to native execution. + +4. **Fallback**: If Comet is disabled or the UDF is not in the registry, Spark executes the UDF row-at-a-time + using the originally registered Scala/Java function. + +## Packaging and Deployment + +1. Package your `CometUDF` implementation in a JAR +2. Include it on the Spark classpath via `--jars` or `spark.jars` +3. Register the UDF as shown above (in your application code or via a Spark session extension) + +The CometUDF class is resolved using the executor's context classloader, so user-supplied JARs added via +`spark.jars` or `--jars` are automatically visible. + +## Limitations + +- Only scalar UDFs are supported (not aggregate or table UDFs) +- The UDF must be registered by name — anonymous lambdas without a name cannot be intercepted +- All input and output types must be representable as Arrow vectors diff --git a/docs/source/user-guide/latest/index.rst b/docs/source/user-guide/latest/index.rst index 480ec4f702..db83ceb5ae 100644 --- a/docs/source/user-guide/latest/index.rst +++ b/docs/source/user-guide/latest/index.rst @@ -37,6 +37,7 @@ Comet $COMET_VERSION User Guide Compatibility Guide Understanding Comet Plans Tuning Guide + Custom JVM UDFs Metrics Guide Iceberg Guide Kubernetes Guide diff --git a/spark/src/main/scala/org/apache/comet/serde/CometScalaUdf.scala b/spark/src/main/scala/org/apache/comet/serde/CometScalaUdf.scala new file mode 100644 index 0000000000..1958ec3a76 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/CometScalaUdf.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.catalyst.expressions.{Attribute, ScalaUDF} + +import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, serializeDataType} +import org.apache.comet.udf.CometUdfRegistry + +/** + * Handles serialization of Spark ScalaUDF expressions when a matching CometUDF implementation is + * registered in [[CometUdfRegistry]]. If the UDF is not registered, falls back to Spark. + */ +object CometScalaUdf { + + def convert( + expr: ScalaUDF, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val name = expr.udfName.getOrElse { + withInfo(expr, "ScalaUDF has no name, cannot look up CometUDF registration") + return None + } + + val entry = CometUdfRegistry.get(name).getOrElse { + withInfo(expr, s"ScalaUDF '$name' is not registered in CometUdfRegistry") + return None + } + + val argProtos = expr.children.map(child => exprToProtoInternal(child, inputs, binding)) + if (argProtos.exists(_.isEmpty)) { + withInfo(expr, s"Failed to serialize one or more arguments for CometUDF '$name'") + return None + } + + val returnType = serializeDataType(entry.returnType).getOrElse { + withInfo(expr, s"Failed to serialize return type for CometUDF '$name'") + return None + } + + val udfBuilder = ExprOuterClass.JvmScalarUdf + .newBuilder() + .setClassName(entry.className) + .setReturnType(returnType) + .setReturnNullable(entry.nullable) + + argProtos.foreach(proto => udfBuilder.addArgs(proto.get)) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setJvmScalarUdf(udfBuilder.build()) + .build()) + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index d5dc5e983c..52069a9f6a 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -662,6 +662,9 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { // `PromotePrecision` is just a wrapper, don't need to serialize it. exprToProtoInternal(child, inputs, binding) + case udf: ScalaUDF => + CometScalaUdf.convert(udf, inputs, binding) + case expr => QueryPlanSerde.exprSerdeMap.get(expr.getClass) match { case Some(handler) => From 813cfd7773b5a26cdba4f56c2e6dc2a554c83571 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 5 May 2026 16:51:42 -0600 Subject: [PATCH 3/9] test: add CometUserUdfSuite for custom JVM UDF integration tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds end-to-end tests verifying: - Basic CometUDF execution (integer doubling via Arrow vectors) - Unregistered UDFs correctly fall back to Spark - Multiple UDF invocations in a single query - UDF combined with WHERE filter - CometUdfRegistry API (register, lookup, remove) Also fixes KnownNotNull unwrapping in CometScalaUdf — Spark wraps UDF arguments in KnownNotNull when the UDF is non-nullable, which needs to be stripped before serializing the underlying expression. Co-Authored-By: Claude Opus 4.6 --- .../apache/comet/serde/CometScalaUdf.scala | 13 ++- .../comet/CometArrayExpressionSuite.scala | 2 +- .../org/apache/comet/CometUserUdfSuite.scala | 103 ++++++++++++++++++ .../org/apache/comet/udf/DoubleIntUdf.scala | 50 +++++++++ 4 files changed, 165 insertions(+), 3 deletions(-) create mode 100644 spark/src/test/scala/org/apache/comet/CometUserUdfSuite.scala create mode 100644 spark/src/test/scala/org/apache/comet/udf/DoubleIntUdf.scala diff --git a/spark/src/main/scala/org/apache/comet/serde/CometScalaUdf.scala b/spark/src/main/scala/org/apache/comet/serde/CometScalaUdf.scala index 1958ec3a76..9026c84abe 100644 --- a/spark/src/main/scala/org/apache/comet/serde/CometScalaUdf.scala +++ b/spark/src/main/scala/org/apache/comet/serde/CometScalaUdf.scala @@ -19,7 +19,7 @@ package org.apache.comet.serde -import org.apache.spark.sql.catalyst.expressions.{Attribute, ScalaUDF} +import org.apache.spark.sql.catalyst.expressions.{Attribute, KnownNotNull, ScalaUDF} import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, serializeDataType} @@ -45,7 +45,16 @@ object CometScalaUdf { return None } - val argProtos = expr.children.map(child => exprToProtoInternal(child, inputs, binding)) + // Spark wraps UDF arguments in KnownNotNull when the UDF is declared non-nullable. + // Unwrap these since the CometUDF handles nullability itself. + val unwrappedChildren = expr.children.map { + case KnownNotNull(child) => child + case other => other + } + + val argProtos = unwrappedChildren.map { child => + exprToProtoInternal(child, inputs, binding) + } if (argProtos.exists(_.isEmpty)) { withInfo(expr, s"Failed to serialize one or more arguments for CometUDF '$name'") return None diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index 48b8905035..63936a94b7 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -243,7 +243,7 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp val df = spark.read .parquet(path.toString) .withColumn("arr", array(col("_4"), lit(null), col("_4"))) - .withColumn("idx", udf((_: Int) => 1).apply(col("_4"))) + .withColumn("idx", org.apache.spark.sql.functions.udf((_: Int) => 1).apply(col("_4"))) .withColumn("arrUnsupportedArgs", expr("array_insert(arr, idx, 1)")) checkSparkAnswerAndFallbackReasons( df.select("arrUnsupportedArgs"), diff --git a/spark/src/test/scala/org/apache/comet/CometUserUdfSuite.scala b/spark/src/test/scala/org/apache/comet/CometUserUdfSuite.scala new file mode 100644 index 0000000000..31246efcd9 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometUserUdfSuite.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.types.LongType + +import org.apache.comet.udf.CometUdfRegistry + +class CometUserUdfSuite extends CometTestBase with AdaptiveSparkPlanHelper { + + override def afterEach(): Unit = { + CometUdfRegistry.clear() + super.afterEach() + } + + test("user CometUDF - basic integer doubling") { + CometUdfRegistry.register( + "double_int", + "org.apache.comet.udf.DoubleIntUdf", + LongType, + nullable = true) + spark.udf.register("double_int", (x: Int) => x.toLong * 2L) + + withTable("t") { + sql("CREATE TABLE t (x INT) USING parquet") + sql("INSERT INTO t VALUES (1), (2), (3), (NULL), (100)") + checkSparkAnswerAndOperator(sql("SELECT double_int(x) FROM t")) + } + } + + test("user CometUDF - unregistered UDF falls back to Spark") { + spark.udf.register("triple_int", (x: Int) => x * 3) + + withTable("t") { + sql("CREATE TABLE t (x INT) USING parquet") + sql("INSERT INTO t VALUES (1), (2), (3)") + // Should still produce correct results via Spark fallback + checkSparkAnswer(sql("SELECT triple_int(x) FROM t")) + } + } + + test("user CometUDF - multiple arguments") { + CometUdfRegistry.register( + "double_int", + "org.apache.comet.udf.DoubleIntUdf", + LongType, + nullable = true) + spark.udf.register("double_int", (x: Int) => x.toLong * 2L) + + withTable("t") { + sql("CREATE TABLE t (x INT, y INT) USING parquet") + sql("INSERT INTO t VALUES (10, 20), (NULL, 5), (3, NULL)") + checkSparkAnswerAndOperator(sql("SELECT double_int(x), double_int(y) FROM t")) + } + } + + test("user CometUDF - with filter") { + CometUdfRegistry.register( + "double_int", + "org.apache.comet.udf.DoubleIntUdf", + LongType, + nullable = true) + spark.udf.register("double_int", (x: Int) => x.toLong * 2L) + + withTable("t") { + sql("CREATE TABLE t (x INT) USING parquet") + sql("INSERT INTO t VALUES (1), (2), (3), (4), (5)") + checkSparkAnswerAndOperator(sql("SELECT double_int(x) FROM t WHERE x > 2")) + } + } + + test("CometUdfRegistry - register and lookup") { + assert(!CometUdfRegistry.isRegistered("test_func")) + CometUdfRegistry.register("test_func", "com.example.TestUdf", LongType, nullable = false) + assert(CometUdfRegistry.isRegistered("test_func")) + val entry = CometUdfRegistry.get("test_func") + assert(entry.isDefined) + assert(entry.get.className == "com.example.TestUdf") + assert(entry.get.returnType == LongType) + assert(!entry.get.nullable) + CometUdfRegistry.remove("test_func") + assert(!CometUdfRegistry.isRegistered("test_func")) + } +} diff --git a/spark/src/test/scala/org/apache/comet/udf/DoubleIntUdf.scala b/spark/src/test/scala/org/apache/comet/udf/DoubleIntUdf.scala new file mode 100644 index 0000000000..cf4a45e3b7 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/udf/DoubleIntUdf.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf + +import org.apache.arrow.vector.{BigIntVector, IntVector, ValueVector} + +import org.apache.comet.CometArrowAllocator + +/** + * Test CometUDF that doubles an integer input. Used by CometUserUdfSuite. + */ +class DoubleIntUdf extends CometUDF { + + override def evaluate(inputs: Array[ValueVector]): ValueVector = { + val input = inputs(0).asInstanceOf[IntVector] + val rowCount = input.getValueCount + + val result = new BigIntVector("result", CometArrowAllocator) + result.allocateNew(rowCount) + + var i = 0 + while (i < rowCount) { + if (input.isNull(i)) { + result.setNull(i) + } else { + result.set(i, input.get(i).toLong * 2L) + } + i += 1 + } + result.setValueCount(rowCount) + result + } +} From dc9ef3ce8f197c65e4feceffaa11c77753a54681 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 5 May 2026 18:25:51 -0600 Subject: [PATCH 4/9] ci: add CometUserUdfSuite to PR build workflows --- .github/workflows/pr_build_linux.yml | 1 + .github/workflows/pr_build_macos.yml | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/workflows/pr_build_linux.yml b/.github/workflows/pr_build_linux.yml index b0f09bc43b..efb3f941aa 100644 --- a/.github/workflows/pr_build_linux.yml +++ b/.github/workflows/pr_build_linux.yml @@ -356,6 +356,7 @@ jobs: org.apache.comet.exec.CometJoinSuite org.apache.comet.CometNativeSuite org.apache.comet.CometSparkSessionExtensionsSuite + org.apache.comet.CometUserUdfSuite org.apache.spark.CometPluginsSuite org.apache.spark.CometPluginsDefaultSuite org.apache.spark.CometPluginsNonOverrideSuite diff --git a/.github/workflows/pr_build_macos.yml b/.github/workflows/pr_build_macos.yml index c743d1888a..3a12328d0f 100644 --- a/.github/workflows/pr_build_macos.yml +++ b/.github/workflows/pr_build_macos.yml @@ -195,6 +195,7 @@ jobs: org.apache.comet.exec.CometJoinSuite org.apache.comet.CometNativeSuite org.apache.comet.CometSparkSessionExtensionsSuite + org.apache.comet.CometUserUdfSuite org.apache.spark.CometPluginsSuite org.apache.spark.CometPluginsDefaultSuite org.apache.spark.CometPluginsNonOverrideSuite From 55d2240c685c14726ff5a66e2f5def4fff3d60be Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 5 May 2026 20:35:15 -0600 Subject: [PATCH 5/9] fix: shade DoubleIntUdf test fixture along with CometUDF interface Move the DoubleIntUdf test fixture from spark/src/test/ to common/src/main/ so that its bytecode references to org.apache.arrow are relocated by common's shade plugin to org.apache.comet.shaded.arrow, matching the shaded CometUDF interface that user code sees at runtime. A test-scope class in spark/ was compiled against common/target/classes (unshaded) due to Maven workspace resolution and failed at runtime with AbstractMethodError when dispatched through the shaded interface. Update the user-guide page to import Arrow from org.apache.comet.shaded.arrow, which is the package real users compile against in the published comet-spark JAR. --- .../org/apache/comet/udf/testing}/DoubleIntUdf.scala | 11 +++++++---- docs/source/user-guide/latest/custom-jvm-udfs.md | 12 ++++++++---- .../scala/org/apache/comet/CometUserUdfSuite.scala | 6 +++--- 3 files changed, 18 insertions(+), 11 deletions(-) rename {spark/src/test/scala/org/apache/comet/udf => common/src/main/scala/org/apache/comet/udf/testing}/DoubleIntUdf.scala (75%) diff --git a/spark/src/test/scala/org/apache/comet/udf/DoubleIntUdf.scala b/common/src/main/scala/org/apache/comet/udf/testing/DoubleIntUdf.scala similarity index 75% rename from spark/src/test/scala/org/apache/comet/udf/DoubleIntUdf.scala rename to common/src/main/scala/org/apache/comet/udf/testing/DoubleIntUdf.scala index cf4a45e3b7..99716adac9 100644 --- a/spark/src/test/scala/org/apache/comet/udf/DoubleIntUdf.scala +++ b/common/src/main/scala/org/apache/comet/udf/testing/DoubleIntUdf.scala @@ -17,15 +17,18 @@ * under the License. */ -package org.apache.comet.udf +package org.apache.comet.udf.testing import org.apache.arrow.vector.{BigIntVector, IntVector, ValueVector} import org.apache.comet.CometArrowAllocator +import org.apache.comet.udf.CometUDF -/** - * Test CometUDF that doubles an integer input. Used by CometUserUdfSuite. - */ +// Test fixture for CometUserUdfSuite. Lives in common's main sources so that the Arrow +// references in its bytecode are relocated by common's shade plugin to match the shaded +// CometUDF interface that user code sees at runtime. A test-scope class in spark/ would +// compile against common/target/classes (unshaded) and fail at runtime with +// AbstractMethodError when dispatched through the shaded interface. class DoubleIntUdf extends CometUDF { override def evaluate(inputs: Array[ValueVector]): ValueVector = { diff --git a/docs/source/user-guide/latest/custom-jvm-udfs.md b/docs/source/user-guide/latest/custom-jvm-udfs.md index 7ed151ab14..6c99f5d2cd 100644 --- a/docs/source/user-guide/latest/custom-jvm-udfs.md +++ b/docs/source/user-guide/latest/custom-jvm-udfs.md @@ -37,12 +37,15 @@ The framework consists of: ## Writing a CometUDF -Implement the `org.apache.comet.udf.CometUDF` trait: +Implement the `org.apache.comet.udf.CometUDF` trait. Comet relocates Apache Arrow into +`org.apache.comet.shaded.arrow.*` to avoid version conflicts with Spark's bundled Arrow, +so your implementation must import Arrow types from the shaded package. This is the +same package that the published `comet-spark` JAR exposes on your classpath. ```java -import org.apache.arrow.vector.IntVector; -import org.apache.arrow.vector.BitVector; -import org.apache.arrow.vector.ValueVector; +import org.apache.comet.shaded.arrow.vector.IntVector; +import org.apache.comet.shaded.arrow.vector.BitVector; +import org.apache.comet.shaded.arrow.vector.ValueVector; import org.apache.comet.udf.CometUDF; public class IsPositiveUdf implements CometUDF { @@ -72,6 +75,7 @@ public class IsPositiveUdf implements CometUDF { Key requirements: - The class must have a **public no-arg constructor** +- Arrow types must be imported from `org.apache.comet.shaded.arrow.*` (the relocated package) - Input vectors arrive at the row count of the current batch - Scalar (literal) arguments arrive as length-1 vectors — read at index 0 - The returned vector's length **must match** the longest input vector diff --git a/spark/src/test/scala/org/apache/comet/CometUserUdfSuite.scala b/spark/src/test/scala/org/apache/comet/CometUserUdfSuite.scala index 31246efcd9..8750171983 100644 --- a/spark/src/test/scala/org/apache/comet/CometUserUdfSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometUserUdfSuite.scala @@ -35,7 +35,7 @@ class CometUserUdfSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("user CometUDF - basic integer doubling") { CometUdfRegistry.register( "double_int", - "org.apache.comet.udf.DoubleIntUdf", + "org.apache.comet.udf.testing.DoubleIntUdf", LongType, nullable = true) spark.udf.register("double_int", (x: Int) => x.toLong * 2L) @@ -61,7 +61,7 @@ class CometUserUdfSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("user CometUDF - multiple arguments") { CometUdfRegistry.register( "double_int", - "org.apache.comet.udf.DoubleIntUdf", + "org.apache.comet.udf.testing.DoubleIntUdf", LongType, nullable = true) spark.udf.register("double_int", (x: Int) => x.toLong * 2L) @@ -76,7 +76,7 @@ class CometUserUdfSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("user CometUDF - with filter") { CometUdfRegistry.register( "double_int", - "org.apache.comet.udf.DoubleIntUdf", + "org.apache.comet.udf.testing.DoubleIntUdf", LongType, nullable = true) spark.udf.register("double_int", (x: Int) => x.toLong * 2L) From 0d44afcb998707768a34ecadbbb110e816f6b0fc Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 5 May 2026 20:35:20 -0600 Subject: [PATCH 6/9] test: update ArrayInsertUnsupportedArgs expected fallback reason The PR's new ScalaUDF dispatch in QueryPlanSerde changes the fallback message emitted for an anonymous (no-name) UDF from the generic "scalaudf is not supported" to "ScalaUDF has no name, cannot look up CometUDF registration". Update the test's expected fallback reasons accordingly. --- .../scala/org/apache/comet/CometArrayExpressionSuite.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index 63936a94b7..37c4347022 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -247,7 +247,9 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp .withColumn("arrUnsupportedArgs", expr("array_insert(arr, idx, 1)")) checkSparkAnswerAndFallbackReasons( df.select("arrUnsupportedArgs"), - Set("scalaudf is not supported", "unsupported arguments for ArrayInsert")) + Set( + "ScalaUDF has no name, cannot look up CometUDF registration", + "unsupported arguments for ArrayInsert")) } } } From b39dbb70a1e11fa17cd4644bdbad0deb1182cf00 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 5 May 2026 22:33:23 -0600 Subject: [PATCH 7/9] test: address PR feedback on CometUserUdfSuite Add Scala function overloads (arity 1-3) to CometUdfRegistry.register so callers can register a UDF with both Spark and Comet in a single call, without first wrapping the function in udf(). Update the test suite to use the facade and to assert the expected fallback reason in the unregistered-UDF case. --- .../apache/comet/udf/CometUdfRegistry.scala | 47 +++++++++++++++++++ .../org/apache/comet/CometUserUdfSuite.scala | 29 +++++------- 2 files changed, 58 insertions(+), 18 deletions(-) diff --git a/common/src/main/scala/org/apache/comet/udf/CometUdfRegistry.scala b/common/src/main/scala/org/apache/comet/udf/CometUdfRegistry.scala index 0b17ac66b2..0ae8bf5c0b 100644 --- a/common/src/main/scala/org/apache/comet/udf/CometUdfRegistry.scala +++ b/common/src/main/scala/org/apache/comet/udf/CometUdfRegistry.scala @@ -21,6 +21,8 @@ package org.apache.comet.udf import java.util.concurrent.ConcurrentHashMap +import scala.reflect.runtime.universe.TypeTag + import org.apache.spark.sql.SparkSession import org.apache.spark.sql.types.DataType @@ -99,6 +101,51 @@ object CometUdfRegistry { registry.put(name, UdfEntry(className, returnType, nullable)) } + /** + * Convenience method that registers an arity-1 Scala function with Spark and with Comet in one + * call. + */ + def register[A1: TypeTag, RT: TypeTag]( + spark: SparkSession, + name: String, + className: String, + func: A1 => RT, + returnType: DataType, + nullable: Boolean): Unit = { + spark.udf.register(name, func) + registry.put(name, UdfEntry(className, returnType, nullable)) + } + + /** + * Convenience method that registers an arity-2 Scala function with Spark and with Comet in one + * call. + */ + def register[A1: TypeTag, A2: TypeTag, RT: TypeTag]( + spark: SparkSession, + name: String, + className: String, + func: (A1, A2) => RT, + returnType: DataType, + nullable: Boolean): Unit = { + spark.udf.register(name, func) + registry.put(name, UdfEntry(className, returnType, nullable)) + } + + /** + * Convenience method that registers an arity-3 Scala function with Spark and with Comet in one + * call. + */ + def register[A1: TypeTag, A2: TypeTag, A3: TypeTag, RT: TypeTag]( + spark: SparkSession, + name: String, + className: String, + func: (A1, A2, A3) => RT, + returnType: DataType, + nullable: Boolean): Unit = { + spark.udf.register(name, func) + registry.put(name, UdfEntry(className, returnType, nullable)) + } + /** * Look up a registered CometUDF by its Spark UDF name. * diff --git a/spark/src/test/scala/org/apache/comet/CometUserUdfSuite.scala b/spark/src/test/scala/org/apache/comet/CometUserUdfSuite.scala index 8750171983..0074c5146b 100644 --- a/spark/src/test/scala/org/apache/comet/CometUserUdfSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometUserUdfSuite.scala @@ -32,14 +32,18 @@ class CometUserUdfSuite extends CometTestBase with AdaptiveSparkPlanHelper { super.afterEach() } - test("user CometUDF - basic integer doubling") { + private def registerDoubleInt(): Unit = { CometUdfRegistry.register( + spark, "double_int", "org.apache.comet.udf.testing.DoubleIntUdf", + (x: Int) => x.toLong * 2L, LongType, nullable = true) - spark.udf.register("double_int", (x: Int) => x.toLong * 2L) + } + test("user CometUDF - basic integer doubling") { + registerDoubleInt() withTable("t") { sql("CREATE TABLE t (x INT) USING parquet") sql("INSERT INTO t VALUES (1), (2), (3), (NULL), (100)") @@ -53,19 +57,14 @@ class CometUserUdfSuite extends CometTestBase with AdaptiveSparkPlanHelper { withTable("t") { sql("CREATE TABLE t (x INT) USING parquet") sql("INSERT INTO t VALUES (1), (2), (3)") - // Should still produce correct results via Spark fallback - checkSparkAnswer(sql("SELECT triple_int(x) FROM t")) + checkSparkAnswerAndFallbackReason( + sql("SELECT triple_int(x) FROM t"), + "ScalaUDF 'triple_int' is not registered in CometUdfRegistry") } } test("user CometUDF - multiple arguments") { - CometUdfRegistry.register( - "double_int", - "org.apache.comet.udf.testing.DoubleIntUdf", - LongType, - nullable = true) - spark.udf.register("double_int", (x: Int) => x.toLong * 2L) - + registerDoubleInt() withTable("t") { sql("CREATE TABLE t (x INT, y INT) USING parquet") sql("INSERT INTO t VALUES (10, 20), (NULL, 5), (3, NULL)") @@ -74,13 +73,7 @@ class CometUserUdfSuite extends CometTestBase with AdaptiveSparkPlanHelper { } test("user CometUDF - with filter") { - CometUdfRegistry.register( - "double_int", - "org.apache.comet.udf.testing.DoubleIntUdf", - LongType, - nullable = true) - spark.udf.register("double_int", (x: Int) => x.toLong * 2L) - + registerDoubleInt() withTable("t") { sql("CREATE TABLE t (x INT) USING parquet") sql("INSERT INTO t VALUES (1), (2), (3), (4), (5)") From 5b8b6ed5ebce2cce30dc21b1c54657ad4099e070 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 6 May 2026 06:43:46 -0600 Subject: [PATCH 8/9] refactor: move CometUDF metadata onto trait and add columnar-only registration Lift name, returnType, nullable, and inputTypes from CometUdfRegistry.register arguments onto the CometUDF trait itself. Registration now takes a class: CometUdfRegistry.register(classOf[MyUdf]) CometUdfRegistry.register(spark, classOf[MyUdf], rowFn) CometUdfRegistry.registerColumnarOnly(spark, classOf[MyUdf]) The third form synthesizes a stub Spark UDF (arities 1 to 5) that throws UnsupportedOperationException if invoked row-at-a-time, so users no longer have to write a row-based equivalent just to register a vectorized implementation. --- .../scala/org/apache/comet/udf/CometUDF.scala | 18 ++ .../apache/comet/udf/CometUdfRegistry.scala | 208 +++++++++--------- .../comet/udf/testing/DoubleIntUdf.scala | 7 + .../user-guide/latest/custom-jvm-udfs.md | 71 +++--- .../org/apache/comet/CometUserUdfSuite.scala | 55 +++-- 5 files changed, 214 insertions(+), 145 deletions(-) diff --git a/common/src/main/scala/org/apache/comet/udf/CometUDF.scala b/common/src/main/scala/org/apache/comet/udf/CometUDF.scala index ac7b72a883..1087e65127 100644 --- a/common/src/main/scala/org/apache/comet/udf/CometUDF.scala +++ b/common/src/main/scala/org/apache/comet/udf/CometUDF.scala @@ -20,6 +20,7 @@ package org.apache.comet.udf import org.apache.arrow.vector.ValueVector +import org.apache.spark.sql.types.DataType /** * Scalar UDF invoked from native execution via JNI. Receives Arrow vectors as input and returns @@ -33,5 +34,22 @@ import org.apache.arrow.vector.ValueVector * cached per executor thread for the lifetime of the JVM. */ trait CometUDF { + + /** UDF name as invoked from SQL or DataFrame. Must match the name registered with Spark. */ + def name: String + + /** Output Arrow vector type. */ + def returnType: DataType + + /** Whether the result vector may contain nulls. */ + def nullable: Boolean = true + + /** + * Input data types. Required only for columnar-only registration via + * [[CometUdfRegistry.registerColumnarOnly]]; ignored when a row-based Spark UDF is also + * registered (Spark uses its own input schema in that case). + */ + def inputTypes: Seq[DataType] = Seq.empty + def evaluate(inputs: Array[ValueVector]): ValueVector } diff --git a/common/src/main/scala/org/apache/comet/udf/CometUdfRegistry.scala b/common/src/main/scala/org/apache/comet/udf/CometUdfRegistry.scala index 0ae8bf5c0b..e3571b2a57 100644 --- a/common/src/main/scala/org/apache/comet/udf/CometUdfRegistry.scala +++ b/common/src/main/scala/org/apache/comet/udf/CometUdfRegistry.scala @@ -24,32 +24,23 @@ import java.util.concurrent.ConcurrentHashMap import scala.reflect.runtime.universe.TypeTag import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.api.java.{UDF1, UDF2, UDF3, UDF4, UDF5} import org.apache.spark.sql.types.DataType /** - * Registry for user-defined CometUDF implementations. Users register their UDF class names here - * so that the Comet serde layer can intercept matching Spark UDFs and route them to native - * execution via the JVM UDF bridge. + * Registry for user-defined CometUDF implementations. Spark UDF metadata (name, return type, + * nullability) is read from the [[CometUDF]] class itself, so registration is a single call: * - * Usage: * {{{ - * // Register a CometUDF implementation for a Spark UDF - * CometUdfRegistry.register( - * "my_func", // Spark UDF name (as used in spark.udf.register) - * "com.example.MyUdf", // CometUDF implementation class - * BooleanType, // return type - * nullable = true // whether the result may contain nulls - * ) + * // Comet-only (user has already called spark.udf.register elsewhere): + * CometUdfRegistry.register(classOf[MyUdf]) * - * // Or use the convenience method that also registers the Spark UDF: - * CometUdfRegistry.register( - * spark, - * "my_func", - * "com.example.MyUdf", - * sparkUdf, // the Spark UserDefinedFunction - * BooleanType, - * nullable = true - * ) + * // Comet plus a row-based Spark fallback in one call: + * CometUdfRegistry.register(spark, classOf[MyUdf], (x: Int) => x > 0) + * + * // Columnar-only: no row-based equivalent. Calling the UDF row-at-a-time + * // (e.g. when Comet falls back) raises UnsupportedOperationException: + * CometUdfRegistry.registerColumnarOnly(spark, classOf[MyUdf]) * }}} */ object CometUdfRegistry { @@ -59,111 +50,70 @@ object CometUdfRegistry { private val registry = new ConcurrentHashMap[String, UdfEntry]() /** - * Register a CometUDF implementation for a named Spark UDF. - * - * @param name - * The UDF name as registered with Spark (via spark.udf.register) - * @param className - * Fully-qualified class name implementing CometUDF - * @param returnType - * The return DataType of the UDF - * @param nullable - * Whether the result column may contain nulls - */ - def register(name: String, className: String, returnType: DataType, nullable: Boolean): Unit = { - registry.put(name, UdfEntry(className, returnType, nullable)) - } - - /** - * Convenience method that registers both with Spark and with Comet in one call. - * - * @param spark - * The SparkSession - * @param name - * The UDF name - * @param className - * Fully-qualified CometUDF class name - * @param sparkUdf - * The Spark UserDefinedFunction (for row-at-a-time fallback) - * @param returnType - * The return DataType - * @param nullable - * Whether the result may contain nulls + * Register a CometUDF for use by Comet. The caller is responsible for separately registering a + * row-based Spark UDF under the same name (e.g. via `spark.udf.register(name, fn)`); without + * one, Spark will fail to bind the function unless it has a stub from [[registerColumnarOnly]]. */ - def register( - spark: SparkSession, - name: String, - className: String, - sparkUdf: org.apache.spark.sql.expressions.UserDefinedFunction, - returnType: DataType, - nullable: Boolean): Unit = { - spark.udf.register(name, sparkUdf) - registry.put(name, UdfEntry(className, returnType, nullable)) + def register(udfClass: Class[_ <: CometUDF]): Unit = { + val udf = newInstance(udfClass) + registry.put(udf.name, UdfEntry(udfClass.getName, udf.returnType, udf.nullable)) } - /** - * Convenience method that registers an arity-1 Scala function with Spark and with Comet in one - * call. - */ + /** Register an arity-1 Spark UDF and the matching CometUDF in one call. */ def register[A1: TypeTag, RT: TypeTag]( spark: SparkSession, - name: String, - className: String, - func: A1 => RT, - returnType: DataType, - nullable: Boolean): Unit = { - spark.udf.register(name, func) - registry.put(name, UdfEntry(className, returnType, nullable)) + udfClass: Class[_ <: CometUDF], + func: A1 => RT): Unit = { + val udf = newInstance(udfClass) + spark.udf.register(udf.name, func) + registry.put(udf.name, UdfEntry(udfClass.getName, udf.returnType, udf.nullable)) } - /** - * Convenience method that registers an arity-2 Scala function with Spark and with Comet in one - * call. - */ + /** Register an arity-2 Spark UDF and the matching CometUDF in one call. */ def register[A1: TypeTag, A2: TypeTag, RT: TypeTag]( spark: SparkSession, - name: String, - className: String, - func: (A1, A2) => RT, - returnType: DataType, - nullable: Boolean): Unit = { - spark.udf.register(name, func) - registry.put(name, UdfEntry(className, returnType, nullable)) + udfClass: Class[_ <: CometUDF], + func: (A1, A2) => RT): Unit = { + val udf = newInstance(udfClass) + spark.udf.register(udf.name, func) + registry.put(udf.name, UdfEntry(udfClass.getName, udf.returnType, udf.nullable)) } - /** - * Convenience method that registers an arity-3 Scala function with Spark and with Comet in one - * call. - */ + /** Register an arity-3 Spark UDF and the matching CometUDF in one call. */ def register[A1: TypeTag, A2: TypeTag, A3: TypeTag, RT: TypeTag]( spark: SparkSession, - name: String, - className: String, - func: (A1, A2, A3) => RT, - returnType: DataType, - nullable: Boolean): Unit = { - spark.udf.register(name, func) - registry.put(name, UdfEntry(className, returnType, nullable)) + udfClass: Class[_ <: CometUDF], + func: (A1, A2, A3) => RT): Unit = { + val udf = newInstance(udfClass) + spark.udf.register(udf.name, func) + registry.put(udf.name, UdfEntry(udfClass.getName, udf.returnType, udf.nullable)) } /** - * Look up a registered CometUDF by its Spark UDF name. + * Register a CometUDF without a row-based Spark equivalent. A stub Spark UDF is synthesized so + * Spark can bind the function name during analysis; calling the stub row-at-a-time (i.e. when + * Comet is disabled or falls back) raises [[UnsupportedOperationException]]. * - * @return - * Some(UdfEntry) if registered, None otherwise + * The CometUDF must declare [[CometUDF.inputTypes]] so that the synthesized stub has the + * correct arity. Arities 1 through 5 are supported; declare more inputs only if you have a + * concrete need (and extend the match below). */ + def registerColumnarOnly(spark: SparkSession, udfClass: Class[_ <: CometUDF]): Unit = { + val udf = newInstance(udfClass) + require( + udf.inputTypes.nonEmpty, + s"CometUDF '${udf.name}' must override inputTypes for columnar-only registration") + registerStub(spark, udf) + registry.put(udf.name, UdfEntry(udfClass.getName, udf.returnType, udf.nullable)) + } + + /** Look up a registered CometUDF by its Spark UDF name. */ def get(name: String): Option[UdfEntry] = Option(registry.get(name)) - /** - * Remove a previously registered UDF. - */ - def remove(name: String): Unit = { - registry.remove(name) - } + /** Remove a previously registered UDF. */ + def remove(name: String): Unit = registry.remove(name) - /** - * Check whether a UDF name is registered. - */ + /** Check whether a UDF name is registered. */ def isRegistered(name: String): Boolean = registry.containsKey(name) // Visible for testing @@ -171,4 +121,54 @@ object CometUdfRegistry { // Visible for testing def clear(): Unit = registry.clear() + + private def newInstance(cls: Class[_ <: CometUDF]): CometUDF = + cls.getDeclaredConstructor().newInstance() + + private def registerStub(spark: SparkSession, udf: CometUDF): Unit = { + val name = udf.name + val rt = udf.returnType + def fail(): Nothing = throw new UnsupportedOperationException( + s"CometUDF '$name' is columnar-only and cannot be evaluated row-at-a-time. " + + "Ensure Comet is enabled and supports this query.") + udf.inputTypes.length match { + case 1 => + spark.udf.register( + name, + new UDF1[AnyRef, AnyRef] { override def call(a: AnyRef): AnyRef = fail() }, + rt) + case 2 => + spark.udf.register( + name, + new UDF2[AnyRef, AnyRef, AnyRef] { + override def call(a: AnyRef, b: AnyRef): AnyRef = fail() + }, + rt) + case 3 => + spark.udf.register( + name, + new UDF3[AnyRef, AnyRef, AnyRef, AnyRef] { + override def call(a: AnyRef, b: AnyRef, c: AnyRef): AnyRef = fail() + }, + rt) + case 4 => + spark.udf.register( + name, + new UDF4[AnyRef, AnyRef, AnyRef, AnyRef, AnyRef] { + override def call(a: AnyRef, b: AnyRef, c: AnyRef, d: AnyRef): AnyRef = fail() + }, + rt) + case 5 => + spark.udf.register( + name, + new UDF5[AnyRef, AnyRef, AnyRef, AnyRef, AnyRef, AnyRef] { + override def call(a: AnyRef, b: AnyRef, c: AnyRef, d: AnyRef, e: AnyRef): AnyRef = + fail() + }, + rt) + case n => + throw new UnsupportedOperationException( + s"Columnar-only registration is not yet supported for arity $n") + } + } } diff --git a/common/src/main/scala/org/apache/comet/udf/testing/DoubleIntUdf.scala b/common/src/main/scala/org/apache/comet/udf/testing/DoubleIntUdf.scala index 99716adac9..7a1f80b40f 100644 --- a/common/src/main/scala/org/apache/comet/udf/testing/DoubleIntUdf.scala +++ b/common/src/main/scala/org/apache/comet/udf/testing/DoubleIntUdf.scala @@ -20,6 +20,7 @@ package org.apache.comet.udf.testing import org.apache.arrow.vector.{BigIntVector, IntVector, ValueVector} +import org.apache.spark.sql.types.{DataType, IntegerType, LongType} import org.apache.comet.CometArrowAllocator import org.apache.comet.udf.CometUDF @@ -31,6 +32,12 @@ import org.apache.comet.udf.CometUDF // AbstractMethodError when dispatched through the shaded interface. class DoubleIntUdf extends CometUDF { + override def name: String = "double_int" + + override def returnType: DataType = LongType + + override def inputTypes: Seq[DataType] = Seq(IntegerType) + override def evaluate(inputs: Array[ValueVector]): ValueVector = { val input = inputs(0).asInstanceOf[IntVector] val rowCount = input.getValueCount diff --git a/docs/source/user-guide/latest/custom-jvm-udfs.md b/docs/source/user-guide/latest/custom-jvm-udfs.md index 6c99f5d2cd..5903d80c7a 100644 --- a/docs/source/user-guide/latest/custom-jvm-udfs.md +++ b/docs/source/user-guide/latest/custom-jvm-udfs.md @@ -31,8 +31,9 @@ row-at-a-time execution while keeping the implementation in Java/Scala. The framework consists of: -- **`CometUDF`** — a trait your UDF class must implement, receiving and returning Arrow `ValueVector` instances -- **`CometUdfRegistry`** — a registry that maps Spark UDF names to CometUDF implementation classes +- **`CometUDF`** — a trait your UDF class must implement, declaring its name, return type, optional input types, + and the vectorized `evaluate` method +- **`CometUdfRegistry`** — a registry that introspects your `CometUDF` class to record metadata for the serde layer - **`CometUdfBridge`** — the JNI bridge that native execution uses to invoke your UDF (no user interaction needed) ## Writing a CometUDF @@ -47,9 +48,28 @@ import org.apache.comet.shaded.arrow.vector.IntVector; import org.apache.comet.shaded.arrow.vector.BitVector; import org.apache.comet.shaded.arrow.vector.ValueVector; import org.apache.comet.udf.CometUDF; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import scala.collection.JavaConverters; +import java.util.Arrays; public class IsPositiveUdf implements CometUDF { + @Override + public String name() { return "is_positive"; } + + @Override + public DataType returnType() { return DataTypes.BooleanType; } + + @Override + public boolean nullable() { return true; } + + @Override + public scala.collection.Seq inputTypes() { + return JavaConverters.asScalaBuffer( + Arrays.asList(DataTypes.IntegerType)).toSeq(); + } + @Override public ValueVector evaluate(ValueVector[] inputs) { IntVector input = (IntVector) inputs[0]; @@ -80,48 +100,48 @@ Key requirements: - Scalar (literal) arguments arrive as length-1 vectors — read at index 0 - The returned vector's length **must match** the longest input vector - Instances are cached per executor thread, so implementations should be **stateless** +- `inputTypes` is required only for columnar-only registration (see below) ## Registering a CometUDF -### Option 1: Register Comet UDF separately from Spark UDF +### Option 1: Comet UDF only (existing Spark UDF) If you already have a Spark UDF registered, just tell Comet about the accelerated implementation: ```scala import org.apache.comet.udf.CometUdfRegistry -import org.apache.spark.sql.types.BooleanType // Register the Spark UDF (row-at-a-time fallback) spark.udf.register("is_positive", (x: Int) => x > 0) -// Register the CometUDF (vectorized Arrow implementation) -CometUdfRegistry.register( - "is_positive", // must match the Spark UDF name - "com.example.IsPositiveUdf", // CometUDF implementation class - BooleanType, // return type - nullable = true // whether results may contain nulls -) +// Tell Comet about the vectorized implementation +CometUdfRegistry.register(classOf[IsPositiveUdf]) ``` ### Option 2: Register both in one call ```scala import org.apache.comet.udf.CometUdfRegistry -import org.apache.spark.sql.functions.udf -import org.apache.spark.sql.types.BooleanType - -val sparkUdf = udf((x: Int) => x > 0) - -CometUdfRegistry.register( - spark, - "is_positive", - "com.example.IsPositiveUdf", - sparkUdf, - BooleanType, - nullable = true -) + +CometUdfRegistry.register(spark, classOf[IsPositiveUdf], (x: Int) => x > 0) ``` +### Option 3: Columnar-only (no row-based equivalent) + +If you do not want to write a row-based fallback, Comet can synthesize a stub Spark UDF that +throws `UnsupportedOperationException` if invoked row-at-a-time. The CometUDF must declare +`inputTypes` so the stub has the correct arity. + +```scala +import org.apache.comet.udf.CometUdfRegistry + +CometUdfRegistry.registerColumnarOnly(spark, classOf[IsPositiveUdf]) +``` + +When Comet is enabled and the query is supported, the vectorized implementation runs natively. +If Comet falls back (e.g. an unsupported expression elsewhere in the plan), the stub is invoked +and the query fails with a clear error rather than silently slow row-at-a-time execution. + ## How It Works 1. **Query planning**: When Comet's serde layer encounters a `ScalaUDF` expression with a name registered in @@ -134,7 +154,7 @@ CometUdfRegistry.register( Arrow vectors, and exports the result vector back to native execution. 4. **Fallback**: If Comet is disabled or the UDF is not in the registry, Spark executes the UDF row-at-a-time - using the originally registered Scala/Java function. + using the originally registered Scala/Java function. Columnar-only UDFs raise an exception in this case. ## Packaging and Deployment @@ -150,3 +170,4 @@ The CometUDF class is resolved using the executor's context classloader, so user - Only scalar UDFs are supported (not aggregate or table UDFs) - The UDF must be registered by name — anonymous lambdas without a name cannot be intercepted - All input and output types must be representable as Arrow vectors +- Columnar-only registration currently supports arities 1 through 5 diff --git a/spark/src/test/scala/org/apache/comet/CometUserUdfSuite.scala b/spark/src/test/scala/org/apache/comet/CometUserUdfSuite.scala index 0074c5146b..0543bbf231 100644 --- a/spark/src/test/scala/org/apache/comet/CometUserUdfSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometUserUdfSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.types.LongType import org.apache.comet.udf.CometUdfRegistry +import org.apache.comet.udf.testing.DoubleIntUdf class CometUserUdfSuite extends CometTestBase with AdaptiveSparkPlanHelper { @@ -33,13 +34,7 @@ class CometUserUdfSuite extends CometTestBase with AdaptiveSparkPlanHelper { } private def registerDoubleInt(): Unit = { - CometUdfRegistry.register( - spark, - "double_int", - "org.apache.comet.udf.testing.DoubleIntUdf", - (x: Int) => x.toLong * 2L, - LongType, - nullable = true) + CometUdfRegistry.register(spark, classOf[DoubleIntUdf], (x: Int) => x.toLong * 2L) } test("user CometUDF - basic integer doubling") { @@ -81,16 +76,44 @@ class CometUserUdfSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - test("CometUdfRegistry - register and lookup") { - assert(!CometUdfRegistry.isRegistered("test_func")) - CometUdfRegistry.register("test_func", "com.example.TestUdf", LongType, nullable = false) - assert(CometUdfRegistry.isRegistered("test_func")) - val entry = CometUdfRegistry.get("test_func") + test("user CometUDF - columnar-only registration runs natively") { + CometUdfRegistry.registerColumnarOnly(spark, classOf[DoubleIntUdf]) + withTable("t") { + sql("CREATE TABLE t (x INT) USING parquet") + sql("INSERT INTO t VALUES (1), (2), (3), (NULL), (100)") + // No Spark-side comparison: the synthesized stub intentionally throws when invoked + // row-at-a-time. With Comet enabled, the query routes to the vectorized implementation. + val rows = sql("SELECT double_int(x) FROM t ORDER BY x").collect().toSeq.map(_.get(0)) + assert(rows == Seq(null, 2L, 4L, 6L, 200L)) + } + } + + test("user CometUDF - columnar-only stub raises when Comet is disabled") { + CometUdfRegistry.registerColumnarOnly(spark, classOf[DoubleIntUdf]) + withTable("t") { + sql("CREATE TABLE t (x INT) USING parquet") + sql("INSERT INTO t VALUES (1)") + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + val ex = intercept[org.apache.spark.SparkException] { + sql("SELECT double_int(x) FROM t").collect() + } + assert( + ex.getMessage.contains("columnar-only") || + Option(ex.getCause).exists(_.getMessage.contains("columnar-only"))) + } + } + } + + test("CometUdfRegistry - register from class") { + assert(!CometUdfRegistry.isRegistered("double_int")) + CometUdfRegistry.register(classOf[DoubleIntUdf]) + assert(CometUdfRegistry.isRegistered("double_int")) + val entry = CometUdfRegistry.get("double_int") assert(entry.isDefined) - assert(entry.get.className == "com.example.TestUdf") + assert(entry.get.className == "org.apache.comet.udf.testing.DoubleIntUdf") assert(entry.get.returnType == LongType) - assert(!entry.get.nullable) - CometUdfRegistry.remove("test_func") - assert(!CometUdfRegistry.isRegistered("test_func")) + assert(entry.get.nullable) + CometUdfRegistry.remove("double_int") + assert(!CometUdfRegistry.isRegistered("double_int")) } } From 22439e38c0b40e7c57ad5e4c9560ef1a62ecde1e Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 6 May 2026 06:45:56 -0600 Subject: [PATCH 9/9] docs: clean up custom JVM UDF guide for new registration API --- .../user-guide/latest/custom-jvm-udfs.md | 93 +++++++++++++------ 1 file changed, 64 insertions(+), 29 deletions(-) diff --git a/docs/source/user-guide/latest/custom-jvm-udfs.md b/docs/source/user-guide/latest/custom-jvm-udfs.md index 5903d80c7a..d9126ea8ff 100644 --- a/docs/source/user-guide/latest/custom-jvm-udfs.md +++ b/docs/source/user-guide/latest/custom-jvm-udfs.md @@ -31,10 +31,12 @@ row-at-a-time execution while keeping the implementation in Java/Scala. The framework consists of: -- **`CometUDF`** — a trait your UDF class must implement, declaring its name, return type, optional input types, - and the vectorized `evaluate` method -- **`CometUdfRegistry`** — a registry that introspects your `CometUDF` class to record metadata for the serde layer -- **`CometUdfBridge`** — the JNI bridge that native execution uses to invoke your UDF (no user interaction needed) +- **`CometUDF`**: a trait your UDF class must implement, declaring its name, return type, optional input + types, and the vectorized `evaluate` method. +- **`CometUdfRegistry`**: a registry that introspects your `CometUDF` class to record metadata for the serde + layer. +- **`CometUdfBridge`**: the JNI bridge that native execution uses to invoke your UDF (no user interaction + needed). ## Writing a CometUDF @@ -43,6 +45,8 @@ Implement the `org.apache.comet.udf.CometUDF` trait. Comet relocates Apache Arro so your implementation must import Arrow types from the shaded package. This is the same package that the published `comet-spark` JAR exposes on your classpath. +### Java + ```java import org.apache.comet.shaded.arrow.vector.IntVector; import org.apache.comet.shaded.arrow.vector.BitVector; @@ -50,8 +54,6 @@ import org.apache.comet.shaded.arrow.vector.ValueVector; import org.apache.comet.udf.CometUDF; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; -import scala.collection.JavaConverters; -import java.util.Arrays; public class IsPositiveUdf implements CometUDF { @@ -64,12 +66,6 @@ public class IsPositiveUdf implements CometUDF { @Override public boolean nullable() { return true; } - @Override - public scala.collection.Seq inputTypes() { - return JavaConverters.asScalaBuffer( - Arrays.asList(DataTypes.IntegerType)).toSeq(); - } - @Override public ValueVector evaluate(ValueVector[] inputs) { IntVector input = (IntVector) inputs[0]; @@ -92,18 +88,54 @@ public class IsPositiveUdf implements CometUDF { } ``` +### Scala + +```scala +import org.apache.comet.shaded.arrow.vector.{BitVector, IntVector, ValueVector} +import org.apache.comet.CometArrowAllocator +import org.apache.comet.udf.CometUDF +import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType} + +class IsPositiveUdf extends CometUDF { + override def name: String = "is_positive" + override def returnType: DataType = BooleanType + override def nullable: Boolean = true + + // Optional: declare only if you plan to use registerColumnarOnly. + override def inputTypes: Seq[DataType] = Seq(IntegerType) + + override def evaluate(inputs: Array[ValueVector]): ValueVector = { + val input = inputs(0).asInstanceOf[IntVector] + val rowCount = input.getValueCount + val result = new BitVector("result", CometArrowAllocator) + result.allocateNew(rowCount) + var i = 0 + while (i < rowCount) { + if (input.isNull(i)) result.setNull(i) + else result.set(i, if (input.get(i) > 0) 1 else 0) + i += 1 + } + result.setValueCount(rowCount) + result + } +} +``` + Key requirements: -- The class must have a **public no-arg constructor** -- Arrow types must be imported from `org.apache.comet.shaded.arrow.*` (the relocated package) -- Input vectors arrive at the row count of the current batch -- Scalar (literal) arguments arrive as length-1 vectors — read at index 0 -- The returned vector's length **must match** the longest input vector -- Instances are cached per executor thread, so implementations should be **stateless** -- `inputTypes` is required only for columnar-only registration (see below) +- The class must have a **public no-arg constructor**. +- Arrow types must be imported from `org.apache.comet.shaded.arrow.*` (the relocated package). +- Input vectors arrive at the row count of the current batch. +- Scalar (literal) arguments arrive as length-1 vectors: read at index 0. +- The returned vector's length **must match** the longest input vector. +- Instances are cached per executor thread, so implementations should be **stateless**. +- `inputTypes` is only required for columnar-only registration (see Option 3 below). ## Registering a CometUDF +There are three ways to register a `CometUDF` with Comet, depending on whether you also want a +row-based Spark fallback. + ### Option 1: Comet UDF only (existing Spark UDF) If you already have a Spark UDF registered, just tell Comet about the accelerated implementation: @@ -126,10 +158,13 @@ import org.apache.comet.udf.CometUdfRegistry CometUdfRegistry.register(spark, classOf[IsPositiveUdf], (x: Int) => x > 0) ``` +Convenience overloads exist for arities 1, 2, and 3. For higher arities, use Option 1 and call +`spark.udf.register` separately. + ### Option 3: Columnar-only (no row-based equivalent) If you do not want to write a row-based fallback, Comet can synthesize a stub Spark UDF that -throws `UnsupportedOperationException` if invoked row-at-a-time. The CometUDF must declare +throws `UnsupportedOperationException` if invoked row-at-a-time. The `CometUDF` must declare `inputTypes` so the stub has the correct arity. ```scala @@ -140,7 +175,7 @@ CometUdfRegistry.registerColumnarOnly(spark, classOf[IsPositiveUdf]) When Comet is enabled and the query is supported, the vectorized implementation runs natively. If Comet falls back (e.g. an unsupported expression elsewhere in the plan), the stub is invoked -and the query fails with a clear error rather than silently slow row-at-a-time execution. +and the query fails with a clear error rather than silently degrading to row-at-a-time execution. ## How It Works @@ -158,16 +193,16 @@ and the query fails with a clear error rather than silently slow row-at-a-time e ## Packaging and Deployment -1. Package your `CometUDF` implementation in a JAR -2. Include it on the Spark classpath via `--jars` or `spark.jars` -3. Register the UDF as shown above (in your application code or via a Spark session extension) +1. Package your `CometUDF` implementation in a JAR. +2. Include it on the Spark classpath via `--jars` or `spark.jars`. +3. Register the UDF as shown above (in your application code or via a Spark session extension). -The CometUDF class is resolved using the executor's context classloader, so user-supplied JARs added via +The `CometUDF` class is resolved using the executor's context classloader, so user-supplied JARs added via `spark.jars` or `--jars` are automatically visible. ## Limitations -- Only scalar UDFs are supported (not aggregate or table UDFs) -- The UDF must be registered by name — anonymous lambdas without a name cannot be intercepted -- All input and output types must be representable as Arrow vectors -- Columnar-only registration currently supports arities 1 through 5 +- Only scalar UDFs are supported (not aggregate or table UDFs). +- The UDF must be registered by name: anonymous lambdas without a name cannot be intercepted. +- All input and output types must be representable as Arrow vectors. +- `registerColumnarOnly` currently supports arities 1 through 5.