Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
271 changes: 269 additions & 2 deletions native/spark-expr/src/comet_scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ use crate::{
SparkContains, SparkDateDiff, SparkDateFromUnixDate, SparkDateTrunc, SparkMakeDate,
SparkSecondsToTimestamp, SparkSizeFunc,
};
use arrow::datatypes::DataType;
use datafusion::common::{DataFusionError, Result as DataFusionResult};
use arrow::array::{Array, ArrayRef, FixedSizeListArray, LargeListArray, ListArray};
use arrow::datatypes::{DataType, Field};
use datafusion::common::{DataFusionError, Result as DataFusionResult, ScalarValue};
use datafusion::execution::FunctionRegistry;
use datafusion::logical_expr::{
ScalarFunctionArgs, ScalarFunctionImplementation, ScalarUDF, ScalarUDFImpl, Signature,
Expand Down Expand Up @@ -196,6 +197,16 @@ pub fn create_comet_physical_fun_with_eval_mode(
let func = Arc::new(spark_map_sort);
make_comet_scalar_udf!("spark_map_sort", func, without data_type)
}
"array_except" => {
let df_udf = registry.udf("array_except")?;
let signature = df_udf.signature().clone();
let wrapper = NormalizingArrayExcept {
delegate: df_udf,
data_type,
signature,
};
Ok(Arc::new(ScalarUDF::new_from_impl(wrapper)))
}
_ => registry.udf(fun_name).map_err(|e| {
DataFusionError::Execution(format!(
"Function {fun_name} not found in the registry: {e}",
Expand All @@ -204,6 +215,215 @@ pub fn create_comet_physical_fun_with_eval_mode(
}
}

/// Normalize inner field nullability of list types to `true`.
///
/// Spark can produce arrays with different `containsNull` values for the same
/// element type (e.g., `array(not_null_col)` yields `containsNull=false` while
/// `array(nullable_col)` yields `containsNull=true`). DataFusion's array set
/// functions (e.g., `array_except`) use strict `equals_datatype()` which compares
/// inner field nullability, causing errors like:
/// "array_except received incompatible types: List(Int32), List(non-null Int32)"
///
/// This helper normalizes `false → true` so both inputs have compatible types.
fn normalize_list_data_type(data_type: &DataType) -> DataType {
match data_type {
DataType::List(field) => DataType::List(Arc::new(Field::new(
field.name(),
normalize_list_data_type(field.data_type()),
true,
))),
DataType::LargeList(field) => DataType::LargeList(Arc::new(Field::new(
field.name(),
normalize_list_data_type(field.data_type()),
true,
))),
DataType::FixedSizeList(field, size) => DataType::FixedSizeList(
Arc::new(Field::new(
field.name(),
normalize_list_data_type(field.data_type()),
true,
)),
*size,
),
_ => data_type.clone(),
}
}

fn normalize_list_inner_nullability(arr: ArrayRef) -> ArrayRef {
match arr.data_type() {
DataType::List(field)
if !field.is_nullable()
|| normalize_list_data_type(field.data_type()) != field.data_type().clone() =>
{
let list_arr = arr
.as_any()
.downcast_ref::<ListArray>()
.expect("Expected ListArray");
let values = normalize_list_inner_nullability(list_arr.values().clone());
let new_field = Arc::new(Field::new(field.name(), values.data_type().clone(), true));
Arc::new(ListArray::new(
new_field,
list_arr.offsets().clone(),
values,
list_arr.nulls().cloned(),
))
}
DataType::LargeList(field)
if !field.is_nullable()
|| normalize_list_data_type(field.data_type()) != field.data_type().clone() =>
{
let list_arr = arr
.as_any()
.downcast_ref::<LargeListArray>()
.expect("Expected LargeListArray");
let values = normalize_list_inner_nullability(list_arr.values().clone());
let new_field = Arc::new(Field::new(field.name(), values.data_type().clone(), true));
Arc::new(LargeListArray::new(
new_field,
list_arr.offsets().clone(),
values,
list_arr.nulls().cloned(),
))
}
DataType::FixedSizeList(field, size)
if !field.is_nullable()
|| normalize_list_data_type(field.data_type()) != field.data_type().clone() =>
{
let list_arr = arr
.as_any()
.downcast_ref::<FixedSizeListArray>()
.expect("Expected FixedSizeListArray");
let values = normalize_list_inner_nullability(list_arr.values().clone());
let new_field = Arc::new(Field::new(field.name(), values.data_type().clone(), true));
Arc::new(FixedSizeListArray::new(
new_field,
*size,
values,
list_arr.nulls().cloned(),
))
}
_ => arr,
}
}

fn normalize_list_scalar(scalar: ScalarValue) -> ScalarValue {
match scalar {
ScalarValue::List(arr) => {
let normalized = normalize_list_inner_nullability(arr);
ScalarValue::List(Arc::new(
normalized
.as_any()
.downcast_ref::<ListArray>()
.expect("Expected ListArray")
.clone(),
))
}
ScalarValue::LargeList(arr) => {
let normalized = normalize_list_inner_nullability(arr);
ScalarValue::LargeList(Arc::new(
normalized
.as_any()
.downcast_ref::<LargeListArray>()
.expect("Expected LargeListArray")
.clone(),
))
}
ScalarValue::FixedSizeList(arr) => {
let normalized = normalize_list_inner_nullability(arr);
ScalarValue::FixedSizeList(Arc::new(
normalized
.as_any()
.downcast_ref::<FixedSizeListArray>()
.expect("Expected FixedSizeListArray")
.clone(),
))
}
_ => scalar,
}
}

/// Strip non-nullable markers from inner list fields in a `DataType` so the
/// return type matches what `normalize_list_inner_nullability` produces at runtime.
fn normalize_return_data_type(dt: DataType) -> DataType {
normalize_list_data_type(&dt)
}

/// Wraps the DataFusion `array_except` UDF, normalizing inner list field nullability
/// on inputs to avoid `check_datatypes()` type compatibility errors.
struct NormalizingArrayExcept {
delegate: Arc<ScalarUDF>,
data_type: DataType,
signature: Signature,
}

impl PartialEq for NormalizingArrayExcept {
fn eq(&self, other: &Self) -> bool {
self.data_type == other.data_type && self.signature == other.signature
}
}

impl Eq for NormalizingArrayExcept {}

impl std::hash::Hash for NormalizingArrayExcept {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.data_type.hash(state);
self.signature.hash(state);
}
}

impl Debug for NormalizingArrayExcept {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NormalizingArrayExcept")
.field("name", &"array_except")
.field("data_type", &self.data_type)
.finish()
}
}

impl ScalarUDFImpl for NormalizingArrayExcept {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"array_except"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult<DataType> {
let dt = self.delegate.inner().return_type(arg_types)?;
// Normalize the return type so it matches the runtime normalization
// applied to inputs. Without this, the planner might declare a return
// type like List(Int32, false) while the actual output is List(Int32, true).
Ok(normalize_return_data_type(dt))
}

fn invoke_with_args(&self, mut args: ScalarFunctionArgs) -> DataFusionResult<ColumnarValue> {
for arg in args.args.iter_mut() {
match arg {
ColumnarValue::Array(arr) => {
*arg = ColumnarValue::Array(normalize_list_inner_nullability(Arc::clone(arr)));
}
ColumnarValue::Scalar(scalar) => {
*arg = ColumnarValue::Scalar(normalize_list_scalar(scalar.clone()));
}
}
}
self.delegate.invoke_with_args(args)
}

fn coerce_types(&self, arg_types: &[DataType]) -> DataFusionResult<Vec<DataType>> {
self.delegate.inner().coerce_types(arg_types)
}

fn aliases(&self) -> &[String] {
&[]
}
}

fn all_scalar_functions() -> Vec<Arc<ScalarUDF>> {
vec![
Arc::new(ScalarUDF::new_from_impl(SparkArrayCompact::default())),
Expand Down Expand Up @@ -303,3 +523,50 @@ impl ScalarUDFImpl for CometScalarFunction {
(self.func)(&args.args)
}
}

#[cfg(test)]
mod tests {
use super::*;
use arrow::array::Int32Array;
use arrow::buffer::OffsetBuffer;

#[test]
fn normalizes_scalar_list_nullability() {
let values = Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef;
let list = Arc::new(ListArray::new(
Arc::new(Field::new("item", DataType::Int32, false)),
OffsetBuffer::new(vec![0, 2].into()),
values,
None,
));

let normalized = normalize_list_scalar(ScalarValue::List(list));
let ScalarValue::List(normalized) = normalized else {
panic!("Expected list scalar");
};
let DataType::List(field) = normalized.data_type() else {
panic!("Expected list type");
};
assert!(field.is_nullable());
}

#[test]
fn normalizes_nested_list_data_type() {
let data_type = DataType::List(Arc::new(Field::new(
"item",
DataType::List(Arc::new(Field::new("item", DataType::Int32, false))),
false,
)));

let normalized = normalize_list_data_type(&data_type);
let DataType::List(outer_field) = normalized else {
panic!("Expected outer list type");
};
assert!(outer_field.is_nullable());

let DataType::List(inner_field) = outer_field.data_type() else {
panic!("Expected inner list type");
};
assert!(inner_field.is_nullable());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,9 @@ query
SELECT array_except(array(1, 2, 3), b) FROM test_array_except

-- literal + literal
query ignore(https://github.com/apache/datafusion-comet/issues/3646)
query
SELECT array_except(array(1, 2, 3), array(2, 3, 4)), array_except(array(1, 2), array()), array_except(array(), array(1)), array_except(cast(NULL as array<int>), array(1))

-- nested literal arrays with mixed element nullability
query
SELECT array_except(array(array(1, 2)), array(array(1, NULL)))
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,40 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp
}
}

test("array_except with mixed nullable inputs (GH-3646)") {
// Spark can produce arrays with different containsNull values depending on
// whether the source columns are NOT NULL constrained. For example,
// `array(1, 2)` produces ArrayType(Int, false) while `array(col)` from a
// nullable column produces ArrayType(Int, true). DataFusion's check_datatypes
// uses strict equals_datatype which compares inner field nullability, causing
// "array_except received incompatible types". This test verifies the fix.
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[ArrayExcept]) -> "true") {
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
withTempView("t1") {
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 10000)
spark.read.parquet(path.toString).createOrReplaceTempView("t1")

// array(1, 2, 3) is built from non-null literals => containsNull=false
// array(_4) is built from nullable column => containsNull=true
checkSparkAnswerAndOperator(
sql("SELECT array_except(array(1, 2, 3), array(_4)) from t1"))

// same but with explicitly typed literal to match _3 type
checkSparkAnswerAndOperator(
sql("SELECT array_except(array(cast(1 as int), 2, 3), array(_3)) from t1"))

// two column-sourced arrays that could have different nullabilities
// after Spark analysis (one with WHERE clause filtering, one without)
checkSparkAnswerAndOperator(
sql("SELECT array_except(array(_2, _3), array(_4)) from t1"))
}
}
}
}
}

test("array_repeat") {
withSQLConf(
CometConf.getExprAllowIncompatConfigKey(classOf[ArrayRepeat]) -> "true",
Expand Down