Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 49 additions & 16 deletions datafusion/proto-common/src/from_proto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ use arrow::datatypes::{
DataType, Field, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, Schema,
TimeUnit, UnionFields, UnionMode, i256,
};
use arrow::ipc::{reader::read_record_batch, root_as_message};
use arrow::ipc::{
convert::fb_to_schema,
reader::{read_dictionary, read_record_batch},
root_as_message,
writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions},
};

use datafusion_common::{
Column, ColumnStatistics, Constraint, Constraints, DFSchema, DFSchemaRef,
Expand Down Expand Up @@ -406,6 +411,35 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
));
};

// IPC dictionary batch IDs are assigned when encoding the schema, but our protobuf
// `Schema` doesn't preserve those IDs. Reconstruct them deterministically by
// round-tripping the schema through IPC.
let schema: Schema = {
let ipc_gen = IpcDataGenerator {};
let write_options = IpcWriteOptions::default();
let mut dict_tracker = DictionaryTracker::new(false);
let encoded_schema = ipc_gen.schema_to_bytes_with_dictionary_tracker(
&schema,
&mut dict_tracker,
&write_options,
);
let message =
root_as_message(encoded_schema.ipc_message.as_slice()).map_err(
|e| {
Error::General(format!(
"Error IPC schema message while deserializing ScalarValue::List: {e}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the error messages say ScalarValue::List ? (List)
Isn't this used for any nested type ? List, Map, Struct, ...

))
},
)?;
let ipc_schema = message.header_as_schema().ok_or_else(|| {
Error::General(
"Unexpected message type deserializing ScalarValue::List schema"
.to_string(),
)
})?;
fb_to_schema(ipc_schema)
};

let message = root_as_message(ipc_message.as_slice()).map_err(|e| {
Error::General(format!(
"Error IPC message while deserializing ScalarValue::List: {e}"
Expand All @@ -420,7 +454,12 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
)
})?;

let dict_by_id: HashMap<i64,ArrayRef> = dictionaries.iter().map(|protobuf::scalar_nested_value::Dictionary { ipc_message, arrow_data }| {
let mut dict_by_id: HashMap<i64, ArrayRef> = HashMap::new();
for protobuf::scalar_nested_value::Dictionary {
ipc_message,
arrow_data,
} in dictionaries
{
let message = root_as_message(ipc_message.as_slice()).map_err(|e| {
Error::General(format!(
"Error IPC message while deserializing ScalarValue::List dictionary message: {e}"
Expand All @@ -434,22 +473,16 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue {
.to_string(),
)
})?;

let id = dict_batch.id();

let record_batch = read_record_batch(
read_dictionary(
&buffer,
dict_batch.data().unwrap(),
Arc::new(schema.clone()),
&Default::default(),
None,
dict_batch,
&schema,
&mut dict_by_id,
&message.version(),
)?;

let values: ArrayRef = Arc::clone(record_batch.column(0));

Ok((id, values))
}).collect::<datafusion_common::Result<HashMap<_, _>>>()?;
)
.map_err(|e| arrow_datafusion_err!(e))
.map_err(|e| e.context("Decoding ScalarValue::List dictionary"))?;
}

let record_batch = read_record_batch(
&buffer,
Expand Down
7 changes: 7 additions & 0 deletions datafusion/proto-common/src/to_proto/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,13 @@ fn encode_scalar_nested_value(
let ipc_gen = IpcDataGenerator {};
let mut dict_tracker = DictionaryTracker::new(false);
let write_options = IpcWriteOptions::default();
// The IPC writer requires pre-allocated dictionary IDs (normally assigned when
// serializing the schema). Populate `dict_tracker` by encoding the schema first.
ipc_gen.schema_to_bytes_with_dictionary_tracker(
batch.schema().as_ref(),
&mut dict_tracker,
&write_options,
);
let mut compression_context = CompressionContext::default();
let (encoded_dictionaries, encoded_message) = ipc_gen
.encode(
Expand Down
19 changes: 19 additions & 0 deletions datafusion/proto/tests/cases/roundtrip_physical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2564,3 +2564,22 @@ fn custom_proto_converter_intercepts() -> Result<()> {

Ok(())
}

#[test]
fn roundtrip_call_null_scalar_struct_dict() -> Result<()> {
let data_type = DataType::Struct(Fields::from(vec![Field::new(
"item",
DataType::Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)),
true,
)]));

let schema = Arc::new(Schema::new(vec![Field::new("a", data_type.clone(), true)]));
let scan = Arc::new(EmptyExec::new(Arc::clone(&schema)));
let scalar = lit(ScalarValue::try_from(data_type)?);
let filter = Arc::new(FilterExec::try_new(
Arc::new(BinaryExpr::new(scalar, Operator::Eq, col("a", &schema)?)),
scan,
)?);

roundtrip_test(filter)
}