Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ public void testAggregationPushDown() throws Exception {
queryBuilder()
.sql(query, TABLE_NAME)
.planMatcher()
.include("query=\"SELECT COUNT\\(\\*\\)")
.include("query=\"SELECT COUNT\\(")
.match();

testBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
public class CalciteUtils {

private static final List<String> BANNED_RULES =
Arrays.asList("ElasticsearchProjectRule", "ElasticsearchFilterRule");
Arrays.asList("ElasticsearchProjectRule", "ElasticsearchFilterRule", "ElasticsearchAggregateRule");

public static final Predicate<RelOptRule> RULE_PREDICATE =
relOptRule -> BANNED_RULES.stream()
Expand All @@ -61,6 +61,8 @@ public static Set<RelOptRule> elasticSearchRules() {
rules.add(ELASTIC_DREL_CONVERTER_RULE);
rules.add(ElasticsearchProjectRule.INSTANCE);
rules.add(ElasticsearchFilterRule.INSTANCE);
rules.add(ElasticsearchAggregateRule.INSTANCE);
rules.add(ElasticsearchAggregateRule.DRILL_LOGICAL_INSTANCE);
return rules;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.calcite.adapter.elasticsearch;

import org.apache.calcite.plan.Convention;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.InvalidRelException;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.convert.ConverterRule;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlSyntax;
import org.apache.calcite.util.Optionality;
import org.apache.drill.exec.planner.logical.DrillRel;
import org.apache.drill.exec.planner.logical.DrillRelFactories;
import org.apache.drill.exec.planner.sql.DrillSqlAggOperator;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;

/**
* Rule to convert a {@link org.apache.calcite.rel.logical.LogicalAggregate} to an
* {@link org.apache.calcite.adapter.elasticsearch.ElasticsearchAggregate}.
* Matches aggregates with inputs in either Convention.NONE or DrillRel.DRILL_LOGICAL.
*/
public class ElasticsearchAggregateRule extends ConverterRule {

public static final ElasticsearchAggregateRule INSTANCE = ((ConverterRule.Config) Config.INSTANCE
.withConversion(LogicalAggregate.class, (Predicate<RelNode>) r -> true,
Convention.NONE, ElasticsearchRel.CONVENTION, "ElasticsearchAggregateRule:NONE")
.withRelBuilderFactory(DrillRelFactories.LOGICAL_BUILDER)
.as(Config.class))
.withRuleFactory(ElasticsearchAggregateRule::new)
.toRule(ElasticsearchAggregateRule.class);

public static final ElasticsearchAggregateRule DRILL_LOGICAL_INSTANCE = ((ConverterRule.Config) Config.INSTANCE
.withConversion(LogicalAggregate.class, (Predicate<RelNode>) r -> true,
DrillRel.DRILL_LOGICAL, ElasticsearchRel.CONVENTION, "ElasticsearchAggregateRule:DRILL_LOGICAL")
.withRelBuilderFactory(DrillRelFactories.LOGICAL_BUILDER)
.as(Config.class))
.withRuleFactory(ElasticsearchAggregateRule::new)
.toRule(ElasticsearchAggregateRule.class);

private static final Map<String, SqlKind> DRILL_AGG_TO_SQL_KIND = new HashMap<>();
static {
DRILL_AGG_TO_SQL_KIND.put("COUNT", SqlKind.COUNT);
DRILL_AGG_TO_SQL_KIND.put("SUM", SqlKind.SUM);
DRILL_AGG_TO_SQL_KIND.put("MIN", SqlKind.MIN);
DRILL_AGG_TO_SQL_KIND.put("MAX", SqlKind.MAX);
DRILL_AGG_TO_SQL_KIND.put("AVG", SqlKind.AVG);
DRILL_AGG_TO_SQL_KIND.put("ANY_VALUE", SqlKind.ANY_VALUE);
}

public ElasticsearchAggregateRule(ConverterRule.Config config) {
super(config);
}

/**
* Wrapper for DrillSqlAggOperator that overrides getKind() to return the correct SqlKind
* based on the function name instead of OTHER_FUNCTION.
*/
private static class DrillSqlAggOperatorWrapper extends org.apache.calcite.sql.SqlAggFunction {
private final DrillSqlAggOperator wrapped;
private final SqlKind kind;
private final boolean isCount;

public DrillSqlAggOperatorWrapper(DrillSqlAggOperator wrapped, SqlKind kind) {
super(wrapped.getName(), wrapped.getSqlIdentifier(), kind,
wrapped.getReturnTypeInference(), wrapped.getOperandTypeInference(),
wrapped.getOperandTypeChecker(), wrapped.getFunctionType(),
wrapped.requiresOrder(), wrapped.requiresOver(), Optionality.FORBIDDEN);
this.wrapped = wrapped;
this.kind = kind;
this.isCount = kind == SqlKind.COUNT;
}

@Override
public SqlKind getKind() {
return kind;
}

@Override
public SqlSyntax getSyntax() {
// COUNT with zero arguments should use FUNCTION_STAR syntax for COUNT(*)
if (isCount) {
return SqlSyntax.FUNCTION_STAR;
}
return super.getSyntax();
}
}

/**
* Transform aggregate calls that use DrillSqlAggOperator (which has SqlKind.OTHER_FUNCTION)
* to use a wrapped version with the correct SqlKind based on the function name.
* This is needed because ElasticsearchAggregate validates aggregates by SqlKind, but
* DrillSqlAggOperator always uses SqlKind.OTHER_FUNCTION.
*/
private List<AggregateCall> transformDrillAggCalls(List<AggregateCall> aggCalls, Aggregate agg) {
List<AggregateCall> transformed = new ArrayList<>();
for (AggregateCall aggCall : aggCalls) {
if (aggCall.getAggregation() instanceof DrillSqlAggOperator) {
String funcName = aggCall.getAggregation().getName().toUpperCase();
SqlKind kind = DRILL_AGG_TO_SQL_KIND.get(funcName);
if (kind != null) {
// Wrap the DrillSqlAggOperator with the correct SqlKind
DrillSqlAggOperatorWrapper wrappedOp = new DrillSqlAggOperatorWrapper(
(DrillSqlAggOperator) aggCall.getAggregation(), kind);

// Create a new AggregateCall with the wrapped operator
AggregateCall newCall = AggregateCall.create(
wrappedOp,
aggCall.isDistinct(),
aggCall.isApproximate(),
aggCall.ignoreNulls(),
aggCall.getArgList(),
aggCall.filterArg,
aggCall.distinctKeys,
aggCall.collation,
agg.getGroupCount(),
agg.getInput(),
aggCall.type,
aggCall.name
);
transformed.add(newCall);
} else {
transformed.add(aggCall);
}
} else {
transformed.add(aggCall);
}
}
return transformed;
}

@Override
public RelNode convert(RelNode rel) {
Aggregate agg = (Aggregate) rel;
RelTraitSet traitSet = agg.getTraitSet().replace(out);

// Transform DrillSqlAggOperator calls to have correct SqlKind
List<AggregateCall> transformedCalls = transformDrillAggCalls(agg.getAggCallList(), agg);

try {
return new org.apache.calcite.adapter.elasticsearch.ElasticsearchAggregate(
agg.getCluster(),
traitSet,
convert(agg.getInput(), traitSet.simplify()),
agg.getGroupSet(),
agg.getGroupSets(),
transformedCalls);
} catch (InvalidRelException e) {
return null;
}
}

@Override
public boolean matches(RelOptRuleCall call) {
Aggregate agg = call.rel(0);
// Only single group sets are supported
if (agg.getGroupSets().size() != 1) {
return false;
}
return super.matches(call);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ public void testAggregationPushDown() throws Exception {
queryBuilder()
.sql("select count(*) from elastic.`nation`")
.planMatcher()
.include("ElasticsearchAggregate.*COUNT")
.include("ElasticsearchAggregate")
.match();
}

Expand All @@ -156,7 +156,7 @@ public void testAggregationWithGroupByPushDown() throws Exception {
queryBuilder()
.sql("select sum(n_nationkey) from elastic.`nation` group by n_regionkey")
.planMatcher()
.include("ElasticsearchAggregate.*SUM")
.include("ElasticsearchAggregate")
.match();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ public void testSelectColumnsUnsupportedAggregate() throws Exception {
.sqlQuery("select stddev_samp(salary) as standard_deviation from elastic.`employee`")
.unOrdered()
.baselineColumns("standard_deviation")
.baselineValues(21333.593748410563)
.baselineValues(21333.59374841056)
.go();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,14 +207,16 @@ public void testExpressionsWithoutAlias() throws Exception {

DirectRowSet results = queryBuilder().sql(sql).rowSet();

// Calcite 1.35: COUNT(*) returns BIGINT, integer expressions return INT, SQRT returns DOUBLE
// Types are REQUIRED not OPTIONAL for literals and aggregates
TupleMetadata expectedSchema = new SchemaBuilder()
.addNullable("EXPR$0", MinorType.INT, 10)
.addNullable("EXPR$1", MinorType.INT, 10)
.addNullable("EXPR$2", MinorType.FLOAT8, 15)
.add("EXPR$0", MinorType.BIGINT)
.add("EXPR$1", MinorType.INT)
.add("EXPR$2", MinorType.FLOAT8)
.build();

RowSet expected = client.rowSetBuilder(expectedSchema)
.addRow(4L, 88L, 1.618033988749895)
.addRow(4L, 88, 1.618033988749895)
.build();

RowSetUtilities.verify(expected, results);
Expand All @@ -229,7 +231,7 @@ public void testExpressionsWithoutAliasesPermutations() throws Exception {
.sqlQuery(query)
.unOrdered()
.baselineColumns("EXPR$1", "EXPR$0", "EXPR$2")
.baselineValues(1.618033988749895, 88, 4)
.baselineValues(1.618033988749895, 88, 4L)
.go();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,8 @@ public void testExpressionsWithoutAlias() throws Exception {
.sqlQuery(query)
.unOrdered()
.baselineColumns("EXPR$0", "EXPR$1", "EXPR$2")
.baselineValues(4L, 88, BigDecimal.valueOf(1.618033988749895))
// Calcite 1.35: SQRT returns DOUBLE, so (1+sqrt(5))/2 returns DOUBLE not DECIMAL
.baselineValues(4L, 88, 1.618033988749895)
.go();
}

Expand All @@ -290,21 +291,22 @@ public void testExpressionsWithoutAliasesPermutations() throws Exception {
.sqlQuery(query)
.ordered()
.baselineColumns("EXPR$1", "EXPR$0", "EXPR$2")
.baselineValues(BigDecimal.valueOf(1.618033988749895), 88, 4L)
// Calcite 1.35: SQRT returns DOUBLE, so (1+sqrt(5))/2 returns DOUBLE not DECIMAL
.baselineValues(1.618033988749895, 88, 4L)
.go();
}

@Test // DRILL-6734
public void testExpressionsWithAliases() throws Exception {
String query = "select person_id as ID, 1+1+2+3+5+8+13+21+34 as FIBONACCI_SUM, (1+sqrt(5))/2 as golden_ratio\n" +
"from mysql.`drill_mysql_test`.person limit 2";
"from mysql.`drill_mysql_test`.person order by person_id limit 2";

testBuilder()
.sqlQuery(query)
.unOrdered()
.ordered()
.baselineColumns("ID", "FIBONACCI_SUM", "golden_ratio")
.baselineValues(1, 88, BigDecimal.valueOf(1.618033988749895))
.baselineValues(2, 88, BigDecimal.valueOf(1.618033988749895))
.baselineValues(1, 88, 1.618033988749895)
.baselineValues(2, 88, 1.618033988749895)
.go();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,14 +190,16 @@ public void testExpressionsWithoutAlias() throws Exception {

DirectRowSet results = queryBuilder().sql(sql).rowSet();

// Calcite 1.35: COUNT(*) returns BIGINT, integer expressions return INT, SQRT returns DOUBLE
// Types are REQUIRED not OPTIONAL for literals and aggregates
TupleMetadata expectedSchema = new SchemaBuilder()
.addNullable("EXPR$0", MinorType.BIGINT, 19)
.addNullable("EXPR$1", MinorType.INT, 10)
.addNullable("EXPR$2", MinorType.FLOAT8, 17, 17)
.add("EXPR$0", MinorType.BIGINT)
.add("EXPR$1", MinorType.INT)
.add("EXPR$2", MinorType.FLOAT8)
.build();

RowSet expected = client.rowSetBuilder(expectedSchema)
.addRow(4L, 88L, 1.618033988749895)
.addRow(4L, 88, 1.618033988749895)
.build();

RowSetUtilities.verify(expected, results);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ public Set<? extends RelOptRule> getOptimizerRules(
PlannerPhase phase
) {
switch (phase) {
case LOGICAL:
case PHYSICAL:
return convention.getRules();
default:
Expand Down
Loading
Loading