Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion core/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ dependencies {
testImplementation(platform(libs.junit.bom))
testImplementation(libs.protobuf.java.util)
testImplementation(libs.guava)
testImplementation(libs.bundles.jackson)
testImplementation(libs.classgraph)

testImplementation(libs.junit.jupiter)
testRuntimeOnly(libs.junit.platform.launcher)
Expand Down Expand Up @@ -229,13 +231,18 @@ sourceSets {
main {
antlr { setSrcDirs(listOf(file("${rootProject.projectDir}/substrait/grammar"))) }
proto.srcDir("../substrait/proto")
resources.srcDir("../substrait/extensions")
// Extension YAMLs are relocated into substrait/extensions/ on the classpath
// via processResources below, rather than landing at the classpath root.
resources.srcDir("build/generated/sources/manifest/")
java.srcDir(file("build/generated/sources/antlr/main/java/"))
java.srcDir("build/generated/sources/version/")
}
}

tasks.named<ProcessResources>("processResources") {
from("../substrait/extensions") { into("substrait/extensions") }
}

project.configure<IdeaModel> {
module {
resourceDirs.addAll(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ public class DefaultExtensionCatalog {
public static final String FUNCTIONS_AGGREGATE_APPROX =
"extension:io.substrait:functions_aggregate_approx";

/** Extension identifier for aggregate functions with decimal output. */
public static final String FUNCTIONS_AGGREGATE_DECIMAL_OUTPUT =
"extension:io.substrait:functions_aggregate_decimal_output";

/** Extension identifier for generic aggregate functions. */
public static final String FUNCTIONS_AGGREGATE_GENERIC =
"extension:io.substrait:functions_aggregate_generic";
Expand Down Expand Up @@ -82,12 +86,13 @@ private static SimpleExtension.ExtensionCollection loadDefaultCollection() {
"logarithmic",
"rounding",
"rounding_decimal",
"set",
"string")
.stream()
.map(c -> String.format("/functions_%s.yaml", c))
.map(c -> String.format("/substrait/extensions/functions_%s.yaml", c))
.collect(Collectors.toList());

defaultFiles.add("/extension_types.yaml");
defaultFiles.add("/substrait/extensions/extension_types.yaml");

return SimpleExtension.load(defaultFiles);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,11 @@ public ScalarFunctionVariant getScalarFunction(FunctionAnchor anchor) {
anchor.key(), anchor.urn()));
}

/** Returns true if the given URN has any functions or types loaded in this collection. */
public boolean containsUrn(String urn) {
return urnSupplier.get().contains(urn) || types().stream().anyMatch(t -> t.urn().equals(urn));
}

private void checkUrn(String name) {
if (urnSupplier.get().contains(name)) {
return;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,97 @@
package io.substrait.extension;

import static io.substrait.extension.DefaultExtensionCatalog.DEFAULT_COLLECTION;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.dataformat.yaml.YAMLFactory;
import io.github.classgraph.ClassGraph;
import io.github.classgraph.ResourceList;
import io.github.classgraph.ScanResult;
import java.io.IOException;
import java.io.InputStream;
import java.util.List;
import java.util.Set;
import org.junit.jupiter.api.Test;

/**
* Verifies that every extension YAML in substrait/extensions is loaded by {@link
* DefaultExtensionCatalog}.
*/
class DefaultExtensionCatalogTest {

/**
* Extension YAML files that are intentionally not loaded by {@link DefaultExtensionCatalog}.
* Adding a file here requires a comment explaining why it cannot be loaded.
*/
private static final Set<String> UNSUPPORTED_FILES =
Set.of(
// aggregate_decimal_output defines count and approx_count_distinct with decimal<38,0>
// return types instead of i64. When loaded alongside aggregate_generic, the same
// function key (e.g. count:any) maps to the same Calcite operator twice, which breaks
// the reverse lookup in FunctionConverter.getSqlOperatorFromSubstraitFunc.
"functions_aggregate_decimal_output.yaml",
// functions_geometry.yaml defines user-defined types (u!geometry) that are not
// supported by Calcite's type conversion in isthmus.
"functions_geometry.yaml",
// type_variations.yaml only defines type variations, which are not tracked by
// ExtensionCollection (no functions or types), so containsUrn cannot verify it.
"type_variations.yaml",
// unknown.yaml uses an "unknown" type that is not a recognized type literal or
// parameterized type, causing Function.constructKey to fail at load time.
"unknown.yaml");

private static final ObjectMapper YAML_MAPPER = new ObjectMapper(new YAMLFactory());

@Test
void defaultCollectionLoads() {
assertNotNull(DefaultExtensionCatalog.DEFAULT_COLLECTION);
assertNotNull(DEFAULT_COLLECTION);
}

@Test
void allExtensionYamlFilesAccountedFor() throws IOException {
List<String> yamlFiles = getExtensionYamlFiles();
assertFalse(yamlFiles.isEmpty(), "No YAML files found in substrait/extensions/");

for (String fileName : yamlFiles) {
String urn = parseUrn(fileName);
assertNotNull(urn, fileName + " does not contain a URN field");
if (UNSUPPORTED_FILES.contains(fileName)) {
assertFalse(
DEFAULT_COLLECTION.containsUrn(urn),
fileName
+ " is in UNSUPPORTED_FILES but is loaded by DefaultExtensionCatalog"
+ " — remove it from UNSUPPORTED_FILES");
} else {
assertTrue(
DEFAULT_COLLECTION.containsUrn(urn),
fileName + " not loaded by DefaultExtensionCatalog (urn: " + urn + ")");
}
}
}

private static String parseUrn(String resourceName) throws IOException {
String resourcePath = "substrait/extensions/" + resourceName;
try (InputStream is =
DefaultExtensionCatalogTest.class.getClassLoader().getResourceAsStream(resourcePath)) {
assertNotNull(is, "Resource not found on classpath: " + resourcePath);
JsonNode doc = YAML_MAPPER.readTree(is);
JsonNode urnNode = doc.get("urn");
return urnNode == null ? null : urnNode.asText();
}
}

private static List<String> getExtensionYamlFiles() {
try (ScanResult scan =
new ClassGraph().acceptPathsNonRecursive("substrait/extensions").scan()) {
ResourceList resources = scan.getResourcesWithExtension(".yaml");
return resources.stream()
.map(r -> r.getPath().substring("substrait/extensions/".length()))
.sorted()
.toList();
}
}
}
Loading