diff --git a/datafusion/proto-common/src/from_proto/mod.rs b/datafusion/proto-common/src/from_proto/mod.rs index 3c41b8cad9ed1..c9b129f37a772 100644 --- a/datafusion/proto-common/src/from_proto/mod.rs +++ b/datafusion/proto-common/src/from_proto/mod.rs @@ -28,7 +28,12 @@ use arrow::datatypes::{ DataType, Field, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema, TimeUnit, UnionFields, UnionMode, i256, }; -use arrow::ipc::{reader::read_record_batch, root_as_message}; +use arrow::ipc::{ + convert::fb_to_schema, + reader::{read_dictionary, read_record_batch}, + root_as_message, + writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions}, +}; use datafusion_common::{ Column, ColumnStatistics, Constraint, Constraints, DFSchema, DFSchemaRef, @@ -406,6 +411,35 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { )); }; + // IPC dictionary batch IDs are assigned when encoding the schema, but our protobuf + // `Schema` doesn't preserve those IDs. Reconstruct them deterministically by + // round-tripping the schema through IPC. + let schema: Schema = { + let ipc_gen = IpcDataGenerator {}; + let write_options = IpcWriteOptions::default(); + let mut dict_tracker = DictionaryTracker::new(false); + let encoded_schema = ipc_gen.schema_to_bytes_with_dictionary_tracker( + &schema, + &mut dict_tracker, + &write_options, + ); + let message = + root_as_message(encoded_schema.ipc_message.as_slice()).map_err( + |e| { + Error::General(format!( + "Error IPC schema message while deserializing ScalarValue::List: {e}" + )) + }, + )?; + let ipc_schema = message.header_as_schema().ok_or_else(|| { + Error::General( + "Unexpected message type deserializing ScalarValue::List schema" + .to_string(), + ) + })?; + fb_to_schema(ipc_schema) + }; + let message = root_as_message(ipc_message.as_slice()).map_err(|e| { Error::General(format!( "Error IPC message while deserializing ScalarValue::List: {e}" @@ -420,7 +454,12 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { ) })?; - let dict_by_id: HashMap = dictionaries.iter().map(|protobuf::scalar_nested_value::Dictionary { ipc_message, arrow_data }| { + let mut dict_by_id: HashMap = HashMap::new(); + for protobuf::scalar_nested_value::Dictionary { + ipc_message, + arrow_data, + } in dictionaries + { let message = root_as_message(ipc_message.as_slice()).map_err(|e| { Error::General(format!( "Error IPC message while deserializing ScalarValue::List dictionary message: {e}" @@ -434,22 +473,16 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { .to_string(), ) })?; - - let id = dict_batch.id(); - - let record_batch = read_record_batch( + read_dictionary( &buffer, - dict_batch.data().unwrap(), - Arc::new(schema.clone()), - &Default::default(), - None, + dict_batch, + &schema, + &mut dict_by_id, &message.version(), - )?; - - let values: ArrayRef = Arc::clone(record_batch.column(0)); - - Ok((id, values)) - }).collect::>>()?; + ) + .map_err(|e| arrow_datafusion_err!(e)) + .map_err(|e| e.context("Decoding ScalarValue::List dictionary"))?; + } let record_batch = read_record_batch( &buffer, diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index fee3656482005..cc3dde7d19cd7 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -1025,6 +1025,13 @@ fn encode_scalar_nested_value( let ipc_gen = IpcDataGenerator {}; let mut dict_tracker = DictionaryTracker::new(false); let write_options = IpcWriteOptions::default(); + // The IPC writer requires pre-allocated dictionary IDs (normally assigned when + // serializing the schema). Populate `dict_tracker` by encoding the schema first. + ipc_gen.schema_to_bytes_with_dictionary_tracker( + batch.schema().as_ref(), + &mut dict_tracker, + &write_options, + ); let mut compression_context = CompressionContext::default(); let (encoded_dictionaries, encoded_message) = ipc_gen .encode( diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 5bb771137fbb7..7ae0036105eae 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -2564,3 +2564,22 @@ fn custom_proto_converter_intercepts() -> Result<()> { Ok(()) } + +#[test] +fn roundtrip_call_null_scalar_struct_dict() -> Result<()> { + let data_type = DataType::Struct(Fields::from(vec![Field::new( + "item", + DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)), + true, + )])); + + let schema = Arc::new(Schema::new(vec![Field::new("a", data_type.clone(), true)])); + let scan = Arc::new(EmptyExec::new(Arc::clone(&schema))); + let scalar = lit(ScalarValue::try_from(data_type)?); + let filter = Arc::new(FilterExec::try_new( + Arc::new(BinaryExpr::new(scalar, Operator::Eq, col("a", &schema)?)), + scan, + )?); + + roundtrip_test(filter) +}