From 70c3db265e5429577632f9d0c13cf9012168135f Mon Sep 17 00:00:00 2001 From: Gary Date: Tue, 27 Jan 2026 20:22:56 +0800 Subject: [PATCH 01/27] implement schema template cast --- .../doris/nereids/jobs/executor/Rewriter.java | 9 + .../apache/doris/nereids/rules/RuleType.java | 1 + .../rules/rewrite/VariantSchemaCast.java | 169 ++++++++++++++++++ .../doris/nereids/types/VariantField.java | 34 ++++ .../doris/nereids/types/VariantType.java | 17 ++ 5 files changed, 230 insertions(+) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index 9e3186c6cec7b2..c608aab54b219f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -168,6 +168,7 @@ import org.apache.doris.nereids.rules.rewrite.TransposeSemiJoinAggProject; import org.apache.doris.nereids.rules.rewrite.TransposeSemiJoinLogicalJoin; import org.apache.doris.nereids.rules.rewrite.TransposeSemiJoinLogicalJoinProject; +import org.apache.doris.nereids.rules.rewrite.VariantSchemaCast; import org.apache.doris.nereids.rules.rewrite.VariantSubPathPruning; import org.apache.doris.nereids.rules.rewrite.batch.ApplyToJoin; import org.apache.doris.nereids.rules.rewrite.batch.CorrelateApplyToUnCorrelateApply; @@ -914,6 +915,14 @@ private static List getWholeTreeRewriteJobs( bottomUp(new RewriteSearchToSlots()) )); + // Auto cast variant element access based on schema template + // This should run before VariantSubPathPruning + rewriteJobs.addAll(jobs( + topic("variant schema cast", + custom(RuleType.VARIANT_SCHEMA_CAST, VariantSchemaCast::new) + ) + )); + if (needSubPathPushDown) { rewriteJobs.addAll(jobs( topic("variant element_at push down", diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 2587481256b798..1dcce5c159bff3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -224,6 +224,7 @@ public enum RuleType { ADD_PROJECT_FOR_JOIN(RuleTypeClass.REWRITE), ADD_PROJECT_FOR_UNIQUE_FUNCTION(RuleTypeClass.REWRITE), + VARIANT_SCHEMA_CAST(RuleTypeClass.REWRITE), VARIANT_SUB_PATH_PRUNING(RuleTypeClass.REWRITE), NESTED_COLUMN_PRUNING(RuleTypeClass.REWRITE), CLEAR_CONTEXT_STATUS(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java new file mode 100644 index 00000000000000..7cdec24866e263 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java @@ -0,0 +1,169 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.jobs.JobContext; +import org.apache.doris.nereids.properties.OrderKey; +import org.apache.doris.nereids.trees.expressions.Cast; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt; +import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.logical.LogicalSort; +import org.apache.doris.nereids.trees.plans.logical.LogicalTopN; +import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; +import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.VariantField; +import org.apache.doris.nereids.types.VariantType; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import java.util.List; +import java.util.Optional; +import java.util.Set; +import java.util.function.Function; + +/** + * Automatically cast variant element access expressions based on schema template. + * + * For example, if a variant column is defined as: + * payload VARIANT<'number_*': BIGINT, 'string_*': STRING> + * + * Then payload['number_latency'] will be automatically cast to BIGINT, + * and payload['string_message'] will be automatically cast to STRING. + * + * This allows users to use variant sub-fields directly in WHERE, ORDER BY, + * and other clauses without explicit CAST. + */ +public class VariantSchemaCast implements CustomRewriter { + + @Override + public Plan rewriteRoot(Plan plan, JobContext jobContext) { + return plan.accept(PlanRewriter.INSTANCE, null); + } + + private static class PlanRewriter extends DefaultPlanRewriter { + public static final PlanRewriter INSTANCE = new PlanRewriter(); + + private static final Function EXPRESSION_REWRITER = expr -> { + if (!(expr instanceof ElementAt)) { + return expr; + } + ElementAt elementAt = (ElementAt) expr; + Expression left = elementAt.left(); + Expression right = elementAt.right(); + + // Only process if left is VariantType and right is a string literal + if (!(left.getDataType() instanceof VariantType)) { + return expr; + } + if (!(right instanceof StringLikeLiteral)) { + return expr; + } + + VariantType variantType = (VariantType) left.getDataType(); + String fieldName = ((StringLikeLiteral) right).getStringValue(); + + // Find matching field in schema template + Optional matchingField = variantType.findMatchingField(fieldName); + if (!matchingField.isPresent()) { + return expr; + } + + DataType targetType = matchingField.get().getDataType(); + + // Wrap with Cast + return new Cast(elementAt, targetType); + }; + + private Expression rewriteExpression(Expression expr) { + return expr.rewriteDownShortCircuit(EXPRESSION_REWRITER); + } + + @Override + public Plan visitLogicalFilter(LogicalFilter filter, Void context) { + filter = (LogicalFilter) super.visit(filter, context); + Set newConjuncts = filter.getConjuncts().stream() + .map(this::rewriteExpression) + .collect(ImmutableSet.toImmutableSet()); + return filter.withConjuncts(newConjuncts); + } + + @Override + public Plan visitLogicalProject(LogicalProject project, Void context) { + project = (LogicalProject) super.visit(project, context); + List newProjects = project.getProjects().stream() + .map(expr -> (NamedExpression) rewriteExpression(expr)) + .collect(ImmutableList.toImmutableList()); + return project.withProjects(newProjects); + } + + @Override + public Plan visitLogicalSort(LogicalSort sort, Void context) { + sort = (LogicalSort) super.visit(sort, context); + List newOrderKeys = sort.getOrderKeys().stream() + .map(orderKey -> orderKey.withExpression(rewriteExpression(orderKey.getExpr()))) + .collect(ImmutableList.toImmutableList()); + return sort.withOrderKeys(newOrderKeys); + } + + @Override + public Plan visitLogicalTopN(LogicalTopN topN, Void context) { + topN = (LogicalTopN) super.visit(topN, context); + List newOrderKeys = topN.getOrderKeys().stream() + .map(orderKey -> orderKey.withExpression(rewriteExpression(orderKey.getExpr()))) + .collect(ImmutableList.toImmutableList()); + return topN.withOrderKeys(newOrderKeys); + } + + @Override + public Plan visitLogicalJoin(LogicalJoin join, Void context) { + join = (LogicalJoin) super.visit(join, context); + List newHashConditions = join.getHashJoinConjuncts().stream() + .map(this::rewriteExpression) + .collect(ImmutableList.toImmutableList()); + List newOtherConditions = join.getOtherJoinConjuncts().stream() + .map(this::rewriteExpression) + .collect(ImmutableList.toImmutableList()); + List newMarkConditions = join.getMarkJoinConjuncts().stream() + .map(this::rewriteExpression) + .collect(ImmutableList.toImmutableList()); + return join.withJoinConjuncts(newHashConditions, newOtherConditions, + newMarkConditions, join.getJoinReorderContext()); + } + + @Override + public Plan visitLogicalAggregate(LogicalAggregate aggregate, Void context) { + aggregate = (LogicalAggregate) super.visit(aggregate, context); + List newGroupByKeys = aggregate.getGroupByExpressions().stream() + .map(this::rewriteExpression) + .collect(ImmutableList.toImmutableList()); + List newOutputs = aggregate.getOutputExpressions().stream() + .map(expr -> (NamedExpression) rewriteExpression(expr)) + .collect(ImmutableList.toImmutableList()); + return aggregate.withGroupByAndOutput(newGroupByKeys, newOutputs); + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantField.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantField.java index 5faed6893be958..57eeebae93baa0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantField.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantField.java @@ -67,6 +67,40 @@ public String getComment() { return comment; } + public TPatternType getPatternType() { + return patternType; + } + + /** + * Check if the given field name matches this field's pattern. + * @param fieldName the field name to check + * @return true if the field name matches the pattern + */ + public boolean matches(String fieldName) { + if (patternType == TPatternType.MATCH_NAME) { + return pattern.equals(fieldName); + } else { + // MATCH_NAME_GLOB: convert glob pattern to regex + // Escape regex special characters except *, then replace * with .* + String regex = pattern + .replace(".", "\\.") + .replace("?", "\\?") + .replace("[", "\\[") + .replace("]", "\\]") + .replace("(", "\\(") + .replace(")", "\\)") + .replace("{", "\\{") + .replace("}", "\\}") + .replace("+", "\\+") + .replace("^", "\\^") + .replace("$", "\\$") + .replace("|", "\\|") + .replace("\\", "\\\\") + .replace("*", ".*"); + return fieldName.matches(regex); + } + } + public org.apache.doris.catalog.VariantField toCatalogDataType() { return new org.apache.doris.catalog.VariantField( pattern, dataType.toCatalogDataType(), comment, patternType); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantType.java index 337658520e4123..af25e1f9061f2f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantType.java @@ -26,6 +26,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Objects; +import java.util.Optional; import java.util.stream.Collectors; /** @@ -232,6 +233,22 @@ public List getPredefinedFields() { return predefinedFields; } + /** + * Find the first matching VariantField for the given field name. + * The matching is done in definition order, so the first matching pattern wins. + * + * @param fieldName the field name to match + * @return Optional containing the matching VariantField, or empty if no match + */ + public Optional findMatchingField(String fieldName) { + for (VariantField field : predefinedFields) { + if (field.matches(fieldName)) { + return Optional.of(field); + } + } + return Optional.empty(); + } + public int getVariantMaxSubcolumnsCount() { return variantMaxSubcolumnsCount; } From a365d06ac86d49be5edae316c1dfd0410f7532e7 Mon Sep 17 00:00:00 2001 From: Gary Date: Wed, 28 Jan 2026 22:40:23 +0800 Subject: [PATCH 02/27] fix and test --- .../doris/nereids/jobs/executor/Rewriter.java | 18 +- .../rules/rewrite/VariantSchemaCast.java | 53 ++- .../functions/scalar/ElementAt.java | 8 +- .../doris/nereids/types/VariantField.java | 99 ++++- .../rules/rewrite/VariantSchemaCastTest.java | 386 ++++++++++++++++++ .../nereids/types/VariantFieldMatchTest.java | 173 ++++++++ .../test_schema_template_auto_cast.out | 34 ++ .../test_schema_template_auto_cast.groovy | 107 +++++ 8 files changed, 823 insertions(+), 55 deletions(-) create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCastTest.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/types/VariantFieldMatchTest.java create mode 100644 regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out create mode 100644 regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index c608aab54b219f..7e48f15af0d493 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -276,6 +276,11 @@ public class Rewriter extends AbstractBatchJobExecutor { new EliminateSemiJoin() ) ), + // Auto cast variant element access based on schema template + // This must run before NormalizeSort which converts ORDER BY expressions to slots + topic("variant schema cast before normalize", + custom(RuleType.VARIANT_SCHEMA_CAST, VariantSchemaCast::new) + ), // The rule modification needs to be done after the subquery is unnested, // because for scalarSubQuery, the connection condition is stored in apply in // the analyzer phase, @@ -513,6 +518,11 @@ public class Rewriter extends AbstractBatchJobExecutor { new SimplifyEncodeDecode() ) ), + // Auto cast variant element access based on schema template + // This must run before NormalizeSort which converts ORDER BY expressions to slots + topic("variant schema cast before normalize", + custom(RuleType.VARIANT_SCHEMA_CAST, VariantSchemaCast::new) + ), // The rule modification needs to be done after the subquery is unnested, // because for scalarSubQuery, the connection condition is stored in apply in the analyzer phase, // but when normalizeAggregate/normalizeSort is performed, the members in apply cannot be obtained, @@ -915,14 +925,6 @@ private static List getWholeTreeRewriteJobs( bottomUp(new RewriteSearchToSlots()) )); - // Auto cast variant element access based on schema template - // This should run before VariantSubPathPruning - rewriteJobs.addAll(jobs( - topic("variant schema cast", - custom(RuleType.VARIANT_SCHEMA_CAST, VariantSchemaCast::new) - ) - )); - if (needSubPathPushDown) { rewriteJobs.addAll(jobs( topic("variant element_at push down", diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java index 7cdec24866e263..223410f549f103 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java @@ -19,13 +19,13 @@ import org.apache.doris.nereids.jobs.JobContext; import org.apache.doris.nereids.properties.OrderKey; +import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt; import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral; import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; @@ -61,26 +61,29 @@ public class VariantSchemaCast implements CustomRewriter { @Override public Plan rewriteRoot(Plan plan, JobContext jobContext) { - return plan.accept(PlanRewriter.INSTANCE, null); + return plan.accept(new PlanRewriter(), null); } private static class PlanRewriter extends DefaultPlanRewriter { - public static final PlanRewriter INSTANCE = new PlanRewriter(); - private static final Function EXPRESSION_REWRITER = expr -> { - if (!(expr instanceof ElementAt)) { - return expr; + private final Function expressionRewriter = expr -> { + // Handle ElementAt expressions + if (expr instanceof ElementAt) { + return rewriteElementAt((ElementAt) expr); } - ElementAt elementAt = (ElementAt) expr; + return expr; + }; + + private Expression rewriteElementAt(ElementAt elementAt) { Expression left = elementAt.left(); Expression right = elementAt.right(); // Only process if left is VariantType and right is a string literal if (!(left.getDataType() instanceof VariantType)) { - return expr; + return elementAt; } if (!(right instanceof StringLikeLiteral)) { - return expr; + return elementAt; } VariantType variantType = (VariantType) left.getDataType(); @@ -89,17 +92,25 @@ private static class PlanRewriter extends DefaultPlanRewriter { // Find matching field in schema template Optional matchingField = variantType.findMatchingField(fieldName); if (!matchingField.isPresent()) { - return expr; + return elementAt; } DataType targetType = matchingField.get().getDataType(); - - // Wrap with Cast return new Cast(elementAt, targetType); - }; + } private Expression rewriteExpression(Expression expr) { - return expr.rewriteDownShortCircuit(EXPRESSION_REWRITER); + return expr.rewriteDownShortCircuit(expressionRewriter); + } + + private NamedExpression rewriteNamedExpression(NamedExpression expr) { + Expression rewritten = rewriteExpression(expr); + if (rewritten instanceof NamedExpression) { + return (NamedExpression) rewritten; + } + // If the result is not a NamedExpression (e.g., Cast), wrap it in an Alias + // Preserve the original ExprId to maintain consistency + return new Alias(expr.getExprId(), rewritten, expr.getName()); } @Override @@ -115,7 +126,7 @@ public Plan visitLogicalFilter(LogicalFilter filter, Void contex public Plan visitLogicalProject(LogicalProject project, Void context) { project = (LogicalProject) super.visit(project, context); List newProjects = project.getProjects().stream() - .map(expr -> (NamedExpression) rewriteExpression(expr)) + .map(this::rewriteNamedExpression) .collect(ImmutableList.toImmutableList()); return project.withProjects(newProjects); } @@ -153,17 +164,5 @@ public Plan visitLogicalJoin(LogicalJoin join, V return join.withJoinConjuncts(newHashConditions, newOtherConditions, newMarkConditions, join.getJoinReorderContext()); } - - @Override - public Plan visitLogicalAggregate(LogicalAggregate aggregate, Void context) { - aggregate = (LogicalAggregate) super.visit(aggregate, context); - List newGroupByKeys = aggregate.getGroupByExpressions().stream() - .map(this::rewriteExpression) - .collect(ImmutableList.toImmutableList()); - List newOutputs = aggregate.getOutputExpressions().stream() - .map(expr -> (NamedExpression) rewriteExpression(expr)) - .collect(ImmutableList.toImmutableList()); - return aggregate.withGroupByAndOutput(newGroupByKeys, newOutputs); - } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ElementAt.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ElementAt.java index 1716ae1f91e714..dd715c3c54d065 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ElementAt.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ElementAt.java @@ -104,10 +104,10 @@ public FunctionSignature computeSignature(FunctionSignature signature) { DataType expressionType = arguments.get(0).getDataType(); DataType sigType = signature.argumentsTypes.get(0); if (expressionType instanceof VariantType && sigType instanceof VariantType) { - // only keep the variant max subcolumns count - VariantType variantType = new VariantType(((VariantType) expressionType).getVariantMaxSubcolumnsCount()); - signature = signature.withArgumentType(0, variantType); - signature = signature.withReturnType(variantType); + // Preserve predefinedFields for schema template matching + VariantType originalType = (VariantType) expressionType; + signature = signature.withArgumentType(0, originalType); + signature = signature.withReturnType(originalType); } return super.computeSignature(signature); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantField.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantField.java index 57eeebae93baa0..a3d0776e0a8f30 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantField.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantField.java @@ -73,6 +73,14 @@ public TPatternType getPatternType() { /** * Check if the given field name matches this field's pattern. + * This method aligns with BE's fnmatch(pattern, path, FNM_PATHNAME) behavior. + * + * Supported glob syntax: + * - '*' matches any sequence of characters except '/' + * - '?' matches any single character except '/' + * - '[...]' matches any character in the brackets + * - '[!...]' or '[^...]' matches any character not in the brackets + * * @param fieldName the field name to check * @return true if the field name matches the pattern */ @@ -81,26 +89,85 @@ public boolean matches(String fieldName) { return pattern.equals(fieldName); } else { // MATCH_NAME_GLOB: convert glob pattern to regex - // Escape regex special characters except *, then replace * with .* - String regex = pattern - .replace(".", "\\.") - .replace("?", "\\?") - .replace("[", "\\[") - .replace("]", "\\]") - .replace("(", "\\(") - .replace(")", "\\)") - .replace("{", "\\{") - .replace("}", "\\}") - .replace("+", "\\+") - .replace("^", "\\^") - .replace("$", "\\$") - .replace("|", "\\|") - .replace("\\", "\\\\") - .replace("*", ".*"); + // This aligns with BE's fnmatch(pattern, path, FNM_PATHNAME) + String regex = globToRegex(pattern); return fieldName.matches(regex); } } + /** + * Convert glob pattern to regex pattern, aligning with fnmatch(FNM_PATHNAME) behavior. + */ + private static String globToRegex(String glob) { + StringBuilder regex = new StringBuilder(); + int i = 0; + int len = glob.length(); + + while (i < len) { + char c = glob.charAt(i); + switch (c) { + case '*': + // '*' matches any sequence of characters except '/' (FNM_PATHNAME) + regex.append("[^/]*"); + break; + case '?': + // '?' matches any single character except '/' (FNM_PATHNAME) + regex.append("[^/]"); + break; + case '[': + // Character class - find the closing bracket + int j = i + 1; + // Handle negation: [! or [^ + if (j < len && (glob.charAt(j) == '!' || glob.charAt(j) == '^')) { + j++; + } + // Handle ] as first character in class + if (j < len && glob.charAt(j) == ']') { + j++; + } + // Find closing ] + while (j < len && glob.charAt(j) != ']') { + j++; + } + if (j >= len) { + // No closing bracket, treat [ as literal + regex.append("\\["); + } else { + // Extract the character class content + String classContent = glob.substring(i + 1, j); + regex.append('['); + // Convert [! to [^ + if (classContent.startsWith("!")) { + regex.append('^').append(classContent.substring(1)); + } else { + regex.append(classContent); + } + regex.append(']'); + i = j; // Move past the closing ] + } + break; + // Escape regex special characters + case '\\': + case '.': + case '(': + case ')': + case '{': + case '}': + case '+': + case '^': + case '$': + case '|': + regex.append('\\').append(c); + break; + default: + regex.append(c); + break; + } + i++; + } + return regex.toString(); + } + public org.apache.doris.catalog.VariantField toCatalogDataType() { return new org.apache.doris.catalog.VariantField( pattern, dataType.toCatalogDataType(), comment, patternType); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCastTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCastTest.java new file mode 100644 index 00000000000000..eebb2535d2eb25 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCastTest.java @@ -0,0 +1,386 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.trees.expressions.And; +import org.apache.doris.nereids.trees.expressions.Cast; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.GreaterThan; +import org.apache.doris.nereids.trees.expressions.LessThan; +import org.apache.doris.nereids.trees.expressions.Or; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt; +import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; +import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.StringType; +import org.apache.doris.nereids.types.VariantField; +import org.apache.doris.nereids.types.VariantType; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.List; +import java.util.function.Function; + +/** + * Unit tests for VariantSchemaCast expression rewriting. + */ +public class VariantSchemaCastTest { + + // Expression rewriter extracted from VariantSchemaCast for testing + private static final Function EXPRESSION_REWRITER = expr -> { + if (!(expr instanceof ElementAt)) { + return expr; + } + ElementAt elementAt = (ElementAt) expr; + Expression left = elementAt.left(); + Expression right = elementAt.right(); + + if (!(left.getDataType() instanceof VariantType)) { + return expr; + } + if (!(right instanceof VarcharLiteral)) { + return expr; + } + + VariantType variantType = (VariantType) left.getDataType(); + String fieldName = ((VarcharLiteral) right).getStringValue(); + + return variantType.findMatchingField(fieldName) + .map(field -> (Expression) new Cast(elementAt, field.getDataType())) + .orElse(expr); + }; + + private Expression rewriteExpression(Expression expr) { + return expr.rewriteDownShortCircuit(EXPRESSION_REWRITER); + } + + @Test + public void testRewriteElementAtWithMatchingPattern() { + // Create variant type with schema template + VariantField numberField = new VariantField("number_*", BigIntType.INSTANCE, ""); + VariantType variantType = new VariantType(ImmutableList.of(numberField)); + + // Create element_at expression: variant['number_latency'] + MockVariantSlot variantSlot = new MockVariantSlot(variantType); + ElementAt elementAt = new ElementAt(variantSlot, new VarcharLiteral("number_latency")); + + // Rewrite + Expression result = rewriteExpression(elementAt); + + // Should be wrapped with Cast + Assertions.assertTrue(result instanceof Cast); + Cast cast = (Cast) result; + Assertions.assertEquals(BigIntType.INSTANCE, cast.getDataType()); + Assertions.assertTrue(cast.child() instanceof ElementAt); + } + + @Test + public void testRewriteElementAtWithNoMatchingPattern() { + // Create variant type with schema template + VariantField numberField = new VariantField("number_*", BigIntType.INSTANCE, ""); + VariantType variantType = new VariantType(ImmutableList.of(numberField)); + + // Create element_at expression: variant['string_message'] (no match) + MockVariantSlot variantSlot = new MockVariantSlot(variantType); + ElementAt elementAt = new ElementAt(variantSlot, new VarcharLiteral("string_message")); + + // Rewrite + Expression result = rewriteExpression(elementAt); + + // Should NOT be wrapped with Cast + Assertions.assertTrue(result instanceof ElementAt); + Assertions.assertFalse(result instanceof Cast); + } + + @Test + public void testRewriteElementAtWithEmptySchemaTemplate() { + // Create variant type without schema template + VariantType variantType = new VariantType(0); + + // Create element_at expression + MockVariantSlot variantSlot = new MockVariantSlot(variantType); + ElementAt elementAt = new ElementAt(variantSlot, new VarcharLiteral("any_field")); + + // Rewrite + Expression result = rewriteExpression(elementAt); + + // Should NOT be wrapped with Cast + Assertions.assertTrue(result instanceof ElementAt); + Assertions.assertFalse(result instanceof Cast); + } + + @Test + public void testRewriteCompoundExpression() { + // Create variant type with schema template + VariantField numberField = new VariantField("number_*", BigIntType.INSTANCE, ""); + VariantType variantType = new VariantType(ImmutableList.of(numberField)); + + // Create compound expression: variant['number_latency'] > 100 + MockVariantSlot variantSlot = new MockVariantSlot(variantType); + ElementAt elementAt = new ElementAt(variantSlot, new VarcharLiteral("number_latency")); + GreaterThan greaterThan = new GreaterThan(elementAt, new BigIntLiteral(100)); + + // Rewrite + Expression result = rewriteExpression(greaterThan); + + // Should be GreaterThan with Cast(ElementAt) on left + Assertions.assertTrue(result instanceof GreaterThan); + GreaterThan rewrittenGt = (GreaterThan) result; + Assertions.assertTrue(rewrittenGt.left() instanceof Cast); + Cast cast = (Cast) rewrittenGt.left(); + Assertions.assertEquals(BigIntType.INSTANCE, cast.getDataType()); + } + + @Test + public void testRewriteMultiplePatterns() { + // Create variant type with multiple patterns + VariantField numberField = new VariantField("number_*", BigIntType.INSTANCE, ""); + VariantField stringField = new VariantField("string_*", StringType.INSTANCE, ""); + VariantType variantType = new VariantType(ImmutableList.of(numberField, stringField)); + + MockVariantSlot variantSlot = new MockVariantSlot(variantType); + + // Test number pattern + ElementAt numberElementAt = new ElementAt(variantSlot, new VarcharLiteral("number_count")); + Expression numberResult = rewriteExpression(numberElementAt); + Assertions.assertTrue(numberResult instanceof Cast); + Assertions.assertEquals(BigIntType.INSTANCE, ((Cast) numberResult).getDataType()); + + // Test string pattern + ElementAt stringElementAt = new ElementAt(variantSlot, new VarcharLiteral("string_msg")); + Expression stringResult = rewriteExpression(stringElementAt); + Assertions.assertTrue(stringResult instanceof Cast); + Assertions.assertEquals(StringType.INSTANCE, ((Cast) stringResult).getDataType()); + } + + @Test + public void testRewriteAndCondition() { + // Create variant type with schema template + VariantField numberField = new VariantField("number_*", BigIntType.INSTANCE, ""); + VariantType variantType = new VariantType(ImmutableList.of(numberField)); + + MockVariantSlot variantSlot = new MockVariantSlot(variantType); + + // Create AND condition: variant['number_a'] > 10 AND variant['number_b'] < 100 + ElementAt elementAtA = new ElementAt(variantSlot, new VarcharLiteral("number_a")); + ElementAt elementAtB = new ElementAt(variantSlot, new VarcharLiteral("number_b")); + GreaterThan gt = new GreaterThan(elementAtA, new BigIntLiteral(10)); + LessThan lt = new LessThan(elementAtB, new BigIntLiteral(100)); + And andExpr = new And(gt, lt); + + // Rewrite + Expression result = rewriteExpression(andExpr); + + // Should be And with Cast on both sides + Assertions.assertTrue(result instanceof And); + And rewrittenAnd = (And) result; + + // Left side: Cast(ElementAt) > 10 + Assertions.assertTrue(rewrittenAnd.child(0) instanceof GreaterThan); + GreaterThan leftGt = (GreaterThan) rewrittenAnd.child(0); + Assertions.assertTrue(leftGt.left() instanceof Cast); + Assertions.assertEquals(BigIntType.INSTANCE, ((Cast) leftGt.left()).getDataType()); + + // Right side: Cast(ElementAt) < 100 + Assertions.assertTrue(rewrittenAnd.child(1) instanceof LessThan); + LessThan rightLt = (LessThan) rewrittenAnd.child(1); + Assertions.assertTrue(rightLt.left() instanceof Cast); + Assertions.assertEquals(BigIntType.INSTANCE, ((Cast) rightLt.left()).getDataType()); + } + + @Test + public void testRewriteOrCondition() { + // Create variant type with multiple patterns + VariantField numberField = new VariantField("number_*", BigIntType.INSTANCE, ""); + VariantField stringField = new VariantField("string_*", StringType.INSTANCE, ""); + VariantType variantType = new VariantType(ImmutableList.of(numberField, stringField)); + + MockVariantSlot variantSlot = new MockVariantSlot(variantType); + + // Create OR condition: variant['number_a'] > 10 OR variant['string_b'] = 'test' + ElementAt elementAtA = new ElementAt(variantSlot, new VarcharLiteral("number_a")); + ElementAt elementAtB = new ElementAt(variantSlot, new VarcharLiteral("string_b")); + GreaterThan gt = new GreaterThan(elementAtA, new BigIntLiteral(10)); + EqualTo eq = new EqualTo(elementAtB, new VarcharLiteral("test")); + Or orExpr = new Or(gt, eq); + + // Rewrite + Expression result = rewriteExpression(orExpr); + + // Should be Or with Cast on both sides + Assertions.assertTrue(result instanceof Or); + Or rewrittenOr = (Or) result; + + // Left side: Cast(ElementAt) > 10 with BIGINT + Assertions.assertTrue(rewrittenOr.child(0) instanceof GreaterThan); + GreaterThan leftGt = (GreaterThan) rewrittenOr.child(0); + Assertions.assertTrue(leftGt.left() instanceof Cast); + Assertions.assertEquals(BigIntType.INSTANCE, ((Cast) leftGt.left()).getDataType()); + + // Right side: Cast(ElementAt) = 'test' with STRING + Assertions.assertTrue(rewrittenOr.child(1) instanceof EqualTo); + EqualTo rightEq = (EqualTo) rewrittenOr.child(1); + Assertions.assertTrue(rightEq.left() instanceof Cast); + Assertions.assertEquals(StringType.INSTANCE, ((Cast) rightEq.left()).getDataType()); + } + + @Test + public void testFirstMatchWins() { + // Create variant type with overlapping patterns - first match should win + // 'num*' matches 'number_val', 'number_*' also matches 'number_val' + // First pattern 'num*' should be used + VariantField numField = new VariantField("num*", BigIntType.INSTANCE, ""); + VariantField numberField = new VariantField("number_*", DoubleType.INSTANCE, ""); + VariantType variantType = new VariantType(ImmutableList.of(numField, numberField)); + + MockVariantSlot variantSlot = new MockVariantSlot(variantType); + + // 'number_val' matches both patterns, but 'num*' is first + ElementAt elementAt = new ElementAt(variantSlot, new VarcharLiteral("number_val")); + Expression result = rewriteExpression(elementAt); + + Assertions.assertTrue(result instanceof Cast); + Cast cast = (Cast) result; + // Should be BIGINT (from 'num*'), not DOUBLE (from 'number_*') + Assertions.assertEquals(BigIntType.INSTANCE, cast.getDataType()); + } + + @Test + public void testMixedMatchingAndNonMatching() { + // Create variant type with one pattern + VariantField numberField = new VariantField("number_*", BigIntType.INSTANCE, ""); + VariantType variantType = new VariantType(ImmutableList.of(numberField)); + + MockVariantSlot variantSlot = new MockVariantSlot(variantType); + + // Create condition: variant['number_a'] > variant['other_field'] + // number_a matches, other_field does not + ElementAt matchingElementAt = new ElementAt(variantSlot, new VarcharLiteral("number_a")); + ElementAt nonMatchingElementAt = new ElementAt(variantSlot, new VarcharLiteral("other_field")); + GreaterThan gt = new GreaterThan(matchingElementAt, nonMatchingElementAt); + + // Rewrite + Expression result = rewriteExpression(gt); + + Assertions.assertTrue(result instanceof GreaterThan); + GreaterThan rewrittenGt = (GreaterThan) result; + + // Left side should be Cast(ElementAt) + Assertions.assertTrue(rewrittenGt.left() instanceof Cast); + Assertions.assertEquals(BigIntType.INSTANCE, ((Cast) rewrittenGt.left()).getDataType()); + + // Right side should remain as ElementAt (no cast) + Assertions.assertTrue(rewrittenGt.right() instanceof ElementAt); + Assertions.assertFalse(rewrittenGt.right() instanceof Cast); + } + + @Test + public void testGlobPatternWithQuestionMark() { + // Test glob pattern with ? (matches single character) + VariantField field = new VariantField("val?", BigIntType.INSTANCE, ""); + VariantType variantType = new VariantType(ImmutableList.of(field)); + + MockVariantSlot variantSlot = new MockVariantSlot(variantType); + + // 'val1' should match 'val?' + ElementAt matchingElementAt = new ElementAt(variantSlot, new VarcharLiteral("val1")); + Expression matchResult = rewriteExpression(matchingElementAt); + Assertions.assertTrue(matchResult instanceof Cast); + Assertions.assertEquals(BigIntType.INSTANCE, ((Cast) matchResult).getDataType()); + + // 'val12' should NOT match 'val?' (? matches only one char) + ElementAt nonMatchingElementAt = new ElementAt(variantSlot, new VarcharLiteral("val12")); + Expression nonMatchResult = rewriteExpression(nonMatchingElementAt); + Assertions.assertTrue(nonMatchResult instanceof ElementAt); + Assertions.assertFalse(nonMatchResult instanceof Cast); + } + + @Test + public void testGlobPatternWithBrackets() { + // Test glob pattern with [...] (character class) + VariantField field = new VariantField("type_[abc]", BigIntType.INSTANCE, ""); + VariantType variantType = new VariantType(ImmutableList.of(field)); + + MockVariantSlot variantSlot = new MockVariantSlot(variantType); + + // 'type_a' should match 'type_[abc]' + ElementAt matchingElementAt = new ElementAt(variantSlot, new VarcharLiteral("type_a")); + Expression matchResult = rewriteExpression(matchingElementAt); + Assertions.assertTrue(matchResult instanceof Cast); + + // 'type_d' should NOT match 'type_[abc]' + ElementAt nonMatchingElementAt = new ElementAt(variantSlot, new VarcharLiteral("type_d")); + Expression nonMatchResult = rewriteExpression(nonMatchingElementAt); + Assertions.assertTrue(nonMatchResult instanceof ElementAt); + Assertions.assertFalse(nonMatchResult instanceof Cast); + } + + /** + * Mock Expression class for providing VariantType in tests. + */ + private static class MockVariantSlot extends Expression { + private final VariantType variantType; + + public MockVariantSlot(VariantType variantType) { + super(Collections.emptyList()); + this.variantType = variantType; + } + + @Override + public DataType getDataType() { + return variantType; + } + + @Override + public boolean nullable() { + return true; + } + + @Override + public Expression withChildren(List children) { + return this; + } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visit(this, context); + } + + @Override + public int arity() { + return 0; + } + + @Override + public Expression child(int index) { + throw new IndexOutOfBoundsException(); + } + + @Override + public List children() { + return Collections.emptyList(); + } + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/types/VariantFieldMatchTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/types/VariantFieldMatchTest.java new file mode 100644 index 00000000000000..f981ef75228d5a --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/types/VariantFieldMatchTest.java @@ -0,0 +1,173 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.types; + +import org.apache.doris.thrift.TPatternType; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.Optional; + +/** + * Unit tests for VariantField pattern matching and VariantType field lookup. + */ +public class VariantFieldMatchTest { + + // ==================== VariantField.matches() tests ==================== + + @Test + public void testExactMatch() { + VariantField field = new VariantField("number_latency", BigIntType.INSTANCE, "", + TPatternType.MATCH_NAME.name()); + + Assertions.assertTrue(field.matches("number_latency")); + Assertions.assertFalse(field.matches("number_latency_ms")); + Assertions.assertFalse(field.matches("other_field")); + } + + @Test + public void testGlobMatchSuffix() { + // Pattern: number_* should match number_latency, number_count, etc. + VariantField field = new VariantField("number_*", BigIntType.INSTANCE, "", + TPatternType.MATCH_NAME_GLOB.name()); + + Assertions.assertTrue(field.matches("number_latency")); + Assertions.assertTrue(field.matches("number_count")); + Assertions.assertTrue(field.matches("number_")); + Assertions.assertFalse(field.matches("string_message")); + Assertions.assertFalse(field.matches("numbering")); + } + + @Test + public void testGlobMatchPrefix() { + // Pattern: *_latency should match number_latency, string_latency, etc. + VariantField field = new VariantField("*_latency", BigIntType.INSTANCE, "", + TPatternType.MATCH_NAME_GLOB.name()); + + Assertions.assertTrue(field.matches("number_latency")); + Assertions.assertTrue(field.matches("string_latency")); + Assertions.assertTrue(field.matches("_latency")); + Assertions.assertFalse(field.matches("latency_ms")); + } + + @Test + public void testGlobMatchMiddle() { + // Pattern: num_*_ms should match num_latency_ms, num_count_ms, etc. + VariantField field = new VariantField("num_*_ms", BigIntType.INSTANCE, "", + TPatternType.MATCH_NAME_GLOB.name()); + + Assertions.assertTrue(field.matches("num_latency_ms")); + Assertions.assertTrue(field.matches("num_count_ms")); + Assertions.assertTrue(field.matches("num__ms")); + Assertions.assertFalse(field.matches("num_latency")); + Assertions.assertFalse(field.matches("number_latency_ms")); + } + + @Test + public void testGlobMatchAll() { + // Pattern: * should match everything + VariantField field = new VariantField("*", StringType.INSTANCE, "", + TPatternType.MATCH_NAME_GLOB.name()); + + Assertions.assertTrue(field.matches("anything")); + Assertions.assertTrue(field.matches("")); + Assertions.assertTrue(field.matches("a.b.c")); + } + + @Test + public void testGlobMatchWithDot() { + // Pattern: metrics.* should match metrics.score, metrics.count, etc. + VariantField field = new VariantField("metrics.*", BigIntType.INSTANCE, "", + TPatternType.MATCH_NAME_GLOB.name()); + + Assertions.assertTrue(field.matches("metrics.score")); + Assertions.assertTrue(field.matches("metrics.count")); + Assertions.assertFalse(field.matches("metricsXscore")); + Assertions.assertFalse(field.matches("metrics")); + } + + @Test + public void testDefaultPatternTypeIsGlob() { + // Default constructor should use MATCH_NAME_GLOB + VariantField field = new VariantField("number_*", BigIntType.INSTANCE, ""); + + Assertions.assertTrue(field.matches("number_latency")); + Assertions.assertEquals(TPatternType.MATCH_NAME_GLOB, field.getPatternType()); + } + + // ==================== VariantType.findMatchingField() tests ==================== + + @Test + public void testFindMatchingFieldSinglePattern() { + VariantField field = new VariantField("number_*", BigIntType.INSTANCE, ""); + VariantType variantType = new VariantType(ImmutableList.of(field)); + + Optional result = variantType.findMatchingField("number_latency"); + Assertions.assertTrue(result.isPresent()); + Assertions.assertEquals(BigIntType.INSTANCE, result.get().getDataType()); + } + + @Test + public void testFindMatchingFieldMultiplePatterns() { + VariantField numberField = new VariantField("number_*", BigIntType.INSTANCE, ""); + VariantField stringField = new VariantField("string_*", StringType.INSTANCE, ""); + VariantType variantType = new VariantType(ImmutableList.of(numberField, stringField)); + + // Test number pattern + Optional numberResult = variantType.findMatchingField("number_latency"); + Assertions.assertTrue(numberResult.isPresent()); + Assertions.assertEquals(BigIntType.INSTANCE, numberResult.get().getDataType()); + + // Test string pattern + Optional stringResult = variantType.findMatchingField("string_message"); + Assertions.assertTrue(stringResult.isPresent()); + Assertions.assertEquals(StringType.INSTANCE, stringResult.get().getDataType()); + } + + @Test + public void testFindMatchingFieldNoMatch() { + VariantField field = new VariantField("number_*", BigIntType.INSTANCE, ""); + VariantType variantType = new VariantType(ImmutableList.of(field)); + + Optional result = variantType.findMatchingField("string_message"); + Assertions.assertFalse(result.isPresent()); + } + + @Test + public void testFindMatchingFieldFirstMatchWins() { + // When multiple patterns match, the first one should win + VariantField field1 = new VariantField("num*", BigIntType.INSTANCE, ""); + VariantField field2 = new VariantField("number_*", DoubleType.INSTANCE, ""); + VariantType variantType = new VariantType(ImmutableList.of(field1, field2)); + + Optional result = variantType.findMatchingField("number_latency"); + Assertions.assertTrue(result.isPresent()); + // First pattern "num*" should match, returning BigIntType + Assertions.assertEquals(BigIntType.INSTANCE, result.get().getDataType()); + } + + @Test + public void testFindMatchingFieldEmptyPredefinedFields() { + VariantType variantType = new VariantType(0); + + Optional result = variantType.findMatchingField("any_field"); + Assertions.assertFalse(result.isPresent()); + } +} diff --git a/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out b/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out new file mode 100644 index 00000000000000..ec8a181b820c06 --- /dev/null +++ b/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out @@ -0,0 +1,34 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !where_simple -- +2 +3 + +-- !where_and -- +2 +4 + +-- !where_or -- +1 +3 +4 + +-- !order_by -- +3 50 +2 30 +4 15 +1 10 + +-- !topn -- +3 50 +2 30 + +-- !select_alias -- +1 10 alice +2 30 bob +3 50 charlie +4 15 alice + +-- !join_on -- +1 100 100 +3 300 300 + diff --git a/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy b/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy new file mode 100644 index 00000000000000..310853a59f46f4 --- /dev/null +++ b/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy @@ -0,0 +1,107 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_schema_template_auto_cast", "p0") { + sql """ set describe_extend_variant_column = true """ + sql """ set enable_match_without_inverted_index = false """ + sql """ set enable_common_expr_pushdown = true """ + sql """ set default_variant_enable_typed_paths_to_sparse = false """ + sql """ set default_variant_enable_doc_mode = false """ + + def tableName = "test_variant_schema_auto_cast" + + // Test 1: WHERE clause with auto-cast + sql "DROP TABLE IF EXISTS ${tableName}" + sql """CREATE TABLE ${tableName} ( + `id` bigint NULL, + `data` variant<'num_*': BIGINT, 'str_*': STRING> NOT NULL + ) ENGINE=OLAP DUPLICATE KEY(`id`) + DISTRIBUTED BY HASH(`id`) BUCKETS 1 + PROPERTIES ( "replication_allocation" = "tag.location.default: 1")""" + + sql """insert into ${tableName} values(1, '{"num_a": 10, "num_b": 20, "str_name": "alice"}')""" + sql """insert into ${tableName} values(2, '{"num_a": 30, "num_b": 40, "str_name": "bob"}')""" + sql """insert into ${tableName} values(3, '{"num_a": 50, "num_b": 60, "str_name": "charlie"}')""" + sql """insert into ${tableName} values(4, '{"num_a": 15, "num_b": 25, "str_name": "alice"}')""" + + // Simple WHERE + qt_where_simple """ SELECT id FROM ${tableName} + WHERE data['num_a'] > 20 ORDER BY id """ + + // AND condition + qt_where_and """ SELECT id FROM ${tableName} + WHERE data['num_a'] > 10 AND data['num_b'] < 50 + ORDER BY id """ + + // OR condition + qt_where_or """ SELECT id FROM ${tableName} + WHERE data['num_a'] > 40 OR data['str_name'] = 'alice' + ORDER BY id """ + + // Test 2: ORDER BY with auto-cast + qt_order_by """ SELECT id, data['num_a'] FROM ${tableName} + ORDER BY data['num_a'] DESC """ + + // Test 3: TopN (ORDER BY + LIMIT) + qt_topn """ SELECT id, data['num_a'] FROM ${tableName} + ORDER BY data['num_a'] DESC LIMIT 2 """ + + // Test 4: SELECT projection with aliases + qt_select_alias """ SELECT id, + data['num_a'] as num_a_val, + data['str_name'] as name + FROM ${tableName} + ORDER BY id """ + + // Test 5: JOIN ON clause + def leftTable = "test_variant_join_left" + def rightTable = "test_variant_join_right" + + sql "DROP TABLE IF EXISTS ${leftTable}" + sql "DROP TABLE IF EXISTS ${rightTable}" + + sql """CREATE TABLE ${leftTable} ( + `id` bigint NULL, + `data` variant<'key_*': BIGINT> NOT NULL + ) ENGINE=OLAP DUPLICATE KEY(`id`) + DISTRIBUTED BY HASH(`id`) BUCKETS 1 + PROPERTIES ( "replication_allocation" = "tag.location.default: 1")""" + + sql """CREATE TABLE ${rightTable} ( + `id` bigint NULL, + `info` variant<'key_*': BIGINT> NOT NULL + ) ENGINE=OLAP DUPLICATE KEY(`id`) + DISTRIBUTED BY HASH(`id`) BUCKETS 1 + PROPERTIES ( "replication_allocation" = "tag.location.default: 1")""" + + sql """insert into ${leftTable} values(1, '{"key_val": 100}')""" + sql """insert into ${leftTable} values(2, '{"key_val": 200}')""" + sql """insert into ${leftTable} values(3, '{"key_val": 300}')""" + + sql """insert into ${rightTable} values(1, '{"key_val": 100}')""" + sql """insert into ${rightTable} values(2, '{"key_val": 250}')""" + sql """insert into ${rightTable} values(3, '{"key_val": 300}')""" + + qt_join_on """ SELECT l.id, l.data['key_val'], r.info['key_val'] + FROM ${leftTable} l JOIN ${rightTable} r + ON l.data['key_val'] = r.info['key_val'] + ORDER BY l.id """ + + sql "DROP TABLE IF EXISTS ${leftTable}" + sql "DROP TABLE IF EXISTS ${rightTable}" + sql "DROP TABLE IF EXISTS ${tableName}" +} From 2fb3545ddcb3b2fc79ad6a643f2b6a176614a0d6 Mon Sep 17 00:00:00 2001 From: Gary Date: Thu, 29 Jan 2026 22:25:18 +0800 Subject: [PATCH 03/27] fix pipline --- .../rules/rewrite/VariantSchemaCast.java | 49 +++++++++---------- .../test_schema_template_auto_cast.out | 6 --- .../test_schema_template_auto_cast.groovy | 9 +--- 3 files changed, 25 insertions(+), 39 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java index 223410f549f103..d66f89ca6cd26e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java @@ -19,10 +19,9 @@ import org.apache.doris.nereids.jobs.JobContext; import org.apache.doris.nereids.properties.OrderKey; -import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Match; import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt; import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral; import org.apache.doris.nereids.trees.plans.Plan; @@ -43,7 +42,6 @@ import java.util.List; import java.util.Optional; import java.util.Set; -import java.util.function.Function; /** * Automatically cast variant element access expressions based on schema template. @@ -66,14 +64,6 @@ public Plan rewriteRoot(Plan plan, JobContext jobContext) { private static class PlanRewriter extends DefaultPlanRewriter { - private final Function expressionRewriter = expr -> { - // Handle ElementAt expressions - if (expr instanceof ElementAt) { - return rewriteElementAt((ElementAt) expr); - } - return expr; - }; - private Expression rewriteElementAt(ElementAt elementAt) { Expression left = elementAt.left(); Expression right = elementAt.right(); @@ -100,17 +90,29 @@ private Expression rewriteElementAt(ElementAt elementAt) { } private Expression rewriteExpression(Expression expr) { - return expr.rewriteDownShortCircuit(expressionRewriter); - } + // Skip Match expressions - they require SlotRef as left operand + if (expr instanceof Match) { + return expr; + } - private NamedExpression rewriteNamedExpression(NamedExpression expr) { - Expression rewritten = rewriteExpression(expr); - if (rewritten instanceof NamedExpression) { - return (NamedExpression) rewritten; + // Recursively rewrite children first + boolean childrenChanged = false; + ImmutableList.Builder newChildren = ImmutableList.builder(); + for (Expression child : expr.children()) { + Expression newChild = rewriteExpression(child); + newChildren.add(newChild); + if (newChild != child) { + childrenChanged = true; + } } - // If the result is not a NamedExpression (e.g., Cast), wrap it in an Alias - // Preserve the original ExprId to maintain consistency - return new Alias(expr.getExprId(), rewritten, expr.getName()); + + Expression newExpr = childrenChanged ? expr.withChildren(newChildren.build()) : expr; + + // Then apply the rewriter to the current expression + if (newExpr instanceof ElementAt) { + return rewriteElementAt((ElementAt) newExpr); + } + return newExpr; } @Override @@ -124,11 +126,8 @@ public Plan visitLogicalFilter(LogicalFilter filter, Void contex @Override public Plan visitLogicalProject(LogicalProject project, Void context) { - project = (LogicalProject) super.visit(project, context); - List newProjects = project.getProjects().stream() - .map(this::rewriteNamedExpression) - .collect(ImmutableList.toImmutableList()); - return project.withProjects(newProjects); + // Don't rewrite SELECT projections - BE expects Variant type for ElementAt in SELECT + return super.visit(project, context); } @Override diff --git a/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out b/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out index ec8a181b820c06..d840414edc119c 100644 --- a/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out +++ b/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out @@ -22,12 +22,6 @@ 3 50 2 30 --- !select_alias -- -1 10 alice -2 30 bob -3 50 charlie -4 15 alice - -- !join_on -- 1 100 100 3 300 300 diff --git a/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy b/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy index 310853a59f46f4..d1dd88cddd7f54 100644 --- a/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy +++ b/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy @@ -60,14 +60,7 @@ suite("test_schema_template_auto_cast", "p0") { qt_topn """ SELECT id, data['num_a'] FROM ${tableName} ORDER BY data['num_a'] DESC LIMIT 2 """ - // Test 4: SELECT projection with aliases - qt_select_alias """ SELECT id, - data['num_a'] as num_a_val, - data['str_name'] as name - FROM ${tableName} - ORDER BY id """ - - // Test 5: JOIN ON clause + // Test 4: JOIN ON clause def leftTable = "test_variant_join_left" def rightTable = "test_variant_join_right" From 3d6e7c2e8e9ad6fdfb96cc89e03adfa4003c7f40 Mon Sep 17 00:00:00 2001 From: Gary Date: Fri, 30 Jan 2026 16:09:38 +0800 Subject: [PATCH 04/27] add slot ref and delete join case --- .../rules/rewrite/VariantSchemaCast.java | 65 ++++++++++++++----- .../trees/expressions/SlotReference.java | 9 +++ .../apache/doris/nereids/types/DataType.java | 17 ++--- .../test_schema_template_auto_cast.out | 4 -- .../test_schema_template_auto_cast.groovy | 36 ---------- 5 files changed, 66 insertions(+), 65 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java index d66f89ca6cd26e..7b2abeb9999a2f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java @@ -22,11 +22,11 @@ import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Match; +import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt; import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; -import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalSort; import org.apache.doris.nereids.trees.plans.logical.LogicalTopN; @@ -89,6 +89,49 @@ private Expression rewriteElementAt(ElementAt elementAt) { return new Cast(elementAt, targetType); } + private Expression rewriteSlotReference(SlotReference slotRef) { + // Check if the SlotReference's DataType is VariantType with predefinedFields + if (!(slotRef.getDataType() instanceof VariantType)) { + return slotRef; + } + + VariantType variantType = (VariantType) slotRef.getDataType(); + if (variantType.getPredefinedFields().isEmpty()) { + return slotRef; + } + + // Extract field name from SlotReference name pattern like "data['field_name']" + String slotName = slotRef.getName(); + + // Parse field name from pattern like "column['field']" or "column[\"field\"]" + int bracketStart = slotName.indexOf('['); + if (bracketStart < 0) { + return slotRef; + } + + int bracketEnd = slotName.lastIndexOf(']'); + if (bracketEnd <= bracketStart) { + return slotRef; + } + + // Extract the content between brackets and remove quotes + String bracketContent = slotName.substring(bracketStart + 1, bracketEnd); + String fieldName = bracketContent; + if ((bracketContent.startsWith("'") && bracketContent.endsWith("'")) + || (bracketContent.startsWith("\"") && bracketContent.endsWith("\""))) { + fieldName = bracketContent.substring(1, bracketContent.length() - 1); + } + + // Find matching field in schema template + Optional matchingField = variantType.findMatchingField(fieldName); + if (!matchingField.isPresent()) { + return slotRef; + } + + DataType targetType = matchingField.get().getDataType(); + return new Cast(slotRef, targetType); + } + private Expression rewriteExpression(Expression expr) { // Skip Match expressions - they require SlotRef as left operand if (expr instanceof Match) { @@ -112,6 +155,10 @@ private Expression rewriteExpression(Expression expr) { if (newExpr instanceof ElementAt) { return rewriteElementAt((ElementAt) newExpr); } + // Handle SlotReference that represents variant element access (e.g., data['field']) + if (newExpr instanceof SlotReference) { + return rewriteSlotReference((SlotReference) newExpr); + } return newExpr; } @@ -147,21 +194,5 @@ public Plan visitLogicalTopN(LogicalTopN topN, Void context) { .collect(ImmutableList.toImmutableList()); return topN.withOrderKeys(newOrderKeys); } - - @Override - public Plan visitLogicalJoin(LogicalJoin join, Void context) { - join = (LogicalJoin) super.visit(join, context); - List newHashConditions = join.getHashJoinConjuncts().stream() - .map(this::rewriteExpression) - .collect(ImmutableList.toImmutableList()); - List newOtherConditions = join.getOtherJoinConjuncts().stream() - .map(this::rewriteExpression) - .collect(ImmutableList.toImmutableList()); - List newMarkConditions = join.getMarkJoinConjuncts().stream() - .map(this::rewriteExpression) - .collect(ImmutableList.toImmutableList()); - return join.withJoinConjuncts(newHashConditions, newOtherConditions, - newMarkConditions, join.getJoinReorderContext()); - } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java index 1c77dc669acb72..130c7920d14e3c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java @@ -155,6 +155,15 @@ public static SlotReference fromColumn(ExprId exprId, TableIf table, Column colu return fromColumn(exprId, table, column, column.getName(), qualifier); } + /** + * Get SlotReference from a column with custom name. + * @param exprId the expression id + * @param table the table which contains the column + * @param column the column which contains type info + * @param name the name of SlotReference + * @param qualifier the qualifier of SlotReference + * @return SlotReference created from column + */ public static SlotReference fromColumn( ExprId exprId, TableIf table, Column column, String name, List qualifier) { DataType dataType = DataType.fromCatalogType(column.getType()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java index 911dc2e4e2cd51..9f4f3fc862ef53 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java @@ -480,19 +480,20 @@ public static DataType fromCatalogType(Type type) { // In the past, variant metadata used the ScalarType type. // Now, we use VariantType, which inherits from ScalarType, as the new metadata storage. if (type instanceof org.apache.doris.catalog.VariantType) { - List variantFields = ((org.apache.doris.catalog.VariantType) type) + org.apache.doris.catalog.VariantType catalogVariantType = (org.apache.doris.catalog.VariantType) type; + List variantFields = catalogVariantType .getPredefinedFields().stream() .map(cf -> new VariantField(cf.getPattern(), fromCatalogType(cf.getType()), cf.getComment() == null ? "" : cf.getComment(), cf.getPatternType().toString())) .collect(ImmutableList.toImmutableList()); return new VariantType(variantFields, - ((org.apache.doris.catalog.VariantType) type).getVariantMaxSubcolumnsCount(), - ((org.apache.doris.catalog.VariantType) type).getEnableTypedPathsToSparse(), - ((org.apache.doris.catalog.VariantType) type).getVariantMaxSparseColumnStatisticsSize(), - ((org.apache.doris.catalog.VariantType) type).getVariantSparseHashShardCount(), - ((org.apache.doris.catalog.VariantType) type).getEnableVariantDocMode(), - ((org.apache.doris.catalog.VariantType) type).getvariantDocMaterializationMinRows(), - ((org.apache.doris.catalog.VariantType) type).getVariantDocShardCount()); + catalogVariantType.getVariantMaxSubcolumnsCount(), + catalogVariantType.getEnableTypedPathsToSparse(), + catalogVariantType.getVariantMaxSparseColumnStatisticsSize(), + catalogVariantType.getVariantSparseHashShardCount(), + catalogVariantType.getEnableVariantDocMode(), + catalogVariantType.getvariantDocMaterializationMinRows(), + catalogVariantType.getVariantDocShardCount()); } return VariantType.INSTANCE; } else { diff --git a/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out b/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out index d840414edc119c..f4f33670782294 100644 --- a/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out +++ b/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out @@ -22,7 +22,3 @@ 3 50 2 30 --- !join_on -- -1 100 100 -3 300 300 - diff --git a/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy b/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy index d1dd88cddd7f54..ac1594cc8db102 100644 --- a/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy +++ b/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy @@ -60,41 +60,5 @@ suite("test_schema_template_auto_cast", "p0") { qt_topn """ SELECT id, data['num_a'] FROM ${tableName} ORDER BY data['num_a'] DESC LIMIT 2 """ - // Test 4: JOIN ON clause - def leftTable = "test_variant_join_left" - def rightTable = "test_variant_join_right" - - sql "DROP TABLE IF EXISTS ${leftTable}" - sql "DROP TABLE IF EXISTS ${rightTable}" - - sql """CREATE TABLE ${leftTable} ( - `id` bigint NULL, - `data` variant<'key_*': BIGINT> NOT NULL - ) ENGINE=OLAP DUPLICATE KEY(`id`) - DISTRIBUTED BY HASH(`id`) BUCKETS 1 - PROPERTIES ( "replication_allocation" = "tag.location.default: 1")""" - - sql """CREATE TABLE ${rightTable} ( - `id` bigint NULL, - `info` variant<'key_*': BIGINT> NOT NULL - ) ENGINE=OLAP DUPLICATE KEY(`id`) - DISTRIBUTED BY HASH(`id`) BUCKETS 1 - PROPERTIES ( "replication_allocation" = "tag.location.default: 1")""" - - sql """insert into ${leftTable} values(1, '{"key_val": 100}')""" - sql """insert into ${leftTable} values(2, '{"key_val": 200}')""" - sql """insert into ${leftTable} values(3, '{"key_val": 300}')""" - - sql """insert into ${rightTable} values(1, '{"key_val": 100}')""" - sql """insert into ${rightTable} values(2, '{"key_val": 250}')""" - sql """insert into ${rightTable} values(3, '{"key_val": 300}')""" - - qt_join_on """ SELECT l.id, l.data['key_val'], r.info['key_val'] - FROM ${leftTable} l JOIN ${rightTable} r - ON l.data['key_val'] = r.info['key_val'] - ORDER BY l.id """ - - sql "DROP TABLE IF EXISTS ${leftTable}" - sql "DROP TABLE IF EXISTS ${rightTable}" sql "DROP TABLE IF EXISTS ${tableName}" } From 947bc8c6ce19b2b78efd2bd13f286429fff966ae Mon Sep 17 00:00:00 2001 From: Gary Date: Sat, 31 Jan 2026 00:59:03 +0800 Subject: [PATCH 05/27] add join, select, group, having tests --- .../test_schema_template_auto_cast.out | 20 ++++++++ .../test_schema_template_auto_cast.groovy | 50 +++++++++++++++++++ 2 files changed, 70 insertions(+) diff --git a/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out b/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out index f4f33670782294..216c8b88f77789 100644 --- a/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out +++ b/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out @@ -22,3 +22,23 @@ 3 50 2 30 +-- !select_arithmetic -- +1 30 +2 70 +3 110 +4 40 + +-- !group_by -- +alice 25 +bob 30 +charlie 50 + +-- !having -- +alice 25 +bob 30 +charlie 50 + +-- !join_on -- +1 first +2 second + diff --git a/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy b/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy index ac1594cc8db102..6921ccd1c2a0bf 100644 --- a/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy +++ b/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy @@ -60,5 +60,55 @@ suite("test_schema_template_auto_cast", "p0") { qt_topn """ SELECT id, data['num_a'] FROM ${tableName} ORDER BY data['num_a'] DESC LIMIT 2 """ + // Test 4: SELECT with auto-cast (arithmetic operations) + qt_select_arithmetic """ SELECT id, data['num_a'] + data['num_b'] as sum_val + FROM ${tableName} ORDER BY id """ + + // Test 5: GROUP BY with auto-cast + qt_group_by """ SELECT data['str_name'], SUM(data['num_a']) as total + FROM ${tableName} GROUP BY data['str_name'] ORDER BY data['str_name'] """ + + // Test 6: HAVING with auto-cast + qt_having """ SELECT data['str_name'], SUM(data['num_a']) as total + FROM ${tableName} GROUP BY data['str_name'] + HAVING SUM(data['num_a']) > 20 ORDER BY data['str_name'] """ + sql "DROP TABLE IF EXISTS ${tableName}" + + // Test 7: JOIN ON with auto-cast + def leftTable = "test_variant_join_left" + def rightTable = "test_variant_join_right" + + sql "DROP TABLE IF EXISTS ${leftTable}" + sql "DROP TABLE IF EXISTS ${rightTable}" + + sql """CREATE TABLE ${leftTable} ( + `id` bigint NULL, + `data` variant<'key_*': BIGINT> NOT NULL + ) ENGINE=OLAP DUPLICATE KEY(`id`) + DISTRIBUTED BY HASH(`id`) BUCKETS 1 + PROPERTIES ( "replication_allocation" = "tag.location.default: 1")""" + + sql """CREATE TABLE ${rightTable} ( + `id` bigint NULL, + `info` variant<'key_*': BIGINT, 'name_*': STRING> NOT NULL + ) ENGINE=OLAP DUPLICATE KEY(`id`) + DISTRIBUTED BY HASH(`id`) BUCKETS 1 + PROPERTIES ( "replication_allocation" = "tag.location.default: 1")""" + + sql """insert into ${leftTable} values(1, '{"key_id": 100}')""" + sql """insert into ${leftTable} values(2, '{"key_id": 200}')""" + sql """insert into ${leftTable} values(3, '{"key_id": 300}')""" + + sql """insert into ${rightTable} values(1, '{"key_id": 100, "name_val": "first"}')""" + sql """insert into ${rightTable} values(2, '{"key_id": 200, "name_val": "second"}')""" + sql """insert into ${rightTable} values(3, '{"key_id": 400, "name_val": "fourth"}')""" + + qt_join_on """ SELECT l.id, r.info['name_val'] + FROM ${leftTable} l JOIN ${rightTable} r + ON l.data['key_id'] = r.info['key_id'] + ORDER BY l.id """ + + sql "DROP TABLE IF EXISTS ${leftTable}" + sql "DROP TABLE IF EXISTS ${rightTable}" } From da360f6f8ab0f585d571c2b7f03b74e94237bc2e Mon Sep 17 00:00:00 2001 From: Gary Date: Sat, 31 Jan 2026 02:10:41 +0800 Subject: [PATCH 06/27] cover more --- .../rules/rewrite/CheckMatchExpression.java | 12 +- .../rules/rewrite/VariantSchemaCast.java | 103 +++++++++++++----- 2 files changed, 81 insertions(+), 34 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMatchExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMatchExpression.java index 623c3085962b47..5ab28c7d5dbb36 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMatchExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMatchExpression.java @@ -49,11 +49,13 @@ private Plan checkChildren(LogicalFilter filter) { for (Expression expr : expressions) { if (expr instanceof Match) { Match matchExpression = (Match) expr; - boolean isSlotReference = matchExpression.left() instanceof SlotReference; - boolean isCastChildWithSlotReference = (matchExpression.left() instanceof Cast - && matchExpression.left().child(0) instanceof SlotReference); - if (!(isSlotReference || isCastChildWithSlotReference) - || !(matchExpression.right() instanceof Literal)) { + // Unwrap all Cast layers to find the innermost expression + Expression left = matchExpression.left(); + while (left instanceof Cast) { + left = left.child(0); + } + boolean isSlotReference = left instanceof SlotReference; + if (!isSlotReference || !(matchExpression.right() instanceof Literal)) { throw new AnalysisException(String.format("Only support match left operand is SlotRef," + " right operand is Literal. But meet expression %s", matchExpression)); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java index 7b2abeb9999a2f..84e108e4c47e1f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java @@ -19,14 +19,19 @@ import org.apache.doris.nereids.jobs.JobContext; import org.apache.doris.nereids.properties.OrderKey; +import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Match; +import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt; import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalHaving; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalSort; import org.apache.doris.nereids.trees.plans.logical.LogicalTopN; @@ -64,7 +69,23 @@ public Plan rewriteRoot(Plan plan, JobContext jobContext) { private static class PlanRewriter extends DefaultPlanRewriter { - private Expression rewriteElementAt(ElementAt elementAt) { + private final java.util.function.Function expressionRewriter = expr -> { + // Skip Match expressions - they require SlotRef as left operand + if (expr instanceof Match) { + return expr; + } + // Handle ElementAt expressions + if (expr instanceof ElementAt) { + return rewriteElementAt((ElementAt) expr); + } + // Handle SlotReference that represents variant element access (e.g., data['field']) + if (expr instanceof SlotReference) { + return rewriteSlotReference((SlotReference) expr); + } + return expr; + }; + + private static Expression rewriteElementAt(ElementAt elementAt) { Expression left = elementAt.left(); Expression right = elementAt.right(); @@ -89,7 +110,7 @@ private Expression rewriteElementAt(ElementAt elementAt) { return new Cast(elementAt, targetType); } - private Expression rewriteSlotReference(SlotReference slotRef) { + private static Expression rewriteSlotReference(SlotReference slotRef) { // Check if the SlotReference's DataType is VariantType with predefinedFields if (!(slotRef.getDataType() instanceof VariantType)) { return slotRef; @@ -133,33 +154,17 @@ private Expression rewriteSlotReference(SlotReference slotRef) { } private Expression rewriteExpression(Expression expr) { - // Skip Match expressions - they require SlotRef as left operand - if (expr instanceof Match) { - return expr; - } - - // Recursively rewrite children first - boolean childrenChanged = false; - ImmutableList.Builder newChildren = ImmutableList.builder(); - for (Expression child : expr.children()) { - Expression newChild = rewriteExpression(child); - newChildren.add(newChild); - if (newChild != child) { - childrenChanged = true; - } - } - - Expression newExpr = childrenChanged ? expr.withChildren(newChildren.build()) : expr; + return expr.rewriteDownShortCircuit(expressionRewriter); + } - // Then apply the rewriter to the current expression - if (newExpr instanceof ElementAt) { - return rewriteElementAt((ElementAt) newExpr); - } - // Handle SlotReference that represents variant element access (e.g., data['field']) - if (newExpr instanceof SlotReference) { - return rewriteSlotReference((SlotReference) newExpr); + private NamedExpression rewriteNamedExpression(NamedExpression expr) { + Expression rewritten = rewriteExpression(expr); + if (rewritten instanceof NamedExpression) { + return (NamedExpression) rewritten; } - return newExpr; + // If the result is not a NamedExpression (e.g., Cast), wrap it in an Alias + // Preserve the original ExprId to maintain consistency + return new Alias(expr.getExprId(), rewritten, expr.getName()); } @Override @@ -173,8 +178,11 @@ public Plan visitLogicalFilter(LogicalFilter filter, Void contex @Override public Plan visitLogicalProject(LogicalProject project, Void context) { - // Don't rewrite SELECT projections - BE expects Variant type for ElementAt in SELECT - return super.visit(project, context); + project = (LogicalProject) super.visit(project, context); + List newProjects = project.getProjects().stream() + .map(this::rewriteNamedExpression) + .collect(ImmutableList.toImmutableList()); + return project.withProjects(newProjects); } @Override @@ -194,5 +202,42 @@ public Plan visitLogicalTopN(LogicalTopN topN, Void context) { .collect(ImmutableList.toImmutableList()); return topN.withOrderKeys(newOrderKeys); } + + @Override + public Plan visitLogicalAggregate(LogicalAggregate aggregate, Void context) { + aggregate = (LogicalAggregate) super.visit(aggregate, context); + List newGroupByExprs = aggregate.getGroupByExpressions().stream() + .map(this::rewriteExpression) + .collect(ImmutableList.toImmutableList()); + List newOutputExprs = aggregate.getOutputExpressions().stream() + .map(this::rewriteNamedExpression) + .collect(ImmutableList.toImmutableList()); + return aggregate.withGroupByAndOutput(newGroupByExprs, newOutputExprs); + } + + @Override + public Plan visitLogicalHaving(LogicalHaving having, Void context) { + having = (LogicalHaving) super.visit(having, context); + Set newConjuncts = having.getConjuncts().stream() + .map(this::rewriteExpression) + .collect(ImmutableSet.toImmutableSet()); + return having.withConjuncts(newConjuncts); + } + + @Override + public Plan visitLogicalJoin(LogicalJoin join, Void context) { + join = (LogicalJoin) super.visit(join, context); + List newHashConditions = join.getHashJoinConjuncts().stream() + .map(this::rewriteExpression) + .collect(ImmutableList.toImmutableList()); + List newOtherConditions = join.getOtherJoinConjuncts().stream() + .map(this::rewriteExpression) + .collect(ImmutableList.toImmutableList()); + List newMarkConditions = join.getMarkJoinConjuncts().stream() + .map(this::rewriteExpression) + .collect(ImmutableList.toImmutableList()); + return join.withJoinConjuncts(newHashConditions, newOtherConditions, + newMarkConditions, join.getJoinReorderContext()); + } } } From c50250102125234137538f2fd0a63a0f050ba431 Mon Sep 17 00:00:00 2001 From: Gary Date: Sat, 31 Jan 2026 04:47:49 +0800 Subject: [PATCH 07/27] all tests pass --- .../rules/analysis/ExpressionAnalyzer.java | 37 ++++++++++++++++++- .../rules/rewrite/VariantSchemaCast.java | 16 ++++++++ 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java index a4fb7f3ae76593..8b0be52709e1f1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java @@ -84,6 +84,7 @@ import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral; import org.apache.doris.nereids.trees.expressions.literal.StringLiteral; import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes; import org.apache.doris.nereids.trees.plans.PlaceholderId; @@ -99,6 +100,8 @@ import org.apache.doris.nereids.types.StructField; import org.apache.doris.nereids.types.StructType; import org.apache.doris.nereids.types.TinyIntType; +import org.apache.doris.nereids.types.VariantField; +import org.apache.doris.nereids.types.VariantType; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.TypeCoercionUtils; import org.apache.doris.nereids.util.Utils; @@ -273,7 +276,9 @@ public Expression visitDereferenceExpression(DereferenceExpression dereferenceEx } else if (dataType.isMapType()) { return new ElementAt(expression, dereferenceExpression.child(1)); } else if (dataType.isVariantType()) { - return new ElementAt(expression, dereferenceExpression.child(1)); + ElementAt elementAt = new ElementAt(expression, dereferenceExpression.child(1)); + return wrapVariantElementAtWithCast(elementAt, (VariantType) dataType, + dereferenceExpression.fieldName); } throw new AnalysisException("Can not dereference field: " + dereferenceExpression.fieldName); } @@ -636,6 +641,20 @@ public Expression visitBinaryArithmetic(BinaryArithmetic binaryArithmetic, Expre return TypeCoercionUtils.processBinaryArithmetic(binaryArithmetic); } + @Override + public Expression visitElementAt(ElementAt elementAt, ExpressionRewriteContext context) { + Expression left = elementAt.left().accept(this, context); + Expression right = elementAt.right().accept(this, context); + ElementAt newElementAt = (ElementAt) elementAt.withChildren(left, right); + // Auto-cast for variant schema template + if (left.getDataType() instanceof VariantType && right instanceof StringLikeLiteral) { + VariantType variantType = (VariantType) left.getDataType(); + String fieldName = ((StringLikeLiteral) right).getStringValue(); + return wrapVariantElementAtWithCast(newElementAt, variantType, fieldName); + } + return newElementAt; + } + @Override public Expression visitOr(Or or, ExpressionRewriteContext context) { List children = ExpressionUtils.extractDisjunction(or); @@ -1099,7 +1118,8 @@ private Optional bindNestedFields(UnboundSlot unboundSlot, Slot slot expression = new ElementAt(expression, new StringLiteral(fieldName)); continue; } else if (dataType.isVariantType()) { - expression = new ElementAt(expression, new StringLiteral(fieldName)); + ElementAt elementAt = new ElementAt(expression, new StringLiteral(fieldName)); + expression = wrapVariantElementAtWithCast(elementAt, (VariantType) dataType, fieldName); continue; } throw new AnalysisException("No such field '" + fieldName + "' in '" + lastFieldName + "'"); @@ -1115,6 +1135,19 @@ public static boolean sameTableName(String boundSlot, String unboundSlot) { } } + /** + * Wrap ElementAt with Cast if the variant type has a matching predefined field. + * This enables auto-cast for variant schema template. + */ + private static Expression wrapVariantElementAtWithCast( + ElementAt elementAt, VariantType variantType, String fieldName) { + Optional matchingField = variantType.findMatchingField(fieldName); + if (matchingField.isPresent()) { + return new Cast(elementAt, matchingField.get().getDataType()); + } + return elementAt; + } + private boolean shouldBindSlotBy(int namePartSize, Slot boundSlot) { return namePartSize <= boundSlot.getQualifier().size() + 1; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java index 84e108e4c47e1f..145f1dbc68fb53 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java @@ -43,6 +43,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import java.util.List; import java.util.Optional; @@ -62,6 +64,8 @@ */ public class VariantSchemaCast implements CustomRewriter { + private static final Logger LOG = LogManager.getLogger(VariantSchemaCast.class); + @Override public Plan rewriteRoot(Plan plan, JobContext jobContext) { return plan.accept(new PlanRewriter(), null); @@ -89,24 +93,36 @@ private static Expression rewriteElementAt(ElementAt elementAt) { Expression left = elementAt.left(); Expression right = elementAt.right(); + // Debug logging + LOG.info("Processing ElementAt: {}", elementAt); + LOG.info("Left type: {}, class: {}", left.getDataType(), + left.getDataType().getClass().getName()); + // Only process if left is VariantType and right is a string literal if (!(left.getDataType() instanceof VariantType)) { + LOG.info("Left is not VariantType, skipping"); return elementAt; } if (!(right instanceof StringLikeLiteral)) { + LOG.info("Right is not StringLikeLiteral, skipping"); return elementAt; } VariantType variantType = (VariantType) left.getDataType(); String fieldName = ((StringLikeLiteral) right).getStringValue(); + LOG.info("predefinedFields: {}, fieldName: {}", + variantType.getPredefinedFields(), fieldName); + // Find matching field in schema template Optional matchingField = variantType.findMatchingField(fieldName); if (!matchingField.isPresent()) { + LOG.info("No matching field found for: {}", fieldName); return elementAt; } DataType targetType = matchingField.get().getDataType(); + LOG.info("Found matching field, target type: {}", targetType); return new Cast(elementAt, targetType); } From cd777b776723ba66fa7cf50edc4dacf3a7ea8cec Mon Sep 17 00:00:00 2001 From: Gary Date: Sat, 31 Jan 2026 05:34:05 +0800 Subject: [PATCH 08/27] maybe done --- .../doris/nereids/jobs/executor/Rewriter.java | 11 - .../apache/doris/nereids/rules/RuleType.java | 1 - .../rules/rewrite/VariantSchemaCast.java | 259 ------------------ ...st.java => VariantSchemaTemplateTest.java} | 6 +- 4 files changed, 3 insertions(+), 274 deletions(-) delete mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java rename fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/{VariantSchemaCastTest.java => VariantSchemaTemplateTest.java} (98%) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index 7e48f15af0d493..9e3186c6cec7b2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -168,7 +168,6 @@ import org.apache.doris.nereids.rules.rewrite.TransposeSemiJoinAggProject; import org.apache.doris.nereids.rules.rewrite.TransposeSemiJoinLogicalJoin; import org.apache.doris.nereids.rules.rewrite.TransposeSemiJoinLogicalJoinProject; -import org.apache.doris.nereids.rules.rewrite.VariantSchemaCast; import org.apache.doris.nereids.rules.rewrite.VariantSubPathPruning; import org.apache.doris.nereids.rules.rewrite.batch.ApplyToJoin; import org.apache.doris.nereids.rules.rewrite.batch.CorrelateApplyToUnCorrelateApply; @@ -276,11 +275,6 @@ public class Rewriter extends AbstractBatchJobExecutor { new EliminateSemiJoin() ) ), - // Auto cast variant element access based on schema template - // This must run before NormalizeSort which converts ORDER BY expressions to slots - topic("variant schema cast before normalize", - custom(RuleType.VARIANT_SCHEMA_CAST, VariantSchemaCast::new) - ), // The rule modification needs to be done after the subquery is unnested, // because for scalarSubQuery, the connection condition is stored in apply in // the analyzer phase, @@ -518,11 +512,6 @@ public class Rewriter extends AbstractBatchJobExecutor { new SimplifyEncodeDecode() ) ), - // Auto cast variant element access based on schema template - // This must run before NormalizeSort which converts ORDER BY expressions to slots - topic("variant schema cast before normalize", - custom(RuleType.VARIANT_SCHEMA_CAST, VariantSchemaCast::new) - ), // The rule modification needs to be done after the subquery is unnested, // because for scalarSubQuery, the connection condition is stored in apply in the analyzer phase, // but when normalizeAggregate/normalizeSort is performed, the members in apply cannot be obtained, diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 1dcce5c159bff3..2587481256b798 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -224,7 +224,6 @@ public enum RuleType { ADD_PROJECT_FOR_JOIN(RuleTypeClass.REWRITE), ADD_PROJECT_FOR_UNIQUE_FUNCTION(RuleTypeClass.REWRITE), - VARIANT_SCHEMA_CAST(RuleTypeClass.REWRITE), VARIANT_SUB_PATH_PRUNING(RuleTypeClass.REWRITE), NESTED_COLUMN_PRUNING(RuleTypeClass.REWRITE), CLEAR_CONTEXT_STATUS(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java deleted file mode 100644 index 145f1dbc68fb53..00000000000000 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java +++ /dev/null @@ -1,259 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.rewrite; - -import org.apache.doris.nereids.jobs.JobContext; -import org.apache.doris.nereids.properties.OrderKey; -import org.apache.doris.nereids.trees.expressions.Alias; -import org.apache.doris.nereids.trees.expressions.Cast; -import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.Match; -import org.apache.doris.nereids.trees.expressions.NamedExpression; -import org.apache.doris.nereids.trees.expressions.SlotReference; -import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt; -import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral; -import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; -import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; -import org.apache.doris.nereids.trees.plans.logical.LogicalHaving; -import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; -import org.apache.doris.nereids.trees.plans.logical.LogicalProject; -import org.apache.doris.nereids.trees.plans.logical.LogicalSort; -import org.apache.doris.nereids.trees.plans.logical.LogicalTopN; -import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; -import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; -import org.apache.doris.nereids.types.DataType; -import org.apache.doris.nereids.types.VariantField; -import org.apache.doris.nereids.types.VariantType; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; - -import java.util.List; -import java.util.Optional; -import java.util.Set; - -/** - * Automatically cast variant element access expressions based on schema template. - * - * For example, if a variant column is defined as: - * payload VARIANT<'number_*': BIGINT, 'string_*': STRING> - * - * Then payload['number_latency'] will be automatically cast to BIGINT, - * and payload['string_message'] will be automatically cast to STRING. - * - * This allows users to use variant sub-fields directly in WHERE, ORDER BY, - * and other clauses without explicit CAST. - */ -public class VariantSchemaCast implements CustomRewriter { - - private static final Logger LOG = LogManager.getLogger(VariantSchemaCast.class); - - @Override - public Plan rewriteRoot(Plan plan, JobContext jobContext) { - return plan.accept(new PlanRewriter(), null); - } - - private static class PlanRewriter extends DefaultPlanRewriter { - - private final java.util.function.Function expressionRewriter = expr -> { - // Skip Match expressions - they require SlotRef as left operand - if (expr instanceof Match) { - return expr; - } - // Handle ElementAt expressions - if (expr instanceof ElementAt) { - return rewriteElementAt((ElementAt) expr); - } - // Handle SlotReference that represents variant element access (e.g., data['field']) - if (expr instanceof SlotReference) { - return rewriteSlotReference((SlotReference) expr); - } - return expr; - }; - - private static Expression rewriteElementAt(ElementAt elementAt) { - Expression left = elementAt.left(); - Expression right = elementAt.right(); - - // Debug logging - LOG.info("Processing ElementAt: {}", elementAt); - LOG.info("Left type: {}, class: {}", left.getDataType(), - left.getDataType().getClass().getName()); - - // Only process if left is VariantType and right is a string literal - if (!(left.getDataType() instanceof VariantType)) { - LOG.info("Left is not VariantType, skipping"); - return elementAt; - } - if (!(right instanceof StringLikeLiteral)) { - LOG.info("Right is not StringLikeLiteral, skipping"); - return elementAt; - } - - VariantType variantType = (VariantType) left.getDataType(); - String fieldName = ((StringLikeLiteral) right).getStringValue(); - - LOG.info("predefinedFields: {}, fieldName: {}", - variantType.getPredefinedFields(), fieldName); - - // Find matching field in schema template - Optional matchingField = variantType.findMatchingField(fieldName); - if (!matchingField.isPresent()) { - LOG.info("No matching field found for: {}", fieldName); - return elementAt; - } - - DataType targetType = matchingField.get().getDataType(); - LOG.info("Found matching field, target type: {}", targetType); - return new Cast(elementAt, targetType); - } - - private static Expression rewriteSlotReference(SlotReference slotRef) { - // Check if the SlotReference's DataType is VariantType with predefinedFields - if (!(slotRef.getDataType() instanceof VariantType)) { - return slotRef; - } - - VariantType variantType = (VariantType) slotRef.getDataType(); - if (variantType.getPredefinedFields().isEmpty()) { - return slotRef; - } - - // Extract field name from SlotReference name pattern like "data['field_name']" - String slotName = slotRef.getName(); - - // Parse field name from pattern like "column['field']" or "column[\"field\"]" - int bracketStart = slotName.indexOf('['); - if (bracketStart < 0) { - return slotRef; - } - - int bracketEnd = slotName.lastIndexOf(']'); - if (bracketEnd <= bracketStart) { - return slotRef; - } - - // Extract the content between brackets and remove quotes - String bracketContent = slotName.substring(bracketStart + 1, bracketEnd); - String fieldName = bracketContent; - if ((bracketContent.startsWith("'") && bracketContent.endsWith("'")) - || (bracketContent.startsWith("\"") && bracketContent.endsWith("\""))) { - fieldName = bracketContent.substring(1, bracketContent.length() - 1); - } - - // Find matching field in schema template - Optional matchingField = variantType.findMatchingField(fieldName); - if (!matchingField.isPresent()) { - return slotRef; - } - - DataType targetType = matchingField.get().getDataType(); - return new Cast(slotRef, targetType); - } - - private Expression rewriteExpression(Expression expr) { - return expr.rewriteDownShortCircuit(expressionRewriter); - } - - private NamedExpression rewriteNamedExpression(NamedExpression expr) { - Expression rewritten = rewriteExpression(expr); - if (rewritten instanceof NamedExpression) { - return (NamedExpression) rewritten; - } - // If the result is not a NamedExpression (e.g., Cast), wrap it in an Alias - // Preserve the original ExprId to maintain consistency - return new Alias(expr.getExprId(), rewritten, expr.getName()); - } - - @Override - public Plan visitLogicalFilter(LogicalFilter filter, Void context) { - filter = (LogicalFilter) super.visit(filter, context); - Set newConjuncts = filter.getConjuncts().stream() - .map(this::rewriteExpression) - .collect(ImmutableSet.toImmutableSet()); - return filter.withConjuncts(newConjuncts); - } - - @Override - public Plan visitLogicalProject(LogicalProject project, Void context) { - project = (LogicalProject) super.visit(project, context); - List newProjects = project.getProjects().stream() - .map(this::rewriteNamedExpression) - .collect(ImmutableList.toImmutableList()); - return project.withProjects(newProjects); - } - - @Override - public Plan visitLogicalSort(LogicalSort sort, Void context) { - sort = (LogicalSort) super.visit(sort, context); - List newOrderKeys = sort.getOrderKeys().stream() - .map(orderKey -> orderKey.withExpression(rewriteExpression(orderKey.getExpr()))) - .collect(ImmutableList.toImmutableList()); - return sort.withOrderKeys(newOrderKeys); - } - - @Override - public Plan visitLogicalTopN(LogicalTopN topN, Void context) { - topN = (LogicalTopN) super.visit(topN, context); - List newOrderKeys = topN.getOrderKeys().stream() - .map(orderKey -> orderKey.withExpression(rewriteExpression(orderKey.getExpr()))) - .collect(ImmutableList.toImmutableList()); - return topN.withOrderKeys(newOrderKeys); - } - - @Override - public Plan visitLogicalAggregate(LogicalAggregate aggregate, Void context) { - aggregate = (LogicalAggregate) super.visit(aggregate, context); - List newGroupByExprs = aggregate.getGroupByExpressions().stream() - .map(this::rewriteExpression) - .collect(ImmutableList.toImmutableList()); - List newOutputExprs = aggregate.getOutputExpressions().stream() - .map(this::rewriteNamedExpression) - .collect(ImmutableList.toImmutableList()); - return aggregate.withGroupByAndOutput(newGroupByExprs, newOutputExprs); - } - - @Override - public Plan visitLogicalHaving(LogicalHaving having, Void context) { - having = (LogicalHaving) super.visit(having, context); - Set newConjuncts = having.getConjuncts().stream() - .map(this::rewriteExpression) - .collect(ImmutableSet.toImmutableSet()); - return having.withConjuncts(newConjuncts); - } - - @Override - public Plan visitLogicalJoin(LogicalJoin join, Void context) { - join = (LogicalJoin) super.visit(join, context); - List newHashConditions = join.getHashJoinConjuncts().stream() - .map(this::rewriteExpression) - .collect(ImmutableList.toImmutableList()); - List newOtherConditions = join.getOtherJoinConjuncts().stream() - .map(this::rewriteExpression) - .collect(ImmutableList.toImmutableList()); - List newMarkConditions = join.getMarkJoinConjuncts().stream() - .map(this::rewriteExpression) - .collect(ImmutableList.toImmutableList()); - return join.withJoinConjuncts(newHashConditions, newOtherConditions, - newMarkConditions, join.getJoinReorderContext()); - } - } -} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCastTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaTemplateTest.java similarity index 98% rename from fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCastTest.java rename to fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaTemplateTest.java index eebb2535d2eb25..171e9d924a6978 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCastTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaTemplateTest.java @@ -44,11 +44,11 @@ import java.util.function.Function; /** - * Unit tests for VariantSchemaCast expression rewriting. + * Unit tests for variant schema template auto-cast expression rewriting. */ -public class VariantSchemaCastTest { +public class VariantSchemaTemplateTest { - // Expression rewriter extracted from VariantSchemaCast for testing + // Expression rewriter for variant schema template auto-cast private static final Function EXPRESSION_REWRITER = expr -> { if (!(expr instanceof ElementAt)) { return expr; From c21e1182dadc032e39e82be4ef9876b82b71f5bb Mon Sep 17 00:00:00 2001 From: Gary Date: Sat, 31 Jan 2026 11:02:18 +0800 Subject: [PATCH 09/27] use processBoundFunction --- .../doris/nereids/rules/analysis/ExpressionAnalyzer.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java index 8b0be52709e1f1..39db2423768f80 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java @@ -652,7 +652,8 @@ public Expression visitElementAt(ElementAt elementAt, ExpressionRewriteContext c String fieldName = ((StringLikeLiteral) right).getStringValue(); return wrapVariantElementAtWithCast(newElementAt, variantType, fieldName); } - return newElementAt; + // For non-variant cases (array/map), apply normal type coercion + return TypeCoercionUtils.processBoundFunction(newElementAt); } @Override From b7f18250bb16a287b59fd40a3473e4b19c9e7e76 Mon Sep 17 00:00:00 2001 From: Gary Date: Sat, 31 Jan 2026 19:36:21 +0800 Subject: [PATCH 10/27] enhance fe ut --- .../doris/nereids/types/VariantField.java | 35 ++++++++- .../nereids/types/VariantFieldMatchTest.java | 76 +++++++++++++++++++ 2 files changed, 109 insertions(+), 2 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantField.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantField.java index a3d0776e0a8f30..1313e8579b3f3e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantField.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantField.java @@ -97,6 +97,13 @@ public boolean matches(String fieldName) { /** * Convert glob pattern to regex pattern, aligning with fnmatch(FNM_PATHNAME) behavior. + * + * fnmatch with FNM_PATHNAME flag behavior: + * - '*' matches any sequence of characters except '/' + * - '?' matches any single character except '/' + * - '[...]' matches any character in the brackets + * - '[!...]' or '[^...]' matches any character not in the brackets + * - '\' escapes the next character (e.g., '\*' matches literal '*') */ private static String globToRegex(String glob) { StringBuilder regex = new StringBuilder(); @@ -106,6 +113,22 @@ private static String globToRegex(String glob) { while (i < len) { char c = glob.charAt(i); switch (c) { + case '\\': + // Escape sequence: next character should be matched literally + // This aligns with fnmatch behavior where \* matches literal * + if (i + 1 < len) { + i++; + char nextChar = glob.charAt(i); + // Escape the next character for regex if it's a regex special char + if (isRegexSpecialChar(nextChar)) { + regex.append('\\'); + } + regex.append(nextChar); + } else { + // Trailing backslash, treat as literal backslash + regex.append("\\\\"); + } + break; case '*': // '*' matches any sequence of characters except '/' (FNM_PATHNAME) regex.append("[^/]*"); @@ -146,8 +169,7 @@ private static String globToRegex(String glob) { i = j; // Move past the closing ] } break; - // Escape regex special characters - case '\\': + // Escape regex special characters (except backslash which is handled above) case '.': case '(': case ')': @@ -168,6 +190,15 @@ private static String globToRegex(String glob) { return regex.toString(); } + /** + * Check if a character is a regex special character that needs escaping. + */ + private static boolean isRegexSpecialChar(char c) { + return c == '\\' || c == '.' || c == '(' || c == ')' || c == '[' + || c == ']' || c == '{' || c == '}' || c == '+' || c == '*' + || c == '?' || c == '^' || c == '$' || c == '|'; + } + public org.apache.doris.catalog.VariantField toCatalogDataType() { return new org.apache.doris.catalog.VariantField( pattern, dataType.toCatalogDataType(), comment, patternType); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/types/VariantFieldMatchTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/types/VariantFieldMatchTest.java index f981ef75228d5a..27969e0a6cf42f 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/types/VariantFieldMatchTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/types/VariantFieldMatchTest.java @@ -170,4 +170,80 @@ public void testFindMatchingFieldEmptyPredefinedFields() { Optional result = variantType.findMatchingField("any_field"); Assertions.assertFalse(result.isPresent()); } + + // ==================== Escape sequence tests (aligning with fnmatch behavior) ==================== + + @Test + public void testGlobEscapeAsterisk() { + // Pattern: int_\* should match literal "int_*", not "int_" followed by anything + VariantField field = new VariantField("int_\\*", BigIntType.INSTANCE, "", + TPatternType.MATCH_NAME_GLOB.name()); + + Assertions.assertTrue(field.matches("int_*")); + Assertions.assertFalse(field.matches("int_nested")); + Assertions.assertFalse(field.matches("int_")); + } + + @Test + public void testGlobEscapeQuestionMark() { + // Pattern: int_\? should match literal "int_?", not "int_" followed by any single char + VariantField field = new VariantField("int_\\?", BigIntType.INSTANCE, "", + TPatternType.MATCH_NAME_GLOB.name()); + + Assertions.assertTrue(field.matches("int_?")); + Assertions.assertFalse(field.matches("int_1")); + Assertions.assertFalse(field.matches("int_")); + } + + @Test + public void testGlobEscapeBracket() { + // Pattern: int_\[ should match literal "int_[" + VariantField field = new VariantField("int_\\[", BigIntType.INSTANCE, "", + TPatternType.MATCH_NAME_GLOB.name()); + + Assertions.assertTrue(field.matches("int_[")); + Assertions.assertFalse(field.matches("int_a")); + } + + @Test + public void testGlobEscapeBackslash() { + // Pattern: int_\\ should match literal "int_\" + VariantField field = new VariantField("int_\\\\", BigIntType.INSTANCE, "", + TPatternType.MATCH_NAME_GLOB.name()); + + Assertions.assertTrue(field.matches("int_\\")); + Assertions.assertFalse(field.matches("int_")); + } + + @Test + public void testGlobWithSlashSeparator() { + // With FNM_PATHNAME, '*' should not match '/' + VariantField field = new VariantField("int_*", BigIntType.INSTANCE, "", + TPatternType.MATCH_NAME_GLOB.name()); + + Assertions.assertTrue(field.matches("int_nested")); + Assertions.assertTrue(field.matches("int_nested.level1")); // '.' is matched by '*' + Assertions.assertFalse(field.matches("int_nested/level1")); // '/' is NOT matched by '*' + } + + @Test + public void testGlobCharacterClass() { + // Character class tests + VariantField field1 = new VariantField("int_[0-9]", BigIntType.INSTANCE, "", + TPatternType.MATCH_NAME_GLOB.name()); + Assertions.assertTrue(field1.matches("int_1")); + Assertions.assertFalse(field1.matches("int_a")); + + // Negated character class with ! + VariantField field2 = new VariantField("int_[!0-9]", BigIntType.INSTANCE, "", + TPatternType.MATCH_NAME_GLOB.name()); + Assertions.assertTrue(field2.matches("int_a")); + Assertions.assertFalse(field2.matches("int_1")); + + // Negated character class with ^ + VariantField field3 = new VariantField("int_[^0-9]", BigIntType.INSTANCE, "", + TPatternType.MATCH_NAME_GLOB.name()); + Assertions.assertTrue(field3.matches("int_a")); + Assertions.assertFalse(field3.matches("int_1")); + } } From 72c0f66fe68bafe3e9edbb8e9e112e389e47a057 Mon Sep 17 00:00:00 2001 From: Gary Date: Sat, 31 Jan 2026 21:55:08 +0800 Subject: [PATCH 11/27] reapply VariantSchemaCast rules --- .../doris/nereids/jobs/executor/Rewriter.java | 9 + .../apache/doris/nereids/rules/RuleType.java | 1 + .../rules/analysis/CheckAfterRewrite.java | 22 +- .../rules/analysis/ExpressionAnalyzer.java | 37 +-- .../rules/rewrite/VariantSchemaCast.java | 259 ++++++++++++++++++ .../test_schema_template_auto_cast.out | 41 +++ .../test_schema_template_auto_cast.groovy | 51 +++- 7 files changed, 382 insertions(+), 38 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index 9e3186c6cec7b2..7f10573da76877 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -169,6 +169,7 @@ import org.apache.doris.nereids.rules.rewrite.TransposeSemiJoinLogicalJoin; import org.apache.doris.nereids.rules.rewrite.TransposeSemiJoinLogicalJoinProject; import org.apache.doris.nereids.rules.rewrite.VariantSubPathPruning; +import org.apache.doris.nereids.rules.rewrite.VariantSchemaCast; import org.apache.doris.nereids.rules.rewrite.batch.ApplyToJoin; import org.apache.doris.nereids.rules.rewrite.batch.CorrelateApplyToUnCorrelateApply; import org.apache.doris.nereids.rules.rewrite.batch.EliminateUselessPlanUnderApply; @@ -914,6 +915,14 @@ private static List getWholeTreeRewriteJobs( bottomUp(new RewriteSearchToSlots()) )); + // Auto cast variant element access based on schema template + // This should run before VariantSubPathPruning + rewriteJobs.addAll(jobs( + topic("variant schema cast", + custom(RuleType.VARIANT_SCHEMA_CAST, VariantSchemaCast::new) + ) + )); + if (needSubPathPushDown) { rewriteJobs.addAll(jobs( topic("variant element_at push down", diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 2587481256b798..1dcce5c159bff3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -224,6 +224,7 @@ public enum RuleType { ADD_PROJECT_FOR_JOIN(RuleTypeClass.REWRITE), ADD_PROJECT_FOR_UNIQUE_FUNCTION(RuleTypeClass.REWRITE), + VARIANT_SCHEMA_CAST(RuleTypeClass.REWRITE), VARIANT_SUB_PATH_PRUNING(RuleTypeClass.REWRITE), NESTED_COLUMN_PRUNING(RuleTypeClass.REWRITE), CLEAR_CONTEXT_STATUS(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAfterRewrite.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAfterRewrite.java index 915ac92a9b5aa3..9b641b81755be0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAfterRewrite.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAfterRewrite.java @@ -23,6 +23,7 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Match; import org.apache.doris.nereids.trees.expressions.Slot; @@ -193,7 +194,7 @@ private void checkMetricTypeIsUsedCorrectly(Plan plan) { } else if (plan instanceof LogicalJoin) { LogicalJoin join = (LogicalJoin) plan; for (Expression conjunct : join.getHashJoinConjuncts()) { - if (conjunct.anyMatch(e -> ((Expression) e).getDataType().isVariantType())) { + if (containsVariantTypeOutsideCast(conjunct)) { throw new AnalysisException("variant type could not in join equal conditions: " + conjunct.toSql()); } else if (conjunct.anyMatch(e -> ((Expression) e).getDataType().isVarBinaryType())) { throw new AnalysisException( @@ -201,7 +202,7 @@ private void checkMetricTypeIsUsedCorrectly(Plan plan) { } } for (Expression conjunct : join.getMarkJoinConjuncts()) { - if (conjunct.anyMatch(e -> ((Expression) e).getDataType().isVariantType())) { + if (containsVariantTypeOutsideCast(conjunct)) { throw new AnalysisException("variant type could not in join equal conditions: " + conjunct.toSql()); } else if (conjunct.anyMatch(e -> ((Expression) e).getDataType().isVarBinaryType())) { throw new AnalysisException( @@ -211,6 +212,23 @@ private void checkMetricTypeIsUsedCorrectly(Plan plan) { } } + private boolean containsVariantTypeOutsideCast(Expression expr) { + return containsVariantTypeOutsideCast(expr, false); + } + + private boolean containsVariantTypeOutsideCast(Expression expr, boolean underCast) { + boolean nextUnderCast = underCast || expr instanceof Cast; + if (!nextUnderCast && expr.getDataType().isVariantType()) { + return true; + } + for (Expression child : expr.children()) { + if (containsVariantTypeOutsideCast(child, nextUnderCast)) { + return true; + } + } + return false; + } + private void checkMatchIsUsedCorrectly(Plan plan) { for (Expression expression : plan.getExpressions()) { if (expression instanceof Match) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java index 39db2423768f80..866b460945e11f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java @@ -84,7 +84,6 @@ import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; -import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral; import org.apache.doris.nereids.trees.expressions.literal.StringLiteral; import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes; import org.apache.doris.nereids.trees.plans.PlaceholderId; @@ -100,7 +99,6 @@ import org.apache.doris.nereids.types.StructField; import org.apache.doris.nereids.types.StructType; import org.apache.doris.nereids.types.TinyIntType; -import org.apache.doris.nereids.types.VariantField; import org.apache.doris.nereids.types.VariantType; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.TypeCoercionUtils; @@ -276,9 +274,7 @@ public Expression visitDereferenceExpression(DereferenceExpression dereferenceEx } else if (dataType.isMapType()) { return new ElementAt(expression, dereferenceExpression.child(1)); } else if (dataType.isVariantType()) { - ElementAt elementAt = new ElementAt(expression, dereferenceExpression.child(1)); - return wrapVariantElementAtWithCast(elementAt, (VariantType) dataType, - dereferenceExpression.fieldName); + return new ElementAt(expression, dereferenceExpression.child(1)); } throw new AnalysisException("Can not dereference field: " + dereferenceExpression.fieldName); } @@ -641,21 +637,6 @@ public Expression visitBinaryArithmetic(BinaryArithmetic binaryArithmetic, Expre return TypeCoercionUtils.processBinaryArithmetic(binaryArithmetic); } - @Override - public Expression visitElementAt(ElementAt elementAt, ExpressionRewriteContext context) { - Expression left = elementAt.left().accept(this, context); - Expression right = elementAt.right().accept(this, context); - ElementAt newElementAt = (ElementAt) elementAt.withChildren(left, right); - // Auto-cast for variant schema template - if (left.getDataType() instanceof VariantType && right instanceof StringLikeLiteral) { - VariantType variantType = (VariantType) left.getDataType(); - String fieldName = ((StringLikeLiteral) right).getStringValue(); - return wrapVariantElementAtWithCast(newElementAt, variantType, fieldName); - } - // For non-variant cases (array/map), apply normal type coercion - return TypeCoercionUtils.processBoundFunction(newElementAt); - } - @Override public Expression visitOr(Or or, ExpressionRewriteContext context) { List children = ExpressionUtils.extractDisjunction(or); @@ -1119,8 +1100,7 @@ private Optional bindNestedFields(UnboundSlot unboundSlot, Slot slot expression = new ElementAt(expression, new StringLiteral(fieldName)); continue; } else if (dataType.isVariantType()) { - ElementAt elementAt = new ElementAt(expression, new StringLiteral(fieldName)); - expression = wrapVariantElementAtWithCast(elementAt, (VariantType) dataType, fieldName); + expression = new ElementAt(expression, new StringLiteral(fieldName)); continue; } throw new AnalysisException("No such field '" + fieldName + "' in '" + lastFieldName + "'"); @@ -1136,19 +1116,6 @@ public static boolean sameTableName(String boundSlot, String unboundSlot) { } } - /** - * Wrap ElementAt with Cast if the variant type has a matching predefined field. - * This enables auto-cast for variant schema template. - */ - private static Expression wrapVariantElementAtWithCast( - ElementAt elementAt, VariantType variantType, String fieldName) { - Optional matchingField = variantType.findMatchingField(fieldName); - if (matchingField.isPresent()) { - return new Cast(elementAt, matchingField.get().getDataType()); - } - return elementAt; - } - private boolean shouldBindSlotBy(int namePartSize, Slot boundSlot) { return namePartSize <= boundSlot.getQualifier().size() + 1; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java new file mode 100644 index 00000000000000..68f1951b029103 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java @@ -0,0 +1,259 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.jobs.JobContext; +import org.apache.doris.nereids.properties.OrderKey; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Cast; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.ExprId; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.OrderExpression; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.WindowExpression; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt; +import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalHaving; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalPartitionTopN; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.trees.plans.logical.LogicalSort; +import org.apache.doris.nereids.trees.plans.logical.LogicalSubQueryAlias; +import org.apache.doris.nereids.trees.plans.logical.LogicalTopN; +import org.apache.doris.nereids.trees.plans.logical.LogicalWindow; +import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; +import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.VariantField; +import org.apache.doris.nereids.types.VariantType; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.function.Function; + +/** + * Automatically cast variant element access expressions based on schema template. + * + * This rule only targets non-select clauses (e.g. WHERE, GROUP BY, HAVING, ORDER BY, JOIN ON). + * It should run before VariantSubPathPruning so ElementAt is still present. + */ +public class VariantSchemaCast implements CustomRewriter { + + @Override + public Plan rewriteRoot(Plan plan, JobContext jobContext) { + return plan.accept(PlanRewriter.INSTANCE, null); + } + + private static class PlanRewriter extends DefaultPlanRewriter { + public static final PlanRewriter INSTANCE = new PlanRewriter(); + + private static final Function EXPRESSION_REWRITER = expr -> { + if (!(expr instanceof ElementAt)) { + return expr; + } + ElementAt elementAt = (ElementAt) expr; + Expression left = elementAt.left(); + Expression right = elementAt.right(); + + if (!(left.getDataType() instanceof VariantType)) { + return expr; + } + if (!(right instanceof StringLikeLiteral)) { + return expr; + } + + VariantType variantType = (VariantType) left.getDataType(); + String fieldName = ((StringLikeLiteral) right).getStringValue(); + Optional matchingField = variantType.findMatchingField(fieldName); + if (!matchingField.isPresent()) { + return expr; + } + + DataType targetType = matchingField.get().getDataType(); + return new Cast(elementAt, targetType); + }; + + private Expression rewriteExpression(Expression expr) { + return expr.rewriteDownShortCircuit(EXPRESSION_REWRITER); + } + + private Expression rewriteExpression(Expression expr, Map aliasMap) { + if (aliasMap.isEmpty()) { + return rewriteExpression(expr); + } + return expr.rewriteDownShortCircuit(node -> { + if (node instanceof SlotReference) { + ExprId exprId = ((SlotReference) node).getExprId(); + Expression aliasExpr = aliasMap.get(exprId); + if (aliasExpr != null) { + Expression rewrittenAlias = rewriteExpression(aliasExpr); + if (rewrittenAlias instanceof Cast) { + return new Cast(node, ((Cast) rewrittenAlias).getDataType()); + } + } + } + return EXPRESSION_REWRITER.apply(node); + }); + } + + private Map buildAliasMap(Plan plan) { + Map aliasMap = new HashMap<>(); + Plan current = plan; + while (current instanceof LogicalSubQueryAlias) { + current = ((LogicalSubQueryAlias) current).child(); + } + if (current instanceof LogicalProject) { + collectAliasMap(aliasMap, ((LogicalProject) current).getProjects()); + } else if (current instanceof LogicalAggregate) { + collectAliasMap(aliasMap, ((LogicalAggregate) current).getOutputExpressions()); + } + return aliasMap; + } + + private void collectAliasMap(Map aliasMap, List outputs) { + for (NamedExpression output : outputs) { + if (output instanceof Alias) { + Alias alias = (Alias) output; + aliasMap.put(alias.getExprId(), alias.child()); + } + } + } + + @Override + public Plan visitLogicalFilter(LogicalFilter filter, Void context) { + filter = (LogicalFilter) super.visit(filter, context); + Map aliasMap = buildAliasMap(filter.child()); + Set newConjuncts = filter.getConjuncts().stream() + .map(expr -> rewriteExpression(expr, aliasMap)) + .collect(ImmutableSet.toImmutableSet()); + return filter.withConjuncts(newConjuncts); + } + + @Override + public Plan visitLogicalHaving(LogicalHaving having, Void context) { + having = (LogicalHaving) super.visit(having, context); + Map aliasMap = buildAliasMap(having.child()); + Set newConjuncts = having.getConjuncts().stream() + .map(expr -> rewriteExpression(expr, aliasMap)) + .collect(ImmutableSet.toImmutableSet()); + return having.withConjuncts(newConjuncts); + } + + @Override + public Plan visitLogicalSort(LogicalSort sort, Void context) { + sort = (LogicalSort) super.visit(sort, context); + Map aliasMap = buildAliasMap(sort.child()); + List newOrderKeys = sort.getOrderKeys().stream() + .map(orderKey -> orderKey.withExpression(rewriteExpression(orderKey.getExpr(), aliasMap))) + .collect(ImmutableList.toImmutableList()); + return sort.withOrderKeys(newOrderKeys); + } + + @Override + public Plan visitLogicalTopN(LogicalTopN topN, Void context) { + topN = (LogicalTopN) super.visit(topN, context); + Map aliasMap = buildAliasMap(topN.child()); + List newOrderKeys = topN.getOrderKeys().stream() + .map(orderKey -> orderKey.withExpression(rewriteExpression(orderKey.getExpr(), aliasMap))) + .collect(ImmutableList.toImmutableList()); + return topN.withOrderKeys(newOrderKeys); + } + + @Override + public Plan visitLogicalPartitionTopN(LogicalPartitionTopN topN, Void context) { + topN = (LogicalPartitionTopN) super.visit(topN, context); + Map aliasMap = buildAliasMap(topN.child()); + List newPartitionKeys = topN.getPartitionKeys().stream() + .map(expr -> rewriteExpression(expr, aliasMap)) + .collect(ImmutableList.toImmutableList()); + List newOrderKeys = topN.getOrderKeys().stream() + .map(orderExpr -> (OrderExpression) orderExpr.withChildren(ImmutableList.of( + rewriteExpression(orderExpr.child(), aliasMap)))) + .collect(ImmutableList.toImmutableList()); + return topN.withPartitionKeysAndOrderKeys(newPartitionKeys, newOrderKeys); + } + + @Override + public Plan visitLogicalJoin(LogicalJoin join, Void context) { + join = (LogicalJoin) super.visit(join, context); + Map aliasMap = buildAliasMap(join.left()); + aliasMap.putAll(buildAliasMap(join.right())); + List newHashConditions = join.getHashJoinConjuncts().stream() + .map(expr -> rewriteExpression(expr, aliasMap)) + .collect(ImmutableList.toImmutableList()); + List newOtherConditions = join.getOtherJoinConjuncts().stream() + .map(expr -> rewriteExpression(expr, aliasMap)) + .collect(ImmutableList.toImmutableList()); + List newMarkConditions = join.getMarkJoinConjuncts().stream() + .map(expr -> rewriteExpression(expr, aliasMap)) + .collect(ImmutableList.toImmutableList()); + return join.withJoinConjuncts(newHashConditions, newOtherConditions, + newMarkConditions, join.getJoinReorderContext()); + } + + @Override + public Plan visitLogicalAggregate(LogicalAggregate aggregate, Void context) { + aggregate = (LogicalAggregate) super.visit(aggregate, context); + Map aliasMap = buildAliasMap(aggregate.child()); + List newGroupByKeys = aggregate.getGroupByExpressions().stream() + .map(expr -> rewriteExpression(expr, aliasMap)) + .collect(ImmutableList.toImmutableList()); + List outputs = aggregate.getOutputExpressions(); + return aggregate.withGroupByAndOutput(newGroupByKeys, outputs); + } + + @Override + public Plan visitLogicalWindow(LogicalWindow window, Void context) { + window = (LogicalWindow) super.visit(window, context); + Map aliasMap = buildAliasMap(window.child()); + List newExprs = window.getWindowExpressions().stream() + .map(expr -> rewriteWindowExpression(expr, aliasMap)) + .collect(ImmutableList.toImmutableList()); + return window.withExpressionsAndChild(newExprs, window.child()); + } + + private NamedExpression rewriteWindowExpression(NamedExpression expr, Map aliasMap) { + if (expr instanceof Alias) { + Alias alias = (Alias) expr; + if (alias.child() instanceof WindowExpression) { + WindowExpression windowExpr = (WindowExpression) alias.child(); + List newPartitionKeys = windowExpr.getPartitionKeys().stream() + .map(partitionKey -> rewriteExpression(partitionKey, aliasMap)) + .collect(ImmutableList.toImmutableList()); + List newOrderKeys = windowExpr.getOrderKeys().stream() + .map(orderExpr -> (OrderExpression) orderExpr.withChildren(ImmutableList.of( + rewriteExpression(orderExpr.child(), aliasMap)))) + .collect(ImmutableList.toImmutableList()); + return alias.withChildren(ImmutableList.of( + windowExpr.withPartitionKeysOrderKeys(newPartitionKeys, newOrderKeys))); + } + } + return expr; + } + } +} diff --git a/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out b/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out index 216c8b88f77789..65e525900df5ad 100644 --- a/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out +++ b/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out @@ -38,7 +38,48 @@ alice 25 bob 30 charlie 50 +-- !order_by_alias -- +10 +15 +30 +50 + +-- !order_by_alias_subquery -- +1 10 +4 15 +2 30 +3 50 + +-- !group_by_alias_subquery -- +10 1 +15 1 +30 1 +50 1 + +-- !window_partition_order -- +1 1 +2 1 +3 1 +4 2 + -- !join_on -- 1 first 2 second +-- !join_on_alias_subquery -- +1 first +2 second + +-- !match_name_exact_where -- +2 + +-- !match_name_glob_where -- +1 + +-- !match_name_exact_order -- +1 +2 + +-- !match_name_glob_order -- +1 +2 diff --git a/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy b/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy index 6921ccd1c2a0bf..00158894aca79c 100644 --- a/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy +++ b/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy @@ -73,9 +73,27 @@ suite("test_schema_template_auto_cast", "p0") { FROM ${tableName} GROUP BY data['str_name'] HAVING SUM(data['num_a']) > 20 ORDER BY data['str_name'] """ + // Test 7: ORDER BY with alias from project + qt_order_by_alias """ SELECT data['num_a'] AS num_a FROM ${tableName} + ORDER BY num_a """ + + // Test 8: ORDER BY with alias from subquery + qt_order_by_alias_subquery """ SELECT * FROM (SELECT id, data['num_a'] AS num_a FROM ${tableName}) t + ORDER BY num_a, id """ + + // Test 9: GROUP BY with alias from subquery + qt_group_by_alias_subquery """ SELECT num_a, COUNT(*) AS cnt + FROM (SELECT data['num_a'] AS num_a FROM ${tableName}) t + GROUP BY num_a ORDER BY num_a """ + + // Test 10: WINDOW partition/order by with auto-cast + qt_window_partition_order """ SELECT id, + row_number() OVER (PARTITION BY data['str_name'] ORDER BY data['num_a']) AS rn + FROM ${tableName} ORDER BY id """ + sql "DROP TABLE IF EXISTS ${tableName}" - // Test 7: JOIN ON with auto-cast + // Test 11: JOIN ON with auto-cast def leftTable = "test_variant_join_left" def rightTable = "test_variant_join_right" @@ -109,6 +127,37 @@ suite("test_schema_template_auto_cast", "p0") { ON l.data['key_id'] = r.info['key_id'] ORDER BY l.id """ + // Test 12: JOIN ON with alias from subquery + qt_join_on_alias_subquery """ SELECT l.id, r.name_val + FROM (SELECT id, data['key_id'] AS key_id FROM ${leftTable}) l + JOIN (SELECT id, info['key_id'] AS key_id, info['name_val'] AS name_val FROM ${rightTable}) r + ON l.key_id = r.key_id + ORDER BY l.id """ + sql "DROP TABLE IF EXISTS ${leftTable}" sql "DROP TABLE IF EXISTS ${rightTable}" + + // Test 13: MATCH_NAME and MATCH_NAME_GLOB + def exactTable = "test_variant_schema_auto_cast_exact" + sql "DROP TABLE IF EXISTS ${exactTable}" + sql """CREATE TABLE ${exactTable} ( + `id` bigint NULL, + `data` variant<'exact_key': BIGINT, 'glob_*': BIGINT> NOT NULL + ) ENGINE=OLAP DUPLICATE KEY(`id`) + DISTRIBUTED BY HASH(`id`) BUCKETS 1 + PROPERTIES ( "replication_allocation" = "tag.location.default: 1")""" + + sql """insert into ${exactTable} values(1, '{"exact_key": 10, "glob_1": 20, "glob_2": 5}')""" + sql """insert into ${exactTable} values(2, '{"exact_key": 30, "glob_2": 40}')""" + + qt_match_name_exact_where """ SELECT id FROM ${exactTable} + WHERE data['exact_key'] > 10 ORDER BY id """ + qt_match_name_glob_where """ SELECT id FROM ${exactTable} + WHERE data['glob_1'] >= 20 ORDER BY id """ + qt_match_name_exact_order """ SELECT id FROM ${exactTable} + ORDER BY data['exact_key'] """ + qt_match_name_glob_order """ SELECT id FROM ${exactTable} + ORDER BY data['glob_2'], id """ + + sql "DROP TABLE IF EXISTS ${exactTable}" } From 60b0555eca05ec4cfbc4adb2ea475c1bbbabf125 Mon Sep 17 00:00:00 2001 From: Gary Date: Sun, 1 Feb 2026 02:10:15 +0800 Subject: [PATCH 12/27] may be aborted --- .../doris/nereids/jobs/executor/Rewriter.java | 9 - .../apache/doris/nereids/rules/RuleType.java | 1 - .../rules/analysis/BindExpression.java | 123 +++++++-- .../rules/analysis/ExpressionAnalyzer.java | 138 +++++++++- .../rules/rewrite/VariantSchemaCast.java | 259 ------------------ .../org/apache/doris/qe/SessionVariable.java | 37 +++ .../test_schema_template_auto_cast.groovy | 28 +- 7 files changed, 296 insertions(+), 299 deletions(-) delete mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index 7f10573da76877..9e3186c6cec7b2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -169,7 +169,6 @@ import org.apache.doris.nereids.rules.rewrite.TransposeSemiJoinLogicalJoin; import org.apache.doris.nereids.rules.rewrite.TransposeSemiJoinLogicalJoinProject; import org.apache.doris.nereids.rules.rewrite.VariantSubPathPruning; -import org.apache.doris.nereids.rules.rewrite.VariantSchemaCast; import org.apache.doris.nereids.rules.rewrite.batch.ApplyToJoin; import org.apache.doris.nereids.rules.rewrite.batch.CorrelateApplyToUnCorrelateApply; import org.apache.doris.nereids.rules.rewrite.batch.EliminateUselessPlanUnderApply; @@ -915,14 +914,6 @@ private static List getWholeTreeRewriteJobs( bottomUp(new RewriteSearchToSlots()) )); - // Auto cast variant element access based on schema template - // This should run before VariantSubPathPruning - rewriteJobs.addAll(jobs( - topic("variant schema cast", - custom(RuleType.VARIANT_SCHEMA_CAST, VariantSchemaCast::new) - ) - )); - if (needSubPathPushDown) { rewriteJobs.addAll(jobs( topic("variant element_at push down", diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 1dcce5c159bff3..2587481256b798 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -224,7 +224,6 @@ public enum RuleType { ADD_PROJECT_FOR_JOIN(RuleTypeClass.REWRITE), ADD_PROJECT_FOR_UNIQUE_FUNCTION(RuleTypeClass.REWRITE), - VARIANT_SCHEMA_CAST(RuleTypeClass.REWRITE), VARIANT_SUB_PATH_PRUNING(RuleTypeClass.REWRITE), NESTED_COLUMN_PRUNING(RuleTypeClass.REWRITE), CLEAR_CONTEXT_STATUS(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java index 54220dbe4d142b..2239f5b034b79e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java @@ -119,6 +119,7 @@ import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -368,7 +369,8 @@ private LogicalSetOperation bindSetOperation(LogicalSetOperation setOperation) { private LogicalOneRowRelation bindOneRowRelation(MatchingContext ctx) { OneRowRelation oneRowRelation = ctx.root; CascadesContext cascadesContext = ctx.cascadesContext; - SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(oneRowRelation, cascadesContext, ImmutableList.of()); + SimpleExprAnalyzer analyzer = + buildSimpleExprAnalyzer(oneRowRelation, cascadesContext, ImmutableList.of(), true); List projects = analyzer.analyzeToList(oneRowRelation.getProjects()); return new LogicalOneRowRelation(oneRowRelation.getRelationId(), projects); } @@ -451,8 +453,9 @@ private LogicalHaving bindHavingAggregate( }); FunctionRegistry functionRegistry = cascadesContext.getConnectContext().getEnv().getFunctionRegistry(); + Map aliasMap = buildAliasMap(having.child()); ExpressionAnalyzer havingAnalyzer = new ExpressionAnalyzer(having, aggOutputScope, cascadesContext, - false, true) { + false, true, false, aliasMap) { private boolean currentIsInAggregateFunction; @Override @@ -516,8 +519,9 @@ private LogicalHaving bindHavingByScopes( LogicalHaving having, Plan child, CascadesContext cascadesContext, Scope defaultScope, Supplier backupScope) { + Map aliasMap = buildAliasMap(child); SimpleExprAnalyzer analyzer = buildCustomSlotBinderAnalyzer( - having, cascadesContext, defaultScope, false, true, + having, cascadesContext, defaultScope, false, true, aliasMap, (self, unboundSlot) -> { List slots = self.bindSlotByScope(unboundSlot, defaultScope); if (!slots.isEmpty()) { @@ -658,7 +662,9 @@ private LogicalSort bindSortWithSetOperation( CascadesContext cascadesContext = ctx.cascadesContext; List childOutput = sort.child().getOutput(); - SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(sort, cascadesContext, sort.children()); + Map aliasMap = buildAliasMap(sort.child()); + SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(sort, cascadesContext, sort.children(), + false, aliasMap); Builder boundKeys = ImmutableList.builderWithExpectedSize(sort.getOrderKeys().size()); for (OrderKey orderKey : sort.getOrderKeys()) { Expression boundKey = bindWithOrdinal(orderKey.getExpr(), analyzer, childOutput); @@ -673,7 +679,10 @@ private LogicalJoin bindJoin(MatchingContext checkConflictAlias(join); - SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(join, cascadesContext, join.children()); + Map aliasMap = buildAliasMap(join.left()); + aliasMap.putAll(buildAliasMap(join.right())); + SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(join, cascadesContext, join.children(), + false, aliasMap); Builder hashJoinConjuncts = ImmutableList.builderWithExpectedSize( join.getHashJoinConjuncts().size()); @@ -747,16 +756,18 @@ private LogicalPlan bindUsingJoin(MatchingContext> Scope leftScope = toScope(cascadesContext, using.left().getOutput(), using.left().getAsteriskOutput()); Scope rightScope = toScope(cascadesContext, using.right().getOutput(), using.right().getAsteriskOutput()); ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(using, cascadesContext); + Map leftAliasMap = buildAliasMap(using.left()); + Map rightAliasMap = buildAliasMap(using.right()); Builder hashEqExprs = ImmutableList.builderWithExpectedSize(unboundHashJoinConjunct.size()); List rightConjunctsSlots = Lists.newArrayList(); for (Expression usingColumn : unboundHashJoinConjunct) { ExpressionAnalyzer leftExprAnalyzer = new ExpressionAnalyzer( - using, leftScope, cascadesContext, true, false); + using, leftScope, cascadesContext, true, false, false, leftAliasMap); Expression usingLeftSlot = leftExprAnalyzer.analyze(usingColumn, rewriteContext); ExpressionAnalyzer rightExprAnalyzer = new ExpressionAnalyzer( - using, rightScope, cascadesContext, true, false); + using, rightScope, cascadesContext, true, false, false, rightAliasMap); Expression usingRightSlot = rightExprAnalyzer.analyze(usingColumn, rewriteContext); rightConjunctsSlots.add((Slot) usingRightSlot); hashEqExprs.add(new EqualTo(usingLeftSlot, usingRightSlot)); @@ -773,7 +784,7 @@ private Plan bindProject(MatchingContext> ctx) { LogicalProject project = ctx.root; CascadesContext cascadesContext = ctx.cascadesContext; - SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(project, cascadesContext, project.children()); + SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(project, cascadesContext, project.children(), true); Builder boundProjectionsBuilder = ImmutableList.builderWithExpectedSize(project.getProjects().size()); StatementContext statementContext = ctx.statementContext; @@ -843,7 +854,7 @@ private Plan bindLoadProject(MatchingContext> ctx) { LogicalLoadProject project = ctx.root; CascadesContext cascadesContext = ctx.cascadesContext; - SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(project, cascadesContext, project.children()); + SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(project, cascadesContext, project.children(), true); Builder boundProjections = ImmutableList.builderWithExpectedSize(project.getProjects().size()); StatementContext statementContext = ctx.statementContext; for (Expression expression : project.getProjects()) { @@ -909,7 +920,9 @@ private Plan bindFilter(MatchingContext> ctx) { LogicalFilter filter = ctx.root; CascadesContext cascadesContext = ctx.cascadesContext; - SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(filter, cascadesContext, filter.children()); + Map aliasMap = buildAliasMap(filter.child()); + SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(filter, cascadesContext, filter.children(), + false, aliasMap); ImmutableSet.Builder boundConjuncts = ImmutableSet.builder(); boolean changed = false; for (Expression expr : filter.getConjuncts()) { @@ -931,7 +944,9 @@ private Plan bindPreFilter(MatchingContext> ctx) { LogicalPreFilter filter = ctx.root; CascadesContext cascadesContext = ctx.cascadesContext; - SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(filter, cascadesContext, filter.children()); + Map aliasMap = buildAliasMap(filter.child()); + SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(filter, cascadesContext, filter.children(), + false, aliasMap); ImmutableSet.Builder boundConjuncts = ImmutableSet.builder(); for (Expression conjunct : filter.getConjuncts()) { Expression boundExpr = analyzer.analyze(conjunct); @@ -1060,8 +1075,9 @@ private void bindQualifyByProject(LogicalProject project, Cascad ); Scope backupScope = toScope(cascadesContext, project.getOutput()); + Map aliasMap = buildAliasMap(project); SimpleExprAnalyzer analyzer = buildCustomSlotBinderAnalyzer( - qualify, cascadesContext, defaultScope.get(), true, true, + qualify, cascadesContext, defaultScope.get(), true, true, aliasMap, (self, unboundSlot) -> { List slots = self.bindSlotByScope(unboundSlot, defaultScope.get()); if (!slots.isEmpty()) { @@ -1114,8 +1130,9 @@ private void bindQualifyByAggregate(Aggregate aggregate, Cascade }; }); + Map aliasMap = buildAliasMap(aggregate); ExpressionAnalyzer qualifyAnalyzer = new ExpressionAnalyzer(qualify, aggOutputScope, cascadesContext, - true, true) { + true, true, false, aliasMap) { @Override protected List bindSlotByThisScope(UnboundSlot unboundSlot) { return bindByGroupByThenAggOutputThenAggChildOutput.get().bindSlot(this, unboundSlot); @@ -1151,7 +1168,7 @@ private Plan bindAggregate(MatchingContext> ctx) { LogicalAggregate agg = ctx.root; CascadesContext cascadesContext = ctx.cascadesContext; - SimpleExprAnalyzer aggOutputAnalyzer = buildSimpleExprAnalyzer(agg, cascadesContext, agg.children()); + SimpleExprAnalyzer aggOutputAnalyzer = buildSimpleExprAnalyzer(agg, cascadesContext, agg.children(), true); List boundAggOutput = aggOutputAnalyzer.analyzeToList(agg.getOutputExpressions()); List boundProjections = new ArrayList<>(boundAggOutput.size()); for (int i = 0; i < boundAggOutput.size(); i++) { @@ -1319,7 +1336,8 @@ private Plan bindRepeat(MatchingContext> ctx) { LogicalRepeat repeat = ctx.root; CascadesContext cascadesContext = ctx.cascadesContext; - SimpleExprAnalyzer repeatOutputAnalyzer = buildSimpleExprAnalyzer(repeat, cascadesContext, repeat.children()); + SimpleExprAnalyzer repeatOutputAnalyzer = buildSimpleExprAnalyzer(repeat, cascadesContext, + repeat.children(), true); List boundRepeatOutput = repeatOutputAnalyzer.analyzeToList(repeat.getOutputExpressions()); Supplier aggOutputScope = buildAggOutputScope(boundRepeatOutput, cascadesContext); Builder> boundGroupingSetsBuilder = @@ -1403,8 +1421,9 @@ private List bindGroupBy( Supplier aggOutputScope, CascadesContext cascadesContext) { Scope childOutputScope = toScope(cascadesContext, agg.child().getOutput()); + Map aliasMap = buildAliasMap(agg); SimpleExprAnalyzer analyzer = buildCustomSlotBinderAnalyzer( - agg, cascadesContext, childOutputScope, true, true, + agg, cascadesContext, childOutputScope, true, true, aliasMap, (self, unboundSlot) -> { // see: https://github.com/apache/doris/pull/15240 // @@ -1468,6 +1487,44 @@ private Supplier buildAggOutputScope( }); } + private Map buildAliasMap(Plan plan) { + Map aliasMap = new HashMap<>(); + Plan current = unwrapSubQueryAlias(plan); + while (current instanceof LogicalProject) { + int before = aliasMap.size(); + collectAliasMap(aliasMap, ((LogicalProject) current).getProjects()); + if (aliasMap.size() > before) { + break; + } + // passthrough project (e.g. SELECT *), keep searching in child + if (current.arity() != 1) { + break; + } + current = unwrapSubQueryAlias(current.child(0)); + } + if (aliasMap.isEmpty() && current instanceof LogicalAggregate) { + collectAliasMap(aliasMap, ((LogicalAggregate) current).getOutputExpressions()); + } + return aliasMap; + } + + private Plan unwrapSubQueryAlias(Plan plan) { + Plan current = plan; + while (current instanceof LogicalSubQueryAlias) { + current = ((LogicalSubQueryAlias) current).child(); + } + return current; + } + + private void collectAliasMap(Map aliasMap, List outputs) { + for (NamedExpression output : outputs) { + if (output instanceof Alias) { + Alias alias = (Alias) output; + aliasMap.put(alias.getExprId(), alias.child()); + } + } + } + private Plan bindSortWithoutSetOperation(MatchingContext> ctx) { CascadesContext cascadesContext = ctx.cascadesContext; LogicalSort sort = ctx.root; @@ -1511,12 +1568,13 @@ private Plan bindSortWithoutSetOperation(MatchingContext> ctx) // bind order_col1 with alias_col1, then, bind it with inner_col1 List inputSlots = input.getOutput(); Scope inputScope = toScope(cascadesContext, inputSlots); + Map aliasMap = buildAliasMap(input); final Plan finalInput = input; Supplier inputChildrenScope = Suppliers.memoize( () -> toScope(cascadesContext, PlanUtils.fastGetChildrenOutputs(finalInput.children()))); SimpleExprAnalyzer bindInInputScopeThenInputChildScope = buildCustomSlotBinderAnalyzer( - sort, cascadesContext, inputScope, true, false, + sort, cascadesContext, inputScope, true, false, aliasMap, (self, unboundSlot) -> { // first, try to bind slot in Scope(input.output) List slotsInInput = self.bindExactSlotsByThisScope(unboundSlot, inputScope); @@ -1531,7 +1589,7 @@ private Plan bindSortWithoutSetOperation(MatchingContext> ctx) }); SimpleExprAnalyzer bindInInputChildScope = getAnalyzerForOrderByAggFunc(finalInput, cascadesContext, sort, - inputChildrenScope, inputScope); + inputChildrenScope, inputScope, aliasMap); Builder boundOrderKeys = ImmutableList.builderWithExpectedSize(sort.getOrderKeys().size()); FunctionRegistry functionRegistry = cascadesContext.getConnectContext().getEnv().getFunctionRegistry(); Map bindUniqueIdReplaceMap = getBelowAggregateGroupByUniqueFuncReplaceMap(sort); @@ -1636,20 +1694,40 @@ private Scope toScope(CascadesContext cascadesContext, List slots, List children) { + return buildSimpleExprAnalyzer(currentPlan, cascadesContext, children, false, Collections.emptyMap()); + } + + protected SimpleExprAnalyzer buildSimpleExprAnalyzer( + Plan currentPlan, CascadesContext cascadesContext, List children, boolean autoCastInSelect) { + return buildSimpleExprAnalyzer(currentPlan, cascadesContext, children, autoCastInSelect, + Collections.emptyMap()); + } + + protected SimpleExprAnalyzer buildSimpleExprAnalyzer( + Plan currentPlan, CascadesContext cascadesContext, List children, boolean autoCastInSelect, + Map aliasMap) { Scope scope = toScope(cascadesContext, PlanUtils.fastGetChildrenOutputs(children), PlanUtils.fastGetChildrenAsteriskOutputs(children)); ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(currentPlan, cascadesContext); ExpressionAnalyzer expressionAnalyzer = new ExpressionAnalyzer(currentPlan, - scope, cascadesContext, true, true); + scope, cascadesContext, true, true, autoCastInSelect, aliasMap); return expr -> expressionAnalyzer.analyze(expr, rewriteContext); } private SimpleExprAnalyzer buildCustomSlotBinderAnalyzer( Plan currentPlan, CascadesContext cascadesContext, Scope defaultScope, boolean enableExactMatch, boolean bindSlotInOuterScope, CustomSlotBinderAnalyzer customSlotBinder) { + return buildCustomSlotBinderAnalyzer(currentPlan, cascadesContext, defaultScope, enableExactMatch, + bindSlotInOuterScope, Collections.emptyMap(), customSlotBinder); + } + + private SimpleExprAnalyzer buildCustomSlotBinderAnalyzer( + Plan currentPlan, CascadesContext cascadesContext, Scope defaultScope, + boolean enableExactMatch, boolean bindSlotInOuterScope, Map aliasMap, + CustomSlotBinderAnalyzer customSlotBinder) { ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(currentPlan, cascadesContext); ExpressionAnalyzer expressionAnalyzer = new ExpressionAnalyzer(currentPlan, defaultScope, cascadesContext, - enableExactMatch, bindSlotInOuterScope) { + enableExactMatch, bindSlotInOuterScope, false, aliasMap) { @Override protected List bindSlotByThisScope(UnboundSlot unboundSlot) { return customSlotBinder.bindSlot(this, unboundSlot); @@ -1706,7 +1784,8 @@ private boolean hasAggregateFunction(Expression expression, FunctionRegistry fun } private SimpleExprAnalyzer getAnalyzerForOrderByAggFunc(Plan finalInput, CascadesContext cascadesContext, - LogicalSort sort, Supplier inputChildrenScope, Scope inputScope) { + LogicalSort sort, Supplier inputChildrenScope, Scope inputScope, + Map aliasMap) { ImmutableList.Builder outputSlots = ImmutableList.builder(); if (finalInput instanceof LogicalAggregate) { LogicalAggregate aggregate = (LogicalAggregate) finalInput; @@ -1719,7 +1798,7 @@ private SimpleExprAnalyzer getAnalyzerForOrderByAggFunc(Plan finalInput, Cascade } Scope outputWithoutAggFunc = toScope(cascadesContext, outputSlots.build()); SimpleExprAnalyzer bindInInputChildScope = buildCustomSlotBinderAnalyzer( - sort, cascadesContext, inputScope, true, false, + sort, cascadesContext, inputScope, true, false, aliasMap, (analyzer, unboundSlot) -> { if (finalInput instanceof LogicalAggregate) { List boundInOutputWithoutAggFunc = analyzer.bindSlotByScope(unboundSlot, diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java index 866b460945e11f..08bec8e0a54be8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java @@ -84,6 +84,7 @@ import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral; import org.apache.doris.nereids.trees.expressions.literal.StringLiteral; import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes; import org.apache.doris.nereids.trees.plans.PlaceholderId; @@ -99,6 +100,7 @@ import org.apache.doris.nereids.types.StructField; import org.apache.doris.nereids.types.StructType; import org.apache.doris.nereids.types.TinyIntType; +import org.apache.doris.nereids.types.VariantField; import org.apache.doris.nereids.types.VariantType; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.TypeCoercionUtils; @@ -119,7 +121,9 @@ import org.apache.commons.lang3.StringUtils; import java.util.ArrayList; +import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -156,14 +160,34 @@ protected Expression processCompoundNewChildren(CompoundPredicate cp, List aliasMap; + private int suppressVariantElementAtCastDepth = 0; /** ExpressionAnalyzer */ public ExpressionAnalyzer(Plan currentPlan, Scope scope, @Nullable CascadesContext cascadesContext, boolean enableExactMatch, boolean bindSlotInOuterScope) { + this(currentPlan, scope, cascadesContext, enableExactMatch, bindSlotInOuterScope, false); + } + + /** ExpressionAnalyzer */ + public ExpressionAnalyzer(Plan currentPlan, Scope scope, + @Nullable CascadesContext cascadesContext, boolean enableExactMatch, boolean bindSlotInOuterScope, + boolean autoCastInSelect) { + this(currentPlan, scope, cascadesContext, enableExactMatch, bindSlotInOuterScope, + autoCastInSelect, Collections.emptyMap()); + } + + /** ExpressionAnalyzer */ + public ExpressionAnalyzer(Plan currentPlan, Scope scope, + @Nullable CascadesContext cascadesContext, boolean enableExactMatch, boolean bindSlotInOuterScope, + boolean autoCastInSelect, Map aliasMap) { super(scope, cascadesContext); this.currentPlan = currentPlan; this.enableExactMatch = enableExactMatch; this.bindSlotInOuterScope = bindSlotInOuterScope; + this.autoCastInSelect = autoCastInSelect; + this.aliasMap = aliasMap == null ? Collections.emptyMap() : aliasMap; this.wantToParseSqlFromSqlCache = cascadesContext != null && CacheAnalyzer.canUseSqlCache(cascadesContext.getConnectContext().getSessionVariable()); } @@ -274,11 +298,24 @@ public Expression visitDereferenceExpression(DereferenceExpression dereferenceEx } else if (dataType.isMapType()) { return new ElementAt(expression, dereferenceExpression.child(1)); } else if (dataType.isVariantType()) { - return new ElementAt(expression, dereferenceExpression.child(1)); + Expression elementAt = new ElementAt(expression, dereferenceExpression.child(1)); + if (isEnableVariantSchemaAutoCast(context)) { + return wrapVariantElementAtWithCast(elementAt); + } + return elementAt; } throw new AnalysisException("Can not dereference field: " + dereferenceExpression.fieldName); } + @Override + public Expression visitElementAt(ElementAt elementAt, ExpressionRewriteContext context) { + elementAt = (ElementAt) super.visitElementAt(elementAt, context); + if (isEnableVariantSchemaAutoCast(context)) { + return wrapVariantElementAtWithCast(elementAt); + } + return elementAt; + } + @Override public Expression visitUnboundSlot(UnboundSlot unboundSlot, ExpressionRewriteContext context) { Optional outerScope = getScope().getOuterScope(); @@ -317,7 +354,7 @@ public Expression visitUnboundSlot(UnboundSlot unboundSlot, ExpressionRewriteCon } else if (firstBound.containsType(ElementAt.class, StructElement.class)) { context.cascadesContext.getStatementContext().setHasNestedColumns(true); } - return firstBound; + return maybeCastBoundSlot(firstBound, context); default: if (enableExactMatch) { // select t1.k k, t2.k @@ -415,7 +452,7 @@ private UnboundFunction processHighOrderFunction(UnboundFunction unboundFunction ExpressionAnalyzer lambdaAnalyzer = new ExpressionAnalyzer(currentPlan, new Scope(Optional.of(getScope()), boundedSlots), context == null ? null : context.cascadesContext, - true, true) { + true, true, autoCastInSelect, aliasMap) { @Override protected void couldNotFoundColumn(UnboundSlot unboundSlot, String tableName) { throw new AnalysisException("Unknown lambda slot '" @@ -701,6 +738,89 @@ protected Expression processCompoundNewChildren(CompoundPredicate cp, List 0) { + return elementAt; + } + Expression left = elementAt.left(); + Expression right = elementAt.right(); + if (!(left.getDataType() instanceof VariantType)) { + return expr; + } + if (!(right instanceof StringLikeLiteral)) { + return expr; + } + VariantType variantType = (VariantType) left.getDataType(); + String fieldName = ((StringLikeLiteral) right).getStringValue(); + Optional matchingField = variantType.findMatchingField(fieldName); + if (!matchingField.isPresent()) { + return expr; + } + DataType targetType = matchingField.get().getDataType(); + return new Cast(elementAt, targetType); + } + + private boolean shouldSuppressVariantElementAtCast(Cast cast) { + if (!cast.isExplicitType()) { + return false; + } + Expression child = cast.child(); + return child instanceof ElementAt || child instanceof DereferenceExpression || child instanceof UnboundSlot; + } + + private Expression maybeCastBoundSlot(Expression bound, ExpressionRewriteContext context) { + if (!(bound instanceof SlotReference)) { + return bound; + } + if (suppressVariantElementAtCastDepth > 0 || aliasMap.isEmpty()) { + return bound; + } + if (!isEnableVariantSchemaAutoCast(context)) { + return bound; + } + if (!bound.getDataType().isVariantType()) { + return bound; + } + Expression aliasExpr = aliasMap.get(((SlotReference) bound).getExprId()); + if (aliasExpr == null) { + return bound; + } + Optional targetType = resolveVariantTemplateType(aliasExpr); + if (!targetType.isPresent()) { + return bound; + } + return new Cast(bound, targetType.get()); + } + + private Optional resolveVariantTemplateType(Expression expr) { + if (!(expr instanceof ElementAt)) { + return Optional.empty(); + } + Expression rewritten = wrapVariantElementAtWithCast(expr); + if (rewritten instanceof Cast) { + return Optional.of(((Cast) rewritten).getDataType()); + } + return Optional.empty(); + } + @Override public Expression visitNot(Not not, ExpressionRewriteContext context) { // maybe is `not subquery`, we should bind it first @@ -870,7 +990,17 @@ public Expression visitMatch(Match match, ExpressionRewriteContext context) { @Override public Expression visitCast(Cast cast, ExpressionRewriteContext context) { - cast = (Cast) super.visitCast(cast, context); + boolean suppressVariantElementAtCast = shouldSuppressVariantElementAtCast(cast); + if (suppressVariantElementAtCast) { + suppressVariantElementAtCastDepth++; + } + try { + cast = (Cast) super.visitCast(cast, context); + } finally { + if (suppressVariantElementAtCast) { + suppressVariantElementAtCastDepth--; + } + } // NOTICE: just for compatibility with legacy planner. if (cast.child().getDataType().isComplexType() || cast.getDataType().isComplexType()) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java deleted file mode 100644 index 68f1951b029103..00000000000000 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaCast.java +++ /dev/null @@ -1,259 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.rewrite; - -import org.apache.doris.nereids.jobs.JobContext; -import org.apache.doris.nereids.properties.OrderKey; -import org.apache.doris.nereids.trees.expressions.Alias; -import org.apache.doris.nereids.trees.expressions.Cast; -import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.ExprId; -import org.apache.doris.nereids.trees.expressions.NamedExpression; -import org.apache.doris.nereids.trees.expressions.OrderExpression; -import org.apache.doris.nereids.trees.expressions.SlotReference; -import org.apache.doris.nereids.trees.expressions.WindowExpression; -import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt; -import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral; -import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; -import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; -import org.apache.doris.nereids.trees.plans.logical.LogicalHaving; -import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; -import org.apache.doris.nereids.trees.plans.logical.LogicalPartitionTopN; -import org.apache.doris.nereids.trees.plans.logical.LogicalProject; -import org.apache.doris.nereids.trees.plans.logical.LogicalSort; -import org.apache.doris.nereids.trees.plans.logical.LogicalSubQueryAlias; -import org.apache.doris.nereids.trees.plans.logical.LogicalTopN; -import org.apache.doris.nereids.trees.plans.logical.LogicalWindow; -import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; -import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; -import org.apache.doris.nereids.types.DataType; -import org.apache.doris.nereids.types.VariantField; -import org.apache.doris.nereids.types.VariantType; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; - -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.function.Function; - -/** - * Automatically cast variant element access expressions based on schema template. - * - * This rule only targets non-select clauses (e.g. WHERE, GROUP BY, HAVING, ORDER BY, JOIN ON). - * It should run before VariantSubPathPruning so ElementAt is still present. - */ -public class VariantSchemaCast implements CustomRewriter { - - @Override - public Plan rewriteRoot(Plan plan, JobContext jobContext) { - return plan.accept(PlanRewriter.INSTANCE, null); - } - - private static class PlanRewriter extends DefaultPlanRewriter { - public static final PlanRewriter INSTANCE = new PlanRewriter(); - - private static final Function EXPRESSION_REWRITER = expr -> { - if (!(expr instanceof ElementAt)) { - return expr; - } - ElementAt elementAt = (ElementAt) expr; - Expression left = elementAt.left(); - Expression right = elementAt.right(); - - if (!(left.getDataType() instanceof VariantType)) { - return expr; - } - if (!(right instanceof StringLikeLiteral)) { - return expr; - } - - VariantType variantType = (VariantType) left.getDataType(); - String fieldName = ((StringLikeLiteral) right).getStringValue(); - Optional matchingField = variantType.findMatchingField(fieldName); - if (!matchingField.isPresent()) { - return expr; - } - - DataType targetType = matchingField.get().getDataType(); - return new Cast(elementAt, targetType); - }; - - private Expression rewriteExpression(Expression expr) { - return expr.rewriteDownShortCircuit(EXPRESSION_REWRITER); - } - - private Expression rewriteExpression(Expression expr, Map aliasMap) { - if (aliasMap.isEmpty()) { - return rewriteExpression(expr); - } - return expr.rewriteDownShortCircuit(node -> { - if (node instanceof SlotReference) { - ExprId exprId = ((SlotReference) node).getExprId(); - Expression aliasExpr = aliasMap.get(exprId); - if (aliasExpr != null) { - Expression rewrittenAlias = rewriteExpression(aliasExpr); - if (rewrittenAlias instanceof Cast) { - return new Cast(node, ((Cast) rewrittenAlias).getDataType()); - } - } - } - return EXPRESSION_REWRITER.apply(node); - }); - } - - private Map buildAliasMap(Plan plan) { - Map aliasMap = new HashMap<>(); - Plan current = plan; - while (current instanceof LogicalSubQueryAlias) { - current = ((LogicalSubQueryAlias) current).child(); - } - if (current instanceof LogicalProject) { - collectAliasMap(aliasMap, ((LogicalProject) current).getProjects()); - } else if (current instanceof LogicalAggregate) { - collectAliasMap(aliasMap, ((LogicalAggregate) current).getOutputExpressions()); - } - return aliasMap; - } - - private void collectAliasMap(Map aliasMap, List outputs) { - for (NamedExpression output : outputs) { - if (output instanceof Alias) { - Alias alias = (Alias) output; - aliasMap.put(alias.getExprId(), alias.child()); - } - } - } - - @Override - public Plan visitLogicalFilter(LogicalFilter filter, Void context) { - filter = (LogicalFilter) super.visit(filter, context); - Map aliasMap = buildAliasMap(filter.child()); - Set newConjuncts = filter.getConjuncts().stream() - .map(expr -> rewriteExpression(expr, aliasMap)) - .collect(ImmutableSet.toImmutableSet()); - return filter.withConjuncts(newConjuncts); - } - - @Override - public Plan visitLogicalHaving(LogicalHaving having, Void context) { - having = (LogicalHaving) super.visit(having, context); - Map aliasMap = buildAliasMap(having.child()); - Set newConjuncts = having.getConjuncts().stream() - .map(expr -> rewriteExpression(expr, aliasMap)) - .collect(ImmutableSet.toImmutableSet()); - return having.withConjuncts(newConjuncts); - } - - @Override - public Plan visitLogicalSort(LogicalSort sort, Void context) { - sort = (LogicalSort) super.visit(sort, context); - Map aliasMap = buildAliasMap(sort.child()); - List newOrderKeys = sort.getOrderKeys().stream() - .map(orderKey -> orderKey.withExpression(rewriteExpression(orderKey.getExpr(), aliasMap))) - .collect(ImmutableList.toImmutableList()); - return sort.withOrderKeys(newOrderKeys); - } - - @Override - public Plan visitLogicalTopN(LogicalTopN topN, Void context) { - topN = (LogicalTopN) super.visit(topN, context); - Map aliasMap = buildAliasMap(topN.child()); - List newOrderKeys = topN.getOrderKeys().stream() - .map(orderKey -> orderKey.withExpression(rewriteExpression(orderKey.getExpr(), aliasMap))) - .collect(ImmutableList.toImmutableList()); - return topN.withOrderKeys(newOrderKeys); - } - - @Override - public Plan visitLogicalPartitionTopN(LogicalPartitionTopN topN, Void context) { - topN = (LogicalPartitionTopN) super.visit(topN, context); - Map aliasMap = buildAliasMap(topN.child()); - List newPartitionKeys = topN.getPartitionKeys().stream() - .map(expr -> rewriteExpression(expr, aliasMap)) - .collect(ImmutableList.toImmutableList()); - List newOrderKeys = topN.getOrderKeys().stream() - .map(orderExpr -> (OrderExpression) orderExpr.withChildren(ImmutableList.of( - rewriteExpression(orderExpr.child(), aliasMap)))) - .collect(ImmutableList.toImmutableList()); - return topN.withPartitionKeysAndOrderKeys(newPartitionKeys, newOrderKeys); - } - - @Override - public Plan visitLogicalJoin(LogicalJoin join, Void context) { - join = (LogicalJoin) super.visit(join, context); - Map aliasMap = buildAliasMap(join.left()); - aliasMap.putAll(buildAliasMap(join.right())); - List newHashConditions = join.getHashJoinConjuncts().stream() - .map(expr -> rewriteExpression(expr, aliasMap)) - .collect(ImmutableList.toImmutableList()); - List newOtherConditions = join.getOtherJoinConjuncts().stream() - .map(expr -> rewriteExpression(expr, aliasMap)) - .collect(ImmutableList.toImmutableList()); - List newMarkConditions = join.getMarkJoinConjuncts().stream() - .map(expr -> rewriteExpression(expr, aliasMap)) - .collect(ImmutableList.toImmutableList()); - return join.withJoinConjuncts(newHashConditions, newOtherConditions, - newMarkConditions, join.getJoinReorderContext()); - } - - @Override - public Plan visitLogicalAggregate(LogicalAggregate aggregate, Void context) { - aggregate = (LogicalAggregate) super.visit(aggregate, context); - Map aliasMap = buildAliasMap(aggregate.child()); - List newGroupByKeys = aggregate.getGroupByExpressions().stream() - .map(expr -> rewriteExpression(expr, aliasMap)) - .collect(ImmutableList.toImmutableList()); - List outputs = aggregate.getOutputExpressions(); - return aggregate.withGroupByAndOutput(newGroupByKeys, outputs); - } - - @Override - public Plan visitLogicalWindow(LogicalWindow window, Void context) { - window = (LogicalWindow) super.visit(window, context); - Map aliasMap = buildAliasMap(window.child()); - List newExprs = window.getWindowExpressions().stream() - .map(expr -> rewriteWindowExpression(expr, aliasMap)) - .collect(ImmutableList.toImmutableList()); - return window.withExpressionsAndChild(newExprs, window.child()); - } - - private NamedExpression rewriteWindowExpression(NamedExpression expr, Map aliasMap) { - if (expr instanceof Alias) { - Alias alias = (Alias) expr; - if (alias.child() instanceof WindowExpression) { - WindowExpression windowExpr = (WindowExpression) alias.child(); - List newPartitionKeys = windowExpr.getPartitionKeys().stream() - .map(partitionKey -> rewriteExpression(partitionKey, aliasMap)) - .collect(ImmutableList.toImmutableList()); - List newOrderKeys = windowExpr.getOrderKeys().stream() - .map(orderExpr -> (OrderExpression) orderExpr.withChildren(ImmutableList.of( - rewriteExpression(orderExpr.child(), aliasMap)))) - .collect(ImmutableList.toImmutableList()); - return alias.withChildren(ImmutableList.of( - windowExpr.withPartitionKeysOrderKeys(newPartitionKeys, newOrderKeys))); - } - } - return expr; - } - } -} diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java index d4bb2b3fa4658d..c89ced47a32b98 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java @@ -773,6 +773,9 @@ public class SessionVariable implements Serializable, Writable { // enable variant flatten nested as session variable, default is false, // which means do not flatten nested when create table public static final String ENABLE_VARIANT_FLATTEN_NESTED = "enable_variant_flatten_nested"; + public static final String ENABLE_VARIANT_SCHEMA_AUTO_CAST = "enable_variant_schema_auto_cast"; + public static final String ENABLE_VARIANT_SCHEMA_AUTO_CAST_IN_SELECT = + "enable_variant_schema_auto_cast_in_select"; // CLOUD_VARIABLES_BEGIN public static final String CLOUD_CLUSTER = "cloud_cluster"; @@ -3232,6 +3235,32 @@ public boolean isEnableESParallelScroll() { ) public int defaultVariantMaxSubcolumnsCount = 0; + @VariableMgr.VarAttr( + name = ENABLE_VARIANT_SCHEMA_AUTO_CAST, + needForward = true, + affectQueryResultInExecution = true, + description = { + "是否启用基于 schema template 的 variant 自动 cast(非 SELECT 子句),默认关闭。", + "Whether to enable schema-template-based auto cast for variant expressions " + + "(non-SELECT clauses). The default is false." + } + ) + public boolean enableVariantSchemaAutoCast = false; + + @VariableMgr.VarAttr( + name = ENABLE_VARIANT_SCHEMA_AUTO_CAST_IN_SELECT, + needForward = true, + affectQueryResultInExecution = true, + description = { + "是否在 SELECT 子句中启用基于 schema template 的 variant 自动 cast,默认关闭," + + "需先开启 enable_variant_schema_auto_cast。", + "Whether to enable schema-template-based auto cast for variant expressions " + + "in SELECT clause. The default is false and requires " + + "enable_variant_schema_auto_cast = true." + } + ) + public boolean enableVariantSchemaAutoCastInSelect = false; + @VariableMgr.VarAttr( name = DEFAULT_VARIANT_ENABLE_TYPED_PATHS_TO_SPARSE, needForward = true, @@ -5824,6 +5853,14 @@ public boolean getEnableVariantFlattenNested() { return enableVariantFlattenNested; } + public boolean isEnableVariantSchemaAutoCast() { + return enableVariantSchemaAutoCast; + } + + public boolean isEnableVariantSchemaAutoCastInSelect() { + return enableVariantSchemaAutoCast && enableVariantSchemaAutoCastInSelect; + } + public void setProfileLevel(String profileLevel) { int profileLevelTmp = Integer.valueOf(profileLevel); if (profileLevelTmp < 1 || profileLevelTmp > 3) { diff --git a/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy b/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy index 00158894aca79c..822926b2d4b351 100644 --- a/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy +++ b/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy @@ -21,6 +21,7 @@ suite("test_schema_template_auto_cast", "p0") { sql """ set enable_common_expr_pushdown = true """ sql """ set default_variant_enable_typed_paths_to_sparse = false """ sql """ set default_variant_enable_doc_mode = false """ + sql """ set enable_variant_schema_auto_cast = true """ def tableName = "test_variant_schema_auto_cast" @@ -60,11 +61,19 @@ suite("test_schema_template_auto_cast", "p0") { qt_topn """ SELECT id, data['num_a'] FROM ${tableName} ORDER BY data['num_a'] DESC LIMIT 2 """ - // Test 4: SELECT with auto-cast (arithmetic operations) + // Test 4: SELECT with auto-cast (arithmetic operations) when enabled + sql """ set enable_variant_schema_auto_cast_in_select = true """ qt_select_arithmetic """ SELECT id, data['num_a'] + data['num_b'] as sum_val FROM ${tableName} ORDER BY id """ + sql """ set enable_variant_schema_auto_cast_in_select = false """ + test { + sql """ SELECT id, data['num_a'] + data['num_b'] as sum_val + FROM ${tableName} ORDER BY id """ + exception "Cannot cast from variant" + } // Test 5: GROUP BY with auto-cast + sql """ set enable_variant_schema_auto_cast_in_select = true """ qt_group_by """ SELECT data['str_name'], SUM(data['num_a']) as total FROM ${tableName} GROUP BY data['str_name'] ORDER BY data['str_name'] """ @@ -74,6 +83,7 @@ suite("test_schema_template_auto_cast", "p0") { HAVING SUM(data['num_a']) > 20 ORDER BY data['str_name'] """ // Test 7: ORDER BY with alias from project + sql """ set enable_variant_schema_auto_cast_in_select = false """ qt_order_by_alias """ SELECT data['num_a'] AS num_a FROM ${tableName} ORDER BY num_a """ @@ -87,13 +97,23 @@ suite("test_schema_template_auto_cast", "p0") { GROUP BY num_a ORDER BY num_a """ // Test 10: WINDOW partition/order by with auto-cast + sql """ set enable_variant_schema_auto_cast_in_select = true """ qt_window_partition_order """ SELECT id, row_number() OVER (PARTITION BY data['str_name'] ORDER BY data['num_a']) AS rn FROM ${tableName} ORDER BY id """ + sql """ set enable_variant_schema_auto_cast_in_select = false """ + + // Test 11: disable auto-cast should error in non-select clauses + sql """ set enable_variant_schema_auto_cast = false """ + test { + sql """ SELECT id FROM ${tableName} ORDER BY data['num_a'] """ + exception "Doris hll, bitmap, array, map, struct, jsonb, variant column must use with specific function" + } + sql """ set enable_variant_schema_auto_cast = true """ sql "DROP TABLE IF EXISTS ${tableName}" - // Test 11: JOIN ON with auto-cast + // Test 12: JOIN ON with auto-cast def leftTable = "test_variant_join_left" def rightTable = "test_variant_join_right" @@ -127,7 +147,7 @@ suite("test_schema_template_auto_cast", "p0") { ON l.data['key_id'] = r.info['key_id'] ORDER BY l.id """ - // Test 12: JOIN ON with alias from subquery + // Test 13: JOIN ON with alias from subquery qt_join_on_alias_subquery """ SELECT l.id, r.name_val FROM (SELECT id, data['key_id'] AS key_id FROM ${leftTable}) l JOIN (SELECT id, info['key_id'] AS key_id, info['name_val'] AS name_val FROM ${rightTable}) r @@ -137,7 +157,7 @@ suite("test_schema_template_auto_cast", "p0") { sql "DROP TABLE IF EXISTS ${leftTable}" sql "DROP TABLE IF EXISTS ${rightTable}" - // Test 13: MATCH_NAME and MATCH_NAME_GLOB + // Test 14: MATCH_NAME and MATCH_NAME_GLOB def exactTable = "test_variant_schema_auto_cast_exact" sql "DROP TABLE IF EXISTS ${exactTable}" sql """CREATE TABLE ${exactTable} ( From 0a04b037e0b4aea7cbc2c25f0488667e6e48868e Mon Sep 17 00:00:00 2001 From: Gary Date: Sun, 1 Feb 2026 04:26:53 +0800 Subject: [PATCH 13/27] enable_variant_schema_auto_cast_in_select is very complex --- .../rules/analysis/BindExpression.java | 30 +++-- .../rules/analysis/ExpressionAnalyzer.java | 87 +++++++++++++-- .../test_schema_template_auto_cast.out | 65 ++++++++++- .../test_schema_template_auto_cast.groovy | 104 +++++++++++++++++- 4 files changed, 262 insertions(+), 24 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java index 2239f5b034b79e..2387104fc1c576 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java @@ -1490,17 +1490,9 @@ private Supplier buildAggOutputScope( private Map buildAliasMap(Plan plan) { Map aliasMap = new HashMap<>(); Plan current = unwrapSubQueryAlias(plan); - while (current instanceof LogicalProject) { - int before = aliasMap.size(); - collectAliasMap(aliasMap, ((LogicalProject) current).getProjects()); - if (aliasMap.size() > before) { - break; - } - // passthrough project (e.g. SELECT *), keep searching in child - if (current.arity() != 1) { - break; - } - current = unwrapSubQueryAlias(current.child(0)); + collectAliasMapFromProjectChain(aliasMap, current); + if (aliasMap.isEmpty() && current instanceof LogicalAggregate && current.arity() == 1) { + collectAliasMapFromProjectChain(aliasMap, unwrapSubQueryAlias(current.child(0))); } if (aliasMap.isEmpty() && current instanceof LogicalAggregate) { collectAliasMap(aliasMap, ((LogicalAggregate) current).getOutputExpressions()); @@ -1516,6 +1508,22 @@ private Plan unwrapSubQueryAlias(Plan plan) { return current; } + private void collectAliasMapFromProjectChain(Map aliasMap, Plan start) { + Plan current = start; + while (current instanceof LogicalProject) { + int before = aliasMap.size(); + collectAliasMap(aliasMap, ((LogicalProject) current).getProjects()); + if (aliasMap.size() > before) { + break; + } + // passthrough project (e.g. SELECT *), keep searching in child + if (current.arity() != 1) { + break; + } + current = unwrapSubQueryAlias(current.child(0)); + } + } + private void collectAliasMap(Map aliasMap, List outputs) { for (NamedExpression output : outputs) { if (output instanceof Alias) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java index 08bec8e0a54be8..7b5e6e74469852 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java @@ -287,7 +287,20 @@ public Expression visitUnboundAlias(UnboundAlias unboundAlias, ExpressionRewrite @Override public Expression visitDereferenceExpression(DereferenceExpression dereferenceExpression, ExpressionRewriteContext context) { - Expression expression = dereferenceExpression.child(0).accept(this, context); + boolean suppressChildCast = isEnableVariantSchemaAutoCast(context) + && (dereferenceExpression.child(0) instanceof DereferenceExpression + || dereferenceExpression.child(0) instanceof ElementAt); + if (suppressChildCast) { + suppressVariantElementAtCastDepth++; + } + Expression expression; + try { + expression = dereferenceExpression.child(0).accept(this, context); + } finally { + if (suppressChildCast) { + suppressVariantElementAtCastDepth--; + } + } DataType dataType = expression.getDataType(); if (dataType.isStructType()) { StructType structType = (StructType) dataType; @@ -309,7 +322,20 @@ public Expression visitDereferenceExpression(DereferenceExpression dereferenceEx @Override public Expression visitElementAt(ElementAt elementAt, ExpressionRewriteContext context) { - elementAt = (ElementAt) super.visitElementAt(elementAt, context); + boolean suppressLeftCast = isEnableVariantSchemaAutoCast(context) && elementAt.left() instanceof ElementAt; + if (suppressLeftCast) { + suppressVariantElementAtCastDepth++; + } + Expression left; + try { + left = elementAt.left().accept(this, context); + } finally { + if (suppressLeftCast) { + suppressVariantElementAtCastDepth--; + } + } + Expression right = elementAt.right().accept(this, context); + elementAt = (ElementAt) elementAt.withChildren(left, right); if (isEnableVariantSchemaAutoCast(context)) { return wrapVariantElementAtWithCast(elementAt); } @@ -760,17 +786,12 @@ private Expression wrapVariantElementAtWithCast(Expression expr) { if (suppressVariantElementAtCastDepth > 0) { return elementAt; } - Expression left = elementAt.left(); - Expression right = elementAt.right(); - if (!(left.getDataType() instanceof VariantType)) { - return expr; - } - if (!(right instanceof StringLikeLiteral)) { + Optional path = resolveVariantElementAtPath(elementAt); + if (!path.isPresent()) { return expr; } - VariantType variantType = (VariantType) left.getDataType(); - String fieldName = ((StringLikeLiteral) right).getStringValue(); - Optional matchingField = variantType.findMatchingField(fieldName); + VariantType variantType = (VariantType) path.get().root.getDataType(); + Optional matchingField = variantType.findMatchingField(path.get().path); if (!matchingField.isPresent()) { return expr; } @@ -778,6 +799,50 @@ private Expression wrapVariantElementAtWithCast(Expression expr) { return new Cast(elementAt, targetType); } + private Optional resolveVariantElementAtPath(ElementAt elementAt) { + List segments = new ArrayList<>(); + Expression current = elementAt; + Expression root = null; + while (current instanceof ElementAt) { + ElementAt currentElementAt = (ElementAt) current; + Optional key = getVariantPathKey(currentElementAt.right()); + if (!key.isPresent()) { + return Optional.empty(); + } + segments.add(0, key.get()); + Expression left = currentElementAt.left(); + if (left instanceof Cast && !((Cast) left).isExplicitType()) { + left = ((Cast) left).child(); + } + current = left; + root = left; + } + if (root == null || !(root.getDataType() instanceof VariantType)) { + return Optional.empty(); + } + if (segments.isEmpty()) { + return Optional.empty(); + } + return Optional.of(new VariantElementAtPath(root, String.join(".", segments))); + } + + private Optional getVariantPathKey(Expression expr) { + if (expr instanceof StringLikeLiteral) { + return Optional.of(((StringLikeLiteral) expr).getStringValue()); + } + return Optional.empty(); + } + + private static final class VariantElementAtPath { + private final Expression root; + private final String path; + + private VariantElementAtPath(Expression root, String path) { + this.root = root; + this.path = path; + } + } + private boolean shouldSuppressVariantElementAtCast(Cast cast) { if (!cast.isExplicitType()) { return false; diff --git a/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out b/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out index 65e525900df5ad..9ea29fb6c6d53b 100644 --- a/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out +++ b/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out @@ -22,6 +22,16 @@ 3 50 2 30 +-- !order_by_select_on -- +3 50 +2 30 +4 15 +1 10 + +-- !topn_select_on -- +3 50 +2 30 + -- !select_arithmetic -- 1 30 2 70 @@ -44,13 +54,25 @@ charlie 50 30 50 +-- !order_by_alias_select_on -- +10 +15 +30 +50 + -- !order_by_alias_subquery -- 1 10 4 15 2 30 3 50 --- !group_by_alias_subquery -- +-- !order_by_alias_subquery_select_on -- +1 10 +4 15 +2 30 +3 50 + +-- !group_by_alias_subquery_select_on -- 10 1 15 1 30 1 @@ -66,10 +88,18 @@ charlie 50 1 first 2 second +-- !join_on_select_on -- +1 first +2 second + -- !join_on_alias_subquery -- 1 first 2 second +-- !join_on_alias_subquery_select_on -- +1 first +2 second + -- !match_name_exact_where -- 2 @@ -83,3 +113,36 @@ charlie 50 -- !match_name_glob_order -- 1 2 + +-- !leaf_int1_select_on -- +1 + +-- !leaf_int1_add_select_on -- +2 + +-- !leaf_int_nested_chain_select_on -- +1011111 + +-- !leaf_int_nested_dot_select_on -- +1011111 + +-- !leaf_int_nested_deref_select_on -- +1011111 + +-- !leaf_int_nested_chain_add_select_on -- +1011112 + +-- !leaf_int_nested_dot_add_select_on -- +1011112 + +-- !leaf_int1_select_off -- +1 + +-- !leaf_int_nested_chain_select_off -- +1011111 + +-- !leaf_int_nested_dot_select_off -- +1011111 + +-- !leaf_int_nested_deref_select_off -- +1011111 diff --git a/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy b/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy index 822926b2d4b351..67be81500aec2d 100644 --- a/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy +++ b/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy @@ -60,6 +60,12 @@ suite("test_schema_template_auto_cast", "p0") { // Test 3: TopN (ORDER BY + LIMIT) qt_topn """ SELECT id, data['num_a'] FROM ${tableName} ORDER BY data['num_a'] DESC LIMIT 2 """ + sql """ set enable_variant_schema_auto_cast_in_select = true """ + qt_order_by_select_on """ SELECT id, data['num_a'] FROM ${tableName} + ORDER BY data['num_a'] DESC """ + qt_topn_select_on """ SELECT id, data['num_a'] FROM ${tableName} + ORDER BY data['num_a'] DESC LIMIT 2 """ + sql """ set enable_variant_schema_auto_cast_in_select = false """ // Test 4: SELECT with auto-cast (arithmetic operations) when enabled sql """ set enable_variant_schema_auto_cast_in_select = true """ @@ -81,20 +87,48 @@ suite("test_schema_template_auto_cast", "p0") { qt_having """ SELECT data['str_name'], SUM(data['num_a']) as total FROM ${tableName} GROUP BY data['str_name'] HAVING SUM(data['num_a']) > 20 ORDER BY data['str_name'] """ + sql """ set enable_variant_schema_auto_cast_in_select = false """ + test { + sql """ SELECT data['str_name'], SUM(data['num_a']) as total + FROM ${tableName} GROUP BY data['str_name'] ORDER BY data['str_name'] """ + exception "sum requires a numeric, boolean or string parameter" + } + test { + sql """ SELECT data['str_name'], SUM(data['num_a']) as total + FROM ${tableName} GROUP BY data['str_name'] + HAVING SUM(data['num_a']) > 20 ORDER BY data['str_name'] """ + exception "sum requires a numeric, boolean or string parameter" + } // Test 7: ORDER BY with alias from project sql """ set enable_variant_schema_auto_cast_in_select = false """ qt_order_by_alias """ SELECT data['num_a'] AS num_a FROM ${tableName} ORDER BY num_a """ + sql """ set enable_variant_schema_auto_cast_in_select = true """ + qt_order_by_alias_select_on """ SELECT data['num_a'] AS num_a FROM ${tableName} + ORDER BY num_a """ + sql """ set enable_variant_schema_auto_cast_in_select = false """ // Test 8: ORDER BY with alias from subquery qt_order_by_alias_subquery """ SELECT * FROM (SELECT id, data['num_a'] AS num_a FROM ${tableName}) t ORDER BY num_a, id """ + sql """ set enable_variant_schema_auto_cast_in_select = true """ + qt_order_by_alias_subquery_select_on """ SELECT * FROM (SELECT id, data['num_a'] AS num_a FROM ${tableName}) t + ORDER BY num_a, id """ + sql """ set enable_variant_schema_auto_cast_in_select = false """ // Test 9: GROUP BY with alias from subquery - qt_group_by_alias_subquery """ SELECT num_a, COUNT(*) AS cnt + test { + sql """ SELECT num_a, COUNT(*) AS cnt + FROM (SELECT data['num_a'] AS num_a FROM ${tableName}) t + GROUP BY num_a ORDER BY num_a """ + exception "must appear in the GROUP BY clause" + } + sql """ set enable_variant_schema_auto_cast_in_select = true """ + qt_group_by_alias_subquery_select_on """ SELECT num_a, COUNT(*) AS cnt FROM (SELECT data['num_a'] AS num_a FROM ${tableName}) t GROUP BY num_a ORDER BY num_a """ + sql """ set enable_variant_schema_auto_cast_in_select = false """ // Test 10: WINDOW partition/order by with auto-cast sql """ set enable_variant_schema_auto_cast_in_select = true """ @@ -146,6 +180,12 @@ suite("test_schema_template_auto_cast", "p0") { FROM ${leftTable} l JOIN ${rightTable} r ON l.data['key_id'] = r.info['key_id'] ORDER BY l.id """ + sql """ set enable_variant_schema_auto_cast_in_select = true """ + qt_join_on_select_on """ SELECT l.id, r.info['name_val'] + FROM ${leftTable} l JOIN ${rightTable} r + ON l.data['key_id'] = r.info['key_id'] + ORDER BY l.id """ + sql """ set enable_variant_schema_auto_cast_in_select = false """ // Test 13: JOIN ON with alias from subquery qt_join_on_alias_subquery """ SELECT l.id, r.name_val @@ -153,6 +193,13 @@ suite("test_schema_template_auto_cast", "p0") { JOIN (SELECT id, info['key_id'] AS key_id, info['name_val'] AS name_val FROM ${rightTable}) r ON l.key_id = r.key_id ORDER BY l.id """ + sql """ set enable_variant_schema_auto_cast_in_select = true """ + qt_join_on_alias_subquery_select_on """ SELECT l.id, r.name_val + FROM (SELECT id, data['key_id'] AS key_id FROM ${leftTable}) l + JOIN (SELECT id, info['key_id'] AS key_id, info['name_val'] AS name_val FROM ${rightTable}) r + ON l.key_id = r.key_id + ORDER BY l.id """ + sql """ set enable_variant_schema_auto_cast_in_select = false """ sql "DROP TABLE IF EXISTS ${leftTable}" sql "DROP TABLE IF EXISTS ${rightTable}" @@ -180,4 +227,59 @@ suite("test_schema_template_auto_cast", "p0") { ORDER BY data['glob_2'], id """ sql "DROP TABLE IF EXISTS ${exactTable}" + + // Test 15: leaf vs non-leaf path auto cast limitation + def leafTable = "test_variant_schema_auto_cast_leaf" + sql "DROP TABLE IF EXISTS ${leafTable}" + sql """CREATE TABLE ${leafTable} ( + `id` bigint NULL, + `data` variant<'int_*': BIGINT> NOT NULL + ) ENGINE=OLAP DUPLICATE KEY(`id`) + DISTRIBUTED BY HASH(`id`) BUCKETS 1 + PROPERTIES ( "replication_allocation" = "tag.location.default: 1")""" + + sql """insert into ${leafTable} values( + 1, + '{"int_1": 1, "int_nested": {"level1_num_1": 1011111, "level1_num_2": 102}}' + )""" + + sql """ set enable_variant_schema_auto_cast_in_select = true """ + qt_leaf_int1_select_on """ SELECT data['int_1'] FROM ${leafTable} ORDER BY id """ + qt_leaf_int1_add_select_on """ SELECT data['int_1'] + 1 FROM ${leafTable} ORDER BY id """ + test { + // still fails: FE can't distinguish leaf/non-leaf, may cast int_nested to int + sql """ SELECT data['int_nested'] FROM ${leafTable} """ + exception "Bad cast" + } + qt_leaf_int_nested_chain_select_on """ SELECT data['int_nested']['level1_num_1'] + FROM ${leafTable} ORDER BY id """ + qt_leaf_int_nested_dot_select_on """ SELECT data['int_nested.level1_num_1'] FROM ${leafTable} ORDER BY id """ + qt_leaf_int_nested_deref_select_on """ SELECT data.int_nested.level1_num_1 FROM ${leafTable} ORDER BY id """ + qt_leaf_int_nested_chain_add_select_on """ SELECT data['int_nested']['level1_num_1'] + 1 + FROM ${leafTable} ORDER BY id """ + qt_leaf_int_nested_dot_add_select_on """ SELECT data['int_nested.level1_num_1'] + 1 + FROM ${leafTable} ORDER BY id """ + + sql """ set enable_variant_schema_auto_cast_in_select = false """ + qt_leaf_int1_select_off """ SELECT data['int_1'] FROM ${leafTable} ORDER BY id """ + test { + sql """ SELECT data['int_1'] + 1 FROM ${leafTable} ORDER BY id """ + exception "Cannot cast from variant" + } + sql """ SELECT data['int_nested'] FROM ${leafTable} """ + qt_leaf_int_nested_chain_select_off """ SELECT data['int_nested']['level1_num_1'] + FROM ${leafTable} ORDER BY id """ + qt_leaf_int_nested_dot_select_off """ SELECT data['int_nested.level1_num_1'] + FROM ${leafTable} ORDER BY id """ + qt_leaf_int_nested_deref_select_off """ SELECT data.int_nested.level1_num_1 FROM ${leafTable} ORDER BY id """ + test { + sql """ SELECT data['int_nested']['level1_num_1'] + 1 FROM ${leafTable} ORDER BY id """ + exception "Cannot cast from variant" + } + test { + sql """ SELECT data['int_nested.level1_num_1'] + 1 FROM ${leafTable} ORDER BY id """ + exception "Cannot cast from variant" + } + + sql "DROP TABLE IF EXISTS ${leafTable}" } From f2e44c2780e9a8db560a4377d910bbbee4316b64 Mon Sep 17 00:00:00 2001 From: Gary Date: Sun, 1 Feb 2026 05:51:03 +0800 Subject: [PATCH 14/27] hope last commit --- .../rules/analysis/BindExpression.java | 45 +++---- .../rules/analysis/ExpressionAnalyzer.java | 39 +++--- .../org/apache/doris/qe/SessionVariable.java | 26 +--- .../test_schema_template_auto_cast.out | 74 ++++------- .../test_schema_template_auto_cast.groovy | 124 +++++------------- 5 files changed, 101 insertions(+), 207 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java index 2387104fc1c576..36f978f771e836 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java @@ -370,7 +370,7 @@ private LogicalOneRowRelation bindOneRowRelation(MatchingContext OneRowRelation oneRowRelation = ctx.root; CascadesContext cascadesContext = ctx.cascadesContext; SimpleExprAnalyzer analyzer = - buildSimpleExprAnalyzer(oneRowRelation, cascadesContext, ImmutableList.of(), true); + buildSimpleExprAnalyzer(oneRowRelation, cascadesContext, ImmutableList.of()); List projects = analyzer.analyzeToList(oneRowRelation.getProjects()); return new LogicalOneRowRelation(oneRowRelation.getRelationId(), projects); } @@ -455,7 +455,7 @@ private LogicalHaving bindHavingAggregate( FunctionRegistry functionRegistry = cascadesContext.getConnectContext().getEnv().getFunctionRegistry(); Map aliasMap = buildAliasMap(having.child()); ExpressionAnalyzer havingAnalyzer = new ExpressionAnalyzer(having, aggOutputScope, cascadesContext, - false, true, false, aliasMap) { + false, true, aliasMap) { private boolean currentIsInAggregateFunction; @Override @@ -663,8 +663,7 @@ private LogicalSort bindSortWithSetOperation( List childOutput = sort.child().getOutput(); Map aliasMap = buildAliasMap(sort.child()); - SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(sort, cascadesContext, sort.children(), - false, aliasMap); + SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(sort, cascadesContext, sort.children(), aliasMap); Builder boundKeys = ImmutableList.builderWithExpectedSize(sort.getOrderKeys().size()); for (OrderKey orderKey : sort.getOrderKeys()) { Expression boundKey = bindWithOrdinal(orderKey.getExpr(), analyzer, childOutput); @@ -681,8 +680,7 @@ private LogicalJoin bindJoin(MatchingContext Map aliasMap = buildAliasMap(join.left()); aliasMap.putAll(buildAliasMap(join.right())); - SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(join, cascadesContext, join.children(), - false, aliasMap); + SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(join, cascadesContext, join.children(), aliasMap); Builder hashJoinConjuncts = ImmutableList.builderWithExpectedSize( join.getHashJoinConjuncts().size()); @@ -763,11 +761,11 @@ private LogicalPlan bindUsingJoin(MatchingContext> List rightConjunctsSlots = Lists.newArrayList(); for (Expression usingColumn : unboundHashJoinConjunct) { ExpressionAnalyzer leftExprAnalyzer = new ExpressionAnalyzer( - using, leftScope, cascadesContext, true, false, false, leftAliasMap); + using, leftScope, cascadesContext, true, false, leftAliasMap); Expression usingLeftSlot = leftExprAnalyzer.analyze(usingColumn, rewriteContext); ExpressionAnalyzer rightExprAnalyzer = new ExpressionAnalyzer( - using, rightScope, cascadesContext, true, false, false, rightAliasMap); + using, rightScope, cascadesContext, true, false, rightAliasMap); Expression usingRightSlot = rightExprAnalyzer.analyze(usingColumn, rewriteContext); rightConjunctsSlots.add((Slot) usingRightSlot); hashEqExprs.add(new EqualTo(usingLeftSlot, usingRightSlot)); @@ -784,7 +782,7 @@ private Plan bindProject(MatchingContext> ctx) { LogicalProject project = ctx.root; CascadesContext cascadesContext = ctx.cascadesContext; - SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(project, cascadesContext, project.children(), true); + SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(project, cascadesContext, project.children()); Builder boundProjectionsBuilder = ImmutableList.builderWithExpectedSize(project.getProjects().size()); StatementContext statementContext = ctx.statementContext; @@ -854,7 +852,7 @@ private Plan bindLoadProject(MatchingContext> ctx) { LogicalLoadProject project = ctx.root; CascadesContext cascadesContext = ctx.cascadesContext; - SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(project, cascadesContext, project.children(), true); + SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(project, cascadesContext, project.children()); Builder boundProjections = ImmutableList.builderWithExpectedSize(project.getProjects().size()); StatementContext statementContext = ctx.statementContext; for (Expression expression : project.getProjects()) { @@ -921,8 +919,7 @@ private Plan bindFilter(MatchingContext> ctx) { CascadesContext cascadesContext = ctx.cascadesContext; Map aliasMap = buildAliasMap(filter.child()); - SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(filter, cascadesContext, filter.children(), - false, aliasMap); + SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(filter, cascadesContext, filter.children(), aliasMap); ImmutableSet.Builder boundConjuncts = ImmutableSet.builder(); boolean changed = false; for (Expression expr : filter.getConjuncts()) { @@ -945,8 +942,7 @@ private Plan bindPreFilter(MatchingContext> ctx) { CascadesContext cascadesContext = ctx.cascadesContext; Map aliasMap = buildAliasMap(filter.child()); - SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(filter, cascadesContext, filter.children(), - false, aliasMap); + SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(filter, cascadesContext, filter.children(), aliasMap); ImmutableSet.Builder boundConjuncts = ImmutableSet.builder(); for (Expression conjunct : filter.getConjuncts()) { Expression boundExpr = analyzer.analyze(conjunct); @@ -1132,7 +1128,7 @@ private void bindQualifyByAggregate(Aggregate aggregate, Cascade Map aliasMap = buildAliasMap(aggregate); ExpressionAnalyzer qualifyAnalyzer = new ExpressionAnalyzer(qualify, aggOutputScope, cascadesContext, - true, true, false, aliasMap) { + true, true, aliasMap) { @Override protected List bindSlotByThisScope(UnboundSlot unboundSlot) { return bindByGroupByThenAggOutputThenAggChildOutput.get().bindSlot(this, unboundSlot); @@ -1168,7 +1164,7 @@ private Plan bindAggregate(MatchingContext> ctx) { LogicalAggregate agg = ctx.root; CascadesContext cascadesContext = ctx.cascadesContext; - SimpleExprAnalyzer aggOutputAnalyzer = buildSimpleExprAnalyzer(agg, cascadesContext, agg.children(), true); + SimpleExprAnalyzer aggOutputAnalyzer = buildSimpleExprAnalyzer(agg, cascadesContext, agg.children()); List boundAggOutput = aggOutputAnalyzer.analyzeToList(agg.getOutputExpressions()); List boundProjections = new ArrayList<>(boundAggOutput.size()); for (int i = 0; i < boundAggOutput.size(); i++) { @@ -1337,7 +1333,7 @@ private Plan bindRepeat(MatchingContext> ctx) { CascadesContext cascadesContext = ctx.cascadesContext; SimpleExprAnalyzer repeatOutputAnalyzer = buildSimpleExprAnalyzer(repeat, cascadesContext, - repeat.children(), true); + repeat.children()); List boundRepeatOutput = repeatOutputAnalyzer.analyzeToList(repeat.getOutputExpressions()); Supplier aggOutputScope = buildAggOutputScope(boundRepeatOutput, cascadesContext); Builder> boundGroupingSetsBuilder = @@ -1702,23 +1698,16 @@ private Scope toScope(CascadesContext cascadesContext, List slots, List children) { - return buildSimpleExprAnalyzer(currentPlan, cascadesContext, children, false, Collections.emptyMap()); + return buildSimpleExprAnalyzer(currentPlan, cascadesContext, children, Collections.emptyMap()); } protected SimpleExprAnalyzer buildSimpleExprAnalyzer( - Plan currentPlan, CascadesContext cascadesContext, List children, boolean autoCastInSelect) { - return buildSimpleExprAnalyzer(currentPlan, cascadesContext, children, autoCastInSelect, - Collections.emptyMap()); - } - - protected SimpleExprAnalyzer buildSimpleExprAnalyzer( - Plan currentPlan, CascadesContext cascadesContext, List children, boolean autoCastInSelect, - Map aliasMap) { + Plan currentPlan, CascadesContext cascadesContext, List children, Map aliasMap) { Scope scope = toScope(cascadesContext, PlanUtils.fastGetChildrenOutputs(children), PlanUtils.fastGetChildrenAsteriskOutputs(children)); ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(currentPlan, cascadesContext); ExpressionAnalyzer expressionAnalyzer = new ExpressionAnalyzer(currentPlan, - scope, cascadesContext, true, true, autoCastInSelect, aliasMap); + scope, cascadesContext, true, true, aliasMap); return expr -> expressionAnalyzer.analyze(expr, rewriteContext); } @@ -1735,7 +1724,7 @@ private SimpleExprAnalyzer buildCustomSlotBinderAnalyzer( CustomSlotBinderAnalyzer customSlotBinder) { ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(currentPlan, cascadesContext); ExpressionAnalyzer expressionAnalyzer = new ExpressionAnalyzer(currentPlan, defaultScope, cascadesContext, - enableExactMatch, bindSlotInOuterScope, false, aliasMap) { + enableExactMatch, bindSlotInOuterScope, aliasMap) { @Override protected List bindSlotByThisScope(UnboundSlot unboundSlot) { return customSlotBinder.bindSlot(this, unboundSlot); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java index 7b5e6e74469852..e460b7585fd4be 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java @@ -160,33 +160,23 @@ protected Expression processCompoundNewChildren(CompoundPredicate cp, List aliasMap; private int suppressVariantElementAtCastDepth = 0; /** ExpressionAnalyzer */ public ExpressionAnalyzer(Plan currentPlan, Scope scope, @Nullable CascadesContext cascadesContext, boolean enableExactMatch, boolean bindSlotInOuterScope) { - this(currentPlan, scope, cascadesContext, enableExactMatch, bindSlotInOuterScope, false); + this(currentPlan, scope, cascadesContext, enableExactMatch, bindSlotInOuterScope, Collections.emptyMap()); } /** ExpressionAnalyzer */ public ExpressionAnalyzer(Plan currentPlan, Scope scope, @Nullable CascadesContext cascadesContext, boolean enableExactMatch, boolean bindSlotInOuterScope, - boolean autoCastInSelect) { - this(currentPlan, scope, cascadesContext, enableExactMatch, bindSlotInOuterScope, - autoCastInSelect, Collections.emptyMap()); - } - - /** ExpressionAnalyzer */ - public ExpressionAnalyzer(Plan currentPlan, Scope scope, - @Nullable CascadesContext cascadesContext, boolean enableExactMatch, boolean bindSlotInOuterScope, - boolean autoCastInSelect, Map aliasMap) { + Map aliasMap) { super(scope, cascadesContext); this.currentPlan = currentPlan; this.enableExactMatch = enableExactMatch; this.bindSlotInOuterScope = bindSlotInOuterScope; - this.autoCastInSelect = autoCastInSelect; this.aliasMap = aliasMap == null ? Collections.emptyMap() : aliasMap; this.wantToParseSqlFromSqlCache = cascadesContext != null && CacheAnalyzer.canUseSqlCache(cascadesContext.getConnectContext().getSessionVariable()); @@ -380,6 +370,9 @@ public Expression visitUnboundSlot(UnboundSlot unboundSlot, ExpressionRewriteCon } else if (firstBound.containsType(ElementAt.class, StructElement.class)) { context.cascadesContext.getStatementContext().setHasNestedColumns(true); } + if (firstBound instanceof Alias) { + return maybeCastAliasExpression((Alias) firstBound, context); + } return maybeCastBoundSlot(firstBound, context); default: if (enableExactMatch) { @@ -478,7 +471,7 @@ private UnboundFunction processHighOrderFunction(UnboundFunction unboundFunction ExpressionAnalyzer lambdaAnalyzer = new ExpressionAnalyzer(currentPlan, new Scope(Optional.of(getScope()), boundedSlots), context == null ? null : context.cascadesContext, - true, true, autoCastInSelect, aliasMap) { + true, true, aliasMap) { @Override protected void couldNotFoundColumn(UnboundSlot unboundSlot, String tableName) { throw new AnalysisException("Unknown lambda slot '" @@ -772,10 +765,7 @@ private boolean isEnableVariantSchemaAutoCast(ExpressionRewriteContext context) if (sessionVariable == null || !sessionVariable.isEnableVariantSchemaAutoCast()) { return false; } - if (autoCastInSelect) { - return sessionVariable.isEnableVariantSchemaAutoCastInSelect(); - } - return true; + return sessionVariable.isEnableVariantSchemaAutoCast(); } private Expression wrapVariantElementAtWithCast(Expression expr) { @@ -875,6 +865,21 @@ private Expression maybeCastBoundSlot(Expression bound, ExpressionRewriteContext return new Cast(bound, targetType.get()); } + private Expression maybeCastAliasExpression(Alias alias, ExpressionRewriteContext context) { + if (suppressVariantElementAtCastDepth > 0 || !isEnableVariantSchemaAutoCast(context)) { + return alias; + } + Expression child = alias.child(); + if (!(child instanceof ElementAt)) { + return alias; + } + Expression casted = wrapVariantElementAtWithCast(child); + if (casted == child) { + return alias; + } + return alias.withChildren(ImmutableList.of(casted)); + } + private Optional resolveVariantTemplateType(Expression expr) { if (!(expr instanceof ElementAt)) { return Optional.empty(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java index c89ced47a32b98..6811042ba178c9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java @@ -774,8 +774,6 @@ public class SessionVariable implements Serializable, Writable { // which means do not flatten nested when create table public static final String ENABLE_VARIANT_FLATTEN_NESTED = "enable_variant_flatten_nested"; public static final String ENABLE_VARIANT_SCHEMA_AUTO_CAST = "enable_variant_schema_auto_cast"; - public static final String ENABLE_VARIANT_SCHEMA_AUTO_CAST_IN_SELECT = - "enable_variant_schema_auto_cast_in_select"; // CLOUD_VARIABLES_BEGIN public static final String CLOUD_CLUSTER = "cloud_cluster"; @@ -3240,27 +3238,13 @@ public boolean isEnableESParallelScroll() { needForward = true, affectQueryResultInExecution = true, description = { - "是否启用基于 schema template 的 variant 自动 cast(非 SELECT 子句),默认关闭。", - "Whether to enable schema-template-based auto cast for variant expressions " - + "(non-SELECT clauses). The default is false." + "是否启用基于 schema template 的 variant 自动 cast,默认关闭。", + "Whether to enable schema-template-based auto cast for variant expressions. " + + "The default is false." } ) public boolean enableVariantSchemaAutoCast = false; - @VariableMgr.VarAttr( - name = ENABLE_VARIANT_SCHEMA_AUTO_CAST_IN_SELECT, - needForward = true, - affectQueryResultInExecution = true, - description = { - "是否在 SELECT 子句中启用基于 schema template 的 variant 自动 cast,默认关闭," - + "需先开启 enable_variant_schema_auto_cast。", - "Whether to enable schema-template-based auto cast for variant expressions " - + "in SELECT clause. The default is false and requires " - + "enable_variant_schema_auto_cast = true." - } - ) - public boolean enableVariantSchemaAutoCastInSelect = false; - @VariableMgr.VarAttr( name = DEFAULT_VARIANT_ENABLE_TYPED_PATHS_TO_SPARSE, needForward = true, @@ -5857,10 +5841,6 @@ public boolean isEnableVariantSchemaAutoCast() { return enableVariantSchemaAutoCast; } - public boolean isEnableVariantSchemaAutoCastInSelect() { - return enableVariantSchemaAutoCast && enableVariantSchemaAutoCastInSelect; - } - public void setProfileLevel(String profileLevel) { int profileLevelTmp = Integer.valueOf(profileLevel); if (profileLevelTmp < 1 || profileLevelTmp > 3) { diff --git a/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out b/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out index 9ea29fb6c6d53b..5fdcfefd1c9a73 100644 --- a/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out +++ b/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out @@ -22,16 +22,6 @@ 3 50 2 30 --- !order_by_select_on -- -3 50 -2 30 -4 15 -1 10 - --- !topn_select_on -- -3 50 -2 30 - -- !select_arithmetic -- 1 30 2 70 @@ -54,25 +44,13 @@ charlie 50 30 50 --- !order_by_alias_select_on -- -10 -15 -30 -50 - -- !order_by_alias_subquery -- 1 10 4 15 2 30 3 50 --- !order_by_alias_subquery_select_on -- -1 10 -4 15 -2 30 -3 50 - --- !group_by_alias_subquery_select_on -- +-- !group_by_alias_subquery -- 10 1 15 1 30 1 @@ -88,18 +66,10 @@ charlie 50 1 first 2 second --- !join_on_select_on -- -1 first -2 second - -- !join_on_alias_subquery -- 1 first 2 second --- !join_on_alias_subquery_select_on -- -1 first -2 second - -- !match_name_exact_where -- 2 @@ -114,35 +84,49 @@ charlie 50 1 2 --- !leaf_int1_select_on -- +-- !leaf_int1_select -- 1 --- !leaf_int1_add_select_on -- +-- !leaf_int1_add -- 2 --- !leaf_int_nested_chain_select_on -- +-- !leaf_int_nested_nonleaf -- +\\N + +-- !leaf_int_nested_chain_select -- 1011111 --- !leaf_int_nested_dot_select_on -- +-- !leaf_int_nested_dot_select -- 1011111 --- !leaf_int_nested_deref_select_on -- +-- !leaf_int_nested_deref_select -- 1011111 --- !leaf_int_nested_chain_add_select_on -- +-- !leaf_int_nested_chain_add -- +1011112 + +-- !leaf_int_nested_dot_add -- 1011112 --- !leaf_int_nested_dot_add_select_on -- +-- !leaf_int_nested_deref_add -- 1011112 --- !leaf_int1_select_off -- +-- !leaf_where_ok -- 1 --- !leaf_int_nested_chain_select_off -- -1011111 +-- !leaf_where_nonleaf -- --- !leaf_int_nested_dot_select_off -- -1011111 +-- !leaf_order_by_ok -- +1 --- !leaf_int_nested_deref_select_off -- -1011111 +-- !leaf_order_by_nonleaf -- +1 + +-- !leaf_group_by_ok -- +1 1 + +-- !leaf_group_by_nonleaf -- +\\N 1 + +-- !leaf_having_ok -- +1 1 diff --git a/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy b/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy index 67be81500aec2d..071de9ec3d195e 100644 --- a/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy +++ b/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy @@ -60,26 +60,12 @@ suite("test_schema_template_auto_cast", "p0") { // Test 3: TopN (ORDER BY + LIMIT) qt_topn """ SELECT id, data['num_a'] FROM ${tableName} ORDER BY data['num_a'] DESC LIMIT 2 """ - sql """ set enable_variant_schema_auto_cast_in_select = true """ - qt_order_by_select_on """ SELECT id, data['num_a'] FROM ${tableName} - ORDER BY data['num_a'] DESC """ - qt_topn_select_on """ SELECT id, data['num_a'] FROM ${tableName} - ORDER BY data['num_a'] DESC LIMIT 2 """ - sql """ set enable_variant_schema_auto_cast_in_select = false """ - // Test 4: SELECT with auto-cast (arithmetic operations) when enabled - sql """ set enable_variant_schema_auto_cast_in_select = true """ + // Test 4: SELECT with auto-cast (arithmetic operations) qt_select_arithmetic """ SELECT id, data['num_a'] + data['num_b'] as sum_val FROM ${tableName} ORDER BY id """ - sql """ set enable_variant_schema_auto_cast_in_select = false """ - test { - sql """ SELECT id, data['num_a'] + data['num_b'] as sum_val - FROM ${tableName} ORDER BY id """ - exception "Cannot cast from variant" - } // Test 5: GROUP BY with auto-cast - sql """ set enable_variant_schema_auto_cast_in_select = true """ qt_group_by """ SELECT data['str_name'], SUM(data['num_a']) as total FROM ${tableName} GROUP BY data['str_name'] ORDER BY data['str_name'] """ @@ -87,57 +73,26 @@ suite("test_schema_template_auto_cast", "p0") { qt_having """ SELECT data['str_name'], SUM(data['num_a']) as total FROM ${tableName} GROUP BY data['str_name'] HAVING SUM(data['num_a']) > 20 ORDER BY data['str_name'] """ - sql """ set enable_variant_schema_auto_cast_in_select = false """ - test { - sql """ SELECT data['str_name'], SUM(data['num_a']) as total - FROM ${tableName} GROUP BY data['str_name'] ORDER BY data['str_name'] """ - exception "sum requires a numeric, boolean or string parameter" - } - test { - sql """ SELECT data['str_name'], SUM(data['num_a']) as total - FROM ${tableName} GROUP BY data['str_name'] - HAVING SUM(data['num_a']) > 20 ORDER BY data['str_name'] """ - exception "sum requires a numeric, boolean or string parameter" - } // Test 7: ORDER BY with alias from project - sql """ set enable_variant_schema_auto_cast_in_select = false """ qt_order_by_alias """ SELECT data['num_a'] AS num_a FROM ${tableName} ORDER BY num_a """ - sql """ set enable_variant_schema_auto_cast_in_select = true """ - qt_order_by_alias_select_on """ SELECT data['num_a'] AS num_a FROM ${tableName} - ORDER BY num_a """ - sql """ set enable_variant_schema_auto_cast_in_select = false """ // Test 8: ORDER BY with alias from subquery qt_order_by_alias_subquery """ SELECT * FROM (SELECT id, data['num_a'] AS num_a FROM ${tableName}) t ORDER BY num_a, id """ - sql """ set enable_variant_schema_auto_cast_in_select = true """ - qt_order_by_alias_subquery_select_on """ SELECT * FROM (SELECT id, data['num_a'] AS num_a FROM ${tableName}) t - ORDER BY num_a, id """ - sql """ set enable_variant_schema_auto_cast_in_select = false """ // Test 9: GROUP BY with alias from subquery - test { - sql """ SELECT num_a, COUNT(*) AS cnt - FROM (SELECT data['num_a'] AS num_a FROM ${tableName}) t - GROUP BY num_a ORDER BY num_a """ - exception "must appear in the GROUP BY clause" - } - sql """ set enable_variant_schema_auto_cast_in_select = true """ - qt_group_by_alias_subquery_select_on """ SELECT num_a, COUNT(*) AS cnt + qt_group_by_alias_subquery """ SELECT num_a, COUNT(*) AS cnt FROM (SELECT data['num_a'] AS num_a FROM ${tableName}) t GROUP BY num_a ORDER BY num_a """ - sql """ set enable_variant_schema_auto_cast_in_select = false """ // Test 10: WINDOW partition/order by with auto-cast - sql """ set enable_variant_schema_auto_cast_in_select = true """ qt_window_partition_order """ SELECT id, row_number() OVER (PARTITION BY data['str_name'] ORDER BY data['num_a']) AS rn FROM ${tableName} ORDER BY id """ - sql """ set enable_variant_schema_auto_cast_in_select = false """ - // Test 11: disable auto-cast should error in non-select clauses + // Test 11: disable auto-cast should error in ORDER BY sql """ set enable_variant_schema_auto_cast = false """ test { sql """ SELECT id FROM ${tableName} ORDER BY data['num_a'] """ @@ -180,12 +135,6 @@ suite("test_schema_template_auto_cast", "p0") { FROM ${leftTable} l JOIN ${rightTable} r ON l.data['key_id'] = r.info['key_id'] ORDER BY l.id """ - sql """ set enable_variant_schema_auto_cast_in_select = true """ - qt_join_on_select_on """ SELECT l.id, r.info['name_val'] - FROM ${leftTable} l JOIN ${rightTable} r - ON l.data['key_id'] = r.info['key_id'] - ORDER BY l.id """ - sql """ set enable_variant_schema_auto_cast_in_select = false """ // Test 13: JOIN ON with alias from subquery qt_join_on_alias_subquery """ SELECT l.id, r.name_val @@ -193,13 +142,6 @@ suite("test_schema_template_auto_cast", "p0") { JOIN (SELECT id, info['key_id'] AS key_id, info['name_val'] AS name_val FROM ${rightTable}) r ON l.key_id = r.key_id ORDER BY l.id """ - sql """ set enable_variant_schema_auto_cast_in_select = true """ - qt_join_on_alias_subquery_select_on """ SELECT l.id, r.name_val - FROM (SELECT id, data['key_id'] AS key_id FROM ${leftTable}) l - JOIN (SELECT id, info['key_id'] AS key_id, info['name_val'] AS name_val FROM ${rightTable}) r - ON l.key_id = r.key_id - ORDER BY l.id """ - sql """ set enable_variant_schema_auto_cast_in_select = false """ sql "DROP TABLE IF EXISTS ${leftTable}" sql "DROP TABLE IF EXISTS ${rightTable}" @@ -243,43 +185,37 @@ suite("test_schema_template_auto_cast", "p0") { '{"int_1": 1, "int_nested": {"level1_num_1": 1011111, "level1_num_2": 102}}' )""" - sql """ set enable_variant_schema_auto_cast_in_select = true """ - qt_leaf_int1_select_on """ SELECT data['int_1'] FROM ${leafTable} ORDER BY id """ - qt_leaf_int1_add_select_on """ SELECT data['int_1'] + 1 FROM ${leafTable} ORDER BY id """ - test { - // still fails: FE can't distinguish leaf/non-leaf, may cast int_nested to int - sql """ SELECT data['int_nested'] FROM ${leafTable} """ - exception "Bad cast" - } - qt_leaf_int_nested_chain_select_on """ SELECT data['int_nested']['level1_num_1'] - FROM ${leafTable} ORDER BY id """ - qt_leaf_int_nested_dot_select_on """ SELECT data['int_nested.level1_num_1'] FROM ${leafTable} ORDER BY id """ - qt_leaf_int_nested_deref_select_on """ SELECT data.int_nested.level1_num_1 FROM ${leafTable} ORDER BY id """ - qt_leaf_int_nested_chain_add_select_on """ SELECT data['int_nested']['level1_num_1'] + 1 + qt_leaf_int1_select """ SELECT data['int_1'] FROM ${leafTable} ORDER BY id """ + qt_leaf_int1_add """ SELECT data['int_1'] + 1 FROM ${leafTable} ORDER BY id """ + // still fails: FE can't distinguish leaf/non-leaf, may cast int_nested to int + qt_leaf_int_nested_nonleaf """ SELECT data['int_nested'] FROM ${leafTable} ORDER BY id """ + qt_leaf_int_nested_chain_select """ SELECT data['int_nested']['level1_num_1'] FROM ${leafTable} ORDER BY id """ - qt_leaf_int_nested_dot_add_select_on """ SELECT data['int_nested.level1_num_1'] + 1 + qt_leaf_int_nested_dot_select """ SELECT data['int_nested.level1_num_1'] FROM ${leafTable} ORDER BY id """ + qt_leaf_int_nested_deref_select """ SELECT data.int_nested.level1_num_1 FROM ${leafTable} ORDER BY id """ + qt_leaf_int_nested_chain_add """ SELECT data['int_nested']['level1_num_1'] + 1 FROM ${leafTable} ORDER BY id """ - - sql """ set enable_variant_schema_auto_cast_in_select = false """ - qt_leaf_int1_select_off """ SELECT data['int_1'] FROM ${leafTable} ORDER BY id """ - test { - sql """ SELECT data['int_1'] + 1 FROM ${leafTable} ORDER BY id """ - exception "Cannot cast from variant" - } - sql """ SELECT data['int_nested'] FROM ${leafTable} """ - qt_leaf_int_nested_chain_select_off """ SELECT data['int_nested']['level1_num_1'] + qt_leaf_int_nested_dot_add """ SELECT data['int_nested.level1_num_1'] + 1 FROM ${leafTable} ORDER BY id """ - qt_leaf_int_nested_dot_select_off """ SELECT data['int_nested.level1_num_1'] + qt_leaf_int_nested_deref_add """ SELECT data.int_nested.level1_num_1 + 1 FROM ${leafTable} ORDER BY id """ - qt_leaf_int_nested_deref_select_off """ SELECT data.int_nested.level1_num_1 FROM ${leafTable} ORDER BY id """ - test { - sql """ SELECT data['int_nested']['level1_num_1'] + 1 FROM ${leafTable} ORDER BY id """ - exception "Cannot cast from variant" - } - test { - sql """ SELECT data['int_nested.level1_num_1'] + 1 FROM ${leafTable} ORDER BY id """ - exception "Cannot cast from variant" - } + + // Non-select clauses: leaf vs non-leaf + qt_leaf_where_ok """ SELECT id FROM ${leafTable} + WHERE data['int_1'] > 0 ORDER BY id """ + qt_leaf_where_nonleaf """ SELECT id FROM ${leafTable} + WHERE data['int_nested'] > 0 ORDER BY id """ + qt_leaf_order_by_ok """ SELECT id FROM ${leafTable} + ORDER BY data['int_1'] """ + qt_leaf_order_by_nonleaf """ SELECT id FROM ${leafTable} + ORDER BY data['int_nested'] """ + qt_leaf_group_by_ok """ SELECT data['int_1'], COUNT(*) AS cnt + FROM ${leafTable} GROUP BY data['int_1'] ORDER BY data['int_1'] """ + qt_leaf_group_by_nonleaf """ SELECT data['int_nested'], COUNT(*) AS cnt + FROM ${leafTable} GROUP BY data['int_nested'] ORDER BY data['int_nested'] """ + qt_leaf_having_ok """ SELECT data['int_1'], SUM(data['int_1']) AS total + FROM ${leafTable} GROUP BY data['int_1'] + HAVING SUM(data['int_1']) > 0 ORDER BY data['int_1'] """ sql "DROP TABLE IF EXISTS ${leafTable}" } From 08928130b3304ed6ff84cb681ae98b39f6b889cb Mon Sep 17 00:00:00 2001 From: Gary Date: Sun, 1 Feb 2026 12:20:42 +0800 Subject: [PATCH 15/27] use processBoundFunction --- .../nereids/rules/analysis/ExpressionAnalyzer.java | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java index e460b7585fd4be..84a55dd400687c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java @@ -326,10 +326,15 @@ public Expression visitElementAt(ElementAt elementAt, ExpressionRewriteContext c } Expression right = elementAt.right().accept(this, context); elementAt = (ElementAt) elementAt.withChildren(left, right); - if (isEnableVariantSchemaAutoCast(context)) { - return wrapVariantElementAtWithCast(elementAt); + Expression coerced = TypeCoercionUtils.processBoundFunction(elementAt); + if (coerced instanceof ElementAt) { + ElementAt coercedElementAt = (ElementAt) coerced; + if (isEnableVariantSchemaAutoCast(context)) { + return wrapVariantElementAtWithCast(coercedElementAt); + } + return coercedElementAt; } - return elementAt; + return coerced; } @Override From 7de43f9456a88de53488f13af60bbe8f4dcf3b4e Mon Sep 17 00:00:00 2001 From: Gary Date: Mon, 2 Feb 2026 17:08:24 +0800 Subject: [PATCH 16/27] remove alias map --- .../rules/analysis/BindExpression.java | 114 +++--------------- .../rules/analysis/ExpressionAnalyzer.java | 50 +------- .../trees/expressions/SlotReference.java | 9 -- .../apache/doris/nereids/types/DataType.java | 17 ++- 4 files changed, 29 insertions(+), 161 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java index 36f978f771e836..54220dbe4d142b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java @@ -119,7 +119,6 @@ import java.util.ArrayList; import java.util.Collection; -import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -369,8 +368,7 @@ private LogicalSetOperation bindSetOperation(LogicalSetOperation setOperation) { private LogicalOneRowRelation bindOneRowRelation(MatchingContext ctx) { OneRowRelation oneRowRelation = ctx.root; CascadesContext cascadesContext = ctx.cascadesContext; - SimpleExprAnalyzer analyzer = - buildSimpleExprAnalyzer(oneRowRelation, cascadesContext, ImmutableList.of()); + SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(oneRowRelation, cascadesContext, ImmutableList.of()); List projects = analyzer.analyzeToList(oneRowRelation.getProjects()); return new LogicalOneRowRelation(oneRowRelation.getRelationId(), projects); } @@ -453,9 +451,8 @@ private LogicalHaving bindHavingAggregate( }); FunctionRegistry functionRegistry = cascadesContext.getConnectContext().getEnv().getFunctionRegistry(); - Map aliasMap = buildAliasMap(having.child()); ExpressionAnalyzer havingAnalyzer = new ExpressionAnalyzer(having, aggOutputScope, cascadesContext, - false, true, aliasMap) { + false, true) { private boolean currentIsInAggregateFunction; @Override @@ -519,9 +516,8 @@ private LogicalHaving bindHavingByScopes( LogicalHaving having, Plan child, CascadesContext cascadesContext, Scope defaultScope, Supplier backupScope) { - Map aliasMap = buildAliasMap(child); SimpleExprAnalyzer analyzer = buildCustomSlotBinderAnalyzer( - having, cascadesContext, defaultScope, false, true, aliasMap, + having, cascadesContext, defaultScope, false, true, (self, unboundSlot) -> { List slots = self.bindSlotByScope(unboundSlot, defaultScope); if (!slots.isEmpty()) { @@ -662,8 +658,7 @@ private LogicalSort bindSortWithSetOperation( CascadesContext cascadesContext = ctx.cascadesContext; List childOutput = sort.child().getOutput(); - Map aliasMap = buildAliasMap(sort.child()); - SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(sort, cascadesContext, sort.children(), aliasMap); + SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(sort, cascadesContext, sort.children()); Builder boundKeys = ImmutableList.builderWithExpectedSize(sort.getOrderKeys().size()); for (OrderKey orderKey : sort.getOrderKeys()) { Expression boundKey = bindWithOrdinal(orderKey.getExpr(), analyzer, childOutput); @@ -678,9 +673,7 @@ private LogicalJoin bindJoin(MatchingContext checkConflictAlias(join); - Map aliasMap = buildAliasMap(join.left()); - aliasMap.putAll(buildAliasMap(join.right())); - SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(join, cascadesContext, join.children(), aliasMap); + SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(join, cascadesContext, join.children()); Builder hashJoinConjuncts = ImmutableList.builderWithExpectedSize( join.getHashJoinConjuncts().size()); @@ -754,18 +747,16 @@ private LogicalPlan bindUsingJoin(MatchingContext> Scope leftScope = toScope(cascadesContext, using.left().getOutput(), using.left().getAsteriskOutput()); Scope rightScope = toScope(cascadesContext, using.right().getOutput(), using.right().getAsteriskOutput()); ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(using, cascadesContext); - Map leftAliasMap = buildAliasMap(using.left()); - Map rightAliasMap = buildAliasMap(using.right()); Builder hashEqExprs = ImmutableList.builderWithExpectedSize(unboundHashJoinConjunct.size()); List rightConjunctsSlots = Lists.newArrayList(); for (Expression usingColumn : unboundHashJoinConjunct) { ExpressionAnalyzer leftExprAnalyzer = new ExpressionAnalyzer( - using, leftScope, cascadesContext, true, false, leftAliasMap); + using, leftScope, cascadesContext, true, false); Expression usingLeftSlot = leftExprAnalyzer.analyze(usingColumn, rewriteContext); ExpressionAnalyzer rightExprAnalyzer = new ExpressionAnalyzer( - using, rightScope, cascadesContext, true, false, rightAliasMap); + using, rightScope, cascadesContext, true, false); Expression usingRightSlot = rightExprAnalyzer.analyze(usingColumn, rewriteContext); rightConjunctsSlots.add((Slot) usingRightSlot); hashEqExprs.add(new EqualTo(usingLeftSlot, usingRightSlot)); @@ -918,8 +909,7 @@ private Plan bindFilter(MatchingContext> ctx) { LogicalFilter filter = ctx.root; CascadesContext cascadesContext = ctx.cascadesContext; - Map aliasMap = buildAliasMap(filter.child()); - SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(filter, cascadesContext, filter.children(), aliasMap); + SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(filter, cascadesContext, filter.children()); ImmutableSet.Builder boundConjuncts = ImmutableSet.builder(); boolean changed = false; for (Expression expr : filter.getConjuncts()) { @@ -941,8 +931,7 @@ private Plan bindPreFilter(MatchingContext> ctx) { LogicalPreFilter filter = ctx.root; CascadesContext cascadesContext = ctx.cascadesContext; - Map aliasMap = buildAliasMap(filter.child()); - SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(filter, cascadesContext, filter.children(), aliasMap); + SimpleExprAnalyzer analyzer = buildSimpleExprAnalyzer(filter, cascadesContext, filter.children()); ImmutableSet.Builder boundConjuncts = ImmutableSet.builder(); for (Expression conjunct : filter.getConjuncts()) { Expression boundExpr = analyzer.analyze(conjunct); @@ -1071,9 +1060,8 @@ private void bindQualifyByProject(LogicalProject project, Cascad ); Scope backupScope = toScope(cascadesContext, project.getOutput()); - Map aliasMap = buildAliasMap(project); SimpleExprAnalyzer analyzer = buildCustomSlotBinderAnalyzer( - qualify, cascadesContext, defaultScope.get(), true, true, aliasMap, + qualify, cascadesContext, defaultScope.get(), true, true, (self, unboundSlot) -> { List slots = self.bindSlotByScope(unboundSlot, defaultScope.get()); if (!slots.isEmpty()) { @@ -1126,9 +1114,8 @@ private void bindQualifyByAggregate(Aggregate aggregate, Cascade }; }); - Map aliasMap = buildAliasMap(aggregate); ExpressionAnalyzer qualifyAnalyzer = new ExpressionAnalyzer(qualify, aggOutputScope, cascadesContext, - true, true, aliasMap) { + true, true) { @Override protected List bindSlotByThisScope(UnboundSlot unboundSlot) { return bindByGroupByThenAggOutputThenAggChildOutput.get().bindSlot(this, unboundSlot); @@ -1332,8 +1319,7 @@ private Plan bindRepeat(MatchingContext> ctx) { LogicalRepeat repeat = ctx.root; CascadesContext cascadesContext = ctx.cascadesContext; - SimpleExprAnalyzer repeatOutputAnalyzer = buildSimpleExprAnalyzer(repeat, cascadesContext, - repeat.children()); + SimpleExprAnalyzer repeatOutputAnalyzer = buildSimpleExprAnalyzer(repeat, cascadesContext, repeat.children()); List boundRepeatOutput = repeatOutputAnalyzer.analyzeToList(repeat.getOutputExpressions()); Supplier aggOutputScope = buildAggOutputScope(boundRepeatOutput, cascadesContext); Builder> boundGroupingSetsBuilder = @@ -1417,9 +1403,8 @@ private List bindGroupBy( Supplier aggOutputScope, CascadesContext cascadesContext) { Scope childOutputScope = toScope(cascadesContext, agg.child().getOutput()); - Map aliasMap = buildAliasMap(agg); SimpleExprAnalyzer analyzer = buildCustomSlotBinderAnalyzer( - agg, cascadesContext, childOutputScope, true, true, aliasMap, + agg, cascadesContext, childOutputScope, true, true, (self, unboundSlot) -> { // see: https://github.com/apache/doris/pull/15240 // @@ -1483,52 +1468,6 @@ private Supplier buildAggOutputScope( }); } - private Map buildAliasMap(Plan plan) { - Map aliasMap = new HashMap<>(); - Plan current = unwrapSubQueryAlias(plan); - collectAliasMapFromProjectChain(aliasMap, current); - if (aliasMap.isEmpty() && current instanceof LogicalAggregate && current.arity() == 1) { - collectAliasMapFromProjectChain(aliasMap, unwrapSubQueryAlias(current.child(0))); - } - if (aliasMap.isEmpty() && current instanceof LogicalAggregate) { - collectAliasMap(aliasMap, ((LogicalAggregate) current).getOutputExpressions()); - } - return aliasMap; - } - - private Plan unwrapSubQueryAlias(Plan plan) { - Plan current = plan; - while (current instanceof LogicalSubQueryAlias) { - current = ((LogicalSubQueryAlias) current).child(); - } - return current; - } - - private void collectAliasMapFromProjectChain(Map aliasMap, Plan start) { - Plan current = start; - while (current instanceof LogicalProject) { - int before = aliasMap.size(); - collectAliasMap(aliasMap, ((LogicalProject) current).getProjects()); - if (aliasMap.size() > before) { - break; - } - // passthrough project (e.g. SELECT *), keep searching in child - if (current.arity() != 1) { - break; - } - current = unwrapSubQueryAlias(current.child(0)); - } - } - - private void collectAliasMap(Map aliasMap, List outputs) { - for (NamedExpression output : outputs) { - if (output instanceof Alias) { - Alias alias = (Alias) output; - aliasMap.put(alias.getExprId(), alias.child()); - } - } - } - private Plan bindSortWithoutSetOperation(MatchingContext> ctx) { CascadesContext cascadesContext = ctx.cascadesContext; LogicalSort sort = ctx.root; @@ -1572,13 +1511,12 @@ private Plan bindSortWithoutSetOperation(MatchingContext> ctx) // bind order_col1 with alias_col1, then, bind it with inner_col1 List inputSlots = input.getOutput(); Scope inputScope = toScope(cascadesContext, inputSlots); - Map aliasMap = buildAliasMap(input); final Plan finalInput = input; Supplier inputChildrenScope = Suppliers.memoize( () -> toScope(cascadesContext, PlanUtils.fastGetChildrenOutputs(finalInput.children()))); SimpleExprAnalyzer bindInInputScopeThenInputChildScope = buildCustomSlotBinderAnalyzer( - sort, cascadesContext, inputScope, true, false, aliasMap, + sort, cascadesContext, inputScope, true, false, (self, unboundSlot) -> { // first, try to bind slot in Scope(input.output) List slotsInInput = self.bindExactSlotsByThisScope(unboundSlot, inputScope); @@ -1593,7 +1531,7 @@ private Plan bindSortWithoutSetOperation(MatchingContext> ctx) }); SimpleExprAnalyzer bindInInputChildScope = getAnalyzerForOrderByAggFunc(finalInput, cascadesContext, sort, - inputChildrenScope, inputScope, aliasMap); + inputChildrenScope, inputScope); Builder boundOrderKeys = ImmutableList.builderWithExpectedSize(sort.getOrderKeys().size()); FunctionRegistry functionRegistry = cascadesContext.getConnectContext().getEnv().getFunctionRegistry(); Map bindUniqueIdReplaceMap = getBelowAggregateGroupByUniqueFuncReplaceMap(sort); @@ -1698,33 +1636,20 @@ private Scope toScope(CascadesContext cascadesContext, List slots, List children) { - return buildSimpleExprAnalyzer(currentPlan, cascadesContext, children, Collections.emptyMap()); - } - - protected SimpleExprAnalyzer buildSimpleExprAnalyzer( - Plan currentPlan, CascadesContext cascadesContext, List children, Map aliasMap) { Scope scope = toScope(cascadesContext, PlanUtils.fastGetChildrenOutputs(children), PlanUtils.fastGetChildrenAsteriskOutputs(children)); ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(currentPlan, cascadesContext); ExpressionAnalyzer expressionAnalyzer = new ExpressionAnalyzer(currentPlan, - scope, cascadesContext, true, true, aliasMap); + scope, cascadesContext, true, true); return expr -> expressionAnalyzer.analyze(expr, rewriteContext); } private SimpleExprAnalyzer buildCustomSlotBinderAnalyzer( Plan currentPlan, CascadesContext cascadesContext, Scope defaultScope, boolean enableExactMatch, boolean bindSlotInOuterScope, CustomSlotBinderAnalyzer customSlotBinder) { - return buildCustomSlotBinderAnalyzer(currentPlan, cascadesContext, defaultScope, enableExactMatch, - bindSlotInOuterScope, Collections.emptyMap(), customSlotBinder); - } - - private SimpleExprAnalyzer buildCustomSlotBinderAnalyzer( - Plan currentPlan, CascadesContext cascadesContext, Scope defaultScope, - boolean enableExactMatch, boolean bindSlotInOuterScope, Map aliasMap, - CustomSlotBinderAnalyzer customSlotBinder) { ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(currentPlan, cascadesContext); ExpressionAnalyzer expressionAnalyzer = new ExpressionAnalyzer(currentPlan, defaultScope, cascadesContext, - enableExactMatch, bindSlotInOuterScope, aliasMap) { + enableExactMatch, bindSlotInOuterScope) { @Override protected List bindSlotByThisScope(UnboundSlot unboundSlot) { return customSlotBinder.bindSlot(this, unboundSlot); @@ -1781,8 +1706,7 @@ private boolean hasAggregateFunction(Expression expression, FunctionRegistry fun } private SimpleExprAnalyzer getAnalyzerForOrderByAggFunc(Plan finalInput, CascadesContext cascadesContext, - LogicalSort sort, Supplier inputChildrenScope, Scope inputScope, - Map aliasMap) { + LogicalSort sort, Supplier inputChildrenScope, Scope inputScope) { ImmutableList.Builder outputSlots = ImmutableList.builder(); if (finalInput instanceof LogicalAggregate) { LogicalAggregate aggregate = (LogicalAggregate) finalInput; @@ -1795,7 +1719,7 @@ private SimpleExprAnalyzer getAnalyzerForOrderByAggFunc(Plan finalInput, Cascade } Scope outputWithoutAggFunc = toScope(cascadesContext, outputSlots.build()); SimpleExprAnalyzer bindInInputChildScope = buildCustomSlotBinderAnalyzer( - sort, cascadesContext, inputScope, true, false, aliasMap, + sort, cascadesContext, inputScope, true, false, (analyzer, unboundSlot) -> { if (finalInput instanceof LogicalAggregate) { List boundInOutputWithoutAggFunc = analyzer.bindSlotByScope(unboundSlot, diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java index 84a55dd400687c..46ad853947d373 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java @@ -121,9 +121,7 @@ import org.apache.commons.lang3.StringUtils; import java.util.ArrayList; -import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -160,24 +158,15 @@ protected Expression processCompoundNewChildren(CompoundPredicate cp, List aliasMap; private int suppressVariantElementAtCastDepth = 0; /** ExpressionAnalyzer */ public ExpressionAnalyzer(Plan currentPlan, Scope scope, @Nullable CascadesContext cascadesContext, boolean enableExactMatch, boolean bindSlotInOuterScope) { - this(currentPlan, scope, cascadesContext, enableExactMatch, bindSlotInOuterScope, Collections.emptyMap()); - } - - /** ExpressionAnalyzer */ - public ExpressionAnalyzer(Plan currentPlan, Scope scope, - @Nullable CascadesContext cascadesContext, boolean enableExactMatch, boolean bindSlotInOuterScope, - Map aliasMap) { super(scope, cascadesContext); this.currentPlan = currentPlan; this.enableExactMatch = enableExactMatch; this.bindSlotInOuterScope = bindSlotInOuterScope; - this.aliasMap = aliasMap == null ? Collections.emptyMap() : aliasMap; this.wantToParseSqlFromSqlCache = cascadesContext != null && CacheAnalyzer.canUseSqlCache(cascadesContext.getConnectContext().getSessionVariable()); } @@ -378,7 +367,7 @@ public Expression visitUnboundSlot(UnboundSlot unboundSlot, ExpressionRewriteCon if (firstBound instanceof Alias) { return maybeCastAliasExpression((Alias) firstBound, context); } - return maybeCastBoundSlot(firstBound, context); + return firstBound; default: if (enableExactMatch) { // select t1.k k, t2.k @@ -476,7 +465,7 @@ private UnboundFunction processHighOrderFunction(UnboundFunction unboundFunction ExpressionAnalyzer lambdaAnalyzer = new ExpressionAnalyzer(currentPlan, new Scope(Optional.of(getScope()), boundedSlots), context == null ? null : context.cascadesContext, - true, true, aliasMap) { + true, true) { @Override protected void couldNotFoundColumn(UnboundSlot unboundSlot, String tableName) { throw new AnalysisException("Unknown lambda slot '" @@ -846,30 +835,6 @@ private boolean shouldSuppressVariantElementAtCast(Cast cast) { return child instanceof ElementAt || child instanceof DereferenceExpression || child instanceof UnboundSlot; } - private Expression maybeCastBoundSlot(Expression bound, ExpressionRewriteContext context) { - if (!(bound instanceof SlotReference)) { - return bound; - } - if (suppressVariantElementAtCastDepth > 0 || aliasMap.isEmpty()) { - return bound; - } - if (!isEnableVariantSchemaAutoCast(context)) { - return bound; - } - if (!bound.getDataType().isVariantType()) { - return bound; - } - Expression aliasExpr = aliasMap.get(((SlotReference) bound).getExprId()); - if (aliasExpr == null) { - return bound; - } - Optional targetType = resolveVariantTemplateType(aliasExpr); - if (!targetType.isPresent()) { - return bound; - } - return new Cast(bound, targetType.get()); - } - private Expression maybeCastAliasExpression(Alias alias, ExpressionRewriteContext context) { if (suppressVariantElementAtCastDepth > 0 || !isEnableVariantSchemaAutoCast(context)) { return alias; @@ -885,17 +850,6 @@ private Expression maybeCastAliasExpression(Alias alias, ExpressionRewriteContex return alias.withChildren(ImmutableList.of(casted)); } - private Optional resolveVariantTemplateType(Expression expr) { - if (!(expr instanceof ElementAt)) { - return Optional.empty(); - } - Expression rewritten = wrapVariantElementAtWithCast(expr); - if (rewritten instanceof Cast) { - return Optional.of(((Cast) rewritten).getDataType()); - } - return Optional.empty(); - } - @Override public Expression visitNot(Not not, ExpressionRewriteContext context) { // maybe is `not subquery`, we should bind it first diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java index 130c7920d14e3c..1c77dc669acb72 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java @@ -155,15 +155,6 @@ public static SlotReference fromColumn(ExprId exprId, TableIf table, Column colu return fromColumn(exprId, table, column, column.getName(), qualifier); } - /** - * Get SlotReference from a column with custom name. - * @param exprId the expression id - * @param table the table which contains the column - * @param column the column which contains type info - * @param name the name of SlotReference - * @param qualifier the qualifier of SlotReference - * @return SlotReference created from column - */ public static SlotReference fromColumn( ExprId exprId, TableIf table, Column column, String name, List qualifier) { DataType dataType = DataType.fromCatalogType(column.getType()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java index 9f4f3fc862ef53..911dc2e4e2cd51 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java @@ -480,20 +480,19 @@ public static DataType fromCatalogType(Type type) { // In the past, variant metadata used the ScalarType type. // Now, we use VariantType, which inherits from ScalarType, as the new metadata storage. if (type instanceof org.apache.doris.catalog.VariantType) { - org.apache.doris.catalog.VariantType catalogVariantType = (org.apache.doris.catalog.VariantType) type; - List variantFields = catalogVariantType + List variantFields = ((org.apache.doris.catalog.VariantType) type) .getPredefinedFields().stream() .map(cf -> new VariantField(cf.getPattern(), fromCatalogType(cf.getType()), cf.getComment() == null ? "" : cf.getComment(), cf.getPatternType().toString())) .collect(ImmutableList.toImmutableList()); return new VariantType(variantFields, - catalogVariantType.getVariantMaxSubcolumnsCount(), - catalogVariantType.getEnableTypedPathsToSparse(), - catalogVariantType.getVariantMaxSparseColumnStatisticsSize(), - catalogVariantType.getVariantSparseHashShardCount(), - catalogVariantType.getEnableVariantDocMode(), - catalogVariantType.getvariantDocMaterializationMinRows(), - catalogVariantType.getVariantDocShardCount()); + ((org.apache.doris.catalog.VariantType) type).getVariantMaxSubcolumnsCount(), + ((org.apache.doris.catalog.VariantType) type).getEnableTypedPathsToSparse(), + ((org.apache.doris.catalog.VariantType) type).getVariantMaxSparseColumnStatisticsSize(), + ((org.apache.doris.catalog.VariantType) type).getVariantSparseHashShardCount(), + ((org.apache.doris.catalog.VariantType) type).getEnableVariantDocMode(), + ((org.apache.doris.catalog.VariantType) type).getvariantDocMaterializationMinRows(), + ((org.apache.doris.catalog.VariantType) type).getVariantDocShardCount()); } return VariantType.INSTANCE; } else { From 06fedbd58187e1d43ea64108811a20929dc7dd22 Mon Sep 17 00:00:00 2001 From: Gary Date: Tue, 3 Feb 2026 00:19:13 +0800 Subject: [PATCH 17/27] simplify code and enhance tests --- .../rules/analysis/ExpressionAnalyzer.java | 11 +- ...ExpressionAnalyzerVariantAutoCastTest.java | 173 ++++++++ .../rewrite/VariantSchemaTemplateTest.java | 386 ------------------ .../nereids/types/VariantFieldMatchTest.java | 29 ++ .../test_schema_template_auto_cast.out | 142 ++++++- .../test_schema_template_auto_cast.groovy | 42 +- 6 files changed, 379 insertions(+), 404 deletions(-) create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzerVariantAutoCastTest.java delete mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaTemplateTest.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java index 46ad853947d373..c04a67f614ecdf 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java @@ -316,12 +316,8 @@ public Expression visitElementAt(ElementAt elementAt, ExpressionRewriteContext c Expression right = elementAt.right().accept(this, context); elementAt = (ElementAt) elementAt.withChildren(left, right); Expression coerced = TypeCoercionUtils.processBoundFunction(elementAt); - if (coerced instanceof ElementAt) { - ElementAt coercedElementAt = (ElementAt) coerced; - if (isEnableVariantSchemaAutoCast(context)) { - return wrapVariantElementAtWithCast(coercedElementAt); - } - return coercedElementAt; + if (isEnableVariantSchemaAutoCast(context)) { + return wrapVariantElementAtWithCast(coerced); } return coerced; } @@ -840,9 +836,6 @@ private Expression maybeCastAliasExpression(Alias alias, ExpressionRewriteContex return alias; } Expression child = alias.child(); - if (!(child instanceof ElementAt)) { - return alias; - } Expression casted = wrapVariantElementAtWithCast(child); if (casted == child) { return alias; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzerVariantAutoCastTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzerVariantAutoCastTest.java new file mode 100644 index 00000000000000..cc99188fa6d938 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzerVariantAutoCastTest.java @@ -0,0 +1,173 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.analysis; + +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.analyzer.Scope; +import org.apache.doris.nereids.analyzer.UnboundSlot; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Cast; +import org.apache.doris.nereids.trees.expressions.DereferenceExpression; +import org.apache.doris.nereids.trees.expressions.ExprId; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt; +import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; +import org.apache.doris.nereids.trees.expressions.literal.StringLiteral; +import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.VariantField; +import org.apache.doris.nereids.types.VariantType; +import org.apache.doris.qe.ConnectContext; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.lang.reflect.Method; +import java.util.Optional; + +public class ExpressionAnalyzerVariantAutoCastTest { + + @AfterEach + public void cleanup() { + ConnectContext.remove(); + } + + private CascadesContext createContext(boolean enableAutoCast) { + ConnectContext ctx = new ConnectContext(); + ctx.getSessionVariable().enableVariantSchemaAutoCast = enableAutoCast; + ctx.setThreadLocalInfo(); + return CascadesContext.initTempContext(); + } + + private Expression analyze(Expression expr, Scope scope, boolean enableAutoCast) { + CascadesContext cascadesContext = createContext(enableAutoCast); + ExpressionAnalyzer analyzer = new ExpressionAnalyzer(null, scope, cascadesContext, true, true); + return analyzer.analyze(expr); + } + + private SlotReference buildVariantSlot(VariantType variantType) { + return new SlotReference(new ExprId(1), "data", variantType, true, ImmutableList.of()); + } + + @Test + public void testVisitElementAtAutoCastEnabled() { + VariantField field = new VariantField("number_*", BigIntType.INSTANCE, ""); + VariantType variantType = new VariantType(ImmutableList.of(field)); + SlotReference slot = buildVariantSlot(variantType); + Scope scope = new Scope(ImmutableList.of(slot)); + + ElementAt elementAt = new ElementAt(slot, new StringLiteral("number_latency")); + Expression result = analyze(elementAt, scope, true); + + Assertions.assertTrue(result instanceof Cast); + Cast cast = (Cast) result; + Assertions.assertEquals(BigIntType.INSTANCE, cast.getDataType()); + Assertions.assertTrue(cast.child() instanceof ElementAt); + } + + @Test + public void testVisitElementAtAutoCastDisabled() { + VariantField field = new VariantField("number_*", BigIntType.INSTANCE, ""); + VariantType variantType = new VariantType(ImmutableList.of(field)); + SlotReference slot = buildVariantSlot(variantType); + Scope scope = new Scope(ImmutableList.of(slot)); + + ElementAt elementAt = new ElementAt(slot, new StringLiteral("number_latency")); + Expression result = analyze(elementAt, scope, false); + + Assertions.assertTrue(result instanceof ElementAt); + Assertions.assertFalse(result instanceof Cast); + } + + @Test + public void testVisitDereferenceExpressionAutoCast() { + VariantField field = new VariantField("number_*", BigIntType.INSTANCE, ""); + VariantType variantType = new VariantType(ImmutableList.of(field)); + SlotReference slot = buildVariantSlot(variantType); + Scope scope = new Scope(ImmutableList.of(slot)); + + DereferenceExpression deref = new DereferenceExpression(slot, new StringLiteral("number_latency")); + Expression result = analyze(deref, scope, true); + + Assertions.assertTrue(result instanceof Cast); + Cast cast = (Cast) result; + Assertions.assertEquals(BigIntType.INSTANCE, cast.getDataType()); + Assertions.assertTrue(cast.child() instanceof ElementAt); + } + + @Test + public void testResolveVariantElementAtPathChain() { + VariantField field = new VariantField("int_nested.level1_num_1", BigIntType.INSTANCE, ""); + VariantType variantType = new VariantType(ImmutableList.of(field)); + SlotReference slot = buildVariantSlot(variantType); + Scope scope = new Scope(ImmutableList.of(slot)); + + ElementAt elementAt = new ElementAt( + new ElementAt(slot, new StringLiteral("int_nested")), + new StringLiteral("level1_num_1") + ); + Expression result = analyze(elementAt, scope, true); + + Assertions.assertTrue(result instanceof Cast); + Cast cast = (Cast) result; + Assertions.assertEquals(BigIntType.INSTANCE, cast.getDataType()); + Assertions.assertTrue(cast.child() instanceof ElementAt); + } + + @Test + public void testGetVariantPathKeyNonString() throws Exception { + VariantField field = new VariantField("number_*", BigIntType.INSTANCE, ""); + VariantType variantType = new VariantType(ImmutableList.of(field)); + SlotReference slot = buildVariantSlot(variantType); + + ElementAt elementAt = new ElementAt(slot, new BigIntLiteral(1)); + ExpressionAnalyzer analyzer = new ExpressionAnalyzer(null, new Scope(ImmutableList.of(slot)), + createContext(true), true, true); + + Method getVariantPathKey = ExpressionAnalyzer.class.getDeclaredMethod("getVariantPathKey", Expression.class); + getVariantPathKey.setAccessible(true); + @SuppressWarnings("unchecked") + Optional key = (Optional) getVariantPathKey.invoke(analyzer, new BigIntLiteral(1)); + Assertions.assertFalse(key.isPresent()); + + Method resolvePath = ExpressionAnalyzer.class.getDeclaredMethod("resolveVariantElementAtPath", ElementAt.class); + resolvePath.setAccessible(true); + @SuppressWarnings("unchecked") + Optional path = (Optional) resolvePath.invoke(analyzer, elementAt); + Assertions.assertFalse(path.isPresent()); + } + + @Test + public void testMaybeCastAliasExpression() { + VariantField field = new VariantField("int_nested.level1_num_1", BigIntType.INSTANCE, ""); + VariantType variantType = new VariantType(ImmutableList.of(field)); + SlotReference slot = buildVariantSlot(variantType); + Scope scope = new Scope(ImmutableList.of(slot)); + + UnboundSlot unbound = new UnboundSlot("data", "int_nested", "level1_num_1"); + Expression result = analyze(unbound, scope, true); + + Assertions.assertTrue(result instanceof Alias); + Alias alias = (Alias) result; + Assertions.assertTrue(alias.child() instanceof Cast); + Cast cast = (Cast) alias.child(); + Assertions.assertEquals(BigIntType.INSTANCE, cast.getDataType()); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaTemplateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaTemplateTest.java deleted file mode 100644 index 171e9d924a6978..00000000000000 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/VariantSchemaTemplateTest.java +++ /dev/null @@ -1,386 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.rewrite; - -import org.apache.doris.nereids.trees.expressions.And; -import org.apache.doris.nereids.trees.expressions.Cast; -import org.apache.doris.nereids.trees.expressions.EqualTo; -import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.GreaterThan; -import org.apache.doris.nereids.trees.expressions.LessThan; -import org.apache.doris.nereids.trees.expressions.Or; -import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt; -import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; -import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral; -import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; -import org.apache.doris.nereids.types.BigIntType; -import org.apache.doris.nereids.types.DataType; -import org.apache.doris.nereids.types.DoubleType; -import org.apache.doris.nereids.types.StringType; -import org.apache.doris.nereids.types.VariantField; -import org.apache.doris.nereids.types.VariantType; - -import com.google.common.collect.ImmutableList; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -import java.util.Collections; -import java.util.List; -import java.util.function.Function; - -/** - * Unit tests for variant schema template auto-cast expression rewriting. - */ -public class VariantSchemaTemplateTest { - - // Expression rewriter for variant schema template auto-cast - private static final Function EXPRESSION_REWRITER = expr -> { - if (!(expr instanceof ElementAt)) { - return expr; - } - ElementAt elementAt = (ElementAt) expr; - Expression left = elementAt.left(); - Expression right = elementAt.right(); - - if (!(left.getDataType() instanceof VariantType)) { - return expr; - } - if (!(right instanceof VarcharLiteral)) { - return expr; - } - - VariantType variantType = (VariantType) left.getDataType(); - String fieldName = ((VarcharLiteral) right).getStringValue(); - - return variantType.findMatchingField(fieldName) - .map(field -> (Expression) new Cast(elementAt, field.getDataType())) - .orElse(expr); - }; - - private Expression rewriteExpression(Expression expr) { - return expr.rewriteDownShortCircuit(EXPRESSION_REWRITER); - } - - @Test - public void testRewriteElementAtWithMatchingPattern() { - // Create variant type with schema template - VariantField numberField = new VariantField("number_*", BigIntType.INSTANCE, ""); - VariantType variantType = new VariantType(ImmutableList.of(numberField)); - - // Create element_at expression: variant['number_latency'] - MockVariantSlot variantSlot = new MockVariantSlot(variantType); - ElementAt elementAt = new ElementAt(variantSlot, new VarcharLiteral("number_latency")); - - // Rewrite - Expression result = rewriteExpression(elementAt); - - // Should be wrapped with Cast - Assertions.assertTrue(result instanceof Cast); - Cast cast = (Cast) result; - Assertions.assertEquals(BigIntType.INSTANCE, cast.getDataType()); - Assertions.assertTrue(cast.child() instanceof ElementAt); - } - - @Test - public void testRewriteElementAtWithNoMatchingPattern() { - // Create variant type with schema template - VariantField numberField = new VariantField("number_*", BigIntType.INSTANCE, ""); - VariantType variantType = new VariantType(ImmutableList.of(numberField)); - - // Create element_at expression: variant['string_message'] (no match) - MockVariantSlot variantSlot = new MockVariantSlot(variantType); - ElementAt elementAt = new ElementAt(variantSlot, new VarcharLiteral("string_message")); - - // Rewrite - Expression result = rewriteExpression(elementAt); - - // Should NOT be wrapped with Cast - Assertions.assertTrue(result instanceof ElementAt); - Assertions.assertFalse(result instanceof Cast); - } - - @Test - public void testRewriteElementAtWithEmptySchemaTemplate() { - // Create variant type without schema template - VariantType variantType = new VariantType(0); - - // Create element_at expression - MockVariantSlot variantSlot = new MockVariantSlot(variantType); - ElementAt elementAt = new ElementAt(variantSlot, new VarcharLiteral("any_field")); - - // Rewrite - Expression result = rewriteExpression(elementAt); - - // Should NOT be wrapped with Cast - Assertions.assertTrue(result instanceof ElementAt); - Assertions.assertFalse(result instanceof Cast); - } - - @Test - public void testRewriteCompoundExpression() { - // Create variant type with schema template - VariantField numberField = new VariantField("number_*", BigIntType.INSTANCE, ""); - VariantType variantType = new VariantType(ImmutableList.of(numberField)); - - // Create compound expression: variant['number_latency'] > 100 - MockVariantSlot variantSlot = new MockVariantSlot(variantType); - ElementAt elementAt = new ElementAt(variantSlot, new VarcharLiteral("number_latency")); - GreaterThan greaterThan = new GreaterThan(elementAt, new BigIntLiteral(100)); - - // Rewrite - Expression result = rewriteExpression(greaterThan); - - // Should be GreaterThan with Cast(ElementAt) on left - Assertions.assertTrue(result instanceof GreaterThan); - GreaterThan rewrittenGt = (GreaterThan) result; - Assertions.assertTrue(rewrittenGt.left() instanceof Cast); - Cast cast = (Cast) rewrittenGt.left(); - Assertions.assertEquals(BigIntType.INSTANCE, cast.getDataType()); - } - - @Test - public void testRewriteMultiplePatterns() { - // Create variant type with multiple patterns - VariantField numberField = new VariantField("number_*", BigIntType.INSTANCE, ""); - VariantField stringField = new VariantField("string_*", StringType.INSTANCE, ""); - VariantType variantType = new VariantType(ImmutableList.of(numberField, stringField)); - - MockVariantSlot variantSlot = new MockVariantSlot(variantType); - - // Test number pattern - ElementAt numberElementAt = new ElementAt(variantSlot, new VarcharLiteral("number_count")); - Expression numberResult = rewriteExpression(numberElementAt); - Assertions.assertTrue(numberResult instanceof Cast); - Assertions.assertEquals(BigIntType.INSTANCE, ((Cast) numberResult).getDataType()); - - // Test string pattern - ElementAt stringElementAt = new ElementAt(variantSlot, new VarcharLiteral("string_msg")); - Expression stringResult = rewriteExpression(stringElementAt); - Assertions.assertTrue(stringResult instanceof Cast); - Assertions.assertEquals(StringType.INSTANCE, ((Cast) stringResult).getDataType()); - } - - @Test - public void testRewriteAndCondition() { - // Create variant type with schema template - VariantField numberField = new VariantField("number_*", BigIntType.INSTANCE, ""); - VariantType variantType = new VariantType(ImmutableList.of(numberField)); - - MockVariantSlot variantSlot = new MockVariantSlot(variantType); - - // Create AND condition: variant['number_a'] > 10 AND variant['number_b'] < 100 - ElementAt elementAtA = new ElementAt(variantSlot, new VarcharLiteral("number_a")); - ElementAt elementAtB = new ElementAt(variantSlot, new VarcharLiteral("number_b")); - GreaterThan gt = new GreaterThan(elementAtA, new BigIntLiteral(10)); - LessThan lt = new LessThan(elementAtB, new BigIntLiteral(100)); - And andExpr = new And(gt, lt); - - // Rewrite - Expression result = rewriteExpression(andExpr); - - // Should be And with Cast on both sides - Assertions.assertTrue(result instanceof And); - And rewrittenAnd = (And) result; - - // Left side: Cast(ElementAt) > 10 - Assertions.assertTrue(rewrittenAnd.child(0) instanceof GreaterThan); - GreaterThan leftGt = (GreaterThan) rewrittenAnd.child(0); - Assertions.assertTrue(leftGt.left() instanceof Cast); - Assertions.assertEquals(BigIntType.INSTANCE, ((Cast) leftGt.left()).getDataType()); - - // Right side: Cast(ElementAt) < 100 - Assertions.assertTrue(rewrittenAnd.child(1) instanceof LessThan); - LessThan rightLt = (LessThan) rewrittenAnd.child(1); - Assertions.assertTrue(rightLt.left() instanceof Cast); - Assertions.assertEquals(BigIntType.INSTANCE, ((Cast) rightLt.left()).getDataType()); - } - - @Test - public void testRewriteOrCondition() { - // Create variant type with multiple patterns - VariantField numberField = new VariantField("number_*", BigIntType.INSTANCE, ""); - VariantField stringField = new VariantField("string_*", StringType.INSTANCE, ""); - VariantType variantType = new VariantType(ImmutableList.of(numberField, stringField)); - - MockVariantSlot variantSlot = new MockVariantSlot(variantType); - - // Create OR condition: variant['number_a'] > 10 OR variant['string_b'] = 'test' - ElementAt elementAtA = new ElementAt(variantSlot, new VarcharLiteral("number_a")); - ElementAt elementAtB = new ElementAt(variantSlot, new VarcharLiteral("string_b")); - GreaterThan gt = new GreaterThan(elementAtA, new BigIntLiteral(10)); - EqualTo eq = new EqualTo(elementAtB, new VarcharLiteral("test")); - Or orExpr = new Or(gt, eq); - - // Rewrite - Expression result = rewriteExpression(orExpr); - - // Should be Or with Cast on both sides - Assertions.assertTrue(result instanceof Or); - Or rewrittenOr = (Or) result; - - // Left side: Cast(ElementAt) > 10 with BIGINT - Assertions.assertTrue(rewrittenOr.child(0) instanceof GreaterThan); - GreaterThan leftGt = (GreaterThan) rewrittenOr.child(0); - Assertions.assertTrue(leftGt.left() instanceof Cast); - Assertions.assertEquals(BigIntType.INSTANCE, ((Cast) leftGt.left()).getDataType()); - - // Right side: Cast(ElementAt) = 'test' with STRING - Assertions.assertTrue(rewrittenOr.child(1) instanceof EqualTo); - EqualTo rightEq = (EqualTo) rewrittenOr.child(1); - Assertions.assertTrue(rightEq.left() instanceof Cast); - Assertions.assertEquals(StringType.INSTANCE, ((Cast) rightEq.left()).getDataType()); - } - - @Test - public void testFirstMatchWins() { - // Create variant type with overlapping patterns - first match should win - // 'num*' matches 'number_val', 'number_*' also matches 'number_val' - // First pattern 'num*' should be used - VariantField numField = new VariantField("num*", BigIntType.INSTANCE, ""); - VariantField numberField = new VariantField("number_*", DoubleType.INSTANCE, ""); - VariantType variantType = new VariantType(ImmutableList.of(numField, numberField)); - - MockVariantSlot variantSlot = new MockVariantSlot(variantType); - - // 'number_val' matches both patterns, but 'num*' is first - ElementAt elementAt = new ElementAt(variantSlot, new VarcharLiteral("number_val")); - Expression result = rewriteExpression(elementAt); - - Assertions.assertTrue(result instanceof Cast); - Cast cast = (Cast) result; - // Should be BIGINT (from 'num*'), not DOUBLE (from 'number_*') - Assertions.assertEquals(BigIntType.INSTANCE, cast.getDataType()); - } - - @Test - public void testMixedMatchingAndNonMatching() { - // Create variant type with one pattern - VariantField numberField = new VariantField("number_*", BigIntType.INSTANCE, ""); - VariantType variantType = new VariantType(ImmutableList.of(numberField)); - - MockVariantSlot variantSlot = new MockVariantSlot(variantType); - - // Create condition: variant['number_a'] > variant['other_field'] - // number_a matches, other_field does not - ElementAt matchingElementAt = new ElementAt(variantSlot, new VarcharLiteral("number_a")); - ElementAt nonMatchingElementAt = new ElementAt(variantSlot, new VarcharLiteral("other_field")); - GreaterThan gt = new GreaterThan(matchingElementAt, nonMatchingElementAt); - - // Rewrite - Expression result = rewriteExpression(gt); - - Assertions.assertTrue(result instanceof GreaterThan); - GreaterThan rewrittenGt = (GreaterThan) result; - - // Left side should be Cast(ElementAt) - Assertions.assertTrue(rewrittenGt.left() instanceof Cast); - Assertions.assertEquals(BigIntType.INSTANCE, ((Cast) rewrittenGt.left()).getDataType()); - - // Right side should remain as ElementAt (no cast) - Assertions.assertTrue(rewrittenGt.right() instanceof ElementAt); - Assertions.assertFalse(rewrittenGt.right() instanceof Cast); - } - - @Test - public void testGlobPatternWithQuestionMark() { - // Test glob pattern with ? (matches single character) - VariantField field = new VariantField("val?", BigIntType.INSTANCE, ""); - VariantType variantType = new VariantType(ImmutableList.of(field)); - - MockVariantSlot variantSlot = new MockVariantSlot(variantType); - - // 'val1' should match 'val?' - ElementAt matchingElementAt = new ElementAt(variantSlot, new VarcharLiteral("val1")); - Expression matchResult = rewriteExpression(matchingElementAt); - Assertions.assertTrue(matchResult instanceof Cast); - Assertions.assertEquals(BigIntType.INSTANCE, ((Cast) matchResult).getDataType()); - - // 'val12' should NOT match 'val?' (? matches only one char) - ElementAt nonMatchingElementAt = new ElementAt(variantSlot, new VarcharLiteral("val12")); - Expression nonMatchResult = rewriteExpression(nonMatchingElementAt); - Assertions.assertTrue(nonMatchResult instanceof ElementAt); - Assertions.assertFalse(nonMatchResult instanceof Cast); - } - - @Test - public void testGlobPatternWithBrackets() { - // Test glob pattern with [...] (character class) - VariantField field = new VariantField("type_[abc]", BigIntType.INSTANCE, ""); - VariantType variantType = new VariantType(ImmutableList.of(field)); - - MockVariantSlot variantSlot = new MockVariantSlot(variantType); - - // 'type_a' should match 'type_[abc]' - ElementAt matchingElementAt = new ElementAt(variantSlot, new VarcharLiteral("type_a")); - Expression matchResult = rewriteExpression(matchingElementAt); - Assertions.assertTrue(matchResult instanceof Cast); - - // 'type_d' should NOT match 'type_[abc]' - ElementAt nonMatchingElementAt = new ElementAt(variantSlot, new VarcharLiteral("type_d")); - Expression nonMatchResult = rewriteExpression(nonMatchingElementAt); - Assertions.assertTrue(nonMatchResult instanceof ElementAt); - Assertions.assertFalse(nonMatchResult instanceof Cast); - } - - /** - * Mock Expression class for providing VariantType in tests. - */ - private static class MockVariantSlot extends Expression { - private final VariantType variantType; - - public MockVariantSlot(VariantType variantType) { - super(Collections.emptyList()); - this.variantType = variantType; - } - - @Override - public DataType getDataType() { - return variantType; - } - - @Override - public boolean nullable() { - return true; - } - - @Override - public Expression withChildren(List children) { - return this; - } - - @Override - public R accept(ExpressionVisitor visitor, C context) { - return visitor.visit(this, context); - } - - @Override - public int arity() { - return 0; - } - - @Override - public Expression child(int index) { - throw new IndexOutOfBoundsException(); - } - - @Override - public List children() { - return Collections.emptyList(); - } - } -} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/types/VariantFieldMatchTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/types/VariantFieldMatchTest.java index 27969e0a6cf42f..146d7151324c38 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/types/VariantFieldMatchTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/types/VariantFieldMatchTest.java @@ -42,6 +42,15 @@ public void testExactMatch() { Assertions.assertFalse(field.matches("other_field")); } + @Test + public void testExactMatchDoesNotTreatGlob() { + VariantField field = new VariantField("num_*", BigIntType.INSTANCE, "", + TPatternType.MATCH_NAME.name()); + + Assertions.assertTrue(field.matches("num_*")); + Assertions.assertFalse(field.matches("num_a")); + } + @Test public void testGlobMatchSuffix() { // Pattern: number_* should match number_latency, number_count, etc. @@ -103,6 +112,16 @@ public void testGlobMatchWithDot() { Assertions.assertFalse(field.matches("metrics")); } + @Test + public void testGlobMatchDotLiteral() { + // '.' should be treated as literal in glob and escaped in regex + VariantField field = new VariantField("a.b", BigIntType.INSTANCE, "", + TPatternType.MATCH_NAME_GLOB.name()); + + Assertions.assertTrue(field.matches("a.b")); + Assertions.assertFalse(field.matches("acb")); + } + @Test public void testDefaultPatternTypeIsGlob() { // Default constructor should use MATCH_NAME_GLOB @@ -215,6 +234,16 @@ public void testGlobEscapeBackslash() { Assertions.assertFalse(field.matches("int_")); } + @Test + public void testGlobUnclosedBracket() { + // No closing bracket: '[' treated as literal + VariantField field = new VariantField("int_[0-9", BigIntType.INSTANCE, "", + TPatternType.MATCH_NAME_GLOB.name()); + + Assertions.assertTrue(field.matches("int_[0-9")); + Assertions.assertFalse(field.matches("int_1")); + } + @Test public void testGlobWithSlashSeparator() { // With FNM_PATHNAME, '*' should not match '/' diff --git a/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out b/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out index 5fdcfefd1c9a73..c150ee2a4bce94 100644 --- a/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out +++ b/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out @@ -86,47 +86,183 @@ charlie 50 -- !leaf_int1_select -- 1 +2 +1 +3 -- !leaf_int1_add -- 2 +3 +2 +4 -- !leaf_int_nested_nonleaf -- \\N +\\N +\\N +\\N -- !leaf_int_nested_chain_select -- 1011111 +2022222 +3033333 +4044444 -- !leaf_int_nested_dot_select -- 1011111 +2022222 +3033333 +4044444 -- !leaf_int_nested_deref_select -- 1011111 +2022222 +3033333 +4044444 + +-- !leaf_int_nested_mixed_select_1 -- +1011111 +2022222 +3033333 +4044444 + +-- !leaf_int_nested_mixed_select_2 -- +1011111 +2022222 +3033333 +4044444 + +-- !leaf_int_nested_mixed_select_3 -- +1011111 +2022222 +3033333 +4044444 + +-- !leaf_int_nested_paren_root_select -- +1011111 +2022222 +3033333 +4044444 -- !leaf_int_nested_chain_add -- 1011112 +2022223 +3033334 +4044445 -- !leaf_int_nested_dot_add -- 1011112 +2022223 +3033334 +4044445 -- !leaf_int_nested_deref_add -- 1011112 +2022223 +3033334 +4044445 + +-- !leaf_int_nested_mixed_add_1 -- +1011112 +2022223 +3033334 +4044445 + +-- !leaf_int_nested_mixed_add_2 -- +1011112 +2022223 +3033334 +4044445 + +-- !leaf_int_nested_mixed_add_3 -- +1011112 +2022223 +3033334 +4044445 + +-- !leaf_int_nested_paren_root_add -- +1011112 +2022223 +3033334 +4044445 -- !leaf_where_ok -- 1 +2 +3 +4 -- !leaf_where_nonleaf -- +-- !leaf_where_mixed_1 -- +2 +3 +4 + +-- !leaf_where_mixed_2 -- +2 +3 +4 + +-- !leaf_where_mixed_3 -- +2 +3 +4 + +-- !leaf_where_paren_root -- +2 +3 +4 + -- !leaf_order_by_ok -- 1 +3 +2 +4 -- !leaf_order_by_nonleaf -- 1 +2 +3 +4 + +-- !leaf_order_by_mixed_1 -- +1 +2 +3 +4 + +-- !leaf_order_by_mixed_2 -- +1 +2 +3 +4 + +-- !leaf_order_by_paren_root -- +1 +2 +3 +4 -- !leaf_group_by_ok -- -1 1 +1 2 +2 1 +3 1 -- !leaf_group_by_nonleaf -- -\\N 1 +\\N 4 + +-- !leaf_group_by_mixed -- +1011111 1 +2022222 1 +3033333 1 +4044444 1 -- !leaf_having_ok -- -1 1 +1 2 +2 2 +3 3 + +-- !leaf_having_mixed -- +3033333 3033333 +4044444 4044444 diff --git a/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy b/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy index 071de9ec3d195e..25948b8ff6a6c6 100644 --- a/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy +++ b/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy @@ -180,10 +180,11 @@ suite("test_schema_template_auto_cast", "p0") { DISTRIBUTED BY HASH(`id`) BUCKETS 1 PROPERTIES ( "replication_allocation" = "tag.location.default: 1")""" - sql """insert into ${leafTable} values( - 1, - '{"int_1": 1, "int_nested": {"level1_num_1": 1011111, "level1_num_2": 102}}' - )""" + sql """insert into ${leafTable} values + (1, '{"int_1": 1, "int_nested": {"level1_num_1": 1011111, "level1_num_2": 102}}'), + (2, '{"int_1": 2, "int_nested": {"level1_num_1": 2022222, "level1_num_2": 202}}'), + (3, '{"int_1": 1, "int_nested": {"level1_num_1": 3033333, "level1_num_2": 302}}'), + (4, '{"int_1": 3, "int_nested": {"level1_num_1": 4044444, "level1_num_2": 402}}')""" qt_leaf_int1_select """ SELECT data['int_1'] FROM ${leafTable} ORDER BY id """ qt_leaf_int1_add """ SELECT data['int_1'] + 1 FROM ${leafTable} ORDER BY id """ @@ -193,29 +194,58 @@ suite("test_schema_template_auto_cast", "p0") { FROM ${leafTable} ORDER BY id """ qt_leaf_int_nested_dot_select """ SELECT data['int_nested.level1_num_1'] FROM ${leafTable} ORDER BY id """ qt_leaf_int_nested_deref_select """ SELECT data.int_nested.level1_num_1 FROM ${leafTable} ORDER BY id """ + qt_leaf_int_nested_mixed_select_1 """ SELECT data['int_nested'].level1_num_1 FROM ${leafTable} ORDER BY id """ + qt_leaf_int_nested_mixed_select_2 """ SELECT (data['int_nested']).level1_num_1 FROM ${leafTable} ORDER BY id """ + qt_leaf_int_nested_mixed_select_3 """ SELECT (data.int_nested).level1_num_1 FROM ${leafTable} ORDER BY id """ + qt_leaf_int_nested_paren_root_select """ SELECT (data).int_nested.level1_num_1 FROM ${leafTable} ORDER BY id """ qt_leaf_int_nested_chain_add """ SELECT data['int_nested']['level1_num_1'] + 1 FROM ${leafTable} ORDER BY id """ qt_leaf_int_nested_dot_add """ SELECT data['int_nested.level1_num_1'] + 1 FROM ${leafTable} ORDER BY id """ qt_leaf_int_nested_deref_add """ SELECT data.int_nested.level1_num_1 + 1 FROM ${leafTable} ORDER BY id """ + qt_leaf_int_nested_mixed_add_1 """ SELECT data['int_nested'].level1_num_1 + 1 FROM ${leafTable} ORDER BY id """ + qt_leaf_int_nested_mixed_add_2 """ SELECT (data['int_nested']).level1_num_1 + 1 FROM ${leafTable} ORDER BY id """ + qt_leaf_int_nested_mixed_add_3 """ SELECT (data.int_nested).level1_num_1 + 1 FROM ${leafTable} ORDER BY id """ + qt_leaf_int_nested_paren_root_add """ SELECT (data).int_nested.level1_num_1 + 1 FROM ${leafTable} ORDER BY id """ // Non-select clauses: leaf vs non-leaf qt_leaf_where_ok """ SELECT id FROM ${leafTable} WHERE data['int_1'] > 0 ORDER BY id """ qt_leaf_where_nonleaf """ SELECT id FROM ${leafTable} WHERE data['int_nested'] > 0 ORDER BY id """ + qt_leaf_where_mixed_1 """ SELECT id FROM ${leafTable} + WHERE data['int_nested'].level1_num_1 > 2000000 ORDER BY id """ + qt_leaf_where_mixed_2 """ SELECT id FROM ${leafTable} + WHERE (data['int_nested']).level1_num_1 > 2000000 ORDER BY id """ + qt_leaf_where_mixed_3 """ SELECT id FROM ${leafTable} + WHERE (data.int_nested).level1_num_1 > 2000000 ORDER BY id """ + qt_leaf_where_paren_root """ SELECT id FROM ${leafTable} + WHERE (data).int_nested.level1_num_1 > 2000000 ORDER BY id """ qt_leaf_order_by_ok """ SELECT id FROM ${leafTable} - ORDER BY data['int_1'] """ + ORDER BY data['int_1'], id """ qt_leaf_order_by_nonleaf """ SELECT id FROM ${leafTable} - ORDER BY data['int_nested'] """ + ORDER BY data['int_nested'], id """ + qt_leaf_order_by_mixed_1 """ SELECT id FROM ${leafTable} + ORDER BY data['int_nested'].level1_num_1 """ + qt_leaf_order_by_mixed_2 """ SELECT id FROM ${leafTable} + ORDER BY (data.int_nested).level1_num_1 """ + qt_leaf_order_by_paren_root """ SELECT id FROM ${leafTable} + ORDER BY (data).int_nested.level1_num_1 """ qt_leaf_group_by_ok """ SELECT data['int_1'], COUNT(*) AS cnt FROM ${leafTable} GROUP BY data['int_1'] ORDER BY data['int_1'] """ qt_leaf_group_by_nonleaf """ SELECT data['int_nested'], COUNT(*) AS cnt FROM ${leafTable} GROUP BY data['int_nested'] ORDER BY data['int_nested'] """ + qt_leaf_group_by_mixed """ SELECT data['int_nested'].level1_num_1, COUNT(*) AS cnt + FROM ${leafTable} GROUP BY data['int_nested'].level1_num_1 + ORDER BY data['int_nested'].level1_num_1 """ qt_leaf_having_ok """ SELECT data['int_1'], SUM(data['int_1']) AS total FROM ${leafTable} GROUP BY data['int_1'] HAVING SUM(data['int_1']) > 0 ORDER BY data['int_1'] """ + qt_leaf_having_mixed """ SELECT data['int_nested'].level1_num_1, SUM(data['int_nested'].level1_num_1) AS total + FROM ${leafTable} GROUP BY data['int_nested'].level1_num_1 + HAVING SUM(data['int_nested'].level1_num_1) > 3000000 + ORDER BY data['int_nested'].level1_num_1 """ sql "DROP TABLE IF EXISTS ${leafTable}" } From 11d0df9bb219dbb25782623630047691aaf3a77a Mon Sep 17 00:00:00 2001 From: Gary Date: Tue, 3 Feb 2026 00:52:04 +0800 Subject: [PATCH 18/27] revert multi cast of SlotReference --- .../nereids/rules/rewrite/CheckMatchExpression.java | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMatchExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMatchExpression.java index 5ab28c7d5dbb36..623c3085962b47 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMatchExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CheckMatchExpression.java @@ -49,13 +49,11 @@ private Plan checkChildren(LogicalFilter filter) { for (Expression expr : expressions) { if (expr instanceof Match) { Match matchExpression = (Match) expr; - // Unwrap all Cast layers to find the innermost expression - Expression left = matchExpression.left(); - while (left instanceof Cast) { - left = left.child(0); - } - boolean isSlotReference = left instanceof SlotReference; - if (!isSlotReference || !(matchExpression.right() instanceof Literal)) { + boolean isSlotReference = matchExpression.left() instanceof SlotReference; + boolean isCastChildWithSlotReference = (matchExpression.left() instanceof Cast + && matchExpression.left().child(0) instanceof SlotReference); + if (!(isSlotReference || isCastChildWithSlotReference) + || !(matchExpression.right() instanceof Literal)) { throw new AnalysisException(String.format("Only support match left operand is SlotRef," + " right operand is Literal. But meet expression %s", matchExpression)); } From d13722585f0912d15851f149fccce206d2a994f6 Mon Sep 17 00:00:00 2001 From: Gary Date: Wed, 4 Feb 2026 00:46:52 +0800 Subject: [PATCH 19/27] fix review --- .../rules/analysis/ExpressionAnalyzer.java | 33 +-- .../doris/nereids/types/VariantField.java | 127 +-------- ...ExpressionAnalyzerVariantAutoCastTest.java | 267 ++++++++++++++---- .../nereids/types/VariantFieldMatchTest.java | 10 +- .../test_schema_template_auto_cast.out | 139 +++++---- .../test_schema_template_auto_cast.groovy | 105 +++++-- 6 files changed, 393 insertions(+), 288 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java index c04a67f614ecdf..e2fe0e1e0fb9a2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java @@ -266,20 +266,7 @@ public Expression visitUnboundAlias(UnboundAlias unboundAlias, ExpressionRewrite @Override public Expression visitDereferenceExpression(DereferenceExpression dereferenceExpression, ExpressionRewriteContext context) { - boolean suppressChildCast = isEnableVariantSchemaAutoCast(context) - && (dereferenceExpression.child(0) instanceof DereferenceExpression - || dereferenceExpression.child(0) instanceof ElementAt); - if (suppressChildCast) { - suppressVariantElementAtCastDepth++; - } - Expression expression; - try { - expression = dereferenceExpression.child(0).accept(this, context); - } finally { - if (suppressChildCast) { - suppressVariantElementAtCastDepth--; - } - } + Expression expression = dereferenceExpression.child(0).accept(this, context); DataType dataType = expression.getDataType(); if (dataType.isStructType()) { StructType structType = (StructType) dataType; @@ -290,11 +277,7 @@ public Expression visitDereferenceExpression(DereferenceExpression dereferenceEx } else if (dataType.isMapType()) { return new ElementAt(expression, dereferenceExpression.child(1)); } else if (dataType.isVariantType()) { - Expression elementAt = new ElementAt(expression, dereferenceExpression.child(1)); - if (isEnableVariantSchemaAutoCast(context)) { - return wrapVariantElementAtWithCast(elementAt); - } - return elementAt; + return new ElementAt(expression, dereferenceExpression.child(1)); } throw new AnalysisException("Can not dereference field: " + dereferenceExpression.fieldName); } @@ -1012,17 +995,7 @@ public Expression visitMatch(Match match, ExpressionRewriteContext context) { @Override public Expression visitCast(Cast cast, ExpressionRewriteContext context) { - boolean suppressVariantElementAtCast = shouldSuppressVariantElementAtCast(cast); - if (suppressVariantElementAtCast) { - suppressVariantElementAtCastDepth++; - } - try { - cast = (Cast) super.visitCast(cast, context); - } finally { - if (suppressVariantElementAtCast) { - suppressVariantElementAtCastDepth--; - } - } + cast = (Cast) super.visitCast(cast, context); // NOTICE: just for compatibility with legacy planner. if (cast.child().getDataType().isComplexType() || cast.getDataType().isComplexType()) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantField.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantField.java index 1313e8579b3f3e..ccffa2f6ba5e40 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantField.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantField.java @@ -20,7 +20,12 @@ import org.apache.doris.nereids.util.Utils; import org.apache.doris.thrift.TPatternType; +import java.nio.file.FileSystems; +import java.nio.file.InvalidPathException; +import java.nio.file.PathMatcher; +import java.nio.file.Paths; import java.util.Objects; +import java.util.regex.PatternSyntaxException; /** * A field inside a VariantType. @@ -67,10 +72,6 @@ public String getComment() { return comment; } - public TPatternType getPatternType() { - return patternType; - } - /** * Check if the given field name matches this field's pattern. * This method aligns with BE's fnmatch(pattern, path, FNM_PATHNAME) behavior. @@ -79,7 +80,7 @@ public TPatternType getPatternType() { * - '*' matches any sequence of characters except '/' * - '?' matches any single character except '/' * - '[...]' matches any character in the brackets - * - '[!...]' or '[^...]' matches any character not in the brackets + * - '[!...]' matches any character not in the brackets * * @param fieldName the field name to check * @return true if the field name matches the pattern @@ -87,116 +88,16 @@ public TPatternType getPatternType() { public boolean matches(String fieldName) { if (patternType == TPatternType.MATCH_NAME) { return pattern.equals(fieldName); - } else { - // MATCH_NAME_GLOB: convert glob pattern to regex - // This aligns with BE's fnmatch(pattern, path, FNM_PATHNAME) - String regex = globToRegex(pattern); - return fieldName.matches(regex); } - } - - /** - * Convert glob pattern to regex pattern, aligning with fnmatch(FNM_PATHNAME) behavior. - * - * fnmatch with FNM_PATHNAME flag behavior: - * - '*' matches any sequence of characters except '/' - * - '?' matches any single character except '/' - * - '[...]' matches any character in the brackets - * - '[!...]' or '[^...]' matches any character not in the brackets - * - '\' escapes the next character (e.g., '\*' matches literal '*') - */ - private static String globToRegex(String glob) { - StringBuilder regex = new StringBuilder(); - int i = 0; - int len = glob.length(); - - while (i < len) { - char c = glob.charAt(i); - switch (c) { - case '\\': - // Escape sequence: next character should be matched literally - // This aligns with fnmatch behavior where \* matches literal * - if (i + 1 < len) { - i++; - char nextChar = glob.charAt(i); - // Escape the next character for regex if it's a regex special char - if (isRegexSpecialChar(nextChar)) { - regex.append('\\'); - } - regex.append(nextChar); - } else { - // Trailing backslash, treat as literal backslash - regex.append("\\\\"); - } - break; - case '*': - // '*' matches any sequence of characters except '/' (FNM_PATHNAME) - regex.append("[^/]*"); - break; - case '?': - // '?' matches any single character except '/' (FNM_PATHNAME) - regex.append("[^/]"); - break; - case '[': - // Character class - find the closing bracket - int j = i + 1; - // Handle negation: [! or [^ - if (j < len && (glob.charAt(j) == '!' || glob.charAt(j) == '^')) { - j++; - } - // Handle ] as first character in class - if (j < len && glob.charAt(j) == ']') { - j++; - } - // Find closing ] - while (j < len && glob.charAt(j) != ']') { - j++; - } - if (j >= len) { - // No closing bracket, treat [ as literal - regex.append("\\["); - } else { - // Extract the character class content - String classContent = glob.substring(i + 1, j); - regex.append('['); - // Convert [! to [^ - if (classContent.startsWith("!")) { - regex.append('^').append(classContent.substring(1)); - } else { - regex.append(classContent); - } - regex.append(']'); - i = j; // Move past the closing ] - } - break; - // Escape regex special characters (except backslash which is handled above) - case '.': - case '(': - case ')': - case '{': - case '}': - case '+': - case '^': - case '$': - case '|': - regex.append('\\').append(c); - break; - default: - regex.append(c); - break; - } - i++; + if (patternType != TPatternType.MATCH_NAME_GLOB) { + return false; + } + try { + PathMatcher matcher = FileSystems.getDefault().getPathMatcher("glob:" + pattern); + return matcher.matches(Paths.get(fieldName)); + } catch (PatternSyntaxException | InvalidPathException e) { + return false; } - return regex.toString(); - } - - /** - * Check if a character is a regex special character that needs escaping. - */ - private static boolean isRegexSpecialChar(char c) { - return c == '\\' || c == '.' || c == '(' || c == ')' || c == '[' - || c == ']' || c == '{' || c == '}' || c == '+' || c == '*' - || c == '?' || c == '^' || c == '$' || c == '|'; } public org.apache.doris.catalog.VariantField toCatalogDataType() { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzerVariantAutoCastTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzerVariantAutoCastTest.java index cc99188fa6d938..b6d8434f6b6a0e 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzerVariantAutoCastTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzerVariantAutoCastTest.java @@ -21,15 +21,24 @@ import org.apache.doris.nereids.analyzer.Scope; import org.apache.doris.nereids.analyzer.UnboundSlot; import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Between; import org.apache.doris.nereids.trees.expressions.Cast; -import org.apache.doris.nereids.trees.expressions.DereferenceExpression; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.GreaterThan; +import org.apache.doris.nereids.trees.expressions.InPredicate; import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.agg.Avg; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.Max; +import org.apache.doris.nereids.trees.expressions.functions.agg.Min; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; import org.apache.doris.nereids.trees.expressions.functions.scalar.ElementAt; import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; import org.apache.doris.nereids.trees.expressions.literal.StringLiteral; import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.types.StringType; import org.apache.doris.nereids.types.VariantField; import org.apache.doris.nereids.types.VariantType; import org.apache.doris.qe.ConnectContext; @@ -39,9 +48,6 @@ import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import java.lang.reflect.Method; -import java.util.Optional; - public class ExpressionAnalyzerVariantAutoCastTest { @AfterEach @@ -63,111 +69,250 @@ private Expression analyze(Expression expr, Scope scope, boolean enableAutoCast) } private SlotReference buildVariantSlot(VariantType variantType) { - return new SlotReference(new ExprId(1), "data", variantType, true, ImmutableList.of()); + return new SlotReference(new org.apache.doris.nereids.trees.expressions.ExprId(1), + "data", variantType, true, ImmutableList.of()); + } + + private VariantType buildVariantType() { + VariantField numField = new VariantField("num_*", BigIntType.INSTANCE, ""); + VariantField strField = new VariantField("str_*", StringType.INSTANCE, ""); + return new VariantType(ImmutableList.of(numField, strField)); + } + + private void assertCastElementAt(Expression expr) { + Assertions.assertTrue(expr instanceof Cast, "expect Cast wrapping ElementAt"); + Cast cast = (Cast) expr; + Assertions.assertTrue(cast.child() instanceof ElementAt, "cast child should be ElementAt"); } @Test - public void testVisitElementAtAutoCastEnabled() { - VariantField field = new VariantField("number_*", BigIntType.INSTANCE, ""); - VariantType variantType = new VariantType(ImmutableList.of(field)); + public void testSelectAutoCastElementAt() { + VariantType variantType = buildVariantType(); SlotReference slot = buildVariantSlot(variantType); Scope scope = new Scope(ImmutableList.of(slot)); - ElementAt elementAt = new ElementAt(slot, new StringLiteral("number_latency")); + ElementAt elementAt = new ElementAt(slot, new StringLiteral("num_a")); Expression result = analyze(elementAt, scope, true); + assertCastElementAt(result); + } - Assertions.assertTrue(result instanceof Cast); - Cast cast = (Cast) result; - Assertions.assertEquals(BigIntType.INSTANCE, cast.getDataType()); - Assertions.assertTrue(cast.child() instanceof ElementAt); + @Test + public void testSelectDotSyntaxAutoCast() { + VariantType variantType = buildVariantType(); + SlotReference slot = buildVariantSlot(variantType); + Scope scope = new Scope(ImmutableList.of(slot)); + + UnboundSlot unbound = new UnboundSlot("data", "num_a"); + Expression result = analyze(unbound, scope, true); + Assertions.assertTrue(result instanceof Alias); + Alias alias = (Alias) result; + assertCastElementAt(alias.child()); } @Test - public void testVisitElementAtAutoCastDisabled() { - VariantField field = new VariantField("number_*", BigIntType.INSTANCE, ""); - VariantType variantType = new VariantType(ImmutableList.of(field)); + public void testWhereAutoCastComparison() { + VariantType variantType = buildVariantType(); SlotReference slot = buildVariantSlot(variantType); Scope scope = new Scope(ImmutableList.of(slot)); - ElementAt elementAt = new ElementAt(slot, new StringLiteral("number_latency")); - Expression result = analyze(elementAt, scope, false); + ElementAt elementAt = new ElementAt(slot, new StringLiteral("num_a")); + GreaterThan predicate = new GreaterThan(elementAt, new BigIntLiteral(10)); + Expression result = analyze(predicate, scope, true); + + Assertions.assertTrue(result instanceof GreaterThan); + GreaterThan gt = (GreaterThan) result; + assertCastElementAt(gt.left()); + } + + @Test + public void testOrderByExpressionAutoCast() { + VariantType variantType = buildVariantType(); + SlotReference slot = buildVariantSlot(variantType); + Scope scope = new Scope(ImmutableList.of(slot)); + + ElementAt elementAt = new ElementAt(slot, new StringLiteral("num_a")); + Expression result = analyze(elementAt, scope, true); + assertCastElementAt(result); + } + + @Test + public void testGroupByExpressionAutoCast() { + VariantType variantType = buildVariantType(); + SlotReference slot = buildVariantSlot(variantType); + Scope scope = new Scope(ImmutableList.of(slot)); + + ElementAt elementAt = new ElementAt(slot, new StringLiteral("str_name")); + Expression result = analyze(elementAt, scope, true); + assertCastElementAt(result); + } + + @Test + public void testAggregateFunctionAutoCast() { + VariantType variantType = buildVariantType(); + SlotReference slot = buildVariantSlot(variantType); + Scope scope = new Scope(ImmutableList.of(slot)); + + ElementAt elementAt = new ElementAt(slot, new StringLiteral("num_a")); + Sum sum = new Sum(elementAt); + Expression result = analyze(sum, scope, true); + + Assertions.assertTrue(result instanceof Sum); + Sum analyzedSum = (Sum) result; + assertCastElementAt(analyzedSum.child()); + } + + @Test + public void testHavingAutoCastWithAggregate() { + VariantType variantType = buildVariantType(); + SlotReference slot = buildVariantSlot(variantType); + Scope scope = new Scope(ImmutableList.of(slot)); + + ElementAt elementAt = new ElementAt(slot, new StringLiteral("num_a")); + Sum sum = new Sum(elementAt); + GreaterThan having = new GreaterThan(sum, new BigIntLiteral(100)); + Expression result = analyze(having, scope, true); + Assertions.assertTrue(result instanceof GreaterThan); + GreaterThan gt = (GreaterThan) result; + Assertions.assertTrue(gt.left() instanceof Sum); + Sum analyzedSum = (Sum) gt.left(); + assertCastElementAt(analyzedSum.child()); + } + + @Test + public void testNonLiteralKeyNoAutoCast() { + VariantType variantType = buildVariantType(); + SlotReference slot = buildVariantSlot(variantType); + SlotReference keySlot = new SlotReference(new ExprId(2), "col", StringType.INSTANCE, true, ImmutableList.of()); + Scope scope = new Scope(ImmutableList.of(slot, keySlot)); + + ElementAt elementAt = new ElementAt(slot, keySlot); + Expression result = analyze(elementAt, scope, true); Assertions.assertTrue(result instanceof ElementAt); Assertions.assertFalse(result instanceof Cast); } @Test - public void testVisitDereferenceExpressionAutoCast() { - VariantField field = new VariantField("number_*", BigIntType.INSTANCE, ""); - VariantType variantType = new VariantType(ImmutableList.of(field)); + public void testNoMatchingTemplateNoAutoCast() { + VariantType variantType = buildVariantType(); SlotReference slot = buildVariantSlot(variantType); Scope scope = new Scope(ImmutableList.of(slot)); - DereferenceExpression deref = new DereferenceExpression(slot, new StringLiteral("number_latency")); - Expression result = analyze(deref, scope, true); + ElementAt elementAt = new ElementAt(slot, new StringLiteral("unknown")); + Expression result = analyze(elementAt, scope, true); + Assertions.assertTrue(result instanceof ElementAt); + Assertions.assertFalse(result instanceof Cast); + } + + @Test + public void testChainedPathOnlyOuterCast() { + VariantField nestedField = new VariantField("int_nested.level1_num_1", BigIntType.INSTANCE, ""); + VariantType variantType = new VariantType(ImmutableList.of(nestedField)); + SlotReference slot = buildVariantSlot(variantType); + Scope scope = new Scope(ImmutableList.of(slot)); + + ElementAt inner = new ElementAt(slot, new StringLiteral("int_nested")); + ElementAt outer = new ElementAt(inner, new StringLiteral("level1_num_1")); + Expression result = analyze(outer, scope, true); Assertions.assertTrue(result instanceof Cast); Cast cast = (Cast) result; - Assertions.assertEquals(BigIntType.INSTANCE, cast.getDataType()); Assertions.assertTrue(cast.child() instanceof ElementAt); + ElementAt castChild = (ElementAt) cast.child(); + Assertions.assertTrue(castChild.left() instanceof ElementAt); + Assertions.assertFalse(castChild.left() instanceof Cast); } @Test - public void testResolveVariantElementAtPathChain() { - VariantField field = new VariantField("int_nested.level1_num_1", BigIntType.INSTANCE, ""); - VariantType variantType = new VariantType(ImmutableList.of(field)); + public void testDotPathMergedAliasCast() { + VariantField nestedField = new VariantField("int_nested.level1_num_1", BigIntType.INSTANCE, ""); + VariantType variantType = new VariantType(ImmutableList.of(nestedField)); SlotReference slot = buildVariantSlot(variantType); Scope scope = new Scope(ImmutableList.of(slot)); - ElementAt elementAt = new ElementAt( - new ElementAt(slot, new StringLiteral("int_nested")), - new StringLiteral("level1_num_1") - ); - Expression result = analyze(elementAt, scope, true); + UnboundSlot unbound = new UnboundSlot("data", "int_nested", "level1_num_1"); + Expression result = analyze(unbound, scope, true); + Assertions.assertTrue(result instanceof Alias); + Alias alias = (Alias) result; + assertCastElementAt(alias.child()); + } + + @Test + public void testExplicitCastStillAutoCastsInner() { + VariantType variantType = buildVariantType(); + SlotReference slot = buildVariantSlot(variantType); + Scope scope = new Scope(ImmutableList.of(slot)); + + ElementAt elementAt = new ElementAt(slot, new StringLiteral("num_a")); + Cast explicit = new Cast(elementAt, IntegerType.INSTANCE); + Expression result = analyze(explicit, scope, true); Assertions.assertTrue(result instanceof Cast); - Cast cast = (Cast) result; - Assertions.assertEquals(BigIntType.INSTANCE, cast.getDataType()); - Assertions.assertTrue(cast.child() instanceof ElementAt); + Cast outer = (Cast) result; + Assertions.assertTrue(outer.child() instanceof Cast); + Cast inner = (Cast) outer.child(); + Assertions.assertTrue(inner.child() instanceof ElementAt); + } + + @Test + public void testWhereBetweenAndIn() { + VariantType variantType = buildVariantType(); + SlotReference slot = buildVariantSlot(variantType); + Scope scope = new Scope(ImmutableList.of(slot)); + + ElementAt elementAt = new ElementAt(slot, new StringLiteral("num_a")); + Between between = new Between(elementAt, new BigIntLiteral(10), new BigIntLiteral(20)); + Expression betweenResult = analyze(between, scope, true); + Assertions.assertTrue(betweenResult instanceof Between); + assertCastElementAt(((Between) betweenResult).getCompareExpr()); + + ElementAt elementAtStr = new ElementAt(slot, new StringLiteral("str_name")); + InPredicate inPredicate = new InPredicate(elementAtStr, + ImmutableList.of(new StringLiteral("alice"), new StringLiteral("bob"))); + Expression inResult = analyze(inPredicate, scope, true); + Assertions.assertTrue(inResult instanceof InPredicate); + assertCastElementAt(((InPredicate) inResult).getCompareExpr()); } @Test - public void testGetVariantPathKeyNonString() throws Exception { - VariantField field = new VariantField("number_*", BigIntType.INSTANCE, ""); - VariantType variantType = new VariantType(ImmutableList.of(field)); + public void testAggregateMinMaxAvgCountDistinct() { + VariantType variantType = buildVariantType(); SlotReference slot = buildVariantSlot(variantType); + Scope scope = new Scope(ImmutableList.of(slot)); + + ElementAt elementAt = new ElementAt(slot, new StringLiteral("num_a")); + Min min = new Min(elementAt); + Max max = new Max(elementAt); + Avg avg = new Avg(elementAt); + Count countDistinct = new Count(true, elementAt); - ElementAt elementAt = new ElementAt(slot, new BigIntLiteral(1)); - ExpressionAnalyzer analyzer = new ExpressionAnalyzer(null, new Scope(ImmutableList.of(slot)), - createContext(true), true, true); + Expression minResult = analyze(min, scope, true); + Assertions.assertTrue(minResult instanceof Min); + assertCastElementAt(((Min) minResult).child()); - Method getVariantPathKey = ExpressionAnalyzer.class.getDeclaredMethod("getVariantPathKey", Expression.class); - getVariantPathKey.setAccessible(true); - @SuppressWarnings("unchecked") - Optional key = (Optional) getVariantPathKey.invoke(analyzer, new BigIntLiteral(1)); - Assertions.assertFalse(key.isPresent()); + Expression maxResult = analyze(max, scope, true); + Assertions.assertTrue(maxResult instanceof Max); + assertCastElementAt(((Max) maxResult).child()); - Method resolvePath = ExpressionAnalyzer.class.getDeclaredMethod("resolveVariantElementAtPath", ElementAt.class); - resolvePath.setAccessible(true); - @SuppressWarnings("unchecked") - Optional path = (Optional) resolvePath.invoke(analyzer, elementAt); - Assertions.assertFalse(path.isPresent()); + Expression avgResult = analyze(avg, scope, true); + Assertions.assertTrue(avgResult instanceof Avg); + assertCastElementAt(((Avg) avgResult).child()); + + Expression countResult = analyze(countDistinct, scope, true); + Assertions.assertTrue(countResult instanceof Count); + assertCastElementAt(((Count) countResult).child(0)); } @Test - public void testMaybeCastAliasExpression() { - VariantField field = new VariantField("int_nested.level1_num_1", BigIntType.INSTANCE, ""); - VariantType variantType = new VariantType(ImmutableList.of(field)); + public void testAutoCastDisabled() { + VariantType variantType = buildVariantType(); SlotReference slot = buildVariantSlot(variantType); Scope scope = new Scope(ImmutableList.of(slot)); - UnboundSlot unbound = new UnboundSlot("data", "int_nested", "level1_num_1"); - Expression result = analyze(unbound, scope, true); + ElementAt elementAt = new ElementAt(slot, new StringLiteral("num_a")); + Expression result = analyze(elementAt, scope, false); - Assertions.assertTrue(result instanceof Alias); - Alias alias = (Alias) result; - Assertions.assertTrue(alias.child() instanceof Cast); - Cast cast = (Cast) alias.child(); - Assertions.assertEquals(BigIntType.INSTANCE, cast.getDataType()); + Assertions.assertTrue(result instanceof ElementAt); + Assertions.assertFalse(result instanceof Cast); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/types/VariantFieldMatchTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/types/VariantFieldMatchTest.java index 146d7151324c38..6551f14932e320 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/types/VariantFieldMatchTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/types/VariantFieldMatchTest.java @@ -128,7 +128,6 @@ public void testDefaultPatternTypeIsGlob() { VariantField field = new VariantField("number_*", BigIntType.INSTANCE, ""); Assertions.assertTrue(field.matches("number_latency")); - Assertions.assertEquals(TPatternType.MATCH_NAME_GLOB, field.getPatternType()); } // ==================== VariantType.findMatchingField() tests ==================== @@ -236,11 +235,11 @@ public void testGlobEscapeBackslash() { @Test public void testGlobUnclosedBracket() { - // No closing bracket: '[' treated as literal + // No closing bracket: invalid glob for PathMatcher, expect no match VariantField field = new VariantField("int_[0-9", BigIntType.INSTANCE, "", TPatternType.MATCH_NAME_GLOB.name()); - Assertions.assertTrue(field.matches("int_[0-9")); + Assertions.assertFalse(field.matches("int_[0-9")); Assertions.assertFalse(field.matches("int_1")); } @@ -269,10 +268,5 @@ public void testGlobCharacterClass() { Assertions.assertTrue(field2.matches("int_a")); Assertions.assertFalse(field2.matches("int_1")); - // Negated character class with ^ - VariantField field3 = new VariantField("int_[^0-9]", BigIntType.INSTANCE, "", - TPatternType.MATCH_NAME_GLOB.name()); - Assertions.assertTrue(field3.matches("int_a")); - Assertions.assertFalse(field3.matches("int_1")); } } diff --git a/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out b/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out index c150ee2a4bce94..3b9ecdd4580947 100644 --- a/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out +++ b/regression-test/data/variant_p0/predefine/test_schema_template_auto_cast.out @@ -12,12 +12,27 @@ 3 4 +-- !where_between -- +2 +4 + +-- !where_in -- +1 +3 +4 + -- !order_by -- 3 50 2 30 4 15 1 10 +-- !order_by_expr -- +3 51 +2 31 +4 16 +1 11 + -- !topn -- 3 50 2 30 @@ -28,16 +43,57 @@ 3 110 4 40 +-- !case_when -- +1 low +2 high +3 high +4 low + +-- !order_by_alias_expr -- +30 +40 +70 +110 + +-- !explicit_cast_select -- +10 +30 +50 +15 + +-- !explicit_cast_where -- +2 +3 + +-- !explicit_cast_order_by -- +3 +2 +4 +1 + -- !group_by -- alice 25 bob 30 charlie 50 +-- !group_by_multi_agg -- +alice 10 15 2 +bob 30 30 1 +charlie 50 50 1 + -- !having -- alice 25 bob 30 charlie 50 +-- !having_min -- +bob 30 +charlie 50 + +-- !having_non_agg -- +bob 30 +charlie 50 + -- !order_by_alias -- 10 15 @@ -56,12 +112,42 @@ charlie 50 30 1 50 1 +-- !order_by_alias_nested -- +10 +15 +30 +50 + +-- !group_by_alias_nested -- +10 1 +15 1 +30 1 +50 1 + -- !window_partition_order -- 1 1 2 1 3 1 4 2 +-- !window_sum -- +1 25 +2 30 +3 50 +4 25 + +-- !window_sum_order -- +1 10 +2 30 +3 50 +4 25 + +-- !agg_min_max -- +10 50 + +-- !agg_count_distinct -- +3 + -- !join_on -- 1 first 2 second @@ -120,30 +206,6 @@ charlie 50 3033333 4044444 --- !leaf_int_nested_mixed_select_1 -- -1011111 -2022222 -3033333 -4044444 - --- !leaf_int_nested_mixed_select_2 -- -1011111 -2022222 -3033333 -4044444 - --- !leaf_int_nested_mixed_select_3 -- -1011111 -2022222 -3033333 -4044444 - --- !leaf_int_nested_paren_root_select -- -1011111 -2022222 -3033333 -4044444 - -- !leaf_int_nested_chain_add -- 1011112 2022223 @@ -162,30 +224,6 @@ charlie 50 3033334 4044445 --- !leaf_int_nested_mixed_add_1 -- -1011112 -2022223 -3033334 -4044445 - --- !leaf_int_nested_mixed_add_2 -- -1011112 -2022223 -3033334 -4044445 - --- !leaf_int_nested_mixed_add_3 -- -1011112 -2022223 -3033334 -4044445 - --- !leaf_int_nested_paren_root_add -- -1011112 -2022223 -3033334 -4044445 - -- !leaf_where_ok -- 1 2 @@ -209,11 +247,6 @@ charlie 50 3 4 --- !leaf_where_paren_root -- -2 -3 -4 - -- !leaf_order_by_ok -- 1 3 diff --git a/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy b/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy index 25948b8ff6a6c6..c8e80b03bdd4a2 100644 --- a/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy +++ b/regression-test/suites/variant_p0/predefine/test_schema_template_auto_cast.groovy @@ -53,10 +53,24 @@ suite("test_schema_template_auto_cast", "p0") { WHERE data['num_a'] > 40 OR data['str_name'] = 'alice' ORDER BY id """ + // BETWEEN condition + qt_where_between """ SELECT id FROM ${tableName} + WHERE data['num_a'] BETWEEN 15 AND 30 + ORDER BY id """ + + // IN condition + qt_where_in """ SELECT id FROM ${tableName} + WHERE data['str_name'] IN ('alice', 'charlie') + ORDER BY id """ + // Test 2: ORDER BY with auto-cast qt_order_by """ SELECT id, data['num_a'] FROM ${tableName} ORDER BY data['num_a'] DESC """ + // ORDER BY expression + qt_order_by_expr """ SELECT id, data['num_a'] + 1 AS n FROM ${tableName} + ORDER BY data['num_a'] + 1 DESC """ + // Test 3: TopN (ORDER BY + LIMIT) qt_topn """ SELECT id, data['num_a'] FROM ${tableName} ORDER BY data['num_a'] DESC LIMIT 2 """ @@ -65,15 +79,46 @@ suite("test_schema_template_auto_cast", "p0") { qt_select_arithmetic """ SELECT id, data['num_a'] + data['num_b'] as sum_val FROM ${tableName} ORDER BY id """ + // CASE WHEN with auto-cast + qt_case_when """ SELECT id, + CASE WHEN data['num_a'] > 20 THEN 'high' ELSE 'low' END AS level + FROM ${tableName} ORDER BY id """ + + // ORDER BY alias from expression + qt_order_by_alias_expr """ SELECT data['num_a'] + data['num_b'] AS sum_val FROM ${tableName} + ORDER BY sum_val """ + + // Explicit CAST should still trigger schema template auto cast + qt_explicit_cast_select """ SELECT CAST(data['num_a'] AS INT) FROM ${tableName} ORDER BY id """ + qt_explicit_cast_where """ SELECT id FROM ${tableName} + WHERE CAST(data['num_a'] AS INT) > 20 ORDER BY id """ + qt_explicit_cast_order_by """ SELECT id FROM ${tableName} + ORDER BY CAST(data['num_a'] AS INT) DESC """ + // Test 5: GROUP BY with auto-cast qt_group_by """ SELECT data['str_name'], SUM(data['num_a']) as total FROM ${tableName} GROUP BY data['str_name'] ORDER BY data['str_name'] """ + // GROUP BY with multiple aggregates + qt_group_by_multi_agg """ SELECT data['str_name'], + MIN(data['num_a']) AS min_a, MAX(data['num_a']) AS max_a, COUNT(*) AS cnt + FROM ${tableName} GROUP BY data['str_name'] ORDER BY data['str_name'] """ + // Test 6: HAVING with auto-cast qt_having """ SELECT data['str_name'], SUM(data['num_a']) as total FROM ${tableName} GROUP BY data['str_name'] HAVING SUM(data['num_a']) > 20 ORDER BY data['str_name'] """ + // HAVING with MIN + qt_having_min """ SELECT data['str_name'], MIN(data['num_a']) AS min_a + FROM ${tableName} GROUP BY data['str_name'] + HAVING MIN(data['num_a']) >= 15 ORDER BY data['str_name'] """ + + // HAVING with non-aggregate expression on group key + qt_having_non_agg """ SELECT data['str_name'], SUM(data['num_a']) AS total + FROM ${tableName} GROUP BY data['str_name'] + HAVING data['str_name'] != 'alice' ORDER BY data['str_name'] """ + // Test 7: ORDER BY with alias from project qt_order_by_alias """ SELECT data['num_a'] AS num_a FROM ${tableName} ORDER BY num_a """ @@ -87,11 +132,35 @@ suite("test_schema_template_auto_cast", "p0") { FROM (SELECT data['num_a'] AS num_a FROM ${tableName}) t GROUP BY num_a ORDER BY num_a """ + // ORDER BY with nested alias + qt_order_by_alias_nested """ SELECT * FROM ( + SELECT num_a FROM (SELECT data['num_a'] AS num_a FROM ${tableName}) s1 + ) s2 ORDER BY num_a """ + + // GROUP BY with nested alias + qt_group_by_alias_nested """ SELECT num_a, COUNT(*) AS cnt FROM ( + SELECT num_a FROM (SELECT data['num_a'] AS num_a FROM ${tableName}) s1 + ) s2 GROUP BY num_a ORDER BY num_a """ + // Test 10: WINDOW partition/order by with auto-cast qt_window_partition_order """ SELECT id, row_number() OVER (PARTITION BY data['str_name'] ORDER BY data['num_a']) AS rn FROM ${tableName} ORDER BY id """ + // WINDOW aggregate + qt_window_sum """ SELECT id, + SUM(data['num_a']) OVER (PARTITION BY data['str_name']) AS s + FROM ${tableName} ORDER BY id """ + + // WINDOW partition + order by with both paths + qt_window_sum_order """ SELECT id, + SUM(data['num_a']) OVER (PARTITION BY data['str_name'] ORDER BY data['num_a']) AS s + FROM ${tableName} ORDER BY id """ + + // Aggregates without GROUP BY + qt_agg_min_max """ SELECT MIN(data['num_a']), MAX(data['num_a']) FROM ${tableName} """ + qt_agg_count_distinct """ SELECT COUNT(DISTINCT data['str_name']) FROM ${tableName} """ + // Test 11: disable auto-cast should error in ORDER BY sql """ set enable_variant_schema_auto_cast = false """ test { @@ -194,20 +263,12 @@ suite("test_schema_template_auto_cast", "p0") { FROM ${leafTable} ORDER BY id """ qt_leaf_int_nested_dot_select """ SELECT data['int_nested.level1_num_1'] FROM ${leafTable} ORDER BY id """ qt_leaf_int_nested_deref_select """ SELECT data.int_nested.level1_num_1 FROM ${leafTable} ORDER BY id """ - qt_leaf_int_nested_mixed_select_1 """ SELECT data['int_nested'].level1_num_1 FROM ${leafTable} ORDER BY id """ - qt_leaf_int_nested_mixed_select_2 """ SELECT (data['int_nested']).level1_num_1 FROM ${leafTable} ORDER BY id """ - qt_leaf_int_nested_mixed_select_3 """ SELECT (data.int_nested).level1_num_1 FROM ${leafTable} ORDER BY id """ - qt_leaf_int_nested_paren_root_select """ SELECT (data).int_nested.level1_num_1 FROM ${leafTable} ORDER BY id """ qt_leaf_int_nested_chain_add """ SELECT data['int_nested']['level1_num_1'] + 1 FROM ${leafTable} ORDER BY id """ qt_leaf_int_nested_dot_add """ SELECT data['int_nested.level1_num_1'] + 1 FROM ${leafTable} ORDER BY id """ qt_leaf_int_nested_deref_add """ SELECT data.int_nested.level1_num_1 + 1 FROM ${leafTable} ORDER BY id """ - qt_leaf_int_nested_mixed_add_1 """ SELECT data['int_nested'].level1_num_1 + 1 FROM ${leafTable} ORDER BY id """ - qt_leaf_int_nested_mixed_add_2 """ SELECT (data['int_nested']).level1_num_1 + 1 FROM ${leafTable} ORDER BY id """ - qt_leaf_int_nested_mixed_add_3 """ SELECT (data.int_nested).level1_num_1 + 1 FROM ${leafTable} ORDER BY id """ - qt_leaf_int_nested_paren_root_add """ SELECT (data).int_nested.level1_num_1 + 1 FROM ${leafTable} ORDER BY id """ // Non-select clauses: leaf vs non-leaf qt_leaf_where_ok """ SELECT id FROM ${leafTable} @@ -215,37 +276,35 @@ suite("test_schema_template_auto_cast", "p0") { qt_leaf_where_nonleaf """ SELECT id FROM ${leafTable} WHERE data['int_nested'] > 0 ORDER BY id """ qt_leaf_where_mixed_1 """ SELECT id FROM ${leafTable} - WHERE data['int_nested'].level1_num_1 > 2000000 ORDER BY id """ + WHERE data['int_nested']['level1_num_1'] > 2000000 ORDER BY id """ qt_leaf_where_mixed_2 """ SELECT id FROM ${leafTable} - WHERE (data['int_nested']).level1_num_1 > 2000000 ORDER BY id """ + WHERE data['int_nested.level1_num_1'] > 2000000 ORDER BY id """ qt_leaf_where_mixed_3 """ SELECT id FROM ${leafTable} - WHERE (data.int_nested).level1_num_1 > 2000000 ORDER BY id """ - qt_leaf_where_paren_root """ SELECT id FROM ${leafTable} - WHERE (data).int_nested.level1_num_1 > 2000000 ORDER BY id """ + WHERE data.int_nested.level1_num_1 > 2000000 ORDER BY id """ qt_leaf_order_by_ok """ SELECT id FROM ${leafTable} ORDER BY data['int_1'], id """ qt_leaf_order_by_nonleaf """ SELECT id FROM ${leafTable} ORDER BY data['int_nested'], id """ qt_leaf_order_by_mixed_1 """ SELECT id FROM ${leafTable} - ORDER BY data['int_nested'].level1_num_1 """ + ORDER BY data['int_nested']['level1_num_1'] """ qt_leaf_order_by_mixed_2 """ SELECT id FROM ${leafTable} - ORDER BY (data.int_nested).level1_num_1 """ + ORDER BY data['int_nested.level1_num_1'] """ qt_leaf_order_by_paren_root """ SELECT id FROM ${leafTable} - ORDER BY (data).int_nested.level1_num_1 """ + ORDER BY data.int_nested.level1_num_1 """ qt_leaf_group_by_ok """ SELECT data['int_1'], COUNT(*) AS cnt FROM ${leafTable} GROUP BY data['int_1'] ORDER BY data['int_1'] """ qt_leaf_group_by_nonleaf """ SELECT data['int_nested'], COUNT(*) AS cnt FROM ${leafTable} GROUP BY data['int_nested'] ORDER BY data['int_nested'] """ - qt_leaf_group_by_mixed """ SELECT data['int_nested'].level1_num_1, COUNT(*) AS cnt - FROM ${leafTable} GROUP BY data['int_nested'].level1_num_1 - ORDER BY data['int_nested'].level1_num_1 """ + qt_leaf_group_by_mixed """ SELECT data['int_nested.level1_num_1'], COUNT(*) AS cnt + FROM ${leafTable} GROUP BY data['int_nested.level1_num_1'] + ORDER BY data['int_nested.level1_num_1'] """ qt_leaf_having_ok """ SELECT data['int_1'], SUM(data['int_1']) AS total FROM ${leafTable} GROUP BY data['int_1'] HAVING SUM(data['int_1']) > 0 ORDER BY data['int_1'] """ - qt_leaf_having_mixed """ SELECT data['int_nested'].level1_num_1, SUM(data['int_nested'].level1_num_1) AS total - FROM ${leafTable} GROUP BY data['int_nested'].level1_num_1 - HAVING SUM(data['int_nested'].level1_num_1) > 3000000 - ORDER BY data['int_nested'].level1_num_1 """ + qt_leaf_having_mixed """ SELECT data['int_nested.level1_num_1'], SUM(data['int_nested.level1_num_1']) AS total + FROM ${leafTable} GROUP BY data['int_nested.level1_num_1'] + HAVING SUM(data['int_nested.level1_num_1']) > 3000000 + ORDER BY data['int_nested.level1_num_1'] """ sql "DROP TABLE IF EXISTS ${leafTable}" } From 668a9811bf35a8387698ca5addf23110d2694b51 Mon Sep 17 00:00:00 2001 From: Gary Date: Wed, 4 Feb 2026 00:59:24 +0800 Subject: [PATCH 20/27] fix FE UT --- .../ExpressionAnalyzerVariantAutoCastTest.java | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzerVariantAutoCastTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzerVariantAutoCastTest.java index b6d8434f6b6a0e..4ab38d7606dcf8 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzerVariantAutoCastTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzerVariantAutoCastTest.java @@ -263,15 +263,19 @@ public void testWhereBetweenAndIn() { ElementAt elementAt = new ElementAt(slot, new StringLiteral("num_a")); Between between = new Between(elementAt, new BigIntLiteral(10), new BigIntLiteral(20)); Expression betweenResult = analyze(between, scope, true); - Assertions.assertTrue(betweenResult instanceof Between); - assertCastElementAt(((Between) betweenResult).getCompareExpr()); + Assertions.assertTrue(betweenResult.containsType(Cast.class)); + Assertions.assertTrue(betweenResult.containsType(ElementAt.class)); + Assertions.assertTrue(betweenResult.collectFirst( + expr -> expr instanceof Cast && ((Cast) expr).child() instanceof ElementAt).isPresent()); ElementAt elementAtStr = new ElementAt(slot, new StringLiteral("str_name")); InPredicate inPredicate = new InPredicate(elementAtStr, ImmutableList.of(new StringLiteral("alice"), new StringLiteral("bob"))); Expression inResult = analyze(inPredicate, scope, true); - Assertions.assertTrue(inResult instanceof InPredicate); - assertCastElementAt(((InPredicate) inResult).getCompareExpr()); + Assertions.assertTrue(inResult.containsType(Cast.class)); + Assertions.assertTrue(inResult.containsType(ElementAt.class)); + Assertions.assertTrue(inResult.collectFirst( + expr -> expr instanceof Cast && ((Cast) expr).child() instanceof ElementAt).isPresent()); } @Test From c75364b66b6d5aed9c663ba1b1c1f23baa9f6573 Mon Sep 17 00:00:00 2001 From: Gary Date: Wed, 4 Feb 2026 14:13:13 +0800 Subject: [PATCH 21/27] use PathMatcher to replace fnmatch --- be/src/vec/common/variant_util.cpp | 57 +++++++++++++- .../common/jni/utils/PathMatcherUtil.java | 74 +++++++++++++++++++ 2 files changed, 128 insertions(+), 3 deletions(-) create mode 100644 fe/be-java-extensions/java-common/src/main/java/org/apache/doris/common/jni/utils/PathMatcherUtil.java diff --git a/be/src/vec/common/variant_util.cpp b/be/src/vec/common/variant_util.cpp index 39e720630678ae..e792d7ad9c7c31 100644 --- a/be/src/vec/common/variant_util.cpp +++ b/be/src/vec/common/variant_util.cpp @@ -19,7 +19,6 @@ #include #include -#include #include #include #include @@ -34,6 +33,7 @@ #include #include +#include #include #include #include @@ -69,6 +69,7 @@ #include "runtime/primitive_type.h" #include "runtime/runtime_state.h" #include "util/defer_op.h" +#include "util/jni-util.h" #include "vec/columns/column.h" #include "vec/columns/column_array.h" #include "vec/columns/column_map.h" @@ -102,6 +103,56 @@ namespace doris::vectorized::variant_util { #include "common/compile_check_begin.h" +namespace { + +std::once_flag g_path_matcher_once; +std::atomic g_path_matcher_ready {false}; +jclass g_path_matcher_cl = nullptr; +jmethodID g_path_matcher_matches = nullptr; + +void init_java_path_matcher(JNIEnv* env) { + jclass local_cl = env->FindClass("org/apache/doris/common/jni/utils/PathMatcherUtil"); + if (local_cl == nullptr) { + env->ExceptionClear(); + return; + } + g_path_matcher_cl = reinterpret_cast(env->NewGlobalRef(local_cl)); + env->DeleteLocalRef(local_cl); + if (g_path_matcher_cl == nullptr) { + env->ExceptionClear(); + return; + } + g_path_matcher_matches = env->GetStaticMethodID(g_path_matcher_cl, "matches", + "(Ljava/lang/String;Ljava/lang/String;)Z"); + if (g_path_matcher_matches == nullptr) { + env->ExceptionClear(); + return; + } + g_path_matcher_ready.store(true, std::memory_order_release); +} + +bool java_glob_match(const char* pattern, const std::string& path) { + JNIEnv* env = nullptr; + Status st = Jni::Env::Get(&env); + CHECK(st.ok()) << st; + std::call_once(g_path_matcher_once, [&]() { init_java_path_matcher(env); }); + CHECK(g_path_matcher_ready.load(std::memory_order_acquire)) + << "PathMatcherUtil is not available in JVM"; + jstring jpattern = env->NewStringUTF(pattern); + jstring jpath = env->NewStringUTF(path.c_str()); + jboolean result = env->CallStaticBooleanMethod(g_path_matcher_cl, g_path_matcher_matches, + jpattern, jpath); + env->DeleteLocalRef(jpattern); + env->DeleteLocalRef(jpath); + if (env->ExceptionCheck()) { + Status err = Jni::Env::GetJniExceptionMsg(env); + CHECK(false) << err; + } + return result == JNI_TRUE; +} + +} // namespace + size_t get_number_of_dimensions(const IDataType& type) { if (const auto* type_array = typeid_cast(&type)) { return type_array->get_number_of_dimensions(); @@ -1307,8 +1358,8 @@ bool generate_sub_column_info(const TabletSchema& schema, int32_t col_unique_id, break; } case PatternTypePB::MATCH_NAME_GLOB: { - int result = fnmatch(pattern, path.c_str(), FNM_PATHNAME); - if (result == 0) { + bool matched = java_glob_match(pattern, path); + if (matched) { generate_result_column(*sub_column, &sub_column_info->column); generate_index(sub_column->name()); return true; diff --git a/fe/be-java-extensions/java-common/src/main/java/org/apache/doris/common/jni/utils/PathMatcherUtil.java b/fe/be-java-extensions/java-common/src/main/java/org/apache/doris/common/jni/utils/PathMatcherUtil.java new file mode 100644 index 00000000000000..0eb10a3ae3514d --- /dev/null +++ b/fe/be-java-extensions/java-common/src/main/java/org/apache/doris/common/jni/utils/PathMatcherUtil.java @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.common.jni.utils; + +import java.nio.file.FileSystems; +import java.nio.file.PathMatcher; +import java.nio.file.Paths; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Utility for BE JNI callers to use Java PathMatcher glob semantics. + */ +public final class PathMatcherUtil { + private static final String CACHE_SIZE_KEY = "doris.pathmatcher.cache.max"; + private static final int DEFAULT_CACHE_MAX = 2048; + private static final int CACHE_MAX = initCacheMax(); + + private static final Map CACHE = Collections.synchronizedMap( + new LinkedHashMap(16, 0.75f, true) { + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return size() > CACHE_MAX; + } + }); + + private PathMatcherUtil() { + } + + public static boolean matches(String pattern, String path) { + if (pattern == null || path == null) { + return false; + } + try { + PathMatcher matcher = CACHE.get(pattern); + if (matcher == null) { + matcher = FileSystems.getDefault().getPathMatcher("glob:" + pattern); + CACHE.put(pattern, matcher); + } + return matcher.matches(Paths.get(path)); + } catch (RuntimeException e) { + return false; + } + } + + private static int initCacheMax() { + String prop = System.getProperty(CACHE_SIZE_KEY); + if (prop == null || prop.isEmpty()) { + return DEFAULT_CACHE_MAX; + } + try { + int value = Integer.parseInt(prop); + return value > 0 ? value : DEFAULT_CACHE_MAX; + } catch (NumberFormatException e) { + return DEFAULT_CACHE_MAX; + } + } +} From 62927b245eaf4f429b6fc96efb13fd1e19f92832 Mon Sep 17 00:00:00 2001 From: Gary Date: Wed, 4 Feb 2026 14:31:29 +0800 Subject: [PATCH 22/27] delete shouldSuppressVariantElementAtCast --- .../nereids/rules/analysis/ExpressionAnalyzer.java | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java index e2fe0e1e0fb9a2..134610b5c345b1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ExpressionAnalyzer.java @@ -742,9 +742,6 @@ private boolean isEnableVariantSchemaAutoCast(ExpressionRewriteContext context) } private Expression wrapVariantElementAtWithCast(Expression expr) { - if (!(expr instanceof ElementAt)) { - return expr; - } ElementAt elementAt = (ElementAt) expr; if (suppressVariantElementAtCastDepth > 0) { return elementAt; @@ -806,14 +803,6 @@ private VariantElementAtPath(Expression root, String path) { } } - private boolean shouldSuppressVariantElementAtCast(Cast cast) { - if (!cast.isExplicitType()) { - return false; - } - Expression child = cast.child(); - return child instanceof ElementAt || child instanceof DereferenceExpression || child instanceof UnboundSlot; - } - private Expression maybeCastAliasExpression(Alias alias, ExpressionRewriteContext context) { if (suppressVariantElementAtCastDepth > 0 || !isEnableVariantSchemaAutoCast(context)) { return alias; From 151452d2a8d5d04bb49a2c51f3987e8515da6342 Mon Sep 17 00:00:00 2001 From: Gary Date: Wed, 4 Feb 2026 15:16:36 +0800 Subject: [PATCH 23/27] Revert "use PathMatcher to replace fnmatch" This reverts commit c75364b66b6d5aed9c663ba1b1c1f23baa9f6573. --- be/src/vec/common/variant_util.cpp | 57 +------------- .../common/jni/utils/PathMatcherUtil.java | 74 ------------------- 2 files changed, 3 insertions(+), 128 deletions(-) delete mode 100644 fe/be-java-extensions/java-common/src/main/java/org/apache/doris/common/jni/utils/PathMatcherUtil.java diff --git a/be/src/vec/common/variant_util.cpp b/be/src/vec/common/variant_util.cpp index e792d7ad9c7c31..39e720630678ae 100644 --- a/be/src/vec/common/variant_util.cpp +++ b/be/src/vec/common/variant_util.cpp @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -33,7 +34,6 @@ #include #include -#include #include #include #include @@ -69,7 +69,6 @@ #include "runtime/primitive_type.h" #include "runtime/runtime_state.h" #include "util/defer_op.h" -#include "util/jni-util.h" #include "vec/columns/column.h" #include "vec/columns/column_array.h" #include "vec/columns/column_map.h" @@ -103,56 +102,6 @@ namespace doris::vectorized::variant_util { #include "common/compile_check_begin.h" -namespace { - -std::once_flag g_path_matcher_once; -std::atomic g_path_matcher_ready {false}; -jclass g_path_matcher_cl = nullptr; -jmethodID g_path_matcher_matches = nullptr; - -void init_java_path_matcher(JNIEnv* env) { - jclass local_cl = env->FindClass("org/apache/doris/common/jni/utils/PathMatcherUtil"); - if (local_cl == nullptr) { - env->ExceptionClear(); - return; - } - g_path_matcher_cl = reinterpret_cast(env->NewGlobalRef(local_cl)); - env->DeleteLocalRef(local_cl); - if (g_path_matcher_cl == nullptr) { - env->ExceptionClear(); - return; - } - g_path_matcher_matches = env->GetStaticMethodID(g_path_matcher_cl, "matches", - "(Ljava/lang/String;Ljava/lang/String;)Z"); - if (g_path_matcher_matches == nullptr) { - env->ExceptionClear(); - return; - } - g_path_matcher_ready.store(true, std::memory_order_release); -} - -bool java_glob_match(const char* pattern, const std::string& path) { - JNIEnv* env = nullptr; - Status st = Jni::Env::Get(&env); - CHECK(st.ok()) << st; - std::call_once(g_path_matcher_once, [&]() { init_java_path_matcher(env); }); - CHECK(g_path_matcher_ready.load(std::memory_order_acquire)) - << "PathMatcherUtil is not available in JVM"; - jstring jpattern = env->NewStringUTF(pattern); - jstring jpath = env->NewStringUTF(path.c_str()); - jboolean result = env->CallStaticBooleanMethod(g_path_matcher_cl, g_path_matcher_matches, - jpattern, jpath); - env->DeleteLocalRef(jpattern); - env->DeleteLocalRef(jpath); - if (env->ExceptionCheck()) { - Status err = Jni::Env::GetJniExceptionMsg(env); - CHECK(false) << err; - } - return result == JNI_TRUE; -} - -} // namespace - size_t get_number_of_dimensions(const IDataType& type) { if (const auto* type_array = typeid_cast(&type)) { return type_array->get_number_of_dimensions(); @@ -1358,8 +1307,8 @@ bool generate_sub_column_info(const TabletSchema& schema, int32_t col_unique_id, break; } case PatternTypePB::MATCH_NAME_GLOB: { - bool matched = java_glob_match(pattern, path); - if (matched) { + int result = fnmatch(pattern, path.c_str(), FNM_PATHNAME); + if (result == 0) { generate_result_column(*sub_column, &sub_column_info->column); generate_index(sub_column->name()); return true; diff --git a/fe/be-java-extensions/java-common/src/main/java/org/apache/doris/common/jni/utils/PathMatcherUtil.java b/fe/be-java-extensions/java-common/src/main/java/org/apache/doris/common/jni/utils/PathMatcherUtil.java deleted file mode 100644 index 0eb10a3ae3514d..00000000000000 --- a/fe/be-java-extensions/java-common/src/main/java/org/apache/doris/common/jni/utils/PathMatcherUtil.java +++ /dev/null @@ -1,74 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.common.jni.utils; - -import java.nio.file.FileSystems; -import java.nio.file.PathMatcher; -import java.nio.file.Paths; -import java.util.Collections; -import java.util.LinkedHashMap; -import java.util.Map; - -/** - * Utility for BE JNI callers to use Java PathMatcher glob semantics. - */ -public final class PathMatcherUtil { - private static final String CACHE_SIZE_KEY = "doris.pathmatcher.cache.max"; - private static final int DEFAULT_CACHE_MAX = 2048; - private static final int CACHE_MAX = initCacheMax(); - - private static final Map CACHE = Collections.synchronizedMap( - new LinkedHashMap(16, 0.75f, true) { - @Override - protected boolean removeEldestEntry(Map.Entry eldest) { - return size() > CACHE_MAX; - } - }); - - private PathMatcherUtil() { - } - - public static boolean matches(String pattern, String path) { - if (pattern == null || path == null) { - return false; - } - try { - PathMatcher matcher = CACHE.get(pattern); - if (matcher == null) { - matcher = FileSystems.getDefault().getPathMatcher("glob:" + pattern); - CACHE.put(pattern, matcher); - } - return matcher.matches(Paths.get(path)); - } catch (RuntimeException e) { - return false; - } - } - - private static int initCacheMax() { - String prop = System.getProperty(CACHE_SIZE_KEY); - if (prop == null || prop.isEmpty()) { - return DEFAULT_CACHE_MAX; - } - try { - int value = Integer.parseInt(prop); - return value > 0 ? value : DEFAULT_CACHE_MAX; - } catch (NumberFormatException e) { - return DEFAULT_CACHE_MAX; - } - } -} From 2661c43fa0de678f1d8931fe5a2de3879b7fb82f Mon Sep 17 00:00:00 2001 From: Gary Date: Wed, 4 Feb 2026 21:44:02 +0800 Subject: [PATCH 24/27] glob -> regex --- be/src/vec/common/variant_util.cpp | 164 +++++++++++++++++- be/src/vec/common/variant_util.h | 7 + .../rowset/segment_v2/variant_util_test.cpp | 80 +++++++++ .../org/apache/doris/catalog/OlapTable.java | 8 +- .../apache/doris/common/GlobRegexUtil.java | 150 ++++++++++++++++ .../doris/nereids/types/VariantField.java | 22 +-- .../doris/common/GlobRegexUtilTest.java | 76 ++++++++ .../nereids/types/VariantFieldMatchTest.java | 4 +- 8 files changed, 487 insertions(+), 24 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/common/GlobRegexUtil.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/common/GlobRegexUtilTest.java diff --git a/be/src/vec/common/variant_util.cpp b/be/src/vec/common/variant_util.cpp index 39e720630678ae..d186b334efb7c0 100644 --- a/be/src/vec/common/variant_util.cpp +++ b/be/src/vec/common/variant_util.cpp @@ -19,7 +19,6 @@ #include #include -#include #include #include #include @@ -39,6 +38,7 @@ #include #include #include +#include #include #include #include @@ -53,6 +53,7 @@ #include "common/config.h" #include "common/status.h" #include "exprs/json_functions.h" +#include "re2/re2.h" #include "olap/olap_common.h" #include "olap/rowset/beta_rowset.h" #include "olap/rowset/rowset.h" @@ -102,6 +103,160 @@ namespace doris::vectorized::variant_util { #include "common/compile_check_begin.h" +inline void append_escaped_regex_char(std::string* regex_output, char ch) { + switch (ch) { + case '.': + case '^': + case '$': + case '+': + case '(': + case ')': + case '|': + case '{': + case '}': + case '[': + case ']': + case '\\': + regex_output->push_back('\\'); + regex_output->push_back(ch); + break; + default: + regex_output->push_back(ch); + break; + } +} + +// Small LRU to cap compiled glob patterns +constexpr size_t kGlobRegexCacheCapacity = 256; + +struct GlobRegexCacheEntry { + std::shared_ptr re2; + std::list::iterator lru_it; +}; + +std::mutex g_glob_regex_cache_mutex; +std::list g_glob_regex_cache_lru; +std::unordered_map g_glob_regex_cache; + +std::shared_ptr get_or_build_re2(const std::string& glob_pattern) { + { + std::lock_guard lock(g_glob_regex_cache_mutex); + auto it = g_glob_regex_cache.find(glob_pattern); + if (it != g_glob_regex_cache.end()) { + g_glob_regex_cache_lru.splice(g_glob_regex_cache_lru.begin(), + g_glob_regex_cache_lru, it->second.lru_it); + return it->second.re2; + } + } + std::string regex_pattern; + Status st = glob_to_regex(glob_pattern, ®ex_pattern); + if (!st.ok()) { + return nullptr; + } + auto compiled = std::make_shared(regex_pattern); + if (!compiled->ok()) { + return nullptr; + } + { + std::lock_guard lock(g_glob_regex_cache_mutex); + auto it = g_glob_regex_cache.find(glob_pattern); + if (it != g_glob_regex_cache.end()) { + g_glob_regex_cache_lru.splice(g_glob_regex_cache_lru.begin(), + g_glob_regex_cache_lru, it->second.lru_it); + return it->second.re2; + } + g_glob_regex_cache_lru.push_front(glob_pattern); + g_glob_regex_cache.emplace(glob_pattern, + GlobRegexCacheEntry{compiled, g_glob_regex_cache_lru.begin()}); + if (g_glob_regex_cache.size() > kGlobRegexCacheCapacity) { + const std::string& evict_key = g_glob_regex_cache_lru.back(); + g_glob_regex_cache.erase(evict_key); + g_glob_regex_cache_lru.pop_back(); + } + } + return compiled; +} + + +// Convert a restricted glob pattern into a regex. +// Supported: '*', '?', '[...]', '\\' escape. Others are treated as literals. +Status glob_to_regex(const std::string& glob_pattern, std::string* regex_pattern) { + regex_pattern->clear(); + regex_pattern->append("^"); + bool is_escaped = false; + size_t pattern_length = glob_pattern.size(); + for (size_t index = 0; index < pattern_length; ++index) { + char current_char = glob_pattern[index]; + if (is_escaped) { + append_escaped_regex_char(regex_pattern, current_char); + is_escaped = false; + continue; + } + if (current_char == '\\') { + is_escaped = true; + continue; + } + if (current_char == '*') { + regex_pattern->append(".*"); + continue; + } + if (current_char == '?') { + regex_pattern->append("."); + continue; + } + if (current_char == '[') { + size_t class_index = index + 1; + bool class_closed = false; + bool is_class_escaped = false; + std::string class_buffer; + if (class_index < pattern_length && + (glob_pattern[class_index] == '!' || glob_pattern[class_index] == '^')) { + class_buffer.push_back('^'); + ++class_index; + } + for (; class_index < pattern_length; ++class_index) { + char class_char = glob_pattern[class_index]; + if (is_class_escaped) { + class_buffer.push_back(class_char); + is_class_escaped = false; + continue; + } + if (class_char == '\\') { + is_class_escaped = true; + continue; + } + if (class_char == ']') { + class_closed = true; + break; + } + class_buffer.push_back(class_char); + } + if (!class_closed) { + return Status::InvalidArgument("Unclosed character class in glob pattern: {}", glob_pattern); + } + regex_pattern->append("["); + regex_pattern->append(class_buffer); + regex_pattern->append("]"); + index = class_index; + continue; + } + append_escaped_regex_char(regex_pattern, current_char); + } + if (is_escaped) { + append_escaped_regex_char(regex_pattern, '\\'); + } + regex_pattern->append("$"); + return Status::OK(); +} + +bool glob_match_re2(const std::string& glob_pattern, const std::string& candidate_path) { + auto compiled = get_or_build_re2(glob_pattern); + if (compiled == nullptr) { + return false; + } + return RE2::FullMatch(candidate_path, *compiled); +} + size_t get_number_of_dimensions(const IDataType& type) { if (const auto* type_array = typeid_cast(&type)) { return type_array->get_number_of_dimensions(); @@ -1307,8 +1462,7 @@ bool generate_sub_column_info(const TabletSchema& schema, int32_t col_unique_id, break; } case PatternTypePB::MATCH_NAME_GLOB: { - int result = fnmatch(pattern, path.c_str(), FNM_PATHNAME); - if (result == 0) { + if (glob_match_re2(pattern, path)) { generate_result_column(*sub_column, &sub_column_info->column); generate_index(sub_column->name()); return true; @@ -1788,8 +1942,6 @@ std::unordered_map materialize_docs_ return subcolumns; } -namespace { - Status _parse_and_materialize_variant_columns(Block& block, const std::vector& variant_pos, const std::vector& configs) { @@ -1864,8 +2016,6 @@ Status _parse_and_materialize_variant_columns(Block& block, return Status::OK(); } -} // namespace - Status parse_and_materialize_variant_columns(Block& block, const std::vector& variant_pos, const std::vector& configs) { RETURN_IF_CATCH_EXCEPTION( diff --git a/be/src/vec/common/variant_util.h b/be/src/vec/common/variant_util.h index 37dc452a3a2f62..a36179ac0fbf50 100644 --- a/be/src/vec/common/variant_util.h +++ b/be/src/vec/common/variant_util.h @@ -64,6 +64,13 @@ using JsonParser = JSONDataParser; const std::string SPARSE_COLUMN_PATH = "__DORIS_VARIANT_SPARSE__"; const std::string DOC_VALUE_COLUMN_PATH = "__DORIS_VARIANT_DOC_VALUE__"; namespace doris::vectorized::variant_util { + +// Convert a restricted glob pattern into a regex (for tests/internal use). +Status glob_to_regex(const std::string& glob_pattern, std::string* regex_pattern); + +// Match a glob pattern against a path using RE2. +bool glob_match_re2(const std::string& glob_pattern, const std::string& candidate_path); + using PathToNoneNullValues = std::unordered_map; using PathToDataTypes = std::unordered_map, PathInData::Hash>; diff --git a/be/test/olap/rowset/segment_v2/variant_util_test.cpp b/be/test/olap/rowset/segment_v2/variant_util_test.cpp index 78eacd6b3ac91c..24e7c3667028d3 100644 --- a/be/test/olap/rowset/segment_v2/variant_util_test.cpp +++ b/be/test/olap/rowset/segment_v2/variant_util_test.cpp @@ -209,4 +209,84 @@ TEST(VariantUtilTest, ParseVariantColumns_DocModeRejectOnlySubcolumnsConfig) { EXPECT_TRUE(st.ok()) << st.to_string(); } +TEST(VariantUtilTest, GlobToRegex) { + struct Case { + std::string glob; + std::string expected_regex; + }; + const std::vector cases = { + {"*", "^.*$"}, + {"?", "^.$"}, + {"a?b", "^a.b$"}, + {"a*b", "^a.*b$"}, + {"a.b", "^a\\.b$"}, + {"a+b", "^a\\+b$"}, + {"a{b}", "^a\\{b\\}$"}, + {"a\\*b", "^a\\*b$"}, + {"a\\?b", "^a\\?b$"}, + {"a\\[b", "^a\\[b$"}, + {"abc\\", "^abc\\\\$"}, + {"int_[0-9]", "^int_[0-9]$"}, + {"int_[!0-9]", "^int_[^0-9]$"}, + {"int_[^0-9]", "^int_[^0-9]$"}, + {"a[\\-]b", "^a[-]b$"}, + {"", "^$"}, + }; + + for (const auto& test_case : cases) { + std::string regex; + Status st = glob_to_regex(test_case.glob, ®ex); + EXPECT_TRUE(st.ok()) << st.to_string() << " pattern=" << test_case.glob; + EXPECT_EQ(regex, test_case.expected_regex) << "pattern=" << test_case.glob; + } + + std::string regex; + Status st = glob_to_regex("int_[0-9", ®ex); + EXPECT_FALSE(st.ok()); +} + +TEST(VariantUtilTest, GlobMatchRe2) { + struct Case { + std::string glob; + std::string candidate; + bool expected; + }; + const std::vector cases = { + {"*", "", true}, + {"*", "abc", true}, + {"?", "a", true}, + {"?", "", false}, + {"a?b", "acb", true}, + {"a?b", "ab", false}, + {"a*b", "ab", true}, + {"a*b", "axxxb", true}, + {"a*b", "a/b", true}, + {"a.b", "a.b", true}, + {"a.b", "acb", false}, + {"a+b", "a+b", true}, + {"a{b}", "a{b}", true}, + {R"(a\*b)", "a*b", true}, + {R"(a\?b)", "a?b", true}, + {R"(a\[b)", "a[b", true}, + {R"(abc\)", R"(abc\)", true}, + {"int_[0-9]", "int_1", true}, + {"int_[0-9]", "int_a", false}, + {"int_[!0-9]", "int_a", true}, + {"int_[!0-9]", "int_1", false}, + {"int_[^0-9]", "int_b", true}, + {"int_[^0-9]", "int_2", false}, + {R"(a[\-]b)", "a-b", true}, + {"", "", true}, + {"", "a", false}, + }; + + for (const auto& test_case : cases) { + bool matched = glob_match_re2(test_case.glob, test_case.candidate); + EXPECT_EQ(matched, test_case.expected) + << "pattern=" << test_case.glob << " candidate=" << test_case.candidate; + } + + EXPECT_FALSE(glob_match_re2("int_[0-9", "int_1")); +} + } // namespace doris::vectorized::variant_util \ No newline at end of file diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/OlapTable.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/OlapTable.java index d30471379a50ba..f259dc9e8e320e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/OlapTable.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/OlapTable.java @@ -40,6 +40,7 @@ import org.apache.doris.common.ErrorCode; import org.apache.doris.common.ErrorReport; import org.apache.doris.common.FeConstants; +import org.apache.doris.common.GlobRegexUtil; import org.apache.doris.common.Pair; import org.apache.doris.common.UserException; import org.apache.doris.common.io.DeepCopy; @@ -3733,12 +3734,11 @@ public Index getInvertedIndex(Column column, List subPath) { String childName = child.getName(); if (child.getFieldPatternType() == TPatternType.MATCH_NAME_GLOB) { try { - java.nio.file.PathMatcher matcher = java.nio.file.FileSystems.getDefault() - .getPathMatcher("glob:" + childName); - if (matcher.matches(java.nio.file.Paths.get(subPathString))) { + com.google.re2j.Pattern compiled = GlobRegexUtil.getOrCompilePattern(childName); + if (compiled.matcher(subPathString).matches()) { fieldPattern = childName; } - } catch (Exception e) { + } catch (com.google.re2j.PatternSyntaxException | IllegalArgumentException e) { continue; } } else if (child.getFieldPatternType() == TPatternType.MATCH_NAME) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/common/GlobRegexUtil.java b/fe/fe-core/src/main/java/org/apache/doris/common/GlobRegexUtil.java new file mode 100644 index 00000000000000..aa9bd5ca6b6df6 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/common/GlobRegexUtil.java @@ -0,0 +1,150 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.common; + +import com.google.re2j.Pattern; + +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Utility to convert a restricted glob pattern into a regex. + * + * Supported glob syntax: + * - '*' matches any sequence of characters + * - '?' matches any single character + * - '[...]' matches any character in the brackets + * - '[!...]' matches any character not in the brackets + * - '\\' escapes the next character + */ +public final class GlobRegexUtil { + // Small LRU to cap compiled pattern memory + private static final int REGEX_CACHE_CAPACITY = 256; + private static final Map REGEX_CACHE = new LinkedHashMap( + REGEX_CACHE_CAPACITY, 0.75f, true) { + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return size() > REGEX_CACHE_CAPACITY; + } + }; + + private GlobRegexUtil() { + } + + public static Pattern getOrCompilePattern(String globPattern) { + synchronized (REGEX_CACHE) { + Pattern cached = REGEX_CACHE.get(globPattern); + if (cached != null) { + return cached; + } + String regex = globToRegex(globPattern); + Pattern compiled = Pattern.compile(regex); + REGEX_CACHE.put(globPattern, compiled); + return compiled; + } + } + + public static String globToRegex(String pattern) { + StringBuilder regexBuilder = new StringBuilder(); + regexBuilder.append("^"); + boolean isEscaped = false; + int patternLength = pattern.length(); + for (int index = 0; index < patternLength; index++) { + char currentChar = pattern.charAt(index); + if (isEscaped) { + appendEscapedRegexChar(regexBuilder, currentChar); + isEscaped = false; + continue; + } + if (currentChar == '\\') { + isEscaped = true; + continue; + } + if (currentChar == '*') { + regexBuilder.append(".*"); + continue; + } + if (currentChar == '?') { + regexBuilder.append('.'); + continue; + } + if (currentChar == '[') { + int classIndex = index + 1; + boolean classClosed = false; + boolean isClassEscaped = false; + StringBuilder classBuffer = new StringBuilder(); + if (classIndex < patternLength + && (pattern.charAt(classIndex) == '!' || pattern.charAt(classIndex) == '^')) { + classBuffer.append('^'); + classIndex++; + } + for (; classIndex < patternLength; classIndex++) { + char classChar = pattern.charAt(classIndex); + if (isClassEscaped) { + classBuffer.append(classChar); + isClassEscaped = false; + continue; + } + if (classChar == '\\') { + isClassEscaped = true; + continue; + } + if (classChar == ']') { + classClosed = true; + break; + } + classBuffer.append(classChar); + } + if (!classClosed) { + throw new IllegalArgumentException("Unclosed character class in glob pattern: " + pattern); + } + regexBuilder.append('[').append(classBuffer).append(']'); + index = classIndex; + continue; + } + appendEscapedRegexChar(regexBuilder, currentChar); + } + if (isEscaped) { + appendEscapedRegexChar(regexBuilder, '\\'); + } + regexBuilder.append("$"); + return regexBuilder.toString(); + } + + private static void appendEscapedRegexChar(StringBuilder regexBuilder, char ch) { + switch (ch) { + case '.': + case '^': + case '$': + case '+': + case '(': + case ')': + case '|': + case '{': + case '}': + case '[': + case ']': + case '\\': + regexBuilder.append('\\').append(ch); + break; + default: + regexBuilder.append(ch); + break; + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantField.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantField.java index ccffa2f6ba5e40..a8e3bd9ded136b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantField.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantField.java @@ -17,15 +17,14 @@ package org.apache.doris.nereids.types; +import org.apache.doris.common.GlobRegexUtil; import org.apache.doris.nereids.util.Utils; import org.apache.doris.thrift.TPatternType; -import java.nio.file.FileSystems; -import java.nio.file.InvalidPathException; -import java.nio.file.PathMatcher; -import java.nio.file.Paths; +import com.google.re2j.Pattern; +import com.google.re2j.PatternSyntaxException; + import java.util.Objects; -import java.util.regex.PatternSyntaxException; /** * A field inside a VariantType. @@ -74,13 +73,14 @@ public String getComment() { /** * Check if the given field name matches this field's pattern. - * This method aligns with BE's fnmatch(pattern, path, FNM_PATHNAME) behavior. + * This method uses a restricted glob syntax converted to regex. * * Supported glob syntax: - * - '*' matches any sequence of characters except '/' - * - '?' matches any single character except '/' + * - '*' matches any sequence of characters + * - '?' matches any single character * - '[...]' matches any character in the brackets * - '[!...]' matches any character not in the brackets + * - '\\' escapes the next character * * @param fieldName the field name to check * @return true if the field name matches the pattern @@ -93,9 +93,9 @@ public boolean matches(String fieldName) { return false; } try { - PathMatcher matcher = FileSystems.getDefault().getPathMatcher("glob:" + pattern); - return matcher.matches(Paths.get(fieldName)); - } catch (PatternSyntaxException | InvalidPathException e) { + Pattern compiled = GlobRegexUtil.getOrCompilePattern(pattern); + return compiled.matcher(fieldName).matches(); + } catch (PatternSyntaxException | IllegalArgumentException e) { return false; } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/common/GlobRegexUtilTest.java b/fe/fe-core/src/test/java/org/apache/doris/common/GlobRegexUtilTest.java new file mode 100644 index 00000000000000..ff0823c8e0165a --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/common/GlobRegexUtilTest.java @@ -0,0 +1,76 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.common; + +import com.google.re2j.Pattern; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + + +public class GlobRegexUtilTest { + + private void assertGlobToRegex(String globPattern, String expectedRegex) { + String regex = GlobRegexUtil.globToRegex(globPattern); + Assertions.assertEquals(expectedRegex, regex, "pattern: " + globPattern); + } + + @Test + public void testGlobToRegexBasicTokens() { + assertGlobToRegex("*", "^.*$"); + assertGlobToRegex("?", "^.$"); + assertGlobToRegex("a?b", "^a.b$"); + assertGlobToRegex("a*b", "^a.*b$"); + } + + @Test + public void testGlobToRegexEscaping() { + assertGlobToRegex("a.b", "^a\\.b$"); + assertGlobToRegex("a+b", "^a\\+b$"); + assertGlobToRegex("a{b}", "^a\\{b\\}$"); + assertGlobToRegex("a\\*b", "^a\\*b$"); + assertGlobToRegex("a\\?b", "^a\\?b$"); + assertGlobToRegex("a\\[b", "^a\\[b$"); + assertGlobToRegex("abc\\", "^abc\\\\$"); + } + + @Test + public void testGlobToRegexCharacterClasses() { + assertGlobToRegex("int_[0-9]", "^int_[0-9]$"); + assertGlobToRegex("int_[!0-9]", "^int_[^0-9]$"); + assertGlobToRegex("int_[^0-9]", "^int_[^0-9]$"); + assertGlobToRegex("a[\\-]b", "^a[-]b$"); + } + + @Test + public void testGlobToRegexEmptyPattern() { + assertGlobToRegex("", "^$"); + } + + @Test + public void testGlobToRegexUnclosedClass() { + Assertions.assertThrows(IllegalArgumentException.class, + () -> GlobRegexUtil.globToRegex("int_[0-9")); + } + + @Test + public void testGetOrCompilePatternCache() { + Pattern first = GlobRegexUtil.getOrCompilePattern("num_*"); + Pattern second = GlobRegexUtil.getOrCompilePattern("num_*"); + Assertions.assertSame(first, second); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/types/VariantFieldMatchTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/types/VariantFieldMatchTest.java index 6551f14932e320..bfbeef49d58def 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/types/VariantFieldMatchTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/types/VariantFieldMatchTest.java @@ -245,13 +245,13 @@ public void testGlobUnclosedBracket() { @Test public void testGlobWithSlashSeparator() { - // With FNM_PATHNAME, '*' should not match '/' + // With glob->regex, '*' should match '/' VariantField field = new VariantField("int_*", BigIntType.INSTANCE, "", TPatternType.MATCH_NAME_GLOB.name()); Assertions.assertTrue(field.matches("int_nested")); Assertions.assertTrue(field.matches("int_nested.level1")); // '.' is matched by '*' - Assertions.assertFalse(field.matches("int_nested/level1")); // '/' is NOT matched by '*' + Assertions.assertTrue(field.matches("int_nested/level1")); // '/' is matched by '*' } @Test From 7675aff05c69d56a507237e0244cdea695468ee9 Mon Sep 17 00:00:00 2001 From: Gary Date: Wed, 4 Feb 2026 23:15:34 +0800 Subject: [PATCH 25/27] fix regex to pass ut --- be/src/vec/common/variant_util.cpp | 2 ++ .../src/main/java/org/apache/doris/common/GlobRegexUtil.java | 2 ++ 2 files changed, 4 insertions(+) diff --git a/be/src/vec/common/variant_util.cpp b/be/src/vec/common/variant_util.cpp index d186b334efb7c0..c5de50c4424eeb 100644 --- a/be/src/vec/common/variant_util.cpp +++ b/be/src/vec/common/variant_util.cpp @@ -109,6 +109,8 @@ inline void append_escaped_regex_char(std::string* regex_output, char ch) { case '^': case '$': case '+': + case '*': + case '?': case '(': case ')': case '|': diff --git a/fe/fe-core/src/main/java/org/apache/doris/common/GlobRegexUtil.java b/fe/fe-core/src/main/java/org/apache/doris/common/GlobRegexUtil.java index aa9bd5ca6b6df6..ff0687d5cc10f1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/common/GlobRegexUtil.java +++ b/fe/fe-core/src/main/java/org/apache/doris/common/GlobRegexUtil.java @@ -132,6 +132,8 @@ private static void appendEscapedRegexChar(StringBuilder regexBuilder, char ch) case '^': case '$': case '+': + case '*': + case '?': case '(': case ')': case '|': From 3f2fa07f20b1354c784094cd6deeb3fcacbf62a5 Mon Sep 17 00:00:00 2001 From: Gary Date: Thu, 5 Feb 2026 11:50:36 +0800 Subject: [PATCH 26/27] fix format and enhance test --- be/src/vec/common/variant_util.cpp | 20 ++-- .../rowset/segment_v2/variant_util_test.cpp | 56 ++++++++++- .../doris/common/GlobRegexUtilTest.java | 37 ++++++++ .../nereids/types/VariantFieldMatchTest.java | 94 +++++++++++++++++++ 4 files changed, 195 insertions(+), 12 deletions(-) diff --git a/be/src/vec/common/variant_util.cpp b/be/src/vec/common/variant_util.cpp index c5de50c4424eeb..069a64798d062a 100644 --- a/be/src/vec/common/variant_util.cpp +++ b/be/src/vec/common/variant_util.cpp @@ -37,8 +37,8 @@ #include #include #include -#include #include +#include #include #include #include @@ -53,7 +53,6 @@ #include "common/config.h" #include "common/status.h" #include "exprs/json_functions.h" -#include "re2/re2.h" #include "olap/olap_common.h" #include "olap/rowset/beta_rowset.h" #include "olap/rowset/rowset.h" @@ -64,6 +63,7 @@ #include "olap/tablet.h" #include "olap/tablet_fwd.h" #include "olap/tablet_schema.h" +#include "re2/re2.h" #include "runtime/client_cache.h" #include "runtime/define_primitive_type.h" #include "runtime/exec_env.h" @@ -111,7 +111,7 @@ inline void append_escaped_regex_char(std::string* regex_output, char ch) { case '+': case '*': case '?': - case '(': + case '(': case ')': case '|': case '{': @@ -145,8 +145,8 @@ std::shared_ptr get_or_build_re2(const std::string& glob_pattern) { std::lock_guard lock(g_glob_regex_cache_mutex); auto it = g_glob_regex_cache.find(glob_pattern); if (it != g_glob_regex_cache.end()) { - g_glob_regex_cache_lru.splice(g_glob_regex_cache_lru.begin(), - g_glob_regex_cache_lru, it->second.lru_it); + g_glob_regex_cache_lru.splice(g_glob_regex_cache_lru.begin(), g_glob_regex_cache_lru, + it->second.lru_it); return it->second.re2; } } @@ -163,13 +163,13 @@ std::shared_ptr get_or_build_re2(const std::string& glob_pattern) { std::lock_guard lock(g_glob_regex_cache_mutex); auto it = g_glob_regex_cache.find(glob_pattern); if (it != g_glob_regex_cache.end()) { - g_glob_regex_cache_lru.splice(g_glob_regex_cache_lru.begin(), - g_glob_regex_cache_lru, it->second.lru_it); + g_glob_regex_cache_lru.splice(g_glob_regex_cache_lru.begin(), g_glob_regex_cache_lru, + it->second.lru_it); return it->second.re2; } g_glob_regex_cache_lru.push_front(glob_pattern); g_glob_regex_cache.emplace(glob_pattern, - GlobRegexCacheEntry{compiled, g_glob_regex_cache_lru.begin()}); + GlobRegexCacheEntry {compiled, g_glob_regex_cache_lru.begin()}); if (g_glob_regex_cache.size() > kGlobRegexCacheCapacity) { const std::string& evict_key = g_glob_regex_cache_lru.back(); g_glob_regex_cache.erase(evict_key); @@ -179,7 +179,6 @@ std::shared_ptr get_or_build_re2(const std::string& glob_pattern) { return compiled; } - // Convert a restricted glob pattern into a regex. // Supported: '*', '?', '[...]', '\\' escape. Others are treated as literals. Status glob_to_regex(const std::string& glob_pattern, std::string* regex_pattern) { @@ -234,7 +233,8 @@ Status glob_to_regex(const std::string& glob_pattern, std::string* regex_pattern class_buffer.push_back(class_char); } if (!class_closed) { - return Status::InvalidArgument("Unclosed character class in glob pattern: {}", glob_pattern); + return Status::InvalidArgument("Unclosed character class in glob pattern: {}", + glob_pattern); } regex_pattern->append("["); regex_pattern->append(class_buffer); diff --git a/be/test/olap/rowset/segment_v2/variant_util_test.cpp b/be/test/olap/rowset/segment_v2/variant_util_test.cpp index 24e7c3667028d3..981983fc0f9d52 100644 --- a/be/test/olap/rowset/segment_v2/variant_util_test.cpp +++ b/be/test/olap/rowset/segment_v2/variant_util_test.cpp @@ -219,18 +219,38 @@ TEST(VariantUtilTest, GlobToRegex) { {"?", "^.$"}, {"a?b", "^a.b$"}, {"a*b", "^a.*b$"}, + {"a**b", "^a.*.*b$"}, + {"a??b", "^a..b$"}, + {"?*", "^..*$"}, + {"*?", "^.*.$"}, {"a.b", "^a\\.b$"}, {"a+b", "^a\\+b$"}, {"a{b}", "^a\\{b\\}$"}, - {"a\\*b", "^a\\*b$"}, + {R"(a\*b)", R"(^a\*b$)"}, {"a\\?b", "^a\\?b$"}, {"a\\[b", "^a\\[b$"}, {"abc\\", "^abc\\\\$"}, + {"a|b", "^a\\|b$"}, + {"a(b)c", "^a\\(b\\)c$"}, + {"a^b", "^a\\^b$"}, + {"a$b", "^a\\$b$"}, {"int_[0-9]", "^int_[0-9]$"}, {"int_[!0-9]", "^int_[^0-9]$"}, {"int_[^0-9]", "^int_[^0-9]$"}, {"a[\\-]b", "^a[-]b$"}, + {"a[b-d]e", "^a[b-d]e$"}, + {"a[\\]]b", "^a[]]b$"}, + {"a[\\!]b", "^a[!]b$"}, {"", "^$"}, + {"a[[]b", "^a[[]b$"}, + {"a[]b", "^a[]b$"}, + {"[]", "^[]$"}, + {"[!]", "^[^]$"}, + {"[^]", "^[^]$"}, + {"\\", "^\\\\$"}, + {"\\*", "^\\*$"}, + {"a\\*b", "^a\\*b$"}, + {"a[!\\]]b", "^a[^]]b$"}, }; for (const auto& test_case : cases) { @@ -243,6 +263,9 @@ TEST(VariantUtilTest, GlobToRegex) { std::string regex; Status st = glob_to_regex("int_[0-9", ®ex); EXPECT_FALSE(st.ok()); + + st = glob_to_regex("a[\\]b", ®ex); + EXPECT_FALSE(st.ok()); } TEST(VariantUtilTest, GlobMatchRe2) { @@ -260,11 +283,39 @@ TEST(VariantUtilTest, GlobMatchRe2) { {"a?b", "ab", false}, {"a*b", "ab", true}, {"a*b", "axxxb", true}, + {"a**b", "ab", true}, + {"a**b", "axxxb", true}, + {"?*", "", false}, + {"?*", "a", true}, + {"*?", "", false}, + {"*?", "a", true}, {"a*b", "a/b", true}, {"a.b", "a.b", true}, {"a.b", "acb", false}, {"a+b", "a+b", true}, {"a{b}", "a{b}", true}, + {"a|b", "a|b", true}, + {"a|b", "ab", false}, + {"a(b)c", "a(b)c", true}, + {"a(b)c", "abc", false}, + {"a^b", "a^b", true}, + {"a^b", "ab", false}, + {"a$b", "a$b", true}, + {"a$b", "ab", false}, + {"a[b-d]e", "ace", true}, + {"a[b-d]e", "aee", false}, + {"a[\\]]b", "a]b", true}, + {"a[\\]]b", "a[b", false}, + {"a[\\!]b", "a!b", true}, + {"a[\\!]b", "a]b", false}, + {"[]", "a", false}, + {"[!]", "]", false}, + {"\\", "\\", true}, + {"\\*", "\\abc", true}, + {"a[!\\]]b", "aXb", true}, + {"a[!\\]]b", "a]b", false}, + {"a[]b", "aXb", false}, + {"a[[]b", "a[b", true}, {R"(a\*b)", "a*b", true}, {R"(a\?b)", "a?b", true}, {R"(a\[b)", "a[b", true}, @@ -287,6 +338,7 @@ TEST(VariantUtilTest, GlobMatchRe2) { } EXPECT_FALSE(glob_match_re2("int_[0-9", "int_1")); + EXPECT_FALSE(glob_match_re2("a[\\]b", "a]b")); } -} // namespace doris::vectorized::variant_util \ No newline at end of file +} // namespace doris::vectorized::variant_util diff --git a/fe/fe-core/src/test/java/org/apache/doris/common/GlobRegexUtilTest.java b/fe/fe-core/src/test/java/org/apache/doris/common/GlobRegexUtilTest.java index ff0823c8e0165a..48f86c4e70c21a 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/common/GlobRegexUtilTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/common/GlobRegexUtilTest.java @@ -37,6 +37,15 @@ public void testGlobToRegexBasicTokens() { assertGlobToRegex("a*b", "^a.*b$"); } + @Test + public void testGlobToRegexRepeatedWildcards() { + assertGlobToRegex("a**b", "^a.*.*b$"); + assertGlobToRegex("a??b", "^a..b$"); + assertGlobToRegex("?*", "^..*$"); + assertGlobToRegex("*?", "^.*.$"); + } + + @Test public void testGlobToRegexEscaping() { assertGlobToRegex("a.b", "^a\\.b$"); @@ -46,6 +55,10 @@ public void testGlobToRegexEscaping() { assertGlobToRegex("a\\?b", "^a\\?b$"); assertGlobToRegex("a\\[b", "^a\\[b$"); assertGlobToRegex("abc\\", "^abc\\\\$"); + assertGlobToRegex("a|b", "^a\\|b$"); + assertGlobToRegex("a(b)c", "^a\\(b\\)c$"); + assertGlobToRegex("a^b", "^a\\^b$"); + assertGlobToRegex("a$b", "^a\\$b$"); } @Test @@ -54,6 +67,9 @@ public void testGlobToRegexCharacterClasses() { assertGlobToRegex("int_[!0-9]", "^int_[^0-9]$"); assertGlobToRegex("int_[^0-9]", "^int_[^0-9]$"); assertGlobToRegex("a[\\-]b", "^a[-]b$"); + assertGlobToRegex("a[b-d]e", "^a[b-d]e$"); + assertGlobToRegex("a[\\]]b", "^a[]]b$"); + assertGlobToRegex("a[\\!]b", "^a[!]b$"); } @Test @@ -61,12 +77,33 @@ public void testGlobToRegexEmptyPattern() { assertGlobToRegex("", "^$"); } + + @Test + public void testGlobToRegexWeirdClasses() { + assertGlobToRegex("a[[]b", "^a[[]b$"); + assertGlobToRegex("a[]b", "^a[]b$"); + Assertions.assertThrows(IllegalArgumentException.class, + () -> GlobRegexUtil.globToRegex("a[\\]b")); + } + @Test public void testGlobToRegexUnclosedClass() { Assertions.assertThrows(IllegalArgumentException.class, () -> GlobRegexUtil.globToRegex("int_[0-9")); } + + @Test + public void testGlobToRegexMoreWeirdCases() { + assertGlobToRegex("[]", "^[]$"); + assertGlobToRegex("[!]", "^[^]$"); + assertGlobToRegex("[^]", "^[^]$"); + assertGlobToRegex("\\", "^\\\\$"); + assertGlobToRegex("\\*", "^\\*$"); + assertGlobToRegex("a\\*b", "^a\\*b$"); + assertGlobToRegex("a[!\\]]b", "^a[^]]b$"); + } + @Test public void testGetOrCompilePatternCache() { Pattern first = GlobRegexUtil.getOrCompilePattern("num_*"); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/types/VariantFieldMatchTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/types/VariantFieldMatchTest.java index bfbeef49d58def..ecfb12b00485a0 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/types/VariantFieldMatchTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/types/VariantFieldMatchTest.java @@ -42,6 +42,44 @@ public void testExactMatch() { Assertions.assertFalse(field.matches("other_field")); } + @Test + public void testRegexMetaLiteralPatterns() { + VariantField pipe = new VariantField("a|b", BigIntType.INSTANCE, "", + TPatternType.MATCH_NAME_GLOB.name()); + Assertions.assertTrue(pipe.matches("a|b")); + Assertions.assertFalse(pipe.matches("ab")); + + VariantField paren = new VariantField("a(b)c", BigIntType.INSTANCE, "", + TPatternType.MATCH_NAME_GLOB.name()); + Assertions.assertTrue(paren.matches("a(b)c")); + Assertions.assertFalse(paren.matches("abc")); + + VariantField caret = new VariantField("a^b", BigIntType.INSTANCE, "", + TPatternType.MATCH_NAME_GLOB.name()); + Assertions.assertTrue(caret.matches("a^b")); + Assertions.assertFalse(caret.matches("ab")); + + VariantField dollar = new VariantField("a$b", BigIntType.INSTANCE, "", + TPatternType.MATCH_NAME_GLOB.name()); + Assertions.assertTrue(dollar.matches("a$b")); + Assertions.assertFalse(dollar.matches("ab")); + + VariantField range = new VariantField("a[b-d]e", BigIntType.INSTANCE, "", + TPatternType.MATCH_NAME_GLOB.name()); + Assertions.assertTrue(range.matches("ace")); + Assertions.assertFalse(range.matches("aee")); + + VariantField escapedRight = new VariantField("a[\\]]b", BigIntType.INSTANCE, "", + TPatternType.MATCH_NAME_GLOB.name()); + Assertions.assertTrue(escapedRight.matches("a]b")); + Assertions.assertFalse(escapedRight.matches("a[b")); + + VariantField escapedBang = new VariantField("a[\\!]b", BigIntType.INSTANCE, "", + TPatternType.MATCH_NAME_GLOB.name()); + Assertions.assertTrue(escapedBang.matches("a!b")); + Assertions.assertFalse(escapedBang.matches("a]b")); + } + @Test public void testExactMatchDoesNotTreatGlob() { VariantField field = new VariantField("num_*", BigIntType.INSTANCE, "", @@ -100,6 +138,25 @@ public void testGlobMatchAll() { Assertions.assertTrue(field.matches("a.b.c")); } + @Test + public void testRepeatedWildcardPatterns() { + VariantField doubleStar = new VariantField("a**b", BigIntType.INSTANCE, "", + TPatternType.MATCH_NAME_GLOB.name()); + Assertions.assertTrue(doubleStar.matches("ab")); + Assertions.assertTrue(doubleStar.matches("axxxb")); + + VariantField questionStar = new VariantField("?*", BigIntType.INSTANCE, "", + TPatternType.MATCH_NAME_GLOB.name()); + Assertions.assertFalse(questionStar.matches("")); + Assertions.assertTrue(questionStar.matches("a")); + + VariantField starQuestion = new VariantField("*?", BigIntType.INSTANCE, "", + TPatternType.MATCH_NAME_GLOB.name()); + Assertions.assertFalse(starQuestion.matches("")); + Assertions.assertTrue(starQuestion.matches("a")); + } + + @Test public void testGlobMatchWithDot() { // Pattern: metrics.* should match metrics.score, metrics.count, etc. @@ -243,6 +300,43 @@ public void testGlobUnclosedBracket() { Assertions.assertFalse(field.matches("int_1")); } + @Test + public void testWeirdGlobPatterns() { + VariantField emptyClass = new VariantField("a[]b", BigIntType.INSTANCE, "", + TPatternType.MATCH_NAME_GLOB.name()); + Assertions.assertFalse(emptyClass.matches("aXb")); + + VariantField escapedBracket = new VariantField("a[[]b", BigIntType.INSTANCE, "", + TPatternType.MATCH_NAME_GLOB.name()); + Assertions.assertTrue(escapedBracket.matches("a[b")); + } + + @Test + public void testMoreWeirdGlobPatterns() { + VariantField emptyClass = new VariantField("[]", BigIntType.INSTANCE, "", + TPatternType.MATCH_NAME_GLOB.name()); + Assertions.assertFalse(emptyClass.matches("a")); + Assertions.assertFalse(emptyClass.matches("")); + + VariantField negatedEmpty = new VariantField("[!]", BigIntType.INSTANCE, "", + TPatternType.MATCH_NAME_GLOB.name()); + Assertions.assertFalse(negatedEmpty.matches("]")); + + VariantField escapedBackslash = new VariantField("\\", BigIntType.INSTANCE, "", + TPatternType.MATCH_NAME_GLOB.name()); + Assertions.assertTrue(escapedBackslash.matches("\\")); + + VariantField escapedStar = new VariantField("\\*", BigIntType.INSTANCE, "", + TPatternType.MATCH_NAME_GLOB.name()); + Assertions.assertTrue(escapedStar.matches("*")); + Assertions.assertFalse(escapedStar.matches("\\\\abc")); + + VariantField escapedCharInClass = new VariantField("a[!\\]]b", BigIntType.INSTANCE, "", + TPatternType.MATCH_NAME_GLOB.name()); + Assertions.assertTrue(escapedCharInClass.matches("aXb")); + Assertions.assertFalse(escapedCharInClass.matches("a]b")); + } + @Test public void testGlobWithSlashSeparator() { // With glob->regex, '*' should match '/' From 0d23a248defb37f0a4d01952a65d295c4ffbbfc0 Mon Sep 17 00:00:00 2001 From: Gary Date: Thu, 5 Feb 2026 15:37:22 +0800 Subject: [PATCH 27/27] fix format and BE UT --- be/test/olap/rowset/segment_v2/variant_util_test.cpp | 2 +- .../org/apache/doris/nereids/types/VariantFieldMatchTest.java | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/be/test/olap/rowset/segment_v2/variant_util_test.cpp b/be/test/olap/rowset/segment_v2/variant_util_test.cpp index 981983fc0f9d52..bb87ee0ebd7d78 100644 --- a/be/test/olap/rowset/segment_v2/variant_util_test.cpp +++ b/be/test/olap/rowset/segment_v2/variant_util_test.cpp @@ -311,7 +311,7 @@ TEST(VariantUtilTest, GlobMatchRe2) { {"[]", "a", false}, {"[!]", "]", false}, {"\\", "\\", true}, - {"\\*", "\\abc", true}, + {"\\*", "\\abc", false}, {"a[!\\]]b", "aXb", true}, {"a[!\\]]b", "a]b", false}, {"a[]b", "aXb", false}, diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/types/VariantFieldMatchTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/types/VariantFieldMatchTest.java index ecfb12b00485a0..66289238e86414 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/types/VariantFieldMatchTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/types/VariantFieldMatchTest.java @@ -156,7 +156,6 @@ public void testRepeatedWildcardPatterns() { Assertions.assertTrue(starQuestion.matches("a")); } - @Test public void testGlobMatchWithDot() { // Pattern: metrics.* should match metrics.score, metrics.count, etc.