Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/substrait/src/extensions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use substrait::proto::extensions::simple_extension_declaration::{
/// types. This structs facilitates the use of these extensions in DataFusion.
/// TODO: DF doesn't yet use extensions for type variations <https://github.com/apache/datafusion/issues/11544>
/// TODO: DF doesn't yet provide valid extensionUris <https://github.com/apache/datafusion/issues/11545>
#[derive(Default, Debug, PartialEq)]
#[derive(Clone, Default, Debug, PartialEq)]
pub struct Extensions {
pub functions: HashMap<u32, String>, // anchor -> function name
pub types: HashMap<u32, String>, // anchor -> type name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use datafusion::logical_expr::Expr;
use std::sync::Arc;
use substrait::proto::expression::FieldReference;
use substrait::proto::expression::field_reference::ReferenceType::DirectReference;
use substrait::proto::expression::field_reference::RootType;
use substrait::proto::expression::field_reference::{LambdaParameterReference, RootType};
use substrait::proto::expression::reference_segment::ReferenceType::StructField;

pub async fn from_field_reference(
Expand Down Expand Up @@ -56,9 +56,9 @@ pub(crate) fn from_substrait_field_reference(
Some(RootType::Expression(_)) => not_impl_err!(
"Expression root type in field reference is not supported"
),
Some(RootType::LambdaParameterReference(_)) => not_impl_err!(
"Lambda parameter reference in field reference is not yet supported"
),
Some(RootType::LambdaParameterReference(
LambdaParameterReference { steps_out },
)) => consumer.lambda_variable(*steps_out as usize, field_idx),
}
}
_ => not_impl_err!(
Expand All @@ -85,3 +85,85 @@ fn resolve_outer_reference(
let col = Column::from((qualifier, field));
Ok(Expr::OuterReferenceColumn(Arc::clone(field), col))
}

#[cfg(test)]
mod tests {
use datafusion::{
common::{DFSchema, assert_contains},
prelude::SessionContext,
};
use substrait::proto::{
Type,
expression::{
FieldReference, ReferenceSegment,
field_reference::{self, LambdaParameterReference, RootType},
reference_segment::{ReferenceType, StructField},
},
r#type::{I64, Kind},
};

use crate::{
extensions::Extensions,
logical_plan::consumer::{
DefaultSubstraitConsumer, SubstraitConsumer, from_field_reference,
},
};

#[tokio::test]
async fn test_lambda_variable_invalid_steps_out() {
let lambda_field_ref = lambda_field_ref(0, 99);

let extensions = Extensions::default();
let session_state = SessionContext::new().state();
let consumer = DefaultSubstraitConsumer::new(&extensions, &session_state);

let err =
from_field_reference(&consumer, &lambda_field_ref, DFSchema::empty_ref())
.await
.unwrap_err();

assert_contains!(err.to_string(), "No lambda at 99 steps out, got only 0");
}

#[tokio::test]
async fn test_lambda_variable_invalid_field_idx() {
let lambda_field_ref = lambda_field_ref(1, 0);

let extensions = Extensions::default();
let session_state = SessionContext::new().state();
let (_names, consumer) =
DefaultSubstraitConsumer::new(&extensions, &session_state)
.with_lambda_parameters(
&[Type {
kind: Some(Kind::I64(I64::default())),
}],
DFSchema::empty_ref(),
)
.unwrap();

let err =
from_field_reference(&consumer, &lambda_field_ref, DFSchema::empty_ref())
.await
.unwrap_err();

assert_contains!(
err.to_string(),
"At lambda 0 steps out, no field at index 1, got only 1"
);
}

fn lambda_field_ref(field: i32, steps_out: u32) -> FieldReference {
FieldReference {
reference_type: Some(field_reference::ReferenceType::DirectReference(
ReferenceSegment {
reference_type: Some(ReferenceType::StructField(Box::new(
StructField { field, child: None },
))),
},
)),
root_type: Some(RootType::LambdaParameterReference(
LambdaParameterReference { steps_out },
)),
}
}
}
103 changes: 103 additions & 0 deletions datafusion/substrait/src/logical_plan/consumer/expr/lambda.rs
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we include tests here to check that the following two things result in an error?

  • missing parameters or body
  • invalid steps_out or field index

Or somewhere else if you think it would be more appropriate.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, added tests for all 4 error at 93f17e4 and c019769, thanks

Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use datafusion::{
common::{DFSchema, substrait_err},
prelude::{Expr, lambda},
};
use substrait::proto;

use crate::logical_plan::consumer::SubstraitConsumer;

pub async fn from_lambda(
consumer: &impl SubstraitConsumer,
expr: &proto::expression::Lambda,
input_schema: &DFSchema,
) -> datafusion::common::Result<Expr> {
let Some(parameters) = expr.parameters.as_ref() else {
return substrait_err!("Lambda expression without parameters is not allowed");
};

let (names, consumer_with_parameters) =
consumer.with_lambda_parameters(&parameters.types, input_schema)?;

let Some(body) = expr.body.as_ref() else {
return substrait_err!("Lambda expression without body is not allowed");
};

let body = consumer_with_parameters
.consume_expression(body, input_schema)
.await?;

Ok(lambda(names, body))
}

#[cfg(test)]
mod tests {
use datafusion::{
common::{DFSchema, assert_contains},
prelude::SessionContext,
};
use substrait::proto::{self, Expression, r#type::Struct};

use crate::{
extensions::Extensions,
logical_plan::consumer::{DefaultSubstraitConsumer, from_lambda},
};

#[tokio::test]
async fn test_lambda_without_body() {
let lambda = proto::expression::Lambda {
parameters: Some(Struct::default()),
body: None,
};

let extensions = Extensions::default();
let session_state = SessionContext::new().state();
let consumer = DefaultSubstraitConsumer::new(&extensions, &session_state);

let err = from_lambda(&consumer, &lambda, DFSchema::empty_ref())
.await
.unwrap_err();

assert_contains!(
err.to_string(),
"Lambda expression without body is not allowed"
);
}

#[tokio::test]
async fn test_lambda_without_parameters() {
let lambda = proto::expression::Lambda {
parameters: None,
body: Some(Box::new(Expression::default())),
};

let extensions = Extensions::default();
let session_state = SessionContext::new().state();
let consumer = DefaultSubstraitConsumer::new(&extensions, &session_state);

let err = from_lambda(&consumer, &lambda, DFSchema::empty_ref())
.await
.unwrap_err();

assert_contains!(
err.to_string(),
"Lambda expression without parameters is not allowed"
);
}
}
9 changes: 7 additions & 2 deletions datafusion/substrait/src/logical_plan/consumer/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ mod cast;
mod field_reference;
mod function_arguments;
mod if_then;
mod lambda;
mod literal;
mod nested;
mod scalar_function;
Expand All @@ -32,6 +33,7 @@ pub use cast::*;
pub use field_reference::*;
pub use function_arguments::*;
pub use if_then::*;
pub use lambda::*;
pub use literal::*;
pub use nested::*;
pub use scalar_function::*;
Expand Down Expand Up @@ -95,8 +97,11 @@ pub async fn from_substrait_rex(
RexType::DynamicParameter(expr) => {
consumer.consume_dynamic_parameter(expr, input_schema).await
}
RexType::Lambda(_) | RexType::LambdaInvocation(_) => {
not_impl_err!("Lambda expressions are not yet supported")
RexType::Lambda(lambda) => {
consumer.consume_lambda(lambda.as_ref(), input_schema).await
}
RexType::LambdaInvocation(_) => {
not_impl_err!("Lambda invocations are not supported")
}
},
None => substrait_err!("Expression must set rex_type: {expression:?}"),
Expand Down
Comment thread
gstvg marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ pub async fn from_scalar_function(
f: &ScalarFunction,
input_schema: &DFSchema,
) -> Result<Expr> {
//TODO: handle higher order functions, as they are also encoded as scalar functions
let Some(fn_signature) = consumer
.get_extensions()
.functions
Expand All @@ -45,6 +44,20 @@ pub async fn from_scalar_function(
let fn_name = substrait_fun_name(fn_signature);
let args = from_substrait_func_args(consumer, &f.arguments, input_schema).await?;

let higher_order_func = consumer
.get_function_registry()
.higher_order_function(fn_name)
.or_else(|e| {
if let Some(alt_name) = substrait_to_df_name(fn_name) {
consumer
.get_function_registry()
.higher_order_function(alt_name)
.or(Err(e))
} else {
Err(e)
}
});

let udf_func = consumer.get_function_registry().udf(fn_name).or_else(|e| {
if let Some(alt_name) = substrait_to_df_name(fn_name) {
consumer.get_function_registry().udf(alt_name).or(Err(e))
Expand All @@ -53,9 +66,14 @@ pub async fn from_scalar_function(
}
});

// try to first match the requested function into registered udfs, then built-in ops
// try to first match the requested function into registered higher-order functions, then udfs, built-in ops
// and finally built-in expressions
if let Ok(func) = udf_func {
if let Ok(func) = higher_order_func {
Ok(Expr::HigherOrderFunction(expr::HigherOrderFunction::new(
func.to_owned(),
args,
)))
} else if let Ok(func) = udf_func {
Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf(
func.to_owned(),
args,
Expand Down
Loading
Loading