Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 11 additions & 28 deletions datafusion/expr/src/higher_order_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ pub struct HigherOrderSignature {
pub type_signature: HigherOrderTypeSignature,
/// The volatility of the function. See [Volatility] for more information.
pub volatility: Volatility,
/// Whether [HigherOrderUDF::coerce_values_for_lambdas] should be called
pub coerce_values_for_lambdas: bool,
/// The max number of times to call [HigherOrderUDF::lambda_parameters] before raising an error.
/// Used to guard against implementations that causes an infinite loop by endlessly returning
/// [LambdaParametersProgress::Partial]. Defaults to 256
Expand All @@ -90,7 +88,6 @@ impl HigherOrderSignature {
HigherOrderSignature {
type_signature,
volatility,
coerce_values_for_lambdas: false,
lambda_parameters_max_iterations: LAMBDA_PARAMETERS_MAX_ITERATIONS,
}
}
Expand All @@ -100,7 +97,6 @@ impl HigherOrderSignature {
Self {
type_signature: HigherOrderTypeSignature::UserDefined,
volatility,
coerce_values_for_lambdas: false,
lambda_parameters_max_iterations: LAMBDA_PARAMETERS_MAX_ITERATIONS,
}
}
Expand All @@ -110,7 +106,6 @@ impl HigherOrderSignature {
Self {
type_signature: HigherOrderTypeSignature::VariadicAny,
volatility,
coerce_values_for_lambdas: false,
lambda_parameters_max_iterations: LAMBDA_PARAMETERS_MAX_ITERATIONS,
}
}
Expand All @@ -120,18 +115,9 @@ impl HigherOrderSignature {
Self {
type_signature: HigherOrderTypeSignature::Any(arg_count),
volatility,
coerce_values_for_lambdas: false,
lambda_parameters_max_iterations: LAMBDA_PARAMETERS_MAX_ITERATIONS,
}
}

/// Set [Self::coerce_values_for_lambdas] to true to indicate that [HigherOrderUDF::coerce_values_for_lambdas]
/// should be called
pub fn with_coerce_values_for_lambdas(mut self) -> Self {
self.coerce_values_for_lambdas = true;

self
}
}

impl PartialEq for dyn HigherOrderUDF {
Expand Down Expand Up @@ -608,12 +594,12 @@ pub trait HigherOrderUDF: Debug + DynEq + DynHash + Send + Sync + Any {
///
/// assert_eq!(
/// coerce_to,
/// vec![
/// Some(vec![
/// // return the same type for the array being reduced
/// DataType::new_list(DataType::Float32, true),
/// // coerce the initial value to the output of the merge lambda
/// DataType::Float32,
/// ]
/// ])
/// );
///
/// ```
Expand All @@ -623,7 +609,7 @@ pub trait HigherOrderUDF: Debug + DynEq + DynHash + Send + Sync + Any {
///
/// The implementation can assume that some other part of the code has coerced
/// the actual argument types to match [`Self::signature`], except the coercion defined by
/// [Self::coerce_values_for_lambdas], if applicable.
/// [Self::coerce_values_for_lambdas].
///
/// [`HigherOrderFunction`]: crate::expr::HigherOrderFunction
/// [`HigherOrderFunction::lambda_parameters`]: crate::expr::HigherOrderFunction::lambda_parameters
Expand All @@ -636,8 +622,7 @@ pub trait HigherOrderUDF: Debug + DynEq + DynHash + Send + Sync + Any {
/// Coerce value arguments of a function call to types that the function can evaluate also taking into
/// account the *output type of it's lambdas*. This differs from [HigherOrderUDF::coerce_value_types]
/// that only has access to the type of it's value arguments because it's called before the output type
/// of lambdas are known. So that this method is called, the function must have it's
/// [HigherOrderSignature::coerce_values_for_lambdas] set to true
/// of lambdas are known.
///
/// See the [type coercion module](crate::type_coercion)
/// documentation for more details on type coercion
Expand All @@ -646,29 +631,27 @@ pub trait HigherOrderUDF: Debug + DynEq + DynHash + Send + Sync + Any {
/// * `fields`: The argument types of the value arguments of this function, or the output type of lambdas
///
/// # Return value
/// A Vec with the same number of [ValueOrLambda::Value] in `fields`. DataFusion will `CAST` the
/// function call arguments to these specific types.
/// If `Some`, contains a Vec with the same number of [ValueOrLambda::Value] in `fields`.
/// DataFusion will `CAST` the function call arguments to these specific types. If `None`, no
/// coercion will be applied beyond the one defined by the function signature.
///
/// For example, a flexible array_reduce implementation (see [Self::lambda_parameters] docs), when working
/// with the expression below, may want to coerce it's initial value argument, the *integer* `0`,
/// to match the output it's merge function, which is a *float*:
/// to match the output of it's merge function, which is a *float*:
///
/// `array_reduce([1.2, 2.1], 0, (acc, v) -> acc + v + 1.5, v -> v > 2.0)`
fn coerce_values_for_lambdas(
&self,
_fields: &[ValueOrLambda<DataType, DataType>],
) -> Result<Vec<DataType>> {
not_impl_err!(
"{} coerce_values_for_lambdas is not implemented",
self.name()
)
) -> Result<Option<Vec<DataType>>> {
Ok(None)
}

/// What type will be returned by this function, given the arguments?
///
/// The implementation can assume that some other part of the code has coerced
/// the actual argument types to match [`Self::signature`], including the coercion
/// defined by [Self::coerce_values_for_lambdas], if applicable.
/// defined by [Self::coerce_values_for_lambdas].
///
/// # Example creating `Field`
///
Expand Down
35 changes: 17 additions & 18 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ pub fn fields_with_udf<F: UDFCoercionExt>(
/// argument must be coerced to match `signature`.
/// For lambda arguments, returns a clone of the associated data
///
/// Note this does not invokes [HigherOrderUDF::coerce_values_for_lambdas]
/// if requested by the function signature. If that's required, use
/// [value_fields_with_higher_order_udf_and_lambdas] instead
/// Note this does not invokes [HigherOrderUDF::coerce_values_for_lambdas].
/// If that's required, use [value_fields_with_higher_order_udf_and_lambdas]
/// instead
///
/// For more details on coercion in general, please see the
/// [`type_coercion`](crate::type_coercion) module.
Expand Down Expand Up @@ -235,8 +235,8 @@ pub fn value_fields_with_higher_order_udf<L: Clone>(

/// Performs type coercion for higher order function arguments,
/// including those defined by [HigherOrderUDF::coerce_values_for_lambdas],
/// if defined by the signature. Note that compared to
/// [value_fields_with_higher_order_udf], this function requires
/// if it returns `Some(...)` instead of the default `None`. Note that
/// compared to [value_fields_with_higher_order_udf], this function requires
/// the [ValueOrLambda::Lambda] variant to contain the output field of the lambda.
///
/// For value arguments, returns the field to which each
Expand All @@ -251,16 +251,16 @@ pub fn value_fields_with_higher_order_udf_and_lambdas(
) -> Result<Vec<ValueOrLambda<FieldRef, FieldRef>>> {
let mut new_fields = value_fields_with_higher_order_udf(current_fields, func)?;

if func.signature().coerce_values_for_lambdas {
let new_types = new_fields
.iter()
.map(|f| match f {
ValueOrLambda::Value(f) => ValueOrLambda::Value(f.data_type().clone()),
ValueOrLambda::Lambda(f) => ValueOrLambda::Lambda(f.data_type().clone()),
})
.collect::<Vec<_>>();
let new_types = new_fields
.iter()
.map(|f| match f {
ValueOrLambda::Value(f) => ValueOrLambda::Value(f.data_type().clone()),
ValueOrLambda::Lambda(f) => ValueOrLambda::Lambda(f.data_type().clone()),
})
.collect::<Vec<_>>();

let mut new_value_types = func.coerce_values_for_lambdas(&new_types)?.into_iter();
if let Some(new_value_types) = func.coerce_values_for_lambdas(&new_types)? {
let mut new_value_types = new_value_types.into_iter();

let value_types_count = new_types
.iter()
Expand Down Expand Up @@ -1851,7 +1851,7 @@ mod tests {
fn coerce_values_for_lambdas(
&self,
fields: &[ValueOrLambda<DataType, DataType>],
) -> Result<Vec<DataType>> {
) -> Result<Option<Vec<DataType>>> {
// thoerical impl of array_reduce without finish
let [
ValueOrLambda::Value(list),
Expand All @@ -1862,7 +1862,7 @@ mod tests {
unreachable!()
};

Ok(vec![list.clone(), merge.clone()])
Ok(Some(vec![list.clone(), merge.clone()]))
}

fn lambda_parameters(
Expand Down Expand Up @@ -1925,8 +1925,7 @@ mod tests {
#[test]
fn test_higher_order_function_coerce_values_for_lambdas() {
let fun = MockHigherOrderUDF {
signature: HigherOrderSignature::variadic_any(Volatility::Immutable)
.with_coerce_values_for_lambdas(),
signature: HigherOrderSignature::variadic_any(Volatility::Immutable),
coerced_value_types: vec![],
};

Expand Down
Loading