diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 4b93c79418..ff09dbe06e 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -29,15 +29,18 @@ use crate::conversion_funcs::string::{ cast_string_to_date, cast_string_to_decimal, cast_string_to_float, cast_string_to_int, cast_string_to_timestamp, is_df_cast_from_string_spark_compatible, spark_cast_utf8_to_boolean, }; +use crate::conversion_funcs::temporal::{ + cast_date_to_timestamp, is_df_cast_from_date_spark_compatible, + is_df_cast_from_timestamp_spark_compatible, +}; use crate::conversion_funcs::utils::spark_cast_postprocess; use crate::utils::array_with_timezone; use crate::EvalMode::Legacy; -use crate::{cast_whole_num_to_binary, timezone, BinaryOutputStyle}; -use crate::{EvalMode, SparkError, SparkResult}; +use crate::{cast_whole_num_to_binary, BinaryOutputStyle}; +use crate::{EvalMode, SparkError}; use arrow::array::builder::StringBuilder; use arrow::array::{ - BinaryBuilder, DictionaryArray, GenericByteArray, ListArray, MapArray, StringArray, - StructArray, TimestampMicrosecondBuilder, + BinaryBuilder, DictionaryArray, GenericByteArray, ListArray, MapArray, StringArray, StructArray, }; use arrow::compute::can_cast_types; use arrow::datatypes::{ArrowDictionaryKeyType, ArrowNativeType, DataType, Schema}; @@ -45,10 +48,8 @@ use arrow::datatypes::{Field, Fields, GenericBinaryType}; use arrow::error::ArrowError; use arrow::{ array::{ - cast::AsArray, - types::{Date32Type, Int32Type}, - Array, ArrayRef, GenericStringArray, Int16Array, Int32Array, Int64Array, Int8Array, - OffsetSizeTrait, PrimitiveArray, + cast::AsArray, types::Int32Type, Array, ArrayRef, GenericStringArray, Int16Array, + Int32Array, Int64Array, Int8Array, OffsetSizeTrait, PrimitiveArray, }, compute::{cast_with_options, take, CastOptions}, record_batch::RecordBatch, @@ -56,11 +57,9 @@ use arrow::{ }; use base64::prelude::BASE64_STANDARD_NO_PAD; use base64::Engine; -use chrono::{NaiveDate, TimeZone}; use datafusion::common::{internal_err, DataFusionError, Result as DataFusionResult, ScalarValue}; use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_plan::ColumnarValue; -use std::str::FromStr; use std::{ any::Any, fmt::{Debug, Display, Formatter}, @@ -404,50 +403,6 @@ pub(crate) fn cast_array( Ok(spark_cast_postprocess(cast_result?, &from_type, to_type)) } -fn cast_date_to_timestamp( - array_ref: &ArrayRef, - cast_options: &SparkCastOptions, - target_tz: &Option>, -) -> SparkResult { - let tz_str = if cast_options.timezone.is_empty() { - "UTC" - } else { - cast_options.timezone.as_str() - }; - // safe to unwrap since we are falling back to UTC above - let tz = timezone::Tz::from_str(tz_str)?; - let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); - let date_array = array_ref.as_primitive::(); - - let mut builder = TimestampMicrosecondBuilder::with_capacity(date_array.len()); - - for date in date_array.iter() { - match date { - Some(date) => { - // safe to unwrap since chrono's range ( 262,143 yrs) is higher than - // number of years possible with days as i32 (~ 6 mil yrs) - // convert date in session timezone to timestamp in UTC - let naive_date = epoch + chrono::Duration::days(date as i64); - let local_midnight = naive_date.and_hms_opt(0, 0, 0).unwrap(); - let local_midnight_in_microsec = tz - .from_local_datetime(&local_midnight) - // return earliest possible time (edge case with spring / fall DST changes) - .earliest() - .map(|dt| dt.timestamp_micros()) - // in case there is an issue with DST and returns None , we fall back to UTC - .unwrap_or((date as i64) * 86_400 * 1_000_000); - builder.append_value(local_midnight_in_microsec); - } - None => { - builder.append_null(); - } - } - } - Ok(Arc::new( - builder.finish().with_timezone_opt(target_tz.clone()), - )) -} - /// Determines if DataFusion supports the given cast in a way that is /// compatible with Spark fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> bool { @@ -467,13 +422,8 @@ fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> b is_df_cast_from_decimal_spark_compatible(to_type) } DataType::Utf8 => is_df_cast_from_string_spark_compatible(to_type), - DataType::Date32 => matches!(to_type, DataType::Int32 | DataType::Utf8), - DataType::Timestamp(_, _) => { - matches!( - to_type, - DataType::Int64 | DataType::Date32 | DataType::Utf8 | DataType::Timestamp(_, _) - ) - } + DataType::Date32 => is_df_cast_from_date_spark_compatible(to_type), + DataType::Timestamp(_, _) => is_df_cast_from_timestamp_spark_compatible(to_type), DataType::Binary => { // note that this is not completely Spark compatible because // DataFusion only supports binary data containing valid UTF-8 strings @@ -827,7 +777,7 @@ mod tests { use super::*; use arrow::array::StringArray; use arrow::datatypes::TimestampMicrosecondType; - use arrow::datatypes::{Field, Fields, TimeUnit}; + use arrow::datatypes::{Field, Fields}; #[test] fn test_cast_unsupported_timestamp_to_date() { // Since datafusion uses chrono::Datetime internally not all dates representable by TimestampMicrosecondType are supported @@ -853,64 +803,6 @@ mod tests { assert!(result.is_err()) } - #[test] - fn test_cast_date_to_timestamp() { - use arrow::array::Date32Array; - - // verifying epoch , DST change dates (US) and a null value (comprehensive tests on spark side) - let dates: ArrayRef = Arc::new(Date32Array::from(vec![ - Some(0), - Some(19723), - Some(19793), - None, - ])); - - let non_dst_date = 1704067200000000i64; - let dst_date = 1710115200000000i64; - let seven_hours_ts = 25200000000i64; - let eight_hours_ts = 28800000000i64; - - // validate UTC - let result = cast_array( - Arc::clone(&dates), - &DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())), - &SparkCastOptions::new(EvalMode::Legacy, "UTC", false), - ) - .unwrap(); - let ts = result.as_primitive::(); - assert_eq!(ts.value(0), 0); - assert_eq!(ts.value(1), non_dst_date); - assert_eq!(ts.value(2), dst_date); - assert!(ts.is_null(3)); - - // validate LA timezone (follows Daylight savings) - let result = cast_array( - Arc::clone(&dates), - &DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())), - &SparkCastOptions::new(EvalMode::Legacy, "America/Los_Angeles", false), - ) - .unwrap(); - let ts = result.as_primitive::(); - assert_eq!(ts.value(0), eight_hours_ts); - assert_eq!(ts.value(1), non_dst_date + eight_hours_ts); - // should adjust for DST - assert_eq!(ts.value(2), dst_date + seven_hours_ts); - assert!(ts.is_null(3)); - - // Phoenix timezone (does not follow Daylight savings) - let result = cast_array( - Arc::clone(&dates), - &DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())), - &SparkCastOptions::new(EvalMode::Legacy, "America/Phoenix", false), - ) - .unwrap(); - let ts = result.as_primitive::(); - assert_eq!(ts.value(0), seven_hours_ts); - assert_eq!(ts.value(1), non_dst_date + seven_hours_ts); - assert_eq!(ts.value(2), dst_date + seven_hours_ts); - assert!(ts.is_null(3)); - } - #[test] fn test_cast_struct_to_utf8() { let a: ArrayRef = Arc::new(Int32Array::from(vec![ diff --git a/native/spark-expr/src/conversion_funcs/mod.rs b/native/spark-expr/src/conversion_funcs/mod.rs index 8e3bbe1c6e..94ab6ac169 100644 --- a/native/spark-expr/src/conversion_funcs/mod.rs +++ b/native/spark-expr/src/conversion_funcs/mod.rs @@ -19,4 +19,5 @@ mod boolean; pub mod cast; mod numeric; mod string; +mod temporal; mod utils; diff --git a/native/spark-expr/src/conversion_funcs/temporal.rs b/native/spark-expr/src/conversion_funcs/temporal.rs new file mode 100644 index 0000000000..f49c39ae50 --- /dev/null +++ b/native/spark-expr/src/conversion_funcs/temporal.rs @@ -0,0 +1,145 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{timezone, SparkCastOptions, SparkResult}; +use arrow::array::{ArrayRef, AsArray, TimestampMicrosecondBuilder}; +use arrow::datatypes::{DataType, Date32Type}; +use chrono::{NaiveDate, TimeZone}; +use std::str::FromStr; +use std::sync::Arc; + +pub(crate) fn is_df_cast_from_date_spark_compatible(to_type: &DataType) -> bool { + matches!(to_type, DataType::Int32 | DataType::Utf8) +} + +pub(crate) fn is_df_cast_from_timestamp_spark_compatible(to_type: &DataType) -> bool { + matches!( + to_type, + DataType::Int64 | DataType::Date32 | DataType::Utf8 | DataType::Timestamp(_, _) + ) +} + +pub(crate) fn cast_date_to_timestamp( + array_ref: &ArrayRef, + cast_options: &SparkCastOptions, + target_tz: &Option>, +) -> SparkResult { + let tz_str = if cast_options.timezone.is_empty() { + "UTC" + } else { + cast_options.timezone.as_str() + }; + // safe to unwrap since we are falling back to UTC above + let tz = timezone::Tz::from_str(tz_str)?; + let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + let date_array = array_ref.as_primitive::(); + + let mut builder = TimestampMicrosecondBuilder::with_capacity(date_array.len()); + + for date in date_array.iter() { + match date { + Some(date) => { + // safe to unwrap since chrono's range ( 262,143 yrs) is higher than + // number of years possible with days as i32 (~ 6 mil yrs) + // convert date in session timezone to timestamp in UTC + let naive_date = epoch + chrono::Duration::days(date as i64); + let local_midnight = naive_date.and_hms_opt(0, 0, 0).unwrap(); + let local_midnight_in_microsec = tz + .from_local_datetime(&local_midnight) + // return earliest possible time (edge case with spring / fall DST changes) + .earliest() + .map(|dt| dt.timestamp_micros()) + // in case there is an issue with DST and returns None , we fall back to UTC + .unwrap_or((date as i64) * 86_400 * 1_000_000); + builder.append_value(local_midnight_in_microsec); + } + None => { + builder.append_null(); + } + } + } + Ok(Arc::new( + builder.finish().with_timezone_opt(target_tz.clone()), + )) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + #[test] + fn test_cast_date_to_timestamp() { + use crate::EvalMode; + use arrow::array::Date32Array; + use arrow::array::{Array, ArrayRef}; + use arrow::datatypes::TimestampMicrosecondType; + + // verifying epoch , DST change dates (US) and a null value (comprehensive tests on spark side) + let dates: ArrayRef = Arc::new(Date32Array::from(vec![ + Some(0), + Some(19723), + Some(19793), + None, + ])); + + let non_dst_date = 1704067200000000i64; + let dst_date = 1710115200000000i64; + let seven_hours_ts = 25200000000i64; + let eight_hours_ts = 28800000000i64; + + // validate UTC + let target_tz: Option> = Some("UTC".into()); + let result = cast_date_to_timestamp( + &dates, + &SparkCastOptions::new(EvalMode::Legacy, "UTC", false), + &target_tz, + ) + .unwrap(); + let ts = result.as_primitive::(); + assert_eq!(ts.value(0), 0); + assert_eq!(ts.value(1), non_dst_date); + assert_eq!(ts.value(2), dst_date); + assert!(ts.is_null(3)); + + // validate LA timezone (follows Daylight savings) + let result = cast_date_to_timestamp( + &dates, + &SparkCastOptions::new(EvalMode::Legacy, "America/Los_Angeles", false), + &target_tz, + ) + .unwrap(); + let ts = result.as_primitive::(); + assert_eq!(ts.value(0), eight_hours_ts); + assert_eq!(ts.value(1), non_dst_date + eight_hours_ts); + // should adjust for DST + assert_eq!(ts.value(2), dst_date + seven_hours_ts); + assert!(ts.is_null(3)); + + // Phoenix timezone (does not follow Daylight savings) + let result = cast_date_to_timestamp( + &dates, + &SparkCastOptions::new(EvalMode::Legacy, "America/Phoenix", false), + &target_tz, + ) + .unwrap(); + let ts = result.as_primitive::(); + assert_eq!(ts.value(0), seven_hours_ts); + assert_eq!(ts.value(1), non_dst_date + seven_hours_ts); + assert_eq!(ts.value(2), dst_date + seven_hours_ts); + assert!(ts.is_null(3)); + } +}