diff --git a/datafusion-examples/examples/custom_data_source/custom_file_casts.rs b/datafusion-examples/examples/custom_data_source/custom_file_casts.rs index 36cc936332065..aaac884fb5385 100644 --- a/datafusion-examples/examples/custom_data_source/custom_file_casts.rs +++ b/datafusion-examples/examples/custom_data_source/custom_file_casts.rs @@ -51,7 +51,7 @@ pub async fn custom_file_casts() -> Result<()> { // Create a logical / table schema with an Int32 column let logical_schema = - Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])); + Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, true)])); // Create some data that can be cast (Int16 -> Int32 is widening) and some that cannot (Int64 -> Int32 is narrowing) let store = Arc::new(InMemory::new()) as Arc; diff --git a/datafusion/common/src/format.rs b/datafusion/common/src/format.rs index a505bd0e1c74e..4543f2584a8bc 100644 --- a/datafusion/common/src/format.rs +++ b/datafusion/common/src/format.rs @@ -24,6 +24,213 @@ use arrow::util::display::{DurationFormat, FormatOptions}; use crate::config::{ConfigField, Visit}; use crate::error::{DataFusionError, Result}; +/// Owned version of Arrow's `FormatOptions` with `String` instead of `&'static str`. +/// +/// Arrow's `FormatOptions<'a>` requires borrowed strings with lifetime bounds, +/// and often requires `&'static str` for storage in long-lived types like `CastExpr`. +/// This struct uses owned `String` values instead, allowing dynamic format options +/// to be created from user queries, Protobuf deserialization, or IPC without +/// memory leaks or string interning. +/// +/// # Conversion to Arrow Types +/// +/// Use the `as_arrow_options()` method to temporarily convert to `FormatOptions<'a>` +/// with borrowed `&str` references for passing to Arrow compute kernels: +/// +/// ```ignore +/// let owned_options = OwnedFormatOptions { ... }; +/// let arrow_options = owned_options.as_arrow_options(); // borrows owned strings +/// arrow::compute::cast(&array, &data_type, Some(&arrow_options))?; +/// ``` +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub struct OwnedFormatOptions { + /// String representation of null values + pub null: String, + /// Date format string + pub date_format: Option, + /// Datetime format string + pub datetime_format: Option, + /// Timestamp format string + pub timestamp_format: Option, + /// Timestamp with timezone format string + pub timestamp_tz_format: Option, + /// Time format string + pub time_format: Option, + /// Duration format (owned, since DurationFormat is a simple enum) + pub duration_format: DurationFormat, + /// Include type information in formatted output + pub types_info: bool, +} + +impl OwnedFormatOptions { + /// Create a new `OwnedFormatOptions` with default values. + pub fn new() -> Self { + Self::default() + } + + /// Set the null string. + pub fn with_null(mut self, null: String) -> Self { + self.null = null; + self + } + + /// Set the date format. + pub fn with_date_format(mut self, date_format: Option) -> Self { + self.date_format = date_format; + self + } + + /// Set the datetime format. + pub fn with_datetime_format(mut self, datetime_format: Option) -> Self { + self.datetime_format = datetime_format; + self + } + + /// Set the timestamp format. + pub fn with_timestamp_format(mut self, timestamp_format: Option) -> Self { + self.timestamp_format = timestamp_format; + self + } + + /// Set the timestamp with timezone format. + pub fn with_timestamp_tz_format( + mut self, + timestamp_tz_format: Option, + ) -> Self { + self.timestamp_tz_format = timestamp_tz_format; + self + } + + /// Set the time format. + pub fn with_time_format(mut self, time_format: Option) -> Self { + self.time_format = time_format; + self + } + + /// Set the duration format. + pub fn with_duration_format(mut self, duration_format: DurationFormat) -> Self { + self.duration_format = duration_format; + self + } + + /// Set whether to include type information in formatted output. + pub fn with_types_info(mut self, types_info: bool) -> Self { + self.types_info = types_info; + self + } + + /// Convert to Arrow's `FormatOptions<'a>` with borrowed references. + /// + /// This creates a temporary `FormatOptions` with borrowed `&str` references + /// to the owned strings. The returned options can be passed to Arrow compute + /// kernels. The borrowed references are valid only as long as `self` is alive. + pub fn as_arrow_options<'a>(&'a self) -> FormatOptions<'a> { + FormatOptions::new() + .with_null(self.null.as_str()) + .with_date_format(self.date_format.as_deref()) + .with_datetime_format(self.datetime_format.as_deref()) + .with_timestamp_format(self.timestamp_format.as_deref()) + .with_timestamp_tz_format(self.timestamp_tz_format.as_deref()) + .with_time_format(self.time_format.as_deref()) + .with_duration_format(self.duration_format) + .with_display_error(false) // safe field is handled separately + .with_types_info(self.types_info) + } +} + +impl Default for OwnedFormatOptions { + fn default() -> Self { + Self { + null: "NULL".to_string(), + date_format: None, + datetime_format: None, + timestamp_format: None, + timestamp_tz_format: None, + time_format: None, + duration_format: DurationFormat::Pretty, + types_info: false, + } + } +} + +/// Owned version of Arrow's `CastOptions` with `OwnedFormatOptions` instead of `FormatOptions<'static>`. +/// +/// Arrow's `CastOptions<'static>` requires `FormatOptions<'static>`, which mandates +/// `&'static str` references. This struct uses `OwnedFormatOptions` with `String` values, +/// allowing dynamic cast options to be created without memory leaks. +/// +/// # Conversion to Arrow Types +/// +/// Use the `as_arrow_options()` method to temporarily convert to `CastOptions<'a>` +/// with borrowed references for passing to Arrow compute kernels: +/// +/// ```ignore +/// let owned_options = OwnedCastOptions { ... }; +/// let arrow_options = owned_options.as_arrow_options(); // borrows owned strings +/// arrow::compute::cast(&array, &data_type, Some(&arrow_options))?; +/// ``` +#[derive(Debug, Clone, Eq, PartialEq, Hash, Default)] +pub struct OwnedCastOptions { + /// Whether to use safe casting (return errors instead of overflowing) + pub safe: bool, + /// Format options for string output + pub format_options: OwnedFormatOptions, +} + +impl OwnedCastOptions { + /// Create a new `OwnedCastOptions` with default values. + pub fn new(safe: bool) -> Self { + Self { + safe, + format_options: OwnedFormatOptions::default(), + } + } + + /// Create a new `OwnedCastOptions` from an Arrow `CastOptions`. + pub fn from_arrow_options(options: &CastOptions<'_>) -> Self { + Self { + safe: options.safe, + format_options: OwnedFormatOptions { + null: options.format_options.null().to_string(), + date_format: options + .format_options + .date_format() + .map(ToString::to_string), + datetime_format: options + .format_options + .datetime_format() + .map(ToString::to_string), + timestamp_format: options + .format_options + .timestamp_format() + .map(ToString::to_string), + timestamp_tz_format: options + .format_options + .timestamp_tz_format() + .map(ToString::to_string), + time_format: options + .format_options + .time_format() + .map(ToString::to_string), + duration_format: options.format_options.duration_format(), + types_info: options.format_options.types_info(), + }, + } + } + + /// Convert to Arrow's `CastOptions<'a>` with borrowed references. + /// + /// This creates a temporary `CastOptions` with borrowed `&str` references + /// to the owned strings. The returned options can be passed to Arrow compute + /// kernels. The borrowed references are valid only as long as `self` is alive. + pub fn as_arrow_options<'a>(&'a self) -> CastOptions<'a> { + CastOptions { + safe: self.safe, + format_options: self.format_options.as_arrow_options(), + } + } +} + /// The default [`FormatOptions`] to use within DataFusion /// Also see [`crate::config::FormatOptions`] pub const DEFAULT_FORMAT_OPTIONS: FormatOptions<'static> = diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index df6659c6f843c..65f0ed89a7620 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -80,6 +80,7 @@ pub use file_options::file_type::{ DEFAULT_ARROW_EXTENSION, DEFAULT_AVRO_EXTENSION, DEFAULT_CSV_EXTENSION, DEFAULT_JSON_EXTENSION, DEFAULT_PARQUET_EXTENSION, GetExt, }; +pub use format::{OwnedCastOptions, OwnedFormatOptions}; pub use functional_dependencies::{ Constraint, Constraints, Dependency, FunctionalDependence, FunctionalDependencies, aggregate_functional_dependencies, get_required_group_by_exprs_indices, diff --git a/datafusion/common/src/nested_struct.rs b/datafusion/common/src/nested_struct.rs index f3f45cfa44e9e..44081cdf76787 100644 --- a/datafusion/common/src/nested_struct.rs +++ b/datafusion/common/src/nested_struct.rs @@ -271,7 +271,15 @@ pub fn validate_struct_compatibility( Ok(()) } -fn validate_field_compatibility( +/// Validate that a field can be cast from source to target type. +/// +/// This function checks: +/// - Nullability compatibility: cannot cast nullable → non-nullable +/// - Data type castability using Arrow's can_cast_types +/// - Recursive validation for nested struct types +/// +/// This validation is used for both top-level fields and nested struct fields. +pub fn validate_field_compatibility( source_field: &Field, target_field: &Field, ) -> Result<()> { diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 064091971cf88..6d8e7cb4d781d 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -3678,7 +3678,7 @@ impl ScalarValue { pub fn cast_to_with_options( &self, target_type: &DataType, - cast_options: &CastOptions<'static>, + cast_options: &CastOptions<'_>, ) -> Result { let source_type = self.data_type(); if let Some(multiplier) = date_to_timestamp_multiplier(&source_type, target_type) diff --git a/datafusion/core/src/datasource/physical_plan/parquet.rs b/datafusion/core/src/datasource/physical_plan/parquet.rs index 9b4733dbcc178..89384f97b5f9a 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet.rs @@ -277,7 +277,8 @@ mod tests { array: ArrayRef, ) -> RecordBatch { let mut fields = SchemaBuilder::from(batch.schema().fields()); - fields.push(Field::new(field_name, array.data_type().clone(), true)); + let nullable = array.null_count() > 0; + fields.push(Field::new(field_name, array.data_type().clone(), nullable)); let schema = Arc::new(fields.finish()); let mut columns = batch.columns().to_vec(); @@ -1135,12 +1136,24 @@ mod tests { let batch3 = create_batch(vec![("c1", c1.clone()), ("c2", c2.clone())]); // batch4 (has c2, c1) -- different column order, should still prune - let batch4 = create_batch(vec![("c2", c2), ("c1", c1)]); + let batch4 = create_batch(vec![ + // Ensure c1 appears in this batch to avoid non-nullable missing column errors + ("c1", c1.clone()), + ("c2", c2), + ]); let filter = col("c2").eq(lit(1_i64)); + // Provide a nullable logical schema so missing columns across batches + // are filled with nulls rather than treated as non-nullable. + let table_schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Utf8, true), + Field::new("c2", DataType::Int64, true), + ])); + // read/write them files: let rt = RoundTrip::new() + .with_table_schema(table_schema) .with_predicate(filter) .with_page_index_predicate() .round_trip(vec![batch1, batch2, batch3, batch4]) diff --git a/datafusion/core/tests/parquet/expr_adapter.rs b/datafusion/core/tests/parquet/expr_adapter.rs index aee37fda1670d..dbe389f243507 100644 --- a/datafusion/core/tests/parquet/expr_adapter.rs +++ b/datafusion/core/tests/parquet/expr_adapter.rs @@ -136,7 +136,7 @@ async fn test_custom_schema_adapter_and_custom_expression_adapter() { write_parquet(batch, store.clone(), path).await; let table_schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Int64, false), + Field::new("c1", DataType::Int64, true), Field::new("c2", DataType::Utf8, true), ])); @@ -234,9 +234,9 @@ async fn test_physical_expr_adapter_with_non_null_defaults() { // Table schema has additional columns c2 (Utf8) and c3 (Int64) that don't exist in file let table_schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Int64, false), // type differs from file (Int32 vs Int64) - Field::new("c2", DataType::Utf8, true), // missing from file - Field::new("c3", DataType::Int64, true), // missing from file + Field::new("c1", DataType::Int64, true), // type differs from file (Int32 vs Int64) + Field::new("c2", DataType::Utf8, true), // missing from file + Field::new("c3", DataType::Int64, true), // missing from file ])); let mut cfg = SessionConfig::new() @@ -343,7 +343,7 @@ async fn test_physical_expr_adapter_factory_reuse_across_tables() { // Table schema has additional columns that don't exist in files let table_schema = Arc::new(Schema::new(vec![ - Field::new("c1", DataType::Int64, false), + Field::new("c1", DataType::Int64, true), Field::new("c2", DataType::Utf8, true), // missing from files ])); diff --git a/datafusion/datasource-parquet/src/row_filter.rs b/datafusion/datasource-parquet/src/row_filter.rs index 2924208c5bd99..0f461c2e2e2d8 100644 --- a/datafusion/datasource-parquet/src/row_filter.rs +++ b/datafusion/datasource-parquet/src/row_filter.rs @@ -728,7 +728,7 @@ mod test { let table_schema = Schema::new(vec![Field::new( "timestamp_col", DataType::Timestamp(Nanosecond, Some(Arc::from("UTC"))), - false, + true, )]); // Test all should fail diff --git a/datafusion/expr-common/src/columnar_value.rs b/datafusion/expr-common/src/columnar_value.rs index 1aa42470a1481..1b13fa9173733 100644 --- a/datafusion/expr-common/src/columnar_value.rs +++ b/datafusion/expr-common/src/columnar_value.rs @@ -288,16 +288,17 @@ impl ColumnarValue { pub fn cast_to( &self, cast_type: &DataType, - cast_options: Option<&CastOptions<'static>>, + cast_options: Option<&CastOptions<'_>>, ) -> Result { - let cast_options = cast_options.cloned().unwrap_or(DEFAULT_CAST_OPTIONS); + // Use provided options when available; otherwise fallback to global default + let cast_options = cast_options.unwrap_or(&DEFAULT_CAST_OPTIONS); match self { ColumnarValue::Array(array) => { - let casted = cast_array_by_name(array, cast_type, &cast_options)?; + let casted = cast_array_by_name(array, cast_type, cast_options)?; Ok(ColumnarValue::Array(casted)) } ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar( - scalar.cast_to_with_options(cast_type, &cast_options)?, + scalar.cast_to_with_options(cast_type, cast_options)?, )), } } @@ -306,7 +307,7 @@ impl ColumnarValue { fn cast_array_by_name( array: &ArrayRef, cast_type: &DataType, - cast_options: &CastOptions<'static>, + cast_options: &CastOptions<'_>, ) -> Result { // If types are already equal, no cast needed if array.data_type() == cast_type { diff --git a/datafusion/physical-expr-adapter/src/schema_rewriter.rs b/datafusion/physical-expr-adapter/src/schema_rewriter.rs index 7b94ed263b0e4..8a049fc8f6ddf 100644 --- a/datafusion/physical-expr-adapter/src/schema_rewriter.rs +++ b/datafusion/physical-expr-adapter/src/schema_rewriter.rs @@ -26,7 +26,7 @@ use std::sync::Arc; use arrow::array::RecordBatch; use arrow::compute::can_cast_types; -use arrow::datatypes::{DataType, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Field, SchemaRef}; use datafusion_common::{ Result, ScalarValue, exec_err, nested_struct::validate_struct_compatibility, @@ -260,20 +260,20 @@ impl DefaultPhysicalExprAdapter { impl PhysicalExprAdapter for DefaultPhysicalExprAdapter { fn rewrite(&self, expr: Arc) -> Result> { let rewriter = DefaultPhysicalExprAdapterRewriter { - logical_file_schema: &self.logical_file_schema, - physical_file_schema: &self.physical_file_schema, + logical_file_schema: Arc::clone(&self.logical_file_schema), + physical_file_schema: Arc::clone(&self.physical_file_schema), }; expr.transform(|expr| rewriter.rewrite_expr(Arc::clone(&expr))) .data() } } -struct DefaultPhysicalExprAdapterRewriter<'a> { - logical_file_schema: &'a Schema, - physical_file_schema: &'a Schema, +struct DefaultPhysicalExprAdapterRewriter { + logical_file_schema: SchemaRef, + physical_file_schema: SchemaRef, } -impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { +impl DefaultPhysicalExprAdapterRewriter { fn rewrite_expr( &self, expr: Arc, @@ -421,18 +421,19 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { }; let physical_field = self.physical_file_schema.field(physical_column_index); - let column = match ( - column.index() == physical_column_index, - logical_field.data_type() == physical_field.data_type(), - ) { - // If the column index matches and the data types match, we can use the column as is - (true, true) => return Ok(Transformed::no(expr)), - // If the indexes or data types do not match, we need to create a new column expression - (true, _) => column.clone(), - (false, _) => { - Column::new_with_schema(logical_field.name(), self.physical_file_schema)? - } - }; + // Check if index and types match - if so, we can return early + if column.index() == physical_column_index + && logical_field.data_type() == physical_field.data_type() + { + return Ok(Transformed::no(expr)); + } + + let column = self.resolve_column( + column, + physical_column_index, + logical_field.data_type(), + physical_field.data_type(), + )?; if logical_field.data_type() == physical_field.data_type() { // If the data types match, we can use the column as is @@ -443,36 +444,80 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { // TODO: add optimization to move the cast from the column to literal expressions in the case of `col = 123` // since that's much cheaper to evalaute. // See https://github.com/apache/datafusion/issues/15780#issuecomment-2824716928 - // + self.create_cast_column_expr(column, logical_field) + } + + /// Resolves a column expression, handling index and type mismatches. + /// + /// Returns the appropriate Column expression when the column's index or data type + /// don't match the physical schema. Assumes that the early-exit case (both index + /// and type match) has already been checked by the caller. + fn resolve_column( + &self, + column: &Column, + physical_column_index: usize, + _logical_type: &DataType, + _physical_type: &DataType, + ) -> Result { + if column.index() == physical_column_index { + // Index matches but type differs - reuse the column as-is + Ok(column.clone()) + } else { + // Index doesn't match - create a new column with the correct index + Column::new_with_schema(column.name(), self.physical_file_schema.as_ref()) + } + } + + /// Validates type compatibility and creates a CastColumnExpr if needed. + /// + /// Checks whether the physical field can be cast to the logical field type, + /// handling both struct and scalar types. Returns a CastColumnExpr with the + /// appropriate configuration. + fn create_cast_column_expr( + &self, + column: Column, + logical_field: &Field, + ) -> Result>> { + // Get the actual field at the column's index (not pre-calculated) + // This is important when the column was recreated with a different index + let actual_physical_field = self.physical_file_schema.field(column.index()); + + // Validate type compatibility for struct and scalar types // For struct types, use validate_struct_compatibility which handles: // - Missing fields in source (filled with nulls) // - Extra fields in source (ignored) // - Recursive validation of nested structs // For non-struct types, use Arrow's can_cast_types - match (physical_field.data_type(), logical_field.data_type()) { + match (actual_physical_field.data_type(), logical_field.data_type()) { (DataType::Struct(physical_fields), DataType::Struct(logical_fields)) => { - validate_struct_compatibility(physical_fields, logical_fields)?; + validate_struct_compatibility( + physical_fields.as_ref(), + logical_fields.as_ref(), + )?; } _ => { - let is_compatible = - can_cast_types(physical_field.data_type(), logical_field.data_type()); + let is_compatible = can_cast_types( + actual_physical_field.data_type(), + logical_field.data_type(), + ); if !is_compatible { return exec_err!( "Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type)", column.name(), - physical_field.data_type(), + actual_physical_field.data_type(), logical_field.data_type() ); } } } - let cast_expr = Arc::new(CastColumnExpr::new( + let cast_expr = Arc::new(CastColumnExpr::new_with_schema( Arc::new(column), - Arc::new(physical_field.clone()), + Arc::new(actual_physical_field.as_ref().clone()), Arc::new(logical_field.clone()), None, - )); + Arc::clone(&self.physical_file_schema), + )?); Ok(Transformed::yes(cast_expr)) } @@ -662,9 +707,11 @@ mod tests { #[test] fn test_rewrite_multi_column_expr_with_type_cast() { let (physical_schema, logical_schema) = create_test_schema(); + let physical_schema = Arc::new(physical_schema); + let logical_schema = Arc::new(logical_schema); let factory = DefaultPhysicalExprAdapterFactory; let adapter = factory - .create(Arc::new(logical_schema), Arc::new(physical_schema)) + .create(Arc::clone(&logical_schema), Arc::clone(&physical_schema)) .unwrap(); // Create a complex expression: (a + 5) OR (c > 0.0) that tests the recursive case of the rewriter @@ -689,12 +736,16 @@ mod tests { println!("Rewritten expression: {result}"); let expected = expressions::BinaryExpr::new( - Arc::new(CastColumnExpr::new( - Arc::new(Column::new("a", 0)), - Arc::new(Field::new("a", DataType::Int32, false)), - Arc::new(Field::new("a", DataType::Int64, false)), - None, - )), + Arc::new( + CastColumnExpr::new_with_schema( + Arc::new(Column::new("a", 0)), + Arc::new(Field::new("a", DataType::Int32, false)), + Arc::new(Field::new("a", DataType::Int64, false)), + None, + Arc::clone(&physical_schema), + ) + .expect("cast column expr"), + ), Operator::Plus, Arc::new(expressions::Literal::new(ScalarValue::Int64(Some(5)))), ); @@ -769,40 +820,51 @@ mod tests { false, )]); + let physical_schema = Arc::new(physical_schema); + let logical_schema = Arc::new(logical_schema); let factory = DefaultPhysicalExprAdapterFactory; let adapter = factory - .create(Arc::new(logical_schema), Arc::new(physical_schema)) + .create(Arc::clone(&logical_schema), Arc::clone(&physical_schema)) .unwrap(); let column_expr = Arc::new(Column::new("data", 0)); let result = adapter.rewrite(column_expr).unwrap(); - let expected = Arc::new(CastColumnExpr::new( - Arc::new(Column::new("data", 0)), - Arc::new(Field::new( - "data", - DataType::Struct( - vec![ - Field::new("id", DataType::Int32, false), - Field::new("name", DataType::Utf8, true), - ] - .into(), - ), - false, - )), - Arc::new(Field::new( - "data", - DataType::Struct( - vec![ - Field::new("id", DataType::Int64, false), - Field::new("name", DataType::Utf8View, true), - ] - .into(), - ), - false, - )), - None, - )) as Arc; + // Build expected physical (source) field: Struct(id: Int32, name: Utf8) + let physical_struct_fields: Fields = vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ] + .into(); + let physical_field = Arc::new(Field::new( + "data", + DataType::Struct(physical_struct_fields), + false, + )); + + // Build expected logical (target) field: Struct(id: Int64, name: Utf8View) + let logical_struct_fields: Fields = vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8View, true), + ] + .into(); + let logical_field = Arc::new(Field::new( + "data", + DataType::Struct(logical_struct_fields), + false, + )); + + // Create the expected cast expression + let expected = Arc::new( + CastColumnExpr::new_with_schema( + Arc::new(Column::new("data", 0)), + physical_field, + logical_field, + None, + Arc::clone(&physical_schema), + ) + .expect("cast column expr"), + ) as Arc; assert_eq!(result.to_string(), expected.to_string()); } @@ -1193,8 +1255,8 @@ mod tests { )]); let rewriter = DefaultPhysicalExprAdapterRewriter { - logical_file_schema: &logical_schema, - physical_file_schema: &physical_schema, + logical_file_schema: Arc::new(logical_schema), + physical_file_schema: Arc::new(physical_schema), }; // Test that when a field exists in physical schema, it returns None diff --git a/datafusion/physical-expr/src/equivalence/properties/dependency.rs b/datafusion/physical-expr/src/equivalence/properties/dependency.rs index edbf7033f4e7a..3f2b03d9127f7 100644 --- a/datafusion/physical-expr/src/equivalence/properties/dependency.rs +++ b/datafusion/physical-expr/src/equivalence/properties/dependency.rs @@ -390,7 +390,7 @@ mod tests { convert_to_sort_reqs, create_test_params, create_test_schema, parse_sort_expr, }; use crate::equivalence::{ProjectionMapping, convert_to_sort_exprs}; - use crate::expressions::{BinaryExpr, CastExpr, Column, col}; + use crate::expressions::{BinaryExpr, CastColumnExpr, CastExpr, Column, col}; use crate::projection::tests::output_schema; use crate::{ConstExpr, EquivalenceProperties, ScalarFunctionExpr}; @@ -441,7 +441,8 @@ mod tests { let col_a2 = &col("a2", &out_schema)?; let col_a3 = &col("a3", &out_schema)?; let col_a4 = &col("a4", &out_schema)?; - let out_properties = input_properties.project(&projection_mapping, out_schema); + let out_properties = + input_properties.project(&projection_mapping, Arc::clone(&out_schema)); // At the output a1=a2=a3=a4 assert_eq!(out_properties.eq_group().len(), 1); @@ -500,6 +501,40 @@ mod tests { Ok(()) } + #[test] + fn project_ordering_with_cast_column_expr() -> Result<()> { + let input_schema = + Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let col_a = col("a", &input_schema)?; + let mut input_properties = EquivalenceProperties::new(Arc::clone(&input_schema)); + input_properties + .add_ordering([PhysicalSortExpr::new_default(Arc::clone(&col_a))]); + + let input_field = Arc::new(input_schema.field(0).clone()); + let target_field = Arc::new(Field::new("a_cast", DataType::Int64, true)); + let cast_col = Arc::new(CastColumnExpr::new_with_schema( + Arc::clone(&col_a), + input_field, + target_field, + None, + Arc::clone(&input_schema), + )?) as Arc; + + let proj_exprs = vec![ + (Arc::clone(&col_a), "a".to_string()), + (Arc::clone(&cast_col), "a_cast".to_string()), + ]; + let projection_mapping = ProjectionMapping::try_new(proj_exprs, &input_schema)?; + let out_schema = output_schema(&projection_mapping, &input_schema)?; + let out_properties = + input_properties.project(&projection_mapping, Arc::clone(&out_schema)); + + let cast_sort_expr = PhysicalSortExpr::new_default(col("a_cast", &out_schema)?); + assert!(out_properties.ordering_satisfy([cast_sort_expr])?); + + Ok(()) + } + #[test] fn test_normalize_ordering_equivalence_classes() -> Result<()> { let schema = Schema::new(vec![ diff --git a/datafusion/physical-expr/src/equivalence/properties/mod.rs b/datafusion/physical-expr/src/equivalence/properties/mod.rs index 996bc4b08fcd2..734abf85e849e 100644 --- a/datafusion/physical-expr/src/equivalence/properties/mod.rs +++ b/datafusion/physical-expr/src/equivalence/properties/mod.rs @@ -33,7 +33,7 @@ use self::dependency::{ use crate::equivalence::{ AcrossPartitions, EquivalenceGroup, OrderingEquivalenceClass, ProjectionMapping, }; -use crate::expressions::{CastExpr, Column, Literal, with_new_schema}; +use crate::expressions::{CastColumnExpr, CastExpr, Column, Literal, with_new_schema}; use crate::{ ConstExpr, LexOrdering, LexRequirement, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement, @@ -853,6 +853,15 @@ impl EquivalenceProperties { sort_expr.options, )); } + } else if let Some(cast_col) = + r_expr.as_any().downcast_ref::() + && cast_col.expr().eq(&sort_expr.expr) + && CastExpr::check_bigger_cast( + cast_col.target_field().data_type(), + &expr_type, + ) + { + result.push(PhysicalSortExpr::new(r_expr, sort_expr.options)); } } result.push(sort_expr); diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index 6fced231f3e6f..cfc677ecb1397 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -22,25 +22,27 @@ use std::sync::Arc; use crate::physical_expr::PhysicalExpr; -use arrow::compute::{CastOptions, can_cast_types}; +use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, DataType::*, FieldRef, Schema}; use arrow::record_batch::RecordBatch; -use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; +use datafusion_common::format::OwnedCastOptions; use datafusion_common::nested_struct::validate_struct_compatibility; use datafusion_common::{Result, not_impl_err}; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; use datafusion_expr_common::sort_properties::ExprProperties; -const DEFAULT_CAST_OPTIONS: CastOptions<'static> = CastOptions { - safe: false, - format_options: DEFAULT_FORMAT_OPTIONS, -}; +// Default cast options using owned strings. +// These are created once and cloned as needed - the cloning is cheap +// since it only clones small String values (typically empty for null +// and None for optional format fields). +fn default_cast_options() -> OwnedCastOptions { + OwnedCastOptions::default() +} -const DEFAULT_SAFE_CAST_OPTIONS: CastOptions<'static> = CastOptions { - safe: true, - format_options: DEFAULT_FORMAT_OPTIONS, -}; +fn default_safe_cast_options() -> OwnedCastOptions { + OwnedCastOptions::new(true) +} /// Check if struct-to-struct casting is allowed by validating field compatibility. /// @@ -65,8 +67,8 @@ pub struct CastExpr { pub expr: Arc, /// The data type to cast to cast_type: DataType, - /// Cast options - cast_options: CastOptions<'static>, + /// Cast options (owned, allowing dynamic format strings without leaks) + cast_options: OwnedCastOptions, } // Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808 @@ -91,12 +93,12 @@ impl CastExpr { pub fn new( expr: Arc, cast_type: DataType, - cast_options: Option>, + cast_options: Option, ) -> Self { Self { expr, cast_type, - cast_options: cast_options.unwrap_or(DEFAULT_CAST_OPTIONS), + cast_options: cast_options.unwrap_or_else(default_cast_options), } } @@ -110,8 +112,8 @@ impl CastExpr { &self.cast_type } - /// The cast options - pub fn cast_options(&self) -> &CastOptions<'static> { + /// The cast options (owned, with ephemeral borrowing for Arrow functions) + pub fn cast_options(&self) -> &OwnedCastOptions { &self.cast_options } @@ -166,7 +168,11 @@ impl PhysicalExpr for CastExpr { fn evaluate(&self, batch: &RecordBatch) -> Result { let value = self.expr.evaluate(batch)?; - value.cast_to(&self.cast_type, Some(&self.cast_options)) + // Convert OwnedCastOptions to arrow's CastOptions for the computation + // This is ephemeral borrowing - the borrowed references are only + // valid during this call and are dropped immediately after + let arrow_options = self.cast_options.as_arrow_options(); + value.cast_to(&self.cast_type, Some(&arrow_options)) } fn return_field(&self, input_schema: &Schema) -> Result { @@ -195,8 +201,10 @@ impl PhysicalExpr for CastExpr { } fn evaluate_bounds(&self, children: &[&Interval]) -> Result { - // Cast current node's interval to the right type: - children[0].cast_to(&self.cast_type, &self.cast_options) + // Cast current node's interval to the right type. + // Convert OwnedCastOptions to arrow's CastOptions for the computation. + let arrow_options = self.cast_options.as_arrow_options(); + children[0].cast_to(&self.cast_type, &arrow_options) } fn propagate_constraints( @@ -207,9 +215,9 @@ impl PhysicalExpr for CastExpr { let child_interval = children[0]; // Get child's datatype: let cast_type = child_interval.data_type(); - Ok(Some(vec![ - interval.cast_to(&cast_type, &DEFAULT_SAFE_CAST_OPTIONS)?, - ])) + let safe_options = default_safe_cast_options(); + let arrow_options = safe_options.as_arrow_options(); + Ok(Some(vec![interval.cast_to(&cast_type, &arrow_options)?])) } /// A [`CastExpr`] preserves the ordering of its child if the cast is done @@ -247,7 +255,7 @@ pub fn cast_with_options( expr: Arc, input_schema: &Schema, cast_type: DataType, - cast_options: Option>, + cast_options: Option, ) -> Result> { let expr_type = expr.data_type(input_schema)?; if expr_type == cast_type { @@ -468,7 +476,7 @@ mod tests { col("a", &schema)?, &schema, Decimal128(6, 2), - Some(DEFAULT_SAFE_CAST_OPTIONS), + Some(default_safe_cast_options()), )?; let result_safe = expression_safe .evaluate(&batch)? diff --git a/datafusion/physical-expr/src/expressions/cast_column.rs b/datafusion/physical-expr/src/expressions/cast_column.rs index d80b6f4a588a4..1487f3d3a9523 100644 --- a/datafusion/physical-expr/src/expressions/cast_column.rs +++ b/datafusion/physical-expr/src/expressions/cast_column.rs @@ -17,14 +17,19 @@ //! Physical expression for struct-aware casting of columns. -use crate::physical_expr::PhysicalExpr; +use crate::{expressions::Column, physical_expr::PhysicalExpr}; use arrow::{ - compute::CastOptions, + compute::can_cast_types, datatypes::{DataType, FieldRef, Schema}, record_batch::RecordBatch, }; use datafusion_common::{ - Result, ScalarValue, format::DEFAULT_CAST_OPTIONS, nested_struct::cast_column, + Result, ScalarValue, + format::OwnedCastOptions, + nested_struct::{ + cast_column, validate_field_compatibility, validate_struct_compatibility, + }, + plan_err, }; use datafusion_expr_common::columnar_value::ColumnarValue; use std::{ @@ -54,8 +59,10 @@ pub struct CastColumnExpr { input_field: FieldRef, /// The field metadata describing the desired output column. target_field: FieldRef, - /// Options forwarded to [`cast_column`]. - cast_options: CastOptions<'static>, + /// Options forwarded to [`cast_column`] (owned, allowing dynamic format strings). + cast_options: OwnedCastOptions, + /// Schema used to resolve expression data types during construction. + input_schema: Arc, } // Manually derive `PartialEq`/`Hash` as `Arc` does not @@ -78,20 +85,132 @@ impl Hash for CastColumnExpr { } } +fn normalize_cast_options(cast_options: Option) -> OwnedCastOptions { + cast_options.unwrap_or_default() +} + +/// Validates that a cast is compatible between input and target fields. +/// +/// This function checks: +/// - If the expression is a Column, its index is within the schema bounds +/// - If the expression is a Column, its data type is castable to the input field type +/// - The input field can be cast to the target field (using validate_field_compatibility) +/// - For struct types, field compatibility is validated recursively via validate_struct_compatibility +fn validate_cast_compatibility( + expr: &Arc, + input_field: &FieldRef, + target_field: &FieldRef, + input_schema: &Schema, +) -> Result<()> { + // Validate that if the expression is a Column, it's within the schema bounds + if let Some(column) = expr.as_any().downcast_ref::() { + let fields = input_schema.fields(); + if column.index() >= fields.len() { + return plan_err!( + "CastColumnExpr column index {} is out of bounds for input schema with {} fields", + column.index(), + fields.len() + ); + } + + // Validate that the column's field data type is compatible with the input_field for casting. + // We use can_cast_types for this check since schema fields may have different names/metadata. + let schema_field = &fields[column.index()]; + if schema_field.data_type() != input_field.data_type() { + let is_compatible = + can_cast_types(schema_field.data_type(), input_field.data_type()); + if !is_compatible { + return plan_err!( + "CastColumnExpr column '{}' at index {} has data type '{}' which is not compatible with input field data type '{}' - they cannot be cast", + column.name(), + column.index(), + schema_field.data_type(), + input_field.data_type() + ); + } + } + } + + // Validate the cast from input_field to target_field using the same logic as nested_struct. + // This ensures consistent nullability and data type checking across all field contexts. + match (input_field.data_type(), target_field.data_type()) { + (DataType::Struct(source_fields), DataType::Struct(target_fields)) => { + validate_struct_compatibility(source_fields, target_fields)?; + } + (_, DataType::Struct(_)) => { + return plan_err!( + "CastColumnExpr cannot cast non-struct input '{}' to struct target '{}'", + input_field.data_type(), + target_field.data_type() + ); + } + _ => { + // For non-struct types, use the same field validation as struct fields. + // This ensures consistent nullability checking across all contexts. + validate_field_compatibility(input_field, target_field)?; + } + } + + Ok(()) +} + impl CastColumnExpr { + fn build( + expr: Arc, + input_field: FieldRef, + target_field: FieldRef, + cast_options: Option, + input_schema: Arc, + ) -> Result { + let cast_options = normalize_cast_options(cast_options); + + // Validate cast compatibility before constructing the expression + validate_cast_compatibility(&expr, &input_field, &target_field, &input_schema)?; + + Ok(Self { + expr, + input_field, + target_field, + cast_options, + input_schema, + }) + } + /// Create a new [`CastColumnExpr`]. + /// + /// This constructor ensures that format options are populated with defaults, + /// normalizing the CastOptions for consistent behavior during serialization + /// and evaluation. It constructs a single-field schema from `input_field`, + /// so it should only be used for expressions that resolve their type from + /// that field alone. pub fn new( expr: Arc, input_field: FieldRef, target_field: FieldRef, - cast_options: Option>, - ) -> Self { - Self { + cast_options: Option, + ) -> Result { + let input_schema = Schema::new(vec![input_field.as_ref().clone()]); + Self::build( expr, input_field, target_field, - cast_options: cast_options.unwrap_or(DEFAULT_CAST_OPTIONS), - } + cast_options, + Arc::new(input_schema), + ) + } + + /// Create a new [`CastColumnExpr`] with a specific input schema. + /// + /// Use this constructor when the expression depends on a broader schema, + /// such as multi-column expressions or columns with non-zero indexes. + pub fn new_with_schema( + expr: Arc, + input_field: FieldRef, + target_field: FieldRef, + cast_options: Option, + input_schema: Arc, + ) -> Result { + Self::build(expr, input_field, target_field, cast_options, input_schema) } /// The expression that produces the value to be cast. @@ -108,6 +227,11 @@ impl CastColumnExpr { pub fn target_field(&self) -> &FieldRef { &self.target_field } + + /// Options forwarded to [`cast_column`]. + pub fn cast_options(&self) -> &OwnedCastOptions { + &self.cast_options + } } impl Display for CastColumnExpr { @@ -136,19 +260,18 @@ impl PhysicalExpr for CastColumnExpr { fn evaluate(&self, batch: &RecordBatch) -> Result { let value = self.expr.evaluate(batch)?; + // Convert OwnedCastOptions to arrow's CastOptions for the computation + let arrow_options = self.cast_options.as_arrow_options(); match value { ColumnarValue::Array(array) => { let casted = - cast_column(&array, self.target_field.as_ref(), &self.cast_options)?; + cast_column(&array, self.target_field.as_ref(), &arrow_options)?; Ok(ColumnarValue::Array(casted)) } ColumnarValue::Scalar(scalar) => { let as_array = scalar.to_array_of_size(1)?; - let casted = cast_column( - &as_array, - self.target_field.as_ref(), - &self.cast_options, - )?; + let casted = + cast_column(&as_array, self.target_field.as_ref(), &arrow_options)?; let result = ScalarValue::try_from_array(casted.as_ref(), 0)?; Ok(ColumnarValue::Scalar(result)) } @@ -169,12 +292,13 @@ impl PhysicalExpr for CastColumnExpr { ) -> Result> { assert_eq!(children.len(), 1); let child = children.pop().expect("CastColumnExpr child"); - Ok(Arc::new(Self::new( + Ok(Arc::new(Self::new_with_schema( child, Arc::clone(&self.input_field), Arc::clone(&self.target_field), Some(self.cast_options.clone()), - ))) + Arc::clone(&self.input_schema), + )?)) } fn fmt_sql(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -192,7 +316,7 @@ mod tests { datatypes::{DataType, Field, Fields, SchemaRef}, }; use datafusion_common::{ - Result as DFResult, ScalarValue, + Result as DFResult, ScalarValue, assert_contains, cast::{as_int64_array, as_string_array, as_struct_array, as_uint8_array}, }; @@ -214,12 +338,13 @@ mod tests { let batch = RecordBatch::try_new(Arc::clone(&schema), vec![values])?; let column = Arc::new(Column::new_with_schema("a", schema.as_ref())?); - let expr = CastColumnExpr::new( + let expr = CastColumnExpr::new_with_schema( column, Arc::new(input_field.clone()), Arc::new(target_field.clone()), None, - ); + Arc::clone(&schema), + )?; let result = expr.evaluate(&batch)?; let ColumnarValue::Array(array) = result else { @@ -268,12 +393,13 @@ mod tests { )?; let column = Arc::new(Column::new_with_schema("s", schema.as_ref())?); - let expr = CastColumnExpr::new( + let expr = CastColumnExpr::new_with_schema( column, Arc::new(input_field.clone()), Arc::new(target_field.clone()), None, - ); + Arc::clone(&schema), + )?; let result = expr.evaluate(&batch)?; let ColumnarValue::Array(array) = result else { @@ -338,12 +464,13 @@ mod tests { )?; let column = Arc::new(Column::new_with_schema("root", schema.as_ref())?); - let expr = CastColumnExpr::new( + let expr = CastColumnExpr::new_with_schema( column, Arc::new(outer_field.clone()), Arc::new(target_field.clone()), None, - ); + Arc::clone(&schema), + )?; let result = expr.evaluate(&batch)?; let ColumnarValue::Array(array) = result else { @@ -389,12 +516,13 @@ mod tests { ); let literal = Arc::new(Literal::new(ScalarValue::Struct(Arc::new(scalar_struct)))); - let expr = CastColumnExpr::new( + let expr = CastColumnExpr::new_with_schema( literal, Arc::new(input_field.clone()), Arc::new(target_field.clone()), None, - ); + Arc::clone(&schema), + )?; let batch = RecordBatch::new_empty(Arc::clone(&schema)); let result = expr.evaluate(&batch)?; @@ -406,4 +534,61 @@ mod tests { assert_eq!(casted.value(0), 9); Ok(()) } + + #[test] + fn cast_column_schema_mismatch() { + // Test that an error is raised when data types are not compatible for casting + let input_field = Field::new("a", DataType::Int32, true); + let target_field = Field::new("a", DataType::Int32, true); + let schema = Arc::new(Schema::new(vec![ + input_field.clone(), + Field::new( + "b", + DataType::Struct( + vec![Field::new("nested", DataType::Int32, true)].into(), + ), + true, + ), + ])); + + let column = Arc::new(Column::new("b", 1)); + let err = CastColumnExpr::new_with_schema( + column, + Arc::new(input_field), + Arc::new(target_field), + None, + schema, + ) + .expect_err("expected incompatible data type error"); + + assert_contains!( + err.to_string(), + r#"CastColumnExpr column 'b' at index 1 has data type 'Struct("nested": Int32)' which is not compatible with input field data type 'Int32' - they cannot be cast"# + ); + } + + #[test] + fn cast_column_schema_mismatch_nullability_metadata() { + // CastColumnExpr reuses validate_field_compatibility from nested_struct, + // it properly rejects nullable -> non-nullable casts to prevent data loss. + let input_field = Field::new("a", DataType::Int32, true); // nullable + let target_field = Field::new("a", DataType::Int32, false); // non-nullable + let schema = Arc::new(Schema::new(vec![input_field.clone()])); + + let column = Arc::new(Column::new("a", 0)); + + let err = CastColumnExpr::new_with_schema( + column, + Arc::new(input_field), + Arc::new(target_field), + None, + schema, + ) + .expect_err("should reject nullable -> non-nullable cast"); + + assert_contains!( + err.to_string(), + "Cannot cast nullable struct field 'a' to non-nullable field" + ); + } } diff --git a/datafusion/physical-expr/src/intervals/utils.rs b/datafusion/physical-expr/src/intervals/utils.rs index 3cada63a34ace..ab2f4be2ab96f 100644 --- a/datafusion/physical-expr/src/intervals/utils.rs +++ b/datafusion/physical-expr/src/intervals/utils.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use crate::{ PhysicalExpr, - expressions::{BinaryExpr, CastExpr, Column, Literal, NegativeExpr}, + expressions::{BinaryExpr, CastColumnExpr, CastExpr, Column, Literal, NegativeExpr}, }; use arrow::array::types::{IntervalDayTime, IntervalMonthDayNano}; @@ -34,7 +34,8 @@ use datafusion_expr::interval_arithmetic::Interval; /// Currently, we do not support all [`PhysicalExpr`]s for interval calculations. /// We do not support every type of [`Operator`]s either. Over time, this check /// will relax as more types of `PhysicalExpr`s and `Operator`s are supported. -/// Currently, [`CastExpr`], [`NegativeExpr`], [`BinaryExpr`], [`Column`] and [`Literal`] are supported. +/// Currently, [`CastExpr`], [`CastColumnExpr`], [`NegativeExpr`], [`BinaryExpr`], [`Column`] and +/// [`Literal`] are supported. pub fn check_support(expr: &Arc, schema: &SchemaRef) -> bool { let expr_any = expr.as_any(); if let Some(binary_expr) = expr_any.downcast_ref::() { @@ -55,6 +56,8 @@ pub fn check_support(expr: &Arc, schema: &SchemaRef) -> bool { } } else if let Some(cast) = expr_any.downcast_ref::() { check_support(cast.expr(), schema) + } else if let Some(cast_column) = expr_any.downcast_ref::() { + check_support(cast_column.expr(), schema) } else if let Some(negative) = expr_any.downcast_ref::() { check_support(negative.arg(), schema) } else { @@ -191,3 +194,32 @@ fn interval_dt_to_duration_ms(dt: &IntervalDayTime) -> Result { ) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::expressions::{CastColumnExpr, col}; + use arrow::datatypes::{DataType, Field, Schema}; + use std::sync::Arc; + + #[test] + fn test_check_support_with_cast_column_expr() { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let input_field = Arc::new(schema.field(0).clone()); + let target_field = Arc::new(Field::new("a", DataType::Int64, true)); + + let column_expr = col("a", &schema).unwrap(); + let cast_expr = Arc::new( + CastColumnExpr::new_with_schema( + column_expr, + input_field, + target_field, + None, + Arc::clone(&schema), + ) + .expect("cast column expr"), + ) as Arc; + + assert!(check_support(&cast_expr, &schema)); + } +} diff --git a/datafusion/physical-expr/src/simplifier/unwrap_cast.rs b/datafusion/physical-expr/src/simplifier/unwrap_cast.rs index ae6da9c5e0dc5..ac5d16110b0dc 100644 --- a/datafusion/physical-expr/src/simplifier/unwrap_cast.rs +++ b/datafusion/physical-expr/src/simplifier/unwrap_cast.rs @@ -42,7 +42,9 @@ use datafusion_expr::Operator; use datafusion_expr_common::casts::try_cast_literal_to_type; use crate::PhysicalExpr; -use crate::expressions::{BinaryExpr, CastExpr, Literal, TryCastExpr, lit}; +use crate::expressions::{ + BinaryExpr, CastColumnExpr, CastExpr, Literal, TryCastExpr, lit, +}; /// Attempts to unwrap casts in comparison expressions. pub(crate) fn unwrap_cast_in_comparison( @@ -112,6 +114,8 @@ fn extract_cast_info( ) -> Option<(&Arc, &DataType)> { if let Some(cast) = expr.as_any().downcast_ref::() { Some((cast.expr(), cast.cast_type())) + } else if let Some(cast_col) = expr.as_any().downcast_ref::() { + Some((cast_col.expr(), cast_col.target_field().data_type())) } else if let Some(try_cast) = expr.as_any().downcast_ref::() { Some((try_cast.expr(), try_cast.cast_type())) } else { @@ -142,7 +146,7 @@ fn try_unwrap_cast_comparison( #[cfg(test)] mod tests { use super::*; - use crate::expressions::{col, lit}; + use crate::expressions::{CastColumnExpr, col, lit}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::ScalarValue; use datafusion_expr::Operator; @@ -150,6 +154,7 @@ mod tests { /// Check if an expression is a cast expression fn is_cast_expr(expr: &Arc) -> bool { expr.as_any().downcast_ref::().is_some() + || expr.as_any().downcast_ref::().is_some() || expr.as_any().downcast_ref::().is_some() } @@ -208,6 +213,50 @@ mod tests { assert_eq!(right_literal.value(), &ScalarValue::Int32(Some(10))); } + #[test] + fn test_unwrap_cast_with_cast_column_expr() { + let schema = test_schema(); + let input_field = Arc::new(schema.field(0).clone()); + let target_field = Arc::new(Field::new("c1", DataType::Int64, false)); + + // Create: cast_column(c1 as INT64) > INT64(10) + let column_expr = col("c1", &schema).unwrap(); + let cast_expr = Arc::new( + CastColumnExpr::new_with_schema( + column_expr, + input_field, + target_field, + None, + Arc::new(schema.clone()), + ) + .expect("cast column expr"), + ); + let literal_expr = lit(10i64); + let binary_expr = + Arc::new(BinaryExpr::new(cast_expr, Operator::Gt, literal_expr)); + + // Apply unwrap cast optimization + let result = unwrap_cast_in_comparison(binary_expr, &schema).unwrap(); + + // Should be transformed + assert!(result.transformed); + + // The result should be: c1 > INT32(10) + let optimized = result.data; + let optimized_binary = optimized.as_any().downcast_ref::().unwrap(); + + // Check that left side is no longer a cast + assert!(!is_cast_expr(optimized_binary.left())); + + // Check that right side is a literal with the correct type and value + let right_literal = optimized_binary + .right() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(right_literal.value(), &ScalarValue::Int32(Some(10))); + } + #[test] fn test_unwrap_cast_with_literal_on_left() { let schema = test_schema(); diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 2efef9f12e701..5b9e2fd7b6510 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -875,6 +875,8 @@ message PhysicalExprNode { UnknownColumn unknown_column = 20; PhysicalHashExprNode hash_expr = 21; + + PhysicalCastColumnNode cast_column = 22; } } @@ -982,6 +984,42 @@ message PhysicalTryCastNode { message PhysicalCastNode { PhysicalExprNode expr = 1; datafusion_common.ArrowType arrow_type = 2; + PhysicalCastOptions cast_options = 3; +} + +message PhysicalCastColumnNode { + PhysicalExprNode expr = 1; + datafusion_common.Field input_field = 2; + datafusion_common.Field target_field = 3; + // DEPRECATED: Use cast_options instead of safe/format_options. + // These fields retained for backward compatibility with DataFusion < 43.0. + // When deserializing, safe and format_options are only used if cast_options is not set. + bool safe = 4; + FormatOptions format_options = 5; + PhysicalCastOptions cast_options = 6; +} + +message PhysicalCastOptions { + bool safe = 1; + FormatOptions format_options = 2; +} + +enum DurationFormat { + DURATION_FORMAT_UNSPECIFIED = 0; + DURATION_FORMAT_ISO8601 = 1; + DURATION_FORMAT_PRETTY = 2; +} + +message FormatOptions { + bool safe = 1; + string null = 2; + optional string date_format = 3; + optional string datetime_format = 4; + optional string timestamp_format = 5; + optional string timestamp_tz_format = 6; + optional string time_format = 7; + DurationFormat duration_format = 8; + bool types_info = 9; } message PhysicalNegativeNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 00870d5ce178c..41cf6ea326150 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -5452,6 +5452,80 @@ impl<'de> serde::Deserialize<'de> for DropViewNode { deserializer.deserialize_struct("datafusion.DropViewNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for DurationFormat { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::Unspecified => "DURATION_FORMAT_UNSPECIFIED", + Self::Iso8601 => "DURATION_FORMAT_ISO8601", + Self::Pretty => "DURATION_FORMAT_PRETTY", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for DurationFormat { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "DURATION_FORMAT_UNSPECIFIED", + "DURATION_FORMAT_ISO8601", + "DURATION_FORMAT_PRETTY", + ]; + + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = DurationFormat; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "DURATION_FORMAT_UNSPECIFIED" => Ok(DurationFormat::Unspecified), + "DURATION_FORMAT_ISO8601" => Ok(DurationFormat::Iso8601), + "DURATION_FORMAT_PRETTY" => Ok(DurationFormat::Pretty), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} impl serde::Serialize for EmptyExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -6831,6 +6905,242 @@ impl<'de> serde::Deserialize<'de> for FixedSizeBinary { deserializer.deserialize_struct("datafusion.FixedSizeBinary", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for FormatOptions { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.safe { + len += 1; + } + if !self.null.is_empty() { + len += 1; + } + if self.date_format.is_some() { + len += 1; + } + if self.datetime_format.is_some() { + len += 1; + } + if self.timestamp_format.is_some() { + len += 1; + } + if self.timestamp_tz_format.is_some() { + len += 1; + } + if self.time_format.is_some() { + len += 1; + } + if self.duration_format != 0 { + len += 1; + } + if self.types_info { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.FormatOptions", len)?; + if self.safe { + struct_ser.serialize_field("safe", &self.safe)?; + } + if !self.null.is_empty() { + struct_ser.serialize_field("null", &self.null)?; + } + if let Some(v) = self.date_format.as_ref() { + struct_ser.serialize_field("dateFormat", v)?; + } + if let Some(v) = self.datetime_format.as_ref() { + struct_ser.serialize_field("datetimeFormat", v)?; + } + if let Some(v) = self.timestamp_format.as_ref() { + struct_ser.serialize_field("timestampFormat", v)?; + } + if let Some(v) = self.timestamp_tz_format.as_ref() { + struct_ser.serialize_field("timestampTzFormat", v)?; + } + if let Some(v) = self.time_format.as_ref() { + struct_ser.serialize_field("timeFormat", v)?; + } + if self.duration_format != 0 { + let v = DurationFormat::try_from(self.duration_format) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.duration_format)))?; + struct_ser.serialize_field("durationFormat", &v)?; + } + if self.types_info { + struct_ser.serialize_field("typesInfo", &self.types_info)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for FormatOptions { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "safe", + "null", + "date_format", + "dateFormat", + "datetime_format", + "datetimeFormat", + "timestamp_format", + "timestampFormat", + "timestamp_tz_format", + "timestampTzFormat", + "time_format", + "timeFormat", + "duration_format", + "durationFormat", + "types_info", + "typesInfo", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Safe, + Null, + DateFormat, + DatetimeFormat, + TimestampFormat, + TimestampTzFormat, + TimeFormat, + DurationFormat, + TypesInfo, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "safe" => Ok(GeneratedField::Safe), + "null" => Ok(GeneratedField::Null), + "dateFormat" | "date_format" => Ok(GeneratedField::DateFormat), + "datetimeFormat" | "datetime_format" => Ok(GeneratedField::DatetimeFormat), + "timestampFormat" | "timestamp_format" => Ok(GeneratedField::TimestampFormat), + "timestampTzFormat" | "timestamp_tz_format" => Ok(GeneratedField::TimestampTzFormat), + "timeFormat" | "time_format" => Ok(GeneratedField::TimeFormat), + "durationFormat" | "duration_format" => Ok(GeneratedField::DurationFormat), + "typesInfo" | "types_info" => Ok(GeneratedField::TypesInfo), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = FormatOptions; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.FormatOptions") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut safe__ = None; + let mut null__ = None; + let mut date_format__ = None; + let mut datetime_format__ = None; + let mut timestamp_format__ = None; + let mut timestamp_tz_format__ = None; + let mut time_format__ = None; + let mut duration_format__ = None; + let mut types_info__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Safe => { + if safe__.is_some() { + return Err(serde::de::Error::duplicate_field("safe")); + } + safe__ = Some(map_.next_value()?); + } + GeneratedField::Null => { + if null__.is_some() { + return Err(serde::de::Error::duplicate_field("null")); + } + null__ = Some(map_.next_value()?); + } + GeneratedField::DateFormat => { + if date_format__.is_some() { + return Err(serde::de::Error::duplicate_field("dateFormat")); + } + date_format__ = map_.next_value()?; + } + GeneratedField::DatetimeFormat => { + if datetime_format__.is_some() { + return Err(serde::de::Error::duplicate_field("datetimeFormat")); + } + datetime_format__ = map_.next_value()?; + } + GeneratedField::TimestampFormat => { + if timestamp_format__.is_some() { + return Err(serde::de::Error::duplicate_field("timestampFormat")); + } + timestamp_format__ = map_.next_value()?; + } + GeneratedField::TimestampTzFormat => { + if timestamp_tz_format__.is_some() { + return Err(serde::de::Error::duplicate_field("timestampTzFormat")); + } + timestamp_tz_format__ = map_.next_value()?; + } + GeneratedField::TimeFormat => { + if time_format__.is_some() { + return Err(serde::de::Error::duplicate_field("timeFormat")); + } + time_format__ = map_.next_value()?; + } + GeneratedField::DurationFormat => { + if duration_format__.is_some() { + return Err(serde::de::Error::duplicate_field("durationFormat")); + } + duration_format__ = Some(map_.next_value::()? as i32); + } + GeneratedField::TypesInfo => { + if types_info__.is_some() { + return Err(serde::de::Error::duplicate_field("typesInfo")); + } + types_info__ = Some(map_.next_value()?); + } + } + } + Ok(FormatOptions { + safe: safe__.unwrap_or_default(), + null: null__.unwrap_or_default(), + date_format: date_format__, + datetime_format: datetime_format__, + timestamp_format: timestamp_format__, + timestamp_tz_format: timestamp_tz_format__, + time_format: time_format__, + duration_format: duration_format__.unwrap_or_default(), + types_info: types_info__.unwrap_or_default(), + }) + } + } + deserializer.deserialize_struct("datafusion.FormatOptions", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for FullTableReference { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -15580,6 +15890,186 @@ impl<'de> serde::Deserialize<'de> for PhysicalCaseNode { deserializer.deserialize_struct("datafusion.PhysicalCaseNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for PhysicalCastColumnNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.expr.is_some() { + len += 1; + } + if self.input_field.is_some() { + len += 1; + } + if self.target_field.is_some() { + len += 1; + } + if self.safe { + len += 1; + } + if self.format_options.is_some() { + len += 1; + } + if self.cast_options.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalCastColumnNode", len)?; + if let Some(v) = self.expr.as_ref() { + struct_ser.serialize_field("expr", v)?; + } + if let Some(v) = self.input_field.as_ref() { + struct_ser.serialize_field("inputField", v)?; + } + if let Some(v) = self.target_field.as_ref() { + struct_ser.serialize_field("targetField", v)?; + } + if self.safe { + struct_ser.serialize_field("safe", &self.safe)?; + } + if let Some(v) = self.format_options.as_ref() { + struct_ser.serialize_field("formatOptions", v)?; + } + if let Some(v) = self.cast_options.as_ref() { + struct_ser.serialize_field("castOptions", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for PhysicalCastColumnNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "expr", + "input_field", + "inputField", + "target_field", + "targetField", + "safe", + "format_options", + "formatOptions", + "cast_options", + "castOptions", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Expr, + InputField, + TargetField, + Safe, + FormatOptions, + CastOptions, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "expr" => Ok(GeneratedField::Expr), + "inputField" | "input_field" => Ok(GeneratedField::InputField), + "targetField" | "target_field" => Ok(GeneratedField::TargetField), + "safe" => Ok(GeneratedField::Safe), + "formatOptions" | "format_options" => Ok(GeneratedField::FormatOptions), + "castOptions" | "cast_options" => Ok(GeneratedField::CastOptions), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PhysicalCastColumnNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.PhysicalCastColumnNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut expr__ = None; + let mut input_field__ = None; + let mut target_field__ = None; + let mut safe__ = None; + let mut format_options__ = None; + let mut cast_options__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Expr => { + if expr__.is_some() { + return Err(serde::de::Error::duplicate_field("expr")); + } + expr__ = map_.next_value()?; + } + GeneratedField::InputField => { + if input_field__.is_some() { + return Err(serde::de::Error::duplicate_field("inputField")); + } + input_field__ = map_.next_value()?; + } + GeneratedField::TargetField => { + if target_field__.is_some() { + return Err(serde::de::Error::duplicate_field("targetField")); + } + target_field__ = map_.next_value()?; + } + GeneratedField::Safe => { + if safe__.is_some() { + return Err(serde::de::Error::duplicate_field("safe")); + } + safe__ = Some(map_.next_value()?); + } + GeneratedField::FormatOptions => { + if format_options__.is_some() { + return Err(serde::de::Error::duplicate_field("formatOptions")); + } + format_options__ = map_.next_value()?; + } + GeneratedField::CastOptions => { + if cast_options__.is_some() { + return Err(serde::de::Error::duplicate_field("castOptions")); + } + cast_options__ = map_.next_value()?; + } + } + } + Ok(PhysicalCastColumnNode { + expr: expr__, + input_field: input_field__, + target_field: target_field__, + safe: safe__.unwrap_or_default(), + format_options: format_options__, + cast_options: cast_options__, + }) + } + } + deserializer.deserialize_struct("datafusion.PhysicalCastColumnNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for PhysicalCastNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -15594,6 +16084,9 @@ impl serde::Serialize for PhysicalCastNode { if self.arrow_type.is_some() { len += 1; } + if self.cast_options.is_some() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalCastNode", len)?; if let Some(v) = self.expr.as_ref() { struct_ser.serialize_field("expr", v)?; @@ -15601,6 +16094,9 @@ impl serde::Serialize for PhysicalCastNode { if let Some(v) = self.arrow_type.as_ref() { struct_ser.serialize_field("arrowType", v)?; } + if let Some(v) = self.cast_options.as_ref() { + struct_ser.serialize_field("castOptions", v)?; + } struct_ser.end() } } @@ -15614,12 +16110,15 @@ impl<'de> serde::Deserialize<'de> for PhysicalCastNode { "expr", "arrow_type", "arrowType", + "cast_options", + "castOptions", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { Expr, ArrowType, + CastOptions, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -15643,6 +16142,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalCastNode { match value { "expr" => Ok(GeneratedField::Expr), "arrowType" | "arrow_type" => Ok(GeneratedField::ArrowType), + "castOptions" | "cast_options" => Ok(GeneratedField::CastOptions), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -15664,6 +16164,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalCastNode { { let mut expr__ = None; let mut arrow_type__ = None; + let mut cast_options__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Expr => { @@ -15678,17 +16179,133 @@ impl<'de> serde::Deserialize<'de> for PhysicalCastNode { } arrow_type__ = map_.next_value()?; } + GeneratedField::CastOptions => { + if cast_options__.is_some() { + return Err(serde::de::Error::duplicate_field("castOptions")); + } + cast_options__ = map_.next_value()?; + } } } Ok(PhysicalCastNode { expr: expr__, arrow_type: arrow_type__, + cast_options: cast_options__, }) } } deserializer.deserialize_struct("datafusion.PhysicalCastNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for PhysicalCastOptions { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if self.safe { + len += 1; + } + if self.format_options.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalCastOptions", len)?; + if self.safe { + struct_ser.serialize_field("safe", &self.safe)?; + } + if let Some(v) = self.format_options.as_ref() { + struct_ser.serialize_field("formatOptions", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for PhysicalCastOptions { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "safe", + "format_options", + "formatOptions", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Safe, + FormatOptions, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl serde::de::Visitor<'_> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "safe" => Ok(GeneratedField::Safe), + "formatOptions" | "format_options" => Ok(GeneratedField::FormatOptions), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = PhysicalCastOptions; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.PhysicalCastOptions") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut safe__ = None; + let mut format_options__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Safe => { + if safe__.is_some() { + return Err(serde::de::Error::duplicate_field("safe")); + } + safe__ = Some(map_.next_value()?); + } + GeneratedField::FormatOptions => { + if format_options__.is_some() { + return Err(serde::de::Error::duplicate_field("formatOptions")); + } + format_options__ = map_.next_value()?; + } + } + } + Ok(PhysicalCastOptions { + safe: safe__.unwrap_or_default(), + format_options: format_options__, + }) + } + } + deserializer.deserialize_struct("datafusion.PhysicalCastOptions", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for PhysicalColumn { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -15995,6 +16612,9 @@ impl serde::Serialize for PhysicalExprNode { physical_expr_node::ExprType::HashExpr(v) => { struct_ser.serialize_field("hashExpr", v)?; } + physical_expr_node::ExprType::CastColumn(v) => { + struct_ser.serialize_field("castColumn", v)?; + } } } struct_ser.end() @@ -16039,6 +16659,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { "unknownColumn", "hash_expr", "hashExpr", + "cast_column", + "castColumn", ]; #[allow(clippy::enum_variant_names)] @@ -16062,6 +16684,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { Extension, UnknownColumn, HashExpr, + CastColumn, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -16102,6 +16725,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { "extension" => Ok(GeneratedField::Extension), "unknownColumn" | "unknown_column" => Ok(GeneratedField::UnknownColumn), "hashExpr" | "hash_expr" => Ok(GeneratedField::HashExpr), + "castColumn" | "cast_column" => Ok(GeneratedField::CastColumn), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -16255,6 +16879,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { return Err(serde::de::Error::duplicate_field("hashExpr")); } expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::HashExpr) +; + } + GeneratedField::CastColumn => { + if expr_type__.is_some() { + return Err(serde::de::Error::duplicate_field("castColumn")); + } + expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::CastColumn) ; } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 39d2604d45cd1..eaa6c3f2e216b 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1279,7 +1279,7 @@ pub struct PhysicalExtensionNode { pub struct PhysicalExprNode { #[prost( oneof = "physical_expr_node::ExprType", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 18, 19, 20, 21" + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 18, 19, 20, 21, 22" )] pub expr_type: ::core::option::Option, } @@ -1332,6 +1332,8 @@ pub mod physical_expr_node { UnknownColumn(super::UnknownColumn), #[prost(message, tag = "21")] HashExpr(super::PhysicalHashExprNode), + #[prost(message, tag = "22")] + CastColumn(::prost::alloc::boxed::Box), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -1508,6 +1510,54 @@ pub struct PhysicalCastNode { pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, optional, tag = "2")] pub arrow_type: ::core::option::Option, + #[prost(message, optional, tag = "3")] + pub cast_options: ::core::option::Option, +} +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct PhysicalCastColumnNode { + #[prost(message, optional, boxed, tag = "1")] + pub expr: ::core::option::Option<::prost::alloc::boxed::Box>, + #[prost(message, optional, tag = "2")] + pub input_field: ::core::option::Option, + #[prost(message, optional, tag = "3")] + pub target_field: ::core::option::Option, + /// DEPRECATED: Use cast_options instead of safe/format_options. + /// These fields retained for backward compatibility with DataFusion < 43.0. + /// When deserializing, safe and format_options are only used if cast_options is not set. + #[prost(bool, tag = "4")] + pub safe: bool, + #[prost(message, optional, tag = "5")] + pub format_options: ::core::option::Option, + #[prost(message, optional, tag = "6")] + pub cast_options: ::core::option::Option, +} +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] +pub struct PhysicalCastOptions { + #[prost(bool, tag = "1")] + pub safe: bool, + #[prost(message, optional, tag = "2")] + pub format_options: ::core::option::Option, +} +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] +pub struct FormatOptions { + #[prost(bool, tag = "1")] + pub safe: bool, + #[prost(string, tag = "2")] + pub null: ::prost::alloc::string::String, + #[prost(string, optional, tag = "3")] + pub date_format: ::core::option::Option<::prost::alloc::string::String>, + #[prost(string, optional, tag = "4")] + pub datetime_format: ::core::option::Option<::prost::alloc::string::String>, + #[prost(string, optional, tag = "5")] + pub timestamp_format: ::core::option::Option<::prost::alloc::string::String>, + #[prost(string, optional, tag = "6")] + pub timestamp_tz_format: ::core::option::Option<::prost::alloc::string::String>, + #[prost(string, optional, tag = "7")] + pub time_format: ::core::option::Option<::prost::alloc::string::String>, + #[prost(enumeration = "DurationFormat", tag = "8")] + pub duration_format: i32, + #[prost(bool, tag = "9")] + pub types_info: bool, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct PhysicalNegativeNode { @@ -2285,6 +2335,35 @@ impl InsertOp { } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] +pub enum DurationFormat { + Unspecified = 0, + Iso8601 = 1, + Pretty = 2, +} +impl DurationFormat { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::Unspecified => "DURATION_FORMAT_UNSPECIFIED", + Self::Iso8601 => "DURATION_FORMAT_ISO8601", + Self::Pretty => "DURATION_FORMAT_PRETTY", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "DURATION_FORMAT_UNSPECIFIED" => Some(Self::Unspecified), + "DURATION_FORMAT_ISO8601" => Some(Self::Iso8601), + "DURATION_FORMAT_PRETTY" => Some(Self::Pretty), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] pub enum PartitionMode { CollectLeft = 0, Partitioned = 1, diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index fc8eba12c5391..c433e862bca6f 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -24,7 +24,15 @@ use arrow::compute::SortOptions; use arrow::datatypes::{Field, Schema}; use arrow::ipc::reader::StreamReader; use chrono::{TimeZone, Utc}; -use datafusion_common::{DataFusionError, Result, internal_datafusion_err, not_impl_err}; +use datafusion_expr::dml::InsertOp; +use object_store::ObjectMeta; +use object_store::path::Path; + +use arrow::util::display::DurationFormat; +use datafusion_common::{ + DataFusionError, OwnedCastOptions, OwnedFormatOptions, Result, + internal_datafusion_err, not_impl_err, +}; use datafusion_datasource::file::FileSource; use datafusion_datasource::file_groups::FileGroup; use datafusion_datasource::file_scan_config::{FileScanConfig, FileScanConfigBuilder}; @@ -37,19 +45,16 @@ use datafusion_datasource_parquet::file_format::ParquetSink; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{FunctionRegistry, TaskContext}; use datafusion_expr::WindowFunctionDefinition; -use datafusion_expr::dml::InsertOp; use datafusion_physical_expr::projection::{ProjectionExpr, ProjectionExprs}; use datafusion_physical_expr::{LexOrdering, PhysicalSortExpr, ScalarFunctionExpr}; use datafusion_physical_plan::expressions::{ - BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr, LikeExpr, Literal, - NegativeExpr, NotExpr, TryCastExpr, UnKnownColumn, in_list, + BinaryExpr, CaseExpr, CastColumnExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr, + LikeExpr, Literal, NegativeExpr, NotExpr, TryCastExpr, UnKnownColumn, in_list, }; use datafusion_physical_plan::joins::{HashExpr, SeededRandomState}; use datafusion_physical_plan::windows::{create_window_expr, schema_add_window_field}; use datafusion_physical_plan::{Partitioning, PhysicalExpr, WindowExpr}; use datafusion_proto_common::common::proto_error; -use object_store::ObjectMeta; -use object_store::path::Path; use super::{ DefaultPhysicalProtoConverter, PhysicalExtensionCodec, @@ -413,8 +418,37 @@ pub fn parse_physical_expr_with_converter( proto_converter, )?, convert_required!(e.arrow_type)?, - None, + cast_options_from_proto(e.cast_options.as_ref(), false, None)?, )), + ExprType::CastColumn(e) => { + let input_field = e + .input_field + .as_ref() + .ok_or_else(|| proto_error("Missing cast_column input_field"))?; + let target_field = e + .target_field + .as_ref() + .ok_or_else(|| proto_error("Missing cast_column target_field"))?; + let cast_options = cast_options_from_proto( + e.cast_options.as_ref(), + e.safe, + e.format_options.as_ref(), + )?; + Arc::new(CastColumnExpr::new_with_schema( + parse_required_physical_expr( + e.expr.as_deref(), + ctx, + "expr", + input_schema, + codec, + proto_converter, + )?, + Arc::new(Field::try_from(input_field)?), + Arc::new(Field::try_from(target_field)?), + cast_options, + Arc::new(input_schema.clone()), + )?) + } ExprType::TryCast(e) => Arc::new(TryCastExpr::new( parse_required_physical_expr( e.expr.as_deref(), @@ -835,6 +869,81 @@ impl TryFrom<&protobuf::FileSinkConfig> for FileSinkConfig { } } +// ============================================================================ +// Deserialization of Cast Options +// ============================================================================ +// +// OwnedCastOptions and OwnedFormatOptions resolve the lifetime +// mismatch between Arrow's `CastOptions<'static>` and Protobuf's +// owned `String` values +// +// 1. Protobuf provides owned `String` values (no lifetime constraints) +// 2. OwnedCastOptions stores these as owned `String` values +// 3. When executing, CastExpr/CastColumnExpr use ephemeral borrowing to convert +// to Arrow's `CastOptions<'a>` with borrowed `&str` references for compute kernels +// 4. Strings are properly dropped when the expression is dropped - no leaks! + +/// Convert protobuf format options to owned format options. +fn format_options_from_proto( + options: &protobuf::FormatOptions, +) -> Result { + let duration_format = duration_format_from_proto(options.duration_format)?; + Ok(OwnedFormatOptions { + null: options.null.clone(), + date_format: options.date_format.clone(), + datetime_format: options.datetime_format.clone(), + timestamp_format: options.timestamp_format.clone(), + timestamp_tz_format: options.timestamp_tz_format.clone(), + time_format: options.time_format.clone(), + duration_format, + types_info: options.types_info, + }) +} + +/// Convert protobuf cast options to owned cast options. +fn cast_options_from_proto( + cast_options: Option<&protobuf::PhysicalCastOptions>, + legacy_safe: bool, + legacy_format_options: Option<&protobuf::FormatOptions>, +) -> Result> { + if let Some(cast_options) = cast_options { + let format_options = cast_options + .format_options + .as_ref() + .map(format_options_from_proto) + .transpose()? + .unwrap_or_default(); + return Ok(Some(OwnedCastOptions { + safe: cast_options.safe, + format_options, + })); + } + + // Handle legacy fields for backward compatibility with DataFusion < 43.0 + if !legacy_safe && legacy_format_options.is_none() { + return Ok(None); + } + + let format_options = legacy_format_options + .map(format_options_from_proto) + .transpose()? + .unwrap_or_default(); + + Ok(Some(OwnedCastOptions { + safe: legacy_safe, + format_options, + })) +} + +fn duration_format_from_proto(value: i32) -> Result { + Ok(match protobuf::DurationFormat::try_from(value) { + Ok(protobuf::DurationFormat::Pretty) => DurationFormat::Pretty, + Ok(protobuf::DurationFormat::Iso8601) + | Ok(protobuf::DurationFormat::Unspecified) + | Err(_) => DurationFormat::ISO8601, + }) +} + #[cfg(test)] mod tests { use chrono::{TimeZone, Utc}; diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index f85b1d1e12b9d..39e6ff8b71b27 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -20,8 +20,10 @@ use std::sync::Arc; use arrow::array::RecordBatch; use arrow::datatypes::Schema; use arrow::ipc::writer::StreamWriter; +use arrow::util::display::{DurationFormat, FormatOptions as ArrowFormatOptions}; use datafusion_common::{ - DataFusionError, Result, internal_datafusion_err, internal_err, not_impl_err, + DataFusionError, OwnedCastOptions, OwnedFormatOptions, Result, + format::DEFAULT_CAST_OPTIONS, internal_datafusion_err, internal_err, not_impl_err, }; use datafusion_datasource::file_scan_config::FileScanConfig; use datafusion_datasource::file_sink_config::{FileSink, FileSinkConfig}; @@ -36,8 +38,8 @@ use datafusion_physical_expr::window::{SlidingAggregateWindowExpr, StandardWindo use datafusion_physical_expr_common::physical_expr::snapshot_physical_expr; use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; use datafusion_physical_plan::expressions::{ - BinaryExpr, CaseExpr, CastExpr, Column, InListExpr, IsNotNullExpr, IsNullExpr, - LikeExpr, Literal, NegativeExpr, NotExpr, TryCastExpr, UnKnownColumn, + BinaryExpr, CaseExpr, CastColumnExpr, CastExpr, Column, InListExpr, IsNotNullExpr, + IsNullExpr, LikeExpr, Literal, NegativeExpr, NotExpr, TryCastExpr, UnKnownColumn, }; use datafusion_physical_plan::joins::{HashExpr, HashTableLookupExpr}; use datafusion_physical_plan::udaf::AggregateFunctionExpr; @@ -418,6 +420,7 @@ pub fn serialize_physical_expr_with_converter( )), }) } else if let Some(cast) = expr.downcast_ref::() { + let cast_options = serialize_cast_options(cast.cast_options())?; Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::Cast(Box::new( protobuf::PhysicalCastNode { @@ -425,9 +428,31 @@ pub fn serialize_physical_expr_with_converter( proto_converter.physical_expr_to_proto(cast.expr(), codec)?, )), arrow_type: Some(cast.cast_type().try_into()?), + cast_options: Some(cast_options), }, ))), }) + } else if let Some(cast_column) = expr.downcast_ref::() { + let cast_options = serialize_cast_options(cast_column.cast_options())?; + let format_options = match cast_options.format_options.clone() { + Some(format_options) => format_options, + None => serialize_format_options(&DEFAULT_CAST_OPTIONS.format_options)?, + }; + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::CastColumn( + Box::new(protobuf::PhysicalCastColumnNode { + expr: Some(Box::new(serialize_physical_expr( + cast_column.expr(), + codec, + )?)), + input_field: Some(cast_column.input_field().as_ref().try_into()?), + target_field: Some(cast_column.target_field().as_ref().try_into()?), + safe: cast_column.cast_options().safe, + format_options: Some(format_options), + cast_options: Some(cast_options), + }), + )), + }) } else if let Some(cast) = expr.downcast_ref::() { Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::TryCast(Box::new( @@ -581,6 +606,55 @@ impl TryFrom<&PartitionedFile> for protobuf::PartitionedFile { } } +fn serialize_format_options( + options: &ArrowFormatOptions<'_>, +) -> Result { + Ok(protobuf::FormatOptions { + safe: options.safe(), + null: options.null().to_string(), + date_format: options.date_format().map(ToString::to_string), + datetime_format: options.datetime_format().map(ToString::to_string), + timestamp_format: options.timestamp_format().map(ToString::to_string), + timestamp_tz_format: options.timestamp_tz_format().map(ToString::to_string), + time_format: options.time_format().map(ToString::to_string), + duration_format: duration_format_to_proto(options.duration_format()) as i32, + types_info: options.types_info(), + }) +} + +fn serialize_cast_options( + options: &OwnedCastOptions, +) -> Result { + Ok(protobuf::PhysicalCastOptions { + safe: options.safe, + format_options: Some(serialize_owned_format_options(&options.format_options)?), + }) +} + +fn serialize_owned_format_options( + options: &OwnedFormatOptions, +) -> Result { + Ok(protobuf::FormatOptions { + safe: false, // safe is stored in CastOptions, not FormatOptions + null: options.null.clone(), + date_format: options.date_format.clone(), + datetime_format: options.datetime_format.clone(), + timestamp_format: options.timestamp_format.clone(), + timestamp_tz_format: options.timestamp_tz_format.clone(), + time_format: options.time_format.clone(), + duration_format: duration_format_to_proto(options.duration_format) as i32, + types_info: options.types_info, + }) +} + +fn duration_format_to_proto(format: DurationFormat) -> protobuf::DurationFormat { + match format { + DurationFormat::ISO8601 => protobuf::DurationFormat::Iso8601, + DurationFormat::Pretty => protobuf::DurationFormat::Pretty, + _ => protobuf::DurationFormat::Unspecified, + } +} + impl TryFrom<&FileRange> for protobuf::FileRange { type Error = DataFusionError; diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index f262020ab843c..902aab7ffa598 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -24,6 +24,7 @@ use std::vec; use arrow::array::RecordBatch; use arrow::csv::WriterBuilder; use arrow::datatypes::{Fields, TimeUnit}; +use arrow::util::display::DurationFormat; use datafusion::arrow::array::ArrayRef; use datafusion::arrow::compute::kernels::sort::SortOptions; use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema}; @@ -62,7 +63,8 @@ use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::expressions::{ - BinaryExpr, Column, NotExpr, PhysicalSortExpr, binary, cast, col, in_list, like, lit, + BinaryExpr, CastColumnExpr, Column, NotExpr, PhysicalSortExpr, binary, cast, col, + in_list, like, lit, }; use datafusion::physical_plan::filter::{FilterExec, FilterExecBuilder}; use datafusion::physical_plan::joins::{ @@ -92,7 +94,8 @@ use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; use datafusion_common::{ - DataFusionError, NullEquality, Result, UnnestOptions, exec_datafusion_err, + DataFusionError, NullEquality, OwnedCastOptions, OwnedFormatOptions, Result, + UnnestOptions, exec_datafusion_err, format::DEFAULT_CAST_OPTIONS, internal_datafusion_err, internal_err, not_impl_err, }; use datafusion_datasource::TableSchema; @@ -212,6 +215,204 @@ async fn all_types_context() -> Result { Ok(ctx) } +fn cast_fields( + name: &str, + input_type: DataType, + target_type: DataType, +) -> (Field, Field) { + let input_field = field_with_origin(name, input_type, false, "input"); + let target_field = field_with_origin(name, target_type, false, "target"); + (input_field, target_field) +} + +fn field_with_origin( + name: &str, + data_type: DataType, + nullable: bool, + origin: &str, +) -> Field { + let mut metadata = HashMap::new(); + metadata.insert("origin".to_string(), origin.to_string()); + Field::new(name, data_type, nullable).with_metadata(metadata) +} + +fn round_trip_cast_expr( + expr: Arc, + input_schema: &Schema, +) -> Result> { + let proto = serialize_cast_expr(&expr)?; + parse_cast_expr(&proto, input_schema) +} + +fn serialize_cast_expr( + expr: &Arc, +) -> Result { + let codec = DefaultPhysicalExtensionCodec {}; + datafusion_proto::physical_plan::to_proto::serialize_physical_expr(expr, &codec) +} + +fn parse_cast_expr( + proto: &protobuf::PhysicalExprNode, + input_schema: &Schema, +) -> Result> { + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + datafusion_proto::physical_plan::from_proto::parse_physical_expr( + proto, + &ctx.task_ctx(), + input_schema, + &codec, + ) +} + +fn downcast_cast_expr(expr: &Arc) -> Result<&CastColumnExpr> { + expr.as_any() + .downcast_ref::() + .ok_or_else(|| internal_datafusion_err!("Expected CastColumnExpr")) +} + +#[test] +fn roundtrip_cast_column_expr() -> Result<()> { + let (input_field, target_field) = cast_fields("a", DataType::Int32, DataType::Int64); + let cast_options = OwnedCastOptions { + safe: true, + format_options: OwnedFormatOptions { + null: "NULL".to_string(), + date_format: Some("%Y/%m/%d".to_string()), + datetime_format: None, + timestamp_format: None, + timestamp_tz_format: None, + time_format: None, + duration_format: DurationFormat::ISO8601, + types_info: false, + }, + }; + let input_schema = Schema::new(vec![input_field.clone()]); + + let expr = Arc::new(CastColumnExpr::new_with_schema( + Arc::new(Column::new("a", 0)), + Arc::new(input_field.clone()), + Arc::new(target_field.clone()), + Some(cast_options.clone()), + Arc::new(input_schema.clone()), + )?); + + let round_trip = round_trip_cast_expr(expr.clone(), &input_schema)?; + let cast_expr = downcast_cast_expr(&round_trip)?; + + let expected = CastColumnExpr::new_with_schema( + Arc::new(Column::new("a", 0)), + Arc::new(input_field.clone()), + Arc::new(target_field.clone()), + Some(cast_options), + Arc::new(input_schema.clone()), + )?; + + assert_eq!(cast_expr, &expected); + assert_eq!(cast_expr.input_field().as_ref(), &input_field); + assert_eq!(cast_expr.target_field().as_ref(), &target_field); + assert_eq!( + cast_expr.data_type(&input_schema)?, + target_field.data_type().clone() + ); + + Ok(()) +} + +#[test] +fn roundtrip_cast_column_expr_with_missing_format_options() -> Result<()> { + let input_field = Field::new("a", DataType::Int32, true); + let target_field = Field::new("a", DataType::Int64, true); + + let cast_options = OwnedCastOptions { + safe: true, + format_options: OwnedFormatOptions::default(), + }; + let expr: Arc = Arc::new(CastColumnExpr::new( + Arc::new(Column::new("a", 0)), + Arc::new(input_field.clone()), + Arc::new(target_field.clone()), + Some(cast_options), + )?); + + let mut proto = serialize_cast_expr(&expr)?; + let cast_column = match proto.expr_type.as_mut() { + Some(protobuf::physical_expr_node::ExprType::CastColumn(cast_column)) => { + cast_column.as_mut() + } + _ => { + return Err(internal_datafusion_err!( + "Expected PhysicalCastColumnNode in proto" + )); + } + }; + cast_column.format_options = None; + match cast_column.cast_options.as_mut() { + Some(cast_options) => { + cast_options.format_options = None; + } + None => { + cast_column.cast_options = Some(protobuf::PhysicalCastOptions { + safe: DEFAULT_CAST_OPTIONS.safe, + format_options: None, + }); + } + } + let input_schema = Schema::new(vec![input_field.clone()]); + let round_trip = parse_cast_expr(&proto, &input_schema)?; + + let cast_expr = round_trip + .as_any() + .downcast_ref::() + .ok_or_else(|| internal_datafusion_err!("Expected CastColumnExpr"))?; + + assert_eq!( + cast_expr.cast_options().format_options, + OwnedFormatOptions::default() + ); + + Ok(()) +} + +#[test] +fn roundtrip_cast_column_expr_with_target_field_change() -> Result<()> { + let input_field = field_with_origin("payload", DataType::Int32, false, "input"); + let target_field = field_with_origin("payload_cast", DataType::Utf8, false, "target"); + + let input_schema = Schema::new(vec![input_field.clone()]); + let expr: Arc = Arc::new(CastColumnExpr::new_with_schema( + Arc::new(Column::new("payload", 0)), + Arc::new(input_field.clone()), + Arc::new(target_field.clone()), + None, + Arc::new(input_schema.clone()), + )?); + + let proto = serialize_cast_expr(&expr)?; + let round_trip = parse_cast_expr(&proto, &input_schema)?; + + let cast_expr = round_trip + .as_any() + .downcast_ref::() + .ok_or_else(|| internal_datafusion_err!("Expected CastColumnExpr"))?; + + assert_eq!(cast_expr.target_field().name(), "payload_cast"); + assert_eq!( + cast_expr.target_field().data_type(), + target_field.data_type() + ); + + let column_expr = cast_expr + .expr() + .as_any() + .downcast_ref::() + .ok_or_else(|| internal_datafusion_err!("Expected Column"))?; + assert_eq!(column_expr.name(), "payload"); + assert_eq!(column_expr.index(), 0); + + Ok(()) +} + #[test] fn roundtrip_empty() -> Result<()> { roundtrip_test(Arc::new(EmptyExec::new(Arc::new(Schema::empty()))))