diff --git a/modules/calcite/src/main/java/org/apache/ignite/internal/processors/query/calcite/exec/exp/agg/Accumulators.java b/modules/calcite/src/main/java/org/apache/ignite/internal/processors/query/calcite/exec/exp/agg/Accumulators.java index 9c15e4a7a2375..5f08a1699fd02 100644 --- a/modules/calcite/src/main/java/org/apache/ignite/internal/processors/query/calcite/exec/exp/agg/Accumulators.java +++ b/modules/calcite/src/main/java/org/apache/ignite/internal/processors/query/calcite/exec/exp/agg/Accumulators.java @@ -78,6 +78,8 @@ private static Supplier> accumulatorFunctionFactory( return avgFactory(call, hnd); case "SUM": return sumFactory(call, hnd); + case "SUM_WITH_KEEP": + return sumWithKeepFactory(call, ctx); case "$SUM0": return sumEmptyIsZeroFactory(call, hnd); case "MIN": @@ -213,6 +215,34 @@ private static Supplier> sumEmptyIsZeroFactory(AggregateC } } + /** */ + private static Supplier> sumWithKeepFactory( + AggregateCall call, + ExecutionContext ctx + ) { + assert call.getCollation() != null && !call.getCollation().getFieldCollations().isEmpty(); + + RowHandler hnd = ctx.rowHandler(); + Comparator cmp = ctx.expressionFactory().comparator(call.getCollation()); + + switch (call.type.getSqlTypeName()) { + case BIGINT: + case DECIMAL: + return () -> new SumWithKeep<>(call, hnd, cmp, DECIMAL); + + case DOUBLE: + case REAL: + case FLOAT: + return () -> new SumWithKeep<>(call, hnd, cmp, DOUBLE); + + case TINYINT: + case SMALLINT: + case INTEGER: + default: + return () -> new SumWithKeep<>(call, hnd, cmp, BIGINT); + } + } + /** */ private static Supplier> minFactory(AggregateCall call, RowHandler hnd) { switch (call.type.getSqlTypeName()) { @@ -696,6 +726,131 @@ public Sum(AggregateCall aggCall, Accumulator acc, RowHandler hnd) { } } + /** SUM(value) over rows tied at the first or last ordering key. */ + private static class SumWithKeep extends AbstractAccumulator implements StoringAccumulator { + /** Type used for intermediate sum. */ + private final SqlTypeName sumType; + + /** Comparator for WITHIN GROUP collation. */ + private final transient Comparator cmp; + + /** Row having the selected ordering key. */ + private Row bestRow; + + /** Intermediate sum. */ + private Object sum; + + /** Whether at least one row was seen. */ + private boolean touched; + + /** Whether the first ordering key should be kept. */ + private boolean first; + + /** */ + SumWithKeep(AggregateCall aggCall, RowHandler hnd, Comparator cmp, SqlTypeName sumType) { + super(aggCall, hnd); + + this.cmp = cmp; + this.sumType = sumType; + } + + /** {@inheritDoc} */ + @Override public void add(Row row) { + boolean first0 = "FIRST".equalsIgnoreCase(get(1, row).toString()); + + if (!touched) { + touched = true; + first = first0; + bestRow = row; + sum = get(0, row); + + return; + } + + assert first == first0; + + int cmp0 = cmp.compare(row, bestRow); + + if ((first && cmp0 < 0) || (!first && cmp0 > 0)) { + bestRow = row; + sum = get(0, row); + } + else if (cmp0 == 0) + addToSum(get(0, row)); + } + + /** {@inheritDoc} */ + @Override public void apply(Accumulator other) { + SumWithKeep other0 = (SumWithKeep)other; + + if (!other0.touched) + return; + + if (!touched) { + touched = true; + first = other0.first; + bestRow = other0.bestRow; + sum = other0.sum; + + return; + } + + assert first == other0.first; + + int cmp0 = cmp.compare(other0.bestRow, bestRow); + + if ((first && cmp0 < 0) || (!first && cmp0 > 0)) { + bestRow = other0.bestRow; + sum = other0.sum; + } + else if (cmp0 == 0) + addToSum(other0.sum); + } + + /** */ + private void addToSum(Object val) { + if (val == null) + return; + + if (sum == null) { + sum = val; + + return; + } + + switch (sumType) { + case DECIMAL: + sum = ((BigDecimal)sum).add((BigDecimal)val); + break; + + case DOUBLE: + sum = (Double)sum + (Double)val; + break; + + default: + sum = (Long)sum + (Long)val; + } + } + + /** {@inheritDoc} */ + @Override public Object end() { + return sum; + } + + /** {@inheritDoc} */ + @Override public List argumentTypes(IgniteTypeFactory typeFactory) { + return F.asList( + typeFactory.createTypeWithNullability(typeFactory.createSqlType(sumType), true), + typeFactory.createTypeWithNullability(typeFactory.createSqlType(VARCHAR), false) + ); + } + + /** {@inheritDoc} */ + @Override public RelDataType returnType(IgniteTypeFactory typeFactory) { + return typeFactory.createTypeWithNullability(typeFactory.createSqlType(sumType), true); + } + } + /** */ private static class DoubleSumEmptyIsZero extends AbstractAccumulator { /** */ diff --git a/modules/calcite/src/main/java/org/apache/ignite/internal/processors/query/calcite/sql/fun/IgniteOwnSqlOperatorTable.java b/modules/calcite/src/main/java/org/apache/ignite/internal/processors/query/calcite/sql/fun/IgniteOwnSqlOperatorTable.java index b0a43abd2be06..cbea591417064 100644 --- a/modules/calcite/src/main/java/org/apache/ignite/internal/processors/query/calcite/sql/fun/IgniteOwnSqlOperatorTable.java +++ b/modules/calcite/src/main/java/org/apache/ignite/internal/processors/query/calcite/sql/fun/IgniteOwnSqlOperatorTable.java @@ -17,6 +17,7 @@ package org.apache.ignite.internal.processors.query.calcite.sql.fun; import org.apache.calcite.sql.SqlFunction; +import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlFunctionCategory; import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.type.InferTypes; @@ -39,6 +40,9 @@ public class IgniteOwnSqlOperatorTable extends ReflectiveSqlOperatorTable { */ private static IgniteOwnSqlOperatorTable instance; + /** Sum of values having the first or last ordering key. */ + public static final SqlAggFunction SUM_WITH_KEEP = new SqlSumWithKeepAggFunction(); + /** * */ diff --git a/modules/calcite/src/main/java/org/apache/ignite/internal/processors/query/calcite/sql/fun/SqlSumWithKeepAggFunction.java b/modules/calcite/src/main/java/org/apache/ignite/internal/processors/query/calcite/sql/fun/SqlSumWithKeepAggFunction.java new file mode 100644 index 0000000000000..f96f34a8ebca0 --- /dev/null +++ b/modules/calcite/src/main/java/org/apache/ignite/internal/processors/query/calcite/sql/fun/SqlSumWithKeepAggFunction.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.ignite.internal.processors.query.calcite.sql.fun; + +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlCallBinding; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.type.OperandTypes; +import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SqlTypeFamily; +import org.apache.calcite.util.Optionality; +import org.apache.ignite.internal.processors.query.calcite.util.IgniteResource; + +import static org.apache.calcite.util.Static.RESOURCE; + +/** + * Aggregate that sums values having the first or last ordering key. + * + *

The syntax is + * {@code SUM_WITH_KEEP(value, 'FIRST'|'LAST') WITHIN GROUP (ORDER BY orderKey [, ...])}. + */ +public class SqlSumWithKeepAggFunction extends SqlAggFunction { + /** */ + public SqlSumWithKeepAggFunction() { + super( + "SUM_WITH_KEEP", + null, + SqlKind.SUM, + ReturnTypes.AGG_SUM, + null, + OperandTypes.family(SqlTypeFamily.NUMERIC, SqlTypeFamily.CHARACTER), + SqlFunctionCategory.NUMERIC, + false, + false, + Optionality.MANDATORY + ); + } + + /** {@inheritDoc} */ + @Override public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) { + if (!super.checkOperandTypes(callBinding, throwOnFailure)) + return false; + + if (!callBinding.isOperandLiteral(1, false)) { + if (throwOnFailure) + throw callBinding.newError(RESOURCE.argumentMustBeLiteral(getName())); + + return false; + } + + String mode = callBinding.getStringLiteralOperand(1); + + if (!"FIRST".equalsIgnoreCase(mode) && !"LAST".equalsIgnoreCase(mode)) { + if (throwOnFailure) + throw callBinding.newError(IgniteResource.INSTANCE.illegalSumWithKeepMode(mode)); + + return false; + } + + return true; + } + + /** {@inheritDoc} */ + @Override public Optionality getDistinctOptionality() { + return Optionality.FORBIDDEN; + } +} diff --git a/modules/calcite/src/main/java/org/apache/ignite/internal/processors/query/calcite/util/IgniteResource.java b/modules/calcite/src/main/java/org/apache/ignite/internal/processors/query/calcite/util/IgniteResource.java index df74d56a91638..817a2071dd483 100644 --- a/modules/calcite/src/main/java/org/apache/ignite/internal/processors/query/calcite/util/IgniteResource.java +++ b/modules/calcite/src/main/java/org/apache/ignite/internal/processors/query/calcite/util/IgniteResource.java @@ -39,6 +39,10 @@ public interface IgniteResource { @Resources.BaseMessage("Illegal aggregate function. {0} is unsupported at the moment.") Resources.ExInst unsupportedAggregationFunction(String a0); + /** */ + @Resources.BaseMessage("Illegal SUM_WITH_KEEP mode ''{0}''. Expected FIRST or LAST.") + Resources.ExInst illegalSumWithKeepMode(String mode); + /** */ @Resources.BaseMessage("Illegal value of {0}. The value must be positive and less than Integer.MAX_VALUE " + "(" + Integer.MAX_VALUE + ")." ) diff --git a/modules/calcite/src/test/java/org/apache/ignite/internal/processors/query/calcite/exec/rel/BaseAggregateTest.java b/modules/calcite/src/test/java/org/apache/ignite/internal/processors/query/calcite/exec/rel/BaseAggregateTest.java index b0bfe89e9769b..536ef026f5650 100644 --- a/modules/calcite/src/test/java/org/apache/ignite/internal/processors/query/calcite/exec/rel/BaseAggregateTest.java +++ b/modules/calcite/src/test/java/org/apache/ignite/internal/processors/query/calcite/exec/rel/BaseAggregateTest.java @@ -28,6 +28,7 @@ import java.util.stream.IntStream; import java.util.stream.Stream; import com.google.common.collect.ImmutableList; +import org.apache.calcite.rel.RelFieldCollation; import org.apache.calcite.rel.RelCollations; import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.type.RelDataType; @@ -40,6 +41,7 @@ import org.apache.ignite.internal.processors.query.calcite.exec.exp.agg.AccumulatorWrapper; import org.apache.ignite.internal.processors.query.calcite.exec.exp.agg.Accumulators; import org.apache.ignite.internal.processors.query.calcite.exec.exp.agg.AggregateType; +import org.apache.ignite.internal.processors.query.calcite.sql.fun.IgniteOwnSqlOperatorTable; import org.apache.ignite.internal.processors.query.calcite.type.IgniteTypeFactory; import org.apache.ignite.internal.processors.query.calcite.util.TypeUtils; import org.apache.ignite.internal.util.typedef.F; @@ -242,6 +244,52 @@ public void max() { assertFalse(root.hasNext()); } + /** */ + @Test + public void sumWithKeep() { + ExecutionContext ctx = executionContext(F.first(nodes()), UUID.randomUUID(), 0); + IgniteTypeFactory tf = ctx.getTypeFactory(); + RelDataType rowType = TypeUtils.createRowType(tf, int.class, int.class, String.class, int.class); + RelDataType aggRowType = TypeUtils.createRowType(tf, long.class); + + for (String mode : Arrays.asList("FIRST", "LAST")) { + ScanNode scan = new ScanNode<>(ctx, rowType, Arrays.asList( + row(0, 10, mode, 2), + row(0, 20, mode, 1), + row(0, 30, mode, 1), + row(0, 40, mode, 3) + )); + + AggregateCall call = AggregateCall.create( + IgniteOwnSqlOperatorTable.SUM_WITH_KEEP, + false, + false, + false, + ImmutableIntList.of(1, 2), + -1, + RelCollations.of(new RelFieldCollation(3)), + tf.createJavaType(long.class), + null); + + SingleNode aggChain = createAggregateNodesChain( + ctx, + ImmutableList.of(ImmutableBitSet.of(0)), + call, + rowType, + aggRowType, + rowFactory(), + scan + ); + + RootNode root = new RootNode<>(ctx, aggRowType); + root.register(aggChain); + + assertTrue(root.hasNext()); + Assert.assertArrayEquals(row(0, "FIRST".equals(mode) ? 50L : 40L), root.next()); + assertFalse(root.hasNext()); + } + } + /** */ @Test public void avg() { diff --git a/modules/calcite/src/test/java/org/apache/ignite/internal/processors/query/calcite/integration/StdSqlOperatorsTest.java b/modules/calcite/src/test/java/org/apache/ignite/internal/processors/query/calcite/integration/StdSqlOperatorsTest.java index 72d6feaa33544..9468854ad74c9 100644 --- a/modules/calcite/src/test/java/org/apache/ignite/internal/processors/query/calcite/integration/StdSqlOperatorsTest.java +++ b/modules/calcite/src/test/java/org/apache/ignite/internal/processors/query/calcite/integration/StdSqlOperatorsTest.java @@ -26,6 +26,7 @@ import java.util.Collections; import java.util.stream.Collectors; import java.util.stream.IntStream; +import org.apache.calcite.sql.validate.SqlValidatorException; import org.apache.ignite.internal.processors.query.calcite.QueryChecker; import org.apache.ignite.internal.processors.query.calcite.sql.fun.IgniteStdSqlOperatorTable; import org.apache.ignite.internal.util.typedef.F; @@ -93,6 +94,10 @@ public void testArithmetic() { /** */ @Test public void testAggregates() { + sql("CREATE TABLE keep_t(id INT PRIMARY KEY, grp INT, v INT, ord INT, priority INT)"); + sql("INSERT INTO keep_t VALUES (1, 1, 10, 2, 2), (2, 1, 20, 1, 2), (3, 1, 30, 1, 1), " + + "(4, 2, 40, 3, NULL), (5, 2, 60, 4, 1), (6, 2, 70, 4, 2)"); + assertExpression("COUNT(*)").returns(1L).check(); assertExpression("SUM(val)").returns(1L).check(); assertExpression("AVG(val)").returns(1).check(); @@ -109,6 +114,52 @@ public void testAggregates() { .returns((IntStream.range(1, 7).boxed().collect(Collectors.toList()))).check(); assertExpression("EVERY(val = 1)").returns(true).check(); assertExpression("SOME(val = 1)").returns(true).check(); + assertQuery("SELECT " + + "SUM_WITH_KEEP(v, 'FIRST') WITHIN GROUP (ORDER BY ord ASC), " + + "SUM_WITH_KEEP(v, 'LAST') WITHIN GROUP (ORDER BY ord ASC) FROM keep_t") + .returns(50L, 130L) + .check(); + assertQuery("SELECT grp, " + + "SUM_WITH_KEEP(v, 'FIRST') WITHIN GROUP (ORDER BY ord ASC), " + + "SUM_WITH_KEEP(v, 'LAST') WITHIN GROUP (ORDER BY ord ASC) " + + "FROM keep_t GROUP BY grp ORDER BY grp") + .returns(1, 50L, 10L) + .returns(2, 40L, 130L) + .check(); + + assertQuery("SELECT " + + "SUM_WITH_KEEP(v, 'FIRST') WITHIN GROUP (ORDER BY ord DESC), " + + "SUM_WITH_KEEP(v, 'LAST') WITHIN GROUP (ORDER BY ord DESC) FROM keep_t") + .returns(130L, 50L) + .check(); + + assertQuery("SELECT " + + "SUM_WITH_KEEP(v, 'FIRST') WITHIN GROUP (ORDER BY ord DESC, priority ASC), " + + "SUM_WITH_KEEP(v, 'LAST') WITHIN GROUP (ORDER BY ord DESC, priority ASC) FROM keep_t") + .returns(60L, 20L) + .check(); + + assertQuery("SELECT " + + "SUM_WITH_KEEP(v, 'FIRST') WITHIN GROUP (ORDER BY priority ASC NULLS FIRST), " + + "SUM_WITH_KEEP(v, 'LAST') WITHIN GROUP (ORDER BY priority ASC NULLS FIRST) FROM keep_t") + .returns(40L, 100L) + .check(); + + assertQuery("SELECT " + + "SUM_WITH_KEEP(v, 'FIRST') WITHIN GROUP (ORDER BY COALESCE(priority, 99)), " + + "SUM_WITH_KEEP(v, 'LAST') WITHIN GROUP (ORDER BY COALESCE(priority, 99)) FROM keep_t") + .returns(90L, 40L) + .check(); + + assertQuery("SELECT SUM_WITH_KEEP(v, 'FIRST') WITHIN GROUP " + + "(ORDER BY COALESCE(priority, 99) ASC, ord DESC) FROM keep_t") + .returns(60L) + .check(); + + assertThrows("SELECT SUM_WITH_KEEP(v, 'MIDDLE') WITHIN GROUP (ORDER BY ord) FROM keep_t", + SqlValidatorException.class, "Expected FIRST or LAST"); + assertThrows("SELECT SUM_WITH_KEEP(v, 'FIRST') FROM keep_t", + SqlValidatorException.class, "must contain a WITHIN GROUP clause"); } /** */