diff --git a/core/src/main/java/org/apache/calcite/rel/rules/AggregateExtractLiteralAggRule.java b/core/src/main/java/org/apache/calcite/rel/rules/AggregateExtractLiteralAggRule.java new file mode 100644 index 00000000000..1374f2cd9b8 --- /dev/null +++ b/core/src/main/java/org/apache/calcite/rel/rules/AggregateExtractLiteralAggRule.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.rel.rules; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.tools.RelBuilder; + +import org.immutables.value.Value; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Rule transforms an {@link org.apache.calcite.rel.core.Aggregate} containing + * {@code LITERAL_AGG} aggregate function into an {@code Aggregate} that still + * performs "group by" on the relevant groups, while placing a {@code Project} + * RelNode on top that returns the literal value. + */ +@Value.Enclosing +public class AggregateExtractLiteralAggRule + extends RelRule + implements TransformationRule { + + protected AggregateExtractLiteralAggRule(Config config) { + super(config); + } + + @Override public void onMatch(RelOptRuleCall call) { + final RelBuilder relBuilder = call.builder(); + final Aggregate aggregate = call.rel(0); + final List aggCalls = aggregate.getAggCallList(); + if (aggCalls.isEmpty()) { + return; + } + + // Collect indices of LITERAL_AGG calls. + final List literalAggIndices = new ArrayList<>(); + for (int i = 0; i < aggCalls.size(); i++) { + final AggregateCall ac = aggCalls.get(i); + if (ac.getAggregation().getKind() == SqlKind.LITERAL_AGG) { + literalAggIndices.add(i); + } + } + + if (literalAggIndices.isEmpty()) { + // nothing to do + return; + } + + // Build new AggregateCall list without LITERAL_AGG entries. + final List newAggCalls = new ArrayList<>(); + final Map oldAggIndexToNewAggIndex = new HashMap<>(); + int newAggPos = 0; + for (int i = 0; i < aggCalls.size(); i++) { + if (!literalAggIndices.contains(i)) { + newAggCalls.add(aggCalls.get(i)); + oldAggIndexToNewAggIndex.put(i, newAggPos++); + } + } + + relBuilder.push(aggregate.getInput()); + final RelBuilder.GroupKey groupKey = + relBuilder.groupKey(aggregate.getGroupSet(), aggregate.getGroupSets()); + final RelNode newAggregate = relBuilder.aggregate(groupKey, newAggCalls).build(); + + final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder(); + + // Number of group columns in output (group keys appear first). + final int groupCount = aggregate.getGroupSet().cardinality(); + final int origAggCount = aggCalls.size(); + final int origOutputCount = groupCount + origAggCount; + + // Build projection expressions to restore original output layout. + final List projects = new ArrayList<>(origOutputCount); + for (int outPos = 0; outPos < origOutputCount; outPos++) { + if (outPos < groupCount) { + // Group key columns remain in the same positions. + projects.add(rexBuilder.makeInputRef(newAggregate, outPos)); + } else { + // Aggregate output: determine original aggregate index. + final int origAggIndex = outPos - groupCount; + if (literalAggIndices.contains(origAggIndex)) { + // Replacement for LITERAL_AGG: try to extract literal from the original AggregateCall. + projects.add(aggCalls.get(origAggIndex).rexList.get(0)); + } else { + // Non-literal aggregate: compute its new output index in newAggregate. + final Integer newAggIndex = oldAggIndexToNewAggIndex.get(origAggIndex); + if (newAggIndex != null) { + projects.add(rexBuilder.makeInputRef(newAggregate, groupCount + newAggIndex)); + } + } + } + } + + relBuilder.push(newAggregate); + relBuilder.project(projects); + call.transformTo(relBuilder.build()); + } + + /** Rule configuration. */ + @Value.Immutable + public interface Config extends RelRule.Config { + Config DEFAULT = ImmutableAggregateExtractLiteralAggRule.Config.of() + .withOperandSupplier(b0 -> + b0.operand(Aggregate.class).anyInputs()); + + @Override default AggregateExtractLiteralAggRule toRule() { + return new AggregateExtractLiteralAggRule(this); + } + } +} diff --git a/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java b/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java index c604f713580..fadcd48a4ea 100644 --- a/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java +++ b/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java @@ -959,4 +959,8 @@ private CoreRules() {} * into equivalent {@link Union} ALL of GROUP BY operations. */ public static final AggregateGroupingSetsToUnionRule AGGREGATE_GROUPING_SETS_TO_UNION = AggregateGroupingSetsToUnionRule.Config.DEFAULT.toRule(); + + /** Rule that gets rid of the LITERAL_AGG into most databases can handle. */ + public static final AggregateExtractLiteralAggRule AGGREGATE_EXTRACT_LITERAL_AGG = + AggregateExtractLiteralAggRule.Config.DEFAULT.toRule(); } diff --git a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java index 959cdada0bc..1ca667ad838 100644 --- a/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java +++ b/core/src/test/java/org/apache/calcite/test/RelOptRulesTest.java @@ -11973,4 +11973,46 @@ private void checkLoptOptimizeJoinRule(LoptOptimizeJoinRule rule) { .check(); } + /** Test case of + * [CALCITE-7242] + * Implement a rule to eliminate LITERAL_AGG so that other databases can handle it. */ + @Test void testAggregateExtractLiteralAggRule1() { + final String sql = "select deptno, name = ANY (\n" + + " select mgr from emp)\n" + + "from dept"; + sql(sql) + .withSubQueryRules() + .withLateDecorrelate(true) + .withAfter((fixture, rel) -> { + final HepProgram program = HepProgram.builder() + .addRuleInstance(CoreRules.AGGREGATE_EXTRACT_LITERAL_AGG) + .build(); + final HepPlanner hep = new HepPlanner(program); + hep.setRoot(rel); + return hep.findBestExp(); + }) + .check(); + } + + /** Test case of + * [CALCITE-7242] + * Implement a rule to eliminate LITERAL_AGG so that other databases can handle it. */ + @Test void testAggregateExtractLiteralAggRule2() { + final String sql = "select empno\n" + + "from sales.emp\n" + + "where deptno in (select deptno from sales.emp where empno < 20)\n" + + "or emp.sal < 100"; + sql(sql) + .withSubQueryRules() + .withLateDecorrelate(true) + .withAfter((fixture, rel) -> { + final HepProgram program = HepProgram.builder() + .addRuleInstance(CoreRules.AGGREGATE_EXTRACT_LITERAL_AGG) + .build(); + final HepPlanner hep = new HepPlanner(program); + hep.setRoot(rel); + return hep.findBestExp(); + }) + .check(); + } } diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml index 2b2deaa226b..7cc4ca85a97 100644 --- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml +++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml @@ -229,6 +229,97 @@ LogicalProject(HIREDATE=[$1]) LogicalProject(MGR=[$3]) LogicalFilter(condition=[AND(IS NULL($3), =($4, CURRENT_TIMESTAMP))]) LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + + + + + + + + + + + ($2, 0)), AND(<($3, $2), null, <>($2, 0), IS NULL($5)))]) + LogicalJoin(condition=[=($1, $4)], joinType=[left]) + LogicalJoin(condition=[true], joinType=[inner]) + LogicalTableScan(table=[[CATALOG, SALES, DEPT]]) + LogicalAggregate(group=[{}], c=[COUNT()], ck=[COUNT($0)]) + LogicalProject(MGR=[$3]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalAggregate(group=[{0}], i=[LITERAL_AGG(true)]) + LogicalProject(MGR=[$3]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + + + ($2, 0)), AND(<($3, $2), null, <>($2, 0), IS NULL($5)))]) + LogicalJoin(condition=[=($1, $4)], joinType=[left]) + LogicalJoin(condition=[true], joinType=[inner]) + LogicalTableScan(table=[[CATALOG, SALES, DEPT]]) + LogicalAggregate(group=[{}], c=[COUNT()], ck=[COUNT($0)]) + LogicalProject(MGR=[$3]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) + LogicalProject(MGR=[$0], $f1=[true]) + LogicalAggregate(group=[{0}]) + LogicalProject(MGR=[$3]) + LogicalTableScan(table=[[CATALOG, SALES, EMP]]) +]]> + + + + + + + + + + + + + +