diff --git a/datafusion/functions-aggregate-common/src/min_max.rs b/datafusion/functions-aggregate-common/src/min_max.rs index a9764d6ac6614..d4e0ce3864976 100644 --- a/datafusion/functions-aggregate-common/src/min_max.rs +++ b/datafusion/functions-aggregate-common/src/min_max.rs @@ -37,443 +37,353 @@ use datafusion_common::{ use datafusion_expr_common::accumulator::Accumulator; use std::{cmp::Ordering, mem::size_of_val}; -// min/max of two non-string scalar values. -macro_rules! typed_min_max { - ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident $(, $EXTRA_ARGS:ident)*) => {{ - ScalarValue::$SCALAR( - match ($VALUE, $DELTA) { - (None, None) => None, - (Some(a), None) => Some(*a), - (None, Some(b)) => Some(*b), - (Some(a), Some(b)) => Some((*a).$OP(*b)), - }, - $($EXTRA_ARGS.clone()),* - ) - }}; +macro_rules! choose_min_max { + (min) => { + std::cmp::Ordering::Greater + }; + (max) => { + std::cmp::Ordering::Less + }; } -macro_rules! typed_min_max_float { - ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{ - ScalarValue::$SCALAR(match ($VALUE, $DELTA) { - (None, None) => None, - (Some(a), None) => Some(*a), - (None, Some(b)) => Some(*b), - (Some(a), Some(b)) => match a.total_cmp(b) { - choose_min_max!($OP) => Some(*b), - _ => Some(*a), - }, - }) - }}; +macro_rules! min_max { + ($VALUE:expr, $DELTA:expr, $OP:ident) => {{ min_max_scalar($VALUE, $DELTA, choose_min_max!($OP)) }}; } -// min/max of two scalar string values. -macro_rules! typed_min_max_string { - ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident) => {{ - ScalarValue::$SCALAR(match ($VALUE, $DELTA) { - (None, None) => None, - (Some(a), None) => Some(a.clone()), - (None, Some(b)) => Some(b.clone()), - (Some(a), Some(b)) => Some((a).$OP(b).clone()), - }) - }}; +fn min_max_option( + lhs: &Option, + rhs: &Option, + ordering: Ordering, +) -> Option { + match (lhs, rhs) { + (None, None) => None, + (Some(a), None) => Some(a.clone()), + (None, Some(b)) => Some(b.clone()), + (Some(a), Some(b)) if a.cmp(b) == ordering => Some(b.clone()), + (Some(a), Some(_)) => Some(a.clone()), + } } -// min/max of two scalar string values with a prefix argument. -macro_rules! typed_min_max_string_arg { - ($VALUE:expr, $DELTA:expr, $SCALAR:ident, $OP:ident, $ARG:expr) => {{ - ScalarValue::$SCALAR( - $ARG, - match ($VALUE, $DELTA) { - (None, None) => None, - (Some(a), None) => Some(a.clone()), - (None, Some(b)) => Some(b.clone()), - (Some(a), Some(b)) => Some((a).$OP(b).clone()), - }, - ) - }}; +fn min_max_float_option( + lhs: &Option, + rhs: &Option, + ordering: Ordering, + cmp: impl Fn(&T, &T) -> Ordering, +) -> Option { + match (lhs, rhs) { + (None, None) => None, + (Some(a), None) => Some(*a), + (None, Some(b)) => Some(*b), + (Some(a), Some(b)) if cmp(a, b) == ordering => Some(*b), + (Some(a), Some(_)) => Some(*a), + } } -macro_rules! choose_min_max { - (min) => { - std::cmp::Ordering::Greater - }; - (max) => { - std::cmp::Ordering::Less - }; +fn ensure_decimal_compatibility( + lhs: &ScalarValue, + rhs: &ScalarValue, + lhs_type: (u8, i8), + rhs_type: (u8, i8), +) -> Result<()> { + if lhs_type == rhs_type { + Ok(()) + } else { + internal_err!( + "MIN/MAX is not expected to receive scalars of incompatible types {:?}", + (lhs, rhs) + ) + } } -macro_rules! interval_min_max { - ($OP:tt, $LHS:expr, $RHS:expr) => {{ - match $LHS.partial_cmp(&$RHS) { - Some(choose_min_max!($OP)) => $RHS.clone(), - Some(_) => $LHS.clone(), - None => { - return internal_err!( - "Comparison error while computing interval min/max" - ); - } +fn min_max_generic_scalar( + lhs: &ScalarValue, + rhs: &ScalarValue, + ordering: Ordering, +) -> ScalarValue { + if lhs.is_null() { + let mut rhs_copy = rhs.clone(); + rhs_copy.compact(); + rhs_copy + } else if rhs.is_null() { + lhs.clone() + } else { + match lhs.partial_cmp(rhs) { + Some(order) if order == ordering => { + let mut rhs_copy = rhs.clone(); + rhs_copy.compact(); + rhs_copy + } + _ => lhs.clone(), } - }}; + } } -macro_rules! min_max_generic { - ($VALUE:expr, $DELTA:expr, $OP:ident) => {{ - if $VALUE.is_null() { - let mut delta_copy = $DELTA.clone(); - // When the new value won we want to compact it to - // avoid storing the entire input - delta_copy.compact(); - delta_copy - } else if $DELTA.is_null() { - $VALUE.clone() - } else { - match $VALUE.partial_cmp(&$DELTA) { - Some(choose_min_max!($OP)) => { - // When the new value won we want to compact it to - // avoid storing the entire input - let mut delta_copy = $DELTA.clone(); - delta_copy.compact(); - delta_copy - } - _ => $VALUE.clone(), - } - } - }}; +fn min_max_interval_scalar( + lhs: &ScalarValue, + rhs: &ScalarValue, + ordering: Ordering, +) -> Result { + match lhs.partial_cmp(rhs) { + Some(order) if order == ordering => Ok(rhs.clone()), + Some(_) => Ok(lhs.clone()), + None => internal_err!("Comparison error while computing interval min/max"), + } } -macro_rules! min_max { - ($VALUE:expr, $DELTA:expr, $OP:ident) => {{ - match choose_min_max!($OP) { - Ordering::Greater => Ok(min_max_scalar_impl!($VALUE, $DELTA, min)), - Ordering::Less => Ok(min_max_scalar_impl!($VALUE, $DELTA, max)), - Ordering::Equal => { - unreachable!("min/max comparisons do not use equal ordering") +fn min_max_dictionary_scalar( + lhs: &ScalarValue, + rhs: &ScalarValue, + ordering: Ordering, +) -> Result> { + match (lhs, rhs) { + ( + ScalarValue::Dictionary(lhs_dict_key_type, lhs_dict_value), + ScalarValue::Dictionary(rhs_dict_key_type, rhs_dict_value), + ) => { + if lhs_dict_key_type != rhs_dict_key_type { + return internal_err!( + "MIN/MAX is not expected to receive dictionary scalars with different key types ({:?} vs {:?})", + lhs_dict_key_type, + rhs_dict_key_type + ); } + + let result = min_max_scalar( + lhs_dict_value.as_ref(), + rhs_dict_value.as_ref(), + ordering, + )?; + Ok(Some(ScalarValue::Dictionary( + lhs_dict_key_type.clone(), + Box::new(result), + ))) } - }}; + (ScalarValue::Dictionary(_, lhs_dict_value), rhs_scalar) => { + min_max_scalar(lhs_dict_value.as_ref(), rhs_scalar, ordering).map(Some) + } + (lhs_scalar, ScalarValue::Dictionary(_, rhs_dict_value)) => { + min_max_scalar(lhs_scalar, rhs_dict_value.as_ref(), ordering).map(Some) + } + _ => Ok(None), + } } // min/max of two logically compatible scalar values. // Dictionary scalars participate by comparing their inner logical values. // When both inputs are dictionaries, matching key types are preserved in the // result; differing key types remain an unexpected invariant violation. -macro_rules! min_max_scalar_impl { - ($VALUE:expr, $DELTA:expr, $OP:ident) => {{ - match ($VALUE, $DELTA) { - (ScalarValue::Null, ScalarValue::Null) => ScalarValue::Null, - ( - lhs @ ScalarValue::Decimal32(lhsv, lhsp, lhss), - rhs @ ScalarValue::Decimal32(rhsv, rhsp, rhss) - ) => { - if lhsp.eq(rhsp) && lhss.eq(rhss) { - typed_min_max!(lhsv, rhsv, Decimal32, $OP, lhsp, lhss) - } else { - return internal_err!( - "MIN/MAX is not expected to receive scalars of incompatible types {:?}", - (lhs, rhs) - ); - } - } - ( - lhs @ ScalarValue::Decimal64(lhsv, lhsp, lhss), - rhs @ ScalarValue::Decimal64(rhsv, rhsp, rhss) - ) => { - if lhsp.eq(rhsp) && lhss.eq(rhss) { - typed_min_max!(lhsv, rhsv, Decimal64, $OP, lhsp, lhss) - } else { - return internal_err!( - "MIN/MAX is not expected to receive scalars of incompatible types {:?}", - (lhs, rhs) - ); - } - } - ( - lhs @ ScalarValue::Decimal128(lhsv, lhsp, lhss), - rhs @ ScalarValue::Decimal128(rhsv, rhsp, rhss) - ) => { - if lhsp.eq(rhsp) && lhss.eq(rhss) { - typed_min_max!(lhsv, rhsv, Decimal128, $OP, lhsp, lhss) - } else { - return internal_err!( - "MIN/MAX is not expected to receive scalars of incompatible types {:?}", - (lhs, rhs) - ); - } - } - ( - lhs @ ScalarValue::Decimal256(lhsv, lhsp, lhss), - rhs @ ScalarValue::Decimal256(rhsv, rhsp, rhss) - ) => { - if lhsp.eq(rhsp) && lhss.eq(rhss) { - typed_min_max!(lhsv, rhsv, Decimal256, $OP, lhsp, lhss) - } else { - return internal_err!( - "MIN/MAX is not expected to receive scalars of incompatible types {:?}", - (lhs, rhs) - ); - } - } - (ScalarValue::Boolean(lhs), ScalarValue::Boolean(rhs)) => { - typed_min_max!(lhs, rhs, Boolean, $OP) - } - (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => { - typed_min_max_float!(lhs, rhs, Float64, $OP) - } - (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => { - typed_min_max_float!(lhs, rhs, Float32, $OP) - } - (ScalarValue::Float16(lhs), ScalarValue::Float16(rhs)) => { - typed_min_max_float!(lhs, rhs, Float16, $OP) - } - (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => { - typed_min_max!(lhs, rhs, UInt64, $OP) - } - (ScalarValue::UInt32(lhs), ScalarValue::UInt32(rhs)) => { - typed_min_max!(lhs, rhs, UInt32, $OP) - } - (ScalarValue::UInt16(lhs), ScalarValue::UInt16(rhs)) => { - typed_min_max!(lhs, rhs, UInt16, $OP) - } - (ScalarValue::UInt8(lhs), ScalarValue::UInt8(rhs)) => { - typed_min_max!(lhs, rhs, UInt8, $OP) - } - (ScalarValue::Int64(lhs), ScalarValue::Int64(rhs)) => { - typed_min_max!(lhs, rhs, Int64, $OP) - } - (ScalarValue::Int32(lhs), ScalarValue::Int32(rhs)) => { - typed_min_max!(lhs, rhs, Int32, $OP) - } - (ScalarValue::Int16(lhs), ScalarValue::Int16(rhs)) => { - typed_min_max!(lhs, rhs, Int16, $OP) - } - (ScalarValue::Int8(lhs), ScalarValue::Int8(rhs)) => { - typed_min_max!(lhs, rhs, Int8, $OP) - } - (ScalarValue::Utf8(lhs), ScalarValue::Utf8(rhs)) => { - typed_min_max_string!(lhs, rhs, Utf8, $OP) - } - (ScalarValue::LargeUtf8(lhs), ScalarValue::LargeUtf8(rhs)) => { - typed_min_max_string!(lhs, rhs, LargeUtf8, $OP) - } - (ScalarValue::Utf8View(lhs), ScalarValue::Utf8View(rhs)) => { - typed_min_max_string!(lhs, rhs, Utf8View, $OP) - } - (ScalarValue::Binary(lhs), ScalarValue::Binary(rhs)) => { - typed_min_max_string!(lhs, rhs, Binary, $OP) - } - (ScalarValue::LargeBinary(lhs), ScalarValue::LargeBinary(rhs)) => { - typed_min_max_string!(lhs, rhs, LargeBinary, $OP) - } - (ScalarValue::FixedSizeBinary(lsize, lhs), ScalarValue::FixedSizeBinary(rsize, rhs)) => { - if lsize == rsize { - typed_min_max_string_arg!(lhs, rhs, FixedSizeBinary, $OP, *lsize) - } - else { - return internal_err!( - "MIN/MAX is not expected to receive FixedSizeBinary of incompatible sizes {:?}", - (lsize, rsize)) - } - } - (ScalarValue::BinaryView(lhs), ScalarValue::BinaryView(rhs)) => { - typed_min_max_string!(lhs, rhs, BinaryView, $OP) - } - (ScalarValue::TimestampSecond(lhs, l_tz), ScalarValue::TimestampSecond(rhs, _)) => { - typed_min_max!(lhs, rhs, TimestampSecond, $OP, l_tz) - } - ( - ScalarValue::TimestampMillisecond(lhs, l_tz), - ScalarValue::TimestampMillisecond(rhs, _), - ) => { - typed_min_max!(lhs, rhs, TimestampMillisecond, $OP, l_tz) - } - ( - ScalarValue::TimestampMicrosecond(lhs, l_tz), - ScalarValue::TimestampMicrosecond(rhs, _), - ) => { - typed_min_max!(lhs, rhs, TimestampMicrosecond, $OP, l_tz) - } - ( - ScalarValue::TimestampNanosecond(lhs, l_tz), - ScalarValue::TimestampNanosecond(rhs, _), - ) => { - typed_min_max!(lhs, rhs, TimestampNanosecond, $OP, l_tz) - } - ( - ScalarValue::Date32(lhs), - ScalarValue::Date32(rhs), - ) => { - typed_min_max!(lhs, rhs, Date32, $OP) - } - ( - ScalarValue::Date64(lhs), - ScalarValue::Date64(rhs), - ) => { - typed_min_max!(lhs, rhs, Date64, $OP) - } - ( - ScalarValue::Time32Second(lhs), - ScalarValue::Time32Second(rhs), - ) => { - typed_min_max!(lhs, rhs, Time32Second, $OP) - } - ( - ScalarValue::Time32Millisecond(lhs), - ScalarValue::Time32Millisecond(rhs), - ) => { - typed_min_max!(lhs, rhs, Time32Millisecond, $OP) - } - ( - ScalarValue::Time64Microsecond(lhs), - ScalarValue::Time64Microsecond(rhs), - ) => { - typed_min_max!(lhs, rhs, Time64Microsecond, $OP) - } - ( - ScalarValue::Time64Nanosecond(lhs), - ScalarValue::Time64Nanosecond(rhs), - ) => { - typed_min_max!(lhs, rhs, Time64Nanosecond, $OP) - } - ( - ScalarValue::IntervalYearMonth(lhs), - ScalarValue::IntervalYearMonth(rhs), - ) => { - typed_min_max!(lhs, rhs, IntervalYearMonth, $OP) - } - ( - ScalarValue::IntervalMonthDayNano(lhs), - ScalarValue::IntervalMonthDayNano(rhs), - ) => { - typed_min_max!(lhs, rhs, IntervalMonthDayNano, $OP) - } - ( - ScalarValue::IntervalDayTime(lhs), - ScalarValue::IntervalDayTime(rhs), - ) => { - typed_min_max!(lhs, rhs, IntervalDayTime, $OP) - } - ( - ScalarValue::IntervalYearMonth(_), - ScalarValue::IntervalMonthDayNano(_), - ) | ( - ScalarValue::IntervalYearMonth(_), - ScalarValue::IntervalDayTime(_), - ) | ( - ScalarValue::IntervalMonthDayNano(_), - ScalarValue::IntervalDayTime(_), - ) | ( - ScalarValue::IntervalMonthDayNano(_), - ScalarValue::IntervalYearMonth(_), - ) | ( - ScalarValue::IntervalDayTime(_), - ScalarValue::IntervalYearMonth(_), - ) | ( - ScalarValue::IntervalDayTime(_), - ScalarValue::IntervalMonthDayNano(_), - ) => { - interval_min_max!($OP, $VALUE, $DELTA) - } - ( - ScalarValue::DurationSecond(lhs), - ScalarValue::DurationSecond(rhs), - ) => { - typed_min_max!(lhs, rhs, DurationSecond, $OP) - } - ( - ScalarValue::DurationMillisecond(lhs), - ScalarValue::DurationMillisecond(rhs), - ) => { - typed_min_max!(lhs, rhs, DurationMillisecond, $OP) - } - ( - ScalarValue::DurationMicrosecond(lhs), - ScalarValue::DurationMicrosecond(rhs), - ) => { - typed_min_max!(lhs, rhs, DurationMicrosecond, $OP) - } - ( - ScalarValue::DurationNanosecond(lhs), - ScalarValue::DurationNanosecond(rhs), - ) => { - typed_min_max!(lhs, rhs, DurationNanosecond, $OP) - } - - ( - lhs @ ScalarValue::Struct(_), - rhs @ ScalarValue::Struct(_), - ) => { - min_max_generic!(lhs, rhs, $OP) - } - - ( - lhs @ ScalarValue::List(_), - rhs @ ScalarValue::List(_), - ) => { - min_max_generic!(lhs, rhs, $OP) - } - - - ( - lhs @ ScalarValue::LargeList(_), - rhs @ ScalarValue::LargeList(_), - ) => { - min_max_generic!(lhs, rhs, $OP) - } - - - ( - lhs @ ScalarValue::FixedSizeList(_), - rhs @ ScalarValue::FixedSizeList(_), - ) => { - min_max_generic!(lhs, rhs, $OP) - } +fn min_max_scalar( + lhs: &ScalarValue, + rhs: &ScalarValue, + ordering: Ordering, +) -> Result { + if ordering == Ordering::Equal { + unreachable!("min/max comparisons do not use equal ordering"); + } - ( - ScalarValue::Dictionary(lhs_dict_key_type, lhs_dict_value), - ScalarValue::Dictionary(rhs_dict_key_type, rhs_dict_value), - ) => { - if lhs_dict_key_type != rhs_dict_key_type { - return internal_err!( - "MIN/MAX is not expected to receive dictionary scalars with different key types ({:?} vs {:?})", - lhs_dict_key_type, - rhs_dict_key_type - ); - } - - let result = min_max_scalar( - lhs_dict_value.as_ref(), - rhs_dict_value.as_ref(), - choose_min_max!($OP), - )?; - ScalarValue::Dictionary(lhs_dict_key_type.clone(), Box::new(result)) - } - (ScalarValue::Dictionary(_, lhs_dict_value), rhs_scalar) => { - min_max_scalar(lhs_dict_value.as_ref(), rhs_scalar, choose_min_max!($OP))? - } - (lhs_scalar, ScalarValue::Dictionary(_, rhs_dict_value)) => { - min_max_scalar(lhs_scalar, rhs_dict_value.as_ref(), choose_min_max!($OP))? - } + if let Some(result) = min_max_dictionary_scalar(lhs, rhs, ordering)? { + return Ok(result); + } - e => { - return internal_err!( - "MIN/MAX is not expected to receive logically incompatible scalar values {:?}", - e - ) - } - } - }}; + min_max_scalar_same_variant(lhs, rhs, ordering) } -fn min_max_scalar( +fn min_max_scalar_same_variant( lhs: &ScalarValue, rhs: &ScalarValue, ordering: Ordering, ) -> Result { - match ordering { - Ordering::Greater => Ok(min_max_scalar_impl!(lhs, rhs, min)), - Ordering::Less => Ok(min_max_scalar_impl!(lhs, rhs, max)), - Ordering::Equal => unreachable!("min/max comparisons do not use equal ordering"), - } + let result = match (lhs, rhs) { + (ScalarValue::Null, ScalarValue::Null) => ScalarValue::Null, + ( + ScalarValue::Decimal32(lhsv, lhsp, lhss), + ScalarValue::Decimal32(rhsv, rhsp, rhss), + ) => { + ensure_decimal_compatibility(lhs, rhs, (*lhsp, *lhss), (*rhsp, *rhss))?; + ScalarValue::Decimal32(min_max_option(lhsv, rhsv, ordering), *lhsp, *lhss) + } + ( + ScalarValue::Decimal64(lhsv, lhsp, lhss), + ScalarValue::Decimal64(rhsv, rhsp, rhss), + ) => { + ensure_decimal_compatibility(lhs, rhs, (*lhsp, *lhss), (*rhsp, *rhss))?; + ScalarValue::Decimal64(min_max_option(lhsv, rhsv, ordering), *lhsp, *lhss) + } + ( + ScalarValue::Decimal128(lhsv, lhsp, lhss), + ScalarValue::Decimal128(rhsv, rhsp, rhss), + ) => { + ensure_decimal_compatibility(lhs, rhs, (*lhsp, *lhss), (*rhsp, *rhss))?; + ScalarValue::Decimal128(min_max_option(lhsv, rhsv, ordering), *lhsp, *lhss) + } + ( + ScalarValue::Decimal256(lhsv, lhsp, lhss), + ScalarValue::Decimal256(rhsv, rhsp, rhss), + ) => { + ensure_decimal_compatibility(lhs, rhs, (*lhsp, *lhss), (*rhsp, *rhss))?; + ScalarValue::Decimal256(min_max_option(lhsv, rhsv, ordering), *lhsp, *lhss) + } + (ScalarValue::Boolean(lhs), ScalarValue::Boolean(rhs)) => { + ScalarValue::Boolean(min_max_option(lhs, rhs, ordering)) + } + (ScalarValue::Float64(lhs), ScalarValue::Float64(rhs)) => { + ScalarValue::Float64(min_max_float_option(lhs, rhs, ordering, f64::total_cmp)) + } + (ScalarValue::Float32(lhs), ScalarValue::Float32(rhs)) => { + ScalarValue::Float32(min_max_float_option(lhs, rhs, ordering, f32::total_cmp)) + } + (ScalarValue::Float16(lhs), ScalarValue::Float16(rhs)) => { + ScalarValue::Float16(min_max_float_option(lhs, rhs, ordering, |a, b| { + a.total_cmp(b) + })) + } + (ScalarValue::UInt64(lhs), ScalarValue::UInt64(rhs)) => { + ScalarValue::UInt64(min_max_option(lhs, rhs, ordering)) + } + (ScalarValue::UInt32(lhs), ScalarValue::UInt32(rhs)) => { + ScalarValue::UInt32(min_max_option(lhs, rhs, ordering)) + } + (ScalarValue::UInt16(lhs), ScalarValue::UInt16(rhs)) => { + ScalarValue::UInt16(min_max_option(lhs, rhs, ordering)) + } + (ScalarValue::UInt8(lhs), ScalarValue::UInt8(rhs)) => { + ScalarValue::UInt8(min_max_option(lhs, rhs, ordering)) + } + (ScalarValue::Int64(lhs), ScalarValue::Int64(rhs)) => { + ScalarValue::Int64(min_max_option(lhs, rhs, ordering)) + } + (ScalarValue::Int32(lhs), ScalarValue::Int32(rhs)) => { + ScalarValue::Int32(min_max_option(lhs, rhs, ordering)) + } + (ScalarValue::Int16(lhs), ScalarValue::Int16(rhs)) => { + ScalarValue::Int16(min_max_option(lhs, rhs, ordering)) + } + (ScalarValue::Int8(lhs), ScalarValue::Int8(rhs)) => { + ScalarValue::Int8(min_max_option(lhs, rhs, ordering)) + } + (ScalarValue::Utf8(lhs), ScalarValue::Utf8(rhs)) => { + ScalarValue::Utf8(min_max_option(lhs, rhs, ordering)) + } + (ScalarValue::LargeUtf8(lhs), ScalarValue::LargeUtf8(rhs)) => { + ScalarValue::LargeUtf8(min_max_option(lhs, rhs, ordering)) + } + (ScalarValue::Utf8View(lhs), ScalarValue::Utf8View(rhs)) => { + ScalarValue::Utf8View(min_max_option(lhs, rhs, ordering)) + } + (ScalarValue::Binary(lhs), ScalarValue::Binary(rhs)) => { + ScalarValue::Binary(min_max_option(lhs, rhs, ordering)) + } + (ScalarValue::LargeBinary(lhs), ScalarValue::LargeBinary(rhs)) => { + ScalarValue::LargeBinary(min_max_option(lhs, rhs, ordering)) + } + ( + ScalarValue::FixedSizeBinary(lsize, lhs), + ScalarValue::FixedSizeBinary(rsize, rhs), + ) => { + if lsize == rsize { + ScalarValue::FixedSizeBinary(*lsize, min_max_option(lhs, rhs, ordering)) + } else { + return internal_err!( + "MIN/MAX is not expected to receive FixedSizeBinary of incompatible sizes {:?}", + (lsize, rsize) + ); + } + } + (ScalarValue::BinaryView(lhs), ScalarValue::BinaryView(rhs)) => { + ScalarValue::BinaryView(min_max_option(lhs, rhs, ordering)) + } + ( + ScalarValue::TimestampSecond(lhs, l_tz), + ScalarValue::TimestampSecond(rhs, _), + ) => { + ScalarValue::TimestampSecond(min_max_option(lhs, rhs, ordering), l_tz.clone()) + } + ( + ScalarValue::TimestampMillisecond(lhs, l_tz), + ScalarValue::TimestampMillisecond(rhs, _), + ) => ScalarValue::TimestampMillisecond( + min_max_option(lhs, rhs, ordering), + l_tz.clone(), + ), + ( + ScalarValue::TimestampMicrosecond(lhs, l_tz), + ScalarValue::TimestampMicrosecond(rhs, _), + ) => ScalarValue::TimestampMicrosecond( + min_max_option(lhs, rhs, ordering), + l_tz.clone(), + ), + ( + ScalarValue::TimestampNanosecond(lhs, l_tz), + ScalarValue::TimestampNanosecond(rhs, _), + ) => ScalarValue::TimestampNanosecond( + min_max_option(lhs, rhs, ordering), + l_tz.clone(), + ), + (ScalarValue::Date32(lhs), ScalarValue::Date32(rhs)) => { + ScalarValue::Date32(min_max_option(lhs, rhs, ordering)) + } + (ScalarValue::Date64(lhs), ScalarValue::Date64(rhs)) => { + ScalarValue::Date64(min_max_option(lhs, rhs, ordering)) + } + (ScalarValue::Time32Second(lhs), ScalarValue::Time32Second(rhs)) => { + ScalarValue::Time32Second(min_max_option(lhs, rhs, ordering)) + } + (ScalarValue::Time32Millisecond(lhs), ScalarValue::Time32Millisecond(rhs)) => { + ScalarValue::Time32Millisecond(min_max_option(lhs, rhs, ordering)) + } + (ScalarValue::Time64Microsecond(lhs), ScalarValue::Time64Microsecond(rhs)) => { + ScalarValue::Time64Microsecond(min_max_option(lhs, rhs, ordering)) + } + (ScalarValue::Time64Nanosecond(lhs), ScalarValue::Time64Nanosecond(rhs)) => { + ScalarValue::Time64Nanosecond(min_max_option(lhs, rhs, ordering)) + } + (ScalarValue::IntervalYearMonth(lhs), ScalarValue::IntervalYearMonth(rhs)) => { + ScalarValue::IntervalYearMonth(min_max_option(lhs, rhs, ordering)) + } + ( + ScalarValue::IntervalMonthDayNano(lhs), + ScalarValue::IntervalMonthDayNano(rhs), + ) => ScalarValue::IntervalMonthDayNano(min_max_option(lhs, rhs, ordering)), + (ScalarValue::IntervalDayTime(lhs), ScalarValue::IntervalDayTime(rhs)) => { + ScalarValue::IntervalDayTime(min_max_option(lhs, rhs, ordering)) + } + (ScalarValue::IntervalYearMonth(_), ScalarValue::IntervalMonthDayNano(_)) + | (ScalarValue::IntervalYearMonth(_), ScalarValue::IntervalDayTime(_)) + | (ScalarValue::IntervalMonthDayNano(_), ScalarValue::IntervalDayTime(_)) + | (ScalarValue::IntervalMonthDayNano(_), ScalarValue::IntervalYearMonth(_)) + | (ScalarValue::IntervalDayTime(_), ScalarValue::IntervalYearMonth(_)) + | (ScalarValue::IntervalDayTime(_), ScalarValue::IntervalMonthDayNano(_)) => { + return min_max_interval_scalar(lhs, rhs, ordering); + } + (ScalarValue::DurationSecond(lhs), ScalarValue::DurationSecond(rhs)) => { + ScalarValue::DurationSecond(min_max_option(lhs, rhs, ordering)) + } + ( + ScalarValue::DurationMillisecond(lhs), + ScalarValue::DurationMillisecond(rhs), + ) => ScalarValue::DurationMillisecond(min_max_option(lhs, rhs, ordering)), + ( + ScalarValue::DurationMicrosecond(lhs), + ScalarValue::DurationMicrosecond(rhs), + ) => ScalarValue::DurationMicrosecond(min_max_option(lhs, rhs, ordering)), + (ScalarValue::DurationNanosecond(lhs), ScalarValue::DurationNanosecond(rhs)) => { + ScalarValue::DurationNanosecond(min_max_option(lhs, rhs, ordering)) + } + (ScalarValue::Struct(_), ScalarValue::Struct(_)) + | (ScalarValue::List(_), ScalarValue::List(_)) + | (ScalarValue::LargeList(_), ScalarValue::LargeList(_)) + | (ScalarValue::FixedSizeList(_), ScalarValue::FixedSizeList(_)) => { + min_max_generic_scalar(lhs, rhs, ordering) + } + _ => { + return internal_err!( + "MIN/MAX is not expected to receive logically incompatible scalar values {:?}", + (lhs, rhs) + ); + } + }; + + Ok(result) } /// An accumulator to compute the maximum value @@ -904,6 +814,118 @@ pub fn max_batch(values: &ArrayRef) -> Result { mod tests { use super::*; + #[test] + fn min_max_scalar_preserves_core_behaviors() -> Result<()> { + let cases = [ + ( + ScalarValue::Int32(Some(1)), + ScalarValue::Int32(Some(2)), + Ordering::Less, + ScalarValue::Int32(Some(2)), + ), + ( + ScalarValue::Int32(Some(1)), + ScalarValue::Int32(Some(2)), + Ordering::Greater, + ScalarValue::Int32(Some(1)), + ), + ( + ScalarValue::Utf8(Some("a".to_string())), + ScalarValue::Utf8(Some("b".to_string())), + Ordering::Less, + ScalarValue::Utf8(Some("b".to_string())), + ), + ( + ScalarValue::Boolean(None), + ScalarValue::Boolean(Some(true)), + Ordering::Greater, + ScalarValue::Boolean(Some(true)), + ), + ]; + + for (lhs, rhs, ordering, expected) in cases { + assert_eq!(min_max_scalar(&lhs, &rhs, ordering)?, expected); + } + + Ok(()) + } + + #[test] + fn min_max_scalar_float_uses_total_cmp_for_nan() -> Result<()> { + type F16 = + ::Native; + + let lhs = ScalarValue::Float64(Some(f64::NAN)); + let rhs = ScalarValue::Float64(Some(1.0)); + assert_eq!(min_max_scalar(&lhs, &rhs, Ordering::Greater)?, rhs); + assert!(matches!( + min_max_scalar(&lhs, &rhs, Ordering::Less)?, + ScalarValue::Float64(Some(value)) if value.is_nan() + )); + + let lhs = ScalarValue::Float32(Some(f32::NAN)); + let rhs = ScalarValue::Float32(Some(1.0)); + assert_eq!(min_max_scalar(&lhs, &rhs, Ordering::Greater)?, rhs); + assert!(matches!( + min_max_scalar(&lhs, &rhs, Ordering::Less)?, + ScalarValue::Float32(Some(value)) if value.is_nan() + )); + + let lhs = ScalarValue::Float16(Some(F16::NAN)); + let rhs = ScalarValue::Float16(Some(F16::from_f32(1.0))); + assert_eq!(min_max_scalar(&lhs, &rhs, Ordering::Greater)?, rhs); + assert!(matches!( + min_max_scalar(&lhs, &rhs, Ordering::Less)?, + ScalarValue::Float16(Some(value)) if value.is_nan() + )); + Ok(()) + } + + #[test] + fn min_max_decimal_mismatch_error_is_preserved() -> Result<()> { + let lhs = ScalarValue::Decimal128(Some(1), 10, 2); + let rhs = ScalarValue::Decimal128(Some(2), 11, 2); + + let error = min_max_scalar(&lhs, &rhs, Ordering::Less).unwrap_err(); + let message = error.to_string(); + + assert!(message.starts_with(&format!( + "Internal error: MIN/MAX is not expected to receive scalars of incompatible types {:?}", + (&lhs, &rhs) + ))); + Ok(()) + } + + #[test] + fn min_max_fixed_size_binary_mismatch_error_is_preserved() -> Result<()> { + let lhs = ScalarValue::FixedSizeBinary(2, Some(vec![1, 2])); + let rhs = ScalarValue::FixedSizeBinary(3, Some(vec![1, 2, 3])); + + let error = min_max_scalar(&lhs, &rhs, Ordering::Less).unwrap_err(); + let message = error.to_string(); + + assert!(message.starts_with( + "Internal error: MIN/MAX is not expected to receive FixedSizeBinary of incompatible sizes (2, 3)" + )); + Ok(()) + } + + #[test] + fn min_max_mixed_interval_error_is_preserved() -> Result<()> { + let lhs = ScalarValue::IntervalYearMonth(Some(1)); + let rhs = ScalarValue::IntervalDayTime(Some( + arrow::datatypes::IntervalDayTime::new(1, 0), + )); + + let error = min_max_scalar(&lhs, &rhs, Ordering::Less).unwrap_err(); + let message = error.to_string(); + + assert!(message.starts_with( + "Internal error: Comparison error while computing interval min/max" + )); + Ok(()) + } + #[test] fn min_max_dictionary_and_scalar_compare_by_inner_value() -> Result<()> { let dictionary = ScalarValue::Dictionary(