Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -1507,7 +1507,24 @@ public Expression.WindowFunctionInvocation windowFn(
* @return a new {@link Type.UserDefined}
*/
public Type.UserDefined userDefinedType(String urn, String typeName) {
return Type.UserDefined.builder().urn(urn).name(typeName).nullable(false).build();
return userDefinedType(urn, typeName, 0);
}

/**
* Creates a user-defined type with the specified URN, type name, and type variation.
*
* @param urn the URN of the extension containing the type
* @param typeName the name of the user-defined type
* @param typeVariationReference the type variation reference
* @return a new {@link Type.UserDefined}
*/
public Type.UserDefined userDefinedType(String urn, String typeName, int typeVariationReference) {
return Type.UserDefined.builder()
.urn(urn)
.name(typeName)
.nullable(false)
.typeVariationReference(typeVariationReference)
.build();
}

// Misc
Expand Down
25 changes: 19 additions & 6 deletions core/src/main/java/io/substrait/type/Type.java
Original file line number Diff line number Diff line change
Expand Up @@ -422,11 +422,11 @@ public <R, E extends Throwable> R accept(final TypeVisitor<R, E> typeVisitor) th
}

@Value.Immutable
abstract class UserDefined implements Type {
interface UserDefined extends Type {

public abstract String urn();
String urn();

public abstract String name();
String name();

/**
* Returns the type parameters for this user-defined type.
Expand All @@ -437,16 +437,29 @@ abstract class UserDefined implements Type {
* @return a list of type parameters, or an empty list if this type is not parameterized
*/
@Value.Default
public java.util.List<Parameter> typeParameters() {
default java.util.List<Parameter> typeParameters() {
return java.util.Collections.emptyList();
}

public static ImmutableType.UserDefined.Builder builder() {
/**
* Returns the type variation reference for this user-defined type.
*
* <p>Type variations allow different physical representations or semantics for the same logical
* type. The reference value maps to an {@code ExtensionTypeVariation} declaration in the plan.
*
* @return the type variation reference, or {@code 0} if using the default variation
*/
@Value.Default
default int typeVariationReference() {
return 0;
}

static ImmutableType.UserDefined.Builder builder() {
return ImmutableType.UserDefined.builder();
}

@Override
public <R, E extends Throwable> R accept(TypeVisitor<R, E> typeVisitor) throws E {
default <R, E extends Throwable> R accept(TypeVisitor<R, E> typeVisitor) throws E {
return typeVisitor.visit(this);
}
}
Expand Down
11 changes: 10 additions & 1 deletion core/src/main/java/io/substrait/type/TypeCreator.java
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,16 @@ public Type.Map map(Type key, Type value) {
}

public Type userDefined(String urn, String name) {
return Type.UserDefined.builder().nullable(nullable).urn(urn).name(name).build();
return userDefined(urn, name, 0);
}

public Type userDefined(String urn, String name, int typeVariationReference) {
return Type.UserDefined.builder()
.nullable(nullable)
.urn(urn)
.name(name)
.typeVariationReference(typeVariationReference)
.build();
}

public static TypeCreator of(boolean nullability) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ public final T visit(final Type.Map expr) {
public final T visit(final Type.UserDefined expr) {
int ref =
extensionCollector.getTypeReference(SimpleExtension.TypeAnchor.of(expr.urn(), expr.name()));
return typeContainer(expr).userDefined(ref, expr.typeParameters());
return typeContainer(expr)
.userDefined(ref, expr.typeVariationReference(), expr.typeParameters());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,14 @@ public final T struct(T... types) {

public abstract T map(T key, T value);

public abstract T userDefined(int ref);
public T userDefined(int ref, int typeVariationReference) {
return userDefined(ref, typeVariationReference, java.util.Collections.emptyList());
}

public abstract T userDefined(
int ref, java.util.List<io.substrait.type.Type.Parameter> typeParameters);
int ref,
int typeVariationReference,
java.util.List<io.substrait.type.Type.Parameter> typeParameters);

protected abstract T wrap(Object o);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ public Type from(io.substrait.proto.Type type) {
userDefined.getTypeParametersList().stream()
.map(this::from)
.collect(java.util.stream.Collectors.toList()))
.typeVariationReference(userDefined.getTypeVariationReference())
.build();
}
case USER_DEFINED_TYPE_REFERENCE:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,18 +180,15 @@ public Type map(Type key, Type value) {
Type.Map.newBuilder().setKey(key).setValue(value).setNullability(nullability).build());
}

@Override
public Type userDefined(int ref) {
return wrap(
Type.UserDefined.newBuilder().setTypeReference(ref).setNullability(nullability).build());
}

@Override
public Type userDefined(
int ref, java.util.List<io.substrait.type.Type.Parameter> typeParameters) {
int ref,
int typeVariationReference,
java.util.List<io.substrait.type.Type.Parameter> typeParameters) {
return wrap(
Type.UserDefined.newBuilder()
.setTypeReference(ref)
.setTypeVariationReference(typeVariationReference)
.setNullability(nullability)
.addAllTypeParameters(
typeParameters.stream()
Expand Down
18 changes: 18 additions & 0 deletions core/src/test/java/io/substrait/extension/TypeExtensionTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,24 @@ void roundtripCustomType() {
io.substrait.proto.Plan protoPlan = planProtoConverter.toProto(plan);
Plan planReturned = protoPlanConverter.from(protoPlan);
assertEquals(plan, planReturned);

// verify default typeVariationReference is 0
assertEquals(0, ((Type.UserDefined) customType1).typeVariationReference());
}

@Test
void roundtripCustomTypeWithVariationReference() {
Type customTypeWithVariation = sb.userDefinedType(URN, "customType1", 42);

List<String> tableName = Stream.of("example").collect(Collectors.toList());
List<String> columnNames = Stream.of("custom_type_column").collect(Collectors.toList());
List<Type> types = Stream.of(customTypeWithVariation).collect(Collectors.toList());

Plan plan = sb.plan(sb.root(sb.namedScan(tableName, columnNames, types)));

io.substrait.proto.Plan protoPlan = planProtoConverter.toProto(plan);
Plan planReturned = protoPlanConverter.from(protoPlan);
assertEquals(plan, planReturned);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of adding this here, what do you think about just adding a JSON file here and adding it to the tests here? This way you don't have to do any of the hand construction of the plan, and you don't have to check for the presence of other fields. Just personal preference though, up to you ultimately.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I simplified it to just compare the full plan roundtrip which keeps it pretty minimal. I looked into extending PlanRoundtripTest to support custom extensions for UDTs but it's a bit involved so maybe in a followup.

}

@Test
Expand Down
Loading