Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 12 additions & 120 deletions native/spark-expr/src/conversion_funcs/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,38 +29,37 @@ use crate::conversion_funcs::string::{
cast_string_to_date, cast_string_to_decimal, cast_string_to_float, cast_string_to_int,
cast_string_to_timestamp, is_df_cast_from_string_spark_compatible, spark_cast_utf8_to_boolean,
};
use crate::conversion_funcs::temporal::{
cast_date_to_timestamp, is_df_cast_from_date_spark_compatible,
is_df_cast_from_timestamp_spark_compatible,
};
use crate::conversion_funcs::utils::spark_cast_postprocess;
use crate::utils::array_with_timezone;
use crate::EvalMode::Legacy;
use crate::{cast_whole_num_to_binary, timezone, BinaryOutputStyle};
use crate::{EvalMode, SparkError, SparkResult};
use crate::{cast_whole_num_to_binary, BinaryOutputStyle};
use crate::{EvalMode, SparkError};
use arrow::array::builder::StringBuilder;
use arrow::array::{
BinaryBuilder, DictionaryArray, GenericByteArray, ListArray, MapArray, StringArray,
StructArray, TimestampMicrosecondBuilder,
BinaryBuilder, DictionaryArray, GenericByteArray, ListArray, MapArray, StringArray, StructArray,
};
use arrow::compute::can_cast_types;
use arrow::datatypes::{ArrowDictionaryKeyType, ArrowNativeType, DataType, Schema};
use arrow::datatypes::{Field, Fields, GenericBinaryType};
use arrow::error::ArrowError;
use arrow::{
array::{
cast::AsArray,
types::{Date32Type, Int32Type},
Array, ArrayRef, GenericStringArray, Int16Array, Int32Array, Int64Array, Int8Array,
OffsetSizeTrait, PrimitiveArray,
cast::AsArray, types::Int32Type, Array, ArrayRef, GenericStringArray, Int16Array,
Int32Array, Int64Array, Int8Array, OffsetSizeTrait, PrimitiveArray,
},
compute::{cast_with_options, take, CastOptions},
record_batch::RecordBatch,
util::display::FormatOptions,
};
use base64::prelude::BASE64_STANDARD_NO_PAD;
use base64::Engine;
use chrono::{NaiveDate, TimeZone};
use datafusion::common::{internal_err, DataFusionError, Result as DataFusionResult, ScalarValue};
use datafusion::physical_expr::PhysicalExpr;
use datafusion::physical_plan::ColumnarValue;
use std::str::FromStr;
use std::{
any::Any,
fmt::{Debug, Display, Formatter},
Expand Down Expand Up @@ -404,50 +403,6 @@ pub(crate) fn cast_array(
Ok(spark_cast_postprocess(cast_result?, &from_type, to_type))
}

fn cast_date_to_timestamp(
array_ref: &ArrayRef,
cast_options: &SparkCastOptions,
target_tz: &Option<Arc<str>>,
) -> SparkResult<ArrayRef> {
let tz_str = if cast_options.timezone.is_empty() {
"UTC"
} else {
cast_options.timezone.as_str()
};
// safe to unwrap since we are falling back to UTC above
let tz = timezone::Tz::from_str(tz_str)?;
let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap();
let date_array = array_ref.as_primitive::<Date32Type>();

let mut builder = TimestampMicrosecondBuilder::with_capacity(date_array.len());

for date in date_array.iter() {
match date {
Some(date) => {
// safe to unwrap since chrono's range ( 262,143 yrs) is higher than
// number of years possible with days as i32 (~ 6 mil yrs)
// convert date in session timezone to timestamp in UTC
let naive_date = epoch + chrono::Duration::days(date as i64);
let local_midnight = naive_date.and_hms_opt(0, 0, 0).unwrap();
let local_midnight_in_microsec = tz
.from_local_datetime(&local_midnight)
// return earliest possible time (edge case with spring / fall DST changes)
.earliest()
.map(|dt| dt.timestamp_micros())
// in case there is an issue with DST and returns None , we fall back to UTC
.unwrap_or((date as i64) * 86_400 * 1_000_000);
builder.append_value(local_midnight_in_microsec);
}
None => {
builder.append_null();
}
}
}
Ok(Arc::new(
builder.finish().with_timezone_opt(target_tz.clone()),
))
}

/// Determines if DataFusion supports the given cast in a way that is
/// compatible with Spark
fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> bool {
Expand All @@ -467,13 +422,8 @@ fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> b
is_df_cast_from_decimal_spark_compatible(to_type)
}
DataType::Utf8 => is_df_cast_from_string_spark_compatible(to_type),
DataType::Date32 => matches!(to_type, DataType::Int32 | DataType::Utf8),
DataType::Timestamp(_, _) => {
matches!(
to_type,
DataType::Int64 | DataType::Date32 | DataType::Utf8 | DataType::Timestamp(_, _)
)
}
DataType::Date32 => is_df_cast_from_date_spark_compatible(to_type),
DataType::Timestamp(_, _) => is_df_cast_from_timestamp_spark_compatible(to_type),
DataType::Binary => {
// note that this is not completely Spark compatible because
// DataFusion only supports binary data containing valid UTF-8 strings
Expand Down Expand Up @@ -827,7 +777,7 @@ mod tests {
use super::*;
use arrow::array::StringArray;
use arrow::datatypes::TimestampMicrosecondType;
use arrow::datatypes::{Field, Fields, TimeUnit};
use arrow::datatypes::{Field, Fields};
#[test]
fn test_cast_unsupported_timestamp_to_date() {
// Since datafusion uses chrono::Datetime internally not all dates representable by TimestampMicrosecondType are supported
Expand All @@ -853,64 +803,6 @@ mod tests {
assert!(result.is_err())
}

#[test]
fn test_cast_date_to_timestamp() {
use arrow::array::Date32Array;

// verifying epoch , DST change dates (US) and a null value (comprehensive tests on spark side)
let dates: ArrayRef = Arc::new(Date32Array::from(vec![
Some(0),
Some(19723),
Some(19793),
None,
]));

let non_dst_date = 1704067200000000i64;
let dst_date = 1710115200000000i64;
let seven_hours_ts = 25200000000i64;
let eight_hours_ts = 28800000000i64;

// validate UTC
let result = cast_array(
Arc::clone(&dates),
&DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())),
&SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
)
.unwrap();
let ts = result.as_primitive::<TimestampMicrosecondType>();
assert_eq!(ts.value(0), 0);
assert_eq!(ts.value(1), non_dst_date);
assert_eq!(ts.value(2), dst_date);
assert!(ts.is_null(3));

// validate LA timezone (follows Daylight savings)
let result = cast_array(
Arc::clone(&dates),
&DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())),
&SparkCastOptions::new(EvalMode::Legacy, "America/Los_Angeles", false),
)
.unwrap();
let ts = result.as_primitive::<TimestampMicrosecondType>();
assert_eq!(ts.value(0), eight_hours_ts);
assert_eq!(ts.value(1), non_dst_date + eight_hours_ts);
// should adjust for DST
assert_eq!(ts.value(2), dst_date + seven_hours_ts);
assert!(ts.is_null(3));

// Phoenix timezone (does not follow Daylight savings)
let result = cast_array(
Arc::clone(&dates),
&DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())),
&SparkCastOptions::new(EvalMode::Legacy, "America/Phoenix", false),
)
.unwrap();
let ts = result.as_primitive::<TimestampMicrosecondType>();
assert_eq!(ts.value(0), seven_hours_ts);
assert_eq!(ts.value(1), non_dst_date + seven_hours_ts);
assert_eq!(ts.value(2), dst_date + seven_hours_ts);
assert!(ts.is_null(3));
}

#[test]
fn test_cast_struct_to_utf8() {
let a: ArrayRef = Arc::new(Int32Array::from(vec![
Expand Down
1 change: 1 addition & 0 deletions native/spark-expr/src/conversion_funcs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,5 @@ mod boolean;
pub mod cast;
mod numeric;
mod string;
mod temporal;
mod utils;
145 changes: 145 additions & 0 deletions native/spark-expr/src/conversion_funcs/temporal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::{timezone, SparkCastOptions, SparkResult};
use arrow::array::{ArrayRef, AsArray, TimestampMicrosecondBuilder};
use arrow::datatypes::{DataType, Date32Type};
use chrono::{NaiveDate, TimeZone};
use std::str::FromStr;
use std::sync::Arc;

pub(crate) fn is_df_cast_from_date_spark_compatible(to_type: &DataType) -> bool {
matches!(to_type, DataType::Int32 | DataType::Utf8)
}

pub(crate) fn is_df_cast_from_timestamp_spark_compatible(to_type: &DataType) -> bool {
matches!(
to_type,
DataType::Int64 | DataType::Date32 | DataType::Utf8 | DataType::Timestamp(_, _)
)
}

pub(crate) fn cast_date_to_timestamp(
array_ref: &ArrayRef,
cast_options: &SparkCastOptions,
target_tz: &Option<Arc<str>>,
) -> SparkResult<ArrayRef> {
let tz_str = if cast_options.timezone.is_empty() {
"UTC"
} else {
cast_options.timezone.as_str()
};
// safe to unwrap since we are falling back to UTC above
let tz = timezone::Tz::from_str(tz_str)?;
let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap();
let date_array = array_ref.as_primitive::<Date32Type>();

let mut builder = TimestampMicrosecondBuilder::with_capacity(date_array.len());

for date in date_array.iter() {
match date {
Some(date) => {
// safe to unwrap since chrono's range ( 262,143 yrs) is higher than
// number of years possible with days as i32 (~ 6 mil yrs)
// convert date in session timezone to timestamp in UTC
let naive_date = epoch + chrono::Duration::days(date as i64);
let local_midnight = naive_date.and_hms_opt(0, 0, 0).unwrap();
let local_midnight_in_microsec = tz
.from_local_datetime(&local_midnight)
// return earliest possible time (edge case with spring / fall DST changes)
.earliest()
.map(|dt| dt.timestamp_micros())
// in case there is an issue with DST and returns None , we fall back to UTC
.unwrap_or((date as i64) * 86_400 * 1_000_000);
builder.append_value(local_midnight_in_microsec);
}
None => {
builder.append_null();
}
}
}
Ok(Arc::new(
builder.finish().with_timezone_opt(target_tz.clone()),
))
}

#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
#[test]
fn test_cast_date_to_timestamp() {
use crate::EvalMode;
use arrow::array::Date32Array;
use arrow::array::{Array, ArrayRef};
use arrow::datatypes::TimestampMicrosecondType;

// verifying epoch , DST change dates (US) and a null value (comprehensive tests on spark side)
let dates: ArrayRef = Arc::new(Date32Array::from(vec![
Some(0),
Some(19723),
Some(19793),
None,
]));

let non_dst_date = 1704067200000000i64;
let dst_date = 1710115200000000i64;
let seven_hours_ts = 25200000000i64;
let eight_hours_ts = 28800000000i64;

// validate UTC
let target_tz: Option<Arc<str>> = Some("UTC".into());
let result = cast_date_to_timestamp(
&dates,
&SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
&target_tz,
)
.unwrap();
let ts = result.as_primitive::<TimestampMicrosecondType>();
assert_eq!(ts.value(0), 0);
assert_eq!(ts.value(1), non_dst_date);
assert_eq!(ts.value(2), dst_date);
assert!(ts.is_null(3));

// validate LA timezone (follows Daylight savings)
let result = cast_date_to_timestamp(
&dates,
&SparkCastOptions::new(EvalMode::Legacy, "America/Los_Angeles", false),
&target_tz,
)
.unwrap();
let ts = result.as_primitive::<TimestampMicrosecondType>();
assert_eq!(ts.value(0), eight_hours_ts);
assert_eq!(ts.value(1), non_dst_date + eight_hours_ts);
// should adjust for DST
assert_eq!(ts.value(2), dst_date + seven_hours_ts);
assert!(ts.is_null(3));

// Phoenix timezone (does not follow Daylight savings)
let result = cast_date_to_timestamp(
&dates,
&SparkCastOptions::new(EvalMode::Legacy, "America/Phoenix", false),
&target_tz,
)
.unwrap();
let ts = result.as_primitive::<TimestampMicrosecondType>();
assert_eq!(ts.value(0), seven_hours_ts);
assert_eq!(ts.value(1), non_dst_date + seven_hours_ts);
assert_eq!(ts.value(2), dst_date + seven_hours_ts);
assert!(ts.is_null(3));
}
}
Loading