From 944c108c34b3dbbfea0c73147322a7676f9f7909 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Wed, 25 Feb 2026 13:22:34 -0800 Subject: [PATCH 1/4] refctor_numeric --- .../spark-expr/src/conversion_funcs/cast.rs | 780 +-------------- native/spark-expr/src/conversion_funcs/mod.rs | 1 + .../src/conversion_funcs/numeric.rs | 898 ++++++++++++++++++ 3 files changed, 904 insertions(+), 775 deletions(-) create mode 100644 native/spark-expr/src/conversion_funcs/numeric.rs diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 37604ab4aa..32a7abf650 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -18,11 +18,15 @@ use crate::conversion_funcs::boolean::{ cast_boolean_to_decimal, is_df_cast_from_bool_spark_compatible, }; +use crate::conversion_funcs::numeric::{ + cast_float32_to_decimal128, cast_float64_to_decimal128, cast_int_to_decimal128, + spark_cast_decimal_to_boolean, spark_cast_float32_to_utf8, spark_cast_float64_to_utf8, + spark_cast_int_to_int, spark_cast_nonintegral_numeric_to_integral, +}; use crate::conversion_funcs::string::{ cast_string_to_date, cast_string_to_decimal, cast_string_to_float, cast_string_to_int, cast_string_to_timestamp, is_df_cast_from_string_spark_compatible, spark_cast_utf8_to_boolean, }; -use crate::conversion_funcs::utils::cast_overflow; use crate::conversion_funcs::utils::spark_cast_postprocess; use crate::utils::array_with_timezone; use crate::EvalMode::Legacy; @@ -101,74 +105,6 @@ impl Hash for Cast { } } -macro_rules! cast_float_to_string { - ($from:expr, $eval_mode:expr, $type:ty, $output_type:ty, $offset_type:ty) => {{ - - fn cast( - from: &dyn Array, - _eval_mode: EvalMode, - ) -> SparkResult - where - OffsetSize: OffsetSizeTrait, { - let array = from.as_any().downcast_ref::<$output_type>().unwrap(); - - // If the absolute number is less than 10,000,000 and greater or equal than 0.001, the - // result is expressed without scientific notation with at least one digit on either side of - // the decimal point. Otherwise, Spark uses a mantissa followed by E and an - // exponent. The mantissa has an optional leading minus sign followed by one digit to the - // left of the decimal point, and the minimal number of digits greater than zero to the - // right. The exponent has and optional leading minus sign. - // source: https://docs.databricks.com/en/sql/language-manual/functions/cast.html - - const LOWER_SCIENTIFIC_BOUND: $type = 0.001; - const UPPER_SCIENTIFIC_BOUND: $type = 10000000.0; - - let output_array = array - .iter() - .map(|value| match value { - Some(value) if value == <$type>::INFINITY => Ok(Some("Infinity".to_string())), - Some(value) if value == <$type>::NEG_INFINITY => Ok(Some("-Infinity".to_string())), - Some(value) - if (value.abs() < UPPER_SCIENTIFIC_BOUND - && value.abs() >= LOWER_SCIENTIFIC_BOUND) - || value.abs() == 0.0 => - { - let trailing_zero = if value.fract() == 0.0 { ".0" } else { "" }; - - Ok(Some(format!("{value}{trailing_zero}"))) - } - Some(value) - if value.abs() >= UPPER_SCIENTIFIC_BOUND - || value.abs() < LOWER_SCIENTIFIC_BOUND => - { - let formatted = format!("{value:E}"); - - if formatted.contains(".") { - Ok(Some(formatted)) - } else { - // `formatted` is already in scientific notation and can be split up by E - // in order to add the missing trailing 0 which gets removed for numbers with a fraction of 0.0 - let prepare_number: Vec<&str> = formatted.split("E").collect(); - - let coefficient = prepare_number[0]; - - let exponent = prepare_number[1]; - - Ok(Some(format!("{coefficient}.0E{exponent}"))) - } - } - Some(value) => Ok(Some(value.to_string())), - _ => Ok(None), - }) - .collect::, SparkError>>()?; - - Ok(Arc::new(output_array)) - } - - cast::<$offset_type>($from, $eval_mode) - }}; -} - // eval mode is not needed since all ints can be implemented in binary format macro_rules! cast_whole_num_to_binary { ($array:expr, $primitive_type:ty, $byte_size:expr) => {{ @@ -192,317 +128,6 @@ macro_rules! cast_whole_num_to_binary { }}; } -macro_rules! cast_int_to_int_macro { - ( - $array: expr, - $eval_mode:expr, - $from_arrow_primitive_type: ty, - $to_arrow_primitive_type: ty, - $from_data_type: expr, - $to_native_type: ty, - $spark_from_data_type_name: expr, - $spark_to_data_type_name: expr - ) => {{ - let cast_array = $array - .as_any() - .downcast_ref::>() - .unwrap(); - let spark_int_literal_suffix = match $from_data_type { - &DataType::Int64 => "L", - &DataType::Int16 => "S", - &DataType::Int8 => "T", - _ => "", - }; - - let output_array = match $eval_mode { - EvalMode::Legacy => cast_array - .iter() - .map(|value| match value { - Some(value) => { - Ok::, SparkError>(Some(value as $to_native_type)) - } - _ => Ok(None), - }) - .collect::, _>>(), - _ => cast_array - .iter() - .map(|value| match value { - Some(value) => { - let res = <$to_native_type>::try_from(value); - if res.is_err() { - Err(cast_overflow( - &(value.to_string() + spark_int_literal_suffix), - $spark_from_data_type_name, - $spark_to_data_type_name, - )) - } else { - Ok::, SparkError>(Some(res.unwrap())) - } - } - _ => Ok(None), - }) - .collect::, _>>(), - }?; - let result: SparkResult = Ok(Arc::new(output_array) as ArrayRef); - result - }}; -} - -// When Spark casts to Byte/Short Types, it does not cast directly to Byte/Short. -// It casts to Int first and then to Byte/Short. Because of potential overflows in the Int cast, -// this can cause unexpected Short/Byte cast results. Replicate this behavior. -macro_rules! cast_float_to_int16_down { - ( - $array:expr, - $eval_mode:expr, - $src_array_type:ty, - $dest_array_type:ty, - $rust_src_type:ty, - $rust_dest_type:ty, - $src_type_str:expr, - $dest_type_str:expr, - $format_str:expr - ) => {{ - let cast_array = $array - .as_any() - .downcast_ref::<$src_array_type>() - .expect(concat!("Expected a ", stringify!($src_array_type))); - - let output_array = match $eval_mode { - EvalMode::Ansi => cast_array - .iter() - .map(|value| match value { - Some(value) => { - let is_overflow = value.is_nan() || value.abs() as i32 == i32::MAX; - if is_overflow { - return Err(cast_overflow( - &format!($format_str, value).replace("e", "E"), - $src_type_str, - $dest_type_str, - )); - } - let i32_value = value as i32; - <$rust_dest_type>::try_from(i32_value) - .map_err(|_| { - cast_overflow( - &format!($format_str, value).replace("e", "E"), - $src_type_str, - $dest_type_str, - ) - }) - .map(Some) - } - None => Ok(None), - }) - .collect::>()?, - _ => cast_array - .iter() - .map(|value| match value { - Some(value) => { - let i32_value = value as i32; - Ok::, SparkError>(Some( - i32_value as $rust_dest_type, - )) - } - None => Ok(None), - }) - .collect::>()?, - }; - Ok(Arc::new(output_array) as ArrayRef) - }}; -} - -macro_rules! cast_float_to_int32_up { - ( - $array:expr, - $eval_mode:expr, - $src_array_type:ty, - $dest_array_type:ty, - $rust_src_type:ty, - $rust_dest_type:ty, - $src_type_str:expr, - $dest_type_str:expr, - $max_dest_val:expr, - $format_str:expr - ) => {{ - let cast_array = $array - .as_any() - .downcast_ref::<$src_array_type>() - .expect(concat!("Expected a ", stringify!($src_array_type))); - - let output_array = match $eval_mode { - EvalMode::Ansi => cast_array - .iter() - .map(|value| match value { - Some(value) => { - let is_overflow = - value.is_nan() || value.abs() as $rust_dest_type == $max_dest_val; - if is_overflow { - return Err(cast_overflow( - &format!($format_str, value).replace("e", "E"), - $src_type_str, - $dest_type_str, - )); - } - Ok(Some(value as $rust_dest_type)) - } - None => Ok(None), - }) - .collect::>()?, - _ => cast_array - .iter() - .map(|value| match value { - Some(value) => { - Ok::, SparkError>(Some(value as $rust_dest_type)) - } - None => Ok(None), - }) - .collect::>()?, - }; - Ok(Arc::new(output_array) as ArrayRef) - }}; -} - -// When Spark casts to Byte/Short Types, it does not cast directly to Byte/Short. -// It casts to Int first and then to Byte/Short. Because of potential overflows in the Int cast, -// this can cause unexpected Short/Byte cast results. Replicate this behavior. -macro_rules! cast_decimal_to_int16_down { - ( - $array:expr, - $eval_mode:expr, - $dest_array_type:ty, - $rust_dest_type:ty, - $dest_type_str:expr, - $precision:expr, - $scale:expr - ) => {{ - let cast_array = $array - .as_any() - .downcast_ref::() - .expect("Expected a Decimal128ArrayType"); - - let output_array = match $eval_mode { - EvalMode::Ansi => cast_array - .iter() - .map(|value| match value { - Some(value) => { - let divisor = 10_i128.pow($scale as u32); - let truncated = value / divisor; - let is_overflow = truncated.abs() > i32::MAX.into(); - if is_overflow { - return Err(cast_overflow( - &format!( - "{}BD", - format_decimal_str( - &value.to_string(), - $precision as usize, - $scale - ) - ), - &format!("DECIMAL({},{})", $precision, $scale), - $dest_type_str, - )); - } - let i32_value = truncated as i32; - <$rust_dest_type>::try_from(i32_value) - .map_err(|_| { - cast_overflow( - &format!( - "{}BD", - format_decimal_str( - &value.to_string(), - $precision as usize, - $scale - ) - ), - &format!("DECIMAL({},{})", $precision, $scale), - $dest_type_str, - ) - }) - .map(Some) - } - None => Ok(None), - }) - .collect::>()?, - _ => cast_array - .iter() - .map(|value| match value { - Some(value) => { - let divisor = 10_i128.pow($scale as u32); - let i32_value = (value / divisor) as i32; - Ok::, SparkError>(Some( - i32_value as $rust_dest_type, - )) - } - None => Ok(None), - }) - .collect::>()?, - }; - Ok(Arc::new(output_array) as ArrayRef) - }}; -} - -macro_rules! cast_decimal_to_int32_up { - ( - $array:expr, - $eval_mode:expr, - $dest_array_type:ty, - $rust_dest_type:ty, - $dest_type_str:expr, - $max_dest_val:expr, - $precision:expr, - $scale:expr - ) => {{ - let cast_array = $array - .as_any() - .downcast_ref::() - .expect("Expected a Decimal128ArrayType"); - - let output_array = match $eval_mode { - EvalMode::Ansi => cast_array - .iter() - .map(|value| match value { - Some(value) => { - let divisor = 10_i128.pow($scale as u32); - let truncated = value / divisor; - let is_overflow = truncated.abs() > $max_dest_val.into(); - if is_overflow { - return Err(cast_overflow( - &format!( - "{}BD", - format_decimal_str( - &value.to_string(), - $precision as usize, - $scale - ) - ), - &format!("DECIMAL({},{})", $precision, $scale), - $dest_type_str, - )); - } - Ok(Some(truncated as $rust_dest_type)) - } - None => Ok(None), - }) - .collect::>()?, - _ => cast_array - .iter() - .map(|value| match value { - Some(value) => { - let divisor = 10_i128.pow($scale as u32); - let truncated = value / divisor; - Ok::, SparkError>(Some( - truncated as $rust_dest_type, - )) - } - None => Ok(None), - }) - .collect::>()?, - }; - Ok(Arc::new(output_array) as ArrayRef) - }}; -} - macro_rules! cast_int_to_timestamp_impl { ($array:expr, $builder:expr, $primitive_type:ty) => {{ let arr = $array.as_primitive::<$primitive_type>(); @@ -520,30 +145,6 @@ macro_rules! cast_int_to_timestamp_impl { }}; } -// copied from arrow::dataTypes::Decimal128Type since Decimal128Type::format_decimal can't be called directly -fn format_decimal_str(value_str: &str, precision: usize, scale: i8) -> String { - let (sign, rest) = match value_str.strip_prefix('-') { - Some(stripped) => ("-", stripped), - None => ("", value_str), - }; - let bound = precision.min(rest.len()) + sign.len(); - let value_str = &value_str[0..bound]; - - if scale == 0 { - value_str.to_string() - } else if scale < 0 { - let padding = value_str.len() + scale.unsigned_abs() as usize; - format!("{value_str:0 scale as usize { - // Decimal separator is in the middle of the string - let (whole, decimal) = value_str.split_at(value_str.len() - scale as usize); - format!("{whole}.{decimal}") - } else { - // String has to be padded - format!("{}0.{:0>width$}", sign, rest, width = scale as usize) - } -} - impl Cast { pub fn new( child: Arc, @@ -1101,377 +702,6 @@ fn casts_struct_to_string( Ok(Arc::new(builder.finish())) } -fn cast_float64_to_decimal128( - array: &dyn Array, - precision: u8, - scale: i8, - eval_mode: EvalMode, -) -> SparkResult { - cast_floating_point_to_decimal128::(array, precision, scale, eval_mode) -} - -fn cast_float32_to_decimal128( - array: &dyn Array, - precision: u8, - scale: i8, - eval_mode: EvalMode, -) -> SparkResult { - cast_floating_point_to_decimal128::(array, precision, scale, eval_mode) -} - -fn cast_floating_point_to_decimal128( - array: &dyn Array, - precision: u8, - scale: i8, - eval_mode: EvalMode, -) -> SparkResult -where - ::Native: AsPrimitive, -{ - let input = array.as_any().downcast_ref::>().unwrap(); - let mut cast_array = PrimitiveArray::::builder(input.len()); - - let mul = 10_f64.powi(scale as i32); - - for i in 0..input.len() { - if input.is_null(i) { - cast_array.append_null(); - continue; - } - - let input_value = input.value(i).as_(); - if let Some(v) = (input_value * mul).round().to_i128() { - if is_validate_decimal_precision(v, precision) { - cast_array.append_value(v); - continue; - } - }; - - if eval_mode == EvalMode::Ansi { - return Err(SparkError::NumericValueOutOfRange { - value: input_value.to_string(), - precision, - scale, - }); - } - cast_array.append_null(); - } - - let res = Arc::new( - cast_array - .with_precision_and_scale(precision, scale)? - .finish(), - ) as ArrayRef; - Ok(res) -} - -fn spark_cast_float64_to_utf8( - from: &dyn Array, - _eval_mode: EvalMode, -) -> SparkResult -where - OffsetSize: OffsetSizeTrait, -{ - cast_float_to_string!(from, _eval_mode, f64, Float64Array, OffsetSize) -} - -fn spark_cast_float32_to_utf8( - from: &dyn Array, - _eval_mode: EvalMode, -) -> SparkResult -where - OffsetSize: OffsetSizeTrait, -{ - cast_float_to_string!(from, _eval_mode, f32, Float32Array, OffsetSize) -} - -fn cast_int_to_decimal128_internal( - array: &PrimitiveArray, - precision: u8, - scale: i8, - eval_mode: EvalMode, -) -> SparkResult -where - T: ArrowPrimitiveType, - T::Native: Into, -{ - let mut builder = Decimal128Builder::with_capacity(array.len()); - let multiplier = 10_i128.pow(scale as u32); - - for i in 0..array.len() { - if array.is_null(i) { - builder.append_null(); - } else { - let v = array.value(i).into(); - let scaled = v.checked_mul(multiplier); - match scaled { - Some(scaled) => { - if !is_validate_decimal_precision(scaled, precision) { - match eval_mode { - EvalMode::Ansi => { - return Err(SparkError::NumericValueOutOfRange { - value: v.to_string(), - precision, - scale, - }); - } - EvalMode::Try | EvalMode::Legacy => builder.append_null(), - } - } else { - builder.append_value(scaled); - } - } - _ => match eval_mode { - EvalMode::Ansi => { - return Err(SparkError::NumericValueOutOfRange { - value: v.to_string(), - precision, - scale, - }) - } - EvalMode::Legacy | EvalMode::Try => builder.append_null(), - }, - } - } - } - Ok(Arc::new( - builder.with_precision_and_scale(precision, scale)?.finish(), - )) -} - -fn cast_int_to_decimal128( - array: &dyn Array, - eval_mode: EvalMode, - from_type: &DataType, - to_type: &DataType, - precision: u8, - scale: i8, -) -> SparkResult { - match (from_type, to_type) { - (DataType::Int8, DataType::Decimal128(_p, _s)) => { - cast_int_to_decimal128_internal::( - array.as_primitive::(), - precision, - scale, - eval_mode, - ) - } - (DataType::Int16, DataType::Decimal128(_p, _s)) => { - cast_int_to_decimal128_internal::( - array.as_primitive::(), - precision, - scale, - eval_mode, - ) - } - (DataType::Int32, DataType::Decimal128(_p, _s)) => { - cast_int_to_decimal128_internal::( - array.as_primitive::(), - precision, - scale, - eval_mode, - ) - } - (DataType::Int64, DataType::Decimal128(_p, _s)) => { - cast_int_to_decimal128_internal::( - array.as_primitive::(), - precision, - scale, - eval_mode, - ) - } - _ => Err(SparkError::Internal(format!( - "Unsupported cast from datatype : {}", - from_type - ))), - } -} - -fn spark_cast_int_to_int( - array: &dyn Array, - eval_mode: EvalMode, - from_type: &DataType, - to_type: &DataType, -) -> SparkResult { - match (from_type, to_type) { - (DataType::Int64, DataType::Int32) => cast_int_to_int_macro!( - array, eval_mode, Int64Type, Int32Type, from_type, i32, "BIGINT", "INT" - ), - (DataType::Int64, DataType::Int16) => cast_int_to_int_macro!( - array, eval_mode, Int64Type, Int16Type, from_type, i16, "BIGINT", "SMALLINT" - ), - (DataType::Int64, DataType::Int8) => cast_int_to_int_macro!( - array, eval_mode, Int64Type, Int8Type, from_type, i8, "BIGINT", "TINYINT" - ), - (DataType::Int32, DataType::Int16) => cast_int_to_int_macro!( - array, eval_mode, Int32Type, Int16Type, from_type, i16, "INT", "SMALLINT" - ), - (DataType::Int32, DataType::Int8) => cast_int_to_int_macro!( - array, eval_mode, Int32Type, Int8Type, from_type, i8, "INT", "TINYINT" - ), - (DataType::Int16, DataType::Int8) => cast_int_to_int_macro!( - array, eval_mode, Int16Type, Int8Type, from_type, i8, "SMALLINT", "TINYINT" - ), - _ => unreachable!( - "{}", - format!("invalid integer type {to_type} in cast from {from_type}") - ), - } -} - -fn spark_cast_decimal_to_boolean(array: &dyn Array) -> SparkResult { - let decimal_array = array.as_primitive::(); - let mut result = BooleanBuilder::with_capacity(decimal_array.len()); - for i in 0..decimal_array.len() { - if decimal_array.is_null(i) { - result.append_null() - } else { - result.append_value(!decimal_array.value(i).is_zero()); - } - } - Ok(Arc::new(result.finish())) -} - -fn spark_cast_nonintegral_numeric_to_integral( - array: &dyn Array, - eval_mode: EvalMode, - from_type: &DataType, - to_type: &DataType, -) -> SparkResult { - match (from_type, to_type) { - (DataType::Float32, DataType::Int8) => cast_float_to_int16_down!( - array, - eval_mode, - Float32Array, - Int8Array, - f32, - i8, - "FLOAT", - "TINYINT", - "{:e}" - ), - (DataType::Float32, DataType::Int16) => cast_float_to_int16_down!( - array, - eval_mode, - Float32Array, - Int16Array, - f32, - i16, - "FLOAT", - "SMALLINT", - "{:e}" - ), - (DataType::Float32, DataType::Int32) => cast_float_to_int32_up!( - array, - eval_mode, - Float32Array, - Int32Array, - f32, - i32, - "FLOAT", - "INT", - i32::MAX, - "{:e}" - ), - (DataType::Float32, DataType::Int64) => cast_float_to_int32_up!( - array, - eval_mode, - Float32Array, - Int64Array, - f32, - i64, - "FLOAT", - "BIGINT", - i64::MAX, - "{:e}" - ), - (DataType::Float64, DataType::Int8) => cast_float_to_int16_down!( - array, - eval_mode, - Float64Array, - Int8Array, - f64, - i8, - "DOUBLE", - "TINYINT", - "{:e}D" - ), - (DataType::Float64, DataType::Int16) => cast_float_to_int16_down!( - array, - eval_mode, - Float64Array, - Int16Array, - f64, - i16, - "DOUBLE", - "SMALLINT", - "{:e}D" - ), - (DataType::Float64, DataType::Int32) => cast_float_to_int32_up!( - array, - eval_mode, - Float64Array, - Int32Array, - f64, - i32, - "DOUBLE", - "INT", - i32::MAX, - "{:e}D" - ), - (DataType::Float64, DataType::Int64) => cast_float_to_int32_up!( - array, - eval_mode, - Float64Array, - Int64Array, - f64, - i64, - "DOUBLE", - "BIGINT", - i64::MAX, - "{:e}D" - ), - (DataType::Decimal128(precision, scale), DataType::Int8) => { - cast_decimal_to_int16_down!( - array, eval_mode, Int8Array, i8, "TINYINT", *precision, *scale - ) - } - (DataType::Decimal128(precision, scale), DataType::Int16) => { - cast_decimal_to_int16_down!( - array, eval_mode, Int16Array, i16, "SMALLINT", *precision, *scale - ) - } - (DataType::Decimal128(precision, scale), DataType::Int32) => { - cast_decimal_to_int32_up!( - array, - eval_mode, - Int32Array, - i32, - "INT", - i32::MAX, - *precision, - *scale - ) - } - (DataType::Decimal128(precision, scale), DataType::Int64) => { - cast_decimal_to_int32_up!( - array, - eval_mode, - Int64Array, - i64, - "BIGINT", - i64::MAX, - *precision, - *scale - ) - } - _ => unreachable!( - "{}", - format!("invalid cast from non-integral numeric type: {from_type} to integral numeric type: {to_type}") - ), - } -} - impl Display for Cast { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!( diff --git a/native/spark-expr/src/conversion_funcs/mod.rs b/native/spark-expr/src/conversion_funcs/mod.rs index 33d7a8e211..8e3bbe1c6e 100644 --- a/native/spark-expr/src/conversion_funcs/mod.rs +++ b/native/spark-expr/src/conversion_funcs/mod.rs @@ -17,5 +17,6 @@ mod boolean; pub mod cast; +mod numeric; mod string; mod utils; diff --git a/native/spark-expr/src/conversion_funcs/numeric.rs b/native/spark-expr/src/conversion_funcs/numeric.rs new file mode 100644 index 0000000000..620bf17e76 --- /dev/null +++ b/native/spark-expr/src/conversion_funcs/numeric.rs @@ -0,0 +1,898 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::conversion_funcs::utils::cast_overflow; +use crate::{EvalMode, SparkError, SparkResult}; +use arrow::array::{ + Array, ArrayRef, BooleanBuilder, Decimal128Array, Decimal128Builder, Float32Array, + Float64Array, GenericStringArray, Int16Array, Int32Array, Int64Array, Int8Array, + OffsetSizeTrait, PrimitiveArray, +}; +use arrow::datatypes::{ + is_validate_decimal_precision, ArrowPrimitiveType, DataType, Decimal128Type, Float32Type, + Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, +}; +use num::{cast::AsPrimitive, ToPrimitive, Zero}; +use std::sync::Arc; + +macro_rules! cast_float_to_string { + ($from:expr, $eval_mode:expr, $type:ty, $output_type:ty, $offset_type:ty) => {{ + + fn cast( + from: &dyn Array, + _eval_mode: EvalMode, + ) -> SparkResult + where + OffsetSize: OffsetSizeTrait, { + let array = from.as_any().downcast_ref::<$output_type>().unwrap(); + + // If the absolute number is less than 10,000,000 and greater or equal than 0.001, the + // result is expressed without scientific notation with at least one digit on either side of + // the decimal point. Otherwise, Spark uses a mantissa followed by E and an + // exponent. The mantissa has an optional leading minus sign followed by one digit to the + // left of the decimal point, and the minimal number of digits greater than zero to the + // right. The exponent has and optional leading minus sign. + // source: https://docs.databricks.com/en/sql/language-manual/functions/cast.html + + const LOWER_SCIENTIFIC_BOUND: $type = 0.001; + const UPPER_SCIENTIFIC_BOUND: $type = 10000000.0; + + let output_array = array + .iter() + .map(|value| match value { + Some(value) if value == <$type>::INFINITY => Ok(Some("Infinity".to_string())), + Some(value) if value == <$type>::NEG_INFINITY => Ok(Some("-Infinity".to_string())), + Some(value) + if (value.abs() < UPPER_SCIENTIFIC_BOUND + && value.abs() >= LOWER_SCIENTIFIC_BOUND) + || value.abs() == 0.0 => + { + let trailing_zero = if value.fract() == 0.0 { ".0" } else { "" }; + + Ok(Some(format!("{value}{trailing_zero}"))) + } + Some(value) + if value.abs() >= UPPER_SCIENTIFIC_BOUND + || value.abs() < LOWER_SCIENTIFIC_BOUND => + { + let formatted = format!("{value:E}"); + + if formatted.contains(".") { + Ok(Some(formatted)) + } else { + // `formatted` is already in scientific notation and can be split up by E + // in order to add the missing trailing 0 which gets removed for numbers with a fraction of 0.0 + let prepare_number: Vec<&str> = formatted.split("E").collect(); + + let coefficient = prepare_number[0]; + + let exponent = prepare_number[1]; + + Ok(Some(format!("{coefficient}.0E{exponent}"))) + } + } + Some(value) => Ok(Some(value.to_string())), + _ => Ok(None), + }) + .collect::, SparkError>>()?; + + Ok(Arc::new(output_array)) + } + + cast::<$offset_type>($from, $eval_mode) + }}; +} + +macro_rules! cast_int_to_int_macro { + ( + $array: expr, + $eval_mode:expr, + $from_arrow_primitive_type: ty, + $to_arrow_primitive_type: ty, + $from_data_type: expr, + $to_native_type: ty, + $spark_from_data_type_name: expr, + $spark_to_data_type_name: expr + ) => {{ + let cast_array = $array + .as_any() + .downcast_ref::>() + .unwrap(); + let spark_int_literal_suffix = match $from_data_type { + &DataType::Int64 => "L", + &DataType::Int16 => "S", + &DataType::Int8 => "T", + _ => "", + }; + + let output_array = match $eval_mode { + EvalMode::Legacy => cast_array + .iter() + .map(|value| match value { + Some(value) => { + Ok::, SparkError>(Some(value as $to_native_type)) + } + _ => Ok(None), + }) + .collect::, _>>(), + _ => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let res = <$to_native_type>::try_from(value); + if res.is_err() { + Err(cast_overflow( + &(value.to_string() + spark_int_literal_suffix), + $spark_from_data_type_name, + $spark_to_data_type_name, + )) + } else { + Ok::, SparkError>(Some(res.unwrap())) + } + } + _ => Ok(None), + }) + .collect::, _>>(), + }?; + let result: SparkResult = Ok(Arc::new(output_array) as ArrayRef); + result + }}; +} + +// When Spark casts to Byte/Short Types, it does not cast directly to Byte/Short. +// It casts to Int first and then to Byte/Short. Because of potential overflows in the Int cast, +// this can cause unexpected Short/Byte cast results. Replicate this behavior. +macro_rules! cast_float_to_int16_down { + ( + $array:expr, + $eval_mode:expr, + $src_array_type:ty, + $dest_array_type:ty, + $rust_src_type:ty, + $rust_dest_type:ty, + $src_type_str:expr, + $dest_type_str:expr, + $format_str:expr + ) => {{ + let cast_array = $array + .as_any() + .downcast_ref::<$src_array_type>() + .expect(concat!("Expected a ", stringify!($src_array_type))); + + let output_array = match $eval_mode { + EvalMode::Ansi => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let is_overflow = value.is_nan() || value.abs() as i32 == i32::MAX; + if is_overflow { + return Err(cast_overflow( + &format!($format_str, value).replace("e", "E"), + $src_type_str, + $dest_type_str, + )); + } + let i32_value = value as i32; + <$rust_dest_type>::try_from(i32_value) + .map_err(|_| { + cast_overflow( + &format!($format_str, value).replace("e", "E"), + $src_type_str, + $dest_type_str, + ) + }) + .map(Some) + } + None => Ok(None), + }) + .collect::>()?, + _ => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let i32_value = value as i32; + Ok::, SparkError>(Some( + i32_value as $rust_dest_type, + )) + } + None => Ok(None), + }) + .collect::>()?, + }; + Ok(Arc::new(output_array) as ArrayRef) + }}; +} + +macro_rules! cast_float_to_int32_up { + ( + $array:expr, + $eval_mode:expr, + $src_array_type:ty, + $dest_array_type:ty, + $rust_src_type:ty, + $rust_dest_type:ty, + $src_type_str:expr, + $dest_type_str:expr, + $max_dest_val:expr, + $format_str:expr + ) => {{ + let cast_array = $array + .as_any() + .downcast_ref::<$src_array_type>() + .expect(concat!("Expected a ", stringify!($src_array_type))); + + let output_array = match $eval_mode { + EvalMode::Ansi => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let is_overflow = + value.is_nan() || value.abs() as $rust_dest_type == $max_dest_val; + if is_overflow { + return Err(cast_overflow( + &format!($format_str, value).replace("e", "E"), + $src_type_str, + $dest_type_str, + )); + } + Ok(Some(value as $rust_dest_type)) + } + None => Ok(None), + }) + .collect::>()?, + _ => cast_array + .iter() + .map(|value| match value { + Some(value) => { + Ok::, SparkError>(Some(value as $rust_dest_type)) + } + None => Ok(None), + }) + .collect::>()?, + }; + Ok(Arc::new(output_array) as ArrayRef) + }}; +} + +// When Spark casts to Byte/Short Types, it does not cast directly to Byte/Short. +// It casts to Int first and then to Byte/Short. Because of potential overflows in the Int cast, +// this can cause unexpected Short/Byte cast results. Replicate this behavior. +macro_rules! cast_decimal_to_int16_down { + ( + $array:expr, + $eval_mode:expr, + $dest_array_type:ty, + $rust_dest_type:ty, + $dest_type_str:expr, + $precision:expr, + $scale:expr + ) => {{ + let cast_array = $array + .as_any() + .downcast_ref::() + .expect("Expected a Decimal128ArrayType"); + + let output_array = match $eval_mode { + EvalMode::Ansi => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let divisor = 10_i128.pow($scale as u32); + let truncated = value / divisor; + let is_overflow = truncated.abs() > i32::MAX.into(); + if is_overflow { + return Err(cast_overflow( + &format!( + "{}BD", + format_decimal_str( + &value.to_string(), + $precision as usize, + $scale + ) + ), + &format!("DECIMAL({},{})", $precision, $scale), + $dest_type_str, + )); + } + let i32_value = truncated as i32; + <$rust_dest_type>::try_from(i32_value) + .map_err(|_| { + cast_overflow( + &format!( + "{}BD", + format_decimal_str( + &value.to_string(), + $precision as usize, + $scale + ) + ), + &format!("DECIMAL({},{})", $precision, $scale), + $dest_type_str, + ) + }) + .map(Some) + } + None => Ok(None), + }) + .collect::>()?, + _ => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let divisor = 10_i128.pow($scale as u32); + let i32_value = (value / divisor) as i32; + Ok::, SparkError>(Some( + i32_value as $rust_dest_type, + )) + } + None => Ok(None), + }) + .collect::>()?, + }; + Ok(Arc::new(output_array) as ArrayRef) + }}; +} + +macro_rules! cast_decimal_to_int32_up { + ( + $array:expr, + $eval_mode:expr, + $dest_array_type:ty, + $rust_dest_type:ty, + $dest_type_str:expr, + $max_dest_val:expr, + $precision:expr, + $scale:expr + ) => {{ + let cast_array = $array + .as_any() + .downcast_ref::() + .expect("Expected a Decimal128ArrayType"); + + let output_array = match $eval_mode { + EvalMode::Ansi => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let divisor = 10_i128.pow($scale as u32); + let truncated = value / divisor; + let is_overflow = truncated.abs() > $max_dest_val.into(); + if is_overflow { + return Err(cast_overflow( + &format!( + "{}BD", + format_decimal_str( + &value.to_string(), + $precision as usize, + $scale + ) + ), + &format!("DECIMAL({},{})", $precision, $scale), + $dest_type_str, + )); + } + Ok(Some(truncated as $rust_dest_type)) + } + None => Ok(None), + }) + .collect::>()?, + _ => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let divisor = 10_i128.pow($scale as u32); + let truncated = value / divisor; + Ok::, SparkError>(Some( + truncated as $rust_dest_type, + )) + } + None => Ok(None), + }) + .collect::>()?, + }; + Ok(Arc::new(output_array) as ArrayRef) + }}; +} + +// copied from arrow::dataTypes::Decimal128Type since Decimal128Type::format_decimal can't be called directly +pub(crate) fn format_decimal_str(value_str: &str, precision: usize, scale: i8) -> String { + let (sign, rest) = match value_str.strip_prefix('-') { + Some(stripped) => ("-", stripped), + None => ("", value_str), + }; + let bound = precision.min(rest.len()) + sign.len(); + let value_str = &value_str[0..bound]; + + if scale == 0 { + value_str.to_string() + } else if scale < 0 { + let padding = value_str.len() + scale.unsigned_abs() as usize; + format!("{value_str:0 scale as usize { + // Decimal separator is in the middle of the string + let (whole, decimal) = value_str.split_at(value_str.len() - scale as usize); + format!("{whole}.{decimal}") + } else { + // String has to be padded + format!("{}0.{:0>width$}", sign, rest, width = scale as usize) + } +} + +pub(crate) fn spark_cast_float64_to_utf8( + from: &dyn Array, + _eval_mode: EvalMode, +) -> SparkResult +where + OffsetSize: OffsetSizeTrait, +{ + cast_float_to_string!(from, _eval_mode, f64, Float64Array, OffsetSize) +} + +pub(crate) fn spark_cast_float32_to_utf8( + from: &dyn Array, + _eval_mode: EvalMode, +) -> SparkResult +where + OffsetSize: OffsetSizeTrait, +{ + cast_float_to_string!(from, _eval_mode, f32, Float32Array, OffsetSize) +} + +fn cast_int_to_decimal128_internal( + array: &PrimitiveArray, + precision: u8, + scale: i8, + eval_mode: EvalMode, +) -> SparkResult +where + T: ArrowPrimitiveType, + T::Native: Into, +{ + let mut builder = Decimal128Builder::with_capacity(array.len()); + let multiplier = 10_i128.pow(scale as u32); + + for i in 0..array.len() { + if array.is_null(i) { + builder.append_null(); + } else { + let v = array.value(i).into(); + let scaled = v.checked_mul(multiplier); + match scaled { + Some(scaled) => { + if !is_validate_decimal_precision(scaled, precision) { + match eval_mode { + EvalMode::Ansi => { + return Err(SparkError::NumericValueOutOfRange { + value: v.to_string(), + precision, + scale, + }); + } + EvalMode::Try | EvalMode::Legacy => builder.append_null(), + } + } else { + builder.append_value(scaled); + } + } + _ => match eval_mode { + EvalMode::Ansi => { + return Err(SparkError::NumericValueOutOfRange { + value: v.to_string(), + precision, + scale, + }) + } + EvalMode::Legacy | EvalMode::Try => builder.append_null(), + }, + } + } + } + Ok(Arc::new( + builder.with_precision_and_scale(precision, scale)?.finish(), + )) +} + +pub(crate) fn cast_int_to_decimal128( + array: &dyn Array, + eval_mode: EvalMode, + from_type: &DataType, + to_type: &DataType, + precision: u8, + scale: i8, +) -> SparkResult { + match (from_type, to_type) { + (DataType::Int8, DataType::Decimal128(_p, _s)) => { + cast_int_to_decimal128_internal::( + array.as_primitive::(), + precision, + scale, + eval_mode, + ) + } + (DataType::Int16, DataType::Decimal128(_p, _s)) => { + cast_int_to_decimal128_internal::( + array.as_primitive::(), + precision, + scale, + eval_mode, + ) + } + (DataType::Int32, DataType::Decimal128(_p, _s)) => { + cast_int_to_decimal128_internal::( + array.as_primitive::(), + precision, + scale, + eval_mode, + ) + } + (DataType::Int64, DataType::Decimal128(_p, _s)) => { + cast_int_to_decimal128_internal::( + array.as_primitive::(), + precision, + scale, + eval_mode, + ) + } + _ => Err(SparkError::Internal(format!( + "Unsupported cast from datatype : {}", + from_type + ))), + } +} + +pub(crate) fn spark_cast_int_to_int( + array: &dyn Array, + eval_mode: EvalMode, + from_type: &DataType, + to_type: &DataType, +) -> SparkResult { + match (from_type, to_type) { + (DataType::Int64, DataType::Int32) => cast_int_to_int_macro!( + array, eval_mode, Int64Type, Int32Type, from_type, i32, "BIGINT", "INT" + ), + (DataType::Int64, DataType::Int16) => cast_int_to_int_macro!( + array, eval_mode, Int64Type, Int16Type, from_type, i16, "BIGINT", "SMALLINT" + ), + (DataType::Int64, DataType::Int8) => cast_int_to_int_macro!( + array, eval_mode, Int64Type, Int8Type, from_type, i8, "BIGINT", "TINYINT" + ), + (DataType::Int32, DataType::Int16) => cast_int_to_int_macro!( + array, eval_mode, Int32Type, Int16Type, from_type, i16, "INT", "SMALLINT" + ), + (DataType::Int32, DataType::Int8) => cast_int_to_int_macro!( + array, eval_mode, Int32Type, Int8Type, from_type, i8, "INT", "TINYINT" + ), + (DataType::Int16, DataType::Int8) => cast_int_to_int_macro!( + array, eval_mode, Int16Type, Int8Type, from_type, i8, "SMALLINT", "TINYINT" + ), + _ => unreachable!( + "{}", + format!("invalid integer type {to_type} in cast from {from_type}") + ), + } +} + +pub(crate) fn spark_cast_decimal_to_boolean(array: &dyn Array) -> SparkResult { + let decimal_array = array.as_primitive::(); + let mut result = BooleanBuilder::with_capacity(decimal_array.len()); + for i in 0..decimal_array.len() { + if decimal_array.is_null(i) { + result.append_null() + } else { + result.append_value(!decimal_array.value(i).is_zero()); + } + } + Ok(Arc::new(result.finish())) +} + +pub(crate) fn cast_float64_to_decimal128( + array: &dyn Array, + precision: u8, + scale: i8, + eval_mode: EvalMode, +) -> SparkResult { + cast_floating_point_to_decimal128::(array, precision, scale, eval_mode) +} + +pub(crate) fn cast_float32_to_decimal128( + array: &dyn Array, + precision: u8, + scale: i8, + eval_mode: EvalMode, +) -> SparkResult { + cast_floating_point_to_decimal128::(array, precision, scale, eval_mode) +} + +fn cast_floating_point_to_decimal128( + array: &dyn Array, + precision: u8, + scale: i8, + eval_mode: EvalMode, +) -> SparkResult +where + ::Native: AsPrimitive, +{ + let input = array.as_any().downcast_ref::>().unwrap(); + let mut cast_array = PrimitiveArray::::builder(input.len()); + + let mul = 10_f64.powi(scale as i32); + + for i in 0..input.len() { + if input.is_null(i) { + cast_array.append_null(); + continue; + } + + let input_value = input.value(i).as_(); + if let Some(v) = (input_value * mul).round().to_i128() { + if is_validate_decimal_precision(v, precision) { + cast_array.append_value(v); + continue; + } + }; + + if eval_mode == EvalMode::Ansi { + return Err(SparkError::NumericValueOutOfRange { + value: input_value.to_string(), + precision, + scale, + }); + } + cast_array.append_null(); + } + + let res = Arc::new( + cast_array + .with_precision_and_scale(precision, scale)? + .finish(), + ) as ArrayRef; + Ok(res) +} + +pub(crate) fn spark_cast_nonintegral_numeric_to_integral( + array: &dyn Array, + eval_mode: EvalMode, + from_type: &DataType, + to_type: &DataType, +) -> SparkResult { + match (from_type, to_type) { + (DataType::Float32, DataType::Int8) => cast_float_to_int16_down!( + array, + eval_mode, + Float32Array, + Int8Array, + f32, + i8, + "FLOAT", + "TINYINT", + "{:e}" + ), + (DataType::Float32, DataType::Int16) => cast_float_to_int16_down!( + array, + eval_mode, + Float32Array, + Int16Array, + f32, + i16, + "FLOAT", + "SMALLINT", + "{:e}" + ), + (DataType::Float32, DataType::Int32) => cast_float_to_int32_up!( + array, + eval_mode, + Float32Array, + Int32Array, + f32, + i32, + "FLOAT", + "INT", + i32::MAX, + "{:e}" + ), + (DataType::Float32, DataType::Int64) => cast_float_to_int32_up!( + array, + eval_mode, + Float32Array, + Int64Array, + f32, + i64, + "FLOAT", + "BIGINT", + i64::MAX, + "{:e}" + ), + (DataType::Float64, DataType::Int8) => cast_float_to_int16_down!( + array, + eval_mode, + Float64Array, + Int8Array, + f64, + i8, + "DOUBLE", + "TINYINT", + "{:e}D" + ), + (DataType::Float64, DataType::Int16) => cast_float_to_int16_down!( + array, + eval_mode, + Float64Array, + Int16Array, + f64, + i16, + "DOUBLE", + "SMALLINT", + "{:e}D" + ), + (DataType::Float64, DataType::Int32) => cast_float_to_int32_up!( + array, + eval_mode, + Float64Array, + Int32Array, + f64, + i32, + "DOUBLE", + "INT", + i32::MAX, + "{:e}D" + ), + (DataType::Float64, DataType::Int64) => cast_float_to_int32_up!( + array, + eval_mode, + Float64Array, + Int64Array, + f64, + i64, + "DOUBLE", + "BIGINT", + i64::MAX, + "{:e}D" + ), + (DataType::Decimal128(precision, scale), DataType::Int8) => { + cast_decimal_to_int16_down!( + array, eval_mode, Int8Array, i8, "TINYINT", *precision, *scale + ) + } + (DataType::Decimal128(precision, scale), DataType::Int16) => { + cast_decimal_to_int16_down!( + array, eval_mode, Int16Array, i16, "SMALLINT", *precision, *scale + ) + } + (DataType::Decimal128(precision, scale), DataType::Int32) => { + cast_decimal_to_int32_up!( + array, + eval_mode, + Int32Array, + i32, + "INT", + i32::MAX, + *precision, + *scale + ) + } + (DataType::Decimal128(precision, scale), DataType::Int64) => { + cast_decimal_to_int32_up!( + array, + eval_mode, + Int64Array, + i64, + "BIGINT", + i64::MAX, + *precision, + *scale + ) + } + _ => unreachable!( + "{}", + format!("invalid cast from non-integral numeric type: {from_type} to integral numeric type: {to_type}") + ), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::AsArray; + use core::f64; + + #[test] + #[cfg_attr(miri, ignore)] + fn test_cast_float_to_decimal() { + let a: ArrayRef = Arc::new(Float64Array::from(vec![ + Some(42.), + Some(0.5153125), + Some(-42.4242415), + Some(42e-314), + Some(0.), + Some(-4242.424242), + Some(f64::INFINITY), + Some(f64::NEG_INFINITY), + Some(f64::NAN), + None, + ])); + let b = + cast_floating_point_to_decimal128::(&a, 8, 6, EvalMode::Legacy).unwrap(); + assert_eq!(b.len(), a.len()); + let casted = b.as_primitive::(); + assert_eq!(casted.value(0), 42000000); + // https://github.com/apache/datafusion-comet/issues/1371 + // assert_eq!(casted.value(1), 515313); + assert_eq!(casted.value(2), -42424242); + assert_eq!(casted.value(3), 0); + assert_eq!(casted.value(4), 0); + assert!(casted.is_null(5)); + assert!(casted.is_null(6)); + assert!(casted.is_null(7)); + assert!(casted.is_null(8)); + assert!(casted.is_null(9)); + } + + #[test] + fn test_spark_cast_int_to_int_overflow() { + // Test Int64 -> Int32 overflow + let array: ArrayRef = Arc::new(Int64Array::from(vec![ + Some(i64::MAX), + Some(i64::MIN), + Some(100), + ])); + + // Legacy mode should wrap around + let result = + spark_cast_int_to_int(&array, EvalMode::Legacy, &DataType::Int64, &DataType::Int32) + .unwrap(); + let int32_array = result.as_primitive::(); + assert_eq!(int32_array.value(2), 100); + + // Ansi mode should error on overflow + let result = + spark_cast_int_to_int(&array, EvalMode::Ansi, &DataType::Int64, &DataType::Int32); + assert!(result.is_err()); + } + + #[test] + fn test_spark_cast_decimal_to_boolean() { + let array: ArrayRef = Arc::new( + Decimal128Array::from(vec![Some(0), Some(100), Some(-100), None]) + .with_precision_and_scale(10, 2) + .unwrap(), + ); + let result = spark_cast_decimal_to_boolean(&array).unwrap(); + let bool_array = result.as_boolean(); + assert!(!bool_array.value(0)); // 0 -> false + assert!(bool_array.value(1)); // 100 -> true + assert!(bool_array.value(2)); // -100 -> true + assert!(bool_array.is_null(3)); // null -> null + } + + #[test] + fn test_cast_int_to_decimal128() { + let array: ArrayRef = Arc::new(Int32Array::from(vec![Some(100), Some(-100), None])); + let result = cast_int_to_decimal128( + &array, + EvalMode::Legacy, + &DataType::Int32, + &DataType::Decimal128(10, 2), + 10, + 2, + ) + .unwrap(); + let decimal_array = result.as_primitive::(); + assert_eq!(decimal_array.value(0), 10000); // 100 * 10^2 + assert_eq!(decimal_array.value(1), -10000); // -100 * 10^2 + assert!(decimal_array.is_null(2)); + } +} From aee32993ecdc5a4e9bbe17f498d321827c3dc585 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Mon, 2 Mar 2026 14:57:21 -0800 Subject: [PATCH 2/4] refactor_cast_module_numeric --- .../spark-expr/src/conversion_funcs/cast.rs | 214 +----------------- .../src/conversion_funcs/numeric.rs | 196 +++++++++++++++- .../spark-expr/src/conversion_funcs/utils.rs | 3 +- 3 files changed, 204 insertions(+), 209 deletions(-) diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 32a7abf650..2777ed0a7c 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -20,8 +20,8 @@ use crate::conversion_funcs::boolean::{ }; use crate::conversion_funcs::numeric::{ cast_float32_to_decimal128, cast_float64_to_decimal128, cast_int_to_decimal128, - spark_cast_decimal_to_boolean, spark_cast_float32_to_utf8, spark_cast_float64_to_utf8, - spark_cast_int_to_int, spark_cast_nonintegral_numeric_to_integral, + cast_int_to_timestamp, spark_cast_decimal_to_boolean, spark_cast_float32_to_utf8, + spark_cast_float64_to_utf8, spark_cast_int_to_int, spark_cast_nonintegral_numeric_to_integral, }; use crate::conversion_funcs::string::{ cast_string_to_date, cast_string_to_decimal, cast_string_to_float, cast_string_to_int, @@ -30,12 +30,12 @@ use crate::conversion_funcs::string::{ use crate::conversion_funcs::utils::spark_cast_postprocess; use crate::utils::array_with_timezone; use crate::EvalMode::Legacy; -use crate::{timezone, BinaryOutputStyle}; +use crate::{cast_whole_num_to_binary, timezone, BinaryOutputStyle}; use crate::{EvalMode, SparkError, SparkResult}; use arrow::array::builder::StringBuilder; use arrow::array::{ - BinaryBuilder, BooleanBuilder, Decimal128Builder, DictionaryArray, GenericByteArray, ListArray, - StringArray, StructArray, TimestampMicrosecondBuilder, + BinaryBuilder, DictionaryArray, GenericByteArray, ListArray, StringArray, StructArray, + TimestampMicrosecondBuilder, }; use arrow::compute::can_cast_types; use arrow::datatypes::GenericBinaryType; @@ -44,15 +44,11 @@ use arrow::error::ArrowError; use arrow::{ array::{ cast::AsArray, - types::{Date32Type, Int16Type, Int32Type, Int8Type}, - Array, ArrayRef, Decimal128Array, Float32Array, Float64Array, GenericStringArray, - Int16Array, Int32Array, Int64Array, Int8Array, OffsetSizeTrait, PrimitiveArray, + types::{Date32Type, Int32Type}, + Array, ArrayRef, GenericStringArray, Int16Array, Int32Array, Int64Array, Int8Array, + OffsetSizeTrait, PrimitiveArray, }, compute::{cast_with_options, take, CastOptions}, - datatypes::{ - is_validate_decimal_precision, ArrowPrimitiveType, Decimal128Type, Float32Type, - Float64Type, Int64Type, - }, record_batch::RecordBatch, util::display::FormatOptions, }; @@ -62,7 +58,6 @@ use chrono::{NaiveDate, TimeZone}; use datafusion::common::{internal_err, DataFusionError, Result as DataFusionResult, ScalarValue}; use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_plan::ColumnarValue; -use num::{cast::AsPrimitive, ToPrimitive, Zero}; use std::str::FromStr; use std::{ any::Any, @@ -73,8 +68,6 @@ use std::{ static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f"); -pub(crate) const MICROS_PER_SECOND: i64 = 1000000; - static CAST_OPTIONS: CastOptions = CastOptions { safe: true, format_options: FormatOptions::new() @@ -105,46 +98,6 @@ impl Hash for Cast { } } -// eval mode is not needed since all ints can be implemented in binary format -macro_rules! cast_whole_num_to_binary { - ($array:expr, $primitive_type:ty, $byte_size:expr) => {{ - let input_arr = $array - .as_any() - .downcast_ref::<$primitive_type>() - .ok_or_else(|| SparkError::Internal("Expected numeric array".to_string()))?; - - let len = input_arr.len(); - let mut builder = BinaryBuilder::with_capacity(len, len * $byte_size); - - for i in 0..input_arr.len() { - if input_arr.is_null(i) { - builder.append_null(); - } else { - builder.append_value(input_arr.value(i).to_be_bytes()); - } - } - - Ok(Arc::new(builder.finish()) as ArrayRef) - }}; -} - -macro_rules! cast_int_to_timestamp_impl { - ($array:expr, $builder:expr, $primitive_type:ty) => {{ - let arr = $array.as_primitive::<$primitive_type>(); - for i in 0..arr.len() { - if arr.is_null(i) { - $builder.append_null(); - } else { - // saturating_mul limits to i64::MIN/MAX on overflow instead of panicking, - // which could occur when converting extreme values (e.g., Long.MIN_VALUE) - // matching spark behavior (irrespective of EvalMode) - let micros = (arr.value(i) as i64).saturating_mul(MICROS_PER_SECOND); - $builder.append_value(micros); - } - } - }}; -} - impl Cast { pub fn new( child: Arc, @@ -442,29 +395,6 @@ pub(crate) fn cast_array( Ok(spark_cast_postprocess(cast_result?, from_type, to_type)) } -fn cast_int_to_timestamp( - array_ref: &ArrayRef, - target_tz: &Option>, -) -> SparkResult { - // Input is seconds since epoch, multiply by MICROS_PER_SECOND to get microseconds. - let mut builder = TimestampMicrosecondBuilder::with_capacity(array_ref.len()); - - match array_ref.data_type() { - DataType::Int8 => cast_int_to_timestamp_impl!(array_ref, builder, Int8Type), - DataType::Int16 => cast_int_to_timestamp_impl!(array_ref, builder, Int16Type), - DataType::Int32 => cast_int_to_timestamp_impl!(array_ref, builder, Int32Type), - DataType::Int64 => cast_int_to_timestamp_impl!(array_ref, builder, Int64Type), - dt => { - return Err(SparkError::Internal(format!( - "Unsupported type for cast_int_to_timestamp: {:?}", - dt - ))) - } - } - - Ok(Arc::new(builder.finish().with_timezone_opt(target_tz.clone())) as ArrayRef) -} - fn cast_date_to_timestamp( array_ref: &ArrayRef, cast_options: &SparkCastOptions, @@ -1008,44 +938,6 @@ mod tests { } } - #[test] - // Currently the cast function depending on `f64::powi`, which has unspecified precision according to the doc - // https://doc.rust-lang.org/std/primitive.f64.html#unspecified-precision. - // Miri deliberately apply random floating-point errors to these operations to expose bugs - // https://github.com/rust-lang/miri/issues/4395. - // The random errors may interfere with test cases at rounding edge, so we ignore it on miri for now. - // Once https://github.com/apache/datafusion-comet/issues/1371 is fixed, this should no longer be an issue. - #[cfg_attr(miri, ignore)] - fn test_cast_float_to_decimal() { - let a: ArrayRef = Arc::new(Float64Array::from(vec![ - Some(42.), - Some(0.5153125), - Some(-42.4242415), - Some(42e-314), - Some(0.), - Some(-4242.424242), - Some(f64::INFINITY), - Some(f64::NEG_INFINITY), - Some(f64::NAN), - None, - ])); - let b = - cast_floating_point_to_decimal128::(&a, 8, 6, EvalMode::Legacy).unwrap(); - assert_eq!(b.len(), a.len()); - let casted = b.as_primitive::(); - assert_eq!(casted.value(0), 42000000); - // https://github.com/apache/datafusion-comet/issues/1371 - // assert_eq!(casted.value(1), 515313); - assert_eq!(casted.value(2), -42424242); - assert_eq!(casted.value(3), 0); - assert_eq!(casted.value(4), 0); - assert!(casted.is_null(5)); - assert!(casted.is_null(6)); - assert!(casted.is_null(7)); - assert!(casted.is_null(8)); - assert!(casted.is_null(9)); - } - #[test] fn test_cast_string_array_to_string() { use arrow::array::ListArray; @@ -1096,94 +988,4 @@ mod tests { assert_eq!(r#"[null]"#, string_array.value(2)); assert_eq!(r#"[]"#, string_array.value(3)); } - - #[test] - fn test_cast_int_to_timestamp() { - let timezones: [Option>; 6] = [ - Some(Arc::from("UTC")), - Some(Arc::from("America/New_York")), - Some(Arc::from("America/Los_Angeles")), - Some(Arc::from("Europe/London")), - Some(Arc::from("Asia/Tokyo")), - Some(Arc::from("Australia/Sydney")), - ]; - - for tz in &timezones { - let int8_array: ArrayRef = Arc::new(Int8Array::from(vec![ - Some(0), - Some(1), - Some(-1), - Some(127), - Some(-128), - None, - ])); - - let result = cast_int_to_timestamp(&int8_array, tz).unwrap(); - let ts_array = result.as_primitive::(); - - assert_eq!(ts_array.value(0), 0); - assert_eq!(ts_array.value(1), 1_000_000); - assert_eq!(ts_array.value(2), -1_000_000); - assert_eq!(ts_array.value(3), 127_000_000); - assert_eq!(ts_array.value(4), -128_000_000); - assert!(ts_array.is_null(5)); - assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref())); - - let int16_array: ArrayRef = Arc::new(Int16Array::from(vec![ - Some(0), - Some(1), - Some(-1), - Some(32767), - Some(-32768), - None, - ])); - - let result = cast_int_to_timestamp(&int16_array, tz).unwrap(); - let ts_array = result.as_primitive::(); - - assert_eq!(ts_array.value(0), 0); - assert_eq!(ts_array.value(1), 1_000_000); - assert_eq!(ts_array.value(2), -1_000_000); - assert_eq!(ts_array.value(3), 32_767_000_000_i64); - assert_eq!(ts_array.value(4), -32_768_000_000_i64); - assert!(ts_array.is_null(5)); - assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref())); - - let int32_array: ArrayRef = Arc::new(Int32Array::from(vec![ - Some(0), - Some(1), - Some(-1), - Some(1704067200), - None, - ])); - - let result = cast_int_to_timestamp(&int32_array, tz).unwrap(); - let ts_array = result.as_primitive::(); - - assert_eq!(ts_array.value(0), 0); - assert_eq!(ts_array.value(1), 1_000_000); - assert_eq!(ts_array.value(2), -1_000_000); - assert_eq!(ts_array.value(3), 1_704_067_200_000_000_i64); - assert!(ts_array.is_null(4)); - assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref())); - - let int64_array: ArrayRef = Arc::new(Int64Array::from(vec![ - Some(0), - Some(1), - Some(-1), - Some(i64::MAX), - Some(i64::MIN), - ])); - - let result = cast_int_to_timestamp(&int64_array, tz).unwrap(); - let ts_array = result.as_primitive::(); - - assert_eq!(ts_array.value(0), 0); - assert_eq!(ts_array.value(1), 1_000_000_i64); - assert_eq!(ts_array.value(2), -1_000_000_i64); - assert_eq!(ts_array.value(3), i64::MAX); - assert_eq!(ts_array.value(4), i64::MIN); - assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref())); - } - } } diff --git a/native/spark-expr/src/conversion_funcs/numeric.rs b/native/spark-expr/src/conversion_funcs/numeric.rs index 620bf17e76..ceb440228e 100644 --- a/native/spark-expr/src/conversion_funcs/numeric.rs +++ b/native/spark-expr/src/conversion_funcs/numeric.rs @@ -16,11 +16,12 @@ // under the License. use crate::conversion_funcs::utils::cast_overflow; +use crate::conversion_funcs::utils::MICROS_PER_SECOND; use crate::{EvalMode, SparkError, SparkResult}; use arrow::array::{ - Array, ArrayRef, BooleanBuilder, Decimal128Array, Decimal128Builder, Float32Array, + Array, ArrayRef, AsArray, BooleanBuilder, Decimal128Array, Decimal128Builder, Float32Array, Float64Array, GenericStringArray, Int16Array, Int32Array, Int64Array, Int8Array, - OffsetSizeTrait, PrimitiveArray, + OffsetSizeTrait, PrimitiveArray, TimestampMicrosecondBuilder, }; use arrow::datatypes::{ is_validate_decimal_precision, ArrowPrimitiveType, DataType, Decimal128Type, Float32Type, @@ -97,6 +98,47 @@ macro_rules! cast_float_to_string { }}; } +// eval mode is not needed since all ints can be implemented in binary format +#[macro_export] +macro_rules! cast_whole_num_to_binary { + ($array:expr, $primitive_type:ty, $byte_size:expr) => {{ + let input_arr = $array + .as_any() + .downcast_ref::<$primitive_type>() + .ok_or_else(|| SparkError::Internal("Expected numeric array".to_string()))?; + + let len = input_arr.len(); + let mut builder = BinaryBuilder::with_capacity(len, len * $byte_size); + + for i in 0..input_arr.len() { + if input_arr.is_null(i) { + builder.append_null(); + } else { + builder.append_value(input_arr.value(i).to_be_bytes()); + } + } + + Ok(Arc::new(builder.finish()) as ArrayRef) + }}; +} + +macro_rules! cast_int_to_timestamp_impl { + ($array:expr, $builder:expr, $primitive_type:ty) => {{ + let arr = $array.as_primitive::<$primitive_type>(); + for i in 0..arr.len() { + if arr.is_null(i) { + $builder.append_null(); + } else { + // saturating_mul limits to i64::MIN/MAX on overflow instead of panicking, + // which could occur when converting extreme values (e.g., Long.MIN_VALUE) + // matching spark behavior (irrespective of EvalMode) + let micros = (arr.value(i) as i64).saturating_mul(MICROS_PER_SECOND); + $builder.append_value(micros); + } + } + }}; +} + macro_rules! cast_int_to_int_macro { ( $array: expr, @@ -803,10 +845,34 @@ pub(crate) fn spark_cast_nonintegral_numeric_to_integral( } } +pub(crate) fn cast_int_to_timestamp( + array_ref: &ArrayRef, + target_tz: &Option>, +) -> SparkResult { + // Input is seconds since epoch, multiply by MICROS_PER_SECOND to get microseconds. + let mut builder = TimestampMicrosecondBuilder::with_capacity(array_ref.len()); + + match array_ref.data_type() { + DataType::Int8 => cast_int_to_timestamp_impl!(array_ref, builder, Int8Type), + DataType::Int16 => cast_int_to_timestamp_impl!(array_ref, builder, Int16Type), + DataType::Int32 => cast_int_to_timestamp_impl!(array_ref, builder, Int32Type), + DataType::Int64 => cast_int_to_timestamp_impl!(array_ref, builder, Int64Type), + dt => { + return Err(SparkError::Internal(format!( + "Unsupported type for cast_int_to_timestamp: {:?}", + dt + ))) + } + } + + Ok(Arc::new(builder.finish().with_timezone_opt(target_tz.clone())) as ArrayRef) +} + #[cfg(test)] mod tests { use super::*; use arrow::array::AsArray; + use arrow::datatypes::TimestampMicrosecondType; use core::f64; #[test] @@ -895,4 +961,130 @@ mod tests { assert_eq!(decimal_array.value(1), -10000); // -100 * 10^2 assert!(decimal_array.is_null(2)); } + #[test] + fn test_cast_int_to_timestamp() { + let timezones: [Option>; 6] = [ + Some(Arc::from("UTC")), + Some(Arc::from("America/New_York")), + Some(Arc::from("America/Los_Angeles")), + Some(Arc::from("Europe/London")), + Some(Arc::from("Asia/Tokyo")), + Some(Arc::from("Australia/Sydney")), + ]; + + for tz in &timezones { + let int8_array: ArrayRef = Arc::new(Int8Array::from(vec![ + Some(0), + Some(1), + Some(-1), + Some(127), + Some(-128), + None, + ])); + + let result = cast_int_to_timestamp(&int8_array, tz).unwrap(); + let ts_array = result.as_primitive::(); + + assert_eq!(ts_array.value(0), 0); + assert_eq!(ts_array.value(1), 1_000_000); + assert_eq!(ts_array.value(2), -1_000_000); + assert_eq!(ts_array.value(3), 127_000_000); + assert_eq!(ts_array.value(4), -128_000_000); + assert!(ts_array.is_null(5)); + assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref())); + + let int16_array: ArrayRef = Arc::new(Int16Array::from(vec![ + Some(0), + Some(1), + Some(-1), + Some(32767), + Some(-32768), + None, + ])); + + let result = cast_int_to_timestamp(&int16_array, tz).unwrap(); + let ts_array = result.as_primitive::(); + + assert_eq!(ts_array.value(0), 0); + assert_eq!(ts_array.value(1), 1_000_000); + assert_eq!(ts_array.value(2), -1_000_000); + assert_eq!(ts_array.value(3), 32_767_000_000_i64); + assert_eq!(ts_array.value(4), -32_768_000_000_i64); + assert!(ts_array.is_null(5)); + assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref())); + + let int32_array: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(0), + Some(1), + Some(-1), + Some(1704067200), + None, + ])); + + let result = cast_int_to_timestamp(&int32_array, tz).unwrap(); + let ts_array = result.as_primitive::(); + + assert_eq!(ts_array.value(0), 0); + assert_eq!(ts_array.value(1), 1_000_000); + assert_eq!(ts_array.value(2), -1_000_000); + assert_eq!(ts_array.value(3), 1_704_067_200_000_000_i64); + assert!(ts_array.is_null(4)); + assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref())); + + let int64_array: ArrayRef = Arc::new(Int64Array::from(vec![ + Some(0), + Some(1), + Some(-1), + Some(i64::MAX), + Some(i64::MIN), + ])); + + let result = cast_int_to_timestamp(&int64_array, tz).unwrap(); + let ts_array = result.as_primitive::(); + + assert_eq!(ts_array.value(0), 0); + assert_eq!(ts_array.value(1), 1_000_000_i64); + assert_eq!(ts_array.value(2), -1_000_000_i64); + assert_eq!(ts_array.value(3), i64::MAX); + assert_eq!(ts_array.value(4), i64::MIN); + assert_eq!(ts_array.timezone(), tz.as_ref().map(|s| s.as_ref())); + } + } + #[test] + // Currently the cast function depending on `f64::powi`, which has unspecified precision according to the doc + // https://doc.rust-lang.org/std/primitive.f64.html#unspecified-precision. + // Miri deliberately apply random floating-point errors to these operations to expose bugs + // https://github.com/rust-lang/miri/issues/4395. + // The random errors may interfere with test cases at rounding edge, so we ignore it on miri for now. + // Once https://github.com/apache/datafusion-comet/issues/1371 is fixed, this should no longer be an issue. + #[cfg_attr(miri, ignore)] + fn test_cast_float_to_decimal() { + let a: ArrayRef = Arc::new(Float64Array::from(vec![ + Some(42.), + Some(0.5153125), + Some(-42.4242415), + Some(42e-314), + Some(0.), + Some(-4242.424242), + Some(f64::INFINITY), + Some(f64::NEG_INFINITY), + Some(f64::NAN), + None, + ])); + let b = + cast_floating_point_to_decimal128::(&a, 8, 6, EvalMode::Legacy).unwrap(); + assert_eq!(b.len(), a.len()); + let casted = b.as_primitive::(); + assert_eq!(casted.value(0), 42000000); + // https://github.com/apache/datafusion-comet/issues/1371 + // assert_eq!(casted.value(1), 515313); + assert_eq!(casted.value(2), -42424242); + assert_eq!(casted.value(3), 0); + assert_eq!(casted.value(4), 0); + assert!(casted.is_null(5)); + assert!(casted.is_null(6)); + assert!(casted.is_null(7)); + assert!(casted.is_null(8)); + assert!(casted.is_null(9)); + } } diff --git a/native/spark-expr/src/conversion_funcs/utils.rs b/native/spark-expr/src/conversion_funcs/utils.rs index bac080a968..174efb0b87 100644 --- a/native/spark-expr/src/conversion_funcs/utils.rs +++ b/native/spark-expr/src/conversion_funcs/utils.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use crate::cast::MICROS_PER_SECOND; use crate::SparkError; use arrow::array::{ Array, ArrayRef, ArrowPrimitiveType, AsArray, GenericStringArray, PrimitiveArray, @@ -27,6 +26,8 @@ use datafusion::common::cast::as_generic_string_array; use num::integer::div_floor; use std::sync::Arc; +pub(crate) const MICROS_PER_SECOND: i64 = 1000000; + /// A fork & modified version of Arrow's `unary_dyn` which is being deprecated pub fn unary_dyn(array: &ArrayRef, op: F) -> Result where From f48b314609533fb06e55e4920a87be9cf10e3a95 Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Mon, 2 Mar 2026 22:39:48 -0800 Subject: [PATCH 3/4] refactor_cast_module_temporal --- .../spark-expr/src/conversion_funcs/cast.rs | 132 ++-------------- native/spark-expr/src/conversion_funcs/mod.rs | 1 + .../src/conversion_funcs/temporal.rs | 145 ++++++++++++++++++ 3 files changed, 158 insertions(+), 120 deletions(-) create mode 100644 native/spark-expr/src/conversion_funcs/temporal.rs diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 4b93c79418..ff09dbe06e 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -29,15 +29,18 @@ use crate::conversion_funcs::string::{ cast_string_to_date, cast_string_to_decimal, cast_string_to_float, cast_string_to_int, cast_string_to_timestamp, is_df_cast_from_string_spark_compatible, spark_cast_utf8_to_boolean, }; +use crate::conversion_funcs::temporal::{ + cast_date_to_timestamp, is_df_cast_from_date_spark_compatible, + is_df_cast_from_timestamp_spark_compatible, +}; use crate::conversion_funcs::utils::spark_cast_postprocess; use crate::utils::array_with_timezone; use crate::EvalMode::Legacy; -use crate::{cast_whole_num_to_binary, timezone, BinaryOutputStyle}; -use crate::{EvalMode, SparkError, SparkResult}; +use crate::{cast_whole_num_to_binary, BinaryOutputStyle}; +use crate::{EvalMode, SparkError}; use arrow::array::builder::StringBuilder; use arrow::array::{ - BinaryBuilder, DictionaryArray, GenericByteArray, ListArray, MapArray, StringArray, - StructArray, TimestampMicrosecondBuilder, + BinaryBuilder, DictionaryArray, GenericByteArray, ListArray, MapArray, StringArray, StructArray, }; use arrow::compute::can_cast_types; use arrow::datatypes::{ArrowDictionaryKeyType, ArrowNativeType, DataType, Schema}; @@ -45,10 +48,8 @@ use arrow::datatypes::{Field, Fields, GenericBinaryType}; use arrow::error::ArrowError; use arrow::{ array::{ - cast::AsArray, - types::{Date32Type, Int32Type}, - Array, ArrayRef, GenericStringArray, Int16Array, Int32Array, Int64Array, Int8Array, - OffsetSizeTrait, PrimitiveArray, + cast::AsArray, types::Int32Type, Array, ArrayRef, GenericStringArray, Int16Array, + Int32Array, Int64Array, Int8Array, OffsetSizeTrait, PrimitiveArray, }, compute::{cast_with_options, take, CastOptions}, record_batch::RecordBatch, @@ -56,11 +57,9 @@ use arrow::{ }; use base64::prelude::BASE64_STANDARD_NO_PAD; use base64::Engine; -use chrono::{NaiveDate, TimeZone}; use datafusion::common::{internal_err, DataFusionError, Result as DataFusionResult, ScalarValue}; use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_plan::ColumnarValue; -use std::str::FromStr; use std::{ any::Any, fmt::{Debug, Display, Formatter}, @@ -404,50 +403,6 @@ pub(crate) fn cast_array( Ok(spark_cast_postprocess(cast_result?, &from_type, to_type)) } -fn cast_date_to_timestamp( - array_ref: &ArrayRef, - cast_options: &SparkCastOptions, - target_tz: &Option>, -) -> SparkResult { - let tz_str = if cast_options.timezone.is_empty() { - "UTC" - } else { - cast_options.timezone.as_str() - }; - // safe to unwrap since we are falling back to UTC above - let tz = timezone::Tz::from_str(tz_str)?; - let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); - let date_array = array_ref.as_primitive::(); - - let mut builder = TimestampMicrosecondBuilder::with_capacity(date_array.len()); - - for date in date_array.iter() { - match date { - Some(date) => { - // safe to unwrap since chrono's range ( 262,143 yrs) is higher than - // number of years possible with days as i32 (~ 6 mil yrs) - // convert date in session timezone to timestamp in UTC - let naive_date = epoch + chrono::Duration::days(date as i64); - let local_midnight = naive_date.and_hms_opt(0, 0, 0).unwrap(); - let local_midnight_in_microsec = tz - .from_local_datetime(&local_midnight) - // return earliest possible time (edge case with spring / fall DST changes) - .earliest() - .map(|dt| dt.timestamp_micros()) - // in case there is an issue with DST and returns None , we fall back to UTC - .unwrap_or((date as i64) * 86_400 * 1_000_000); - builder.append_value(local_midnight_in_microsec); - } - None => { - builder.append_null(); - } - } - } - Ok(Arc::new( - builder.finish().with_timezone_opt(target_tz.clone()), - )) -} - /// Determines if DataFusion supports the given cast in a way that is /// compatible with Spark fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> bool { @@ -467,13 +422,8 @@ fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> b is_df_cast_from_decimal_spark_compatible(to_type) } DataType::Utf8 => is_df_cast_from_string_spark_compatible(to_type), - DataType::Date32 => matches!(to_type, DataType::Int32 | DataType::Utf8), - DataType::Timestamp(_, _) => { - matches!( - to_type, - DataType::Int64 | DataType::Date32 | DataType::Utf8 | DataType::Timestamp(_, _) - ) - } + DataType::Date32 => is_df_cast_from_date_spark_compatible(to_type), + DataType::Timestamp(_, _) => is_df_cast_from_timestamp_spark_compatible(to_type), DataType::Binary => { // note that this is not completely Spark compatible because // DataFusion only supports binary data containing valid UTF-8 strings @@ -827,7 +777,7 @@ mod tests { use super::*; use arrow::array::StringArray; use arrow::datatypes::TimestampMicrosecondType; - use arrow::datatypes::{Field, Fields, TimeUnit}; + use arrow::datatypes::{Field, Fields}; #[test] fn test_cast_unsupported_timestamp_to_date() { // Since datafusion uses chrono::Datetime internally not all dates representable by TimestampMicrosecondType are supported @@ -853,64 +803,6 @@ mod tests { assert!(result.is_err()) } - #[test] - fn test_cast_date_to_timestamp() { - use arrow::array::Date32Array; - - // verifying epoch , DST change dates (US) and a null value (comprehensive tests on spark side) - let dates: ArrayRef = Arc::new(Date32Array::from(vec![ - Some(0), - Some(19723), - Some(19793), - None, - ])); - - let non_dst_date = 1704067200000000i64; - let dst_date = 1710115200000000i64; - let seven_hours_ts = 25200000000i64; - let eight_hours_ts = 28800000000i64; - - // validate UTC - let result = cast_array( - Arc::clone(&dates), - &DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())), - &SparkCastOptions::new(EvalMode::Legacy, "UTC", false), - ) - .unwrap(); - let ts = result.as_primitive::(); - assert_eq!(ts.value(0), 0); - assert_eq!(ts.value(1), non_dst_date); - assert_eq!(ts.value(2), dst_date); - assert!(ts.is_null(3)); - - // validate LA timezone (follows Daylight savings) - let result = cast_array( - Arc::clone(&dates), - &DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())), - &SparkCastOptions::new(EvalMode::Legacy, "America/Los_Angeles", false), - ) - .unwrap(); - let ts = result.as_primitive::(); - assert_eq!(ts.value(0), eight_hours_ts); - assert_eq!(ts.value(1), non_dst_date + eight_hours_ts); - // should adjust for DST - assert_eq!(ts.value(2), dst_date + seven_hours_ts); - assert!(ts.is_null(3)); - - // Phoenix timezone (does not follow Daylight savings) - let result = cast_array( - Arc::clone(&dates), - &DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())), - &SparkCastOptions::new(EvalMode::Legacy, "America/Phoenix", false), - ) - .unwrap(); - let ts = result.as_primitive::(); - assert_eq!(ts.value(0), seven_hours_ts); - assert_eq!(ts.value(1), non_dst_date + seven_hours_ts); - assert_eq!(ts.value(2), dst_date + seven_hours_ts); - assert!(ts.is_null(3)); - } - #[test] fn test_cast_struct_to_utf8() { let a: ArrayRef = Arc::new(Int32Array::from(vec![ diff --git a/native/spark-expr/src/conversion_funcs/mod.rs b/native/spark-expr/src/conversion_funcs/mod.rs index 8e3bbe1c6e..94ab6ac169 100644 --- a/native/spark-expr/src/conversion_funcs/mod.rs +++ b/native/spark-expr/src/conversion_funcs/mod.rs @@ -19,4 +19,5 @@ mod boolean; pub mod cast; mod numeric; mod string; +mod temporal; mod utils; diff --git a/native/spark-expr/src/conversion_funcs/temporal.rs b/native/spark-expr/src/conversion_funcs/temporal.rs new file mode 100644 index 0000000000..a95e408584 --- /dev/null +++ b/native/spark-expr/src/conversion_funcs/temporal.rs @@ -0,0 +1,145 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{timezone, SparkCastOptions, SparkResult}; +use arrow::array::{ArrayRef, AsArray, TimestampMicrosecondBuilder}; +use arrow::datatypes::{DataType, Date32Type}; +use chrono::{NaiveDate, TimeZone}; +use std::str::FromStr; +use std::sync::Arc; + +pub(crate) fn is_df_cast_from_date_spark_compatible(to_type: &DataType) -> bool { + matches!(to_type, DataType::Int32 | DataType::Utf8) +} + +pub(crate) fn is_df_cast_from_timestamp_spark_compatible(to_type: &DataType) -> bool { + matches!( + to_type, + DataType::Int64 | DataType::Date32 | DataType::Utf8 | DataType::Timestamp(_, _) + ) +} + +pub(crate) fn cast_date_to_timestamp( + array_ref: &ArrayRef, + cast_options: &SparkCastOptions, + target_tz: &Option>, +) -> SparkResult { + let tz_str = if cast_options.timezone.is_empty() { + "UTC" + } else { + cast_options.timezone.as_str() + }; + // safe to unwrap since we are falling back to UTC above + let tz = timezone::Tz::from_str(tz_str)?; + let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + let date_array = array_ref.as_primitive::(); + + let mut builder = TimestampMicrosecondBuilder::with_capacity(date_array.len()); + + for date in date_array.iter() { + match date { + Some(date) => { + // safe to unwrap since chrono's range ( 262,143 yrs) is higher than + // number of years possible with days as i32 (~ 6 mil yrs) + // convert date in session timezone to timestamp in UTC + let naive_date = epoch + chrono::Duration::days(date as i64); + let local_midnight = naive_date.and_hms_opt(0, 0, 0).unwrap(); + let local_midnight_in_microsec = tz + .from_local_datetime(&local_midnight) + // return earliest possible time (edge case with spring / fall DST changes) + .earliest() + .map(|dt| dt.timestamp_micros()) + // in case there is an issue with DST and returns None , we fall back to UTC + .unwrap_or((date as i64) * 86_400 * 1_000_000); + builder.append_value(local_midnight_in_microsec); + } + None => { + builder.append_null(); + } + } + } + Ok(Arc::new( + builder.finish().with_timezone_opt(target_tz.clone()), + )) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + #[test] + fn test_cast_date_to_timestamp() { + use crate::EvalMode; + use arrow::array::{Array, ArrayRef}; + use arrow::array::Date32Array; + use arrow::datatypes::TimestampMicrosecondType; + + // verifying epoch , DST change dates (US) and a null value (comprehensive tests on spark side) + let dates: ArrayRef = Arc::new(Date32Array::from(vec![ + Some(0), + Some(19723), + Some(19793), + None, + ])); + + let non_dst_date = 1704067200000000i64; + let dst_date = 1710115200000000i64; + let seven_hours_ts = 25200000000i64; + let eight_hours_ts = 28800000000i64; + + // validate UTC + let target_tz: Option> = Some("UTC".into()); + let result = cast_date_to_timestamp( + &dates, + &SparkCastOptions::new(EvalMode::Legacy, "UTC", false), + &target_tz, + ) + .unwrap(); + let ts = result.as_primitive::(); + assert_eq!(ts.value(0), 0); + assert_eq!(ts.value(1), non_dst_date); + assert_eq!(ts.value(2), dst_date); + assert!(ts.is_null(3)); + + // validate LA timezone (follows Daylight savings) + let result = cast_date_to_timestamp( + &dates, + &SparkCastOptions::new(EvalMode::Legacy, "America/Los_Angeles", false), + &target_tz, + ) + .unwrap(); + let ts = result.as_primitive::(); + assert_eq!(ts.value(0), eight_hours_ts); + assert_eq!(ts.value(1), non_dst_date + eight_hours_ts); + // should adjust for DST + assert_eq!(ts.value(2), dst_date + seven_hours_ts); + assert!(ts.is_null(3)); + + // Phoenix timezone (does not follow Daylight savings) + let result = cast_date_to_timestamp( + &dates, + &SparkCastOptions::new(EvalMode::Legacy, "America/Phoenix", false), + &target_tz, + ) + .unwrap(); + let ts = result.as_primitive::(); + assert_eq!(ts.value(0), seven_hours_ts); + assert_eq!(ts.value(1), non_dst_date + seven_hours_ts); + assert_eq!(ts.value(2), dst_date + seven_hours_ts); + assert!(ts.is_null(3)); + } +} From 540581181d40d3f6294b04e0cd41d64136e61f6f Mon Sep 17 00:00:00 2001 From: B Vadlamani Date: Mon, 2 Mar 2026 23:31:43 -0800 Subject: [PATCH 4/4] refactor_cast_module_temporal --- native/spark-expr/src/conversion_funcs/temporal.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/native/spark-expr/src/conversion_funcs/temporal.rs b/native/spark-expr/src/conversion_funcs/temporal.rs index a95e408584..f49c39ae50 100644 --- a/native/spark-expr/src/conversion_funcs/temporal.rs +++ b/native/spark-expr/src/conversion_funcs/temporal.rs @@ -84,8 +84,8 @@ mod tests { #[test] fn test_cast_date_to_timestamp() { use crate::EvalMode; - use arrow::array::{Array, ArrayRef}; use arrow::array::Date32Array; + use arrow::array::{Array, ArrayRef}; use arrow::datatypes::TimestampMicrosecondType; // verifying epoch , DST change dates (US) and a null value (comprehensive tests on spark side)