diff --git a/datafusion/functions/src/math/floor.rs b/datafusion/functions/src/math/floor.rs index d69f9b9d86fe0..db07fa3e5f787 100644 --- a/datafusion/functions/src/math/floor.rs +++ b/datafusion/functions/src/math/floor.rs @@ -25,12 +25,15 @@ use arrow::datatypes::{ }; use datafusion_common::{Result, ScalarValue, exec_err}; use datafusion_expr::interval_arithmetic::Interval; +use datafusion_expr::preimage::PreimageResult; +use datafusion_expr::simplify::SimplifyContext; use datafusion_expr::sort_properties::{ExprProperties, SortProperties}; use datafusion_expr::{ - Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, - TypeSignature, TypeSignatureClass, Volatility, + Coercion, ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDFImpl, + Signature, TypeSignature, TypeSignatureClass, Volatility, }; use datafusion_macros::user_doc; +use num_traits::{CheckedAdd, Float, One}; use super::decimal::{apply_decimal_op, floor_decimal_value}; @@ -200,7 +203,242 @@ impl ScalarUDFImpl for FloorFunc { Interval::make_unbounded(&data_type) } + /// Compute the preimage for floor function. + /// + /// For `floor(x) = N`, the preimage is `x >= N AND x < N + 1` + /// because floor(x) = N for all x in [N, N+1). + /// + /// This enables predicate pushdown optimizations, transforming: + /// `floor(col) = 100` into `col >= 100 AND col < 101` + fn preimage( + &self, + args: &[Expr], + lit_expr: &Expr, + _info: &SimplifyContext, + ) -> Result { + // floor takes exactly one argument + if args.len() != 1 { + return Ok(PreimageResult::None); + } + + let arg = args[0].clone(); + + // Extract the literal value being compared to + let Expr::Literal(lit_value, _) = lit_expr else { + return Ok(PreimageResult::None); + }; + + // Compute lower bound (N) and upper bound (N + 1) using helper functions + let Some((lower, upper)) = (match lit_value { + // Floating-point types + ScalarValue::Float64(Some(n)) => float_preimage_bounds(*n).map(|(lo, hi)| { + ( + ScalarValue::Float64(Some(lo)), + ScalarValue::Float64(Some(hi)), + ) + }), + ScalarValue::Float32(Some(n)) => float_preimage_bounds(*n).map(|(lo, hi)| { + ( + ScalarValue::Float32(Some(lo)), + ScalarValue::Float32(Some(hi)), + ) + }), + + // Integer types + ScalarValue::Int8(Some(n)) => int_preimage_bounds(*n).map(|(lo, hi)| { + (ScalarValue::Int8(Some(lo)), ScalarValue::Int8(Some(hi))) + }), + ScalarValue::Int16(Some(n)) => int_preimage_bounds(*n).map(|(lo, hi)| { + (ScalarValue::Int16(Some(lo)), ScalarValue::Int16(Some(hi))) + }), + ScalarValue::Int32(Some(n)) => int_preimage_bounds(*n).map(|(lo, hi)| { + (ScalarValue::Int32(Some(lo)), ScalarValue::Int32(Some(hi))) + }), + ScalarValue::Int64(Some(n)) => int_preimage_bounds(*n).map(|(lo, hi)| { + (ScalarValue::Int64(Some(lo)), ScalarValue::Int64(Some(hi))) + }), + + // Unsupported types + _ => None, + }) else { + return Ok(PreimageResult::None); + }; + + Ok(PreimageResult::Range { + expr: arg, + interval: Box::new(Interval::try_new(lower, upper)?), + }) + } + fn documentation(&self) -> Option<&Documentation> { self.doc() } } + +// ============ Helper functions for preimage bounds ============ + +/// Compute preimage bounds for floor function on floating-point types. +/// For floor(x) = n, the preimage is [n, n+1). +/// Returns None if the value is non-finite or would lose precision. +fn float_preimage_bounds(n: F) -> Option<(F, F)> { + let one = F::one(); + // Check for non-finite values (infinity, NaN) or precision loss at extreme values + if !n.is_finite() || n + one <= n { + return None; + } + Some((n, n + one)) +} + +/// Compute preimage bounds for floor function on integer types. +/// For floor(x) = n, the preimage is [n, n+1). +/// Returns None if adding 1 would overflow. +fn int_preimage_bounds(n: I) -> Option<(I, I)> { + let upper = n.checked_add(&I::one())?; + Some((n, upper)) +} + +#[cfg(test)] +mod tests { + use super::*; + use datafusion_expr::col; + + /// Helper to test valid preimage cases that should return a Range + fn assert_preimage_range( + input: ScalarValue, + expected_lower: ScalarValue, + expected_upper: ScalarValue, + ) { + let floor_func = FloorFunc::new(); + let args = vec![col("x")]; + let lit_expr = Expr::Literal(input.clone(), None); + let info = SimplifyContext::default(); + + let result = floor_func.preimage(&args, &lit_expr, &info).unwrap(); + + match result { + PreimageResult::Range { expr, interval } => { + assert_eq!(expr, col("x")); + assert_eq!(interval.lower().clone(), expected_lower); + assert_eq!(interval.upper().clone(), expected_upper); + } + PreimageResult::None => { + panic!("Expected Range, got None for input {input:?}") + } + } + } + + /// Helper to test cases that should return None + fn assert_preimage_none(input: ScalarValue) { + let floor_func = FloorFunc::new(); + let args = vec![col("x")]; + let lit_expr = Expr::Literal(input.clone(), None); + let info = SimplifyContext::default(); + + let result = floor_func.preimage(&args, &lit_expr, &info).unwrap(); + assert!( + matches!(result, PreimageResult::None), + "Expected None for input {input:?}" + ); + } + + #[test] + fn test_floor_preimage_valid_cases() { + // Float64 + assert_preimage_range( + ScalarValue::Float64(Some(100.0)), + ScalarValue::Float64(Some(100.0)), + ScalarValue::Float64(Some(101.0)), + ); + // Float32 + assert_preimage_range( + ScalarValue::Float32(Some(50.0)), + ScalarValue::Float32(Some(50.0)), + ScalarValue::Float32(Some(51.0)), + ); + // Int64 + assert_preimage_range( + ScalarValue::Int64(Some(42)), + ScalarValue::Int64(Some(42)), + ScalarValue::Int64(Some(43)), + ); + // Int32 + assert_preimage_range( + ScalarValue::Int32(Some(100)), + ScalarValue::Int32(Some(100)), + ScalarValue::Int32(Some(101)), + ); + // Negative values + assert_preimage_range( + ScalarValue::Float64(Some(-5.0)), + ScalarValue::Float64(Some(-5.0)), + ScalarValue::Float64(Some(-4.0)), + ); + // Zero + assert_preimage_range( + ScalarValue::Float64(Some(0.0)), + ScalarValue::Float64(Some(0.0)), + ScalarValue::Float64(Some(1.0)), + ); + } + + #[test] + fn test_floor_preimage_integer_overflow() { + // All integer types at MAX value should return None + assert_preimage_none(ScalarValue::Int64(Some(i64::MAX))); + assert_preimage_none(ScalarValue::Int32(Some(i32::MAX))); + assert_preimage_none(ScalarValue::Int16(Some(i16::MAX))); + assert_preimage_none(ScalarValue::Int8(Some(i8::MAX))); + } + + #[test] + fn test_floor_preimage_float_edge_cases() { + // Float64 edge cases + assert_preimage_none(ScalarValue::Float64(Some(f64::INFINITY))); + assert_preimage_none(ScalarValue::Float64(Some(f64::NEG_INFINITY))); + assert_preimage_none(ScalarValue::Float64(Some(f64::NAN))); + assert_preimage_none(ScalarValue::Float64(Some(f64::MAX))); // precision loss + + // Float32 edge cases + assert_preimage_none(ScalarValue::Float32(Some(f32::INFINITY))); + assert_preimage_none(ScalarValue::Float32(Some(f32::NEG_INFINITY))); + assert_preimage_none(ScalarValue::Float32(Some(f32::NAN))); + assert_preimage_none(ScalarValue::Float32(Some(f32::MAX))); // precision loss + } + + #[test] + fn test_floor_preimage_null_values() { + assert_preimage_none(ScalarValue::Float64(None)); + assert_preimage_none(ScalarValue::Float32(None)); + assert_preimage_none(ScalarValue::Int64(None)); + } + + #[test] + fn test_floor_preimage_invalid_inputs() { + let floor_func = FloorFunc::new(); + let info = SimplifyContext::default(); + + // Non-literal comparison value + let result = floor_func.preimage(&[col("x")], &col("y"), &info).unwrap(); + assert!( + matches!(result, PreimageResult::None), + "Expected None for non-literal" + ); + + // Wrong argument count (too many) + let lit = Expr::Literal(ScalarValue::Float64(Some(100.0)), None); + let result = floor_func + .preimage(&[col("x"), col("y")], &lit, &info) + .unwrap(); + assert!( + matches!(result, PreimageResult::None), + "Expected None for wrong arg count" + ); + + // Wrong argument count (zero) + let result = floor_func.preimage(&[], &lit, &info).unwrap(); + assert!( + matches!(result, PreimageResult::None), + "Expected None for zero args" + ); + } +}