From 68b34131ec5a11c14b4ba9fffe8945c9fd69f001 Mon Sep 17 00:00:00 2001 From: Yubo Xu <5395686+yuboxx@users.noreply.github.com> Date: Tue, 5 May 2026 23:14:17 -0700 Subject: [PATCH] Fix array_except nullability mismatch --- native/spark-expr/src/comet_scalar_funcs.rs | 271 +++++++++++++++++- .../expressions/array/array_except.sql | 6 +- .../comet/CometArrayExpressionSuite.scala | 34 +++ 3 files changed, 308 insertions(+), 3 deletions(-) diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 9ecb11dc52..e97cb04855 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -28,8 +28,9 @@ use crate::{ SparkContains, SparkDateDiff, SparkDateFromUnixDate, SparkDateTrunc, SparkMakeDate, SparkSecondsToTimestamp, SparkSizeFunc, }; -use arrow::datatypes::DataType; -use datafusion::common::{DataFusionError, Result as DataFusionResult}; +use arrow::array::{Array, ArrayRef, FixedSizeListArray, LargeListArray, ListArray}; +use arrow::datatypes::{DataType, Field}; +use datafusion::common::{DataFusionError, Result as DataFusionResult, ScalarValue}; use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::{ ScalarFunctionArgs, ScalarFunctionImplementation, ScalarUDF, ScalarUDFImpl, Signature, @@ -196,6 +197,16 @@ pub fn create_comet_physical_fun_with_eval_mode( let func = Arc::new(spark_map_sort); make_comet_scalar_udf!("spark_map_sort", func, without data_type) } + "array_except" => { + let df_udf = registry.udf("array_except")?; + let signature = df_udf.signature().clone(); + let wrapper = NormalizingArrayExcept { + delegate: df_udf, + data_type, + signature, + }; + Ok(Arc::new(ScalarUDF::new_from_impl(wrapper))) + } _ => registry.udf(fun_name).map_err(|e| { DataFusionError::Execution(format!( "Function {fun_name} not found in the registry: {e}", @@ -204,6 +215,215 @@ pub fn create_comet_physical_fun_with_eval_mode( } } +/// Normalize inner field nullability of list types to `true`. +/// +/// Spark can produce arrays with different `containsNull` values for the same +/// element type (e.g., `array(not_null_col)` yields `containsNull=false` while +/// `array(nullable_col)` yields `containsNull=true`). DataFusion's array set +/// functions (e.g., `array_except`) use strict `equals_datatype()` which compares +/// inner field nullability, causing errors like: +/// "array_except received incompatible types: List(Int32), List(non-null Int32)" +/// +/// This helper normalizes `false → true` so both inputs have compatible types. +fn normalize_list_data_type(data_type: &DataType) -> DataType { + match data_type { + DataType::List(field) => DataType::List(Arc::new(Field::new( + field.name(), + normalize_list_data_type(field.data_type()), + true, + ))), + DataType::LargeList(field) => DataType::LargeList(Arc::new(Field::new( + field.name(), + normalize_list_data_type(field.data_type()), + true, + ))), + DataType::FixedSizeList(field, size) => DataType::FixedSizeList( + Arc::new(Field::new( + field.name(), + normalize_list_data_type(field.data_type()), + true, + )), + *size, + ), + _ => data_type.clone(), + } +} + +fn normalize_list_inner_nullability(arr: ArrayRef) -> ArrayRef { + match arr.data_type() { + DataType::List(field) + if !field.is_nullable() + || normalize_list_data_type(field.data_type()) != field.data_type().clone() => + { + let list_arr = arr + .as_any() + .downcast_ref::() + .expect("Expected ListArray"); + let values = normalize_list_inner_nullability(list_arr.values().clone()); + let new_field = Arc::new(Field::new(field.name(), values.data_type().clone(), true)); + Arc::new(ListArray::new( + new_field, + list_arr.offsets().clone(), + values, + list_arr.nulls().cloned(), + )) + } + DataType::LargeList(field) + if !field.is_nullable() + || normalize_list_data_type(field.data_type()) != field.data_type().clone() => + { + let list_arr = arr + .as_any() + .downcast_ref::() + .expect("Expected LargeListArray"); + let values = normalize_list_inner_nullability(list_arr.values().clone()); + let new_field = Arc::new(Field::new(field.name(), values.data_type().clone(), true)); + Arc::new(LargeListArray::new( + new_field, + list_arr.offsets().clone(), + values, + list_arr.nulls().cloned(), + )) + } + DataType::FixedSizeList(field, size) + if !field.is_nullable() + || normalize_list_data_type(field.data_type()) != field.data_type().clone() => + { + let list_arr = arr + .as_any() + .downcast_ref::() + .expect("Expected FixedSizeListArray"); + let values = normalize_list_inner_nullability(list_arr.values().clone()); + let new_field = Arc::new(Field::new(field.name(), values.data_type().clone(), true)); + Arc::new(FixedSizeListArray::new( + new_field, + *size, + values, + list_arr.nulls().cloned(), + )) + } + _ => arr, + } +} + +fn normalize_list_scalar(scalar: ScalarValue) -> ScalarValue { + match scalar { + ScalarValue::List(arr) => { + let normalized = normalize_list_inner_nullability(arr); + ScalarValue::List(Arc::new( + normalized + .as_any() + .downcast_ref::() + .expect("Expected ListArray") + .clone(), + )) + } + ScalarValue::LargeList(arr) => { + let normalized = normalize_list_inner_nullability(arr); + ScalarValue::LargeList(Arc::new( + normalized + .as_any() + .downcast_ref::() + .expect("Expected LargeListArray") + .clone(), + )) + } + ScalarValue::FixedSizeList(arr) => { + let normalized = normalize_list_inner_nullability(arr); + ScalarValue::FixedSizeList(Arc::new( + normalized + .as_any() + .downcast_ref::() + .expect("Expected FixedSizeListArray") + .clone(), + )) + } + _ => scalar, + } +} + +/// Strip non-nullable markers from inner list fields in a `DataType` so the +/// return type matches what `normalize_list_inner_nullability` produces at runtime. +fn normalize_return_data_type(dt: DataType) -> DataType { + normalize_list_data_type(&dt) +} + +/// Wraps the DataFusion `array_except` UDF, normalizing inner list field nullability +/// on inputs to avoid `check_datatypes()` type compatibility errors. +struct NormalizingArrayExcept { + delegate: Arc, + data_type: DataType, + signature: Signature, +} + +impl PartialEq for NormalizingArrayExcept { + fn eq(&self, other: &Self) -> bool { + self.data_type == other.data_type && self.signature == other.signature + } +} + +impl Eq for NormalizingArrayExcept {} + +impl std::hash::Hash for NormalizingArrayExcept { + fn hash(&self, state: &mut H) { + self.data_type.hash(state); + self.signature.hash(state); + } +} + +impl Debug for NormalizingArrayExcept { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("NormalizingArrayExcept") + .field("name", &"array_except") + .field("data_type", &self.data_type) + .finish() + } +} + +impl ScalarUDFImpl for NormalizingArrayExcept { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "array_except" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> DataFusionResult { + let dt = self.delegate.inner().return_type(arg_types)?; + // Normalize the return type so it matches the runtime normalization + // applied to inputs. Without this, the planner might declare a return + // type like List(Int32, false) while the actual output is List(Int32, true). + Ok(normalize_return_data_type(dt)) + } + + fn invoke_with_args(&self, mut args: ScalarFunctionArgs) -> DataFusionResult { + for arg in args.args.iter_mut() { + match arg { + ColumnarValue::Array(arr) => { + *arg = ColumnarValue::Array(normalize_list_inner_nullability(Arc::clone(arr))); + } + ColumnarValue::Scalar(scalar) => { + *arg = ColumnarValue::Scalar(normalize_list_scalar(scalar.clone())); + } + } + } + self.delegate.invoke_with_args(args) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> DataFusionResult> { + self.delegate.inner().coerce_types(arg_types) + } + + fn aliases(&self) -> &[String] { + &[] + } +} + fn all_scalar_functions() -> Vec> { vec![ Arc::new(ScalarUDF::new_from_impl(SparkArrayCompact::default())), @@ -303,3 +523,50 @@ impl ScalarUDFImpl for CometScalarFunction { (self.func)(&args.args) } } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::Int32Array; + use arrow::buffer::OffsetBuffer; + + #[test] + fn normalizes_scalar_list_nullability() { + let values = Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef; + let list = Arc::new(ListArray::new( + Arc::new(Field::new("item", DataType::Int32, false)), + OffsetBuffer::new(vec![0, 2].into()), + values, + None, + )); + + let normalized = normalize_list_scalar(ScalarValue::List(list)); + let ScalarValue::List(normalized) = normalized else { + panic!("Expected list scalar"); + }; + let DataType::List(field) = normalized.data_type() else { + panic!("Expected list type"); + }; + assert!(field.is_nullable()); + } + + #[test] + fn normalizes_nested_list_data_type() { + let data_type = DataType::List(Arc::new(Field::new( + "item", + DataType::List(Arc::new(Field::new("item", DataType::Int32, false))), + false, + ))); + + let normalized = normalize_list_data_type(&data_type); + let DataType::List(outer_field) = normalized else { + panic!("Expected outer list type"); + }; + assert!(outer_field.is_nullable()); + + let DataType::List(inner_field) = outer_field.data_type() else { + panic!("Expected inner list type"); + }; + assert!(inner_field.is_nullable()); + } +} diff --git a/spark/src/test/resources/sql-tests/expressions/array/array_except.sql b/spark/src/test/resources/sql-tests/expressions/array/array_except.sql index 1b78896dbf..93499b24ef 100644 --- a/spark/src/test/resources/sql-tests/expressions/array/array_except.sql +++ b/spark/src/test/resources/sql-tests/expressions/array/array_except.sql @@ -35,5 +35,9 @@ query SELECT array_except(array(1, 2, 3), b) FROM test_array_except -- literal + literal -query ignore(https://github.com/apache/datafusion-comet/issues/3646) +query SELECT array_except(array(1, 2, 3), array(2, 3, 4)), array_except(array(1, 2), array()), array_except(array(), array(1)), array_except(cast(NULL as array), array(1)) + +-- nested literal arrays with mixed element nullability +query +SELECT array_except(array(array(1, 2)), array(array(1, NULL))) diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index 0abf4e4e9e..18f2e84892 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -748,6 +748,40 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp } } + test("array_except with mixed nullable inputs (GH-3646)") { + // Spark can produce arrays with different containsNull values depending on + // whether the source columns are NOT NULL constrained. For example, + // `array(1, 2)` produces ArrayType(Int, false) while `array(col)` from a + // nullable column produces ArrayType(Int, true). DataFusion's check_datatypes + // uses strict equals_datatype which compares inner field nullability, causing + // "array_except received incompatible types". This test verifies the fix. + withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[ArrayExcept]) -> "true") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled, 10000) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + + // array(1, 2, 3) is built from non-null literals => containsNull=false + // array(_4) is built from nullable column => containsNull=true + checkSparkAnswerAndOperator( + sql("SELECT array_except(array(1, 2, 3), array(_4)) from t1")) + + // same but with explicitly typed literal to match _3 type + checkSparkAnswerAndOperator( + sql("SELECT array_except(array(cast(1 as int), 2, 3), array(_3)) from t1")) + + // two column-sourced arrays that could have different nullabilities + // after Spark analysis (one with WHERE clause filtering, one without) + checkSparkAnswerAndOperator( + sql("SELECT array_except(array(_2, _3), array(_4)) from t1")) + } + } + } + } + } + test("array_repeat") { withSQLConf( CometConf.getExprAllowIncompatConfigKey(classOf[ArrayRepeat]) -> "true",