diff --git a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java index 0f58da6af..54d4f63a3 100644 --- a/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java +++ b/core/src/main/java/io/substrait/dsl/SubstraitBuilder.java @@ -1507,7 +1507,24 @@ public Expression.WindowFunctionInvocation windowFn( * @return a new {@link Type.UserDefined} */ public Type.UserDefined userDefinedType(String urn, String typeName) { - return Type.UserDefined.builder().urn(urn).name(typeName).nullable(false).build(); + return userDefinedType(urn, typeName, 0); + } + + /** + * Creates a user-defined type with the specified URN, type name, and type variation. + * + * @param urn the URN of the extension containing the type + * @param typeName the name of the user-defined type + * @param typeVariationReference the type variation reference + * @return a new {@link Type.UserDefined} + */ + public Type.UserDefined userDefinedType(String urn, String typeName, int typeVariationReference) { + return Type.UserDefined.builder() + .urn(urn) + .name(typeName) + .nullable(false) + .typeVariationReference(typeVariationReference) + .build(); } // Misc diff --git a/core/src/main/java/io/substrait/type/Type.java b/core/src/main/java/io/substrait/type/Type.java index e82f8c1cf..801e83880 100644 --- a/core/src/main/java/io/substrait/type/Type.java +++ b/core/src/main/java/io/substrait/type/Type.java @@ -422,11 +422,11 @@ public R accept(final TypeVisitor typeVisitor) th } @Value.Immutable - abstract class UserDefined implements Type { + interface UserDefined extends Type { - public abstract String urn(); + String urn(); - public abstract String name(); + String name(); /** * Returns the type parameters for this user-defined type. @@ -437,16 +437,29 @@ abstract class UserDefined implements Type { * @return a list of type parameters, or an empty list if this type is not parameterized */ @Value.Default - public java.util.List typeParameters() { + default java.util.List typeParameters() { return java.util.Collections.emptyList(); } - public static ImmutableType.UserDefined.Builder builder() { + /** + * Returns the type variation reference for this user-defined type. + * + *

Type variations allow different physical representations or semantics for the same logical + * type. The reference value maps to an {@code ExtensionTypeVariation} declaration in the plan. + * + * @return the type variation reference, or {@code 0} if using the default variation + */ + @Value.Default + default int typeVariationReference() { + return 0; + } + + static ImmutableType.UserDefined.Builder builder() { return ImmutableType.UserDefined.builder(); } @Override - public R accept(TypeVisitor typeVisitor) throws E { + default R accept(TypeVisitor typeVisitor) throws E { return typeVisitor.visit(this); } } diff --git a/core/src/main/java/io/substrait/type/TypeCreator.java b/core/src/main/java/io/substrait/type/TypeCreator.java index 6a897417e..0362223bd 100644 --- a/core/src/main/java/io/substrait/type/TypeCreator.java +++ b/core/src/main/java/io/substrait/type/TypeCreator.java @@ -112,7 +112,16 @@ public Type.Map map(Type key, Type value) { } public Type userDefined(String urn, String name) { - return Type.UserDefined.builder().nullable(nullable).urn(urn).name(name).build(); + return userDefined(urn, name, 0); + } + + public Type userDefined(String urn, String name, int typeVariationReference) { + return Type.UserDefined.builder() + .nullable(nullable) + .urn(urn) + .name(name) + .typeVariationReference(typeVariationReference) + .build(); } public static TypeCreator of(boolean nullability) { diff --git a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java index 3909d4033..898fe87d0 100644 --- a/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/BaseProtoConverter.java @@ -175,6 +175,7 @@ public final T visit(final Type.Map expr) { public final T visit(final Type.UserDefined expr) { int ref = extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name())); - return typeContainer(expr).userDefined(ref, expr.typeParameters()); + return typeContainer(expr) + .userDefined(ref, expr.typeVariationReference(), expr.typeParameters()); } } diff --git a/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java b/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java index 47842382f..c6490c9b1 100644 --- a/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java +++ b/core/src/main/java/io/substrait/type/proto/BaseProtoTypes.java @@ -131,10 +131,14 @@ public final T struct(T... types) { public abstract T map(T key, T value); - public abstract T userDefined(int ref); + public T userDefined(int ref, int typeVariationReference) { + return userDefined(ref, typeVariationReference, java.util.Collections.emptyList()); + } public abstract T userDefined( - int ref, java.util.List typeParameters); + int ref, + int typeVariationReference, + java.util.List typeParameters); protected abstract T wrap(Object o); diff --git a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java index 24231aebc..00b2532b7 100644 --- a/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java +++ b/core/src/main/java/io/substrait/type/proto/ProtoTypeConverter.java @@ -107,6 +107,7 @@ public Type from(io.substrait.proto.Type type) { userDefined.getTypeParametersList().stream() .map(this::from) .collect(java.util.stream.Collectors.toList())) + .typeVariationReference(userDefined.getTypeVariationReference()) .build(); } case USER_DEFINED_TYPE_REFERENCE: diff --git a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java index 102c295a2..aa7b4d6a6 100644 --- a/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java +++ b/core/src/main/java/io/substrait/type/proto/TypeProtoConverter.java @@ -180,18 +180,15 @@ public Type map(Type key, Type value) { Type.Map.newBuilder().setKey(key).setValue(value).setNullability(nullability).build()); } - @Override - public Type userDefined(int ref) { - return wrap( - Type.UserDefined.newBuilder().setTypeReference(ref).setNullability(nullability).build()); - } - @Override public Type userDefined( - int ref, java.util.List typeParameters) { + int ref, + int typeVariationReference, + java.util.List typeParameters) { return wrap( Type.UserDefined.newBuilder() .setTypeReference(ref) + .setTypeVariationReference(typeVariationReference) .setNullability(nullability) .addAllTypeParameters( typeParameters.stream() diff --git a/core/src/test/java/io/substrait/extension/TypeExtensionTest.java b/core/src/test/java/io/substrait/extension/TypeExtensionTest.java index 3a3c2ec63..c5d709a26 100644 --- a/core/src/test/java/io/substrait/extension/TypeExtensionTest.java +++ b/core/src/test/java/io/substrait/extension/TypeExtensionTest.java @@ -78,6 +78,24 @@ void roundtripCustomType() { io.substrait.proto.Plan protoPlan = planProtoConverter.toProto(plan); Plan planReturned = protoPlanConverter.from(protoPlan); assertEquals(plan, planReturned); + + // verify default typeVariationReference is 0 + assertEquals(0, ((Type.UserDefined) customType1).typeVariationReference()); + } + + @Test + void roundtripCustomTypeWithVariationReference() { + Type customTypeWithVariation = sb.userDefinedType(URN, "customType1", 42); + + List tableName = Stream.of("example").collect(Collectors.toList()); + List columnNames = Stream.of("custom_type_column").collect(Collectors.toList()); + List types = Stream.of(customTypeWithVariation).collect(Collectors.toList()); + + Plan plan = sb.plan(sb.root(sb.namedScan(tableName, columnNames, types))); + + io.substrait.proto.Plan protoPlan = planProtoConverter.toProto(plan); + Plan planReturned = protoPlanConverter.from(protoPlan); + assertEquals(plan, planReturned); } @Test