Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions datafusion/datasource-parquet/src/row_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ impl TreeNodeVisitor<'_> for PushdownChecker<'_> {
&& (!DataType::is_nested(return_type)
|| self.is_nested_type_supported(return_type))
{
// try to resolve all field name arguments to strinrg literals
// try to resolve all field name arguments to string literals
// if any argument is not a string literal, we can not determine the exact
// leaf path so we fall back to reading the entire struct root column
let field_path = args[1..]
Expand Down Expand Up @@ -766,11 +766,7 @@ fn resolve_struct_field_leaves(

// A leaf matches if its path starts with our prefix.
// e.g., prefix=["s", "value"] matches leaf path ["s", "value"]
// prefix=["s", "outer"] matches ["s", "outer", "inner"]

// a leaf matches if its path starts with our prefix
// for example: prefix=["s", "value"] matches leaf path ["s", "value"]
// prefix=["s", "outer"] matches ["s", "outer", "inner"]
// prefix=["s", "outer"] matches ["s", "outer", "inner"]
let leaf_matches_path = col_path.len() >= prefix.len()
&& col_path.iter().zip(prefix.iter()).all(|(a, b)| a == b);

Expand Down Expand Up @@ -1523,9 +1519,8 @@ mod test {
}

/// Regression test: when a schema has Struct columns, Arrow field indices diverge
/// from Parquet leaf indices (Struct children become separate leaves). The
/// `PrimitiveOnly` fast-path in `leaf_indices_for_roots` assumes they are equal,
/// so a filter on a primitive column *after* a Struct gets the wrong leaf index.
/// from Parquet leaf indices (Struct children become separate leaves).
/// A filter on a primitive column *after* a Struct must use the correct leaf index.
///
/// Schema:
/// Arrow indices: col_a=0 struct_col=1 col_b=2
Expand Down Expand Up @@ -2045,7 +2040,7 @@ mod test {
),
);

// all3 Parquet leaves should be in the projection mask
// all 3 Parquet leaves should be in the projection mask
let expected_mask = ProjectionMask::leaves(schema_descr, [0, 1, 2]);
assert_eq!(read_plan.projection_mask, expected_mask,);
}
Expand Down
183 changes: 174 additions & 9 deletions datafusion/spark/src/function/datetime/make_interval.rs
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't belong in this PR

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will update by removing those changes.

Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,25 @@ use std::sync::Arc;
use arrow::array::{Array, ArrayRef, IntervalMonthDayNanoBuilder, PrimitiveArray};
use arrow::datatypes::DataType::Interval;
use arrow::datatypes::IntervalUnit::MonthDayNano;
use arrow::datatypes::{DataType, IntervalMonthDayNano};
use arrow::datatypes::{DataType, Field, FieldRef, IntervalMonthDayNano};
use datafusion_common::config::ConfigOptions;
use datafusion_common::types::{NativeType, logical_float64, logical_int32};
use datafusion_common::{DataFusionError, Result, ScalarValue, plan_datafusion_err};
use datafusion_common::{
DataFusionError, Result, ScalarValue, exec_err, plan_datafusion_err,
};
use datafusion_expr::{
Coercion, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature,
TypeSignatureClass, Volatility,
Coercion, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF,
ScalarUDFImpl, Signature, TypeSignature, TypeSignatureClass, Volatility,
};
use datafusion_functions::utils::make_scalar_function;

#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SparkMakeInterval {
signature: Signature,
/// Mirrors `spark.sql.ansi.enabled` / `enable_ansi_mode`.
/// When true (failOnError=true in Spark) arithmetic overflow returns an error;
/// when false (default) it returns NULL instead.
ansi_mode: bool,
}

impl Default for SparkMakeInterval {
Expand All @@ -42,6 +49,10 @@ impl Default for SparkMakeInterval {

impl SparkMakeInterval {
pub fn new() -> Self {
Self::new_with_config(&ConfigOptions::default())
}

pub fn new_with_config(config: &ConfigOptions) -> Self {
let int32 = Coercion::new_implicit(
TypeSignatureClass::Native(logical_int32()),
vec![TypeSignatureClass::Integer],
Expand Down Expand Up @@ -100,6 +111,7 @@ impl SparkMakeInterval {

Self {
signature: Signature::one_of(variants, Volatility::Immutable),
ansi_mode: config.execution.enable_ansi_mode,
}
}
}
Expand All @@ -114,20 +126,50 @@ impl ScalarUDFImpl for SparkMakeInterval {
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
// return_field_from_args is the authoritative implementation
Ok(Interval(MonthDayNano))
}

fn with_updated_config(&self, config: &ConfigOptions) -> Option<ScalarUDF> {
Some(ScalarUDF::from(Self::new_with_config(config)))
}

/// Spark nullability rule (mirrors `failOnError` in Spark source):
/// - nullary call → never null (always returns zero interval)
/// - ANSI mode on → nullable only when any input field is nullable
/// - ANSI mode off → always nullable (overflow silently produces NULL)
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
let nullable = if args.arg_fields.is_empty() {
false
} else if self.ansi_mode {
args.arg_fields.iter().any(|f| f.is_nullable())
} else {
true
};
Ok(Arc::new(Field::new(
self.name(),
Interval(MonthDayNano),
nullable,
)))
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
if args.args.is_empty() {
return Ok(ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano(
Some(IntervalMonthDayNano::new(0, 0, 0)),
)));
}
make_scalar_function(make_interval_kernel, vec![])(&args.args)
let ansi_mode = self.ansi_mode;
make_scalar_function(move |cols| make_interval_kernel(cols, ansi_mode), vec![])(
&args.args,
)
}
}

fn make_interval_kernel(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
fn make_interval_kernel(
args: &[ArrayRef],
ansi_mode: bool,
) -> Result<ArrayRef, DataFusionError> {
use arrow::array::AsArray;
use arrow::datatypes::{Float64Type, Int32Type};

Expand Down Expand Up @@ -216,6 +258,11 @@ fn make_interval_kernel(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError>
match make_interval_month_day_nano(y, mo, w, d, h, mi, s) {
Some(v) => builder.append_value(v),
None => {
if ansi_mode {
return exec_err!(
"Arithmetic overflow in make_interval: result does not fit in IntervalMonthDayNano"
);
}
builder.append_null();
continue;
}
Expand Down Expand Up @@ -274,7 +321,7 @@ mod tests {

use super::*;
fn run_make_interval_month_day_nano(arrs: Vec<ArrayRef>) -> Result<ArrayRef> {
make_interval_kernel(&arrs)
make_interval_kernel(&arrs, false)
}

#[test]
Expand Down Expand Up @@ -537,6 +584,14 @@ mod tests {
fn invoke_make_interval_with_args(
args: Vec<ColumnarValue>,
number_rows: usize,
) -> Result<ColumnarValue, DataFusionError> {
invoke_make_interval_with_config(args, number_rows, &ConfigOptions::default())
}

fn invoke_make_interval_with_config(
args: Vec<ColumnarValue>,
number_rows: usize,
config: &ConfigOptions,
) -> Result<ColumnarValue, DataFusionError> {
let arg_fields = args
.iter()
Expand All @@ -547,9 +602,9 @@ mod tests {
arg_fields,
number_rows,
return_field: Field::new("f", Interval(MonthDayNano), true).into(),
config_options: Arc::new(ConfigOptions::default()),
config_options: Arc::new(config.clone()),
};
SparkMakeInterval::new().invoke_with_args(args)
SparkMakeInterval::new_with_config(config).invoke_with_args(args)
}

#[test]
Expand Down Expand Up @@ -601,4 +656,114 @@ mod tests {

Ok(())
}

// --- nullability / return_field_from_args tests ---

fn make_ansi_config() -> ConfigOptions {
let mut cfg = ConfigOptions::default();
cfg.execution.enable_ansi_mode = true;
cfg
}

#[test]
fn return_field_nullary_is_not_nullable() {
let udf = SparkMakeInterval::new();
let field = udf
.return_field_from_args(ReturnFieldArgs {
arg_fields: &[],
scalar_arguments: &[],
})
.unwrap();
assert!(!field.is_nullable(), "nullary call must not be nullable");
}

#[test]
fn return_field_non_ansi_always_nullable() {
// Even with all non-null inputs, non-ANSI mode is always nullable
// because overflow silently returns NULL.
let udf = SparkMakeInterval::new(); // ansi_mode = false
let non_null_field: FieldRef = Arc::new(Field::new("x", DataType::Int32, false));
let field = udf
.return_field_from_args(ReturnFieldArgs {
arg_fields: &[non_null_field],
scalar_arguments: &[None],
})
.unwrap();
assert!(field.is_nullable(), "non-ANSI must always be nullable");
}

#[test]
fn return_field_ansi_mode_not_nullable_when_inputs_not_null() {
// ANSI mode: no overflow → null; nullable only if inputs are nullable.
let udf = SparkMakeInterval::new_with_config(&make_ansi_config());
let non_null_field: FieldRef = Arc::new(Field::new("x", DataType::Int32, false));
let field = udf
.return_field_from_args(ReturnFieldArgs {
arg_fields: &[non_null_field],
scalar_arguments: &[None],
})
.unwrap();
assert!(
!field.is_nullable(),
"ANSI mode with non-null inputs must not be nullable"
);
}

#[test]
fn return_field_ansi_mode_nullable_when_any_input_nullable() {
let udf = SparkMakeInterval::new_with_config(&make_ansi_config());
let nullable_field: FieldRef = Arc::new(Field::new("x", DataType::Int32, true));
let field = udf
.return_field_from_args(ReturnFieldArgs {
arg_fields: &[nullable_field],
scalar_arguments: &[None],
})
.unwrap();
assert!(
field.is_nullable(),
"ANSI mode with nullable inputs must be nullable"
);
}

// --- ANSI mode overflow error tests ---

#[test]
fn ansi_mode_overflow_returns_error() {
let ansi_cfg = make_ansi_config();
let year = ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(i32::MAX)])));
let month = ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(1)])));
let week = ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(0)])));
let day = ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(0)])));
let hour = ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(0)])));
let min = ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(0)])));
let sec = ColumnarValue::Array(Arc::new(Float64Array::from(vec![Some(0.0)])));

let result = invoke_make_interval_with_config(
vec![year, month, week, day, hour, min, sec],
1,
&ansi_cfg,
);
assert!(
result.is_err(),
"ANSI mode overflow must return an error, not NULL"
);
}

#[test]
fn non_ansi_overflow_returns_null() {
// Existing behavior must be preserved: overflow → NULL in non-ANSI mode.
let year = Arc::new(Int32Array::from(vec![Some(i32::MAX)])) as ArrayRef;
let month = Arc::new(Int32Array::from(vec![Some(1)])) as ArrayRef;
let week = Arc::new(Int32Array::from(vec![Some(0)])) as ArrayRef;
let day = Arc::new(Int32Array::from(vec![Some(0)])) as ArrayRef;
let hour = Arc::new(Int32Array::from(vec![Some(0)])) as ArrayRef;
let min = Arc::new(Int32Array::from(vec![Some(0)])) as ArrayRef;
let sec = Arc::new(Float64Array::from(vec![Some(0.0)])) as ArrayRef;

let out = run_make_interval_month_day_nano(vec![
year, month, week, day, hour, min, sec,
])
.unwrap();
assert_eq!(out.null_count(), 1, "non-ANSI overflow must produce NULL");
}
}