diff --git a/core/build.gradle.kts b/core/build.gradle.kts index 421cee46c..5e84490ae 100644 --- a/core/build.gradle.kts +++ b/core/build.gradle.kts @@ -88,6 +88,8 @@ dependencies { testImplementation(platform(libs.junit.bom)) testImplementation(libs.protobuf.java.util) testImplementation(libs.guava) + testImplementation(libs.bundles.jackson) + testImplementation(libs.classgraph) testImplementation(libs.junit.jupiter) testRuntimeOnly(libs.junit.platform.launcher) @@ -229,13 +231,18 @@ sourceSets { main { antlr { setSrcDirs(listOf(file("${rootProject.projectDir}/substrait/grammar"))) } proto.srcDir("../substrait/proto") - resources.srcDir("../substrait/extensions") + // Extension YAMLs are relocated into substrait/extensions/ on the classpath + // via processResources below, rather than landing at the classpath root. resources.srcDir("build/generated/sources/manifest/") java.srcDir(file("build/generated/sources/antlr/main/java/")) java.srcDir("build/generated/sources/version/") } } +tasks.named("processResources") { + from("../substrait/extensions") { into("substrait/extensions") } +} + project.configure { module { resourceDirs.addAll( diff --git a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java index d8b6b1fa6..25dc76cad 100644 --- a/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java +++ b/core/src/main/java/io/substrait/extension/DefaultExtensionCatalog.java @@ -14,6 +14,10 @@ public class DefaultExtensionCatalog { public static final String FUNCTIONS_AGGREGATE_APPROX = "extension:io.substrait:functions_aggregate_approx"; + /** Extension identifier for aggregate functions with decimal output. */ + public static final String FUNCTIONS_AGGREGATE_DECIMAL_OUTPUT = + "extension:io.substrait:functions_aggregate_decimal_output"; + /** Extension identifier for generic aggregate functions. */ public static final String FUNCTIONS_AGGREGATE_GENERIC = "extension:io.substrait:functions_aggregate_generic"; @@ -82,12 +86,13 @@ private static SimpleExtension.ExtensionCollection loadDefaultCollection() { "logarithmic", "rounding", "rounding_decimal", + "set", "string") .stream() - .map(c -> String.format("/functions_%s.yaml", c)) + .map(c -> String.format("/substrait/extensions/functions_%s.yaml", c)) .collect(Collectors.toList()); - defaultFiles.add("/extension_types.yaml"); + defaultFiles.add("/substrait/extensions/extension_types.yaml"); return SimpleExtension.load(defaultFiles); } diff --git a/core/src/main/java/io/substrait/extension/SimpleExtension.java b/core/src/main/java/io/substrait/extension/SimpleExtension.java index d98c632f6..1b3f05820 100644 --- a/core/src/main/java/io/substrait/extension/SimpleExtension.java +++ b/core/src/main/java/io/substrait/extension/SimpleExtension.java @@ -718,6 +718,11 @@ public ScalarFunctionVariant getScalarFunction(FunctionAnchor anchor) { anchor.key(), anchor.urn())); } + /** Returns true if the given URN has any functions or types loaded in this collection. */ + public boolean containsUrn(String urn) { + return urnSupplier.get().contains(urn) || types().stream().anyMatch(t -> t.urn().equals(urn)); + } + private void checkUrn(String name) { if (urnSupplier.get().contains(name)) { return; diff --git a/core/src/test/java/io/substrait/extension/DefaultExtensionCatalogTest.java b/core/src/test/java/io/substrait/extension/DefaultExtensionCatalogTest.java index 7aed4e961..c98637de7 100644 --- a/core/src/test/java/io/substrait/extension/DefaultExtensionCatalogTest.java +++ b/core/src/test/java/io/substrait/extension/DefaultExtensionCatalogTest.java @@ -1,13 +1,97 @@ package io.substrait.extension; +import static io.substrait.extension.DefaultExtensionCatalog.DEFAULT_COLLECTION; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; +import io.github.classgraph.ClassGraph; +import io.github.classgraph.ResourceList; +import io.github.classgraph.ScanResult; +import java.io.IOException; +import java.io.InputStream; +import java.util.List; +import java.util.Set; import org.junit.jupiter.api.Test; +/** + * Verifies that every extension YAML in substrait/extensions is loaded by {@link + * DefaultExtensionCatalog}. + */ class DefaultExtensionCatalogTest { + /** + * Extension YAML files that are intentionally not loaded by {@link DefaultExtensionCatalog}. + * Adding a file here requires a comment explaining why it cannot be loaded. + */ + private static final Set UNSUPPORTED_FILES = + Set.of( + // aggregate_decimal_output defines count and approx_count_distinct with decimal<38,0> + // return types instead of i64. When loaded alongside aggregate_generic, the same + // function key (e.g. count:any) maps to the same Calcite operator twice, which breaks + // the reverse lookup in FunctionConverter.getSqlOperatorFromSubstraitFunc. + "functions_aggregate_decimal_output.yaml", + // functions_geometry.yaml defines user-defined types (u!geometry) that are not + // supported by Calcite's type conversion in isthmus. + "functions_geometry.yaml", + // type_variations.yaml only defines type variations, which are not tracked by + // ExtensionCollection (no functions or types), so containsUrn cannot verify it. + "type_variations.yaml", + // unknown.yaml uses an "unknown" type that is not a recognized type literal or + // parameterized type, causing Function.constructKey to fail at load time. + "unknown.yaml"); + + private static final ObjectMapper YAML_MAPPER = new ObjectMapper(new YAMLFactory()); + @Test void defaultCollectionLoads() { - assertNotNull(DefaultExtensionCatalog.DEFAULT_COLLECTION); + assertNotNull(DEFAULT_COLLECTION); + } + + @Test + void allExtensionYamlFilesAccountedFor() throws IOException { + List yamlFiles = getExtensionYamlFiles(); + assertFalse(yamlFiles.isEmpty(), "No YAML files found in substrait/extensions/"); + + for (String fileName : yamlFiles) { + String urn = parseUrn(fileName); + assertNotNull(urn, fileName + " does not contain a URN field"); + if (UNSUPPORTED_FILES.contains(fileName)) { + assertFalse( + DEFAULT_COLLECTION.containsUrn(urn), + fileName + + " is in UNSUPPORTED_FILES but is loaded by DefaultExtensionCatalog" + + " — remove it from UNSUPPORTED_FILES"); + } else { + assertTrue( + DEFAULT_COLLECTION.containsUrn(urn), + fileName + " not loaded by DefaultExtensionCatalog (urn: " + urn + ")"); + } + } + } + + private static String parseUrn(String resourceName) throws IOException { + String resourcePath = "substrait/extensions/" + resourceName; + try (InputStream is = + DefaultExtensionCatalogTest.class.getClassLoader().getResourceAsStream(resourcePath)) { + assertNotNull(is, "Resource not found on classpath: " + resourcePath); + JsonNode doc = YAML_MAPPER.readTree(is); + JsonNode urnNode = doc.get("urn"); + return urnNode == null ? null : urnNode.asText(); + } + } + + private static List getExtensionYamlFiles() { + try (ScanResult scan = + new ClassGraph().acceptPathsNonRecursive("substrait/extensions").scan()) { + ResourceList resources = scan.getResourcesWithExtension(".yaml"); + return resources.stream() + .map(r -> r.getPath().substring("substrait/extensions/".length())) + .sorted() + .toList(); + } } }