Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 240 additions & 2 deletions datafusion/functions/src/math/floor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@ use arrow::datatypes::{
};
use datafusion_common::{Result, ScalarValue, exec_err};
use datafusion_expr::interval_arithmetic::Interval;
use datafusion_expr::preimage::PreimageResult;
use datafusion_expr::simplify::SimplifyContext;
use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
use datafusion_expr::{
Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
TypeSignature, TypeSignatureClass, Volatility,
Coercion, ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDFImpl,
Signature, TypeSignature, TypeSignatureClass, Volatility,
};
use datafusion_macros::user_doc;
use num_traits::{CheckedAdd, Float, One};

use super::decimal::{apply_decimal_op, floor_decimal_value};

Expand Down Expand Up @@ -200,7 +203,242 @@ impl ScalarUDFImpl for FloorFunc {
Interval::make_unbounded(&data_type)
}

/// Compute the preimage for floor function.
///
/// For `floor(x) = N`, the preimage is `x >= N AND x < N + 1`
/// because floor(x) = N for all x in [N, N+1).
///
/// This enables predicate pushdown optimizations, transforming:
/// `floor(col) = 100` into `col >= 100 AND col < 101`
fn preimage(
&self,
args: &[Expr],
lit_expr: &Expr,
_info: &SimplifyContext,
) -> Result<PreimageResult> {
// floor takes exactly one argument
if args.len() != 1 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps its good to debug_assert! here?

return Ok(PreimageResult::None);
}

let arg = args[0].clone();

// Extract the literal value being compared to
let Expr::Literal(lit_value, _) = lit_expr else {
return Ok(PreimageResult::None);
};

// Compute lower bound (N) and upper bound (N + 1) using helper functions
let Some((lower, upper)) = (match lit_value {
// Floating-point types
ScalarValue::Float64(Some(n)) => float_preimage_bounds(*n).map(|(lo, hi)| {
(
ScalarValue::Float64(Some(lo)),
ScalarValue::Float64(Some(hi)),
)
}),
ScalarValue::Float32(Some(n)) => float_preimage_bounds(*n).map(|(lo, hi)| {
(
ScalarValue::Float32(Some(lo)),
ScalarValue::Float32(Some(hi)),
)
}),

// Integer types
ScalarValue::Int8(Some(n)) => int_preimage_bounds(*n).map(|(lo, hi)| {
(ScalarValue::Int8(Some(lo)), ScalarValue::Int8(Some(hi)))
}),
ScalarValue::Int16(Some(n)) => int_preimage_bounds(*n).map(|(lo, hi)| {
(ScalarValue::Int16(Some(lo)), ScalarValue::Int16(Some(hi)))
}),
ScalarValue::Int32(Some(n)) => int_preimage_bounds(*n).map(|(lo, hi)| {
(ScalarValue::Int32(Some(lo)), ScalarValue::Int32(Some(hi)))
}),
ScalarValue::Int64(Some(n)) => int_preimage_bounds(*n).map(|(lo, hi)| {
(ScalarValue::Int64(Some(lo)), ScalarValue::Int64(Some(hi)))
}),

// Unsupported types
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

floor also supports decimal types. Should we add those here?

_ => None,
}) else {
return Ok(PreimageResult::None);
};

Ok(PreimageResult::Range {
expr: arg,
interval: Box::new(Interval::try_new(lower, upper)?),
})
}

fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}

// ============ Helper functions for preimage bounds ============

/// Compute preimage bounds for floor function on floating-point types.
/// For floor(x) = n, the preimage is [n, n+1).
/// Returns None if the value is non-finite or would lose precision.
fn float_preimage_bounds<F: Float>(n: F) -> Option<(F, F)> {
let one = F::one();
// Check for non-finite values (infinity, NaN) or precision loss at extreme values
if !n.is_finite() || n + one <= n {
return None;
}
Some((n, n + one))
}

/// Compute preimage bounds for floor function on integer types.
/// For floor(x) = n, the preimage is [n, n+1).
/// Returns None if adding 1 would overflow.
fn int_preimage_bounds<I: CheckedAdd + One + Copy>(n: I) -> Option<(I, I)> {
let upper = n.checked_add(&I::one())?;
Some((n, upper))
}

#[cfg(test)]
mod tests {
use super::*;
use datafusion_expr::col;

/// Helper to test valid preimage cases that should return a Range
fn assert_preimage_range(
input: ScalarValue,
expected_lower: ScalarValue,
expected_upper: ScalarValue,
) {
let floor_func = FloorFunc::new();
let args = vec![col("x")];
let lit_expr = Expr::Literal(input.clone(), None);
let info = SimplifyContext::default();

let result = floor_func.preimage(&args, &lit_expr, &info).unwrap();

match result {
PreimageResult::Range { expr, interval } => {
assert_eq!(expr, col("x"));
assert_eq!(interval.lower().clone(), expected_lower);
assert_eq!(interval.upper().clone(), expected_upper);
}
PreimageResult::None => {
panic!("Expected Range, got None for input {input:?}")
}
}
}

/// Helper to test cases that should return None
fn assert_preimage_none(input: ScalarValue) {
let floor_func = FloorFunc::new();
let args = vec![col("x")];
let lit_expr = Expr::Literal(input.clone(), None);
let info = SimplifyContext::default();

let result = floor_func.preimage(&args, &lit_expr, &info).unwrap();
assert!(
matches!(result, PreimageResult::None),
"Expected None for input {input:?}"
);
}

#[test]
fn test_floor_preimage_valid_cases() {
// Float64
assert_preimage_range(
ScalarValue::Float64(Some(100.0)),
ScalarValue::Float64(Some(100.0)),
ScalarValue::Float64(Some(101.0)),
);
// Float32
assert_preimage_range(
ScalarValue::Float32(Some(50.0)),
ScalarValue::Float32(Some(50.0)),
ScalarValue::Float32(Some(51.0)),
);
// Int64
assert_preimage_range(
ScalarValue::Int64(Some(42)),
ScalarValue::Int64(Some(42)),
ScalarValue::Int64(Some(43)),
);
// Int32
assert_preimage_range(
ScalarValue::Int32(Some(100)),
ScalarValue::Int32(Some(100)),
ScalarValue::Int32(Some(101)),
);
// Negative values
assert_preimage_range(
ScalarValue::Float64(Some(-5.0)),
ScalarValue::Float64(Some(-5.0)),
ScalarValue::Float64(Some(-4.0)),
);
// Zero
assert_preimage_range(
ScalarValue::Float64(Some(0.0)),
ScalarValue::Float64(Some(0.0)),
ScalarValue::Float64(Some(1.0)),
);
}

#[test]
fn test_floor_preimage_integer_overflow() {
// All integer types at MAX value should return None
assert_preimage_none(ScalarValue::Int64(Some(i64::MAX)));
assert_preimage_none(ScalarValue::Int32(Some(i32::MAX)));
assert_preimage_none(ScalarValue::Int16(Some(i16::MAX)));
assert_preimage_none(ScalarValue::Int8(Some(i8::MAX)));
}

#[test]
fn test_floor_preimage_float_edge_cases() {
// Float64 edge cases
assert_preimage_none(ScalarValue::Float64(Some(f64::INFINITY)));
assert_preimage_none(ScalarValue::Float64(Some(f64::NEG_INFINITY)));
assert_preimage_none(ScalarValue::Float64(Some(f64::NAN)));
assert_preimage_none(ScalarValue::Float64(Some(f64::MAX))); // precision loss

// Float32 edge cases
assert_preimage_none(ScalarValue::Float32(Some(f32::INFINITY)));
assert_preimage_none(ScalarValue::Float32(Some(f32::NEG_INFINITY)));
assert_preimage_none(ScalarValue::Float32(Some(f32::NAN)));
assert_preimage_none(ScalarValue::Float32(Some(f32::MAX))); // precision loss
}

#[test]
fn test_floor_preimage_null_values() {
assert_preimage_none(ScalarValue::Float64(None));
assert_preimage_none(ScalarValue::Float32(None));
assert_preimage_none(ScalarValue::Int64(None));
}

#[test]
fn test_floor_preimage_invalid_inputs() {
let floor_func = FloorFunc::new();
let info = SimplifyContext::default();

// Non-literal comparison value
let result = floor_func.preimage(&[col("x")], &col("y"), &info).unwrap();
assert!(
matches!(result, PreimageResult::None),
"Expected None for non-literal"
);

// Wrong argument count (too many)
let lit = Expr::Literal(ScalarValue::Float64(Some(100.0)), None);
let result = floor_func
.preimage(&[col("x"), col("y")], &lit, &info)
.unwrap();
assert!(
matches!(result, PreimageResult::None),
"Expected None for wrong arg count"
);

// Wrong argument count (zero)
let result = floor_func.preimage(&[], &lit, &info).unwrap();
assert!(
matches!(result, PreimageResult::None),
"Expected None for zero args"
);
}
}