From e0a0061176445a2372b729b276d03d73b8901896 Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Wed, 28 Jan 2026 11:49:33 +0530 Subject: [PATCH 1/2] fix: regression of dict_id in physical plan proto --- datafusion/proto-common/src/from_proto/mod.rs | 65 ++++++++++++++----- datafusion/proto-common/src/to_proto/mod.rs | 7 ++ .../tests/cases/roundtrip_physical_plan.rs | 19 ++++++ 3 files changed, 75 insertions(+), 16 deletions(-) diff --git a/datafusion/proto-common/src/from_proto/mod.rs b/datafusion/proto-common/src/from_proto/mod.rs index 3c41b8cad9ed1..c9b129f37a772 100644 --- a/datafusion/proto-common/src/from_proto/mod.rs +++ b/datafusion/proto-common/src/from_proto/mod.rs @@ -28,7 +28,12 @@ use arrow::datatypes::{ DataType, Field, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema, TimeUnit, UnionFields, UnionMode, i256, }; -use arrow::ipc::{reader::read_record_batch, root_as_message}; +use arrow::ipc::{ + convert::fb_to_schema, + reader::{read_dictionary, read_record_batch}, + root_as_message, + writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions}, +}; use datafusion_common::{ Column, ColumnStatistics, Constraint, Constraints, DFSchema, DFSchemaRef, @@ -406,6 +411,35 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { )); }; + // IPC dictionary batch IDs are assigned when encoding the schema, but our protobuf + // `Schema` doesn't preserve those IDs. Reconstruct them deterministically by + // round-tripping the schema through IPC. + let schema: Schema = { + let ipc_gen = IpcDataGenerator {}; + let write_options = IpcWriteOptions::default(); + let mut dict_tracker = DictionaryTracker::new(false); + let encoded_schema = ipc_gen.schema_to_bytes_with_dictionary_tracker( + &schema, + &mut dict_tracker, + &write_options, + ); + let message = + root_as_message(encoded_schema.ipc_message.as_slice()).map_err( + |e| { + Error::General(format!( + "Error IPC schema message while deserializing ScalarValue::List: {e}" + )) + }, + )?; + let ipc_schema = message.header_as_schema().ok_or_else(|| { + Error::General( + "Unexpected message type deserializing ScalarValue::List schema" + .to_string(), + ) + })?; + fb_to_schema(ipc_schema) + }; + let message = root_as_message(ipc_message.as_slice()).map_err(|e| { Error::General(format!( "Error IPC message while deserializing ScalarValue::List: {e}" @@ -420,7 +454,12 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { ) })?; - let dict_by_id: HashMap = dictionaries.iter().map(|protobuf::scalar_nested_value::Dictionary { ipc_message, arrow_data }| { + let mut dict_by_id: HashMap = HashMap::new(); + for protobuf::scalar_nested_value::Dictionary { + ipc_message, + arrow_data, + } in dictionaries + { let message = root_as_message(ipc_message.as_slice()).map_err(|e| { Error::General(format!( "Error IPC message while deserializing ScalarValue::List dictionary message: {e}" @@ -434,22 +473,16 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { .to_string(), ) })?; - - let id = dict_batch.id(); - - let record_batch = read_record_batch( + read_dictionary( &buffer, - dict_batch.data().unwrap(), - Arc::new(schema.clone()), - &Default::default(), - None, + dict_batch, + &schema, + &mut dict_by_id, &message.version(), - )?; - - let values: ArrayRef = Arc::clone(record_batch.column(0)); - - Ok((id, values)) - }).collect::>>()?; + ) + .map_err(|e| arrow_datafusion_err!(e)) + .map_err(|e| e.context("Decoding ScalarValue::List dictionary"))?; + } let record_batch = read_record_batch( &buffer, diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index fee3656482005..cc3dde7d19cd7 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -1025,6 +1025,13 @@ fn encode_scalar_nested_value( let ipc_gen = IpcDataGenerator {}; let mut dict_tracker = DictionaryTracker::new(false); let write_options = IpcWriteOptions::default(); + // The IPC writer requires pre-allocated dictionary IDs (normally assigned when + // serializing the schema). Populate `dict_tracker` by encoding the schema first. + ipc_gen.schema_to_bytes_with_dictionary_tracker( + batch.schema().as_ref(), + &mut dict_tracker, + &write_options, + ); let mut compression_context = CompressionContext::default(); let (encoded_dictionaries, encoded_message) = ipc_gen .encode( diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 5bb771137fbb7..7ae0036105eae 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -2564,3 +2564,22 @@ fn custom_proto_converter_intercepts() -> Result<()> { Ok(()) } + +#[test] +fn roundtrip_call_null_scalar_struct_dict() -> Result<()> { + let data_type = DataType::Struct(Fields::from(vec![Field::new( + "item", + DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)), + true, + )])); + + let schema = Arc::new(Schema::new(vec![Field::new("a", data_type.clone(), true)])); + let scan = Arc::new(EmptyExec::new(Arc::clone(&schema))); + let scalar = lit(ScalarValue::try_from(data_type)?); + let filter = Arc::new(FilterExec::try_new( + Arc::new(BinaryExpr::new(scalar, Operator::Eq, col("a", &schema)?)), + scan, + )?); + + roundtrip_test(filter) +} From 7248cb8d39c87e6c3d61954bca7a18c7e60470e0 Mon Sep 17 00:00:00 2001 From: Kumar Ujjawal Date: Sat, 31 Jan 2026 12:46:12 +0530 Subject: [PATCH 2/2] update the comments --- datafusion/proto-common/src/from_proto/mod.rs | 20 +++++++++---------- datafusion/proto-common/src/to_proto/mod.rs | 6 +++--- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/datafusion/proto-common/src/from_proto/mod.rs b/datafusion/proto-common/src/from_proto/mod.rs index c9b129f37a772..8fd691ea3a78d 100644 --- a/datafusion/proto-common/src/from_proto/mod.rs +++ b/datafusion/proto-common/src/from_proto/mod.rs @@ -389,7 +389,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Value::Float32Value(v) => Self::Float32(Some(*v)), Value::Float64Value(v) => Self::Float64(Some(*v)), Value::Date32Value(v) => Self::Date32(Some(*v)), - // ScalarValue::List is serialized using arrow IPC format + // Nested ScalarValue types are serialized using arrow IPC format Value::ListValue(v) | Value::FixedSizeListValue(v) | Value::LargeListValue(v) @@ -406,7 +406,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { schema_ref.try_into()? } else { return Err(Error::General( - "Invalid schema while deserializing ScalarValue::List" + "Invalid schema while deserializing nested ScalarValue" .to_string(), )); }; @@ -427,13 +427,13 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { root_as_message(encoded_schema.ipc_message.as_slice()).map_err( |e| { Error::General(format!( - "Error IPC schema message while deserializing ScalarValue::List: {e}" + "Error IPC schema message while deserializing nested ScalarValue: {e}" )) }, )?; let ipc_schema = message.header_as_schema().ok_or_else(|| { Error::General( - "Unexpected message type deserializing ScalarValue::List schema" + "Unexpected message type deserializing nested ScalarValue schema" .to_string(), ) })?; @@ -442,14 +442,14 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { let message = root_as_message(ipc_message.as_slice()).map_err(|e| { Error::General(format!( - "Error IPC message while deserializing ScalarValue::List: {e}" + "Error IPC message while deserializing nested ScalarValue: {e}" )) })?; let buffer = Buffer::from(arrow_data.as_slice()); let ipc_batch = message.header_as_record_batch().ok_or_else(|| { Error::General( - "Unexpected message type deserializing ScalarValue::List" + "Unexpected message type deserializing nested ScalarValue" .to_string(), ) })?; @@ -462,14 +462,14 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { { let message = root_as_message(ipc_message.as_slice()).map_err(|e| { Error::General(format!( - "Error IPC message while deserializing ScalarValue::List dictionary message: {e}" + "Error IPC message while deserializing nested ScalarValue dictionary message: {e}" )) })?; let buffer = Buffer::from(arrow_data.as_slice()); let dict_batch = message.header_as_dictionary_batch().ok_or_else(|| { Error::General( - "Unexpected message type deserializing ScalarValue::List dictionary message" + "Unexpected message type deserializing nested ScalarValue dictionary message" .to_string(), ) })?; @@ -481,7 +481,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { &message.version(), ) .map_err(|e| arrow_datafusion_err!(e)) - .map_err(|e| e.context("Decoding ScalarValue::List dictionary"))?; + .map_err(|e| e.context("Decoding nested ScalarValue dictionary"))?; } let record_batch = read_record_batch( @@ -493,7 +493,7 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { &message.version(), ) .map_err(|e| arrow_datafusion_err!(e)) - .map_err(|e| e.context("Decoding ScalarValue::List Value"))?; + .map_err(|e| e.context("Decoding nested ScalarValue value"))?; let arr = record_batch.column(0); match value { Value::ListValue(_) => { diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index cc3dde7d19cd7..06e3c6e6d4b37 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -1010,7 +1010,7 @@ fn create_proto_scalar protobuf::scalar_value::Value>( Ok(protobuf::ScalarValue { value: Some(value) }) } -// ScalarValue::List / FixedSizeList / LargeList / Struct / Map are serialized using +// Nested ScalarValue types (List / FixedSizeList / LargeList / Struct / Map) are serialized using // Arrow IPC messages as a single column RecordBatch fn encode_scalar_nested_value( arr: ArrayRef, @@ -1018,7 +1018,7 @@ fn encode_scalar_nested_value( ) -> Result { let batch = RecordBatch::try_from_iter(vec![("field_name", arr)]).map_err(|e| { Error::General(format!( - "Error creating temporary batch while encoding ScalarValue::List: {e}" + "Error creating temporary batch while encoding nested ScalarValue: {e}" )) })?; @@ -1041,7 +1041,7 @@ fn encode_scalar_nested_value( &mut compression_context, ) .map_err(|e| { - Error::General(format!("Error encoding ScalarValue::List as IPC: {e}")) + Error::General(format!("Error encoding nested ScalarValue as IPC: {e}")) })?; let schema: protobuf::Schema = batch.schema().try_into()?;