diff --git a/core/src/main/java/io/substrait/expression/MaskExpression.java b/core/src/main/java/io/substrait/expression/MaskExpression.java
new file mode 100644
index 000000000..730e3d222
--- /dev/null
+++ b/core/src/main/java/io/substrait/expression/MaskExpression.java
@@ -0,0 +1,422 @@
+package io.substrait.expression;
+
+import io.substrait.util.VisitationContext;
+import java.util.List;
+import java.util.Optional;
+import org.immutables.value.Value;
+
+/**
+ * A mask expression that selectively removes fields from complex types (struct, list, map).
+ *
+ *
This corresponds to the {@code Expression.MaskExpression} message in the Substrait protobuf
+ * specification. It is used in {@code ReadRel} to describe column projection — the subset of a
+ * relation's schema that should actually be read.
+ *
+ * @see Substrait Field References
+ */
+@Value.Enclosing
+@Value.Immutable
+public interface MaskExpression {
+
+ /**
+ * The top-level struct selection describing which fields to include.
+ *
+ * @return the top-level struct selection
+ */
+ StructSelect getSelect();
+
+ /**
+ * When {@code true}, a struct that has only a single selected field will not be
+ * unwrapped into its child type.
+ *
+ * @return {@code true} if singular structs should be maintained, {@code false} otherwise
+ */
+ @Value.Default
+ default boolean getMaintainSingularStruct() {
+ return false;
+ }
+
+ /**
+ * Creates a new builder for constructing a MaskExpression.
+ *
+ * @return a new builder instance
+ */
+ static ImmutableMaskExpression.Builder builder() {
+ return ImmutableMaskExpression.builder();
+ }
+
+ // ---------------------------------------------------------------------------
+ // Select – a union of StructSelect | ListSelect | MapSelect
+ // ---------------------------------------------------------------------------
+
+ /** A selection on a complex type – one of StructSelect, ListSelect, or MapSelect. */
+ interface Select {
+ /**
+ * Accepts a visitor to process this select node.
+ *
+ * @param the return type of the visitor
+ * @param the context type
+ * @param the exception type that may be thrown
+ * @param visitor the visitor to accept
+ * @param context the visitation context
+ * @return the result of the visitation
+ * @throws E if an error occurs during visitation
+ */
+ R accept(
+ MaskExpressionVisitor visitor, C context) throws E;
+ }
+
+ // ---------------------------------------------------------------------------
+ // Struct selection
+ // ---------------------------------------------------------------------------
+
+ /** Selects a subset of fields from a struct type. */
+ @Value.Immutable
+ interface StructSelect extends Select {
+ /**
+ * Returns the list of struct items being selected.
+ *
+ * @return the list of struct items
+ */
+ List getStructItems();
+
+ /**
+ * Creates a new builder for constructing a StructSelect.
+ *
+ * @return a new builder instance
+ */
+ static ImmutableMaskExpression.StructSelect.Builder builder() {
+ return ImmutableMaskExpression.StructSelect.builder();
+ }
+
+ @Override
+ default R accept(
+ MaskExpressionVisitor visitor, C context) throws E {
+ return visitor.visit(this, context);
+ }
+ }
+
+ /** Selects a single field from a struct, with an optional nested child selection. */
+ @Value.Immutable
+ interface StructItem {
+ /**
+ * Returns the zero-based field index within the struct.
+ *
+ * @return the field index
+ */
+ int getField();
+
+ /**
+ * Returns the optional child selection for nested complex types.
+ *
+ * @return the optional child selection
+ */
+ Optional getChild();
+
+ /**
+ * Creates a new builder for constructing a StructItem.
+ *
+ * @return a new builder instance
+ */
+ static ImmutableMaskExpression.StructItem.Builder builder() {
+ return ImmutableMaskExpression.StructItem.builder();
+ }
+
+ /**
+ * Creates a StructItem for a single field with no nested selection.
+ *
+ * @param field the zero-based field index within the struct
+ * @return a new StructItem instance
+ */
+ static StructItem of(int field) {
+ return builder().field(field).build();
+ }
+
+ /**
+ * Creates a StructItem for a single field with an optional nested selection.
+ *
+ * @param field the zero-based field index within the struct
+ * @param child the nested child selection for complex types
+ * @return a new StructItem instance
+ */
+ static StructItem of(int field, Select child) {
+ return builder().field(field).child(child).build();
+ }
+ }
+
+ // ---------------------------------------------------------------------------
+ // List selection
+ // ---------------------------------------------------------------------------
+
+ /** Selects elements from a list type by index or slice. */
+ @Value.Immutable
+ interface ListSelect extends Select {
+ /**
+ * Returns the list of selection items (individual elements or slices).
+ *
+ * @return the list of selection items
+ */
+ List getSelection();
+
+ /**
+ * Returns the optional child selection applied to each selected element.
+ *
+ * @return the optional child selection
+ */
+ Optional getChild();
+
+ /**
+ * Creates a new builder for constructing a ListSelect.
+ *
+ * @return a new builder instance
+ */
+ static ImmutableMaskExpression.ListSelect.Builder builder() {
+ return ImmutableMaskExpression.ListSelect.builder();
+ }
+
+ @Override
+ default R accept(
+ MaskExpressionVisitor visitor, C context) throws E {
+ return visitor.visit(this, context);
+ }
+ }
+
+ /** A single selection within a list – either an element or a slice. */
+ @Value.Immutable
+ interface ListSelectItem {
+ /**
+ * Returns the optional list element selection.
+ *
+ * @return the optional list element
+ */
+ Optional getItem();
+
+ /**
+ * Returns the optional list slice selection.
+ *
+ * @return the optional list slice
+ */
+ Optional getSlice();
+
+ /**
+ * Creates a new builder for constructing a ListSelectItem.
+ *
+ * @return a new builder instance
+ */
+ static ImmutableMaskExpression.ListSelectItem.Builder builder() {
+ return ImmutableMaskExpression.ListSelectItem.builder();
+ }
+
+ /**
+ * Creates a ListSelectItem for a single element selection.
+ *
+ * @param element the list element to select
+ * @return a new ListSelectItem instance
+ */
+ static ListSelectItem ofItem(ListElement element) {
+ return builder().item(element).build();
+ }
+
+ /**
+ * Creates a ListSelectItem for a slice selection.
+ *
+ * @param slice the list slice to select
+ * @return a new ListSelectItem instance
+ */
+ static ListSelectItem ofSlice(ListSlice slice) {
+ return builder().slice(slice).build();
+ }
+ }
+
+ /** Selects a single element from a list by zero-based index. */
+ @Value.Immutable
+ interface ListElement {
+ /**
+ * Returns the zero-based element index within the list.
+ *
+ * @return the element index
+ */
+ int getField();
+
+ /**
+ * Creates a new builder for constructing a ListElement.
+ *
+ * @return a new builder instance
+ */
+ static ImmutableMaskExpression.ListElement.Builder builder() {
+ return ImmutableMaskExpression.ListElement.builder();
+ }
+
+ /**
+ * Creates a ListElement for a single element selection.
+ *
+ * @param field the zero-based element index within the list
+ * @return a new ListElement instance
+ */
+ static ListElement of(int field) {
+ return builder().field(field).build();
+ }
+ }
+
+ /** Selects a contiguous range of elements from a list. */
+ @Value.Immutable
+ interface ListSlice {
+ /**
+ * Returns the zero-based start index of the slice (inclusive).
+ *
+ * @return the start index
+ */
+ int getStart();
+
+ /**
+ * Returns the zero-based end index of the slice (exclusive).
+ *
+ * @return the end index
+ */
+ int getEnd();
+
+ /**
+ * Creates a new builder for constructing a ListSlice.
+ *
+ * @return a new builder instance
+ */
+ static ImmutableMaskExpression.ListSlice.Builder builder() {
+ return ImmutableMaskExpression.ListSlice.builder();
+ }
+
+ /**
+ * Creates a ListSlice for a contiguous range of elements.
+ *
+ * @param start the zero-based start index (inclusive)
+ * @param end the zero-based end index (exclusive)
+ * @return a new ListSlice instance
+ */
+ static ListSlice of(int start, int end) {
+ return builder().start(start).end(end).build();
+ }
+ }
+
+ // ---------------------------------------------------------------------------
+ // Map selection
+ // ---------------------------------------------------------------------------
+
+ /** Selects entries from a map type by exact key or key expression. */
+ @Value.Immutable
+ interface MapSelect extends Select {
+ /**
+ * Returns the optional exact key for map selection.
+ *
+ * @return the optional map key
+ */
+ Optional getKey();
+
+ /**
+ * Returns the optional key expression for wildcard map selection.
+ *
+ * @return the optional map key expression
+ */
+ Optional getExpression();
+
+ /**
+ * Returns the optional child selection applied to each selected map value.
+ *
+ * @return the optional child selection
+ */
+ Optional getChild();
+
+ /**
+ * Creates a new builder for constructing a MapSelect.
+ *
+ * @return a new builder instance
+ */
+ static ImmutableMaskExpression.MapSelect.Builder builder() {
+ return ImmutableMaskExpression.MapSelect.builder();
+ }
+
+ /**
+ * Creates a MapSelect for a single key selection.
+ *
+ * @param key the exact key to select
+ * @return a new MapSelect instance
+ */
+ static MapSelect ofKey(MapKey key) {
+ return builder().key(key).build();
+ }
+
+ /**
+ * Creates a MapSelect for a wildcard key expression selection.
+ *
+ * @param expression the key expression to select
+ * @return a new MapSelect instance
+ */
+ static MapSelect ofExpression(MapKeyExpression expression) {
+ return builder().expression(expression).build();
+ }
+
+ @Override
+ default R accept(
+ MaskExpressionVisitor visitor, C context) throws E {
+ return visitor.visit(this, context);
+ }
+ }
+
+ /** Selects a map entry by an exact key match. */
+ @Value.Immutable
+ interface MapKey {
+ /**
+ * Returns the map key string for exact matching.
+ *
+ * @return the map key
+ */
+ String getMapKey();
+
+ /**
+ * Creates a new builder for constructing a MapKey.
+ *
+ * @return a new builder instance
+ */
+ static ImmutableMaskExpression.MapKey.Builder builder() {
+ return ImmutableMaskExpression.MapKey.builder();
+ }
+
+ /**
+ * Creates a MapKey for exact key matching.
+ *
+ * @param mapKey the key string to match
+ * @return a new MapKey instance
+ */
+ static MapKey of(String mapKey) {
+ return builder().mapKey(mapKey).build();
+ }
+ }
+
+ /** Selects map entries by a wildcard key expression. */
+ @Value.Immutable
+ interface MapKeyExpression {
+ /**
+ * Returns the wildcard key expression string.
+ *
+ * @return the map key expression
+ */
+ String getMapKeyExpression();
+
+ /**
+ * Creates a new builder for constructing a MapKeyExpression.
+ *
+ * @return a new builder instance
+ */
+ static ImmutableMaskExpression.MapKeyExpression.Builder builder() {
+ return ImmutableMaskExpression.MapKeyExpression.builder();
+ }
+
+ /**
+ * Creates a MapKeyExpression for wildcard key matching.
+ *
+ * @param mapKeyExpression the wildcard expression string
+ * @return a new MapKeyExpression instance
+ */
+ static MapKeyExpression of(String mapKeyExpression) {
+ return builder().mapKeyExpression(mapKeyExpression).build();
+ }
+ }
+}
diff --git a/core/src/main/java/io/substrait/expression/MaskExpressionTypeProjector.java b/core/src/main/java/io/substrait/expression/MaskExpressionTypeProjector.java
new file mode 100644
index 000000000..d1d9945b6
--- /dev/null
+++ b/core/src/main/java/io/substrait/expression/MaskExpressionTypeProjector.java
@@ -0,0 +1,133 @@
+package io.substrait.expression;
+
+import io.substrait.type.Type;
+import io.substrait.type.TypeCreator;
+import io.substrait.util.EmptyVisitationContext;
+import java.util.List;
+
+/**
+ * Applies a {@link MaskExpression} projection to a {@link Type.Struct}, returning a pruned struct.
+ */
+public final class MaskExpressionTypeProjector {
+
+ private MaskExpressionTypeProjector() {}
+
+ /**
+ * Applies the given projection to a struct type, returning a pruned struct.
+ *
+ * @param projection the mask expression projection
+ * @param structType the struct type to project
+ * @return a pruned struct containing only the selected fields
+ */
+ public static Type.Struct project(MaskExpression projection, Type.Struct structType) {
+ return projectStruct(projection.getSelect(), structType);
+ }
+
+ private static Type.Struct projectStruct(
+ MaskExpression.StructSelect structSelect, Type.Struct structType) {
+ List fields = structType.fields();
+ List items = structSelect.getStructItems();
+
+ return TypeCreator.of(structType.nullable())
+ .struct(items.stream().map(item -> projectItem(item, fields.get(item.getField()))));
+ }
+
+ private static Type projectItem(MaskExpression.StructItem item, Type fieldType) {
+ if (!item.getChild().isPresent()) {
+ return fieldType;
+ }
+
+ MaskExpression.Select select = item.getChild().get();
+
+ return select.accept(
+ new MaskExpressionVisitor() {
+ @Override
+ public Type visit(
+ MaskExpression.StructSelect structSelect, EmptyVisitationContext context) {
+ return projectStruct(structSelect, (Type.Struct) fieldType);
+ }
+
+ @Override
+ public Type visit(MaskExpression.ListSelect listSelect, EmptyVisitationContext context) {
+ return projectList(listSelect, (Type.ListType) fieldType);
+ }
+
+ @Override
+ public Type visit(MaskExpression.MapSelect mapSelect, EmptyVisitationContext context) {
+ return projectMap(mapSelect, (Type.Map) fieldType);
+ }
+ },
+ EmptyVisitationContext.INSTANCE);
+ }
+
+ private static Type.ListType projectList(
+ MaskExpression.ListSelect listSelect, Type.ListType listType) {
+ if (!listSelect.getChild().isPresent()) {
+ return listType;
+ }
+
+ MaskExpression.Select childSelect = listSelect.getChild().get();
+ Type elementType = listType.elementType();
+
+ return childSelect.accept(
+ new MaskExpressionVisitor() {
+ @Override
+ public Type.ListType visit(
+ MaskExpression.StructSelect structSelect, EmptyVisitationContext context) {
+ if (elementType instanceof Type.Struct) {
+ Type.Struct prunedElement = projectStruct(structSelect, (Type.Struct) elementType);
+ return TypeCreator.of(listType.nullable()).list(prunedElement);
+ }
+ return listType;
+ }
+
+ @Override
+ public Type.ListType visit(
+ MaskExpression.ListSelect listSelect, EmptyVisitationContext context) {
+ return listType;
+ }
+
+ @Override
+ public Type.ListType visit(
+ MaskExpression.MapSelect mapSelect, EmptyVisitationContext context) {
+ return listType;
+ }
+ },
+ EmptyVisitationContext.INSTANCE);
+ }
+
+ private static Type.Map projectMap(MaskExpression.MapSelect mapSelect, Type.Map mapType) {
+ if (!mapSelect.getChild().isPresent()) {
+ return mapType;
+ }
+
+ MaskExpression.Select childSelect = mapSelect.getChild().get();
+ Type valueType = mapType.value();
+
+ return childSelect.accept(
+ new MaskExpressionVisitor() {
+ @Override
+ public Type.Map visit(
+ MaskExpression.StructSelect structSelect, EmptyVisitationContext context) {
+ if (valueType instanceof Type.Struct) {
+ Type.Struct prunedValue = projectStruct(structSelect, (Type.Struct) valueType);
+ return TypeCreator.of(mapType.nullable()).map(mapType.key(), prunedValue);
+ }
+ return mapType;
+ }
+
+ @Override
+ public Type.Map visit(
+ MaskExpression.ListSelect listSelect, EmptyVisitationContext context) {
+ return mapType;
+ }
+
+ @Override
+ public Type.Map visit(
+ MaskExpression.MapSelect mapSelect, EmptyVisitationContext context) {
+ return mapType;
+ }
+ },
+ EmptyVisitationContext.INSTANCE);
+ }
+}
diff --git a/core/src/main/java/io/substrait/expression/MaskExpressionVisitor.java b/core/src/main/java/io/substrait/expression/MaskExpressionVisitor.java
new file mode 100644
index 000000000..b5eb251c7
--- /dev/null
+++ b/core/src/main/java/io/substrait/expression/MaskExpressionVisitor.java
@@ -0,0 +1,43 @@
+package io.substrait.expression;
+
+import io.substrait.util.VisitationContext;
+
+/**
+ * Visitor for {@link MaskExpression} select nodes.
+ *
+ * @param result type returned by each visit
+ * @param visitation context type
+ * @param throwable type that visit methods may throw
+ */
+public interface MaskExpressionVisitor {
+
+ /**
+ * Visit a struct select.
+ *
+ * @param structSelect the struct select
+ * @param context visitation context
+ * @return visit result
+ * @throws E on visit failure
+ */
+ R visit(MaskExpression.StructSelect structSelect, C context) throws E;
+
+ /**
+ * Visit a list select.
+ *
+ * @param listSelect the list select
+ * @param context visitation context
+ * @return visit result
+ * @throws E on visit failure
+ */
+ R visit(MaskExpression.ListSelect listSelect, C context) throws E;
+
+ /**
+ * Visit a map select.
+ *
+ * @param mapSelect the map select
+ * @param context visitation context
+ * @return visit result
+ * @throws E on visit failure
+ */
+ R visit(MaskExpression.MapSelect mapSelect, C context) throws E;
+}
diff --git a/core/src/main/java/io/substrait/expression/proto/MaskExpressionProtoConverter.java b/core/src/main/java/io/substrait/expression/proto/MaskExpressionProtoConverter.java
new file mode 100644
index 000000000..d5ef0d58a
--- /dev/null
+++ b/core/src/main/java/io/substrait/expression/proto/MaskExpressionProtoConverter.java
@@ -0,0 +1,137 @@
+package io.substrait.expression.proto;
+
+import io.substrait.expression.MaskExpression;
+import io.substrait.expression.MaskExpression.ListSelect;
+import io.substrait.expression.MaskExpression.ListSelectItem;
+import io.substrait.expression.MaskExpression.MapSelect;
+import io.substrait.expression.MaskExpression.Select;
+import io.substrait.expression.MaskExpression.StructItem;
+import io.substrait.expression.MaskExpression.StructSelect;
+import io.substrait.expression.MaskExpressionVisitor;
+import io.substrait.proto.Expression;
+import io.substrait.util.EmptyVisitationContext;
+
+/**
+ * Converts from {@link io.substrait.expression.MaskExpression} to {@link Expression.MaskExpression}
+ */
+public final class MaskExpressionProtoConverter {
+
+ private MaskExpressionProtoConverter() {}
+
+ private static final MaskExpressionVisitor<
+ Expression.MaskExpression.Select, EmptyVisitationContext, RuntimeException>
+ SELECT_TO_PROTO_VISITOR =
+ new MaskExpressionVisitor<
+ Expression.MaskExpression.Select, EmptyVisitationContext, RuntimeException>() {
+ @Override
+ public Expression.MaskExpression.Select visit(
+ MaskExpression.StructSelect structSelect, EmptyVisitationContext context) {
+ return Expression.MaskExpression.Select.newBuilder()
+ .setStruct(toProto(structSelect))
+ .build();
+ }
+
+ @Override
+ public Expression.MaskExpression.Select visit(
+ MaskExpression.ListSelect listSelect, EmptyVisitationContext context) {
+ return Expression.MaskExpression.Select.newBuilder()
+ .setList(toProtoListSelect(listSelect))
+ .build();
+ }
+
+ @Override
+ public Expression.MaskExpression.Select visit(
+ MaskExpression.MapSelect mapSelect, EmptyVisitationContext context) {
+ return Expression.MaskExpression.Select.newBuilder()
+ .setMap(toProtoMapSelect(mapSelect))
+ .build();
+ }
+ };
+
+ /**
+ * Converts a POJO {@link MaskExpression} to its proto representation.
+ *
+ * @param mask the POJO {@link MaskExpression}
+ * @return the proto {@link Expression.MaskExpression}
+ */
+ public static Expression.MaskExpression toProto(MaskExpression mask) {
+ return Expression.MaskExpression.newBuilder()
+ .setSelect(toProto(mask.getSelect()))
+ .setMaintainSingularStruct(mask.getMaintainSingularStruct())
+ .build();
+ }
+
+ private static Expression.MaskExpression.StructSelect toProto(StructSelect structSelect) {
+ Expression.MaskExpression.StructSelect.Builder builder =
+ Expression.MaskExpression.StructSelect.newBuilder();
+ for (StructItem item : structSelect.getStructItems()) {
+ builder.addStructItems(toProto(item));
+ }
+ return builder.build();
+ }
+
+ private static Expression.MaskExpression.StructItem toProto(StructItem structItem) {
+ Expression.MaskExpression.StructItem.Builder builder =
+ Expression.MaskExpression.StructItem.newBuilder().setField(structItem.getField());
+ structItem.getChild().ifPresent(child -> builder.setChild(toProtoSelect(child)));
+ return builder.build();
+ }
+
+ private static Expression.MaskExpression.Select toProtoSelect(Select select) {
+ return select.accept(SELECT_TO_PROTO_VISITOR, EmptyVisitationContext.INSTANCE);
+ }
+
+ private static Expression.MaskExpression.ListSelect toProtoListSelect(ListSelect listSelect) {
+ Expression.MaskExpression.ListSelect.Builder builder =
+ Expression.MaskExpression.ListSelect.newBuilder();
+ for (ListSelectItem item : listSelect.getSelection()) {
+ builder.addSelection(toProtoListSelectItem(item));
+ }
+ listSelect.getChild().ifPresent(child -> builder.setChild(toProtoSelect(child)));
+ return builder.build();
+ }
+
+ private static Expression.MaskExpression.ListSelect.ListSelectItem toProtoListSelectItem(
+ ListSelectItem item) {
+ Expression.MaskExpression.ListSelect.ListSelectItem.Builder builder =
+ Expression.MaskExpression.ListSelect.ListSelectItem.newBuilder();
+ if (item.getItem().isPresent()) {
+ builder.setItem(
+ Expression.MaskExpression.ListSelect.ListSelectItem.ListElement.newBuilder()
+ .setField(item.getItem().get().getField())
+ .build());
+ } else if (item.getSlice().isPresent()) {
+ builder.setSlice(
+ Expression.MaskExpression.ListSelect.ListSelectItem.ListSlice.newBuilder()
+ .setStart(item.getSlice().get().getStart())
+ .setEnd(item.getSlice().get().getEnd())
+ .build());
+ } else {
+ throw new IllegalArgumentException("ListSelectItem must have either item or slice set");
+ }
+ return builder.build();
+ }
+
+ private static Expression.MaskExpression.MapSelect toProtoMapSelect(MapSelect mapSelect) {
+ Expression.MaskExpression.MapSelect.Builder builder =
+ Expression.MaskExpression.MapSelect.newBuilder();
+ mapSelect
+ .getKey()
+ .ifPresent(
+ key ->
+ builder.setKey(
+ Expression.MaskExpression.MapSelect.MapKey.newBuilder()
+ .setMapKey(key.getMapKey())
+ .build()));
+ mapSelect
+ .getExpression()
+ .ifPresent(
+ expr ->
+ builder.setExpression(
+ Expression.MaskExpression.MapSelect.MapKeyExpression.newBuilder()
+ .setMapKeyExpression(expr.getMapKeyExpression())
+ .build()));
+ mapSelect.getChild().ifPresent(child -> builder.setChild(toProtoSelect(child)));
+ return builder.build();
+ }
+}
diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoMaskExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoMaskExpressionConverter.java
new file mode 100644
index 000000000..6148b0f3d
--- /dev/null
+++ b/core/src/main/java/io/substrait/expression/proto/ProtoMaskExpressionConverter.java
@@ -0,0 +1,110 @@
+package io.substrait.expression.proto;
+
+import io.substrait.expression.ImmutableMaskExpression;
+import io.substrait.expression.MaskExpression;
+import io.substrait.expression.MaskExpression.ListSelect;
+import io.substrait.expression.MaskExpression.ListSelectItem;
+import io.substrait.expression.MaskExpression.MapSelect;
+import io.substrait.expression.MaskExpression.Select;
+import io.substrait.expression.MaskExpression.StructItem;
+import io.substrait.expression.MaskExpression.StructSelect;
+import io.substrait.proto.Expression;
+
+/**
+ * Converts from {@link Expression.MaskExpression} to {@link io.substrait.expression.MaskExpression}
+ */
+public final class ProtoMaskExpressionConverter {
+
+ private ProtoMaskExpressionConverter() {}
+
+ /**
+ * Converts a proto {@link Expression.MaskExpression} to its POJO representation.
+ *
+ * @param proto the proto {@link Expression.MaskExpression}
+ * @return the POJO {@link MaskExpression}
+ */
+ public static MaskExpression fromProto(Expression.MaskExpression proto) {
+ return MaskExpression.builder()
+ .select(fromProto(proto.getSelect()))
+ .maintainSingularStruct(proto.getMaintainSingularStruct())
+ .build();
+ }
+
+ private static StructSelect fromProto(Expression.MaskExpression.StructSelect proto) {
+ ImmutableMaskExpression.StructSelect.Builder builder = StructSelect.builder();
+ for (Expression.MaskExpression.StructItem item : proto.getStructItemsList()) {
+ builder.addStructItems(fromProto(item));
+ }
+ return builder.build();
+ }
+
+ private static StructItem fromProto(Expression.MaskExpression.StructItem proto) {
+ ImmutableMaskExpression.StructItem.Builder builder =
+ StructItem.builder().field(proto.getField());
+ if (proto.hasChild()) {
+ builder.child(fromProtoSelect(proto.getChild()));
+ }
+ return builder.build();
+ }
+
+ private static Select fromProtoSelect(Expression.MaskExpression.Select proto) {
+ switch (proto.getTypeCase()) {
+ case STRUCT:
+ return fromProto(proto.getStruct());
+ case LIST:
+ return fromProtoListSelect(proto.getList());
+ case MAP:
+ return fromProtoMapSelect(proto.getMap());
+ default:
+ throw new IllegalArgumentException(
+ "Unknown MaskExpression.Select type: " + proto.getTypeCase());
+ }
+ }
+
+ private static ListSelect fromProtoListSelect(Expression.MaskExpression.ListSelect proto) {
+ ImmutableMaskExpression.ListSelect.Builder builder = ListSelect.builder();
+ for (Expression.MaskExpression.ListSelect.ListSelectItem item : proto.getSelectionList()) {
+ builder.addSelection(fromProtoListSelectItem(item));
+ }
+ if (proto.hasChild()) {
+ builder.child(fromProtoSelect(proto.getChild()));
+ }
+ return builder.build();
+ }
+
+ private static ListSelectItem fromProtoListSelectItem(
+ Expression.MaskExpression.ListSelect.ListSelectItem proto) {
+ ImmutableMaskExpression.ListSelectItem.Builder builder = ListSelectItem.builder();
+ switch (proto.getTypeCase()) {
+ case ITEM:
+ builder.item(MaskExpression.ListElement.of(proto.getItem().getField()));
+ break;
+ case SLICE:
+ builder.slice(
+ MaskExpression.ListSlice.of(proto.getSlice().getStart(), proto.getSlice().getEnd()));
+ break;
+ default:
+ throw new IllegalArgumentException("Unknown ListSelectItem type: " + proto.getTypeCase());
+ }
+ return builder.build();
+ }
+
+ private static MapSelect fromProtoMapSelect(Expression.MaskExpression.MapSelect proto) {
+ ImmutableMaskExpression.MapSelect.Builder builder = MapSelect.builder();
+ switch (proto.getSelectCase()) {
+ case KEY:
+ builder.key(MaskExpression.MapKey.of(proto.getKey().getMapKey()));
+ break;
+ case EXPRESSION:
+ builder.expression(
+ MaskExpression.MapKeyExpression.of(proto.getExpression().getMapKeyExpression()));
+ break;
+ default:
+ throw new IllegalArgumentException("Unknown MapSelect type: " + proto.getSelectCase());
+ }
+ if (proto.hasChild()) {
+ builder.child(fromProtoSelect(proto.getChild()));
+ }
+ return builder.build();
+ }
+}
diff --git a/core/src/main/java/io/substrait/relation/AbstractReadRel.java b/core/src/main/java/io/substrait/relation/AbstractReadRel.java
index 51df27fcf..ab3445f69 100644
--- a/core/src/main/java/io/substrait/relation/AbstractReadRel.java
+++ b/core/src/main/java/io/substrait/relation/AbstractReadRel.java
@@ -1,6 +1,8 @@
package io.substrait.relation;
import io.substrait.expression.Expression;
+import io.substrait.expression.MaskExpression;
+import io.substrait.expression.MaskExpressionTypeProjector;
import io.substrait.type.NamedStruct;
import io.substrait.type.Type;
import java.util.Optional;
@@ -13,11 +15,13 @@ public abstract class AbstractReadRel extends ZeroInputRel implements HasExtensi
public abstract Optional getBestEffortFilter();
- // TODO:
- // public abstract Optional
+ public abstract Optional getProjection();
@Override
protected final Type.Struct deriveRecordType() {
- return getInitialSchema().struct();
+ Type.Struct base = getInitialSchema().struct();
+ return getProjection()
+ .map(projection -> MaskExpressionTypeProjector.project(projection, base))
+ .orElse(base);
}
}
diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java
index 015695638..b8a28f78c 100644
--- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java
+++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java
@@ -2,7 +2,9 @@
import io.substrait.expression.Expression;
import io.substrait.expression.FieldReference;
+import io.substrait.expression.MaskExpression;
import io.substrait.expression.proto.ProtoExpressionConverter;
+import io.substrait.expression.proto.ProtoMaskExpressionConverter;
import io.substrait.extension.AdvancedExtension;
import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.extension.ExtensionLookup;
@@ -455,7 +457,8 @@ protected NamedScan newNamedScan(ReadRel rel) {
? new ProtoExpressionConverter(
lookup, extensions, namedStruct.struct(), this)
.from(rel.getFilter())
- : null));
+ : null))
+ .projection(optionalMaskExpression(rel));
builder
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
@@ -475,6 +478,7 @@ protected ExtensionTable newExtensionTable(final ReadRel rel) {
ExtensionTable.from(detail).initialSchema(namedStruct);
builder
+ .projection(optionalMaskExpression(rel))
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
.remap(optionalRelmap(rel.getCommon()))
.hint(optionalHint(rel.getCommon()));
@@ -507,7 +511,8 @@ protected LocalFiles newLocalFiles(ReadRel rel) {
? new ProtoExpressionConverter(
lookup, extensions, namedStruct.struct(), this)
.from(rel.getFilter())
- : null));
+ : null))
+ .projection(optionalMaskExpression(rel));
builder
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
@@ -614,7 +619,8 @@ protected VirtualTableScan newVirtualTable(ReadRel rel) {
rel.hasBestEffortFilter() ? converter.from(rel.getBestEffortFilter()) : null))
.filter(Optional.ofNullable(rel.hasFilter() ? converter.from(rel.getFilter()) : null))
.initialSchema(NamedStruct.fromProto(rel.getBaseSchema(), protoTypeConverter))
- .rows(expressions);
+ .rows(expressions)
+ .projection(optionalMaskExpression(rel));
builder
.commonExtension(optionalAdvancedExtension(rel.getCommon()))
@@ -1210,6 +1216,11 @@ protected Optional optionalAdvancedExtension(
: null);
}
+ protected Optional optionalMaskExpression(ReadRel rel) {
+ return Optional.ofNullable(
+ rel.hasProjection() ? ProtoMaskExpressionConverter.fromProto(rel.getProjection()) : null);
+ }
+
/** Override to provide a custom converter for {@link ExtensionLeafRel#getDetail()} data */
protected Extension.LeafRelDetail detailFromExtensionLeafRel(com.google.protobuf.Any any) {
return emptyDetail();
diff --git a/core/src/main/java/io/substrait/relation/RelProtoConverter.java b/core/src/main/java/io/substrait/relation/RelProtoConverter.java
index 429d26661..0b9aebada 100644
--- a/core/src/main/java/io/substrait/relation/RelProtoConverter.java
+++ b/core/src/main/java/io/substrait/relation/RelProtoConverter.java
@@ -5,6 +5,7 @@
import io.substrait.expression.FunctionArg;
import io.substrait.expression.proto.ExpressionProtoConverter;
import io.substrait.expression.proto.ExpressionProtoConverter.BoundConverter;
+import io.substrait.expression.proto.MaskExpressionProtoConverter;
import io.substrait.extension.ExtensionCollector;
import io.substrait.extension.ExtensionProtoConverter;
import io.substrait.extension.SimpleExtension;
@@ -297,6 +298,9 @@ public Rel visit(NamedScan namedScan, EmptyVisitationContext context) throws Run
namedScan.getFilter().ifPresent(f -> builder.setFilter(toProto(f)));
namedScan.getBestEffortFilter().ifPresent(f -> builder.setBestEffortFilter(toProto(f)));
+ namedScan
+ .getProjection()
+ .ifPresent(p -> builder.setProjection(MaskExpressionProtoConverter.toProto(p)));
namedScan
.getExtension()
@@ -319,6 +323,9 @@ public Rel visit(LocalFiles localFiles, EmptyVisitationContext context) throws R
.setBaseSchema(localFiles.getInitialSchema().toProto(typeProtoConverter));
localFiles.getFilter().ifPresent(t -> builder.setFilter(toProto(t)));
localFiles.getBestEffortFilter().ifPresent(t -> builder.setBestEffortFilter(toProto(t)));
+ localFiles
+ .getProjection()
+ .ifPresent(p -> builder.setProjection(MaskExpressionProtoConverter.toProto(p)));
localFiles
.getExtension()
@@ -337,6 +344,10 @@ public Rel visit(ExtensionTable extensionTable, EmptyVisitationContext context)
.setBaseSchema(extensionTable.getInitialSchema().toProto(typeProtoConverter))
.setExtensionTable(extensionTableBuilder);
+ extensionTable
+ .getProjection()
+ .ifPresent(p -> builder.setProjection(MaskExpressionProtoConverter.toProto(p)));
+
extensionTable
.getExtension()
.ifPresent(ae -> builder.setAdvancedExtension(extensionProtoConverter.toProto(ae)));
@@ -770,6 +781,9 @@ public Rel visit(VirtualTableScan virtualTableScan, EmptyVisitationContext conte
virtualTableScan.getFilter().ifPresent(f -> builder.setFilter(toProto(f)));
virtualTableScan.getBestEffortFilter().ifPresent(f -> builder.setBestEffortFilter(toProto(f)));
+ virtualTableScan
+ .getProjection()
+ .ifPresent(p -> builder.setProjection(MaskExpressionProtoConverter.toProto(p)));
virtualTableScan
.getExtension()
diff --git a/core/src/test/java/io/substrait/type/proto/ReadRelRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/ReadRelRoundtripTest.java
index c9a37d8f6..606b9743f 100644
--- a/core/src/test/java/io/substrait/type/proto/ReadRelRoundtripTest.java
+++ b/core/src/test/java/io/substrait/type/proto/ReadRelRoundtripTest.java
@@ -1,10 +1,17 @@
package io.substrait.type.proto;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
import io.substrait.TestBase;
import io.substrait.expression.Expression;
import io.substrait.expression.ExpressionCreator;
+import io.substrait.expression.MaskExpression;
+import io.substrait.relation.LocalFiles;
import io.substrait.relation.NamedScan;
+import io.substrait.relation.Rel;
import io.substrait.relation.VirtualTableScan;
+import io.substrait.relation.files.FileFormat;
+import io.substrait.relation.files.FileOrFiles;
import io.substrait.type.NamedStruct;
import io.substrait.type.Type;
import java.util.List;
@@ -83,4 +90,518 @@ void virtualTableWithNullable() {
.build();
verifyRoundTrip(virtTable);
}
+
+ @Test
+ void namedScanWithSimpleProjection() {
+ List tableName = Stream.of("my_table").collect(Collectors.toList());
+ List columnNames = Stream.of("col_a", "col_b", "col_c").collect(Collectors.toList());
+ List columnTypes = Stream.of(R.I32, R.STRING, R.I64).collect(Collectors.toList());
+
+ // Select columns 0 and 2
+ MaskExpression projection =
+ MaskExpression.builder()
+ .select(
+ MaskExpression.StructSelect.builder()
+ .addStructItems(MaskExpression.StructItem.of(0))
+ .addStructItems(MaskExpression.StructItem.of(2))
+ .build())
+ .build();
+
+ NamedScan namedScan =
+ NamedScan.builder()
+ .from(sb.namedScan(tableName, columnNames, columnTypes))
+ .projection(projection)
+ .build();
+
+ verifyRoundTrip(namedScan);
+ }
+
+ @Test
+ void namedScanWithNestedProjection() {
+ List tableName = Stream.of("nested_table").collect(Collectors.toList());
+ List columnNames = Stream.of("outer_struct", "simple_col").collect(Collectors.toList());
+ List columnTypes =
+ Stream.of(R.struct(R.I32, R.STRING, R.I64), R.I32).collect(Collectors.toList());
+
+ // Select field 0, but within it only subfields 0 and 2
+ MaskExpression projection =
+ MaskExpression.builder()
+ .select(
+ MaskExpression.StructSelect.builder()
+ .addStructItems(
+ MaskExpression.StructItem.of(
+ 0,
+ MaskExpression.StructSelect.builder()
+ .addStructItems(MaskExpression.StructItem.of(0))
+ .addStructItems(MaskExpression.StructItem.of(2))
+ .build()))
+ .addStructItems(MaskExpression.StructItem.of(1))
+ .build())
+ .maintainSingularStruct(true)
+ .build();
+
+ NamedScan namedScan =
+ NamedScan.builder()
+ .from(sb.namedScan(tableName, columnNames, columnTypes))
+ .projection(projection)
+ .build();
+
+ verifyRoundTrip(namedScan);
+ }
+
+ @Test
+ void namedScanWithListProjection() {
+ List tableName = Stream.of("list_table").collect(Collectors.toList());
+ List columnNames = Stream.of("list_col", "id").collect(Collectors.toList());
+ List columnTypes = Stream.of(R.list(R.I32), R.I64).collect(Collectors.toList());
+
+ // Select field 0 with list element and slice selection, and field 1
+ MaskExpression projection =
+ MaskExpression.builder()
+ .select(
+ MaskExpression.StructSelect.builder()
+ .addStructItems(
+ MaskExpression.StructItem.of(
+ 0,
+ MaskExpression.ListSelect.builder()
+ .addSelection(
+ MaskExpression.ListSelectItem.ofItem(
+ MaskExpression.ListElement.of(0)))
+ .addSelection(
+ MaskExpression.ListSelectItem.ofSlice(
+ MaskExpression.ListSlice.of(2, 5)))
+ .build()))
+ .addStructItems(MaskExpression.StructItem.of(1))
+ .build())
+ .build();
+
+ NamedScan namedScan =
+ NamedScan.builder()
+ .from(sb.namedScan(tableName, columnNames, columnTypes))
+ .projection(projection)
+ .build();
+
+ verifyRoundTrip(namedScan);
+ }
+
+ @Test
+ void namedScanWithMapProjection() {
+ List tableName = Stream.of("map_table").collect(Collectors.toList());
+ List columnNames = Stream.of("map_col", "id").collect(Collectors.toList());
+ List columnTypes = Stream.of(R.map(R.STRING, R.I32), R.I64).collect(Collectors.toList());
+
+ // Select field 0 with map key selection, and field 1
+ MaskExpression projection =
+ MaskExpression.builder()
+ .select(
+ MaskExpression.StructSelect.builder()
+ .addStructItems(
+ MaskExpression.StructItem.of(
+ 0, MaskExpression.MapSelect.ofKey(MaskExpression.MapKey.of("my_key"))))
+ .addStructItems(MaskExpression.StructItem.of(1))
+ .build())
+ .build();
+
+ NamedScan namedScan =
+ NamedScan.builder()
+ .from(sb.namedScan(tableName, columnNames, columnTypes))
+ .projection(projection)
+ .build();
+
+ verifyRoundTrip(namedScan);
+ }
+
+ @Test
+ void namedScanWithMapKeyExpressionProjection() {
+ List tableName = Stream.of("map_table").collect(Collectors.toList());
+ List columnNames = Stream.of("map_col").collect(Collectors.toList());
+ List columnTypes = Stream.of(R.map(R.STRING, R.I32)).collect(Collectors.toList());
+
+ MaskExpression projection =
+ MaskExpression.builder()
+ .select(
+ MaskExpression.StructSelect.builder()
+ .addStructItems(
+ MaskExpression.StructItem.of(
+ 0,
+ MaskExpression.MapSelect.ofExpression(
+ MaskExpression.MapKeyExpression.of("prefix_*"))))
+ .build())
+ .build();
+
+ NamedScan namedScan =
+ NamedScan.builder()
+ .from(sb.namedScan(tableName, columnNames, columnTypes))
+ .projection(projection)
+ .build();
+
+ verifyRoundTrip(namedScan);
+ }
+
+ @Test
+ void virtualTableWithProjection() {
+ MaskExpression projection =
+ MaskExpression.builder()
+ .select(
+ MaskExpression.StructSelect.builder()
+ .addStructItems(MaskExpression.StructItem.of(0))
+ .build())
+ .build();
+
+ io.substrait.relation.ImmutableVirtualTableScan virtTable =
+ VirtualTableScan.builder()
+ .initialSchema(
+ NamedStruct.of(
+ Stream.of("column1", "column2").collect(Collectors.toList()),
+ R.struct(R.I64, R.I64)))
+ .addRows(
+ Expression.NestedStruct.builder()
+ .addFields(ExpressionCreator.i64(false, 1))
+ .addFields(ExpressionCreator.i64(false, 2))
+ .build())
+ .projection(projection)
+ .build();
+
+ verifyRoundTrip(virtTable);
+ }
+
+ @Test
+ void localFilesWithProjection() {
+ MaskExpression projection =
+ MaskExpression.builder()
+ .select(
+ MaskExpression.StructSelect.builder()
+ .addStructItems(MaskExpression.StructItem.of(1))
+ .build())
+ .build();
+
+ LocalFiles localFiles =
+ LocalFiles.builder()
+ .initialSchema(
+ NamedStruct.of(
+ Stream.of("col_a", "col_b").collect(Collectors.toList()),
+ R.struct(R.I32, R.STRING)))
+ .addItems(
+ FileOrFiles.builder()
+ .pathType(FileOrFiles.PathType.URI_PATH)
+ .path("/data/file.parquet")
+ .partitionIndex(0)
+ .start(0)
+ .length(1024)
+ .fileFormat(FileFormat.ParquetReadOptions.builder().build())
+ .build())
+ .projection(projection)
+ .build();
+
+ verifyRoundTrip(localFiles);
+ }
+
+ @Test
+ void namedScanWithProjectionAndFilter() {
+ List tableName = Stream.of("filtered_table").collect(Collectors.toList());
+ List columnNames = Stream.of("col_a", "col_b").collect(Collectors.toList());
+ List columnTypes = Stream.of(R.I64, R.I64).collect(Collectors.toList());
+
+ MaskExpression projection =
+ MaskExpression.builder()
+ .select(
+ MaskExpression.StructSelect.builder()
+ .addStructItems(MaskExpression.StructItem.of(0))
+ .addStructItems(MaskExpression.StructItem.of(1))
+ .build())
+ .maintainSingularStruct(true)
+ .build();
+
+ NamedScan namedScan = sb.namedScan(tableName, columnNames, columnTypes);
+ namedScan =
+ NamedScan.builder()
+ .from(namedScan)
+ .filter(sb.equal(sb.fieldReference(namedScan, 0), sb.fieldReference(namedScan, 1)))
+ .projection(projection)
+ .build();
+
+ verifyRoundTrip(namedScan);
+ }
+
+ @Test
+ void namedScanWithListSelectWithChild() {
+ // list> - select list elements, then pick struct subfields
+ List tableName = Stream.of("nested_list_table").collect(Collectors.toList());
+ List columnNames = Stream.of("items", "id").collect(Collectors.toList());
+ List columnTypes =
+ Stream.of(R.list(R.struct(R.I32, R.STRING, R.I64)), R.I64).collect(Collectors.toList());
+
+ MaskExpression projection =
+ MaskExpression.builder()
+ .select(
+ MaskExpression.StructSelect.builder()
+ .addStructItems(
+ MaskExpression.StructItem.of(
+ 0,
+ MaskExpression.ListSelect.builder()
+ .addSelection(
+ MaskExpression.ListSelectItem.ofSlice(
+ MaskExpression.ListSlice.of(0, 5)))
+ .child(
+ MaskExpression.StructSelect.builder()
+ .addStructItems(MaskExpression.StructItem.of(0))
+ .addStructItems(MaskExpression.StructItem.of(2))
+ .build())
+ .build()))
+ .addStructItems(MaskExpression.StructItem.of(1))
+ .build())
+ .build();
+
+ NamedScan namedScan =
+ NamedScan.builder()
+ .from(sb.namedScan(tableName, columnNames, columnTypes))
+ .projection(projection)
+ .build();
+
+ verifyRoundTrip(namedScan);
+ }
+
+ @Test
+ void namedScanWithMapSelectWithChild() {
+ // map> - select by key, then pick struct subfields
+ List tableName = Stream.of("nested_map_table").collect(Collectors.toList());
+ List columnNames = Stream.of("entries", "id").collect(Collectors.toList());
+ List columnTypes =
+ Stream.of(R.map(R.STRING, R.struct(R.I32, R.STRING)), R.I64).collect(Collectors.toList());
+
+ MaskExpression projection =
+ MaskExpression.builder()
+ .select(
+ MaskExpression.StructSelect.builder()
+ .addStructItems(
+ MaskExpression.StructItem.of(
+ 0,
+ MaskExpression.MapSelect.builder()
+ .key(MaskExpression.MapKey.of("user_info"))
+ .child(
+ MaskExpression.StructSelect.builder()
+ .addStructItems(MaskExpression.StructItem.of(1))
+ .build())
+ .build()))
+ .addStructItems(MaskExpression.StructItem.of(1))
+ .build())
+ .build();
+
+ NamedScan namedScan =
+ NamedScan.builder()
+ .from(sb.namedScan(tableName, columnNames, columnTypes))
+ .projection(projection)
+ .build();
+
+ verifyRoundTrip(namedScan);
+ }
+
+ @Test
+ void namedScanWithMapKeyExpressionAndChild() {
+ // map> - select by key expression, then pick struct subfield
+ List tableName = Stream.of("map_expr_table").collect(Collectors.toList());
+ List columnNames = Stream.of("data").collect(Collectors.toList());
+ List columnTypes =
+ Stream.of(R.map(R.STRING, R.struct(R.I32, R.STRING))).collect(Collectors.toList());
+
+ MaskExpression projection =
+ MaskExpression.builder()
+ .select(
+ MaskExpression.StructSelect.builder()
+ .addStructItems(
+ MaskExpression.StructItem.of(
+ 0,
+ MaskExpression.MapSelect.builder()
+ .expression(MaskExpression.MapKeyExpression.of("user_*"))
+ .child(
+ MaskExpression.StructSelect.builder()
+ .addStructItems(MaskExpression.StructItem.of(0))
+ .build())
+ .build()))
+ .build())
+ .build();
+
+ NamedScan namedScan =
+ NamedScan.builder()
+ .from(sb.namedScan(tableName, columnNames, columnTypes))
+ .projection(projection)
+ .build();
+
+ verifyRoundTrip(namedScan);
+ }
+
+ @Test
+ void recordTypeReflectsSimpleProjection() {
+ List tableName = Stream.of("my_table").collect(Collectors.toList());
+ List columnNames = Stream.of("col_a", "col_b", "col_c").collect(Collectors.toList());
+ List columnTypes = Stream.of(R.I32, R.STRING, R.I64).collect(Collectors.toList());
+
+ // Select fields 0 (I32) and 2 (I64), skipping field 1 (STRING)
+ MaskExpression projection =
+ MaskExpression.builder()
+ .select(
+ MaskExpression.StructSelect.builder()
+ .addStructItems(MaskExpression.StructItem.of(0))
+ .addStructItems(MaskExpression.StructItem.of(2))
+ .build())
+ .build();
+
+ NamedScan namedScan =
+ NamedScan.builder()
+ .from(sb.namedScan(tableName, columnNames, columnTypes))
+ .projection(projection)
+ .build();
+
+ Type.Struct expected = R.struct(R.I32, R.I64);
+ assertEquals(expected, namedScan.getRecordType());
+ }
+
+ @Test
+ void recordTypeReflectsNestedStructProjection() {
+ List tableName = Stream.of("nested_table").collect(Collectors.toList());
+ List columnNames = Stream.of("outer_struct", "simple_col").collect(Collectors.toList());
+ List columnTypes =
+ Stream.of(R.struct(R.I32, R.STRING, R.I64), R.I32).collect(Collectors.toList());
+
+ // Select field 0 with child struct selecting subfields 0 (I32) and 2 (I64), plus field 1 (I32)
+ MaskExpression projection =
+ MaskExpression.builder()
+ .select(
+ MaskExpression.StructSelect.builder()
+ .addStructItems(
+ MaskExpression.StructItem.of(
+ 0,
+ MaskExpression.StructSelect.builder()
+ .addStructItems(MaskExpression.StructItem.of(0))
+ .addStructItems(MaskExpression.StructItem.of(2))
+ .build()))
+ .addStructItems(MaskExpression.StructItem.of(1))
+ .build())
+ .build();
+
+ NamedScan namedScan =
+ NamedScan.builder()
+ .from(sb.namedScan(tableName, columnNames, columnTypes))
+ .projection(projection)
+ .build();
+
+ Type.Struct expected = R.struct(R.struct(R.I32, R.I64), R.I32);
+ assertEquals(expected, namedScan.getRecordType());
+ }
+
+ @Test
+ void recordTypeReflectsListWithChildProjection() {
+ List tableName = Stream.of("nested_list_table").collect(Collectors.toList());
+ List columnNames = Stream.of("items", "id").collect(Collectors.toList());
+ List columnTypes =
+ Stream.of(R.list(R.struct(R.I32, R.STRING, R.I64)), R.I64).collect(Collectors.toList());
+
+ // Select field 0 with list child selecting struct subfields 0 (I32) and 2 (I64), plus field 1
+ MaskExpression projection =
+ MaskExpression.builder()
+ .select(
+ MaskExpression.StructSelect.builder()
+ .addStructItems(
+ MaskExpression.StructItem.of(
+ 0,
+ MaskExpression.ListSelect.builder()
+ .addSelection(
+ MaskExpression.ListSelectItem.ofSlice(
+ MaskExpression.ListSlice.of(0, 5)))
+ .child(
+ MaskExpression.StructSelect.builder()
+ .addStructItems(MaskExpression.StructItem.of(0))
+ .addStructItems(MaskExpression.StructItem.of(2))
+ .build())
+ .build()))
+ .addStructItems(MaskExpression.StructItem.of(1))
+ .build())
+ .build();
+
+ NamedScan namedScan =
+ NamedScan.builder()
+ .from(sb.namedScan(tableName, columnNames, columnTypes))
+ .projection(projection)
+ .build();
+
+ Type.Struct expected = R.struct(R.list(R.struct(R.I32, R.I64)), R.I64);
+ assertEquals(expected, namedScan.getRecordType());
+ }
+
+ @Test
+ void recordTypeReflectsMapWithChildProjection() {
+ List tableName = Stream.of("nested_map_table").collect(Collectors.toList());
+ List columnNames = Stream.of("entries", "id").collect(Collectors.toList());
+ List columnTypes =
+ Stream.of(R.map(R.STRING, R.struct(R.I32, R.STRING)), R.I64).collect(Collectors.toList());
+
+ // Select field 0 with map child selecting struct subfield 1 (STRING), plus field 1
+ MaskExpression projection =
+ MaskExpression.builder()
+ .select(
+ MaskExpression.StructSelect.builder()
+ .addStructItems(
+ MaskExpression.StructItem.of(
+ 0,
+ MaskExpression.MapSelect.builder()
+ .key(MaskExpression.MapKey.of("any_key"))
+ .child(
+ MaskExpression.StructSelect.builder()
+ .addStructItems(MaskExpression.StructItem.of(1))
+ .build())
+ .build()))
+ .addStructItems(MaskExpression.StructItem.of(1))
+ .build())
+ .build();
+
+ NamedScan namedScan =
+ NamedScan.builder()
+ .from(sb.namedScan(tableName, columnNames, columnTypes))
+ .projection(projection)
+ .build();
+
+ Type.Struct expected = R.struct(R.map(R.STRING, R.struct(R.STRING)), R.I64);
+ assertEquals(expected, namedScan.getRecordType());
+ }
+
+ @Test
+ void recordTypeWithoutProjection() {
+ List tableName = Stream.of("full_table").collect(Collectors.toList());
+ List columnNames = Stream.of("col_a", "col_b", "col_c").collect(Collectors.toList());
+ List columnTypes = Stream.of(R.I32, R.STRING, R.I64).collect(Collectors.toList());
+
+ NamedScan namedScan = sb.namedScan(tableName, columnNames, columnTypes);
+
+ // Without projection, getRecordType() should return the full base schema
+ Type.Struct expected = R.struct(R.I32, R.STRING, R.I64);
+ assertEquals(expected, namedScan.getRecordType());
+ }
+
+ @Test
+ void recordTypeWithProjectionAndRemap() {
+ List tableName = Stream.of("remap_table").collect(Collectors.toList());
+ List columnNames = Stream.of("col_a", "col_b", "col_c").collect(Collectors.toList());
+ List columnTypes = Stream.of(R.I32, R.STRING, R.I64).collect(Collectors.toList());
+
+ // Select fields 0 (I32) and 2 (I64), then remap to reverse order: [1, 0] -> (I64, I32)
+ MaskExpression projection =
+ MaskExpression.builder()
+ .select(
+ MaskExpression.StructSelect.builder()
+ .addStructItems(MaskExpression.StructItem.of(0))
+ .addStructItems(MaskExpression.StructItem.of(2))
+ .build())
+ .build();
+
+ NamedScan namedScan =
+ NamedScan.builder()
+ .from(sb.namedScan(tableName, columnNames, columnTypes))
+ .projection(projection)
+ .remap(Rel.Remap.of(Stream.of(1, 0).collect(Collectors.toList())))
+ .build();
+
+ // Projection yields struct(I32, I64), remap [1, 0] reorders to struct(I64, I32)
+ Type.Struct expected = R.struct(R.I64, R.I32);
+ assertEquals(expected, namedScan.getRecordType());
+ }
}