Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ object Utils extends CometTypeShim with Logging {
case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH =>
YearMonthIntervalType()
case di: ArrowType.Interval if di.getUnit == IntervalUnit.DAY_TIME => DayTimeIntervalType()
case t: ArrowType.Time if t.getUnit == TimeUnit.NANOSECOND && t.getBitWidth == 64 =>
// scalastyle:off classforname
val clazz = Class.forName("org.apache.spark.sql.types.TimeType$")
// scalastyle:on classforname
val module = clazz.getField("MODULE$").get(null)
clazz.getMethod("apply").invoke(module).asInstanceOf[DataType]
case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dt.toString}")
}

Expand Down Expand Up @@ -142,6 +148,8 @@ object Utils extends CometTypeShim with Logging {
}
case TimestampNTZType =>
new ArrowType.Timestamp(TimeUnit.MICROSECOND, null)
case dt if dt.getClass.getSimpleName.startsWith("TimeType") =>
new ArrowType.Time(TimeUnit.NANOSECOND, 64)
case _ =>
throw new UnsupportedOperationException(
s"Unsupported data type: [${dt.getClass.getName}] ${dt.catalogString}")
Expand Down
4 changes: 4 additions & 0 deletions docs/source/user-guide/latest/expressions.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ of expressions that be disabled.
| DayOfYear | `dayofyear` |
| WeekOfYear | `weekofyear` |
| Quarter | `quarter` |
| MakeDate | `make_date` |
| MakeTime | `make_time` |
| ToTime | `to_time` |
| TryToTime | `try_to_time` |

## Math Expressions

Expand Down
39 changes: 36 additions & 3 deletions native/core/src/execution/columnar_to_row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ enum TypedArray<'a> {
Float64(&'a Float64Array),
Date32(&'a Date32Array),
TimestampMicro(&'a TimestampMicrosecondArray),
Time64Nano(&'a Time64NanosecondArray),
Decimal128(&'a Decimal128Array, u8), // array + precision
String(&'a StringArray),
LargeString(&'a LargeStringArray),
Expand Down Expand Up @@ -200,6 +201,10 @@ impl<'a> TypedArray<'a> {
DataType::Timestamp(TimeUnit::Microsecond, _) => Ok(TypedArray::TimestampMicro(
downcast_array!(array, TimestampMicrosecondArray)?,
)),
DataType::Time64(TimeUnit::Nanosecond) => Ok(TypedArray::Time64Nano(downcast_array!(
array,
Time64NanosecondArray
)?)),
DataType::Decimal128(p, _) => Ok(TypedArray::Decimal128(
downcast_array!(array, Decimal128Array)?,
*p,
Expand Down Expand Up @@ -267,6 +272,7 @@ impl<'a> TypedArray<'a> {
Float64,
Date32,
TimestampMicro,
Time64Nano,
Decimal128,
String,
LargeString,
Expand Down Expand Up @@ -295,6 +301,7 @@ impl<'a> TypedArray<'a> {
TypedArray::Float64(arr) => arr.value(row_idx).to_bits() as i64,
TypedArray::Date32(arr) => arr.value(row_idx) as i64,
TypedArray::TimestampMicro(arr) => arr.value(row_idx),
TypedArray::Time64Nano(arr) => arr.value(row_idx),
TypedArray::Decimal128(arr, precision) if *precision <= MAX_LONG_DIGITS => {
arr.value(row_idx) as i64
}
Expand All @@ -317,7 +324,8 @@ impl<'a> TypedArray<'a> {
| TypedArray::Float32(_)
| TypedArray::Float64(_)
| TypedArray::Date32(_)
| TypedArray::TimestampMicro(_) => false,
| TypedArray::TimestampMicro(_)
| TypedArray::Time64Nano(_) => false,
TypedArray::Decimal128(_, precision) => *precision > MAX_LONG_DIGITS,
_ => true,
}
Expand Down Expand Up @@ -380,6 +388,7 @@ enum TypedElements<'a> {
Float64(&'a Float64Array),
Date32(&'a Date32Array),
TimestampMicro(&'a TimestampMicrosecondArray),
Time64Nano(&'a Time64NanosecondArray),
Decimal128(&'a Decimal128Array, u8),
String(&'a StringArray),
LargeString(&'a LargeStringArray),
Expand Down Expand Up @@ -418,6 +427,11 @@ impl<'a> TypedElements<'a> {
return TypedElements::TimestampMicro(arr);
}
}
DataType::Time64(TimeUnit::Nanosecond) => {
if let Some(arr) = array.as_any().downcast_ref::<Time64NanosecondArray>() {
return TypedElements::Time64Nano(arr);
}
}
DataType::Decimal128(p, _) => {
if let Some(arr) = array.as_any().downcast_ref::<Decimal128Array>() {
return TypedElements::Decimal128(arr, *p);
Expand All @@ -442,6 +456,7 @@ impl<'a> TypedElements<'a> {
TypedElements::Int32(_) | TypedElements::Date32(_) | TypedElements::Float32(_) => 4,
TypedElements::Int64(_)
| TypedElements::TimestampMicro(_)
| TypedElements::Time64Nano(_)
| TypedElements::Float64(_) => 8,
TypedElements::Decimal128(_, p) if *p <= MAX_LONG_DIGITS => 8,
_ => 8, // Variable-length uses 8 bytes for offset+length
Expand All @@ -460,6 +475,7 @@ impl<'a> TypedElements<'a> {
| TypedElements::Float64(_)
| TypedElements::Date32(_)
| TypedElements::TimestampMicro(_)
| TypedElements::Time64Nano(_)
)
}

Expand All @@ -479,6 +495,7 @@ impl<'a> TypedElements<'a> {
Float64,
Date32,
TimestampMicro,
Time64Nano,
Decimal128,
String,
LargeString,
Expand All @@ -502,7 +519,8 @@ impl<'a> TypedElements<'a> {
| TypedElements::Float32(_)
| TypedElements::Float64(_)
| TypedElements::Date32(_)
| TypedElements::TimestampMicro(_) => true,
| TypedElements::TimestampMicro(_)
| TypedElements::Time64Nano(_) => true,
TypedElements::Decimal128(_, p) => *p <= MAX_LONG_DIGITS,
_ => false,
}
Expand All @@ -521,6 +539,7 @@ impl<'a> TypedElements<'a> {
TypedElements::Float64(arr) => arr.value(idx).to_bits() as i64,
TypedElements::Date32(arr) => arr.value(idx) as i64,
TypedElements::TimestampMicro(arr) => arr.value(idx),
TypedElements::Time64Nano(arr) => arr.value(idx),
TypedElements::Decimal128(arr, _) => arr.value(idx) as i64,
_ => 0, // Should not be called for variable-length types
}
Expand Down Expand Up @@ -655,6 +674,7 @@ impl<'a> TypedElements<'a> {
TypedElements::Float64(arr) => bulk_copy_range!(arr, 8),
TypedElements::Date32(arr) => bulk_copy_range!(arr, 4),
TypedElements::TimestampMicro(arr) => bulk_copy_range!(arr, 8),
TypedElements::Time64Nano(arr) => bulk_copy_range!(arr, 8),
_ => {} // Should not reach here due to supports_bulk_copy check
}
}
Expand Down Expand Up @@ -827,7 +847,8 @@ fn is_fixed_width(data_type: &DataType) -> bool {
| DataType::Float32
| DataType::Float64
| DataType::Date32
| DataType::Timestamp(TimeUnit::Microsecond, _) => true,
| DataType::Timestamp(TimeUnit::Microsecond, _)
| DataType::Time64(TimeUnit::Nanosecond) => true,
DataType::Decimal128(p, _) => *p <= MAX_LONG_DIGITS,
_ => false,
}
Expand Down Expand Up @@ -1235,6 +1256,15 @@ impl ColumnarToRowContext {
TimestampMicrosecondArray,
|v: i64| v
),
DataType::Time64(TimeUnit::Nanosecond) => write_fixed_column_primitive!(
self,
array,
row_size,
field_offset_in_row,
num_rows,
Time64NanosecondArray,
|v: i64| v
),
DataType::Decimal128(precision, _) if *precision <= MAX_LONG_DIGITS => {
write_fixed_column_primitive!(
self,
Expand Down Expand Up @@ -1360,6 +1390,9 @@ fn get_field_value(data_type: &DataType, array: &ArrayRef, row_idx: usize) -> Co
DataType::Timestamp(TimeUnit::Microsecond, _) => {
get_field_value_primitive!(array, row_idx, TimestampMicrosecondArray, |v: i64| v)
}
DataType::Time64(TimeUnit::Nanosecond) => {
get_field_value_primitive!(array, row_idx, Time64NanosecondArray, |v: i64| v)
}
DataType::Decimal128(precision, _) if *precision <= MAX_LONG_DIGITS => {
get_field_value_primitive!(array, row_idx, Decimal128Array, |v: i128| v as i64)
}
Expand Down
3 changes: 3 additions & 0 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,9 @@ impl PhysicalPlanner {
DataType::Map(f, s) => DataType::Map(f, s).try_into()?,
DataType::List(f) => DataType::List(f).try_into()?,
DataType::Null => ScalarValue::Null,
DataType::Time64(TimeUnit::Nanosecond) => {
ScalarValue::Time64Nanosecond(None)
}
dt => {
return Err(GeneralError(format!("{dt:?} is not supported in Comet")))
}
Expand Down
1 change: 1 addition & 0 deletions native/core/src/execution/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ pub fn to_arrow_datatype(dt_value: &DataType) -> ArrowDataType {
}
DataTypeId::TimestampNtz => ArrowDataType::Timestamp(TimeUnit::Microsecond, None),
DataTypeId::Date => ArrowDataType::Date32,
DataTypeId::Time => ArrowDataType::Time64(TimeUnit::Nanosecond),
DataTypeId::Null => ArrowDataType::Null,
DataTypeId::List => match dt_value
.type_info
Expand Down
1 change: 1 addition & 0 deletions native/proto/src/proto/types.proto
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ message DataType {
LIST = 14;
MAP = 15;
STRUCT = 16;
TIME = 17;
}
DataTypeId type_id = 1;

Expand Down
13 changes: 9 additions & 4 deletions native/spark-expr/src/comet_scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ use crate::math_funcs::log::spark_log;
use crate::math_funcs::modulo_expr::spark_modulo;
use crate::{
spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan,
spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex,
spark_unscaled_value, EvalMode, SparkArrayCompact, SparkArrayPositionFunc, SparkArraysOverlap,
SparkContains, SparkDateDiff, SparkDateFromUnixDate, SparkDateTrunc, SparkMakeDate,
SparkSecondsToTimestamp, SparkSizeFunc,
spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad,
spark_to_time, spark_unhex, spark_unscaled_value, EvalMode, SparkArrayCompact,
SparkArrayPositionFunc, SparkArraysOverlap, SparkContains, SparkDateDiff,
SparkDateFromUnixDate, SparkDateTrunc, SparkMakeDate, SparkMakeTime, SparkSecondsToTimestamp,
SparkSizeFunc,
};
use arrow::datatypes::DataType;
use datafusion::common::{DataFusionError, Result as DataFusionResult};
Expand Down Expand Up @@ -196,6 +197,9 @@ pub fn create_comet_physical_fun_with_eval_mode(
let func = Arc::new(spark_map_sort);
make_comet_scalar_udf!("spark_map_sort", func, without data_type)
}
"to_time" => {
make_comet_scalar_udf!("to_time", spark_to_time, without data_type, fail_on_error)
}
_ => registry.udf(fun_name).map_err(|e| {
DataFusionError::Execution(format!(
"Function {fun_name} not found in the registry: {e}",
Expand All @@ -214,6 +218,7 @@ fn all_scalar_functions() -> Vec<Arc<ScalarUDF>> {
Arc::new(ScalarUDF::new_from_impl(SparkDateFromUnixDate::default())),
Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())),
Arc::new(ScalarUDF::new_from_impl(SparkMakeDate::default())),
Arc::new(ScalarUDF::new_from_impl(SparkMakeTime::default())),
Arc::new(ScalarUDF::new_from_impl(SparkSecondsToTimestamp::default())),
Arc::new(ScalarUDF::new_from_impl(SparkSizeFunc::default())),
]
Expand Down
Loading
Loading