From 8d5ba352a9596f9ef5c88d999f94006830b894c3 Mon Sep 17 00:00:00 2001 From: Sergey Nuyanzin Date: Wed, 13 May 2026 20:49:41 +0200 Subject: [PATCH 1/5] [FLINK-36602][table] Upgrade Calcite version to 1.38.0 --- flink-python/pyflink/table/tests/test_udf.py | 2 +- flink-table/flink-sql-parser/pom.xml | 8 +- .../src/main/codegen/data/Parser.tdd | 1 - .../src/main/codegen/templates/Parser.jj | 117 +- .../java/org/apache/calcite/sql/SqlJoin.java | 293 --- .../parser/validate/FlinkSqlConformance.java | 5 - .../sql/parser/FlinkSqlParserImplTest.java | 3 + .../flink-table-calcite-bridge/pom.xml | 6 +- .../functions/BuiltInFunctionDefinitions.java | 2 +- .../org/apache/calcite/plan/RelOptUtil.java | 107 +- .../calcite/rel/metadata/RelMdPredicates.java | 1051 -------- .../calcite/rel/rules/SubQueryRemoveRule.java | 10 +- .../rel/type/RelDataTypeFactoryImpl.java | 19 +- .../org/apache/calcite/rex/RexBuilder.java | 437 +++- .../org/apache/calcite/rex/RexChecker.java | 34 +- .../org/apache/calcite/rex/RexShuttle.java | 35 +- .../java/org/apache/calcite/rex/RexUtil.java | 67 +- .../apache/calcite/runtime/SqlFunctions.java | 647 ++++- .../calcite/sql/SqlIntervalQualifier.java | 51 +- .../java/org/apache/calcite/sql/SqlUtil.java | 76 +- .../apache/calcite/sql/type/BasicSqlType.java | 322 +++ .../calcite/sql/type/SqlTypeFactoryImpl.java | 18 +- .../apache/calcite/sql/type/SqlTypeUtil.java | 96 +- .../sql/validate/SqlValidatorImpl.java | 502 +++- .../apache/calcite/sql2rel/AggConverter.java | 33 +- .../calcite/sql2rel/RelDecorrelator.java | 73 +- .../calcite/sql2rel/SqlToRelConverter.java | 615 +++-- .../sql2rel/StandardConvertletTable.java | 2183 ----------------- .../planner/calcite/FlinkConvertletTable.java | 22 + .../planner/calcite/FlinkTypeFactory.java | 7 +- .../planner/calcite/FlinkTypeSystem.java | 10 + .../calcite/RelTimeIndicatorConverter.java | 1 + .../exec/serde/RexNodeJsonDeserializer.java | 6 +- .../nodes/exec/utils/CommonPythonUtil.java | 4 +- ...markIntoTableSourceScanAcrossCalcRule.java | 2 +- .../plan/schema/TimeIndicatorRelDataType.java | 76 + .../src/main/resources/META-INF/NOTICE | 4 +- .../planner/codegen/ExpressionReducer.scala | 14 +- .../BatchPhysicalOverAggregateRule.scala | 1 + .../schema/TimeIndicatorRelDataType.scala | 55 - .../planner/plan/utils/FlinkRelOptUtil.scala | 4 +- .../planner/plan/utils/FlinkRexUtil.scala | 2 +- .../plan/utils/OverAggregateUtil.scala | 1 + .../functions/JsonFunctionsITCase.java | 30 +- .../SqlNodeToCallOperationTest.java | 2 +- .../plan/batch/sql/RowLevelUpdateTest.java | 7 +- .../exec/serde/RexNodeJsonSerdeTest.java | 2 +- .../table/planner/plan/batch/sql/CalcTest.xml | 10 +- .../plan/batch/sql/DeadlockBreakupTest.xml | 6 +- .../batch/sql/ForwardHashExchangeTest.xml | 8 +- .../plan/batch/sql/RemoveCollationTest.xml | 4 +- .../plan/batch/sql/RemoveShuffleTest.xml | 16 +- .../plan/batch/sql/SubplanReuseTest.xml | 8 +- .../planner/plan/batch/sql/TableSinkTest.xml | 10 +- .../planner/plan/batch/sql/UnionTest.xml | 2 +- .../planner/plan/batch/sql/ValuesTest.xml | 3 +- .../plan/batch/sql/agg/OverAggregateTest.xml | 54 +- .../AggregateReduceFunctionsRuleTest.xml | 4 +- .../logical/ConvertToNotInOrInRuleTest.xml | 22 +- .../logical/FlinkPruneEmptyRulesTest.xml | 27 +- .../PushFilterIntoTableSourceScanRuleTest.xml | 2 +- ...veUnreachableCoalesceArgumentsRuleTest.xml | 4 +- .../logical/WindowGroupReorderRuleTest.xml | 32 +- .../logical/subquery/SubQuerySemiJoinTest.xml | 61 + ...ushLocalAggIntoTableSourceScanRuleTest.xml | 4 +- .../planner/plan/stream/sql/CalcTest.xml | 8 +- .../planner/plan/stream/sql/DeltaJoinTest.xml | 4 +- .../plan/stream/sql/MatchRecognizeTest.xml | 4 +- .../planner/plan/stream/sql/UnionTest.xml | 2 +- .../planner/plan/stream/sql/ValuesTest.xml | 3 +- .../plan/stream/sql/agg/AggregateTest.xml | 4 +- .../plan/stream/sql/agg/OverAggregateTest.xml | 46 +- .../planner/plan/stream/table/ValuesTest.xml | 82 +- .../FlinkRelMdColumnIntervalTest.scala | 6 +- .../metadata/FlinkRelMdHandlerTestBase.scala | 4 + .../metadata/FlinkRelMdSelectivityTest.scala | 1 + .../metadata/FlinkRelMdUpsertKeysTest.scala | 1 + .../logical/FlinkPruneEmptyRulesTest.scala | 2 +- .../subquery/SubQueryAntiJoinTest.scala | 2 +- .../subquery/SubQuerySemiJoinTest.scala | 40 +- flink-table/pom.xml | 2 +- 81 files changed, 3004 insertions(+), 4547 deletions(-) delete mode 100644 flink-table/flink-sql-parser/src/main/java/org/apache/calcite/sql/SqlJoin.java delete mode 100644 flink-table/flink-table-planner/src/main/java/org/apache/calcite/rel/metadata/RelMdPredicates.java create mode 100644 flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/type/BasicSqlType.java delete mode 100644 flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java create mode 100644 flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/schema/TimeIndicatorRelDataType.java delete mode 100644 flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/schema/TimeIndicatorRelDataType.scala diff --git a/flink-python/pyflink/table/tests/test_udf.py b/flink-python/pyflink/table/tests/test_udf.py index 182c46bbd9cf8..588add0ffba99 100644 --- a/flink-python/pyflink/table/tests/test_udf.py +++ b/flink-python/pyflink/table/tests/test_udf.py @@ -218,7 +218,7 @@ def udf_with_constant_params(p, null_param, tinyint_param, smallint_param, int_p "cast (1 as SMALLINT)," "cast (1 as INT)," "cast (1 as BIGINT)," - "cast (1.05 as DECIMAL)," + "cast (1.05 as DECIMAL(3,2))," "cast (1.23 as FLOAT)," "cast (1.98932 as DOUBLE)," "true," diff --git a/flink-table/flink-sql-parser/pom.xml b/flink-table/flink-sql-parser/pom.xml index 366d6d8ea1f2f..0264583bfecc7 100644 --- a/flink-table/flink-sql-parser/pom.xml +++ b/flink-table/flink-sql-parser/pom.xml @@ -68,13 +68,13 @@ under the License. ${calcite.version} + + + + + + + + + + + @@ -1491,6 +1520,38 @@ LogicalProject(a=[$0], b=[$1], c=[$2]) +- LogicalProject(d=[$0]) +- LogicalFilter(condition=[true]) +- LogicalTableScan(table=[[default_catalog, default_database, r]]) +]]> + + + + + + + + + + + diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.xml index 7852555acf870..229bac52ea0e8 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/physical/batch/PushLocalAggIntoTableSourceScanRuleTest.xml @@ -529,14 +529,14 @@ FROM inventory]]> (COUNT($3) OVER (PARTITION BY $1), 0), $SUM0($3) OVER (PARTITION BY $1), null:BIGINT)], name=[$1]) +LogicalProject(id=[$0], amount=[$2], EXPR$2=[CASE(>(COUNT($3) OVER (PARTITION BY $1), 0), SUM($3) OVER (PARTITION BY $1), null:BIGINT)], name=[$1]) +- LogicalTableScan(table=[[default_catalog, default_database, inventory]]) ]]> (w0$o0, 0), w0$o1, null:BIGINT) AS EXPR$2, name]) -+- OverAggregate(partitionBy=[name], window#0=[COUNT(price) AS w0$o0, $SUM0(price) AS w0$o1 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], select=[id, name, amount, price, w0$o0, w0$o1]) ++- OverAggregate(partitionBy=[name], window#0=[COUNT(price) AS w0$o0, SUM(price) AS w0$o1 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING], select=[id, name, amount, price, w0$o0, w0$o1]) +- Sort(orderBy=[name ASC]) +- Exchange(distribution=[hash[name]]) +- TableSourceScan(table=[[default_catalog, default_database, inventory, project=[id, name, amount, price], metadata=[]]], fields=[id, name, amount, price]) diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/CalcTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/CalcTest.xml index 11b409c98fe18..b7b97659a739c 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/CalcTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/CalcTest.xml @@ -181,13 +181,13 @@ Calc(select=[a, b, c], where=[((a < 10) AND (b > 20))]) @@ -198,13 +198,13 @@ Calc(select=[ARRAY(0.12, 0.5, 0.99) AS EXPR$0]) diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/DeltaJoinTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/DeltaJoinTest.xml index 7c04bf6c2aaa6..0995440d9eeb7 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/DeltaJoinTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/DeltaJoinTest.xml @@ -377,7 +377,7 @@ Sink(table=[default_catalog.default_database.snk], fields=[a0, a1, a2, a3, b0, b @@ -385,7 +385,7 @@ LogicalSink(table=[default_catalog.default_database.snk], fields=[a0, a1, a2, a3 (COUNT($1) OVER (PARTITION BY $0 ORDER BY $3 NULLS FIRST), 0), $SUM0($1) OVER (PARTITION BY $0 ORDER BY $3 NULLS FIRST), null:INTEGER)]), rowType=[RecordType(VARCHAR(2147483647) symbol, INTEGER price, INTEGER tax, TIMESTAMP_LTZ(3) *ROWTIME* matchRowtime, INTEGER price_sum)] +LogicalProject(symbol=[$0], price=[$1], tax=[$2], matchRowtime=[$3], price_sum=[CASE(>(COUNT($1) OVER (PARTITION BY $0 ORDER BY $3 NULLS FIRST), 0), SUM($1) OVER (PARTITION BY $0 ORDER BY $3 NULLS FIRST), null:INTEGER)]), rowType=[RecordType(VARCHAR(2147483647) symbol, INTEGER price, INTEGER tax, TIMESTAMP_LTZ(3) *ROWTIME* matchRowtime, INTEGER price_sum)] +- LogicalProject(symbol=[$0], price=[$1], tax=[$2], matchRowtime=[$3]), rowType=[RecordType(VARCHAR(2147483647) symbol, INTEGER price, INTEGER tax, TIMESTAMP_LTZ(3) *ROWTIME* matchRowtime)] +- LogicalMatch(partition=[[0]], order=[[1 ASC-nulls-first]], outputFields=[[symbol, price, tax, matchRowtime]], allRows=[false], after=[FLAG(SKIP TO NEXT ROW)], pattern=[_UTF-16LE'A'], isStrictStarts=[false], isStrictEnds=[false], subsets=[[]], patternDefinitions=[[>(PREV(A.$2, 0), 0)]], inputFields=[[symbol, ts_ltz, price, tax]]), rowType=[RecordType(VARCHAR(2147483647) symbol, INTEGER price, INTEGER tax, TIMESTAMP_LTZ(3) *ROWTIME* matchRowtime)] +- LogicalWatermarkAssigner(rowtime=[ts_ltz], watermark=[-($1, 1000:INTERVAL SECOND)]), rowType=[RecordType(VARCHAR(2147483647) symbol, TIMESTAMP_LTZ(3) *ROWTIME* ts_ltz, INTEGER price, INTEGER tax)] @@ -199,7 +199,7 @@ LogicalProject(symbol=[$0], price=[$1], tax=[$2], matchRowtime=[$3], price_sum=[ (w0$o0, 0), w0$o1, null:INTEGER) AS price_sum]), rowType=[RecordType(VARCHAR(2147483647) symbol, INTEGER price, INTEGER tax, TIMESTAMP_LTZ(3) *ROWTIME* matchRowtime, INTEGER price_sum)] -+- OverAggregate(partitionBy=[symbol], orderBy=[matchRowtime ASC], window=[ RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[symbol, price, tax, matchRowtime, COUNT(price) AS w0$o0, $SUM0(price) AS w0$o1]), rowType=[RecordType(VARCHAR(2147483647) symbol, INTEGER price, INTEGER tax, TIMESTAMP_LTZ(3) *ROWTIME* matchRowtime, BIGINT w0$o0, INTEGER w0$o1)] ++- OverAggregate(partitionBy=[symbol], orderBy=[matchRowtime ASC], window=[ RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[symbol, price, tax, matchRowtime, COUNT(price) AS w0$o0, SUM(price) AS w0$o1]), rowType=[RecordType(VARCHAR(2147483647) symbol, INTEGER price, INTEGER tax, TIMESTAMP_LTZ(3) *ROWTIME* matchRowtime, BIGINT w0$o0, INTEGER w0$o1)] +- Exchange(distribution=[hash[symbol]]), rowType=[RecordType(VARCHAR(2147483647) symbol, INTEGER price, INTEGER tax, TIMESTAMP_LTZ(3) *ROWTIME* matchRowtime)] +- Match(partitionBy=[symbol], orderBy=[ts_ltz ASC], measures=[FINAL(A.price) AS price, FINAL(A.tax) AS tax, FINAL(MATCH_ROWTIME(*.ts_ltz)) AS matchRowtime], rowsPerMatch=[ONE ROW PER MATCH], after=[SKIP TO NEXT ROW], pattern=[_UTF-16LE'A'], define=[{A=>(PREV(A.$2, 0), 0)}]), rowType=[RecordType(VARCHAR(2147483647) symbol, INTEGER price, INTEGER tax, TIMESTAMP_LTZ(3) *ROWTIME* matchRowtime)] +- Exchange(distribution=[hash[symbol]]), rowType=[RecordType(VARCHAR(2147483647) symbol, TIMESTAMP_LTZ(3) *ROWTIME* ts_ltz, INTEGER price, INTEGER tax)] diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/UnionTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/UnionTest.xml index 2969cdc7486ee..0dbdafc413b55 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/UnionTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/UnionTest.xml @@ -70,7 +70,7 @@ LogicalProject(a=[$0], b=[$1]), rowType=[RecordType(INTEGER a, DECIMAL(20, 1) b) +- LogicalUnion(all=[true]), rowType=[RecordType(INTEGER a, DECIMAL(20, 1) b)] :- LogicalProject(a=[$0], b=[$1]), rowType=[RecordType(INTEGER a, BIGINT b)] : +- LogicalTableScan(table=[[default_catalog, default_database, MyTable1]]), rowType=[RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c)] - +- LogicalProject(a=[$0], EXPR$1=[0:DECIMAL(2, 1)]), rowType=[RecordType(INTEGER a, DECIMAL(2, 1) EXPR$1)] + +- LogicalProject(a=[$0], EXPR$1=[0.0:DECIMAL(2, 1)]), rowType=[RecordType(INTEGER a, DECIMAL(2, 1) EXPR$1)] +- LogicalTableScan(table=[[default_catalog, default_database, MyTable2]]), rowType=[RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c)] ]]> diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/ValuesTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/ValuesTest.xml index 098ecd36f3471..42260c1b33fc6 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/ValuesTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/ValuesTest.xml @@ -100,7 +100,8 @@ LogicalProject(a=[$0], b=[$1], c=[$2]) diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/AggregateTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/AggregateTest.xml index 92fccedf388e4..c2a34a6c49bac 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/AggregateTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/AggregateTest.xml @@ -68,7 +68,7 @@ FROM T GROUP BY a @@ -76,7 +76,7 @@ LogicalAggregate(group=[{0}], EXPR$1=[SUM($1)], EXPR$2=[SUM($2)], EXPR$3=[SUM($3 diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/OverAggregateTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/OverAggregateTest.xml index dc0185da0ec8d..fdd3264866764 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/OverAggregateTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/OverAggregateTest.xml @@ -111,14 +111,14 @@ FROM MyTable (COUNT($2) OVER (PARTITION BY $0 ORDER BY $3 NULLS FIRST RANGE 7200000:INTERVAL HOUR PRECEDING), 0), $SUM0($2) OVER (PARTITION BY $0 ORDER BY $3 NULLS FIRST RANGE 7200000:INTERVAL HOUR PRECEDING), null:BIGINT), COUNT($2) OVER (PARTITION BY $0 ORDER BY $3 NULLS FIRST RANGE 7200000:INTERVAL HOUR PRECEDING))]) +LogicalProject(a=[$0], avgA=[/(CASE(>(COUNT($2) OVER (PARTITION BY $0 ORDER BY $3 NULLS FIRST RANGE 7200000:INTERVAL HOUR PRECEDING), 0), SUM($2) OVER (PARTITION BY $0 ORDER BY $3 NULLS FIRST RANGE 7200000:INTERVAL HOUR PRECEDING), null:BIGINT), COUNT($2) OVER (PARTITION BY $0 ORDER BY $3 NULLS FIRST RANGE 7200000:INTERVAL HOUR PRECEDING))]) +- LogicalTableScan(table=[[default_catalog, default_database, MyTable]]) ]]> 0), w0$o1, null:BIGINT) / w0$o0) AS avgA]) -+- OverAggregate(partitionBy=[a], orderBy=[proctime ASC], window=[ RANGE BETWEEN 7200000 PRECEDING AND CURRENT ROW], select=[a, c, proctime, COUNT(c) AS w0$o0, $SUM0(c) AS w0$o1]) ++- OverAggregate(partitionBy=[a], orderBy=[proctime ASC], window=[ RANGE BETWEEN 7200000 PRECEDING AND CURRENT ROW], select=[a, c, proctime, COUNT(c) AS w0$o0, SUM(c) AS w0$o1]) +- Exchange(distribution=[hash[a]]) +- Calc(select=[a, c, proctime]) +- DataStreamScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c, proctime, rowtime]) @@ -138,14 +138,14 @@ FROM MyTable (COUNT($0) OVER (PARTITION BY $2 ORDER BY $3 NULLS FIRST ROWS 2 PRECEDING), 0), $SUM0($0) OVER (PARTITION BY $2 ORDER BY $3 NULLS FIRST ROWS 2 PRECEDING), null:INTEGER)]) +LogicalProject(c=[$2], cnt1=[COUNT($0) OVER (PARTITION BY $2 ORDER BY $3 NULLS FIRST ROWS 2 PRECEDING)], sum1=[CASE(>(COUNT($0) OVER (PARTITION BY $2 ORDER BY $3 NULLS FIRST ROWS 2 PRECEDING), 0), SUM($0) OVER (PARTITION BY $2 ORDER BY $3 NULLS FIRST ROWS 2 PRECEDING), null:INTEGER)]) +- LogicalTableScan(table=[[default_catalog, default_database, MyTable]]) ]]> 0), w0$o1, null:INTEGER) AS sum1]) -+- OverAggregate(partitionBy=[c], orderBy=[proctime ASC], window=[ ROWS BETWEEN 2 PRECEDING AND CURRENT ROW], select=[a, c, proctime, COUNT(a) AS w0$o0, $SUM0(a) AS w0$o1]) ++- OverAggregate(partitionBy=[c], orderBy=[proctime ASC], window=[ ROWS BETWEEN 2 PRECEDING AND CURRENT ROW], select=[a, c, proctime, COUNT(a) AS w0$o0, SUM(a) AS w0$o1]) +- Exchange(distribution=[hash[c]]) +- Calc(select=[a, c, proctime]) +- DataStreamScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c, proctime, rowtime]) @@ -158,14 +158,14 @@ Calc(select=[c, w0$o0 AS cnt1, CASE((w0$o0 > 0), w0$o1, null:INTEGER) AS sum1]) (COUNT($2) OVER (PARTITION BY $0 ORDER BY PROCTIME() NULLS FIRST ROWS 4 PRECEDING), 0), $SUM0($2) OVER (PARTITION BY $0 ORDER BY PROCTIME() NULLS FIRST ROWS 4 PRECEDING), null:BIGINT)], EXPR$2=[MIN($2) OVER (PARTITION BY $0 ORDER BY PROCTIME() NULLS FIRST ROWS 4 PRECEDING)]) +LogicalProject(a=[$0], EXPR$1=[CASE(>(COUNT($2) OVER (PARTITION BY $0 ORDER BY PROCTIME() NULLS FIRST ROWS 4 PRECEDING), 0), SUM($2) OVER (PARTITION BY $0 ORDER BY PROCTIME() NULLS FIRST ROWS 4 PRECEDING), null:BIGINT)], EXPR$2=[MIN($2) OVER (PARTITION BY $0 ORDER BY PROCTIME() NULLS FIRST ROWS 4 PRECEDING)]) +- LogicalTableScan(table=[[default_catalog, default_database, MyTable]]) ]]> 0), w0$o1, null:BIGINT) AS EXPR$1, w0$o2 AS EXPR$2]) -+- OverAggregate(partitionBy=[a], orderBy=[$2 ASC], window=[ ROWS BETWEEN 4 PRECEDING AND CURRENT ROW], select=[a, c, $2, COUNT(c) AS w0$o0, $SUM0(c) AS w0$o1, MIN(c) AS w0$o2]) ++- OverAggregate(partitionBy=[a], orderBy=[$2 ASC], window=[ ROWS BETWEEN 4 PRECEDING AND CURRENT ROW], select=[a, c, $2, COUNT(c) AS w0$o0, SUM(c) AS w0$o1, MIN(c) AS w0$o2]) +- Exchange(distribution=[hash[a]]) +- Calc(select=[a, c, PROCTIME() AS $2]) +- DataStreamScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c, proctime, rowtime]) @@ -183,14 +183,14 @@ FROM MyTable (COUNT($0) OVER (ORDER BY $3 NULLS FIRST), 0), $SUM0($0) OVER (ORDER BY $3 NULLS FIRST), null:INTEGER)]) +LogicalProject(c=[$2], cnt1=[COUNT($0) OVER (ORDER BY $3 NULLS FIRST)], cnt2=[CASE(>(COUNT($0) OVER (ORDER BY $3 NULLS FIRST), 0), SUM($0) OVER (ORDER BY $3 NULLS FIRST), null:INTEGER)]) +- LogicalTableScan(table=[[default_catalog, default_database, MyTable]]) ]]> 0), w0$o1, null:INTEGER) AS cnt2]) -+- OverAggregate(orderBy=[proctime ASC], window=[ RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, c, proctime, COUNT(a) AS w0$o0, $SUM0(a) AS w0$o1]) ++- OverAggregate(orderBy=[proctime ASC], window=[ RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, c, proctime, COUNT(a) AS w0$o0, SUM(a) AS w0$o1]) +- Exchange(distribution=[single]) +- Calc(select=[a, c, proctime]) +- DataStreamScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c, proctime, rowtime]) @@ -231,14 +231,14 @@ FROM MyTable (COUNT($0) OVER (PARTITION BY $2 ORDER BY $3 NULLS FIRST), 0), $SUM0($0) OVER (PARTITION BY $2 ORDER BY $3 NULLS FIRST), null:INTEGER)]) +LogicalProject(c=[$2], cnt1=[COUNT($0) OVER (PARTITION BY $2 ORDER BY $3 NULLS FIRST)], cnt2=[CASE(>(COUNT($0) OVER (PARTITION BY $2 ORDER BY $3 NULLS FIRST), 0), SUM($0) OVER (PARTITION BY $2 ORDER BY $3 NULLS FIRST), null:INTEGER)]) +- LogicalTableScan(table=[[default_catalog, default_database, MyTable]]) ]]> 0), w0$o1, null:INTEGER) AS cnt2]) -+- OverAggregate(partitionBy=[c], orderBy=[proctime ASC], window=[ RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, c, proctime, COUNT(a) AS w0$o0, $SUM0(a) AS w0$o1]) ++- OverAggregate(partitionBy=[c], orderBy=[proctime ASC], window=[ RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, c, proctime, COUNT(a) AS w0$o0, SUM(a) AS w0$o1]) +- Exchange(distribution=[hash[c]]) +- Calc(select=[a, c, proctime]) +- DataStreamScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c, proctime, rowtime]) @@ -282,14 +282,14 @@ FROM MyTable (COUNT(DISTINCT $0) OVER (PARTITION BY $2 ORDER BY $3 NULLS FIRST ROWS 2 PRECEDING), 0), $SUM0(DISTINCT $0) OVER (PARTITION BY $2 ORDER BY $3 NULLS FIRST ROWS 2 PRECEDING), null:INTEGER)]) +LogicalProject(c=[$2], cnt1=[COUNT(DISTINCT $0) OVER (PARTITION BY $2 ORDER BY $3 NULLS FIRST ROWS 2 PRECEDING)], sum1=[CASE(>(COUNT(DISTINCT $0) OVER (PARTITION BY $2 ORDER BY $3 NULLS FIRST ROWS 2 PRECEDING), 0), SUM(DISTINCT $0) OVER (PARTITION BY $2 ORDER BY $3 NULLS FIRST ROWS 2 PRECEDING), null:INTEGER)]) +- LogicalTableScan(table=[[default_catalog, default_database, MyTable]]) ]]> 0), w0$o1, null:INTEGER) AS sum1]) -+- OverAggregate(partitionBy=[c], orderBy=[proctime ASC], window=[ ROWS BETWEEN 2 PRECEDING AND CURRENT ROW], select=[a, c, proctime, COUNT(DISTINCT a) AS w0$o0, $SUM0(DISTINCT a) AS w0$o1]) ++- OverAggregate(partitionBy=[c], orderBy=[proctime ASC], window=[ ROWS BETWEEN 2 PRECEDING AND CURRENT ROW], select=[a, c, proctime, COUNT(DISTINCT a) AS w0$o0, SUM(DISTINCT a) AS w0$o1]) +- Exchange(distribution=[hash[c]]) +- Calc(select=[a, c, proctime]) +- DataStreamScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c, proctime, rowtime]) @@ -313,14 +313,14 @@ FROM MyTable (COUNT($0) OVER (PARTITION BY $1 ORDER BY $3 NULLS FIRST ROWS 2 PRECEDING), 0), $SUM0($0) OVER (PARTITION BY $1 ORDER BY $3 NULLS FIRST ROWS 2 PRECEDING), null:INTEGER)], cnt2=[COUNT(DISTINCT $0) OVER (PARTITION BY $1 ORDER BY $3 NULLS FIRST ROWS 2 PRECEDING)], sum2=[CASE(>(COUNT(DISTINCT $2) OVER (PARTITION BY $1 ORDER BY $3 NULLS FIRST ROWS 2 PRECEDING), 0), $SUM0(DISTINCT $2) OVER (PARTITION BY $1 ORDER BY $3 NULLS FIRST ROWS 2 PRECEDING), null:BIGINT)]) +LogicalProject(b=[$1], cnt1=[COUNT($0) OVER (PARTITION BY $1 ORDER BY $3 NULLS FIRST ROWS 2 PRECEDING)], sum1=[CASE(>(COUNT($0) OVER (PARTITION BY $1 ORDER BY $3 NULLS FIRST ROWS 2 PRECEDING), 0), SUM($0) OVER (PARTITION BY $1 ORDER BY $3 NULLS FIRST ROWS 2 PRECEDING), null:INTEGER)], cnt2=[COUNT(DISTINCT $0) OVER (PARTITION BY $1 ORDER BY $3 NULLS FIRST ROWS 2 PRECEDING)], sum2=[CASE(>(COUNT(DISTINCT $2) OVER (PARTITION BY $1 ORDER BY $3 NULLS FIRST ROWS 2 PRECEDING), 0), SUM(DISTINCT $2) OVER (PARTITION BY $1 ORDER BY $3 NULLS FIRST ROWS 2 PRECEDING), null:BIGINT)]) +- LogicalTableScan(table=[[default_catalog, default_database, MyTable]]) ]]> 0), w0$o1, null:INTEGER) AS sum1, w0$o2 AS cnt2, CASE((w0$o3 > 0), w0$o4, null:BIGINT) AS sum2]) -+- OverAggregate(partitionBy=[b], orderBy=[proctime ASC], window=[ ROWS BETWEEN 2 PRECEDING AND CURRENT ROW], select=[a, b, c, proctime, COUNT(a) AS w0$o0, $SUM0(a) AS w0$o1, COUNT(DISTINCT a) AS w0$o2, COUNT(DISTINCT c) AS w0$o3, $SUM0(DISTINCT c) AS w0$o4]) ++- OverAggregate(partitionBy=[b], orderBy=[proctime ASC], window=[ ROWS BETWEEN 2 PRECEDING AND CURRENT ROW], select=[a, b, c, proctime, COUNT(a) AS w0$o0, SUM(a) AS w0$o1, COUNT(DISTINCT a) AS w0$o2, COUNT(DISTINCT c) AS w0$o3, SUM(DISTINCT c) AS w0$o4]) +- Exchange(distribution=[hash[b]]) +- Calc(select=[a, b, c, proctime]) +- DataStreamScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c, proctime, rowtime]) @@ -437,14 +437,14 @@ FROM MyTable (COUNT($0) OVER (ORDER BY $4 NULLS FIRST), 0), $SUM0($0) OVER (ORDER BY $4 NULLS FIRST), null:INTEGER)]) +LogicalProject(c=[$2], cnt1=[COUNT($0) OVER (ORDER BY $4 NULLS FIRST)], cnt2=[CASE(>(COUNT($0) OVER (ORDER BY $4 NULLS FIRST), 0), SUM($0) OVER (ORDER BY $4 NULLS FIRST), null:INTEGER)]) +- LogicalTableScan(table=[[default_catalog, default_database, MyTable]]) ]]> 0), w0$o1, null:INTEGER) AS cnt2]) -+- OverAggregate(orderBy=[rowtime ASC], window=[ RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, c, rowtime, COUNT(a) AS w0$o0, $SUM0(a) AS w0$o1]) ++- OverAggregate(orderBy=[rowtime ASC], window=[ RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, c, rowtime, COUNT(a) AS w0$o0, SUM(a) AS w0$o1]) +- Exchange(distribution=[single]) +- Calc(select=[a, c, rowtime]) +- DataStreamScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c, proctime, rowtime]) @@ -462,14 +462,14 @@ FROM MyTable (COUNT($0) OVER (ORDER BY $4 NULLS FIRST ROWS UNBOUNDED PRECEDING), 0), $SUM0($0) OVER (ORDER BY $4 NULLS FIRST ROWS UNBOUNDED PRECEDING), null:INTEGER)]) +LogicalProject(c=[$2], cnt1=[COUNT($0) OVER (ORDER BY $4 NULLS FIRST ROWS UNBOUNDED PRECEDING)], cnt2=[CASE(>(COUNT($0) OVER (ORDER BY $4 NULLS FIRST ROWS UNBOUNDED PRECEDING), 0), SUM($0) OVER (ORDER BY $4 NULLS FIRST ROWS UNBOUNDED PRECEDING), null:INTEGER)]) +- LogicalTableScan(table=[[default_catalog, default_database, MyTable]]) ]]> 0), w0$o1, null:INTEGER) AS cnt2]) -+- OverAggregate(orderBy=[rowtime ASC], window=[ ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, c, rowtime, COUNT(a) AS w0$o0, $SUM0(a) AS w0$o1]) ++- OverAggregate(orderBy=[rowtime ASC], window=[ ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, c, rowtime, COUNT(a) AS w0$o0, SUM(a) AS w0$o1]) +- Exchange(distribution=[single]) +- Calc(select=[a, c, rowtime]) +- DataStreamScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c, proctime, rowtime]) @@ -487,14 +487,14 @@ FROM MyTable (COUNT($0) OVER (PARTITION BY $2 ORDER BY $4 NULLS FIRST), 0), $SUM0($0) OVER (PARTITION BY $2 ORDER BY $4 NULLS FIRST), null:INTEGER)]) +LogicalProject(c=[$2], cnt1=[COUNT($0) OVER (PARTITION BY $2 ORDER BY $4 NULLS FIRST)], cnt2=[CASE(>(COUNT($0) OVER (PARTITION BY $2 ORDER BY $4 NULLS FIRST), 0), SUM($0) OVER (PARTITION BY $2 ORDER BY $4 NULLS FIRST), null:INTEGER)]) +- LogicalTableScan(table=[[default_catalog, default_database, MyTable]]) ]]> 0), w0$o1, null:INTEGER) AS cnt2]) -+- OverAggregate(partitionBy=[c], orderBy=[rowtime ASC], window=[ RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, c, rowtime, COUNT(a) AS w0$o0, $SUM0(a) AS w0$o1]) ++- OverAggregate(partitionBy=[c], orderBy=[rowtime ASC], window=[ RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, c, rowtime, COUNT(a) AS w0$o0, SUM(a) AS w0$o1]) +- Exchange(distribution=[hash[c]]) +- Calc(select=[a, c, rowtime]) +- DataStreamScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c, proctime, rowtime]) @@ -512,14 +512,14 @@ FROM MyTable (COUNT($0) OVER (PARTITION BY $2 ORDER BY $4 NULLS FIRST ROWS UNBOUNDED PRECEDING), 0), $SUM0($0) OVER (PARTITION BY $2 ORDER BY $4 NULLS FIRST ROWS UNBOUNDED PRECEDING), null:INTEGER)]) +LogicalProject(c=[$2], cnt1=[COUNT($0) OVER (PARTITION BY $2 ORDER BY $4 NULLS FIRST ROWS UNBOUNDED PRECEDING)], cnt2=[CASE(>(COUNT($0) OVER (PARTITION BY $2 ORDER BY $4 NULLS FIRST ROWS UNBOUNDED PRECEDING), 0), SUM($0) OVER (PARTITION BY $2 ORDER BY $4 NULLS FIRST ROWS UNBOUNDED PRECEDING), null:INTEGER)]) +- LogicalTableScan(table=[[default_catalog, default_database, MyTable]]) ]]> 0), w0$o1, null:INTEGER) AS cnt2]) -+- OverAggregate(partitionBy=[c], orderBy=[rowtime ASC], window=[ ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, c, rowtime, COUNT(a) AS w0$o0, $SUM0(a) AS w0$o1]) ++- OverAggregate(partitionBy=[c], orderBy=[rowtime ASC], window=[ ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW], select=[a, c, rowtime, COUNT(a) AS w0$o0, SUM(a) AS w0$o1]) +- Exchange(distribution=[hash[c]]) +- Calc(select=[a, c, rowtime]) +- DataStreamScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c, proctime, rowtime]) @@ -728,7 +728,7 @@ LogicalProject(c=[$2], cnt1=[COUNT($0) OVER (PARTITION BY $1 ORDER BY $3 NULLS F @@ -45,11 +45,11 @@ LogicalUnion(all=[true]) @@ -58,30 +58,30 @@ Union(all=[true], union=[f0]) @@ -90,7 +90,7 @@ Union(all=[true], union=[f0, f1, f2]) @@ -146,30 +146,30 @@ Union(all=[true], union=[f0, f1, f2]) @@ -178,30 +178,30 @@ Union(all=[true], union=[f0, f1, f2]) @@ -244,18 +244,18 @@ Union(all=[true], union=[a, b]) diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnIntervalTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnIntervalTest.scala index 232213f48046e..aab16f95d4c9f 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnIntervalTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnIntervalTest.scala @@ -139,7 +139,7 @@ class FlinkRelMdColumnIntervalTest extends FlinkRelMdHandlerTestBase { // id <= 20 val expr1 = relBuilder.call(LESS_THAN_OR_EQUAL, relBuilder.field(0), relBuilder.literal(20)) // id > 10.0 (note: the types of id and literal are different) - val expr2 = relBuilder.call(GREATER_THAN, relBuilder.field(0), relBuilder.literal(10.0)) + val expr2 = relBuilder.call(GREATER_THAN, relBuilder.field(0), relBuilder.literal(bd(10.0))) // DIV(id, 2) > 3 val expr3 = relBuilder.call( GREATER_THAN, @@ -165,13 +165,13 @@ class FlinkRelMdColumnIntervalTest extends FlinkRelMdHandlerTestBase { // id <= 20 AND id > 10 AND DIV(id, 2) > 3 val filter2 = relBuilder.push(ts).filter(expr1, expr2, expr3).build assertEquals( - ValueInterval(bd(10.0), bd(20), includeLower = false), + ValueInterval(bd(10.0), bd(20.0), includeLower = false), mq.getColumnInterval(filter2, 0)) // id <= 20 AND id > 10 AND score < 4.1 val filter3 = relBuilder.push(ts).filter(expr1, expr2, expr4).build assertEquals( - ValueInterval(bd(10.0), bd(20), includeLower = false), + ValueInterval(bd(10.0), bd(20.0), includeLower = false), mq.getColumnInterval(filter3, 0)) // score > 6.0 OR score <= 4.0 diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala index eaca9fa23c901..d8815d3aa6bf7 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala @@ -2435,6 +2435,7 @@ class FlinkRelMdHandlerTestBase { true, RexWindowBounds.create(SqlWindow.createUnboundedPreceding(new SqlParserPos(0, 0)), null), RexWindowBounds.create(SqlWindow.createCurrentRow(new SqlParserPos(0, 0)), null), + null, RelCollations.of( new RelFieldCollation( 1, @@ -2568,6 +2569,7 @@ class FlinkRelMdHandlerTestBase { true, RexWindowBounds.create(SqlWindow.createUnboundedPreceding(new SqlParserPos(0, 0)), null), RexWindowBounds.create(SqlWindow.createCurrentRow(new SqlParserPos(0, 0)), null), + null, RelCollations.of( new RelFieldCollation( 1, @@ -2589,6 +2591,7 @@ class FlinkRelMdHandlerTestBase { false, RexWindowBounds.create(SqlWindow.createUnboundedPreceding(new SqlParserPos(4, 15)), null), RexWindowBounds.create(SqlWindow.createCurrentRow(new SqlParserPos(0, 0)), null), + null, RelCollations.of( new RelFieldCollation( 2, @@ -2634,6 +2637,7 @@ class FlinkRelMdHandlerTestBase { false, RexWindowBounds.create(SqlWindow.createUnboundedPreceding(new SqlParserPos(7, 19)), null), RexWindowBounds.create(SqlWindow.createUnboundedFollowing(new SqlParserPos(0, 0)), null), + null, RelCollations.EMPTY, ImmutableList.of( new Window.RexWinAggCall( diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdSelectivityTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdSelectivityTest.scala index eea6b0b433605..1fc2f71faeb13 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdSelectivityTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdSelectivityTest.scala @@ -520,6 +520,7 @@ class FlinkRelMdSelectivityTest extends FlinkRelMdHandlerTestBase { true, RexWindowBounds.create(SqlWindow.createUnboundedPreceding(new SqlParserPos(0, 0)), null), RexWindowBounds.create(SqlWindow.createCurrentRow(new SqlParserPos(0, 0)), null), + null, RelCollations.of( new RelFieldCollation( 1, diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUpsertKeysTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUpsertKeysTest.scala index 996ef4f2de8bf..4a77662bd2b70 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUpsertKeysTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUpsertKeysTest.scala @@ -571,6 +571,7 @@ class FlinkRelMdUpsertKeysTest extends FlinkRelMdHandlerTestBase { true, RexWindowBounds.create(SqlWindow.createUnboundedPreceding(new SqlParserPos(0, 0)), null), RexWindowBounds.create(SqlWindow.createCurrentRow(new SqlParserPos(0, 0)), null), + null, RelCollations.of( new RelFieldCollation( 2, diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/FlinkPruneEmptyRulesTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/FlinkPruneEmptyRulesTest.scala index 74e9ad050fea3..11258852f29a2 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/FlinkPruneEmptyRulesTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/FlinkPruneEmptyRulesTest.scala @@ -47,7 +47,7 @@ class FlinkPruneEmptyRulesTest extends TableTestBase { CoreRules.FILTER_REDUCE_EXPRESSIONS, CoreRules.PROJECT_REDUCE_EXPRESSIONS, CoreRules.FILTER_SET_OP_TRANSPOSE, - CoreRules.FILTER_PROJECT_TRANSPOSE, + FlinkFilterProjectTransposeRule.INSTANCE, CoreRules.PROJECT_MERGE, CoreRules.PROJECT_FILTER_VALUES_MERGE, FlinkPruneEmptyRules.UNION_INSTANCE, diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/subquery/SubQueryAntiJoinTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/subquery/SubQueryAntiJoinTest.scala index 7ecef1016c0d1..065f5b97c7750 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/subquery/SubQueryAntiJoinTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/subquery/SubQueryAntiJoinTest.scala @@ -737,7 +737,7 @@ class SubQueryAntiJoinTest extends SubQueryTestBase { // TODO some bugs in SubQueryRemoveRule // the result RelNode (LogicalJoin(condition=[=($1, $11)], joinType=[left])) // after SubQueryRemoveRule is unexpected - assertThatExceptionOfType(classOf[AssertionError]) + assertThatExceptionOfType(classOf[NullPointerException]) .isThrownBy(() => util.verifyRelPlanNotExpected(sqlQuery, "joinType=[anti]")) } diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/subquery/SubQuerySemiJoinTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/subquery/SubQuerySemiJoinTest.scala index 5b0802ccf59f8..809ac66c577c2 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/subquery/SubQuerySemiJoinTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/rules/logical/subquery/SubQuerySemiJoinTest.scala @@ -964,25 +964,9 @@ class SubQuerySemiJoinTest extends SubQueryTestBase { @Test def testInWithCorrelatedOnHaving(): Unit = { - // TODO There are some bugs when converting SqlNode to RelNode: val sqlQuery = "SELECT SUM(a) AS s FROM x GROUP BY b " + "HAVING MAX(a) IN (SELECT d FROM y WHERE y.d = x.b)" - - // the logical plan is: - // - // LogicalProject(s=[$1]) - // LogicalFilter(condition=[IN($2, { - // LogicalProject(d=[$1]) - // LogicalFilter(condition=[=($1, $cor0.b)]) - // LogicalTableScan(table=[[builtin, default, r]]) - // })]) - // LogicalAggregate(group=[{0}], s=[SUM($1)], agg#1=[MAX($1)]) - // LogicalProject(b=[$1], a=[$0]) - // LogicalTableScan(table=[[builtin, default, l]]) - // - // LogicalFilter lost variablesSet information. - - util.verifyRelPlanNotExpected(sqlQuery, "joinType=[semi]") + util.verifyRelPlan(sqlQuery) } @Test @@ -1622,25 +1606,9 @@ class SubQuerySemiJoinTest extends SubQueryTestBase { @Test def testExistsWithCorrelatedOnHaving(): Unit = { - // TODO There are some bugs when converting SqlNode to RelNode: - val sqlQuery1 = + val sqlQuery = "SELECT SUM(a) AS s FROM x GROUP BY b HAVING EXISTS (SELECT * FROM y WHERE y.d = x.b)" - - // the logical plan is: - // - // LogicalProject(s=[$1]) - // LogicalFilter(condition=[IN($2, { - // LogicalProject(d=[$1]) - // LogicalFilter(condition=[=($1, $cor0.b)]) - // LogicalTableScan(table=[[builtin, default, r]]) - // })]) - // LogicalAggregate(group=[{0}], s=[SUM($1)], agg#1=[MAX($1)]) - // LogicalProject(b=[$1], a=[$0]) - // LogicalTableScan(table=[[builtin, default, l]]) - // - // LogicalFilter lost variablesSet information. - - util.verifyRelPlanNotExpected(sqlQuery1, "joinType=[semi]") + util.verifyRelPlan(sqlQuery) } @Test @@ -1711,7 +1679,7 @@ class SubQuerySemiJoinTest extends SubQueryTestBase { // TODO some bugs in SubQueryRemoveRule // the result RelNode (LogicalJoin(condition=[=($1, $8)], joinType=[left])) // after SubQueryRemoveRule is unexpected - assertThatExceptionOfType(classOf[AssertionError]) + assertThatExceptionOfType(classOf[NullPointerException]) .isThrownBy(() => util.verifyRelPlanNotExpected(sqlQuery, "joinType=[semi]")) } diff --git a/flink-table/pom.xml b/flink-table/pom.xml index 476e911568487..eacd06e2cc1ce 100644 --- a/flink-table/pom.xml +++ b/flink-table/pom.xml @@ -79,7 +79,7 @@ under the License. - 1.37.0 + 1.38.0 3.1.11 33.4.0-jre 2.5.2 From e7ad03281274c4c95cedfe8e220b887b54a37c2c Mon Sep 17 00:00:00 2001 From: Sergey Nuyanzin Date: Thu, 28 May 2026 13:50:12 +0200 Subject: [PATCH 2/5] [FLINK-36602][tests] Replace deprecated `JavaConversions` in `FlinkRelMdHandlerTest` and its children --- .../FlinkRelMdColumnIntervalTest.scala | 36 +- .../FlinkRelMdColumnNullCountTest.scala | 11 +- .../FlinkRelMdColumnOriginNullCountTest.scala | 11 +- .../FlinkRelMdColumnUniquenessTest.scala | 4 +- .../FlinkRelMdDistinctRowCountTest.scala | 5 +- .../metadata/FlinkRelMdDistributionTest.scala | 18 +- ...FlinkRelMdFilteredColumnIntervalTest.scala | 5 +- .../metadata/FlinkRelMdHandlerTestBase.scala | 345 ++++++++++-------- .../FlinkRelMdModifiedMonotonicityTest.scala | 11 +- .../metadata/FlinkRelMdRowCollationTest.scala | 24 +- .../metadata/FlinkRelMdRowCountTest.scala | 4 +- .../metadata/FlinkRelMdSelectivityTest.scala | 29 +- .../plan/metadata/FlinkRelMdSizeTest.scala | 106 +++--- .../metadata/FlinkRelMdUniqueGroupsTest.scala | 24 +- .../metadata/FlinkRelMdUniqueKeysTest.scala | 173 +++++---- .../metadata/FlinkRelMdUpsertKeysTest.scala | 183 +++++----- 16 files changed, 533 insertions(+), 456 deletions(-) diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnIntervalTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnIntervalTest.scala index aab16f95d4c9f..1902cf6e4be06 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnIntervalTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnIntervalTest.scala @@ -27,7 +27,7 @@ import org.apache.flink.table.types.logical.IntType import org.apache.calcite.rel.RelDistributions import org.apache.calcite.rel.core.JoinRelType import org.apache.calcite.rel.logical.LogicalExchange -import org.apache.calcite.rex.{RexCall, RexUtil} +import org.apache.calcite.rex.{RexCall, RexNode, RexUtil} import org.apache.calcite.sql.fun.SqlStdOperatorTable._ import org.apache.calcite.util.{DateString, TimestampString, TimeString} import org.junit.jupiter.api.Assertions._ @@ -35,8 +35,6 @@ import org.junit.jupiter.api.Test import java.sql.{Date, Time, Timestamp} -import scala.collection.JavaConversions._ - class FlinkRelMdColumnIntervalTest extends FlinkRelMdHandlerTestBase { @Test @@ -238,7 +236,11 @@ class FlinkRelMdColumnIntervalTest extends FlinkRelMdHandlerTestBase { val expr7 = relBuilder.call(GREATER_THAN, relBuilder.field(2), relBuilder.literal(1.9d)) // calc => projects + filter(id <= 20) - val calc1 = createLogicalCalc(studentLogicalScan, outputRowType, projects, List(expr1)) + val calc1 = createLogicalCalc( + studentLogicalScan, + outputRowType, + projects, + java.util.List.of[RexNode](expr1)) assertEquals(ValueInterval(bd(0), bd(20)), mq.getColumnInterval(calc1, 0)) assertNull(mq.getColumnInterval(calc1, 1)) assertEqualsAsDouble(ValueInterval(bd(2.9), bd(5.0)), mq.getColumnInterval(calc1, 2)) @@ -254,7 +256,11 @@ class FlinkRelMdColumnIntervalTest extends FlinkRelMdHandlerTestBase { // calc => project + filter(id <= 20 AND id > 10 AND DIV(id, 2) > 3) val calc2 = - createLogicalCalc(studentLogicalScan, outputRowType, projects, List(expr1, expr2, expr3)) + createLogicalCalc( + studentLogicalScan, + outputRowType, + projects, + java.util.List.of(expr1, expr2, expr3)) assertEquals( ValueInterval(bd(10), bd(20), includeLower = false), mq.getColumnInterval(calc2, 0)) @@ -262,7 +268,11 @@ class FlinkRelMdColumnIntervalTest extends FlinkRelMdHandlerTestBase { // calc => project + filter(id <= 20 AND id > 10 AND score < 4.1) val calc3 = - createLogicalCalc(studentLogicalScan, outputRowType, projects, List(expr1, expr2, expr4)) + createLogicalCalc( + studentLogicalScan, + outputRowType, + projects, + java.util.List.of(expr1, expr2, expr4)) assertEquals( ValueInterval(bd(10), bd(20), includeLower = false), mq.getColumnInterval(calc3, 0)) @@ -272,7 +282,7 @@ class FlinkRelMdColumnIntervalTest extends FlinkRelMdHandlerTestBase { studentLogicalScan, outputRowType, projects, - List(relBuilder.call(OR, expr5, expr6))) + java.util.List.of[RexNode](relBuilder.call(OR, expr5, expr6))) assertEqualsAsDouble(ValueInterval(bd(2.9), bd(5.0)), mq.getColumnInterval(calc4, 2)) // calc => project + filter(score > 6.0 OR score <= 4.0 OR id < 20) @@ -280,7 +290,7 @@ class FlinkRelMdColumnIntervalTest extends FlinkRelMdHandlerTestBase { studentLogicalScan, outputRowType, projects, - List(relBuilder.call(OR, expr5, expr6, expr1))) + java.util.List.of[RexNode](relBuilder.call(OR, expr5, expr6, expr1))) assertEqualsAsDouble(ValueInterval(bd(2.9), bd(5.0)), mq.getColumnInterval(calc5, 2)) // calc => project + filter((id <= 20 AND score < 4.1) OR NOT(DIV(id, 2) > 3 OR score > 1.9)) @@ -288,7 +298,7 @@ class FlinkRelMdColumnIntervalTest extends FlinkRelMdHandlerTestBase { studentLogicalScan, outputRowType, projects, - List( + java.util.List.of[RexNode]( relBuilder.call( OR, relBuilder.call(AND, expr1, expr4), @@ -301,7 +311,7 @@ class FlinkRelMdColumnIntervalTest extends FlinkRelMdHandlerTestBase { studentLogicalScan, outputRowType, projects, - List( + java.util.List.of[RexNode]( relBuilder.call( OR, relBuilder.call(AND, expr1, expr4), @@ -331,7 +341,11 @@ class FlinkRelMdColumnIntervalTest extends FlinkRelMdHandlerTestBase { Array("f0", "f1", "f2", "f3"), Array(new IntType(), new IntType(), new IntType(), new IntType())) val calc8 = - createLogicalCalc(studentLogicalScan, rowType, List(expr8, expr9, expr10, expr11), List()) + createLogicalCalc( + studentLogicalScan, + rowType, + java.util.List.of(expr8, expr9, expr10, expr11), + java.util.List.of()) assertEquals(ValueInterval(bd(0), bd(1)), mq.getColumnInterval(calc8, 0)) assertEquals(ValueInterval(bd(10), bd(12)), mq.getColumnInterval(calc8, 1)) diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnNullCountTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnNullCountTest.scala index 9d7a4073c8f08..4d4af154745d1 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnNullCountTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnNullCountTest.scala @@ -18,12 +18,11 @@ package org.apache.flink.table.planner.plan.metadata import org.apache.calcite.rel.core.JoinRelType +import org.apache.calcite.rex.RexNode import org.apache.calcite.sql.fun.SqlStdOperatorTable._ import org.junit.jupiter.api.Assertions._ import org.junit.jupiter.api.Test -import scala.collection.JavaConversions._ - class FlinkRelMdColumnNullCountTest extends FlinkRelMdHandlerTestBase { @Test @@ -126,7 +125,7 @@ class FlinkRelMdColumnNullCountTest extends FlinkRelMdHandlerTestBase { studentLogicalScan, studentLogicalScan.getRowType, relBuilder.fields(), - List(expr1)) + java.util.List.of[RexNode](expr1)) assertEquals(0.0, mq.getColumnNullCount(calc2, 0)) assertEquals(0.0, mq.getColumnNullCount(calc2, 1)) assertNull(mq.getColumnNullCount(calc2, 2)) @@ -141,7 +140,7 @@ class FlinkRelMdColumnNullCountTest extends FlinkRelMdHandlerTestBase { studentLogicalScan, studentLogicalScan.getRowType, relBuilder.fields(), - List(expr2)) + java.util.List.of[RexNode](expr2)) assertEquals(0.0, mq.getColumnNullCount(calc3, 0)) assertEquals(0.0, mq.getColumnNullCount(calc3, 1)) assertNull(mq.getColumnNullCount(calc3, 2)) @@ -156,7 +155,7 @@ class FlinkRelMdColumnNullCountTest extends FlinkRelMdHandlerTestBase { studentLogicalScan, logicalProject.getRowType, logicalProject.getProjects, - List(expr1)) + java.util.List.of[RexNode](expr1)) assertEquals(0.0, mq.getColumnNullCount(calc4, 0)) assertEquals(0.0, mq.getColumnNullCount(calc4, 1)) assertNull(mq.getColumnNullCount(calc4, 2)) @@ -175,7 +174,7 @@ class FlinkRelMdColumnNullCountTest extends FlinkRelMdHandlerTestBase { studentLogicalScan, logicalProject.getRowType, logicalProject.getProjects, - List(expr2)) + java.util.List.of[RexNode](expr2)) assertEquals(0.0, mq.getColumnNullCount(calc5, 0)) assertEquals(0.0, mq.getColumnNullCount(calc5, 1)) assertNull(mq.getColumnNullCount(calc5, 2)) diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnOriginNullCountTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnOriginNullCountTest.scala index ef23c4858dda1..4a137ce040fc7 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnOriginNullCountTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnOriginNullCountTest.scala @@ -18,12 +18,11 @@ package org.apache.flink.table.planner.plan.metadata import org.apache.calcite.rel.core.JoinRelType +import org.apache.calcite.rex.RexNode import org.apache.calcite.sql.fun.SqlStdOperatorTable.{EQUALS, LESS_THAN_OR_EQUAL} import org.junit.jupiter.api.Assertions._ import org.junit.jupiter.api.Test -import scala.collection.JavaConversions._ - class FlinkRelMdColumnOriginNullCountTest extends FlinkRelMdHandlerTestBase { @Test @@ -69,7 +68,7 @@ class FlinkRelMdColumnOriginNullCountTest extends FlinkRelMdHandlerTestBase { val ts = relBuilder.scan("MyTable3").build() relBuilder.push(ts) - val projects = List( + val projects = java.util.List.of( relBuilder.call(EQUALS, relBuilder.field(0), relBuilder.literal(1)), relBuilder.field(0), relBuilder.field(1), @@ -95,14 +94,14 @@ class FlinkRelMdColumnOriginNullCountTest extends FlinkRelMdHandlerTestBase { studentLogicalScan, studentLogicalScan.getRowType, relBuilder.fields(), - List(expr)) + java.util.List.of[RexNode](expr)) (0 until calc1.getRowType.getFieldCount).foreach { idx => assertNull(mq.getColumnOriginNullCount(calc1, idx)) } val ts = relBuilder.scan("MyTable3").build() relBuilder.push(ts) - val projects = List( + val projects = java.util.List.of( relBuilder.call(EQUALS, relBuilder.field(0), relBuilder.literal(1)), relBuilder.field(0), relBuilder.field(1), @@ -110,7 +109,7 @@ class FlinkRelMdColumnOriginNullCountTest extends FlinkRelMdHandlerTestBase { relBuilder.literal(null) ) val outputRowType = relBuilder.project(projects).build().getRowType - val calc2 = createLogicalCalc(ts, outputRowType, projects, List()) + val calc2 = createLogicalCalc(ts, outputRowType, projects, java.util.List.of()) assertEquals(null, mq.getColumnOriginNullCount(calc2, 0)) assertEquals(1.0, mq.getColumnOriginNullCount(calc2, 1)) assertEquals(0.0, mq.getColumnOriginNullCount(calc2, 2)) diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnUniquenessTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnUniquenessTest.scala index 57920d5d6ccc9..49d07882c63f4 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnUniquenessTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdColumnUniquenessTest.scala @@ -29,8 +29,6 @@ import org.apache.calcite.util.ImmutableBitSet import org.junit.jupiter.api.Assertions._ import org.junit.jupiter.api.Test -import scala.collection.JavaConversions._ - class FlinkRelMdColumnUniquenessTest extends FlinkRelMdHandlerTestBase { @Test @@ -93,7 +91,7 @@ class FlinkRelMdColumnUniquenessTest extends FlinkRelMdHandlerTestBase { // project: id, cast(id as long not null), name, cast(name as varchar not null) relBuilder.push(studentLogicalScan) - val exprs = List( + val exprs = java.util.List.of( relBuilder.field(0), relBuilder.cast(relBuilder.field(0), BIGINT), relBuilder.field(1), diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdDistinctRowCountTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdDistinctRowCountTest.scala index 568307603f898..4beb3b1802ad6 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdDistinctRowCountTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdDistinctRowCountTest.scala @@ -20,13 +20,12 @@ package org.apache.flink.table.planner.plan.metadata import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalRank import org.apache.flink.table.planner.plan.utils.FlinkRelMdUtil +import org.apache.calcite.rex.RexNode import org.apache.calcite.sql.fun.SqlStdOperatorTable._ import org.apache.calcite.util.ImmutableBitSet import org.junit.jupiter.api.Assertions._ import org.junit.jupiter.api.Test -import scala.collection.JavaConversions._ - class FlinkRelMdDistinctRowCountTest extends FlinkRelMdHandlerTestBase { @Test @@ -187,7 +186,7 @@ class FlinkRelMdDistinctRowCountTest extends FlinkRelMdHandlerTestBase { studentLogicalScan, logicalProject.getRowType, logicalProject.getProjects, - List(expr1)) + java.util.List.of[RexNode](expr1)) assertEquals(1.0, mq.getDistinctRowCount(calc, ImmutableBitSet.of(), null)) assertEquals(25.0, mq.getDistinctRowCount(calc, ImmutableBitSet.of(0), null)) diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdDistributionTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdDistributionTest.scala index 0b741f2fb22e7..b8325f725f625 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdDistributionTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdDistributionTest.scala @@ -24,12 +24,11 @@ import org.apache.flink.table.planner.plan.nodes.physical.stream.StreamPhysicalD import org.apache.flink.table.types.logical.{BigIntType, DoubleType} import com.google.common.collect.ImmutableList +import org.apache.calcite.rex.RexNode import org.apache.calcite.sql.fun.SqlStdOperatorTable._ import org.junit.jupiter.api.Assertions._ import org.junit.jupiter.api.Test -import scala.collection.JavaConversions._ - class FlinkRelMdDistributionTest extends FlinkRelMdHandlerTestBase { @Test @@ -70,7 +69,11 @@ class FlinkRelMdDistributionTest extends FlinkRelMdHandlerTestBase { relBuilder.push(scan1) val expr4 = relBuilder.call(LESS_THAN, relBuilder.field(4), relBuilder.literal(170.0)) val calc = - createLogicalCalc(scan1, logicalProject.getRowType, logicalProject.getProjects, List(expr4)) + createLogicalCalc( + scan1, + logicalProject.getRowType, + logicalProject.getProjects, + java.util.List.of[RexNode](expr4)) assertEquals( FlinkRelDistribution.hash(Array(6), requireStrict = false), mq.flinkDistribution(calc)) @@ -80,7 +83,7 @@ class FlinkRelMdDistributionTest extends FlinkRelMdHandlerTestBase { createDataStreamScan(ImmutableList.of("student"), flinkLogicalTraits.replace(distribution01)) relBuilder.push(scan2) // projects: $0==1, $0, $1, true, 2.1, 2 - val projects1 = List( + val projects1 = java.util.List.of( relBuilder.call(EQUALS, relBuilder.field(0), relBuilder.literal(1)), relBuilder.field(0), relBuilder.field(1), @@ -92,13 +95,14 @@ class FlinkRelMdDistributionTest extends FlinkRelMdHandlerTestBase { relBuilder.push(scan2) val expr1 = relBuilder.call(LESS_THAN_OR_EQUAL, relBuilder.field(0), relBuilder.literal(2)) // calc => projects + filter: $0 <= 2 - val calc1 = createLogicalCalc(scan2, outputRowType, projects1, List(expr1)) + val calc1 = + createLogicalCalc(scan2, outputRowType, projects1, java.util.List.of[RexNode](expr1)) assertEquals( FlinkRelDistribution.hash(Array(1, 2), requireStrict = false), mq.flinkDistribution(calc1)) // projects: $0==1, $0, 2.1, true, 2.1, 2 - val projects2 = List( + val projects2 = java.util.List.of( relBuilder.call(EQUALS, relBuilder.field(0), relBuilder.literal(1)), relBuilder.field(0), makeLiteral(2.1, new DoubleType(), isNullable = false, allowCast = true), @@ -106,7 +110,7 @@ class FlinkRelMdDistributionTest extends FlinkRelMdHandlerTestBase { makeLiteral(2.1, new DoubleType(), isNullable = false, allowCast = true), makeLiteral(2L, new BigIntType(), isNullable = false, allowCast = true) ) - val calc2 = createLogicalCalc(scan2, outputRowType, projects2, List()) + val calc2 = createLogicalCalc(scan2, outputRowType, projects2, java.util.List.of()) assertEquals(FlinkRelDistribution.ANY, mq.flinkDistribution(calc2)) } diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdFilteredColumnIntervalTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdFilteredColumnIntervalTest.scala index 6ea6a129fa1ec..96a24e332fa39 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdFilteredColumnIntervalTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdFilteredColumnIntervalTest.scala @@ -26,7 +26,7 @@ import org.apache.calcite.sql.fun.SqlStdOperatorTable.{DIVIDE, EQUALS, GREATER_T import org.junit.jupiter.api.{BeforeEach, Test} import org.junit.jupiter.api.Assertions.{assertEquals, assertNull} -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters.seqAsJavaListConverter class FlinkRelMdFilteredColumnIntervalTest extends FlinkRelMdHandlerTestBase { private var ts: RelNode = _ @@ -202,7 +202,8 @@ class FlinkRelMdFilteredColumnIntervalTest extends FlinkRelMdHandlerTestBase { new BooleanType() ) ) - val calc = createLogicalCalc(ts, outputRowType, projects, List(expr1)) + val calc = + createLogicalCalc(ts, outputRowType, projects.asJava, java.util.List.of[RexNode](expr1)) assertEquals(ValueInterval(bd(-5), bd(2)), mq.getFilteredColumnInterval(calc, 0, -1)) assertEquals(ValueInterval(bd(0d), bd(6.1d)), mq.getFilteredColumnInterval(calc, 1, -1)) assertEquals(ValueInterval(bd(-5), bd(2)), mq.getFilteredColumnInterval(calc, 0, 2)) diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala index d8815d3aa6bf7..2662d3c3139ad 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdHandlerTestBase.scala @@ -77,7 +77,7 @@ import java.time.Duration import java.util import java.util.Collections -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters.{collectionAsScalaIterableConverter, seqAsJavaListConverter} class FlinkRelMdHandlerTestBase { @@ -177,11 +177,20 @@ class FlinkRelMdHandlerTestBase { createTableSourceTable(ImmutableList.of("TableSourceTable1"), streamPhysicalTraits) protected lazy val flinkLogicalIntermediateTableScan: FlinkLogicalIntermediateTableScan = - createIntermediateScan(streamExchangeById, flinkLogicalTraits, Set(ImmutableBitSet.of(0))) + createIntermediateScan( + streamExchangeById, + flinkLogicalTraits, + java.util.Set.of[ImmutableBitSet](ImmutableBitSet.of(0))) protected lazy val batchPhysicalIntermediateTableScan: BatchPhysicalIntermediateTableScan = - createIntermediateScan(batchExchangeById, batchPhysicalTraits, Set(ImmutableBitSet.of(0))) + createIntermediateScan( + batchExchangeById, + batchPhysicalTraits, + java.util.Set.of[ImmutableBitSet](ImmutableBitSet.of(0))) protected lazy val streamPhysicalIntermediateTableScan: StreamPhysicalIntermediateTableScan = - createIntermediateScan(streamExchangeById, streamPhysicalTraits, Set(ImmutableBitSet.of(0))) + createIntermediateScan( + streamExchangeById, + streamPhysicalTraits, + java.util.Set.of[ImmutableBitSet](ImmutableBitSet.of(0))) protected lazy val tablePartiallyProjectedKeyLogicalScan: LogicalTableScan = createTableSourceTable( @@ -248,7 +257,7 @@ class FlinkRelMdHandlerTestBase { List(null, "false", "2017-09-01", "10:00:01", null, "3.12", null, null), List("3", "true", null, "10:00:02", "2017-10-01 01:00:00", "3.0", null, "xyz"), List("2", "true", "2017-10-02", "09:59:59", "2017-07-01 01:00:00", "-1", null, "F") - ).map(createLiteralList(valuesType, _)) + ).map(createLiteralList(valuesType, _)).asJava relBuilder.values(tupleList, valuesType) relBuilder.build().asInstanceOf[LogicalValues] } @@ -257,7 +266,7 @@ class FlinkRelMdHandlerTestBase { // case sex = 'M' then 1 else 2, true, 2.1, 2, cast(score as double not null) as s from student protected lazy val logicalProject: LogicalProject = { relBuilder.push(studentLogicalScan) - val projects = List( + val projects = java.util.List.of( // id relBuilder.field(0), // name @@ -304,7 +313,7 @@ class FlinkRelMdHandlerTestBase { studentLogicalScan, logicalProject.getRowType, logicalProject.getProjects, - List(expr)) + java.util.List.of[RexNode](expr)) (filter, calc) } @@ -383,7 +392,12 @@ class FlinkRelMdHandlerTestBase { } protected lazy val intermediateTable = - new IntermediateRelTable(Seq(""), streamExchangeById, null, false, Set(ImmutableBitSet.of(0))) + new IntermediateRelTable( + java.util.List.of[String](""), + streamExchangeById, + null, + false, + java.util.Set.of[ImmutableBitSet](ImmutableBitSet.of(0))) protected lazy val intermediateScan = new FlinkLogicalIntermediateTableScan( cluster, @@ -403,7 +417,7 @@ class FlinkRelMdHandlerTestBase { protected def createSorts(sortKeys: () => Seq[RexNode]): (RelNode, RelNode, RelNode, RelNode) = { val logicalSort = relBuilder .scan("student") - .sort(sortKeys()) + .sort(sortKeys().asJava) .build .asInstanceOf[LogicalSort] val collation = logicalSort.getCollation @@ -513,7 +527,7 @@ class FlinkRelMdHandlerTestBase { sortKeys: () => Seq[RexNode]): (RelNode, RelNode, RelNode, RelNode, RelNode, RelNode) = { val logicalSortLimit = relBuilder .scan("student") - .sort(sortKeys()) + .sort(sortKeys().asJava) .limit(10, 20) .build .asInstanceOf[LogicalSort] @@ -873,10 +887,10 @@ class FlinkRelMdHandlerTestBase { ) val builder = typeFactory.builder() - firstRow.getRowType.getFieldList.dropRight(2).foreach(builder.add) + firstRow.getRowType.getFieldList.asScala.dropRight(2).foreach(builder.add) val projectProgram = RexProgram.create( firstRow.getRowType, - Array(0, 1, 2).map(i => RexInputRef.of(i, firstRow.getRowType)).toList, + Array(0, 1, 2).map(i => RexInputRef.of(i, firstRow.getRowType)).toList.asJava, null, builder.build(), rexBuilder @@ -1016,8 +1030,8 @@ class FlinkRelMdHandlerTestBase { false, false, false, - Seq().toList, - Seq(Integer.valueOf(3)).toList, + java.util.List.of[RexNode](), + java.util.List.of[Integer](Integer.valueOf(3)), -1, null, RelCollations.of(), @@ -1034,7 +1048,7 @@ class FlinkRelMdHandlerTestBase { studentLogicalScan, ImmutableBitSet.of(0), null, - Seq(tableAggCall)) + Seq(tableAggCall).asJava) val flinkLogicalTableAgg = new FlinkLogicalTableAggregate( cluster, @@ -1042,7 +1056,7 @@ class FlinkRelMdHandlerTestBase { studentLogicalScan, ImmutableBitSet.of(0), null, - Seq(tableAggCall) + Seq(tableAggCall).asJava ) val builder = typeFactory.builder() @@ -1075,7 +1089,7 @@ class FlinkRelMdHandlerTestBase { relBuilder.scan("TemporalTable1") val ts = relBuilder.peek() val project = relBuilder - .project(relBuilder.fields(Seq[Integer](2, 0, 1, 4).toList)) + .project(relBuilder.fields(Seq[Integer](2, 0, 1, 4).toList.asJava)) .build() .asInstanceOf[Project] val program = @@ -1122,9 +1136,9 @@ class FlinkRelMdHandlerTestBase { streamExchange, flinkLogicalWindowAgg.getRowType, Array(1), - flinkLogicalWindowAgg.getAggCallList, + flinkLogicalWindowAgg.getAggCallList.asScala.toSeq, tumblingGroupWindow, - namedPropertiesOfWindowAgg, + namedPropertiesOfWindowAgg.asScala.toSeq, emitStrategy ) @@ -1170,11 +1184,11 @@ class FlinkRelMdHandlerTestBase { logicalAgg.getAggCallList ) - val aggCalls = logicalAgg.getAggCallList + val aggCalls = logicalAgg.getAggCallList.asScala val aggFunctionFactory = new AggFunctionFactory( FlinkTypeFactory.toLogicalRowType(studentBatchScan.getRowType), Array.empty[Int], - Array.fill(aggCalls.size())(false), + Array.fill(aggCalls.size)(false), false) val aggCallToAggFunction = aggCalls.zipWithIndex.map { case (call, index) => (call, aggFunctionFactory.createAggFunction(call, index)) @@ -1210,7 +1224,7 @@ class FlinkRelMdHandlerTestBase { Array(3), auxGrouping = Array(), true, - aggCallToAggFunction) + aggCallToAggFunction.toSeq) val batchExchange1 = new BatchPhysicalExchange( cluster, @@ -1226,7 +1240,7 @@ class FlinkRelMdHandlerTestBase { batchLocalAgg.getInput.getRowType, Array(0), auxGrouping = Array(), - aggCallToAggFunction, + aggCallToAggFunction.toSeq, isMerge = true) val batchExchange2 = new BatchPhysicalExchange( @@ -1243,17 +1257,17 @@ class FlinkRelMdHandlerTestBase { batchExchange2.getRowType, Array(3), auxGrouping = Array(), - aggCallToAggFunction, + aggCallToAggFunction.toSeq, isMerge = false) val aggCallNeedRetractions = - AggregateUtil.deriveAggCallNeedRetractions(1, aggCalls, needRetraction = false, null) + AggregateUtil.deriveAggCallNeedRetractions(1, aggCalls.toSeq, needRetraction = false, null) val streamLocalAgg = new StreamPhysicalLocalGroupAggregate( cluster, streamPhysicalTraits, studentStreamScan, Array(3), - aggCalls, + aggCalls.toSeq, aggCallNeedRetractions, false, PartialFinalType.NONE) @@ -1269,7 +1283,7 @@ class FlinkRelMdHandlerTestBase { streamExchange1, rowTypeOfGlobalAgg, Array(0), - aggCalls, + aggCalls.toSeq, aggCallNeedRetractions, streamLocalAgg.getInput.getRowType, AggregateUtil.needRetraction(streamLocalAgg), @@ -1286,7 +1300,7 @@ class FlinkRelMdHandlerTestBase { streamExchange2, rowTypeOfGlobalAgg, Array(3), - aggCalls) + aggCalls.toSeq) ( logicalAgg, @@ -1328,7 +1342,7 @@ class FlinkRelMdHandlerTestBase { streamGlobalAggWithoutLocalWithFilter) = { relBuilder.push(studentLogicalScan) - val projects = List( + val projects = java.util.List.of( relBuilder.field(0), relBuilder.field(1), relBuilder.field(2), @@ -1371,7 +1385,7 @@ class FlinkRelMdHandlerTestBase { false, false, false, - List(Integer.valueOf(argIndex)), + java.util.List.of[Integer](Integer.valueOf(argIndex)), filterArg, null, RelCollations.EMPTY, @@ -1401,10 +1415,10 @@ class FlinkRelMdHandlerTestBase { val logicalAggWithFilter = LogicalAggregate.create( calcOnStudentScan, - List(), + java.util.List.of[RelHint](), ImmutableBitSet.of(3), - List(ImmutableBitSet.of(3)), - aggCallList) + java.util.List.of[ImmutableBitSet](ImmutableBitSet.of(3)), + aggCallList.asJava) val flinkLogicalAggWithFilter = new FlinkLogicalAggregate( cluster, @@ -1414,11 +1428,11 @@ class FlinkRelMdHandlerTestBase { logicalAggWithFilter.getGroupSets, logicalAggWithFilter.getAggCallList) - val aggCalls = logicalAggWithFilter.getAggCallList + val aggCalls = logicalAggWithFilter.getAggCallList.asScala val aggFunctionFactory = new AggFunctionFactory( FlinkTypeFactory.toLogicalRowType(calcOnStudentScan.getRowType), Array.empty[Int], - Array.fill(aggCalls.size())(false), + Array.fill(aggCalls.size)(false), false) val aggCallToAggFunction = aggCalls.zipWithIndex.map { case (call, index) => (call, aggFunctionFactory.createAggFunction(call, index)) @@ -1476,7 +1490,7 @@ class FlinkRelMdHandlerTestBase { Array(3), auxGrouping = Array(), true, - aggCallToAggFunction) + aggCallToAggFunction.toSeq) val batchExchange1 = new BatchPhysicalExchange( cluster, @@ -1492,7 +1506,7 @@ class FlinkRelMdHandlerTestBase { batchLocalAggWithFilter.getInput.getRowType, Array(0), auxGrouping = Array(), - aggCallToAggFunction, + aggCallToAggFunction.toSeq, isMerge = true) val batchExchange2 = new BatchPhysicalExchange( @@ -1509,17 +1523,17 @@ class FlinkRelMdHandlerTestBase { batchExchange2.getRowType, Array(3), auxGrouping = Array(), - aggCallToAggFunction, + aggCallToAggFunction.toSeq, isMerge = false) val aggCallNeedRetractions = - AggregateUtil.deriveAggCallNeedRetractions(1, aggCalls, needRetraction = false, null) + AggregateUtil.deriveAggCallNeedRetractions(1, aggCalls.toSeq, needRetraction = false, null) val streamLocalAggWithFilter = new StreamPhysicalLocalGroupAggregate( cluster, streamPhysicalTraits, calcOnStudentScan, Array(3), - aggCalls, + aggCalls.toSeq, aggCallNeedRetractions, false, PartialFinalType.NONE) @@ -1535,7 +1549,7 @@ class FlinkRelMdHandlerTestBase { streamExchange1, rowTypeOfGlobalAgg, Array(0), - aggCalls, + aggCalls.toSeq, aggCallNeedRetractions, streamLocalAggWithFilter.getInput.getRowType, AggregateUtil.needRetraction(streamLocalAggWithFilter), @@ -1552,7 +1566,7 @@ class FlinkRelMdHandlerTestBase { streamExchange2, rowTypeOfGlobalAgg, Array(3), - aggCalls) + aggCalls.toSeq) ( logicalAggWithFilter, @@ -1599,13 +1613,13 @@ class FlinkRelMdHandlerTestBase { logicalAggWithAuxGroup.getAggCallList ) - val aggCalls = logicalAggWithAuxGroup.getAggCallList.filter { + val aggCalls = logicalAggWithAuxGroup.getAggCallList.asScala.filter { call => call.getAggregation != FlinkSqlOperatorTable.AUXILIARY_GROUP } val aggFunctionFactory = new AggFunctionFactory( FlinkTypeFactory.toLogicalRowType(studentBatchScan.getRowType), Array.empty[Int], - Array.fill(aggCalls.size())(false), + Array.fill(aggCalls.size)(false), false) val aggCallToAggFunction = aggCalls.zipWithIndex.map { case (call, index) => (call, aggFunctionFactory.createAggFunction(call, index)) @@ -1629,7 +1643,7 @@ class FlinkRelMdHandlerTestBase { Array(0), auxGrouping = Array(1, 4), true, - aggCallToAggFunction) + aggCallToAggFunction.toSeq) val hash0 = FlinkRelDistribution.hash(Array(0), requireStrict = true) val batchExchange = new BatchPhysicalExchange( @@ -1655,7 +1669,7 @@ class FlinkRelMdHandlerTestBase { batchLocalAggWithAuxGroup.getInput.getRowType, Array(0), auxGrouping = Array(1, 2), - aggCallToAggFunction, + aggCallToAggFunction.toSeq, isMerge = true) val batchExchange2 = new BatchPhysicalExchange( @@ -1672,7 +1686,7 @@ class FlinkRelMdHandlerTestBase { batchExchange2.getRowType, Array(0), auxGrouping = Array(1, 4), - aggCallToAggFunction, + aggCallToAggFunction.toSeq, isMerge = false) ( @@ -1700,8 +1714,8 @@ class FlinkRelMdHandlerTestBase { intervalOfMillis(900000) ) - protected lazy val namedPropertiesOfWindowAgg: Seq[NamedWindowProperty] = - Seq( + protected lazy val namedPropertiesOfWindowAgg: java.util.List[NamedWindowProperty] = + java.util.List.of( new NamedWindowProperty("w$start", new WindowStart(windowRef)), new NamedWindowProperty("w$end", new WindowStart(windowRef)), new NamedWindowProperty("w$rowtime", new RowtimeAttribute(windowRef)), @@ -1725,7 +1739,7 @@ class FlinkRelMdHandlerTestBase { relBuilder.scan("TemporalTable1") val ts = relBuilder.peek() val project = relBuilder - .project(relBuilder.fields(Seq[Integer](0, 1, 4, 2).toList)) + .project(relBuilder.fields(Seq[Integer](0, 1, 4, 2).toList.asJava)) .build() .asInstanceOf[Project] val program = @@ -1736,7 +1750,7 @@ class FlinkRelMdHandlerTestBase { false, false, false, - List[Integer](3), + java.util.List.of[Integer](3), -1, null, RelCollations.EMPTY, @@ -1780,17 +1794,17 @@ class FlinkRelMdHandlerTestBase { AggregateUtil.transformToBatchAggregateFunctions( typeFactory, FlinkTypeFactory.toLogicalRowType(batchExchange1.getRowType), - flinkLogicalWindowAgg.getAggCallList) - val aggCallToAggFunction = flinkLogicalWindowAgg.getAggCallList.zip(aggregates) + flinkLogicalWindowAgg.getAggCallList.asScala.toSeq) + val aggCallToAggFunction = flinkLogicalWindowAgg.getAggCallList.asScala.zip(aggregates) val localWindowAggTypes = (Array(0, 1).map(batchCalc.getRowType.getFieldList.get(_).getType) ++ // grouping Array(longType) ++ // assignTs - aggCallOfWindowAgg.map(_.getType)).toList // agg calls + aggCallOfWindowAgg.asScala.map(_.getType)).toList.asJava // agg calls val localWindowAggNames = (Array(0, 1).map(batchCalc.getRowType.getFieldNames.get(_)) ++ // grouping Array("assignedWindow$") ++ // assignTs - Array("count$0")).toList // agg calls + Array("count$0")).toList.asJava // agg calls val localWindowAggRowType = typeFactory.createStructType(localWindowAggTypes, localWindowAggNames) val batchLocalWindowAgg = new BatchPhysicalLocalHashWindowAggregate( @@ -1801,11 +1815,11 @@ class FlinkRelMdHandlerTestBase { batchCalc.getRowType, Array(0, 1), Array.empty, - aggCallToAggFunction, + aggCallToAggFunction.toSeq, tumblingGroupWindow, inputTimeFieldIndex = 2, inputTimeIsDate = false, - namedPropertiesOfWindowAgg, + namedPropertiesOfWindowAgg.asScala.toSeq, enableAssignPane = false) val batchExchange2 = new BatchPhysicalExchange( cluster, @@ -1820,11 +1834,11 @@ class FlinkRelMdHandlerTestBase { batchCalc.getRowType, Array(0, 1), Array.empty, - aggCallToAggFunction, + aggCallToAggFunction.toSeq, tumblingGroupWindow, inputTimeFieldIndex = 2, inputTimeIsDate = false, - namedPropertiesOfWindowAgg, + namedPropertiesOfWindowAgg.asScala.toSeq, enableAssignPane = false, isMerge = true ) @@ -1837,11 +1851,11 @@ class FlinkRelMdHandlerTestBase { batchExchange1.getRowType, Array(0, 1), Array.empty, - aggCallToAggFunction, + aggCallToAggFunction.toSeq, tumblingGroupWindow, inputTimeFieldIndex = 2, inputTimeIsDate = false, - namedPropertiesOfWindowAgg, + namedPropertiesOfWindowAgg.asScala.toSeq, enableAssignPane = false, isMerge = false ) @@ -1863,9 +1877,9 @@ class FlinkRelMdHandlerTestBase { streamExchange, flinkLogicalWindowAgg.getRowType, Array(0, 1), - flinkLogicalWindowAgg.getAggCallList, + flinkLogicalWindowAgg.getAggCallList.asScala.toSeq, tumblingGroupWindow, - namedPropertiesOfWindowAgg, + namedPropertiesOfWindowAgg.asScala.toSeq, emitStrategy ) @@ -1895,7 +1909,7 @@ class FlinkRelMdHandlerTestBase { relBuilder.scan("TemporalTable1") val ts = relBuilder.peek() val project = relBuilder - .project(relBuilder.fields(Seq[Integer](0, 1, 4).toList)) + .project(relBuilder.fields(Seq[Integer](0, 1, 4).toList.asJava)) .build() .asInstanceOf[Project] val program = @@ -1906,7 +1920,7 @@ class FlinkRelMdHandlerTestBase { false, false, false, - List[Integer](0), + java.util.List.of[Integer](0), -1, null, RelCollations.EMPTY, @@ -1950,17 +1964,17 @@ class FlinkRelMdHandlerTestBase { AggregateUtil.transformToBatchAggregateFunctions( typeFactory, FlinkTypeFactory.toLogicalRowType(batchExchange1.getRowType), - flinkLogicalWindowAgg.getAggCallList) - val aggCallToAggFunction = flinkLogicalWindowAgg.getAggCallList.zip(aggregates) + flinkLogicalWindowAgg.getAggCallList.asScala.toSeq) + val aggCallToAggFunction = flinkLogicalWindowAgg.getAggCallList.asScala.zip(aggregates) val localWindowAggTypes = (Array(batchCalc.getRowType.getFieldList.get(1).getType) ++ // grouping Array(longType) ++ // assignTs - aggCallOfWindowAgg.map(_.getType)).toList // agg calls + aggCallOfWindowAgg.asScala.map(_.getType)).toList.asJava // agg calls val localWindowAggNames = (Array(batchCalc.getRowType.getFieldNames.get(1)) ++ // grouping Array("assignedWindow$") ++ // assignTs - Array("count$0")).toList // agg calls + Array("count$0")).toList.asJava // agg calls val localWindowAggRowType = typeFactory.createStructType(localWindowAggTypes, localWindowAggNames) val batchLocalWindowAgg = new BatchPhysicalLocalHashWindowAggregate( @@ -1971,11 +1985,11 @@ class FlinkRelMdHandlerTestBase { batchCalc.getRowType, Array(1), Array.empty, - aggCallToAggFunction, + aggCallToAggFunction.toSeq, tumblingGroupWindow, inputTimeFieldIndex = 2, inputTimeIsDate = false, - namedPropertiesOfWindowAgg, + namedPropertiesOfWindowAgg.asScala.toSeq, enableAssignPane = false) val batchExchange2 = new BatchPhysicalExchange( cluster, @@ -1990,11 +2004,11 @@ class FlinkRelMdHandlerTestBase { batchCalc.getRowType, Array(0), Array.empty, - aggCallToAggFunction, + aggCallToAggFunction.toSeq, tumblingGroupWindow, inputTimeFieldIndex = 2, inputTimeIsDate = false, - namedPropertiesOfWindowAgg, + namedPropertiesOfWindowAgg.asScala.toSeq, enableAssignPane = false, isMerge = true ) @@ -2007,11 +2021,11 @@ class FlinkRelMdHandlerTestBase { batchExchange1.getRowType, Array(1), Array.empty, - aggCallToAggFunction, + aggCallToAggFunction.toSeq, tumblingGroupWindow, inputTimeFieldIndex = 2, inputTimeIsDate = false, - namedPropertiesOfWindowAgg, + namedPropertiesOfWindowAgg.asScala.toSeq, enableAssignPane = false, isMerge = false ) @@ -2033,9 +2047,9 @@ class FlinkRelMdHandlerTestBase { streamExchange, flinkLogicalWindowAgg.getRowType, Array(1), - flinkLogicalWindowAgg.getAggCallList, + flinkLogicalWindowAgg.getAggCallList.asScala.toSeq, tumblingGroupWindow, - namedPropertiesOfWindowAgg, + namedPropertiesOfWindowAgg.asScala.toSeq, emitStrategy ) @@ -2064,7 +2078,7 @@ class FlinkRelMdHandlerTestBase { relBuilder.scan("TemporalTable2") val ts = relBuilder.peek() val project = relBuilder - .project(relBuilder.fields(Seq[Integer](0, 2, 4, 1).toList)) + .project(relBuilder.fields(Seq[Integer](0, 2, 4, 1).toList.asJava)) .build() .asInstanceOf[Project] val program = @@ -2075,7 +2089,7 @@ class FlinkRelMdHandlerTestBase { false, false, false, - List[Integer](1), + java.util.List.of[Integer](1), -1, null, RelCollations.EMPTY, @@ -2088,7 +2102,7 @@ class FlinkRelMdHandlerTestBase { false, false, false, - List[Integer](3), + java.util.List.of[Integer](3), -1, null, RelCollations.EMPTY, @@ -2129,24 +2143,24 @@ class FlinkRelMdHandlerTestBase { val hash0 = FlinkRelDistribution.hash(Array(0), requireStrict = true) val batchExchange1 = new BatchPhysicalExchange(cluster, batchPhysicalTraits.replace(hash0), batchCalc, hash0) - val aggCallsWithoutAuxGroup = flinkLogicalWindowAggWithAuxGroup.getAggCallList.drop(1) + val aggCallsWithoutAuxGroup = flinkLogicalWindowAggWithAuxGroup.getAggCallList.asScala.drop(1) val (_, _, aggregates) = AggregateUtil.transformToBatchAggregateFunctions( typeFactory, FlinkTypeFactory.toLogicalRowType(batchExchange1.getRowType), - aggCallsWithoutAuxGroup) + aggCallsWithoutAuxGroup.toSeq) val aggCallToAggFunction = aggCallsWithoutAuxGroup.zip(aggregates) val localWindowAggTypes = (Array(batchCalc.getRowType.getFieldList.get(0).getType) ++ // grouping Array(longType) ++ // assignTs Array(batchCalc.getRowType.getFieldList.get(1).getType) ++ // auxGrouping - aggCallsWithoutAuxGroup.map(_.getType)).toList // agg calls + aggCallsWithoutAuxGroup.map(_.getType)).toList.asJava // agg calls val localWindowAggNames = (Array(batchCalc.getRowType.getFieldNames.get(0)) ++ // grouping Array("assignedWindow$") ++ // assignTs Array(batchCalc.getRowType.getFieldNames.get(1)) ++ // auxGrouping - Array("count$0")).toList // agg calls + Array("count$0")).toList.asJava // agg calls val localWindowAggRowType = typeFactory.createStructType(localWindowAggTypes, localWindowAggNames) val batchLocalWindowAggWithAuxGroup = new BatchPhysicalLocalHashWindowAggregate( @@ -2157,11 +2171,11 @@ class FlinkRelMdHandlerTestBase { batchCalc.getRowType, Array(0), Array(1), - aggCallToAggFunction, + aggCallToAggFunction.toSeq, tumblingGroupWindow, inputTimeFieldIndex = 2, inputTimeIsDate = false, - namedPropertiesOfWindowAgg, + namedPropertiesOfWindowAgg.asScala.toSeq, enableAssignPane = false) val batchExchange2 = new BatchPhysicalExchange( cluster, @@ -2176,11 +2190,11 @@ class FlinkRelMdHandlerTestBase { batchCalc.getRowType, Array(0), Array(2), // local output grouping keys: grouping + assignTs + auxGrouping - aggCallToAggFunction, + aggCallToAggFunction.toSeq, tumblingGroupWindow, inputTimeFieldIndex = 2, inputTimeIsDate = false, - namedPropertiesOfWindowAgg, + namedPropertiesOfWindowAgg.asScala.toSeq, enableAssignPane = false, isMerge = true ) @@ -2193,11 +2207,11 @@ class FlinkRelMdHandlerTestBase { batchExchange1.getRowType, Array(0), Array(1), - aggCallToAggFunction, + aggCallToAggFunction.toSeq, tumblingGroupWindow, inputTimeFieldIndex = 2, inputTimeIsDate = false, - namedPropertiesOfWindowAgg, + namedPropertiesOfWindowAgg.asScala.toSeq, enableAssignPane = false, isMerge = false ) @@ -2250,7 +2264,10 @@ class FlinkRelMdHandlerTestBase { val rowTypeOfCalc = createRowType("id", "name", "score", "age", "class") val rexProgram = RexProgram.create( studentFlinkLogicalScan.getRowType, - Array(0, 1, 2, 3, 6).map(i => RexInputRef.of(i, studentFlinkLogicalScan.getRowType)).toList, + Array(0, 1, 2, 3, 6) + .map(i => RexInputRef.of(i, studentFlinkLogicalScan.getRowType)) + .toList + .asJava, null, rowTypeOfCalc, rexBuilder @@ -2292,20 +2309,23 @@ class FlinkRelMdHandlerTestBase { "cnt") val projectProgram = RexProgram.create( flinkLogicalOverAgg.getRowType, - (0 until flinkLogicalOverAgg.getRowType.getFieldCount).flatMap { - i => - if (i < 8 || i >= 10) { - Array[RexNode](RexInputRef.of(i, flinkLogicalOverAgg.getRowType)) - } else if (i == 8) { - Array[RexNode]( - rexBuilder.makeCall( - SqlStdOperatorTable.DIVIDE, - RexInputRef.of(8, flinkLogicalOverAgg.getRowType), - RexInputRef.of(9, flinkLogicalOverAgg.getRowType))) - } else { - Array.empty[RexNode] - } - }.toList, + (0 until flinkLogicalOverAgg.getRowType.getFieldCount) + .flatMap { + i => + if (i < 8 || i >= 10) { + Array[RexNode](RexInputRef.of(i, flinkLogicalOverAgg.getRowType)) + } else if (i == 8) { + Array[RexNode]( + rexBuilder.makeCall( + SqlStdOperatorTable.DIVIDE, + RexInputRef.of(8, flinkLogicalOverAgg.getRowType), + RexInputRef.of(9, flinkLogicalOverAgg.getRowType))) + } else { + Array.empty[RexNode] + } + } + .toList + .asJava, null, rowTypeOfWindowAggOutput, rexBuilder @@ -2344,7 +2364,7 @@ class FlinkRelMdHandlerTestBase { sort1, outputRowType1, sort1.getRowType, - Seq(overAggGroups(0)), + Seq(overAggGroups.get(0)), flinkLogicalOverAgg ) @@ -2376,7 +2396,7 @@ class FlinkRelMdHandlerTestBase { sort2, outputRowType2, sort2.getRowType, - Seq(overAggGroups(1)), + Seq(overAggGroups.get(1)), flinkLogicalOverAgg ) @@ -2406,7 +2426,7 @@ class FlinkRelMdHandlerTestBase { exchange2, outputRowType3, exchange2.getRowType, - Seq(overAggGroups(2)), + Seq(overAggGroups.get(2)), flinkLogicalOverAgg ) @@ -2483,7 +2503,10 @@ class FlinkRelMdHandlerTestBase { val rowTypeOfCalc = createRowType("id", "name", "score", "age", "class") val rexProgram = RexProgram.create( studentFlinkLogicalScan.getRowType, - Array(0, 1, 2, 3, 6).map(i => RexInputRef.of(i, studentFlinkLogicalScan.getRowType)).toList, + Array(0, 1, 2, 3, 6) + .map(i => RexInputRef.of(i, studentFlinkLogicalScan.getRowType)) + .toList + .asJava, null, rowTypeOfCalc, rexBuilder @@ -2527,20 +2550,23 @@ class FlinkRelMdHandlerTestBase { createRowType("id", "name", "score", "age", "class", "rk", "drk", "avg_score") val projectProgram = RexProgram.create( flinkLogicalOverAgg.getRowType, - (0 until flinkLogicalOverAgg.getRowType.getFieldCount).flatMap { - i => - if (i < 7) { - Array[RexNode](RexInputRef.of(i, flinkLogicalOverAgg.getRowType)) - } else if (i == 7) { - Array[RexNode]( - rexBuilder.makeCall( - SqlStdOperatorTable.DIVIDE, - RexInputRef.of(7, flinkLogicalOverAgg.getRowType), - RexInputRef.of(8, flinkLogicalOverAgg.getRowType))) - } else { - Array.empty[RexNode] - } - }.toList, + (0 until flinkLogicalOverAgg.getRowType.getFieldCount) + .flatMap { + i => + if (i < 7) { + Array[RexNode](RexInputRef.of(i, flinkLogicalOverAgg.getRowType)) + } else if (i == 7) { + Array[RexNode]( + rexBuilder.makeCall( + SqlStdOperatorTable.DIVIDE, + RexInputRef.of(7, flinkLogicalOverAgg.getRowType), + RexInputRef.of(8, flinkLogicalOverAgg.getRowType))) + } else { + Array.empty[RexNode] + } + } + .toList + .asJava, null, rowTypeOfWindowAggOutput, rexBuilder @@ -3284,7 +3310,7 @@ class FlinkRelMdHandlerTestBase { relBuilder.push(studentLogicalScan) if (windowFunctionCall) { - val projects = List( + val projects = java.util.List.of( relBuilder.field(0), relBuilder.field(1), relBuilder.call(FlinkSqlOperatorTable.PROCTIME)) @@ -3342,7 +3368,10 @@ class FlinkRelMdHandlerTestBase { FunctionIdentifier.of("STRING_SPLIT"), new JavaUserDefinedTableFunctions.StringSplit())), rexBuilder.makeFieldAccess(correlVar, 1), - rexBuilder.makeCall(stringType, CAST, List(rexBuilder.makeFieldAccess(correlVar, 0))) + rexBuilder.makeCall( + stringType, + CAST, + java.util.List.of[RexNode](rexBuilder.makeFieldAccess(correlVar, 0))) ) new FlinkLogicalTableFunctionScan( cluster, @@ -3421,7 +3450,7 @@ class FlinkRelMdHandlerTestBase { )) } } - val program = RexProgram.create(tvf.getRowType, projects, null, outputType, rexBuilder) + val program = RexProgram.create(tvf.getRowType, projects.asJava, null, outputType, rexBuilder) new StreamPhysicalCalc( cluster, streamPhysicalTraits, @@ -3451,7 +3480,12 @@ class FlinkRelMdHandlerTestBase { protected def createUnionOnWindowTVF( tvf1: CommonPhysicalWindowTableFunction, tvf2: CommonPhysicalWindowTableFunction): Union = { - new StreamPhysicalUnion(cluster, streamPhysicalTraits, List(tvf1, tvf2), true, tvf1.getRowType) + new StreamPhysicalUnion( + cluster, + streamPhysicalTraits, + java.util.List.of(tvf1, tvf2), + true, + tvf1.getRowType) } // hash by field a @@ -3487,8 +3521,8 @@ class FlinkRelMdHandlerTestBase { relBuilder.push(windowTableFunctionScan) val groupKey = if (groupByWindow) - List(relBuilder.field(0), relBuilder.field(3), relBuilder.field(4)) - else List(relBuilder.field(0)) + java.util.List.of(relBuilder.field(0), relBuilder.field(3), relBuilder.field(4)) + else java.util.List.of[RexInputRef](relBuilder.field(0)) val logicalAgg = relBuilder .aggregate( @@ -3535,7 +3569,7 @@ class FlinkRelMdHandlerTestBase { traitSet, streamTumbleWindowTVFRel, Array(), - logicalAgg.getAggCallList, + logicalAgg.getAggCallList.asScala.toSeq, new WindowAttachedWindowingStrategy(tumbleWindowSpec, timeAttributeType, 5, 6), namedWindowProperties) @@ -3656,7 +3690,13 @@ class FlinkRelMdHandlerTestBase { upsertKeys: util.Set[ImmutableBitSet], statistic: FlinkStatistic = FlinkStatistic.UNKNOWN): T = { val intermediateTable = - new IntermediateRelTable(Seq(""), relNode, null, false, upsertKeys, statistic) + new IntermediateRelTable( + java.util.List.of[String](""), + relNode, + null, + false, + upsertKeys, + statistic) val conventionTrait = traitSet.getTrait(ConventionTraitDef.INSTANCE) val scan = conventionTrait match { @@ -3722,26 +3762,29 @@ class FlinkRelMdHandlerTestBase { literalValues: Seq[String]): util.List[RexLiteral] = { require(literalValues.length == rowType.getFieldCount) val rexBuilder = relBuilder.getRexBuilder - literalValues.zipWithIndex.map { - case (v, index) => - val fieldType = rowType.getFieldList.get(index).getType - if (v == null) { - rexBuilder.makeNullLiteral(fieldType) - } else { - fieldType.getSqlTypeName match { - case BIGINT => rexBuilder.makeLiteral(v.toLong, fieldType, true) - case INTEGER => rexBuilder.makeLiteral(v.toInt, fieldType, true) - case BOOLEAN => rexBuilder.makeLiteral(v.toBoolean) - case DATE => rexBuilder.makeDateLiteral(new DateString(v)) - case TIME => rexBuilder.makeTimeLiteral(new TimeString(v), 0) - case TIMESTAMP => rexBuilder.makeTimestampLiteral(new TimestampString(v), 0) - case DOUBLE => rexBuilder.makeApproxLiteral(BigDecimal.valueOf(v.toDouble)) - case FLOAT => rexBuilder.makeApproxLiteral(BigDecimal.valueOf(v.toFloat)) - case VARCHAR => rexBuilder.makeLiteral(v) - case _ => throw new TableException(s"${fieldType.getSqlTypeName} is not supported!") - } - }.asInstanceOf[RexLiteral] - }.toList + literalValues.zipWithIndex + .map { + case (v, index) => + val fieldType = rowType.getFieldList.get(index).getType + if (v == null) { + rexBuilder.makeNullLiteral(fieldType) + } else { + fieldType.getSqlTypeName match { + case BIGINT => rexBuilder.makeLiteral(v.toLong, fieldType, true) + case INTEGER => rexBuilder.makeLiteral(v.toInt, fieldType, true) + case BOOLEAN => rexBuilder.makeLiteral(v.toBoolean) + case DATE => rexBuilder.makeDateLiteral(new DateString(v)) + case TIME => rexBuilder.makeTimeLiteral(new TimeString(v), 0) + case TIMESTAMP => rexBuilder.makeTimestampLiteral(new TimestampString(v), 0) + case DOUBLE => rexBuilder.makeApproxLiteral(BigDecimal.valueOf(v.toDouble)) + case FLOAT => rexBuilder.makeApproxLiteral(BigDecimal.valueOf(v.toFloat)) + case VARCHAR => rexBuilder.makeLiteral(v) + case _ => throw new TableException(s"${fieldType.getSqlTypeName} is not supported!") + } + }.asInstanceOf[RexLiteral] + } + .toList + .asJava } protected def createLogicalCalc( diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicityTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicityTest.scala index a086c081426ef..586c0daa92c78 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicityTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdModifiedMonotonicityTest.scala @@ -26,7 +26,7 @@ import org.apache.calcite.rel.RelCollations import org.apache.calcite.rel.core.JoinRelType import org.apache.calcite.rel.hint.RelHint import org.apache.calcite.rel.logical.LogicalCalc -import org.apache.calcite.rex.{RexNode, RexProgram} +import org.apache.calcite.rex.{RexInputRef, RexNode, RexProgram} import org.apache.calcite.sql.fun.SqlStdOperatorTable._ import org.apache.calcite.sql.validate.SqlMonotonicity import org.apache.calcite.sql.validate.SqlMonotonicity._ @@ -36,7 +36,7 @@ import org.junit.jupiter.api.Test import java.util -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters.seqAsJavaListConverter import scala.language.postfixOps class FlinkRelMdModifiedMonotonicityTest extends FlinkRelMdHandlerTestBase { @@ -60,8 +60,9 @@ class FlinkRelMdModifiedMonotonicityTest extends FlinkRelMdHandlerTestBase { relBuilder.push(inputAgg) // project `age` field and corresponding output type - val projection = List(relBuilder.field("age")) - val ageFieldType = inputAgg.getRowType.getFieldList.filter(x => x.getName.equals("age")) + val projection = java.util.List.of[RexInputRef](relBuilder.field("age")) + val ageFieldType = + inputAgg.getRowType.getFieldList.stream().filter(x => x.getName.equals("age")).toList val outputType = new RelRecordType(ageFieldType) // select age from (select id, age, count() from student by id, age) where ... @@ -250,7 +251,7 @@ class FlinkRelMdModifiedMonotonicityTest extends FlinkRelMdHandlerTestBase { projectWithMaxAgg, ImmutableBitSet.of(0), null, - Seq(tableAggCall) + Seq(tableAggCall).toList.asJava ) assertEquals(null, mq.getRelModifiedMonotonicity(tableAggregate)) } diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdRowCollationTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdRowCollationTest.scala index 5ee117dc1a42b..c92937cca6f1b 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdRowCollationTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdRowCollationTest.scala @@ -31,7 +31,7 @@ import org.apache.calcite.util.ImmutableBitSet import org.junit.jupiter.api.Assertions.{assertEquals, assertTrue} import org.junit.jupiter.api.Test -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters.seqAsJavaListConverter class FlinkRelMdRowCollationTest extends FlinkRelMdHandlerTestBase { @@ -47,7 +47,7 @@ class FlinkRelMdRowCollationTest extends FlinkRelMdHandlerTestBase { List("1", "9.0", "true", "2"), List("2", "6.0", "false", "3"), List("3", "3.0", "true", "4") - ).map(createLiteralList(valuesType, _)) + ).map(createLiteralList(valuesType, _)).asJava relBuilder.clear() relBuilder.values(tupleList, valuesType) relBuilder.build().asInstanceOf[LogicalValues] @@ -61,11 +61,20 @@ class FlinkRelMdRowCollationTest extends FlinkRelMdHandlerTestBase { // Test intermediate table scan. val flinkLogicalIntermediateTableScan: FlinkLogicalIntermediateTableScan = - createIntermediateScan(flinkLogicalSort, flinkLogicalTraits, Set(ImmutableBitSet.of(0))) + createIntermediateScan( + flinkLogicalSort, + flinkLogicalTraits, + java.util.Set.of[ImmutableBitSet](ImmutableBitSet.of(0))) val batchPhysicalIntermediateTableScan: BatchPhysicalIntermediateTableScan = - createIntermediateScan(batchSort, batchPhysicalTraits, Set(ImmutableBitSet.of(0))) + createIntermediateScan( + batchSort, + batchPhysicalTraits, + java.util.Set.of[ImmutableBitSet](ImmutableBitSet.of(0))) val streamPhysicalIntermediateTableScan: StreamPhysicalIntermediateTableScan = - createIntermediateScan(streamSort, streamPhysicalTraits, Set(ImmutableBitSet.of(0))) + createIntermediateScan( + streamSort, + streamPhysicalTraits, + java.util.Set.of[ImmutableBitSet](ImmutableBitSet.of(0))) Array( flinkLogicalIntermediateTableScan, batchPhysicalIntermediateTableScan, @@ -104,7 +113,7 @@ class FlinkRelMdRowCollationTest extends FlinkRelMdHandlerTestBase { val project: LogicalProject = { relBuilder.push(collationValues) - val projects = List( + val projects = java.util.List.of( // a + b relBuilder.call(PLUS, relBuilder.field(0), relBuilder.literal(1)), // c @@ -125,7 +134,8 @@ class FlinkRelMdRowCollationTest extends FlinkRelMdHandlerTestBase { .add("a", SqlTypeName.BIGINT) .add("ts", SqlTypeName.VARCHAR) .build() - val tupleList = List(List("3", "2015-07-24 10:00:00")).map(createLiteralList(valuesType, _)) + val tupleList = + List(List("3", "2015-07-24 10:00:00")).map(createLiteralList(valuesType, _)).asJava relBuilder.values(tupleList, valuesType) val project2 = relBuilder .project( diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdRowCountTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdRowCountTest.scala index 3792805bfff6a..ad44ac902db6f 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdRowCountTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdRowCountTest.scala @@ -28,8 +28,6 @@ import org.apache.calcite.util.ImmutableBitSet import org.junit.jupiter.api.Assertions._ import org.junit.jupiter.api.Test -import scala.collection.JavaConversions._ - class FlinkRelMdRowCountTest extends FlinkRelMdHandlerTestBase { @Test @@ -174,7 +172,7 @@ class FlinkRelMdRowCountTest extends FlinkRelMdHandlerTestBase { false, false, false, - List[Integer](3), + java.util.List.of[Integer](3), -1, null, RelCollations.EMPTY, diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdSelectivityTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdSelectivityTest.scala index 1fc2f71faeb13..e8894c4f05f23 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdSelectivityTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdSelectivityTest.scala @@ -40,7 +40,7 @@ import org.junit.jupiter.api.Test import java.util import java.util.Collections -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters.seqAsJavaListConverter class FlinkRelMdSelectivityTest extends FlinkRelMdHandlerTestBase { @@ -72,9 +72,9 @@ class FlinkRelMdSelectivityTest extends FlinkRelMdHandlerTestBase { val expr1 = relBuilder.call(LESS_THAN_OR_EQUAL, relBuilder.field(0), relBuilder.literal(2)) val expr2 = relBuilder.call(GREATER_THAN, relBuilder.field(0), relBuilder.literal(-1)) val expr3 = relBuilder.call(LESS_THAN, relBuilder.field(1), relBuilder.literal(1.1d)) - relBuilder.filter(List(expr1, expr2, expr3)) + relBuilder.filter(java.util.List.of(expr1, expr2, expr3)) // top projects: $0==1, $0, $1, true, 2.1, 2 - val projects = List( + val projects = java.util.List.of( relBuilder.call(EQUALS, relBuilder.field(0), relBuilder.literal(1)), relBuilder.field(0), relBuilder.field(1), @@ -101,7 +101,7 @@ class FlinkRelMdSelectivityTest extends FlinkRelMdHandlerTestBase { val expr1 = relBuilder.call(LESS_THAN_OR_EQUAL, relBuilder.field(0), relBuilder.literal(2)) val expr2 = relBuilder.call(GREATER_THAN, relBuilder.field(0), relBuilder.literal(-1)) val expr3 = relBuilder.call(LESS_THAN, relBuilder.field(1), relBuilder.literal(1.1d)) - val filter = relBuilder.filter(List(expr1, expr2, expr3)).build() + val filter = relBuilder.filter(java.util.List.of(expr1, expr2, expr3)).build() relBuilder.push(filter) val pred1 = relBuilder.call(LESS_THAN_OR_EQUAL, relBuilder.field(0), relBuilder.literal(1)) assertEquals((1.0 + 1.0) / (2.0 + 1.0), mq.getSelectivity(filter, pred1)) @@ -115,7 +115,7 @@ class FlinkRelMdSelectivityTest extends FlinkRelMdHandlerTestBase { relBuilder.push(ts) // projects: $0==1, $0, $1, true, 2.1, 2 - val projects = List( + val projects = java.util.List.of( relBuilder.call(EQUALS, relBuilder.field(0), relBuilder.literal(1)), relBuilder.field(0), relBuilder.field(1), @@ -132,7 +132,8 @@ class FlinkRelMdSelectivityTest extends FlinkRelMdHandlerTestBase { val expr2 = relBuilder.call(GREATER_THAN, relBuilder.field(0), relBuilder.literal(-1)) val expr3 = relBuilder.call(LESS_THAN, relBuilder.field(1), relBuilder.literal(1.1d)) val rexBuilder = relBuilder.getRexBuilder - val predicate = RexUtil.composeConjunction(rexBuilder, List(expr1, expr2, expr3), true) + val predicate = + RexUtil.composeConjunction(rexBuilder, java.util.List.of(expr1, expr2, expr3), true) val program = RexProgram.create(ts.getRowType, projects, predicate, outputRowType, rexBuilder) val calc = new BatchPhysicalCalc(cluster, batchPhysicalTraits, ts, program, outputRowType) @@ -331,14 +332,14 @@ class FlinkRelMdSelectivityTest extends FlinkRelMdHandlerTestBase { val aggWithAuxGroupAndExpand = relBuilder .push(expand) .aggregate( - relBuilder.groupKey(relBuilder.fields(Seq[Integer](0, 4).toList)), + relBuilder.groupKey(relBuilder.fields(Seq[Integer](0, 4).toList.asJava)), Lists.newArrayList( AggregateCall.create( FlinkSqlOperatorTable.AUXILIARY_GROUP, false, false, false, - List[Integer](1), + java.util.List.of[Integer](1), -1, null, RelCollations.EMPTY, @@ -351,7 +352,7 @@ class FlinkRelMdSelectivityTest extends FlinkRelMdHandlerTestBase { false, false, false, - List[Integer](2), + java.util.List.of[Integer](2), -1, null, RelCollations.EMPTY, @@ -364,7 +365,7 @@ class FlinkRelMdSelectivityTest extends FlinkRelMdHandlerTestBase { false, false, false, - List[Integer](3), + java.util.List.of[Integer](3), -1, null, RelCollations.EMPTY, @@ -544,9 +545,9 @@ class FlinkRelMdSelectivityTest extends FlinkRelMdHandlerTestBase { ) )) val scan: FlinkLogicalDataStreamTableScan = - createDataStreamScan(List("MyTable4"), flinkLogicalTraits) + createDataStreamScan(java.util.List.of[String]("MyTable4"), flinkLogicalTraits) val builder = typeFactory.builder - scan.getRowType.getFieldList.foreach(f => builder.add(f.getName, f.getType)) + scan.getRowType.getFieldList.stream().forEach(f => builder.add(f.getName, f.getType)) builder.add(rankAggCall.getName, rankAggCall.getType) builder.add(maxAggCall.getName, maxAggCall.getType) val overWindow = new FlinkLogicalOverAggregate( @@ -607,7 +608,7 @@ class FlinkRelMdSelectivityTest extends FlinkRelMdHandlerTestBase { relBuilder.push(ts).push(right) val joinCondition = RexUtil.composeConjunction( rexBuilder, - List( + java.util.List.of( relBuilder.call(EQUALS, relBuilder.field(2, 0, 0), relBuilder.field(2, 1, 0)), relBuilder.call(GREATER_THAN, relBuilder.field(2, 0, 0), relBuilder.literal(-1)), relBuilder.call(GREATER_THAN, relBuilder.field(2, 1, 1), relBuilder.literal(0.1d)) @@ -619,7 +620,7 @@ class FlinkRelMdSelectivityTest extends FlinkRelMdHandlerTestBase { right, Collections.emptyList(), joinCondition, - Set.empty[CorrelationId], + java.util.Set.of[CorrelationId], JoinRelType.INNER) relBuilder.push(join) diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdSizeTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdSizeTest.scala index 706ca89348948..738d5c22463cf 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdSizeTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdSizeTest.scala @@ -20,8 +20,6 @@ package org.apache.flink.table.planner.plan.metadata import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Test -import scala.collection.JavaConversions._ - class FlinkRelMdSizeTest extends FlinkRelMdHandlerTestBase { @Test @@ -45,47 +43,47 @@ class FlinkRelMdSizeTest extends FlinkRelMdHandlerTestBase { .foreach { scan => assertEquals( - Seq(8.0, 7.2, 8.0, 4.0, 8.0, 1.0, 4.0), - mq.getAverageColumnSizes(scan).toList) + java.util.List.of(8.0, 7.2, 8.0, 4.0, 8.0, 1.0, 4.0), + mq.getAverageColumnSizes(scan)) } Array(empLogicalScan, empBatchScan, empStreamScan).foreach { scan => assertEquals( - Seq(4.0, 12.0, 12.0, 4.0, 12.0, 8.0, 8.0, 4.0), - mq.getAverageColumnSizes(scan).toList) + java.util.List.of(4.0, 12.0, 12.0, 4.0, 12.0, 8.0, 8.0, 4.0), + mq.getAverageColumnSizes(scan)) } } @Test def testAverageColumnSizeOnValues(): Unit = { assertEquals( - Seq(6.25, 1.0, 9.25, 12.0, 9.25, 8.0, 1.0, 3.75), - mq.getAverageColumnSizes(logicalValues).toList) + java.util.List.of(6.25, 1.0, 9.25, 12.0, 9.25, 8.0, 1.0, 3.75), + mq.getAverageColumnSizes(logicalValues)) assertEquals( - Seq(8.0, 1.0, 12.0, 12.0, 12.0, 8.0, 4.0, 12.0), - mq.getAverageColumnSizes(emptyValues).toList) + java.util.List.of(8.0, 1.0, 12.0, 12.0, 12.0, 8.0, 4.0, 12.0), + mq.getAverageColumnSizes(emptyValues)) } @Test def testAverageColumnSizeOnProject(): Unit = { assertEquals( - Seq(8.0, 7.2, 8.0, 4.0, 8.0, 8.0, 8.0, 4.0, 1.0, 8.0, 8.0, 8.0), - mq.getAverageColumnSizes(logicalProject).toList) + java.util.List.of(8.0, 7.2, 8.0, 4.0, 8.0, 8.0, 8.0, 4.0, 1.0, 8.0, 8.0, 8.0), + mq.getAverageColumnSizes(logicalProject)) } @Test def testAverageColumnSizeOnFilter(): Unit = { assertEquals( - Seq(8.0, 7.2, 8.0, 4.0, 8.0, 1.0, 4.0), - mq.getAverageColumnSizes(logicalFilter).toList) + java.util.List.of(8.0, 7.2, 8.0, 4.0, 8.0, 1.0, 4.0), + mq.getAverageColumnSizes(logicalFilter)) } @Test def testAverageColumnSizeOnCalc(): Unit = { assertEquals( - Seq(8.0, 7.2, 8.0, 4.0, 8.0, 8.0, 8.0, 4.0, 1.0, 8.0, 8.0, 8.0), - mq.getAverageColumnSizes(logicalCalc).toList) + java.util.List.of(8.0, 7.2, 8.0, 4.0, 8.0, 8.0, 8.0, 4.0, 1.0, 8.0, 8.0, 8.0), + mq.getAverageColumnSizes(logicalCalc)) } @Test @@ -93,8 +91,8 @@ class FlinkRelMdSizeTest extends FlinkRelMdHandlerTestBase { Array(logicalExpand, flinkLogicalExpand, batchExpand, streamExpand).foreach { expand => assertEquals( - Seq(8.0, 7.2, 8.0, 4.0, 8.0, 1.0, 4.0, 8.0), - mq.getAverageColumnSizes(expand).toList) + java.util.List.of(8.0, 7.2, 8.0, 4.0, 8.0, 1.0, 4.0, 8.0), + mq.getAverageColumnSizes(expand)) } } @@ -103,8 +101,8 @@ class FlinkRelMdSizeTest extends FlinkRelMdHandlerTestBase { Array(batchExchange, streamExchange).foreach { exchange => assertEquals( - Seq(8.0, 7.2, 8.0, 4.0, 8.0, 1.0, 4.0), - mq.getAverageColumnSizes(exchange).toList) + java.util.List.of(8.0, 7.2, 8.0, 4.0, 8.0, 1.0, 4.0), + mq.getAverageColumnSizes(exchange)) } } @@ -113,12 +111,12 @@ class FlinkRelMdSizeTest extends FlinkRelMdHandlerTestBase { Array(logicalRank, flinkLogicalRank, batchGlobalRank, streamRank).foreach { rank => assertEquals( - Seq(8.0, 7.2, 8.0, 4.0, 8.0, 1.0, 4.0, 8.0), - mq.getAverageColumnSizes(rank).toList) + java.util.List.of(8.0, 7.2, 8.0, 4.0, 8.0, 1.0, 4.0, 8.0), + mq.getAverageColumnSizes(rank)) } assertEquals( - Seq(8.0, 7.2, 8.0, 4.0, 8.0, 1.0, 4.0), - mq.getAverageColumnSizes(batchLocalRank).toList) + java.util.List.of(8.0, 7.2, 8.0, 4.0, 8.0, 1.0, 4.0), + mq.getAverageColumnSizes(batchLocalRank)) } @Test @@ -140,7 +138,9 @@ class FlinkRelMdSizeTest extends FlinkRelMdHandlerTestBase { streamLimit ).foreach { sort => - assertEquals(Seq(8.0, 7.2, 8.0, 4.0, 8.0, 1.0, 4.0), mq.getAverageColumnSizes(sort).toList) + assertEquals( + java.util.List.of(8.0, 7.2, 8.0, 4.0, 8.0, 1.0, 4.0), + mq.getAverageColumnSizes(sort)) } } @@ -153,7 +153,8 @@ class FlinkRelMdSizeTest extends FlinkRelMdHandlerTestBase { batchGlobalAggWithoutLocal, streamGlobalAggWithLocal, streamGlobalAggWithoutLocal).foreach { - agg => assertEquals(Seq(4.0, 8.0, 8.0, 8.0, 8.0, 8.0), mq.getAverageColumnSizes(agg).toList) + agg => + assertEquals(java.util.List.of(4.0, 8.0, 8.0, 8.0, 8.0, 8.0), mq.getAverageColumnSizes(agg)) } Array( @@ -161,7 +162,8 @@ class FlinkRelMdSizeTest extends FlinkRelMdHandlerTestBase { flinkLogicalAggWithAuxGroup, batchGlobalAggWithLocalWithAuxGroup, batchGlobalAggWithoutLocalWithAuxGroup).foreach { - agg => assertEquals(Seq(8.0, 7.2, 8.0, 8.0, 8.0, 8.0), mq.getAverageColumnSizes(agg).toList) + agg => + assertEquals(java.util.List.of(8.0, 7.2, 8.0, 8.0, 8.0, 8.0), mq.getAverageColumnSizes(agg)) } } @@ -172,9 +174,14 @@ class FlinkRelMdSizeTest extends FlinkRelMdHandlerTestBase { flinkLogicalWindowAgg, batchGlobalWindowAggWithoutLocalAgg, batchGlobalWindowAggWithLocalAgg).foreach { - agg => assertEquals(Seq(4d, 32d, 8d, 12d, 12d, 12d, 12d), mq.getAverageColumnSizes(agg).toSeq) + agg => + assertEquals( + java.util.List.of(4d, 32d, 8d, 12d, 12d, 12d, 12d), + mq.getAverageColumnSizes(agg)) } - assertEquals(Seq(4.0, 32.0, 8.0, 8.0), mq.getAverageColumnSizes(batchLocalWindowAgg).toSeq) + assertEquals( + java.util.List.of(4.0, 32.0, 8.0, 8.0), + mq.getAverageColumnSizes(batchLocalWindowAgg)) Array( logicalWindowAggWithAuxGroup, @@ -182,11 +189,14 @@ class FlinkRelMdSizeTest extends FlinkRelMdHandlerTestBase { batchGlobalWindowAggWithoutLocalAggWithAuxGroup, batchGlobalWindowAggWithLocalAggWithAuxGroup ).foreach { - agg => assertEquals(Seq(8d, 4d, 8d, 12d, 12d, 12d, 12d), mq.getAverageColumnSizes(agg).toSeq) + agg => + assertEquals( + java.util.List.of(8d, 4d, 8d, 12d, 12d, 12d, 12d), + mq.getAverageColumnSizes(agg)) } assertEquals( - Seq(8d, 8d, 4d, 8d), - mq.getAverageColumnSizes(batchLocalWindowAggWithAuxGroup).toSeq) + java.util.List.of(8d, 8d, 4d, 8d), + mq.getAverageColumnSizes(batchLocalWindowAggWithAuxGroup)) } @Test @@ -194,19 +204,19 @@ class FlinkRelMdSizeTest extends FlinkRelMdHandlerTestBase { Array(flinkLogicalOverAgg, batchOverAgg).foreach { agg => assertEquals( - Seq(8.0, 7.2, 8.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0), - mq.getAverageColumnSizes(agg).toList) + java.util.List.of(8.0, 7.2, 8.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0), + mq.getAverageColumnSizes(agg)) } assertEquals( - Seq(8.0, 12.0, 8.0, 4.0, 4.0, 8.0, 8.0, 8.0), - mq.getAverageColumnSizes(streamOverAgg).toList) + java.util.List.of(8.0, 12.0, 8.0, 4.0, 4.0, 8.0, 8.0, 8.0), + mq.getAverageColumnSizes(streamOverAgg)) } @Test def testAverageColumnSizeOnJoin(): Unit = { assertEquals( - Seq(4.0, 8.0, 12.0, 88.8, 4.0, 8.0, 8.0, 4.0, 8.0), - mq.getAverageColumnSizes(logicalInnerJoinOnUniqueKeys).toList) + java.util.List.of(4.0, 8.0, 12.0, 88.8, 4.0, 8.0, 8.0, 4.0, 8.0), + mq.getAverageColumnSizes(logicalInnerJoinOnUniqueKeys)) Array( logicalInnerJoinOnDisjointKeys, logicalLeftJoinNotOnUniqueKeys, @@ -214,39 +224,45 @@ class FlinkRelMdSizeTest extends FlinkRelMdHandlerTestBase { logicalFullJoinWithoutEquiCond).foreach { join => assertEquals( - Seq(4.0, 8.0, 12.0, 88.8, 4.0, 4.0, 8.0, 12.0, 10.52, 4.0), - mq.getAverageColumnSizes(join).toList) + java.util.List.of(4.0, 8.0, 12.0, 88.8, 4.0, 4.0, 8.0, 12.0, 10.52, 4.0), + mq.getAverageColumnSizes(join)) } Array(logicalSemiJoinOnUniqueKeys, logicalAntiJoinNotOnUniqueKeys).foreach { - join => assertEquals(Seq(4.0, 8.0, 12.0, 88.8, 4.0), mq.getAverageColumnSizes(join).toList) + join => + assertEquals(java.util.List.of(4.0, 8.0, 12.0, 88.8, 4.0), mq.getAverageColumnSizes(join)) } } @Test def testAverageColumnSizeOnUnion(): Unit = { Array(logicalUnion, logicalUnionAll).foreach { - union => assertEquals(Seq(4.0, 8.0, 12.0, 49.66, 4.0), mq.getAverageColumnSizes(union).toList) + union => + assertEquals(java.util.List.of(4.0, 8.0, 12.0, 49.66, 4.0), mq.getAverageColumnSizes(union)) } } @Test def testAverageColumnSizeOnIntersect(): Unit = { Array(logicalIntersect, logicalIntersectAll).foreach { - union => assertEquals(Seq(4.0, 8.0, 12.0, 88.8, 4.0), mq.getAverageColumnSizes(union).toList) + union => + assertEquals(java.util.List.of(4.0, 8.0, 12.0, 88.8, 4.0), mq.getAverageColumnSizes(union)) } } @Test def testAverageColumnSizeOnMinus(): Unit = { Array(logicalMinus, logicalMinusAll).foreach { - union => assertEquals(Seq(4.0, 8.0, 12.0, 88.8, 4.0), mq.getAverageColumnSizes(union).toList) + union => + assertEquals(java.util.List.of(4.0, 8.0, 12.0, 88.8, 4.0), mq.getAverageColumnSizes(union)) } } @Test def testAverageColumnSizeOnDefault(): Unit = { - assertEquals(Seq(8.0, 12.0, 8.0, 4.0, 8.0, 12.0, 4.0), mq.getAverageColumnSizes(testRel).toList) + assertEquals( + java.util.List.of(8.0, 12.0, 8.0, 4.0, 8.0, 12.0, 4.0), + mq.getAverageColumnSizes(testRel)) } } diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueGroupsTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueGroupsTest.scala index 8ae002b351b6a..7f63798a2b8b1 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueGroupsTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueGroupsTest.scala @@ -21,14 +21,13 @@ import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalExpand import org.apache.flink.table.planner.plan.utils.ExpandUtil import com.google.common.collect.{ImmutableList, ImmutableSet} +import org.apache.calcite.rex.RexNode import org.apache.calcite.sql.fun.SqlStdOperatorTable.{GREATER_THAN, LESS_THAN_OR_EQUAL, MULTIPLY, PLUS} import org.apache.calcite.util.ImmutableBitSet import org.assertj.core.api.Assertions.assertThatThrownBy import org.junit.jupiter.api.Assertions.assertEquals import org.junit.jupiter.api.Test -import scala.collection.JavaConversions._ - class FlinkRelMdUniqueGroupsTest extends FlinkRelMdHandlerTestBase { @Test @@ -86,7 +85,7 @@ class FlinkRelMdUniqueGroupsTest extends FlinkRelMdHandlerTestBase { // a <= 2 and b > 10 val expr1 = relBuilder.call(LESS_THAN_OR_EQUAL, relBuilder.field(0), relBuilder.literal(2)) val expr2 = relBuilder.call(GREATER_THAN, relBuilder.field(1), relBuilder.literal(10d)) - val filter1 = relBuilder.filter(List(expr1, expr2)).build() + val filter1 = relBuilder.filter(java.util.List.of(expr1, expr2)).build() assertEquals(ImmutableBitSet.of(0), mq.getUniqueGroups(filter1, ImmutableBitSet.of(0))) assertEquals(ImmutableBitSet.of(0, 1), mq.getUniqueGroups(filter1, ImmutableBitSet.of(0, 1))) @@ -95,7 +94,7 @@ class FlinkRelMdUniqueGroupsTest extends FlinkRelMdHandlerTestBase { // a <= 2 and b > 10 val expr3 = relBuilder.call(LESS_THAN_OR_EQUAL, relBuilder.field(0), relBuilder.literal(2)) val expr4 = relBuilder.call(GREATER_THAN, relBuilder.field(1), relBuilder.literal(10d)) - val filter2 = relBuilder.filter(List(expr3, expr4)).build() + val filter2 = relBuilder.filter(java.util.List.of(expr3, expr4)).build() assertEquals(ImmutableBitSet.of(0), mq.getUniqueGroups(filter2, ImmutableBitSet.of(0))) assertEquals(ImmutableBitSet.of(0), mq.getUniqueGroups(filter2, ImmutableBitSet.of(0, 1))) assertEquals(ImmutableBitSet.of(0), mq.getUniqueGroups(filter2, ImmutableBitSet.of(0, 1, 2))) @@ -107,7 +106,7 @@ class FlinkRelMdUniqueGroupsTest extends FlinkRelMdHandlerTestBase { relBuilder.scan("MyTable4") // a, b, c val proj1 = relBuilder - .project(List(relBuilder.field(0), relBuilder.field(1), relBuilder.field(2))) + .project(java.util.List.of(relBuilder.field(0), relBuilder.field(1), relBuilder.field(2))) .build() assertEquals(ImmutableBitSet.of(0), mq.getUniqueGroups(proj1, ImmutableBitSet.of(0))) assertEquals(ImmutableBitSet.of(0), mq.getUniqueGroups(proj1, ImmutableBitSet.of(0, 1))) @@ -119,7 +118,7 @@ class FlinkRelMdUniqueGroupsTest extends FlinkRelMdHandlerTestBase { // a + b, b * 2, 1 val proj2 = relBuilder .project( - List( + java.util.List.of( relBuilder.call(PLUS, relBuilder.field(0), relBuilder.field(1)), relBuilder.call(MULTIPLY, relBuilder.field(1), relBuilder.literal(2)), relBuilder.literal(1) @@ -135,7 +134,7 @@ class FlinkRelMdUniqueGroupsTest extends FlinkRelMdHandlerTestBase { // a, b * 2, c, 1, 2 val proj3 = relBuilder .project( - List( + java.util.List.of( relBuilder.field(0), relBuilder.call(MULTIPLY, relBuilder.field(1), relBuilder.literal(2)), relBuilder.field(2), @@ -156,8 +155,8 @@ class FlinkRelMdUniqueGroupsTest extends FlinkRelMdHandlerTestBase { relBuilder.scan("MyTable4") // a, a as a1, $2 val proj4 = relBuilder - .project( - List(relBuilder.field(0), relBuilder.alias(relBuilder.field(0), "a1"), relBuilder.field(2))) + .project(java.util.List + .of(relBuilder.field(0), relBuilder.alias(relBuilder.field(0), "a1"), relBuilder.field(2))) .build() assertEquals(ImmutableBitSet.of(0), mq.getUniqueGroups(proj4, ImmutableBitSet.of(0))) assertEquals(ImmutableBitSet.of(0), mq.getUniqueGroups(proj4, ImmutableBitSet.of(0, 1))) @@ -167,7 +166,8 @@ class FlinkRelMdUniqueGroupsTest extends FlinkRelMdHandlerTestBase { relBuilder.clear() relBuilder.scan("MyTable4") // true, 1 - val proj5 = relBuilder.project(List(relBuilder.literal(true), relBuilder.literal(1))).build() + val proj5 = + relBuilder.project(java.util.List.of(relBuilder.literal(true), relBuilder.literal(1))).build() assertEquals(ImmutableBitSet.of(0), mq.getUniqueGroups(proj5, ImmutableBitSet.of(0))) assertEquals(ImmutableBitSet.of(1), mq.getUniqueGroups(proj5, ImmutableBitSet.of(1))) assertEquals(ImmutableBitSet.of(0), mq.getUniqueGroups(proj5, ImmutableBitSet.of(0, 1))) @@ -179,7 +179,7 @@ class FlinkRelMdUniqueGroupsTest extends FlinkRelMdHandlerTestBase { // project: a, b * 2, c, 1, 2 // filter: a > 1 relBuilder.push(ts) - val projects = List( + val projects = java.util.List.of( relBuilder.field(0), relBuilder.call(MULTIPLY, relBuilder.field(1), relBuilder.literal(2)), relBuilder.field(2), @@ -188,7 +188,7 @@ class FlinkRelMdUniqueGroupsTest extends FlinkRelMdHandlerTestBase { ) val condition = relBuilder.call(LESS_THAN_OR_EQUAL, relBuilder.field(0), relBuilder.literal(2)) val outputRowType = relBuilder.push(ts).project(projects).build().getRowType - val calc = createLogicalCalc(ts, outputRowType, projects, List(condition)) + val calc = createLogicalCalc(ts, outputRowType, projects, java.util.List.of[RexNode](condition)) assertEquals(ImmutableBitSet.of(0), mq.getUniqueGroups(calc, ImmutableBitSet.of(0))) assertEquals(ImmutableBitSet.of(1), mq.getUniqueGroups(calc, ImmutableBitSet.of(1))) diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeysTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeysTest.scala index 41fd47cc62a61..38f0c1fe0d45c 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeysTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUniqueKeysTest.scala @@ -25,6 +25,7 @@ import org.apache.flink.table.planner.plan.utils.ExpandUtil import com.google.common.collect.{ImmutableList, ImmutableSet} import org.apache.calcite.prepare.CalciteCatalogReader import org.apache.calcite.rel.hint.RelHint +import org.apache.calcite.rex.RexNode import org.apache.calcite.sql.`type`.SqlTypeName.VARCHAR import org.apache.calcite.sql.fun.SqlStdOperatorTable.{EQUALS, LESS_THAN} import org.apache.calcite.util.ImmutableBitSet @@ -33,14 +34,14 @@ import org.junit.jupiter.api.Test import java.util.Collections -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters.setAsJavaSetConverter class FlinkRelMdUniqueKeysTest extends FlinkRelMdHandlerTestBase { @Test def testGetUniqueKeysOnTableScan(): Unit = { Array(studentLogicalScan, studentBatchScan, studentStreamScan).foreach { - scan => assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(scan).toSet) + scan => assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(scan)) } Array(empLogicalScan, empBatchScan, empStreamScan).foreach { @@ -49,21 +50,21 @@ class FlinkRelMdUniqueKeysTest extends FlinkRelMdHandlerTestBase { val table = relBuilder.getRelOptSchema .asInstanceOf[CalciteCatalogReader] - .getTable(Seq("projected_table_source_table")) + .getTable(java.util.List.of[String]("projected_table_source_table")) .asInstanceOf[TableSourceTable] val tableSourceScan = new StreamPhysicalTableSourceScan( cluster, streamPhysicalTraits, Collections.emptyList[RelHint](), table) - assertEquals(uniqueKeys(Array(0, 2)), mq.getUniqueKeys(tableSourceScan).toSet) + assertEquals(uniqueKeys(Array(0, 2)), mq.getUniqueKeys(tableSourceScan)) } @Test def testGetUniqueKeysOnProjectedTableScanWithPartialCompositePrimaryKey(): Unit = { val table = relBuilder.getRelOptSchema .asInstanceOf[CalciteCatalogReader] - .getTable(Seq("projected_table_source_table_with_partial_pk")) + .getTable(java.util.List.of[String]("projected_table_source_table_with_partial_pk")) .asInstanceOf[TableSourceTable] val tableSourceScan = new StreamPhysicalTableSourceScan( cluster, @@ -81,11 +82,11 @@ class FlinkRelMdUniqueKeysTest extends FlinkRelMdHandlerTestBase { @Test def testGetUniqueKeysOnProject(): Unit = { - assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(logicalProject).toSet) + assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(logicalProject)) relBuilder.push(studentLogicalScan) // id=1, id, cast(id AS bigint not null), cast(id AS int), $1 - val exprs = List( + val exprs = java.util.List.of( relBuilder.call(EQUALS, relBuilder.field(0), relBuilder.literal(1)), relBuilder.field(0), rexBuilder.makeCast(longType, relBuilder.field(0)), @@ -94,8 +95,8 @@ class FlinkRelMdUniqueKeysTest extends FlinkRelMdHandlerTestBase { ) val project1 = relBuilder.project(exprs).build() // INT -> BIGINT is an injective cast, so position 2 is also a unique key - assertEquals(uniqueKeys(Array(1), Array(2)), mq.getUniqueKeys(project1).toSet) - assertEquals(uniqueKeys(Array(1), Array(2)), mq.getUniqueKeys(project1, true).toSet) + assertEquals(uniqueKeys(Array(1), Array(2)), mq.getUniqueKeys(project1)) + assertEquals(uniqueKeys(Array(1), Array(2)), mq.getUniqueKeys(project1, true)) } @Test @@ -108,14 +109,14 @@ class FlinkRelMdUniqueKeysTest extends FlinkRelMdHandlerTestBase { // Project: CAST(id AS STRING), name // id (position 0 in source) is the unique key - val exprs = List( + val exprs = java.util.List.of( rexBuilder.makeCast(stringType, relBuilder.field(0)), // CAST(id AS STRING) relBuilder.field(1) // name ) val project = relBuilder.project(exprs).build() // The casted id at position 0 should still be recognized as unique - assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(project).toSet) + assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(project)) } @Test @@ -127,7 +128,7 @@ class FlinkRelMdUniqueKeysTest extends FlinkRelMdHandlerTestBase { val stringType = typeFactory.createSqlType(VARCHAR, 100) // Project: CAST(id AS STRING), id, name - val exprs = List( + val exprs = java.util.List.of( rexBuilder.makeCast(stringType, relBuilder.field(0)), // CAST(id AS STRING) - injective relBuilder.field(0), // id (raw reference) relBuilder.field(1) // name @@ -135,7 +136,7 @@ class FlinkRelMdUniqueKeysTest extends FlinkRelMdHandlerTestBase { val project = relBuilder.project(exprs).build() // Both position 0 (STRING cast of id) and position 1 (raw id) are unique keys - assertEquals(uniqueKeys(Array(0), Array(1)), mq.getUniqueKeys(project).toSet) + assertEquals(uniqueKeys(Array(0), Array(1)), mq.getUniqueKeys(project)) } @Test @@ -148,7 +149,7 @@ class FlinkRelMdUniqueKeysTest extends FlinkRelMdHandlerTestBase { // Project: id, CAST(name AS STRING) // id is the unique key; name is NOT a key (even after casting) - val exprs = List( + val exprs = java.util.List.of( relBuilder.field(0), // id - the unique key rexBuilder.makeCast(stringType, relBuilder.field(1)) // CAST(name AS STRING) - not a key ) @@ -156,7 +157,7 @@ class FlinkRelMdUniqueKeysTest extends FlinkRelMdHandlerTestBase { // Only position 0 (id) is a unique key // Position 1 (cast of name) is NOT a key because name wasn't a key to begin with - assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(project).toSet) + assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(project)) } @Test @@ -168,39 +169,39 @@ class FlinkRelMdUniqueKeysTest extends FlinkRelMdHandlerTestBase { val stringType = typeFactory.createSqlType(VARCHAR, 100) // First, project id as STRING to simulate a STRING key column - val stringKeyExprs = List( + val stringKeyExprs = java.util.List.of( rexBuilder.makeCast(stringType, relBuilder.field(0)), // CAST(id AS STRING) relBuilder.field(1) // name ) val stringKeyProject = relBuilder.project(stringKeyExprs).build() // At this point, position 0 is a STRING that's still a unique key - assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(stringKeyProject).toSet) + assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(stringKeyProject)) // Now cast the STRING back to INT - this is a non-injective (narrowing) cast relBuilder.push(stringKeyProject) - val narrowedExprs = List( + val narrowedExprs = java.util.List.of( rexBuilder.makeCast(intType, relBuilder.field(0)), // CAST(string_id AS INT) - NOT injective relBuilder.field(1) // name ) val narrowedProject = relBuilder.project(narrowedExprs).build() // The key is LOST because STRING->INT is not injective - assertEquals(uniqueKeys(), mq.getUniqueKeys(narrowedProject).toSet) + assertEquals(uniqueKeys(), mq.getUniqueKeys(narrowedProject)) } @Test def testGetUniqueKeysOnFilter(): Unit = { - assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(logicalFilter).toSet) + assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(logicalFilter)) } @Test def testGetUniqueKeysOnWatermark(): Unit = { - assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(logicalWatermarkAssigner).toSet) + assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(logicalWatermarkAssigner)) } @Test def testGetUniqueKeysOnMiniBatchAssigner(): Unit = { - assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(streamMiniBatchAssigner).toSet) + assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(streamMiniBatchAssigner)) } @Test @@ -212,11 +213,11 @@ class FlinkRelMdUniqueKeysTest extends FlinkRelMdHandlerTestBase { studentLogicalScan, logicalProject.getRowType, logicalProject.getProjects, - List(expr)) - assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(logicalCalc).toSet) + java.util.List.of[RexNode](expr)) + assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(logicalCalc)) // id=1, id, cast(id AS bigint not null), cast(id AS int), $1 - val exprs = List( + val exprs = java.util.List.of( relBuilder.call(EQUALS, relBuilder.field(0), relBuilder.literal(1)), relBuilder.field(0), rexBuilder.makeCast(longType, relBuilder.field(0)), @@ -224,16 +225,17 @@ class FlinkRelMdUniqueKeysTest extends FlinkRelMdHandlerTestBase { relBuilder.field(1) ) val rowType = relBuilder.project(exprs).build().getRowType - val calc2 = createLogicalCalc(studentLogicalScan, rowType, exprs, List(expr)) + val calc2 = + createLogicalCalc(studentLogicalScan, rowType, exprs, java.util.List.of[RexNode](expr)) // INT -> BIGINT is an injective cast, so position 2 is also a unique key - assertEquals(uniqueKeys(Array(1), Array(2)), mq.getUniqueKeys(calc2).toSet) - assertEquals(uniqueKeys(Array(1), Array(2)), mq.getUniqueKeys(calc2, true).toSet) + assertEquals(uniqueKeys(Array(1), Array(2)), mq.getUniqueKeys(calc2)) + assertEquals(uniqueKeys(Array(1), Array(2)), mq.getUniqueKeys(calc2, true)) } @Test def testGetUniqueKeysOnExpand(): Unit = { Array(logicalExpand, flinkLogicalExpand, batchExpand, streamExpand).foreach { - expand => assertEquals(uniqueKeys(Array(0, 7)), mq.getUniqueKeys(expand).toSet) + expand => assertEquals(uniqueKeys(Array(0, 7)), mq.getUniqueKeys(expand)) } val expandProjects = ExpandUtil.createExpandProjects( @@ -259,20 +261,18 @@ class FlinkRelMdUniqueKeysTest extends FlinkRelMdHandlerTestBase { @Test def testGetUniqueKeysOnExchange(): Unit = { Array(batchExchange, streamExchange).foreach { - exchange => assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(exchange).toSet) + exchange => assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(exchange)) } } @Test def testGetUniqueKeysOnRank(): Unit = { Array(logicalRank, flinkLogicalRank, batchLocalRank, batchGlobalRank, streamRank).foreach { - rank => assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(rank).toSet) + rank => assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(rank)) } Array(logicalRowNumber, flinkLogicalRowNumber, streamRowNumber) - .foreach { - rank => assertEquals(uniqueKeys(Array(0), Array(7)), mq.getUniqueKeys(rank).toSet) - } + .foreach(rank => assertEquals(uniqueKeys(Array(0), Array(7)), mq.getUniqueKeys(rank))) } @Test @@ -292,25 +292,25 @@ class FlinkRelMdUniqueKeysTest extends FlinkRelMdHandlerTestBase { flinkLogicalLimit, batchLimit, streamLimit - ).foreach(sort => assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(sort).toSet)) + ).foreach(sort => assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(sort))) } @Test def testGetUniqueKeysOnStreamExecDeduplicate(): Unit = { - assertEquals(uniqueKeys(Array(1)), mq.getUniqueKeys(streamProcTimeDeduplicateFirstRow).toSet) - assertEquals(uniqueKeys(Array(1, 2)), mq.getUniqueKeys(streamProcTimeDeduplicateLastRow).toSet) - assertEquals(uniqueKeys(Array(1)), mq.getUniqueKeys(streamRowTimeDeduplicateFirstRow).toSet) - assertEquals(uniqueKeys(Array(1, 2)), mq.getUniqueKeys(streamRowTimeDeduplicateLastRow).toSet) + assertEquals(uniqueKeys(Array(1)), mq.getUniqueKeys(streamProcTimeDeduplicateFirstRow)) + assertEquals(uniqueKeys(Array(1, 2)), mq.getUniqueKeys(streamProcTimeDeduplicateLastRow)) + assertEquals(uniqueKeys(Array(1)), mq.getUniqueKeys(streamRowTimeDeduplicateFirstRow)) + assertEquals(uniqueKeys(Array(1, 2)), mq.getUniqueKeys(streamRowTimeDeduplicateLastRow)) } @Test def testGetUniqueKeysOnStreamExecChangelogNormalize(): Unit = { - assertEquals(uniqueKeys(Array(1, 0)), mq.getUniqueKeys(streamChangelogNormalize).toSet) + assertEquals(uniqueKeys(Array(1, 0)), mq.getUniqueKeys(streamChangelogNormalize)) } @Test def testGetUniqueKeysOnStreamExecDropUpdateBefore(): Unit = { - assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(streamDropUpdateBefore).toSet) + assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(streamDropUpdateBefore)) } @Test @@ -322,7 +322,7 @@ class FlinkRelMdUniqueKeysTest extends FlinkRelMdHandlerTestBase { batchGlobalAggWithoutLocal, streamGlobalAggWithLocal, streamGlobalAggWithoutLocal).foreach { - agg => assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(agg).toSet) + agg => assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(agg)) } assertNull(mq.getUniqueKeys(batchLocalAgg)) assertNull(mq.getUniqueKeys(streamLocalAgg)) @@ -332,7 +332,7 @@ class FlinkRelMdUniqueKeysTest extends FlinkRelMdHandlerTestBase { flinkLogicalAggWithAuxGroup, batchGlobalAggWithLocalWithAuxGroup, batchGlobalAggWithoutLocalWithAuxGroup).foreach { - agg => assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(agg).toSet) + agg => assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(agg)) } assertNull(mq.getUniqueKeys(batchLocalAggWithAuxGroup)) } @@ -376,93 +376,88 @@ class FlinkRelMdUniqueKeysTest extends FlinkRelMdHandlerTestBase { @Test def testGetUniqueKeysOnOverAgg(): Unit = { Array(flinkLogicalOverAgg, batchOverAgg).foreach { - agg => assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(agg).toSet) + agg => assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(agg)) } - assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(streamOverAgg).toSet) + assertEquals(uniqueKeys(Array(0)), mq.getUniqueKeys(streamOverAgg)) } @Test def testGetUniqueKeysOnJoin(): Unit = { assertEquals( uniqueKeys(Array(1), Array(5), Array(1, 5), Array(5, 6), Array(1, 5, 6)), - mq.getUniqueKeys(logicalInnerJoinOnUniqueKeys).toSet) - assertEquals(uniqueKeys(), mq.getUniqueKeys(logicalInnerJoinNotOnUniqueKeys).toSet) - assertEquals(uniqueKeys(), mq.getUniqueKeys(logicalInnerJoinOnRHSUniqueKeys).toSet) - assertEquals(uniqueKeys(), mq.getUniqueKeys(logicalInnerJoinWithoutEquiCond).toSet) - assertEquals(uniqueKeys(), mq.getUniqueKeys(logicalInnerJoinWithEquiAndNonEquiCond).toSet) + mq.getUniqueKeys(logicalInnerJoinOnUniqueKeys)) + assertEquals(uniqueKeys(), mq.getUniqueKeys(logicalInnerJoinNotOnUniqueKeys)) + assertEquals(uniqueKeys(), mq.getUniqueKeys(logicalInnerJoinOnRHSUniqueKeys)) + assertEquals(uniqueKeys(), mq.getUniqueKeys(logicalInnerJoinWithoutEquiCond)) + assertEquals(uniqueKeys(), mq.getUniqueKeys(logicalInnerJoinWithEquiAndNonEquiCond)) assertEquals( uniqueKeys(Array(1), Array(1, 5), Array(1, 5, 6)), - mq.getUniqueKeys(logicalLeftJoinOnUniqueKeys).toSet) - assertEquals(uniqueKeys(), mq.getUniqueKeys(logicalLeftJoinNotOnUniqueKeys).toSet) - assertEquals(uniqueKeys(), mq.getUniqueKeys(logicalLeftJoinOnRHSUniqueKeys).toSet) - assertEquals(uniqueKeys(), mq.getUniqueKeys(logicalLeftJoinWithoutEquiCond).toSet) - assertEquals(uniqueKeys(), mq.getUniqueKeys(logicalLeftJoinWithEquiAndNonEquiCond).toSet) + mq.getUniqueKeys(logicalLeftJoinOnUniqueKeys)) + assertEquals(uniqueKeys(), mq.getUniqueKeys(logicalLeftJoinNotOnUniqueKeys)) + assertEquals(uniqueKeys(), mq.getUniqueKeys(logicalLeftJoinOnRHSUniqueKeys)) + assertEquals(uniqueKeys(), mq.getUniqueKeys(logicalLeftJoinWithoutEquiCond)) + assertEquals(uniqueKeys(), mq.getUniqueKeys(logicalLeftJoinWithEquiAndNonEquiCond)) assertEquals( uniqueKeys(Array(5), Array(1, 5), Array(5, 6), Array(1, 5, 6)), - mq.getUniqueKeys(logicalRightJoinOnUniqueKeys).toSet) - assertEquals(uniqueKeys(), mq.getUniqueKeys(logicalRightJoinNotOnUniqueKeys).toSet) - assertEquals(uniqueKeys(), mq.getUniqueKeys(logicalRightJoinOnLHSUniqueKeys).toSet) - assertEquals(uniqueKeys(), mq.getUniqueKeys(logicalRightJoinWithoutEquiCond).toSet) - assertEquals(uniqueKeys(), mq.getUniqueKeys(logicalRightJoinWithEquiAndNonEquiCond).toSet) + mq.getUniqueKeys(logicalRightJoinOnUniqueKeys)) + assertEquals(uniqueKeys(), mq.getUniqueKeys(logicalRightJoinNotOnUniqueKeys)) + assertEquals(uniqueKeys(), mq.getUniqueKeys(logicalRightJoinOnLHSUniqueKeys)) + assertEquals(uniqueKeys(), mq.getUniqueKeys(logicalRightJoinWithoutEquiCond)) + assertEquals(uniqueKeys(), mq.getUniqueKeys(logicalRightJoinWithEquiAndNonEquiCond)) assertEquals( uniqueKeys(Array(1, 5), Array(1, 5, 6)), - mq.getUniqueKeys(logicalFullJoinOnUniqueKeys).toSet) - assertEquals(uniqueKeys(), mq.getUniqueKeys(logicalFullJoinNotOnUniqueKeys).toSet) - assertEquals(uniqueKeys(), mq.getUniqueKeys(logicalFullJoinOnRHSUniqueKeys).toSet) - assertEquals(uniqueKeys(), mq.getUniqueKeys(logicalFullJoinWithoutEquiCond).toSet) - assertEquals(uniqueKeys(), mq.getUniqueKeys(logicalFullJoinWithEquiAndNonEquiCond).toSet) - - assertEquals(uniqueKeys(Array(1)), mq.getUniqueKeys(logicalSemiJoinOnUniqueKeys).toSet) - assertEquals(uniqueKeys(Array(1)), mq.getUniqueKeys(logicalSemiJoinNotOnUniqueKeys).toSet) + mq.getUniqueKeys(logicalFullJoinOnUniqueKeys)) + assertEquals(uniqueKeys(), mq.getUniqueKeys(logicalFullJoinNotOnUniqueKeys)) + assertEquals(uniqueKeys(), mq.getUniqueKeys(logicalFullJoinOnRHSUniqueKeys)) + assertEquals(uniqueKeys(), mq.getUniqueKeys(logicalFullJoinWithoutEquiCond)) + assertEquals(uniqueKeys(), mq.getUniqueKeys(logicalFullJoinWithEquiAndNonEquiCond)) + + assertEquals(uniqueKeys(Array(1)), mq.getUniqueKeys(logicalSemiJoinOnUniqueKeys)) + assertEquals(uniqueKeys(Array(1)), mq.getUniqueKeys(logicalSemiJoinNotOnUniqueKeys)) assertNull(mq.getUniqueKeys(logicalSemiJoinOnRHSUniqueKeys)) - assertEquals(uniqueKeys(Array(1)), mq.getUniqueKeys(logicalSemiJoinWithoutEquiCond).toSet) - assertEquals( - uniqueKeys(Array(1)), - mq.getUniqueKeys(logicalSemiJoinWithEquiAndNonEquiCond).toSet) + assertEquals(uniqueKeys(Array(1)), mq.getUniqueKeys(logicalSemiJoinWithoutEquiCond)) + assertEquals(uniqueKeys(Array(1)), mq.getUniqueKeys(logicalSemiJoinWithEquiAndNonEquiCond)) - assertEquals(uniqueKeys(Array(1)), mq.getUniqueKeys(logicalAntiJoinOnUniqueKeys).toSet) - assertEquals(uniqueKeys(Array(1)), mq.getUniqueKeys(logicalAntiJoinNotOnUniqueKeys).toSet) + assertEquals(uniqueKeys(Array(1)), mq.getUniqueKeys(logicalAntiJoinOnUniqueKeys)) + assertEquals(uniqueKeys(Array(1)), mq.getUniqueKeys(logicalAntiJoinNotOnUniqueKeys)) assertNull(mq.getUniqueKeys(logicalAntiJoinOnRHSUniqueKeys)) - assertEquals(uniqueKeys(Array(1)), mq.getUniqueKeys(logicalAntiJoinWithoutEquiCond).toSet) - assertEquals( - uniqueKeys(Array(1)), - mq.getUniqueKeys(logicalAntiJoinWithEquiAndNonEquiCond).toSet) + assertEquals(uniqueKeys(Array(1)), mq.getUniqueKeys(logicalAntiJoinWithoutEquiCond)) + assertEquals(uniqueKeys(Array(1)), mq.getUniqueKeys(logicalAntiJoinWithEquiAndNonEquiCond)) } @Test def testGetUniqueKeysOnLookupJoin(): Unit = { Array(batchLookupJoin, streamLookupJoin).foreach { - join => assertEquals(uniqueKeys(), mq.getUniqueKeys(join).toSet) + join => assertEquals(uniqueKeys(), mq.getUniqueKeys(join)) } } @Test def testGetUniqueKeysOnLookupJoinWithPk(): Unit = { Array(batchLookupJoinWithPk, streamLookupJoinWithPk).foreach { - join => - assertEquals(uniqueKeys(Array(7), Array(0, 7), Array(0)), mq.getUniqueKeys(join).toSet) + join => assertEquals(uniqueKeys(Array(7), Array(0, 7), Array(0)), mq.getUniqueKeys(join)) } } @Test def testGetUniqueKeysOnLookupJoinNotContainsPk(): Unit = { Array(batchLookupJoinNotContainsPk, streamLookupJoinNotContainsPk).foreach { - join => assertEquals(uniqueKeys(), mq.getUniqueKeys(join).toSet) + join => assertEquals(uniqueKeys(), mq.getUniqueKeys(join)) } } @Test def testGetUniqueKeysOnSetOp(): Unit = { Array(logicalUnionAll, logicalIntersectAll, logicalMinusAll).foreach { - setOp => assertEquals(uniqueKeys(), mq.getUniqueKeys(setOp).toSet) + setOp => assertEquals(uniqueKeys(), mq.getUniqueKeys(setOp)) } Array(logicalUnion, logicalIntersect, logicalMinus).foreach { - setOp => assertEquals(uniqueKeys(Array(0, 1, 2, 3, 4)), mq.getUniqueKeys(setOp).toSet) + setOp => assertEquals(uniqueKeys(Array(0, 1, 2, 3, 4)), mq.getUniqueKeys(setOp)) } } @@ -475,19 +470,19 @@ class FlinkRelMdUniqueKeysTest extends FlinkRelMdHandlerTestBase { def testGetUniqueKeysOnTableScanTable(): Unit = { assertEquals( uniqueKeys(Array(0, 1), Array(0, 1, 5)), - mq.getUniqueKeys(logicalLeftJoinOnContainedUniqueKeys).toSet + mq.getUniqueKeys(logicalLeftJoinOnContainedUniqueKeys) ) assertEquals( uniqueKeys(Array(0, 1, 5)), - mq.getUniqueKeys(logicalLeftJoinOnDisjointUniqueKeys).toSet + mq.getUniqueKeys(logicalLeftJoinOnDisjointUniqueKeys) ) assertEquals( uniqueKeys(), - mq.getUniqueKeys(logicalLeftJoinWithNoneKeyTableUniqueKeys).toSet + mq.getUniqueKeys(logicalLeftJoinWithNoneKeyTableUniqueKeys) ) } - private def uniqueKeys(keys: Array[Int]*): Set[ImmutableBitSet] = { - keys.map(k => ImmutableBitSet.of(k: _*)).toSet + private def uniqueKeys(keys: Array[Int]*): java.util.Set[ImmutableBitSet] = { + keys.map(k => ImmutableBitSet.of(k: _*)).toSet.asJava } } diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUpsertKeysTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUpsertKeysTest.scala index 4a77662bd2b70..4ae6f389e5182 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUpsertKeysTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdUpsertKeysTest.scala @@ -42,14 +42,14 @@ import org.junit.jupiter.api.Test import java.util.Collections -import scala.collection.JavaConversions._ +import scala.collection.JavaConverters.setAsJavaSetConverter class FlinkRelMdUpsertKeysTest extends FlinkRelMdHandlerTestBase { @Test def testGetUpsertKeysOnTableScan(): Unit = { Array(studentLogicalScan, studentBatchScan, studentStreamScan).foreach { - scan => assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(scan).toSet) + scan => assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(scan)) } Array(empLogicalScan, empBatchScan, empStreamScan).foreach { @@ -58,21 +58,21 @@ class FlinkRelMdUpsertKeysTest extends FlinkRelMdHandlerTestBase { val table = relBuilder.getRelOptSchema .asInstanceOf[CalciteCatalogReader] - .getTable(Seq("projected_table_source_table")) + .getTable(java.util.List.of[String]("projected_table_source_table")) .asInstanceOf[TableSourceTable] val tableSourceScan = new StreamPhysicalTableSourceScan( cluster, streamPhysicalTraits, Collections.emptyList[RelHint](), table) - assertEquals(toBitSet(Array(0, 2)), mq.getUpsertKeys(tableSourceScan).toSet) + assertEquals(toBitSet(Array(0, 2)), mq.getUpsertKeys(tableSourceScan)) } @Test def testGetUpsertKeysOnProjectedTableScanWithPartialCompositePrimaryKey(): Unit = { val table = relBuilder.getRelOptSchema .asInstanceOf[CalciteCatalogReader] - .getTable(Seq("projected_table_source_table_with_partial_pk")) + .getTable(java.util.List.of[String]("projected_table_source_table_with_partial_pk")) .asInstanceOf[TableSourceTable] val tableSourceScan = new StreamPhysicalTableSourceScan( cluster, @@ -90,11 +90,11 @@ class FlinkRelMdUpsertKeysTest extends FlinkRelMdHandlerTestBase { @Test def testGetUpsertKeysOnProject(): Unit = { - assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(logicalProject).toSet) + assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(logicalProject)) relBuilder.push(studentLogicalScan) // id=1, id, cast(id AS bigint not null), cast(id AS int), $1 - val exprs = List( + val exprs = java.util.List.of( relBuilder.call(EQUALS, relBuilder.field(0), relBuilder.literal(1)), relBuilder.field(0), // INT -> BIGINT is an injective cast, so position 2 is now also an upsert key @@ -104,7 +104,7 @@ class FlinkRelMdUpsertKeysTest extends FlinkRelMdHandlerTestBase { ) val project1 = relBuilder.project(exprs).build() - assertEquals(toBitSet(Array(1), Array(2)), mq.getUpsertKeys(project1).toSet) + assertEquals(toBitSet(Array(1), Array(2)), mq.getUpsertKeys(project1)) } @Test @@ -116,14 +116,14 @@ class FlinkRelMdUpsertKeysTest extends FlinkRelMdHandlerTestBase { val stringType = typeFactory.createSqlType(VARCHAR, 100) // Project: CAST(id AS STRING), name - val exprs = List( + val exprs = java.util.List.of( rexBuilder.makeCast(stringType, relBuilder.field(0)), // CAST(id AS STRING) relBuilder.field(1) // name ) val project = relBuilder.project(exprs).build() // The casted id at position 0 should still be recognized as upsert key - assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(project).toSet) + assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(project)) } @Test @@ -135,7 +135,7 @@ class FlinkRelMdUpsertKeysTest extends FlinkRelMdHandlerTestBase { val stringType = typeFactory.createSqlType(VARCHAR, 100) // Project: CAST(id AS STRING), id, name - val exprs = List( + val exprs = java.util.List.of( rexBuilder.makeCast(stringType, relBuilder.field(0)), // CAST(id AS STRING) - injective relBuilder.field(0), // id (raw reference) relBuilder.field(1) // name @@ -143,7 +143,7 @@ class FlinkRelMdUpsertKeysTest extends FlinkRelMdHandlerTestBase { val project = relBuilder.project(exprs).build() // Both position 0 (STRING cast of id) and position 1 (raw id) are upsert keys - assertEquals(toBitSet(Array(0), Array(1)), mq.getUpsertKeys(project).toSet) + assertEquals(toBitSet(Array(0), Array(1)), mq.getUpsertKeys(project)) } @Test @@ -155,38 +155,38 @@ class FlinkRelMdUpsertKeysTest extends FlinkRelMdHandlerTestBase { val stringType = typeFactory.createSqlType(VARCHAR, 100) // First, project id as STRING to simulate a STRING key column - val stringKeyExprs = List( + val stringKeyExprs = java.util.List.of( rexBuilder.makeCast(stringType, relBuilder.field(0)), // CAST(id AS STRING) relBuilder.field(1) // name ) val stringKeyProject = relBuilder.project(stringKeyExprs).build() - assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(stringKeyProject).toSet) + assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(stringKeyProject)) // Now cast the STRING back to INT - this is a non-injective cast relBuilder.push(stringKeyProject) - val narrowedExprs = List( + val narrowedExprs = java.util.List.of( rexBuilder.makeCast(intType, relBuilder.field(0)), // CAST(string_id AS INT) - NOT injective relBuilder.field(1) // name ) val narrowedProject = relBuilder.project(narrowedExprs).build() // The key is LOST because STRING->INT is not injective - assertEquals(toBitSet(), mq.getUpsertKeys(narrowedProject).toSet) + assertEquals(toBitSet(), mq.getUpsertKeys(narrowedProject)) } @Test def testGetUpsertKeysOnFilter(): Unit = { - assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(logicalFilter).toSet) + assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(logicalFilter)) } @Test def testGetUpsertKeysOnWatermark(): Unit = { - assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(logicalWatermarkAssigner).toSet) + assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(logicalWatermarkAssigner)) } @Test def testGetUpsertKeysOnMiniBatchAssigner(): Unit = { - assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(streamMiniBatchAssigner).toSet) + assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(streamMiniBatchAssigner)) } @Test @@ -198,11 +198,11 @@ class FlinkRelMdUpsertKeysTest extends FlinkRelMdHandlerTestBase { studentLogicalScan, logicalProject.getRowType, logicalProject.getProjects, - List(expr)) - assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(logicalCalc).toSet) + java.util.List.of[RexNode](expr)) + assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(logicalCalc)) // id=1, id, cast(id AS bigint not null), cast(id AS int), $1 - val exprs = List( + val exprs = java.util.List.of( relBuilder.call(EQUALS, relBuilder.field(0), relBuilder.literal(1)), relBuilder.field(0), rexBuilder.makeCast(longType, relBuilder.field(0)), @@ -210,15 +210,16 @@ class FlinkRelMdUpsertKeysTest extends FlinkRelMdHandlerTestBase { relBuilder.field(1) ) val rowType = relBuilder.project(exprs).build().getRowType - val calc2 = createLogicalCalc(studentLogicalScan, rowType, exprs, List(expr)) + val calc2 = + createLogicalCalc(studentLogicalScan, rowType, exprs, java.util.List.of[RexNode](expr)) // INT -> BIGINT is an injective cast, so position 2 is now also an upsert key - assertEquals(toBitSet(Array(1), Array(2)), mq.getUpsertKeys(calc2).toSet) + assertEquals(toBitSet(Array(1), Array(2)), mq.getUpsertKeys(calc2)) } @Test def testGetUpsertKeysOnExpand(): Unit = { Array(logicalExpand, flinkLogicalExpand, batchExpand, streamExpand).foreach { - expand => assertEquals(toBitSet(Array(0, 7)), mq.getUpsertKeys(expand).toSet) + expand => assertEquals(toBitSet(Array(0, 7)), mq.getUpsertKeys(expand)) } val expandProjects = ExpandUtil.createExpandProjects( @@ -244,18 +245,18 @@ class FlinkRelMdUpsertKeysTest extends FlinkRelMdHandlerTestBase { @Test def testGetUpsertKeysOnExchange(): Unit = { Array(batchExchange, streamExchange).foreach { - exchange => assertEquals(toBitSet(), mq.getUpsertKeys(exchange).toSet) + exchange => assertEquals(toBitSet(), mq.getUpsertKeys(exchange)) } Array(batchExchangeById, streamExchangeById).foreach { - exchange => assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(exchange).toSet) + exchange => assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(exchange)) } } @Test def testGetUpsertKeysOnRank(): Unit = { Array(logicalRank, flinkLogicalRank, batchLocalRank, batchGlobalRank, streamRank).foreach { - rank => assertEquals(toBitSet(), mq.getUpsertKeys(rank).toSet) + rank => assertEquals(toBitSet(), mq.getUpsertKeys(rank)) } Array( @@ -263,22 +264,20 @@ class FlinkRelMdUpsertKeysTest extends FlinkRelMdHandlerTestBase { flinkLogicalRankById, batchLocalRankById, batchGlobalRankById, - streamRankById).foreach { - rank => assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(rank).toSet) - } + streamRankById).foreach(rank => assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(rank))) Array(logicalRowNumber, flinkLogicalRowNumber, streamRowNumber) - .foreach(rank => assertEquals(toBitSet(Array(0), Array(7)), mq.getUpsertKeys(rank).toSet)) + .foreach(rank => assertEquals(toBitSet(Array(0), Array(7)), mq.getUpsertKeys(rank))) } @Test def testGetUpsertKeysOnSort(): Unit = { def testWithoutKey(rel: RelNode): Unit = { - assertEquals(toBitSet(), mq.getUpsertKeys(rel).toSet) + assertEquals(toBitSet(), mq.getUpsertKeys(rel)) } def testWithKey(rel: RelNode): Unit = { - assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(rel).toSet) + assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(rel)) } testWithoutKey(logicalSort) @@ -311,20 +310,20 @@ class FlinkRelMdUpsertKeysTest extends FlinkRelMdHandlerTestBase { @Test def testGetUpsertKeysOnStreamExecDeduplicate(): Unit = { - assertEquals(toBitSet(Array(1)), mq.getUpsertKeys(streamProcTimeDeduplicateFirstRow).toSet) - assertEquals(toBitSet(Array(1, 2)), mq.getUpsertKeys(streamProcTimeDeduplicateLastRow).toSet) - assertEquals(toBitSet(Array(1)), mq.getUpsertKeys(streamRowTimeDeduplicateFirstRow).toSet) - assertEquals(toBitSet(Array(1, 2)), mq.getUpsertKeys(streamRowTimeDeduplicateLastRow).toSet) + assertEquals(toBitSet(Array(1)), mq.getUpsertKeys(streamProcTimeDeduplicateFirstRow)) + assertEquals(toBitSet(Array(1, 2)), mq.getUpsertKeys(streamProcTimeDeduplicateLastRow)) + assertEquals(toBitSet(Array(1)), mq.getUpsertKeys(streamRowTimeDeduplicateFirstRow)) + assertEquals(toBitSet(Array(1, 2)), mq.getUpsertKeys(streamRowTimeDeduplicateLastRow)) } @Test def testGetUpsertKeysOnStreamExecChangelogNormalize(): Unit = { - assertEquals(toBitSet(Array(1, 0)), mq.getUpsertKeys(streamChangelogNormalize).toSet) + assertEquals(toBitSet(Array(1, 0)), mq.getUpsertKeys(streamChangelogNormalize)) } @Test def testGetUpsertKeysOnStreamExecDropUpdateBefore(): Unit = { - assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(streamDropUpdateBefore).toSet) + assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(streamDropUpdateBefore)) } @Test @@ -336,7 +335,7 @@ class FlinkRelMdUpsertKeysTest extends FlinkRelMdHandlerTestBase { batchGlobalAggWithoutLocal, streamGlobalAggWithLocal, streamGlobalAggWithoutLocal).foreach { - agg => assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(agg).toSet) + agg => assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(agg)) } assertNull(mq.getUpsertKeys(batchLocalAgg)) assertNull(mq.getUpsertKeys(streamLocalAgg)) @@ -346,7 +345,7 @@ class FlinkRelMdUpsertKeysTest extends FlinkRelMdHandlerTestBase { flinkLogicalAggWithAuxGroup, batchGlobalAggWithLocalWithAuxGroup, batchGlobalAggWithoutLocalWithAuxGroup).foreach { - agg => assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(agg).toSet) + agg => assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(agg)) } assertNull(mq.getUpsertKeys(batchLocalAggWithAuxGroup)) } @@ -390,88 +389,88 @@ class FlinkRelMdUpsertKeysTest extends FlinkRelMdHandlerTestBase { @Test def testGetUpsertKeysOnOverAgg(): Unit = { Array(flinkLogicalOverAgg, batchOverAgg, streamOverAgg).foreach { - agg => assertEquals(toBitSet(), mq.getUpsertKeys(agg).toSet) + agg => assertEquals(toBitSet(), mq.getUpsertKeys(agg)) } - assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(streamOverAggById).toSet) + assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(streamOverAggById)) } @Test def testGetUpsertKeysOnJoin(): Unit = { assertEquals( toBitSet(Array(1), Array(5), Array(1, 5), Array(5, 6), Array(1, 5, 6)), - mq.getUpsertKeys(logicalInnerJoinOnUniqueKeys).toSet) - assertEquals(toBitSet(), mq.getUpsertKeys(logicalInnerJoinNotOnUniqueKeys).toSet) - assertEquals(toBitSet(), mq.getUpsertKeys(logicalInnerJoinOnRHSUniqueKeys).toSet) - assertEquals(toBitSet(), mq.getUpsertKeys(logicalInnerJoinWithoutEquiCond).toSet) - assertEquals(toBitSet(), mq.getUpsertKeys(logicalInnerJoinWithEquiAndNonEquiCond).toSet) + mq.getUpsertKeys(logicalInnerJoinOnUniqueKeys)) + assertEquals(toBitSet(), mq.getUpsertKeys(logicalInnerJoinNotOnUniqueKeys)) + assertEquals(toBitSet(), mq.getUpsertKeys(logicalInnerJoinOnRHSUniqueKeys)) + assertEquals(toBitSet(), mq.getUpsertKeys(logicalInnerJoinWithoutEquiCond)) + assertEquals(toBitSet(), mq.getUpsertKeys(logicalInnerJoinWithEquiAndNonEquiCond)) assertEquals( toBitSet(Array(1), Array(1, 5), Array(1, 5, 6)), - mq.getUpsertKeys(logicalLeftJoinOnUniqueKeys).toSet) - assertEquals(toBitSet(), mq.getUpsertKeys(logicalLeftJoinNotOnUniqueKeys).toSet) - assertEquals(toBitSet(), mq.getUpsertKeys(logicalLeftJoinOnRHSUniqueKeys).toSet) - assertEquals(toBitSet(), mq.getUpsertKeys(logicalLeftJoinWithoutEquiCond).toSet) - assertEquals(toBitSet(), mq.getUpsertKeys(logicalLeftJoinWithEquiAndNonEquiCond).toSet) + mq.getUpsertKeys(logicalLeftJoinOnUniqueKeys)) + assertEquals(toBitSet(), mq.getUpsertKeys(logicalLeftJoinNotOnUniqueKeys)) + assertEquals(toBitSet(), mq.getUpsertKeys(logicalLeftJoinOnRHSUniqueKeys)) + assertEquals(toBitSet(), mq.getUpsertKeys(logicalLeftJoinWithoutEquiCond)) + assertEquals(toBitSet(), mq.getUpsertKeys(logicalLeftJoinWithEquiAndNonEquiCond)) assertEquals( toBitSet(Array(5), Array(1, 5), Array(5, 6), Array(1, 5, 6)), - mq.getUpsertKeys(logicalRightJoinOnUniqueKeys).toSet) - assertEquals(toBitSet(), mq.getUpsertKeys(logicalRightJoinNotOnUniqueKeys).toSet) - assertEquals(toBitSet(), mq.getUpsertKeys(logicalRightJoinOnLHSUniqueKeys).toSet) - assertEquals(toBitSet(), mq.getUpsertKeys(logicalRightJoinWithoutEquiCond).toSet) - assertEquals(toBitSet(), mq.getUpsertKeys(logicalRightJoinWithEquiAndNonEquiCond).toSet) + mq.getUpsertKeys(logicalRightJoinOnUniqueKeys)) + assertEquals(toBitSet(), mq.getUpsertKeys(logicalRightJoinNotOnUniqueKeys)) + assertEquals(toBitSet(), mq.getUpsertKeys(logicalRightJoinOnLHSUniqueKeys)) + assertEquals(toBitSet(), mq.getUpsertKeys(logicalRightJoinWithoutEquiCond)) + assertEquals(toBitSet(), mq.getUpsertKeys(logicalRightJoinWithEquiAndNonEquiCond)) assertEquals( toBitSet(Array(1, 5), Array(1, 5, 6)), - mq.getUpsertKeys(logicalFullJoinOnUniqueKeys).toSet) - assertEquals(toBitSet(), mq.getUpsertKeys(logicalFullJoinNotOnUniqueKeys).toSet) - assertEquals(toBitSet(), mq.getUpsertKeys(logicalFullJoinOnRHSUniqueKeys).toSet) - assertEquals(toBitSet(), mq.getUpsertKeys(logicalFullJoinWithoutEquiCond).toSet) - assertEquals(toBitSet(), mq.getUpsertKeys(logicalFullJoinWithEquiAndNonEquiCond).toSet) - - assertEquals(toBitSet(Array(1)), mq.getUpsertKeys(logicalSemiJoinOnUniqueKeys).toSet) - assertEquals(toBitSet(), mq.getUpsertKeys(logicalSemiJoinNotOnUniqueKeys).toSet) + mq.getUpsertKeys(logicalFullJoinOnUniqueKeys)) + assertEquals(toBitSet(), mq.getUpsertKeys(logicalFullJoinNotOnUniqueKeys)) + assertEquals(toBitSet(), mq.getUpsertKeys(logicalFullJoinOnRHSUniqueKeys)) + assertEquals(toBitSet(), mq.getUpsertKeys(logicalFullJoinWithoutEquiCond)) + assertEquals(toBitSet(), mq.getUpsertKeys(logicalFullJoinWithEquiAndNonEquiCond)) + + assertEquals(toBitSet(Array(1)), mq.getUpsertKeys(logicalSemiJoinOnUniqueKeys)) + assertEquals(toBitSet(), mq.getUpsertKeys(logicalSemiJoinNotOnUniqueKeys)) assertNull(mq.getUpsertKeys(logicalSemiJoinOnRHSUniqueKeys)) - assertEquals(toBitSet(Array(1)), mq.getUpsertKeys(logicalSemiJoinWithoutEquiCond).toSet) - assertEquals(toBitSet(Array(1)), mq.getUpsertKeys(logicalSemiJoinWithEquiAndNonEquiCond).toSet) + assertEquals(toBitSet(Array(1)), mq.getUpsertKeys(logicalSemiJoinWithoutEquiCond)) + assertEquals(toBitSet(Array(1)), mq.getUpsertKeys(logicalSemiJoinWithEquiAndNonEquiCond)) - assertEquals(toBitSet(Array(1)), mq.getUpsertKeys(logicalAntiJoinOnUniqueKeys).toSet) - assertEquals(toBitSet(), mq.getUpsertKeys(logicalAntiJoinNotOnUniqueKeys).toSet) + assertEquals(toBitSet(Array(1)), mq.getUpsertKeys(logicalAntiJoinOnUniqueKeys)) + assertEquals(toBitSet(), mq.getUpsertKeys(logicalAntiJoinNotOnUniqueKeys)) assertNull(mq.getUpsertKeys(logicalAntiJoinOnRHSUniqueKeys)) - assertEquals(toBitSet(Array(1)), mq.getUpsertKeys(logicalAntiJoinWithoutEquiCond).toSet) - assertEquals(toBitSet(), mq.getUpsertKeys(logicalAntiJoinWithEquiAndNonEquiCond).toSet) + assertEquals(toBitSet(Array(1)), mq.getUpsertKeys(logicalAntiJoinWithoutEquiCond)) + assertEquals(toBitSet(), mq.getUpsertKeys(logicalAntiJoinWithEquiAndNonEquiCond)) } @Test def testGetUpsertKeysOnLookupJoin(): Unit = { Array(batchLookupJoin, streamLookupJoin).foreach { - join => assertEquals(toBitSet(), mq.getUpsertKeys(join).toSet) + join => assertEquals(toBitSet(), mq.getUpsertKeys(join)) } } @Test def testGetUpsertKeysOnLookupJoinWithPk(): Unit = { Array(batchLookupJoinWithPk, streamLookupJoinWithPk).foreach { - join => assertEquals(toBitSet(Array(7), Array(0, 7), Array(0)), mq.getUpsertKeys(join).toSet) + join => assertEquals(toBitSet(Array(7), Array(0, 7), Array(0)), mq.getUpsertKeys(join)) } } @Test def testGetUpsertKeysOnLookupJoinNotContainsPk(): Unit = { Array(batchLookupJoinNotContainsPk, streamLookupJoinNotContainsPk).foreach { - join => assertEquals(toBitSet(), mq.getUpsertKeys(join).toSet) + join => assertEquals(toBitSet(), mq.getUpsertKeys(join)) } } @Test def testGetUpsertKeysOnSetOp(): Unit = { Array(logicalUnionAll, logicalIntersectAll, logicalMinusAll).foreach { - setOp => assertEquals(toBitSet(), mq.getUpsertKeys(setOp).toSet) + setOp => assertEquals(toBitSet(), mq.getUpsertKeys(setOp)) } Array(logicalUnion, logicalIntersect, logicalMinus).foreach { - setOp => assertEquals(toBitSet(Array(0, 1, 2, 3, 4)), mq.getUpsertKeys(setOp).toSet) + setOp => assertEquals(toBitSet(Array(0, 1, 2, 3, 4)), mq.getUpsertKeys(setOp)) } } @@ -482,7 +481,7 @@ class FlinkRelMdUpsertKeysTest extends FlinkRelMdHandlerTestBase { @Test def testGetUpsertKeysOnIntermediateScan(): Unit = { - assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(intermediateScan).toSet) + assertEquals(toBitSet(Array(0)), mq.getUpsertKeys(intermediateScan)) } @Test @@ -490,7 +489,7 @@ class FlinkRelMdUpsertKeysTest extends FlinkRelMdHandlerTestBase { // Immutable columns: {0, 1, 2} (PK 'a' + immutable 'c', 'd') assertEquals( toBitSet(Array(0), Array(0, 1, 2)), - mq.getUpsertKeys(tableWithImmutableColsLogicalScan).toSet) + mq.getUpsertKeys(tableWithImmutableColsLogicalScan)) } @Test @@ -502,7 +501,7 @@ class FlinkRelMdUpsertKeysTest extends FlinkRelMdHandlerTestBase { streamPhysicalTraits.replace(hash1), tableWithImmutableColsStreamScan, hash1) - assertEquals(toBitSet(Array(0), Array(0, 1, 2)), mq.getUpsertKeys(exchange1).toSet) + assertEquals(toBitSet(Array(0), Array(0, 1, 2)), mq.getUpsertKeys(exchange1)) // Hash exchange on column 3 (rowtime, NOT immutable) val hash3 = FlinkRelDistribution.hash(Array(3), requireStrict = true) @@ -511,7 +510,7 @@ class FlinkRelMdUpsertKeysTest extends FlinkRelMdHandlerTestBase { streamPhysicalTraits.replace(hash3), tableWithImmutableColsStreamScan, hash3) - assertEquals(toBitSet(), mq.getUpsertKeys(exchange3).toSet) + assertEquals(toBitSet(), mq.getUpsertKeys(exchange3)) } @Test @@ -519,12 +518,12 @@ class FlinkRelMdUpsertKeysTest extends FlinkRelMdHandlerTestBase { // Sort on column 1 (c, immutable) relBuilder.push(tableWithImmutableColsLogicalScan) val sort1 = relBuilder.sort(relBuilder.field(1)).build() - assertEquals(toBitSet(Array(0), Array(0, 1, 2)), mq.getUpsertKeys(sort1).toSet) + assertEquals(toBitSet(Array(0), Array(0, 1, 2)), mq.getUpsertKeys(sort1)) // Sort on column 3 (rowtime, NOT immutable) relBuilder.push(tableWithImmutableColsLogicalScan) val sort3 = relBuilder.sort(relBuilder.field(3)).build() - assertEquals(toBitSet(), mq.getUpsertKeys(sort3).toSet) + assertEquals(toBitSet(), mq.getUpsertKeys(sort3)) } @Test @@ -553,11 +552,11 @@ class FlinkRelMdUpsertKeysTest extends FlinkRelMdHandlerTestBase { // Rank partitioned by column 1 (c, immutable) val rank1 = buildRank(1) - assertEquals(toBitSet(Array(0), Array(0, 1, 2)), mq.getUpsertKeys(rank1).toSet) + assertEquals(toBitSet(Array(0), Array(0, 1, 2)), mq.getUpsertKeys(rank1)) // Rank partitioned by column 3 (rowtime, NOT immutable) val rank3 = buildRank(3) - assertEquals(toBitSet(), mq.getUpsertKeys(rank3).toSet) + assertEquals(toBitSet(), mq.getUpsertKeys(rank3)) } @Test @@ -621,11 +620,11 @@ class FlinkRelMdUpsertKeysTest extends FlinkRelMdHandlerTestBase { // Over agg partitioned by column 1 (c, immutable) val over1 = buildOverAgg(1) - assertEquals(toBitSet(Array(0), Array(0, 1, 2)), mq.getUpsertKeys(over1).toSet) + assertEquals(toBitSet(Array(0), Array(0, 1, 2)), mq.getUpsertKeys(over1)) // Over agg partitioned by column 3 (rowtime, NOT immutable) val over3 = buildOverAgg(3) - assertEquals(toBitSet(), mq.getUpsertKeys(over3).toSet) + assertEquals(toBitSet(), mq.getUpsertKeys(over3)) } @Test @@ -639,7 +638,7 @@ class FlinkRelMdUpsertKeysTest extends FlinkRelMdHandlerTestBase { JoinRelType.SEMI, relBuilder.call(EQUALS, relBuilder.field(2, 0, 1), relBuilder.field(2, 1, 1))) .build() - assertEquals(toBitSet(Array(0), Array(0, 1, 2)), mq.getUpsertKeys(join1).toSet) + assertEquals(toBitSet(Array(0), Array(0, 1, 2)), mq.getUpsertKeys(join1)) } @Test @@ -659,10 +658,10 @@ class FlinkRelMdUpsertKeysTest extends FlinkRelMdHandlerTestBase { .build() assertEquals( toBitSet(Array(0, 4), Array(0, 4, 5, 6), Array(0, 1, 2, 4), Array(0, 1, 2, 4, 5, 6)), - mq.getUpsertKeys(join).toSet) + mq.getUpsertKeys(join)) } - private def toBitSet(keys: Array[Int]*): Set[ImmutableBitSet] = { - keys.map(k => ImmutableBitSet.of(k: _*)).toSet + private def toBitSet(keys: Array[Int]*): java.util.Set[ImmutableBitSet] = { + keys.map(k => ImmutableBitSet.of(k: _*)).toSet.asJava } } From 9e4cd0c593404e5ce0e08e84e5b25231306f7975 Mon Sep 17 00:00:00 2001 From: Sergey Nuyanzin Date: Thu, 28 May 2026 15:54:22 +0200 Subject: [PATCH 3/5] [FLINK-39817][table] Upgrade Calcite to 1.39.0 --- flink-table/flink-sql-parser/pom.xml | 8 +- .../src/main/codegen/data/Parser.tdd | 4 + .../src/main/codegen/templates/Parser.jj | 102 +- .../calcite/sql/type/SqlTypeFamily.java | 292 -- .../apache/calcite/sql/type/SqlTypeName.java | 1078 ---- .../parser/validate/FlinkSqlConformance.java | 15 + .../flink-table-calcite-bridge/pom.xml | 8 +- .../calcite/jdbc/SimpleCalciteSchema.java | 56 +- .../org/apache/calcite/plan/RelOptUtil.java | 4596 ----------------- .../calcite/rel/rules/SubQueryRemoveRule.java | 115 +- .../rel/type/RelDataTypeFactoryImpl.java | 14 +- .../org/apache/calcite/rex/RexBuilder.java | 2206 -------- .../org/apache/calcite/rex/RexChecker.java | 203 - .../apache/calcite/rex/RexFieldAccess.java | 141 - .../org/apache/calcite/rex/RexProgram.java | 984 ---- .../org/apache/calcite/rex/RexShuttle.java | 284 - .../java/org/apache/calcite/rex/RexUtil.java | 18 +- .../apache/calcite/runtime/SqlFunctions.java | 384 +- .../java/org/apache/calcite/sql/SqlUtil.java | 6 +- .../apache/calcite/sql/type/BasicSqlType.java | 2 +- .../sql/type/SqlTypeAssignmentRule.java | 4 +- .../calcite/sql/type/SqlTypeFactoryImpl.java | 64 +- .../apache/calcite/sql/type/SqlTypeUtil.java | 86 +- .../sql/validate/SqlValidatorImpl.java | 408 +- .../apache/calcite/sql2rel/AggConverter.java | 2 +- .../calcite/sql2rel/RelDecorrelator.java | 486 +- .../calcite/sql2rel/SqlToRelConverter.java | 492 +- .../calcite/FlinkCalciteSqlValidator.java | 11 + .../src/main/resources/META-INF/NOTICE | 6 +- .../planner/plan/utils/TemporalJoinUtil.scala | 3 - .../planner/functions/CastFunctionITCase.java | 3 +- .../table/planner/plan/batch/sql/CalcTest.xml | 2 +- .../plan/batch/sql/DeadlockBreakupTest.xml | 2 +- .../batch/sql/DynamicFunctionPlanTest.xml | 9 +- .../plan/batch/sql/SetOperatorsTest.xml | 4 +- .../ReplaceIntersectWithSemiJoinRuleTest.xml | 2 +- .../ReplaceMinusWithAntiJoinRuleTest.xml | 2 +- .../logical/WindowGroupReorderRuleTest.xml | 2 +- .../logical/subquery/SubQueryAntiJoinTest.xml | 180 + .../logical/subquery/SubQuerySemiJoinTest.xml | 109 + ...PushCalcPastChangelogNormalizeRuleTest.xml | 18 +- .../planner/plan/stream/sql/CalcTest.xml | 2 +- .../plan/stream/sql/SetOperatorsTest.xml | 4 +- .../planner/plan/stream/table/ValuesTest.xml | 2 +- .../expressions/SqlExpressionTest.scala | 3 +- .../subquery/SubQueryAntiJoinTest.scala | 16 +- .../subquery/SubQuerySemiJoinTest.scala | 12 +- flink-table/pom.xml | 2 +- 48 files changed, 1879 insertions(+), 10573 deletions(-) delete mode 100644 flink-table/flink-sql-parser/src/main/java/org/apache/calcite/sql/type/SqlTypeFamily.java delete mode 100644 flink-table/flink-sql-parser/src/main/java/org/apache/calcite/sql/type/SqlTypeName.java delete mode 100644 flink-table/flink-table-planner/src/main/java/org/apache/calcite/plan/RelOptUtil.java delete mode 100644 flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexBuilder.java delete mode 100644 flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexChecker.java delete mode 100644 flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexFieldAccess.java delete mode 100644 flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexProgram.java delete mode 100644 flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexShuttle.java diff --git a/flink-table/flink-sql-parser/pom.xml b/flink-table/flink-sql-parser/pom.xml index 0264583bfecc7..8be0c191ffb04 100644 --- a/flink-table/flink-sql-parser/pom.xml +++ b/flink-table/flink-sql-parser/pom.xml @@ -68,11 +68,11 @@ under the License. ${calcite.version} - 3.1.11 + 3.1.12 diff --git a/flink-table/flink-sql-parser/pom.xml b/flink-table/flink-sql-parser/pom.xml index 8be0c191ffb04..2c4c37d47bb67 100644 --- a/flink-table/flink-sql-parser/pom.xml +++ b/flink-table/flink-sql-parser/pom.xml @@ -68,10 +68,10 @@ under the License. ${calcite.version} diff --git a/flink-table/flink-sql-parser/src/main/codegen/data/Parser.tdd b/flink-table/flink-sql-parser/src/main/codegen/data/Parser.tdd index 84b69813a5247..b7e6d7ab9ea3f 100644 --- a/flink-table/flink-sql-parser/src/main/codegen/data/Parser.tdd +++ b/flink-table/flink-sql-parser/src/main/codegen/data/Parser.tdd @@ -781,4 +781,5 @@ includeBraces: true includeAdditionalDeclarations: false includeParsingStringLiteralAsArrayLiteral: false + includeIntervalWithoutQualifier: false } diff --git a/flink-table/flink-sql-parser/src/main/codegen/templates/Parser.jj b/flink-table/flink-sql-parser/src/main/codegen/templates/Parser.jj index 3b03088e25ec1..38b6748352831 100644 --- a/flink-table/flink-sql-parser/src/main/codegen/templates/Parser.jj +++ b/flink-table/flink-sql-parser/src/main/codegen/templates/Parser.jj @@ -5232,6 +5232,19 @@ TimeUnit Second() : { return warn(TimeUnit.SECOND); } } +SqlIntervalQualifier IntervalWithoutQualifier() : +{ + final Span s; + int secondFracPrec = RelDataType.PRECISION_NOT_SPECIFIED; +} +{ + { + s = span(); + return new SqlIntervalQualifier(TimeUnit.SECOND, -1, null, secondFracPrec, + s.end(this)); + } +} + SqlIntervalQualifier IntervalQualifier() : { final Span s; @@ -5891,6 +5904,7 @@ SqlTypeNameSpec SqlTypeName(Span s) : SqlTypeNameSpec SqlTypeName1(Span s) : { final SqlTypeName sqlTypeName; + boolean unsigned = false; } { ( @@ -5904,13 +5918,40 @@ SqlTypeNameSpec SqlTypeName1(Span s) : | { s.add(this); sqlTypeName = SqlTypeName.BOOLEAN; } | - ( | ) { s.add(this); sqlTypeName = SqlTypeName.INTEGER; } + ( | ) ( { unsigned = true; })? { + if (unsigned && !this.conformance.supportsUnsignedTypes()) { + throw SqlUtil.newContextException(getPos(), RESOURCE.unsignedDisabled()); + } + s.add(this); sqlTypeName = unsigned ? SqlTypeName.UINTEGER : SqlTypeName.INTEGER; + } | - { s.add(this); sqlTypeName = SqlTypeName.TINYINT; } + { + if (!this.conformance.supportsUnsignedTypes()) { + throw SqlUtil.newContextException(getPos(), RESOURCE.unsignedDisabled()); + } + s.add(this); sqlTypeName = SqlTypeName.UINTEGER; + } + | + ( { unsigned = true; })? { + if (unsigned && !this.conformance.supportsUnsignedTypes()) { + throw SqlUtil.newContextException(getPos(), RESOURCE.unsignedDisabled()); + } + s.add(this); sqlTypeName = unsigned ? SqlTypeName.UTINYINT : SqlTypeName.TINYINT; + } | - { s.add(this); sqlTypeName = SqlTypeName.SMALLINT; } + ( { unsigned = true; })? { + if (unsigned && !this.conformance.supportsUnsignedTypes()) { + throw SqlUtil.newContextException(getPos(), RESOURCE.unsignedDisabled()); + } + s.add(this); sqlTypeName = unsigned ? SqlTypeName.USMALLINT : SqlTypeName.SMALLINT; + } | - { s.add(this); sqlTypeName = SqlTypeName.BIGINT; } + ( { unsigned = true; })? { + if (unsigned && !this.conformance.supportsUnsignedTypes()) { + throw SqlUtil.newContextException(getPos(), RESOURCE.unsignedDisabled()); + } + s.add(this); sqlTypeName = unsigned ? SqlTypeName.UBIGINT : SqlTypeName.BIGINT; + } | { s.add(this); sqlTypeName = SqlTypeName.REAL; } | @@ -6296,7 +6337,12 @@ SqlNode BuiltinFunctionCall() : ( dt = DataType() { args.add(dt); } | + LOOKAHEAD(2) e = IntervalQualifier() { args.add(e); } +<#if (parser.includeIntervalWithoutQualifier!default.parser.includeIntervalWithoutQualifier) > + | + e = IntervalWithoutQualifier() { args.add(e); } + ) [ format = StringLiteral() { args.add(format); } ] { @@ -7936,6 +7982,7 @@ SqlBinaryOperator BinaryRowOperator() : { // is handled as a special case { return SqlStdOperatorTable.EQUALS; } +| { return SqlStdOperatorTable.BIT_LEFT_SHIFT; } | { return SqlStdOperatorTable.GREATER_THAN; } | { return SqlStdOperatorTable.LESS_THAN; } | { return SqlStdOperatorTable.LESS_THAN_OR_EQUAL; } @@ -7967,7 +8014,9 @@ SqlBinaryOperator BinaryRowOperator() : | { return SqlStdOperatorTable.NOT_SUBMULTISET_OF; } | { return SqlStdOperatorTable.CONTAINS; } | { return SqlStdOperatorTable.OVERLAPS; } +| { return SqlStdOperatorTable.BITXOR_OPERATOR; } | { return SqlStdOperatorTable.PERIOD_EQUALS; } +| { return SqlStdOperatorTable.BITAND_OPERATOR; } | { return SqlStdOperatorTable.PRECEDES; } | { return SqlStdOperatorTable.SUCCEEDS; } | LOOKAHEAD(2) { return SqlStdOperatorTable.IMMEDIATELY_PRECEDES; } @@ -8714,6 +8763,7 @@ SqlPostfixOperator PostfixRowOperator() : | < UNPIVOT: "UNPIVOT" > | < UNNAMED: "UNNAMED" > | < UNNEST: "UNNEST" > +| < UNSIGNED: "UNSIGNED" > | < UPDATE: "UPDATE" > { beforeTableName(); } | < UPPER: "UPPER" > | < UPSERT: "UPSERT" > @@ -8957,6 +9007,8 @@ void NonReservedKeyWord2of3() : | < DOUBLE_QUOTE: "\"" > | < VERTICAL_BAR: "|" > | < CARET: "^" > +| < AMPERSAND: "&" > +| < LEFTSHIFT: "<<" > | < DOLLAR: "$" > <#list (parser.binaryOperatorsTokens!default.parser.binaryOperatorsTokens) as operator> | ${operator} diff --git a/flink-table/flink-sql-parser/src/main/java/org/apache/calcite/sql/SqlCollectionTypeNameSpec.java b/flink-table/flink-sql-parser/src/main/java/org/apache/calcite/sql/SqlCollectionTypeNameSpec.java index a260a6a730d81..3ca5081b410b1 100644 --- a/flink-table/flink-sql-parser/src/main/java/org/apache/calcite/sql/SqlCollectionTypeNameSpec.java +++ b/flink-table/flink-sql-parser/src/main/java/org/apache/calcite/sql/SqlCollectionTypeNameSpec.java @@ -63,22 +63,40 @@ */ public class SqlCollectionTypeNameSpec extends SqlTypeNameSpec { private final SqlTypeNameSpec elementTypeName; + private final boolean elementTypeNullable; private final SqlTypeName collectionTypeName; /** * Creates a {@code SqlCollectionTypeNameSpec}. * * @param elementTypeName Type of the collection element + * @param elementTypeNullable Type of the collection element is nullable * @param collectionTypeName Collection type name * @param pos Parser position, must not be null */ public SqlCollectionTypeNameSpec( - SqlTypeNameSpec elementTypeName, SqlTypeName collectionTypeName, SqlParserPos pos) { + SqlTypeNameSpec elementTypeName, + boolean elementTypeNullable, + SqlTypeName collectionTypeName, + SqlParserPos pos) { super(new SqlIdentifier(collectionTypeName.name(), pos), pos); this.elementTypeName = requireNonNull(elementTypeName, "elementTypeName"); + this.elementTypeNullable = elementTypeNullable; this.collectionTypeName = requireNonNull(collectionTypeName, "collectionTypeName"); } + /** + * Creates a {@code SqlCollectionTypeNameSpec}. + * + * @param elementTypeName Type of the collection element + * @param collectionTypeName Collection type name + * @param pos Parser position, must not be null + */ + public SqlCollectionTypeNameSpec( + SqlTypeNameSpec elementTypeName, SqlTypeName collectionTypeName, SqlParserPos pos) { + this(elementTypeName, true, collectionTypeName, pos); + } + public SqlTypeNameSpec getElementTypeName() { return elementTypeName; } diff --git a/flink-table/flink-sql-parser/src/main/java/org/apache/calcite/sql/SqlMapTypeNameSpec.java b/flink-table/flink-sql-parser/src/main/java/org/apache/calcite/sql/SqlMapTypeNameSpec.java index 827302de6a2a2..616dff23490f7 100644 --- a/flink-table/flink-sql-parser/src/main/java/org/apache/calcite/sql/SqlMapTypeNameSpec.java +++ b/flink-table/flink-sql-parser/src/main/java/org/apache/calcite/sql/SqlMapTypeNameSpec.java @@ -57,9 +57,12 @@ public SqlDataTypeSpec getValType() { @Override public RelDataType deriveType(SqlValidator validator) { - return validator - .getTypeFactory() - .createMapType(keyType.deriveType(validator), valType.deriveType(validator)); + boolean keyCanBeNullable = + validator.getTypeFactory().getTypeSystem().mapKeysCanBeNullable(); + RelDataType kType = keyType.deriveType(validator, keyCanBeNullable); + + RelDataType valueType = valType.deriveType(validator, true); + return validator.getTypeFactory().createMapType(kType, valueType); } @Override diff --git a/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/validate/FlinkSqlConformance.java b/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/validate/FlinkSqlConformance.java index 0524581f2a39e..ccc1ef14f605a 100644 --- a/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/validate/FlinkSqlConformance.java +++ b/flink-table/flink-sql-parser/src/main/java/org/apache/flink/sql/parser/validate/FlinkSqlConformance.java @@ -47,6 +47,11 @@ public SelectAliasLookup isSelectAlias() { return SqlConformanceEnum.DEFAULT.isSelectAlias(); } + @Override + public boolean isNonStrictGroupBy() { + return false; + } + @Override public boolean isGroupByOrdinal() { return false; @@ -192,6 +197,11 @@ public boolean checkedArithmetic() { return SqlConformanceEnum.DEFAULT.checkedArithmetic(); } + @Override + public boolean supportsUnsignedTypes() { + return false; + } + @Override public boolean isValueAllowed() { return true; diff --git a/flink-table/flink-table-calcite-bridge/pom.xml b/flink-table/flink-table-calcite-bridge/pom.xml index 90ade1deb93c9..77ee0a2fd4950 100644 --- a/flink-table/flink-table-calcite-bridge/pom.xml +++ b/flink-table/flink-table-calcite-bridge/pom.xml @@ -45,18 +45,18 @@ under the License. ${calcite.version} @@ -371,6 +377,7 @@ under the License. commons-io:commons-io org.apache.commons:commons-math3 org.checkerframework:checker-qual + org.jooq:joou-java-6 org.apache.flink:flink-sql-parser @@ -425,6 +432,12 @@ under the License. com.ibm.icu org.apache.flink.table.shaded.com.ibm.icu + + + + org.jooq + org.apache.flink.table.shaded.org.jooq + diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rel/rules/SubQueryRemoveRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rel/rules/SubQueryRemoveRule.java index 0114ad194fd27..3a78b222602f9 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rel/rules/SubQueryRemoveRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rel/rules/SubQueryRemoveRule.java @@ -17,9 +17,11 @@ package org.apache.calcite.rel.rules; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; import org.apache.calcite.plan.RelOptRuleCall; import org.apache.calcite.plan.RelOptUtil; import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelHomogeneousShuttle; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Collect; import org.apache.calcite.rel.core.CorrelationId; @@ -30,6 +32,7 @@ import org.apache.calcite.rel.metadata.RelMdUtil; import org.apache.calcite.rel.metadata.RelMetadataQuery; import org.apache.calcite.rex.LogicVisitor; +import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; @@ -742,10 +745,29 @@ private static RexNode rewriteIn( case TRUE_FALSE_UNKNOWN: case UNKNOWN_AS_TRUE: // Builds the cross join - builder.aggregate( - builder.groupKey(), - builder.count(false, "c"), - builder.count(builder.fields()).as("ck")); + // Some databases don't support use FILTER clauses for aggregate functions + // like {@code COUNT(*) FILTER (WHERE not(a is null))} + // So use count(*) when only one column + if (builder.fields().size() <= 1) { + builder.aggregate( + builder.groupKey(), + builder.count(false, "c"), + builder.count(builder.fields()).as("ck")); + } else { + builder.aggregate( + builder.groupKey(), + builder.count(false, "c"), + builder.count() + .filter( + builder.not( + builder.and( + builder.fields().stream() + .map(builder::isNull) + .collect( + Collectors + .toList())))) + .as("ck")); + } builder.as(ctAlias); if (!variablesSet.isEmpty()) { builder.join(JoinRelType.LEFT, trueLiteral, variablesSet); @@ -925,13 +947,30 @@ private static void matchJoin(SubQueryRemoveRule rule, RelOptRuleCall call) { final Join join = call.rel(0); final RelBuilder builder = call.builder(); final RexSubQuery e = requireNonNull(RexUtil.SubQueryFinder.find(join.getCondition())); - final RelOptUtil.Logic logic = - LogicVisitor.find(RelOptUtil.Logic.TRUE, ImmutableList.of(join.getCondition()), e); ImmutableBitSet inputSet = RelOptUtil.InputFinder.bits(e.getOperands(), null); int nFieldsLeft = join.getLeft().getRowType().getFieldCount(); int nFieldsRight = join.getRight().getRowType().getFieldCount(); + // Correlation columns should also be considered. + // For example: + // LogicalJoin + // left right + // | | + // LogicalProject.NONE.[0, 1] LogicalValues.NONE.[0] + // RecordType(INTEGER DEPTNO, CHAR(11) DNAME) RecordType(INTEGER DEPTNO) + // + // and subquery: $SCALAR_QUERY with correlate + // LogicalProject(DEPTNO=[$1]) + // LogicalFilter(condition=[=(CAST($0):CHAR(11) NOT NULL, $cor0.DNAME)]) + // + // In such a case $cor0.DNAME need to be accounted as input form left side. + final Set variablesSet = RelOptUtil.getVariablesUsed(e.rel); + for (CorrelationId id : variablesSet) { + ImmutableBitSet requiredColumns = RelOptUtil.correlationColumns(id, e.rel); + inputSet = ImmutableBitSet.union(ImmutableList.of(requiredColumns, inputSet)); + } + boolean inputIntersectsLeftSide = inputSet.intersects(ImmutableBitSet.range(0, nFieldsLeft)); boolean inputIntersectsRightSide = @@ -945,10 +984,17 @@ private static void matchJoin(SubQueryRemoveRule rule, RelOptRuleCall call) { return; } - final Set variablesSet = RelOptUtil.getVariablesUsed(e.rel); if (inputIntersectsLeftSide) { builder.push(join.getLeft()); + final RelOptUtil.Logic logic = + LogicVisitor.find( + join.getJoinType().generatesNullsOnRight() + ? RelOptUtil.Logic.TRUE_FALSE_UNKNOWN + : RelOptUtil.Logic.TRUE, + ImmutableList.of(join.getCondition()), + e); + final RexNode target = rule.apply(e, variablesSet, logic, builder, 1, nFieldsLeft, 0); final RexShuttle shuttle = new ReplaceSubQueryShuttle(e, target); @@ -968,12 +1014,94 @@ private static void matchJoin(SubQueryRemoveRule rule, RelOptRuleCall call) { .union(ImmutableBitSet.range(nFields - nFieldsRight, nFields))); builder.project(fields); } else { - builder.push(join.getLeft()); builder.push(join.getRight()); + final RelOptUtil.Logic logic = + LogicVisitor.find( + join.getJoinType().generatesNullsOnLeft() + ? RelOptUtil.Logic.TRUE_FALSE_UNKNOWN + : RelOptUtil.Logic.TRUE, + ImmutableList.of(join.getCondition()), + e); + + RexSubQuery subQuery = e; + + if (!variablesSet.isEmpty()) { + // Original correlates reference joint row type, but we are about to create + // new join of original right side and correlated sub-query. Therefore we have + // to adjust correlated variables in following way: + // 1) new correlation variable must reference row type of right side only + // 2) field index must be shifted on the size of the left side + // Example: + // SELECT e1.* + // FROM emp e1 + // JOIN dept d + // ON e1.deptno = d.deptno + // AND d.deptno IN ( + // SELECT e3.empno + // FROM emp e3 + // WHERE d.deptno > e3.comm + // ) + // ORDER BY e1.empno, e1.deptno; + // + // LogicalJoin(condition=[AND(=($7, $8), IN(CAST($8):SMALLINT NOT NULL, { + // LogicalProject(EMPNO=[$0]) + // LogicalFilter(condition=[>(CAST($cor0.DEPTNO0):DECIMAL(7, 2) NOT NULL, $6)]) + // LogicalTableScan(table=[[scott, EMP]]) + // }))], joinType=[inner]) + // LogicalTableScan(table=[[scott, EMP]]) + // LogicalProject(DEPTNO=[$0]) + // LogicalTableScan(table=[[scott, DEPT]]) + // + // Rewrite to: + // + // LogicalProject(EMPNO=[$0], ENAME=[$1], ..., COMM=[$6], DEPTNO=[$7], DEPTNO0=[$8]) + // LogicalJoin(condition=[=($7, $8)], joinType=[inner]) + // LogicalTableScan(table=[[scott, EMP]]) + // LogicalFilter(condition=[=(CAST($0):SMALLINT NOT NULL, $1)]) + // LogicalCorrelate(correlation=[$cor0], joinType=[inner], + // requiredColumns=[{0}]) + // LogicalProject(DEPTNO=[$0]) + // LogicalTableScan(table=[[scott, DEPT]]) + // LogicalProject(EMPNO=[$0]) + // LogicalFilter(condition=[>(CAST($cor0.DEPTNO):DECIMAL(7, 2) NOT NULL, + // $6)]) + // LogicalTableScan(table=[[scott, EMP]]) + CorrelationId id = Iterables.getOnlyElement(variablesSet); + RexBuilder rexBuilder = builder.getRexBuilder(); + + RelNode newSubQueryRel = + e.rel.accept( + new RelHomogeneousShuttle() { + @Override + public RelNode visit(RelNode other) { + RelNode node = + RexUtil.shiftFieldAccess( + rexBuilder, + other, + id, + join.getRight(), + -nFieldsLeft); + return super.visit(node); + } + }); + subQuery = e.clone(newSubQueryRel); + } + + subQuery = + subQuery.clone( + subQuery.getType(), + RexUtil.shift(subQuery.getOperands(), -nFieldsLeft)); + final int nFields = join.getRowType().getFieldCount(); - final RexNode target = rule.apply(e, variablesSet, logic, builder, 2, nFields, 0); - final RexShuttle shuttle = new ReplaceSubQueryShuttle(e, target); + final RexNode target = + rule.apply(subQuery, variablesSet, logic, builder, 1, nFieldsRight, 0); + final RexShuttle shuttle = + new ReplaceSubQueryShuttle(e, RexUtil.shift(target, nFieldsLeft)); + + RelNode newRight = builder.build(); + builder.push(join.getLeft()); + builder.push(newRight); builder.join(join.getJoinType(), shuttle.apply(join.getCondition())); builder.project(fields(builder, nFields)); diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rel/type/RelDataTypeFactoryImpl.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rel/type/RelDataTypeFactoryImpl.java index 9752813928f78..3ba7ea582554a 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rel/type/RelDataTypeFactoryImpl.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rel/type/RelDataTypeFactoryImpl.java @@ -58,7 +58,6 @@ *

FLINK modifications are at lines * *

    - *
  1. Should be removed after fixing CALCITE-5199: Lines 242-244 *
  2. Added in FLINK-39695 (backport of CALCITE-6764): Lines 407 ~ 438 *
*/ @@ -238,9 +237,7 @@ private RelDataType createStructType( // recursively compute column-wise least restrictive // preserve the struct kind from type0 - // FLINK MODIFICATION BEGIN final Builder builder = builder().kind(type0.getStructKind()); - // FLINK MODIFICATION END for (int j = 0; j < fieldCount; ++j) { // REVIEW jvs 22-Jan-2004: Always use the field name from the // first type? @@ -277,9 +274,11 @@ private RelDataType createStructType( if (type == null) { return null; } - return sqlTypeName == SqlTypeName.ARRAY - ? new ArraySqlType(type, isNullable) - : new MultisetSqlType(type, isNullable); + RelDataType collection = + sqlTypeName == SqlTypeName.ARRAY + ? createArrayType(type, -1) + : createMultisetType(type, -1); + return createTypeWithNullability(collection, isNullable); } protected @Nullable RelDataType leastRestrictiveMapType( @@ -302,7 +301,7 @@ private RelDataType createStructType( if (valueType == null) { return null; } - return new MapSqlType(keyType, valueType, isNullable); + return createTypeWithNullability(createMapType(keyType, valueType), isNullable); } protected RelDataType leastRestrictiveIntervalDatetimeType( diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexSimplify.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexSimplify.java new file mode 100644 index 0000000000000..3f0f5539f8cc6 --- /dev/null +++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexSimplify.java @@ -0,0 +1,3707 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.rex; + +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.BoundType; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableRangeSet; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; +import com.google.common.collect.Multimap; +import com.google.common.collect.Range; +import com.google.common.collect.RangeSet; +import com.google.common.collect.Sets; +import com.google.common.collect.TreeRangeSet; +import org.apache.calcite.avatica.util.TimeUnit; +import org.apache.calcite.avatica.util.TimeUnitRange; +import org.apache.calcite.plan.RelOptPredicateList; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.plan.Strong; +import org.apache.calcite.rel.core.Project; +import org.apache.calcite.rel.metadata.NullSentinel; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.sql.SqlAggFunction; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.type.SqlTypeCoercionRule; +import org.apache.calcite.sql.type.SqlTypeFamily; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.type.SqlTypeUtil; +import org.apache.calcite.util.Bug; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.Pair; +import org.apache.calcite.util.RangeSets; +import org.apache.calcite.util.Sarg; +import org.apache.calcite.util.Util; +import org.checkerframework.checker.nullness.qual.Nullable; + +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.BitSet; +import java.util.Collection; +import java.util.Collections; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static java.util.Objects.requireNonNull; +import static org.apache.calcite.linq4j.Nullness.castNonNull; +import static org.apache.calcite.rex.RexUnknownAs.FALSE; +import static org.apache.calcite.rex.RexUnknownAs.TRUE; +import static org.apache.calcite.rex.RexUnknownAs.UNKNOWN; + +/** Context required to simplify a row-expression. */ +public class RexSimplify { + private final boolean paranoid; + public final RexBuilder rexBuilder; + private final RelOptPredicateList predicates; + + /** + * How to treat UNKNOWN values, if one of the deprecated {@code simplify} methods without an + * {@code unknownAs} argument is called. + */ + final RexUnknownAs defaultUnknownAs; + + final boolean predicateElimination; + private final RexExecutor executor; + + private static final Strong STRONG = new Strong(); + + /** + * Creates a RexSimplify. + * + * @param rexBuilder Rex builder + * @param predicates Predicates known to hold on input fields + * @param executor Executor for constant reduction, not null + */ + public RexSimplify( + RexBuilder rexBuilder, RelOptPredicateList predicates, RexExecutor executor) { + this(rexBuilder, predicates, UNKNOWN, true, false, executor); + } + + /** Internal constructor. */ + private RexSimplify( + RexBuilder rexBuilder, + RelOptPredicateList predicates, + RexUnknownAs defaultUnknownAs, + boolean predicateElimination, + boolean paranoid, + RexExecutor executor) { + this.rexBuilder = requireNonNull(rexBuilder, "rexBuilder"); + this.predicates = requireNonNull(predicates, "predicates"); + this.defaultUnknownAs = requireNonNull(defaultUnknownAs, "defaultUnknownAs"); + this.predicateElimination = predicateElimination; + this.paranoid = paranoid; + this.executor = requireNonNull(executor, "executor"); + } + + @Deprecated // to be removed before 2.0 + public RexSimplify(RexBuilder rexBuilder, boolean unknownAsFalse, RexExecutor executor) { + this( + rexBuilder, + RelOptPredicateList.EMPTY, + RexUnknownAs.falseIf(unknownAsFalse), + true, + false, + executor); + } + + @Deprecated // to be removed before 2.0 + public RexSimplify( + RexBuilder rexBuilder, + RelOptPredicateList predicates, + boolean unknownAsFalse, + RexExecutor executor) { + this(rexBuilder, predicates, RexUnknownAs.falseIf(unknownAsFalse), true, false, executor); + } + + // ~ Methods ---------------------------------------------------------------- + + /** + * Returns a RexSimplify the same as this but with a specified {@link #defaultUnknownAs} value. + * + * @deprecated Use methods with a {@link RexUnknownAs} argument, such as {@link + * #simplify(RexNode, RexUnknownAs)}. + */ + @Deprecated // to be removed before 2.0 + public RexSimplify withUnknownAsFalse(boolean unknownAsFalse) { + final RexUnknownAs defaultUnknownAs = RexUnknownAs.falseIf(unknownAsFalse); + return defaultUnknownAs == this.defaultUnknownAs + ? this + : new RexSimplify( + rexBuilder, + predicates, + defaultUnknownAs, + predicateElimination, + paranoid, + executor); + } + + /** Returns a RexSimplify the same as this but with a specified {@link #predicates} value. */ + public RexSimplify withPredicates(RelOptPredicateList predicates) { + return predicates == this.predicates + ? this + : new RexSimplify( + rexBuilder, + predicates, + defaultUnknownAs, + predicateElimination, + paranoid, + executor); + } + + /** + * Returns a RexSimplify the same as this but which verifies that the expression before and + * after simplification are equivalent. + * + * @see #verify + */ + public RexSimplify withParanoid(boolean paranoid) { + return paranoid == this.paranoid + ? this + : new RexSimplify( + rexBuilder, + predicates, + defaultUnknownAs, + predicateElimination, + paranoid, + executor); + } + + /** + * Returns a RexSimplify the same as this but with a specified {@link #predicateElimination} + * value. + * + *

This is introduced temporarily, until {@link Bug#CALCITE_2401_FIXED [CALCITE-2401] is + * fixed}. + */ + private RexSimplify withPredicateElimination(boolean predicateElimination) { + return predicateElimination == this.predicateElimination + ? this + : new RexSimplify( + rexBuilder, + predicates, + defaultUnknownAs, + predicateElimination, + paranoid, + executor); + } + + /** + * Simplifies a boolean expression, always preserving its type and its nullability. + * + *

This is useful if you are simplifying expressions in a {@link Project}. + */ + public RexNode simplifyPreservingType(RexNode e) { + return simplifyPreservingType(e, defaultUnknownAs, true); + } + + public RexNode simplifyPreservingType( + RexNode e, RexUnknownAs unknownAs, boolean matchNullability) { + final RexNode e2 = simplifyUnknownAs(e, unknownAs); + if (e2.getType() == e.getType()) { + return e2; + } + if (!matchNullability + && SqlTypeUtil.equalSansNullability( + rexBuilder.typeFactory, e2.getType(), e.getType())) { + return e2; + } + final RexNode e3 = rexBuilder.makeCast(e.getType(), e2, matchNullability, false); + if (e3.equals(e)) { + return e; + } + return e3; + } + + /** + * Simplifies a boolean expression. + * + *

In particular: + * + *

    + *
  • {@code simplify(x = 1 OR NOT x = 1 OR x IS NULL)} returns {@code TRUE} + *
  • {@code simplify(x = 1 AND FALSE)} returns {@code FALSE} + *
+ * + *

Handles UNKNOWN values using the policy specified when you created this {@code + * RexSimplify}. Unless you used a deprecated constructor, that policy is {@link + * RexUnknownAs#UNKNOWN}. + * + *

If the expression is a predicate in a WHERE clause, consider instead using {@link + * #simplifyUnknownAsFalse(RexNode)}. + * + * @param e Expression to simplify + */ + public RexNode simplify(RexNode e) { + return simplifyUnknownAs(e, defaultUnknownAs); + } + + /** + * As {@link #simplify(RexNode)}, but for a boolean expression for which a result of UNKNOWN + * will be treated as FALSE. + * + *

Use this form for expressions on a WHERE, ON, HAVING or FILTER(WHERE) clause. + * + *

This may allow certain additional simplifications. A result of UNKNOWN may yield FALSE, + * however it may still yield UNKNOWN. (If the simplified expression has type BOOLEAN NOT NULL, + * then of course it can only return FALSE.) + */ + public final RexNode simplifyUnknownAsFalse(RexNode e) { + return simplifyUnknownAs(e, FALSE); + } + + /** + * As {@link #simplify(RexNode)}, but specifying how UNKNOWN values are to be treated. + * + *

If UNKNOWN is treated as FALSE, this may allow certain additional simplifications. A + * result of UNKNOWN may yield FALSE, however it may still yield UNKNOWN. (If the simplified + * expression has type BOOLEAN NOT NULL, then of course it can only return FALSE.) + */ + public RexNode simplifyUnknownAs(RexNode e, RexUnknownAs unknownAs) { + final RexNode simplified = withParanoid(false).simplify(e, unknownAs); + if (paranoid) { + verify(e, simplified, unknownAs); + } + return simplified; + } + + /** + * Internal method to simplify an expression. + * + *

Unlike the public {@link #simplify(RexNode)} and {@link #simplifyUnknownAsFalse(RexNode)} + * methods, never calls {@link #verify(RexNode, RexNode, RexUnknownAs)}. Verify adds an overhead + * that is only acceptable for a top-level call. + */ + RexNode simplify(RexNode e, RexUnknownAs unknownAs) { + if (isSafeExpression(e) && STRONG.isNull(e)) { + // Only boolean NULL (aka UNKNOWN) can be converted to FALSE. Even in + // unknownAs=FALSE mode, we must not convert a NULL integer (say) to FALSE + if (e.getType().getSqlTypeName() == SqlTypeName.BOOLEAN) { + switch (unknownAs) { + case FALSE: + case TRUE: + return rexBuilder.makeLiteral(unknownAs.toBoolean()); + default: + break; + } + } + return rexBuilder.makeNullLiteral(e.getType()); + } + switch (e.getKind()) { + case AND: + return simplifyAnd((RexCall) e, unknownAs); + case OR: + return simplifyOr((RexCall) e, unknownAs); + case NOT: + return simplifyNot((RexCall) e, unknownAs); + case CASE: + return simplifyCase((RexCall) e, unknownAs); + case COALESCE: + return simplifyCoalesce((RexCall) e); + case CAST: + case SAFE_CAST: + return simplifyCast((RexCall) e); + case CEIL: + case FLOOR: + return simplifyCeilFloor((RexCall) e); + case TRIM: + return simplifyTrim((RexCall) e); + case IS_NULL: + case IS_NOT_NULL: + case IS_TRUE: + case IS_NOT_TRUE: + case IS_FALSE: + case IS_NOT_FALSE: + assert e instanceof RexCall; + return simplifyIs((RexCall) e, unknownAs); + case EQUALS: + case GREATER_THAN: + case GREATER_THAN_OR_EQUAL: + case LESS_THAN: + case LESS_THAN_OR_EQUAL: + case NOT_EQUALS: + return simplifyComparison((RexCall) e, unknownAs); + case SEARCH: + return simplifySearch((RexCall) e, unknownAs); + case LIKE: + return simplifyLike((RexCall) e, unknownAs); + case MINUS_PREFIX: + case CHECKED_MINUS_PREFIX: + return simplifyUnaryMinus((RexCall) e, unknownAs); + case PLUS_PREFIX: + return simplifyUnaryPlus((RexCall) e, unknownAs); + case PLUS: + case MINUS: + case TIMES: + case DIVIDE: + case CHECKED_PLUS: + case CHECKED_MINUS: + case CHECKED_TIMES: + case CHECKED_DIVIDE: + return simplifyArithmetic((RexCall) e); + case M2V: + return simplifyM2v((RexCall) e); + default: + if (e.getClass() == RexCall.class) { + return simplifyGenericNode((RexCall) e); + } else { + return e; + } + } + } + + /** Applies NOT to an expression. */ + RexNode not(RexNode e) { + return RexUtil.not(rexBuilder, e); + } + + /** Applies IS NOT FALSE to an expression. */ + RexNode isNotFalse(RexNode e) { + return e.isAlwaysTrue() + ? rexBuilder.makeLiteral(true) + : e.isAlwaysFalse() + ? rexBuilder.makeLiteral(false) + : e.getKind() == SqlKind.NOT + ? isNotTrue(((RexCall) e).operands.get(0)) + : predicates.isEffectivelyNotNull(e) + ? e // would "CAST(e AS BOOLEAN NOT NULL)" better? + : rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_FALSE, e); + } + + /** Applies IS NOT TRUE to an expression. */ + RexNode isNotTrue(RexNode e) { + return e.isAlwaysTrue() + ? rexBuilder.makeLiteral(false) + : e.isAlwaysFalse() + ? rexBuilder.makeLiteral(true) + : e.getKind() == SqlKind.NOT + ? isNotFalse(((RexCall) e).operands.get(0)) + : predicates.isEffectivelyNotNull(e) + ? not(e) + : rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_TRUE, e); + } + + /** Applies IS TRUE to an expression. */ + RexNode isTrue(RexNode e) { + return e.isAlwaysTrue() + ? rexBuilder.makeLiteral(true) + : e.isAlwaysFalse() + ? rexBuilder.makeLiteral(false) + : e.getKind() == SqlKind.NOT + ? isFalse(((RexCall) e).operands.get(0)) + : predicates.isEffectivelyNotNull(e) + ? e // would "CAST(e AS BOOLEAN NOT NULL)" better? + : rexBuilder.makeCall(SqlStdOperatorTable.IS_TRUE, e); + } + + /** Applies IS FALSE to an expression. */ + RexNode isFalse(RexNode e) { + return e.isAlwaysTrue() + ? rexBuilder.makeLiteral(false) + : e.isAlwaysFalse() + ? rexBuilder.makeLiteral(true) + : e.getKind() == SqlKind.NOT + ? isTrue(((RexCall) e).operands.get(0)) + : predicates.isEffectivelyNotNull(e) + ? not(e) + : rexBuilder.makeCall(SqlStdOperatorTable.IS_FALSE, e); + } + + /** Runs simplification inside a non-specialized node. */ + private RexNode simplifyGenericNode(RexCall e) { + final List operands = new ArrayList<>(e.operands); + simplifyList(operands, UNKNOWN); + if (e.operands.equals(operands)) { + return e; + } + return rexBuilder.makeCall(e.getParserPosition(), e.getType(), e.getOperator(), operands); + } + + /** + * Try to find a literal with the given value in the input list. The type of the literal must be + * one of the numeric types. + */ + private static int findLiteralIndex(List operands, BigDecimal value) { + for (int i = 0; i < operands.size(); i++) { + if (operands.get(i).isA(SqlKind.LITERAL)) { + Comparable comparable = ((RexLiteral) operands.get(i)).getValue(); + if (comparable instanceof BigDecimal + && value.compareTo((BigDecimal) comparable) == 0) { + return i; + } + } + } + return -1; + } + + private RexNode simplifyArithmetic(RexCall e) { + if (e.getType().getSqlTypeName().getFamily() != SqlTypeFamily.NUMERIC + || e.getOperands().stream() + .anyMatch( + o -> + e.getType().getSqlTypeName().getFamily() + != SqlTypeFamily.NUMERIC)) { + // we only support simplifying numeric types. + return simplifyGenericNode(e); + } + + assert e.getOperands().size() == 2; + + switch (e.getKind()) { + // These simplifications are safe for both checked and unchecked arithemtic. + case PLUS: + case CHECKED_PLUS: + return simplifyPlus(e); + case MINUS: + case CHECKED_MINUS: + return simplifyMinus(e); + case TIMES: + case CHECKED_TIMES: + return simplifyMultiply(e); + case DIVIDE: + case CHECKED_DIVIDE: + return simplifyDivide(e); + default: + throw new IllegalArgumentException( + "Unsupported arithmeitc operation " + e.getKind()); + } + } + + private RexNode simplifyPlus(RexCall e) { + final int zeroIndex = findLiteralIndex(e.operands, BigDecimal.ZERO); + if (zeroIndex >= 0) { + // return the other operand. + RexNode other = e.getOperands().get((zeroIndex + 1) % 2); + return other.getType().equals(e.getType()) + ? other + : rexBuilder.makeCast(e.getParserPosition(), e.getType(), other); + } + return simplifyGenericNode(e); + } + + private RexNode simplifyMinus(RexCall e) { + final int zeroIndex = findLiteralIndex(e.operands, BigDecimal.ZERO); + if (zeroIndex == 1) { + RexNode leftOperand = e.getOperands().get(0); + return leftOperand.getType().equals(e.getType()) + ? leftOperand + : rexBuilder.makeCast(e.getParserPosition(), e.getType(), leftOperand); + } + return simplifyGenericNode(e); + } + + private RexNode simplifyMultiply(RexCall e) { + final int oneIndex = findLiteralIndex(e.operands, BigDecimal.ONE); + if (oneIndex >= 0) { + // return the other operand. + RexNode other = e.getOperands().get((oneIndex + 1) % 2); + return other.getType().equals(e.getType()) + ? other + : rexBuilder.makeCast(e.getParserPosition(), e.getType(), other); + } + return simplifyGenericNode(e); + } + + private RexNode simplifyDivide(RexCall e) { + final int oneIndex = findLiteralIndex(e.operands, BigDecimal.ONE); + if (oneIndex == 1) { + RexNode leftOperand = e.getOperands().get(0); + return leftOperand.getType().equals(e.getType()) + ? leftOperand + : rexBuilder.makeCast(e.getParserPosition(), e.getType(), leftOperand); + } + return simplifyGenericNode(e); + } + + private RexNode simplifyLike(RexCall e, RexUnknownAs unknownAs) { + if (e.operands.get(1) instanceof RexLiteral) { + final RexLiteral literal = (RexLiteral) e.operands.get(1); + String likeStr = requireNonNull(literal.getValueAs(String.class)); + Pattern pattern = Pattern.compile("%+"); + String value = pattern.matcher(likeStr).replaceAll("%"); + if ("%".equals(value)) { + // "x LIKE '%'" or "x LIKE '%...'" simplifies to "x = x" + final RexNode x = e.operands.get(0); + return simplify( + rexBuilder.makeCall( + e.getParserPosition(), SqlStdOperatorTable.EQUALS, x, x), + unknownAs); + } + // simplify "x LIKE '%%\%%a%%%'" to "x LIKE '%\%%a%'", default escape is '\' + if (e.operands.size() == 2) { + e = + (RexCall) + rexBuilder.makeCall( + e.getParserPosition(), + e.getOperator(), + e.operands.get(0), + rexBuilder.makeLiteral( + simplifyLikeString(likeStr, '\\', '%'))); + } + if (e.operands.size() == 3 && e.operands.get(2) instanceof RexLiteral) { + final RexLiteral escapeLiteral = (RexLiteral) e.operands.get(2); + // FLINK MODIFICATION BEGIN, CALCITE-7578 + final String escapeStr = requireNonNull(escapeLiteral.getValueAs(String.class)); + if (escapeStr.length() == 1) { + char escape = escapeStr.charAt(0); + e = + (RexCall) + rexBuilder.makeCall( + e.getParserPosition(), + e.getOperator(), + e.operands.get(0), + rexBuilder.makeLiteral( + simplifyLikeString(likeStr, escape, '%')), + escapeLiteral); + } + // FLINK MODIFICATION END + } + } + return simplifyGenericNode(e); + } + + // string 'AA%%__%%AA' simplify to 'AA__%AA' + // string with even escapes 'AA\\\\%%__%%AA' simplify to 'AA\\__%AA' + // string with odd escapes 'AA\\\\\\%%__%%AA' simplify to 'AA\\\\\\%__%AA' + private String simplifyMixedWildcards(String str, char escape) { + // FLINK MODIFICATION BEGIN + Pattern pattern = getMixedWildCardPattern(escape); + // FLINK MODIFICATION END + Matcher matcher = pattern.matcher(str); + StringBuilder builder = new StringBuilder(); + int from = 0; + while (matcher.find()) { + int start = matcher.start(); + String group = requireNonNull(matcher.group(0)); + if (start > 0 + && str.charAt(start - 1) == escape + && consecutiveSameCharCountBefore(str, start - 1, escape) % 2 == 1) { + builder.append(str, from, start + 1); + builder.append(simplifyPercentAndUnderline(group.substring(1))); + } else { + builder.append(str, from, start); + builder.append(simplifyPercentAndUnderline(group)); + } + from = matcher.end(); + } + if (from < str.length()) { + builder.append(str.substring(from)); + } + return builder.toString(); + } + + // FLINK MODIFICATION BEGIN + private static Pattern getMixedWildCardPattern(char escape) { + switch (escape) { + case '%': + return Pattern.compile("_+"); + case '_': + return Pattern.compile("%+"); + default: + return Pattern.compile("[_%]+"); + } + } + + // FLINK MODIFICATION END + + // Tool method: count the number of consecutive identical characters before index + private int consecutiveSameCharCountBefore(String str, int index, char escape) { + int count = 0; + while (index >= 0) { + if (str.charAt(index) != escape) { + break; + } + count++; + index--; + } + return count; + } + + // Tool method: simplified string mixed with '%' and '_' + private String simplifyPercentAndUnderline(String str) { + StringBuilder builder = new StringBuilder(); + boolean containsPercent = false; + for (int index = 0; index < str.length(); index++) { + if (str.charAt(index) == '%') { + containsPercent = true; + continue; + } + if (str.charAt(index) == '_') { + builder.append('_'); + } + } + if (containsPercent) { + builder.append('%'); + } + return builder.toString(); + } + + /** + * Simplifies like string with escape. A like '%%#%%A%%' escape '#' should simplify to A like + * '%#%%A%' escape '#'. + */ + private String simplifyLikeString(String content, char escape, char wildcard) { + int escapeCount = 0; + int wildcardCount = 0; + StringBuilder builder = new StringBuilder(); + for (int index = 0; index < content.length(); index++) { + char c = content.charAt(index); + if (c == escape) { + builder.append(c); + escapeCount++; + wildcardCount = 0; + continue; + } + if (c == wildcard) { + if (escapeCount % 2 == 1) { + builder.append(wildcard); + } else if (wildcardCount == 0) { + builder.append(wildcard); + wildcardCount++; + } + escapeCount = 0; + continue; + } + builder.append(c); + escapeCount = 0; + wildcardCount = 0; + } + return simplifyMixedWildcards(builder.toString(), escape); + } + + // e must be a comparison (=, >, >=, <, <=, !=) + private RexNode simplifyComparison(RexCall e, RexUnknownAs unknownAs) { + //noinspection unchecked + return simplifyComparison(e, unknownAs, Comparable.class); + } + + // e must be a comparison (=, >, >=, <, <=, !=) + private > RexNode simplifyComparison( + RexCall e, RexUnknownAs unknownAs, Class clazz) { + final List operands = new ArrayList<>(e.operands); + // UNKNOWN mode is warranted: false = null + simplifyList(operands, UNKNOWN); + + // Simplify "x x" + final RexNode o0 = operands.get(0); + final RexNode o1 = operands.get(1); + if (o0.equals(o1) && RexUtil.isDeterministic(o0)) { + RexNode newExpr; + switch (e.getKind()) { + case EQUALS: + case GREATER_THAN_OR_EQUAL: + case LESS_THAN_OR_EQUAL: + // "x = x" simplifies to "null or x is not null" (similarly <= and >=) + newExpr = + rexBuilder.makeCall( + e.getParserPosition(), + SqlStdOperatorTable.OR, + rexBuilder.makeNullLiteral(e.getType()), + rexBuilder.makeCall( + e.getParserPosition(), + SqlStdOperatorTable.IS_NOT_NULL, + o0)); + return simplify(newExpr, unknownAs); + case NOT_EQUALS: + case LESS_THAN: + case GREATER_THAN: + // "x != x" simplifies to "null and x is null" (similarly < and >) + newExpr = + rexBuilder.makeCall( + e.getParserPosition(), + SqlStdOperatorTable.AND, + rexBuilder.makeNullLiteral(e.getType()), + rexBuilder.makeCall( + e.getParserPosition(), + SqlStdOperatorTable.IS_NULL, + o0)); + return simplify(newExpr, unknownAs); + default: + // unknown kind + } + } + + if (o0.getType().getSqlTypeName() == SqlTypeName.BOOLEAN) { + Comparison cmp = + Comparison.of( + rexBuilder.makeCall(e.getParserPosition(), e.getOperator(), o0, o1), + node -> true); + if (cmp != null) { + if (cmp.literal.isAlwaysTrue()) { + switch (cmp.kind) { + case GREATER_THAN_OR_EQUAL: + case EQUALS: // x=true + return cmp.ref; + case LESS_THAN: + case NOT_EQUALS: // x!=true + return simplify(not(cmp.ref), unknownAs); + case GREATER_THAN: + /* this is false, but could be null if x is null */ + if (!cmp.ref.getType().isNullable()) { + return rexBuilder.makeLiteral(false); + } + break; + case LESS_THAN_OR_EQUAL: + /* this is true, but could be null if x is null */ + if (!cmp.ref.getType().isNullable()) { + return rexBuilder.makeLiteral(true); + } + break; + default: + break; + } + } + if (cmp.literal.isAlwaysFalse()) { + switch (cmp.kind) { + case EQUALS: + case LESS_THAN_OR_EQUAL: + return simplify(not(cmp.ref), unknownAs); + case NOT_EQUALS: + case GREATER_THAN: + return cmp.ref; + case GREATER_THAN_OR_EQUAL: + /* this is true, but could be null if x is null */ + if (!cmp.ref.getType().isNullable()) { + return rexBuilder.makeLiteral(true); + } + break; + case LESS_THAN: + /* this is false, but could be null if x is null */ + if (!cmp.ref.getType().isNullable()) { + return rexBuilder.makeLiteral(false); + } + break; + default: + break; + } + } + } + } + + // Simplify " " + // For example, "1 = 2" becomes FALSE; + // "1 != 1" becomes FALSE; + // "1 != NULL" becomes UNKNOWN (or FALSE if unknownAsFalse); + // "1 != '1'" is unchanged because the types are not the same. + if (o0.isA(SqlKind.LITERAL) + && o1.isA(SqlKind.LITERAL) + && SqlTypeUtil.equalSansNullability( + rexBuilder.getTypeFactory(), o0.getType(), o1.getType())) { + final C v0 = ((RexLiteral) o0).getValueAs(clazz); + final C v1 = ((RexLiteral) o1).getValueAs(clazz); + if (v0 == null || v1 == null) { + return unknownAs == FALSE + ? rexBuilder.makeLiteral(false) + : rexBuilder.makeNullLiteral(e.getType()); + } + final int comparisonResult = v0.compareTo(v1); + switch (e.getKind()) { + case EQUALS: + return rexBuilder.makeLiteral(comparisonResult == 0); + case GREATER_THAN: + return rexBuilder.makeLiteral(comparisonResult > 0); + case GREATER_THAN_OR_EQUAL: + return rexBuilder.makeLiteral(comparisonResult >= 0); + case LESS_THAN: + return rexBuilder.makeLiteral(comparisonResult < 0); + case LESS_THAN_OR_EQUAL: + return rexBuilder.makeLiteral(comparisonResult <= 0); + case NOT_EQUALS: + return rexBuilder.makeLiteral(comparisonResult != 0); + default: + throw new AssertionError(); + } + } + + RexNode node = simplifyComparisonWithNull(e, unknownAs); + if (node instanceof RexLiteral) { + return node; + } + + // If none of the arguments were simplified, return the call unchanged. + final RexNode e2; + if (operands.equals(e.operands)) { + e2 = e; + } else { + e2 = rexBuilder.makeCall(e.getParserPosition(), e.op, operands); + } + return simplifyUsingPredicates(e2, clazz); + } + + /** + * If this RexNode is a comparison against NULL, return FALSE, otherwise return it unchanged. + */ + static RexNode simplifyComparisonWithNull( + RexNode e, RexBuilder rexBuilder, RexUnknownAs unknownAs) { + final RexSimplify.Comparison comparison = RexSimplify.Comparison.of(e); + if (comparison != null) { + boolean againstNull = comparison.literal.isNull(); + // There is another possibility to check: in a comparison like 1 = null, + // the "non-literal" side of the Comparison can be null + if (comparison.ref instanceof RexLiteral) { + againstNull = againstNull || ((RexLiteral) comparison.ref).isNull(); + } + if (againstNull) { + return unknownAs == FALSE + ? rexBuilder.makeLiteral(false) + : rexBuilder.makeNullLiteral(e.getType()); + } + } + return e; + } + + public static RexNode simplifyComparisonWithNull(RexNode e, RexBuilder rexBuilder) { + return RexSimplify.simplifyComparisonWithNull(e, rexBuilder, FALSE); + } + + /** + * If this RexNode is a comparison against NULL, return a simplified form, otherwise return it + * unchanged. + */ + public RexNode simplifyComparisonWithNull(RexNode e, RexUnknownAs unknownAs) { + return simplifyComparisonWithNull(e, this.rexBuilder, unknownAs); + } + + /** Simplifies a conjunction of boolean expressions. */ + @Deprecated // to be removed before 2.0 + public RexNode simplifyAnds(Iterable nodes) { + ensureParanoidOff(); + return simplifyAnds(nodes, defaultUnknownAs); + } + + // package-protected only for a deprecated method; treat as private + RexNode simplifyAnds(Iterable nodes, RexUnknownAs unknownAs) { + final List terms = new ArrayList<>(); + final List notTerms = new ArrayList<>(); + for (RexNode e : nodes) { + RelOptUtil.decomposeConjunction(e, terms, notTerms); + } + simplifyList(terms, UNKNOWN); + simplifyList(notTerms, UNKNOWN); + if (unknownAs == FALSE) { + return simplifyAnd2ForUnknownAsFalse(terms, notTerms); + } + return simplifyAnd2(terms, notTerms); + } + + private void simplifyList(List terms, RexUnknownAs unknownAs) { + terms.replaceAll(e -> simplify(e, unknownAs)); + } + + private void simplifyAndTerms(List terms, RexUnknownAs unknownAs) { + RexSimplify simplify = this; + for (int i = 0; i < terms.size(); i++) { + RexNode t = terms.get(i); + if (Predicate.of(t) == null) { + continue; + } + terms.set(i, simplify.simplify(t, unknownAs)); + RelOptPredicateList newPredicates = + simplify.predicates.union( + rexBuilder, + RelOptPredicateList.of(rexBuilder, terms.subList(i, i + 1))); + simplify = simplify.withPredicates(newPredicates); + } + for (int i = 0; i < terms.size(); i++) { + RexNode t = terms.get(i); + if (Predicate.of(t) != null) { + continue; + } + terms.set(i, simplify.simplify(t, unknownAs)); + } + } + + private void simplifyOrTerms(List terms, RexUnknownAs unknownAs) { + // Suppose we are processing "e1(x) OR e2(x) OR e3(x)". When we are + // visiting "e3(x)" we know both "e1(x)" and "e2(x)" are not true (they + // may be unknown), because if either of them were true we would have + // stopped. + RexSimplify simplify = this; + + // 'doneTerms' prevents us from visiting a term in both first and second + // loops. If we did this, the second visit would have a predicate saying + // that 'term' is false. Effectively, we sort terms: visiting + // 'allowedAsPredicate' terms in the first loop, and + // non-'allowedAsPredicate' in the second. Each term is visited once. + final BitSet doneTerms = new BitSet(); + for (int i = 0; i < terms.size(); i++) { + final RexNode t = terms.get(i); + if (!simplify.allowedAsPredicateDuringOrSimplification(t)) { + continue; + } + doneTerms.set(i); + final RexNode t2 = simplify.simplify(t, unknownAs); + terms.set(i, t2); + final RexNode inverse = simplify.simplify(isNotTrue(t2), RexUnknownAs.UNKNOWN); + final RelOptPredicateList newPredicates = + simplify.predicates.union( + rexBuilder, + RelOptPredicateList.of(rexBuilder, ImmutableList.of(inverse))); + simplify = simplify.withPredicates(newPredicates); + } + for (int i = 0; i < terms.size(); i++) { + final RexNode t = terms.get(i); + if (doneTerms.get(i)) { + continue; // we visited this term in the first loop + } + terms.set(i, simplify.simplify(t, unknownAs)); + } + } + + /** + * Decides whether the given node could be used as a predicate during the simplification of + * other OR operands. + */ + private boolean allowedAsPredicateDuringOrSimplification(final RexNode t) { + Predicate predicate = Predicate.of(t); + return predicate != null && predicate.allowedInOr(predicates); + } + + private RexNode simplifyNot(RexCall call, RexUnknownAs unknownAs) { + final RexNode a = call.getOperands().get(0); + final List newOperands; + switch (a.getKind()) { + case NOT: + // NOT NOT x ==> x + return simplify(((RexCall) a).getOperands().get(0), unknownAs); + + case SEARCH: + // NOT SEARCH(x, Sarg[(-inf, 10) OR NULL) ==> SEARCH(x, Sarg[[10, +inf)]) + final RexCall call2 = (RexCall) a; + final RexNode ref = call2.operands.get(0); + final RexLiteral literal = (RexLiteral) call2.operands.get(1); + final Sarg sarg = literal.getValueAs(Sarg.class); + return simplifySearch( + call2.clone( + call2.type, + ImmutableList.of( + ref, + rexBuilder.makeLiteral( + requireNonNull(sarg, "sarg").negate(), + literal.getType(), + literal.getTypeName()))), + unknownAs.negate()); + + case LITERAL: + if (a.getType().getSqlTypeName() == SqlTypeName.BOOLEAN + && !RexLiteral.isNullLiteral(a)) { + return rexBuilder.makeLiteral(!RexLiteral.booleanValue(a)); + } + break; + + case AND: + // NOT distributivity for AND + newOperands = new ArrayList<>(); + for (RexNode operand : ((RexCall) a).getOperands()) { + newOperands.add(simplify(not(operand), unknownAs)); + } + return simplify( + rexBuilder.makeCall( + call.getParserPosition(), SqlStdOperatorTable.OR, newOperands), + unknownAs); + + case OR: + // NOT distributivity for OR + newOperands = new ArrayList<>(); + for (RexNode operand : ((RexCall) a).getOperands()) { + newOperands.add(simplify(not(operand), unknownAs)); + } + return simplify( + rexBuilder.makeCall( + call.getParserPosition(), SqlStdOperatorTable.AND, newOperands), + unknownAs); + + case CASE: + newOperands = new ArrayList<>(); + List operands = ((RexCall) a).getOperands(); + for (int i = 0; i < operands.size(); i += 2) { + if (i + 1 == operands.size()) { + newOperands.add(not(operands.get(i))); + } else { + newOperands.add(operands.get(i)); + newOperands.add(not(operands.get(i + 1))); + } + } + return simplify( + rexBuilder.makeCall( + call.getParserPosition(), SqlStdOperatorTable.CASE, newOperands), + unknownAs); + + case IN: + case NOT_IN: + // do not try to negate + break; + + default: + final SqlKind negateKind = a.getKind().negate(); + if (a.getKind() != negateKind) { + return simplify( + rexBuilder.makeCall( + call.getParserPosition(), + RexUtil.op(negateKind), + ((RexCall) a).getOperands()), + unknownAs); + } + final SqlKind negateKind2 = a.getKind().negateNullSafe(); + if (a.getKind() != negateKind2) { + return simplify( + rexBuilder.makeCall( + call.getParserPosition(), + RexUtil.op(negateKind2), + ((RexCall) a).getOperands()), + unknownAs); + } + } + + RexNode a2 = simplify(a, unknownAs.negate()); + if (a == a2) { + return call; + } + return not(a2); + } + + private RexNode simplifyUnaryMinus(RexCall call, RexUnknownAs unknownAs) { + final RexNode a = call.getOperands().get(0); + if (a.getKind() == SqlKind.MINUS_PREFIX) { + // -(-(x)) ==> x + return simplify(((RexCall) a).getOperands().get(0), unknownAs); + } + return simplifyGenericNode(call); + } + + private RexNode simplifyUnaryPlus(RexCall call, RexUnknownAs unknownAs) { + return simplify(call.getOperands().get(0), unknownAs); + } + + private RexNode simplifyIs(RexCall call, RexUnknownAs unknownAs) { + final SqlKind kind = call.getKind(); + final RexNode a = call.getOperands().get(0); + final RexNode simplified = simplifyIs1(kind, a, unknownAs); + return simplified == null ? call : simplified; + } + + private @Nullable RexNode simplifyIs1(SqlKind kind, RexNode a, RexUnknownAs unknownAs) { + // UnknownAs.FALSE corresponds to x IS TRUE evaluation + // UnknownAs.TRUE to x IS NOT FALSE + // Note that both UnknownAs.TRUE and UnknownAs.FALSE only changes the meaning of Unknown + // (1) if we are already in UnknownAs.FALSE mode; x IS TRUE can be simplified to x + // (2) similarly in UnknownAs.TRUE mode; x IS NOT FALSE can be simplified to x + // (3) x IS FALSE could be rewritten to (NOT x) IS TRUE and from there the 1. rule applies + // (4) x IS NOT TRUE can be rewritten to (NOT x) IS NOT FALSE and from there the 2. rule + // applies + if (kind == SqlKind.IS_TRUE && unknownAs == RexUnknownAs.FALSE) { + return simplify(a, unknownAs); + } + if (kind == SqlKind.IS_FALSE && unknownAs == RexUnknownAs.FALSE) { + return simplify(not(a), unknownAs); + } + if (kind == SqlKind.IS_NOT_FALSE && unknownAs == RexUnknownAs.TRUE) { + return simplify(a, unknownAs); + } + if (kind == SqlKind.IS_NOT_TRUE && unknownAs == RexUnknownAs.TRUE) { + return simplify(not(a), unknownAs); + } + final RexNode pred = simplifyIsPredicate(kind, a); + if (pred != null) { + return pred; + } + + return simplifyIs2(kind, a, unknownAs); + } + + private @Nullable RexNode simplifyIsPredicate(SqlKind kind, RexNode a) { + if (!(RexUtil.isReferenceOrAccess(a, true) || RexUtil.isDeterministic(a))) { + return null; + } + + for (RexNode p : predicates.pulledUpPredicates) { + IsPredicate pred = IsPredicate.of(p); + if (pred == null || !a.equals(pred.ref)) { + continue; + } + if (kind == pred.kind) { + return rexBuilder.makeLiteral(true); + } + } + return null; + } + + private @Nullable RexNode simplifyIs2(SqlKind kind, RexNode a, RexUnknownAs unknownAs) { + final RexNode simplified; + switch (kind) { + case IS_NULL: + // x IS NULL ==> FALSE (if x is not nullable) + validateStrongPolicy(a); + simplified = simplifyIsNull(a); + if (simplified != null) { + return simplified; + } + break; + case IS_NOT_NULL: + // x IS NOT NULL ==> TRUE (if x is not nullable) + validateStrongPolicy(a); + simplified = simplifyIsNotNull(a); + if (simplified != null) { + return simplified; + } + break; + + case IS_TRUE: + // x IS TRUE ==> x (if x is not nullable) + if (predicates.isEffectivelyNotNull(a)) { + return simplify(a, unknownAs); + } + simplified = simplify(a, RexUnknownAs.FALSE); + if (simplified == a) { + return null; + } + return isTrue(simplified); + + case IS_NOT_FALSE: + // x IS NOT FALSE ==> x (if x is not nullable) + if (predicates.isEffectivelyNotNull(a)) { + return simplify(a, unknownAs); + } + simplified = simplify(a, RexUnknownAs.TRUE); + if (simplified == a) { + return null; + } + return isNotFalse(simplified); + + case IS_FALSE: + case IS_NOT_TRUE: + // x IS NOT TRUE ==> NOT x (if x is not nullable) + // x IS FALSE ==> NOT x (if x is not nullable) + if (predicates.isEffectivelyNotNull(a)) { + return simplify(not(a), unknownAs); + } + break; + + default: + break; + } + switch (a.getKind()) { + case NOT: + // (NOT x) IS TRUE ==> x IS FALSE + // Similarly for IS NOT TRUE, IS FALSE, etc. + // + // Note that + // (NOT x) IS TRUE !=> x IS FALSE + // because of null values. + final SqlOperator notKind = RexUtil.op(kind.negateNullSafe()); + final RexNode arg = ((RexCall) a).operands.get(0); + return simplify(rexBuilder.makeCall(notKind, arg), UNKNOWN); + default: + break; + } + final RexNode a2 = simplify(a, UNKNOWN); + if (a != a2) { + return rexBuilder.makeCall(RexUtil.op(kind), ImmutableList.of(a2)); + } + return null; // cannot be simplified + } + + private @Nullable RexNode simplifyIsNotNull(RexNode a) { + // Simplify the argument first, + // call ourselves recursively to see whether we can make more progress. + // For example, given + // "(CASE WHEN FALSE THEN 1 ELSE 2) IS NOT NULL" we first simplify the + // argument to "2", and only then we can simplify "2 IS NOT NULL" to "TRUE". + a = simplify(a, UNKNOWN); + if (!a.getType().isNullable() && isSafeExpression(a)) { + return rexBuilder.makeLiteral(true); + } + if (RexUtil.isLosslessCast(a)) { + if (!a.getType().isNullable()) { + return rexBuilder.makeLiteral(true); + } + return rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, RexUtil.removeCast(a)); + } + if (predicates.pulledUpPredicates.contains(a)) { + return rexBuilder.makeLiteral(true); + } + if (hasCustomNullabilityRules(a.getKind())) { + return null; + } + switch (Strong.policy(a)) { + case NOT_NULL: + return rexBuilder.makeLiteral(true); + case ANY: + // "f" is a strong operator, so "f(operand0, operand1) IS NOT NULL" + // simplifies to "operand0 IS NOT NULL AND operand1 IS NOT NULL" + final List operands = new ArrayList<>(); + for (RexNode operand : ((RexCall) a).getOperands()) { + final RexNode simplified = simplifyIsNotNull(operand); + if (simplified == null) { + operands.add(rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, operand)); + } else if (simplified.isAlwaysFalse()) { + return rexBuilder.makeLiteral(false); + } else { + operands.add(simplified); + } + } + return RexUtil.composeConjunction(rexBuilder, operands); + case CUSTOM: + switch (a.getKind()) { + case LITERAL: + return rexBuilder.makeLiteral(!((RexLiteral) a).isNull()); + default: + throw new AssertionError( + "every CUSTOM policy needs a handler, " + a.getKind()); + } + case AS_IS: + default: + return null; + } + } + + private @Nullable RexNode simplifyIsNull(RexNode a) { + // Simplify the argument first, + // call ourselves recursively to see whether we can make more progress. + // For example, given + // "(CASE WHEN FALSE THEN 1 ELSE 2) IS NULL" we first simplify the + // argument to "2", and only then we can simplify "2 IS NULL" to "FALSE". + a = simplify(a, UNKNOWN); + if (!a.getType().isNullable() && isSafeExpression(a)) { + return rexBuilder.makeLiteral(false); + } + if (RexUtil.isLosslessCast(a)) { + if (!a.getType().isNullable()) { + return rexBuilder.makeLiteral(false); + } + return rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL, RexUtil.removeCast(a)); + } + if (RexUtil.isNull(a)) { + return rexBuilder.makeLiteral(true); + } + if (hasCustomNullabilityRules(a.getKind())) { + return null; + } + switch (Strong.policy(a)) { + case NOT_NULL: + return rexBuilder.makeLiteral(false); + case ANY: + // "f" is a strong operator, so "f(operand0, operand1) IS NULL" simplifies + // to "operand0 IS NULL OR operand1 IS NULL" + final List operands = new ArrayList<>(); + for (RexNode operand : ((RexCall) a).getOperands()) { + final RexNode simplified = simplifyIsNull(operand); + if (simplified == null) { + operands.add(rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL, operand)); + } else { + operands.add(simplified); + } + } + return RexUtil.composeDisjunction(rexBuilder, operands, false); + case AS_IS: + default: + return null; + } + } + + /** + * Validates strong policy for specified {@link RexNode}. + * + * @param rexNode Rex node to validate the strong policy + * @throws AssertionError If the validation fails + */ + private static void validateStrongPolicy(RexNode rexNode) { + if (hasCustomNullabilityRules(rexNode.getKind())) { + return; + } + switch (Strong.policy(rexNode)) { + case NOT_NULL: + assert !rexNode.getType().isNullable(); + break; + case ANY: + List operands = ((RexCall) rexNode).getOperands(); + if (rexNode.getType().isNullable()) { + assert operands.stream() + .map(RexNode::getType) + .anyMatch(RelDataType::isNullable); + } else { + assert operands.stream() + .map(RexNode::getType) + .noneMatch(RelDataType::isNullable); + } + break; + default: + break; + } + } + + /** + * Returns {@code true} if specified {@link SqlKind} has custom nullability rules which depend + * not only on the nullability of input operands. + * + *

For example, CAST may be used to change the nullability of its operand type, so it may be + * nullable, though the argument type was non-nullable. + * + * @param sqlKind Sql kind to check + * @return {@code true} if specified {@link SqlKind} has custom nullability rules + */ + private static boolean hasCustomNullabilityRules(SqlKind sqlKind) { + switch (sqlKind) { + case CAST: + case ITEM: + return true; + default: + return false; + } + } + + private RexNode simplifyCoalesce(RexCall call) { + final Set operandSet = new HashSet<>(); + final List operands = new ArrayList<>(); + for (RexNode operand : call.getOperands()) { + operand = simplify(operand, UNKNOWN); + if (!RexUtil.isNull(operand) && operandSet.add(operand)) { + operands.add(operand); + } + if (!operand.getType().isNullable()) { + break; + } + } + switch (operands.size()) { + case 0: + return rexBuilder.makeNullLiteral(call.type); + case 1: + return operands.get(0); + default: + if (operands.equals(call.operands)) { + return call; + } + return call.clone(call.type, operands); + } + } + + private RexNode simplifyCase(RexCall call, RexUnknownAs unknownAs) { + List inputBranches = + CaseBranch.fromCaseOperands(rexBuilder, new ArrayList<>(call.getOperands())); + + // run simplification on all operands + RexSimplify condSimplifier = this.withPredicates(RelOptPredicateList.EMPTY); + RexSimplify valueSimplifier = this; + RelDataType caseType = call.getType(); + + boolean conditionNeedsSimplify = false; + CaseBranch lastBranch = null; + List branches = new ArrayList<>(); + for (CaseBranch inputBranch : inputBranches) { + // simplify the condition + RexNode newCond = condSimplifier.simplify(inputBranch.cond, RexUnknownAs.FALSE); + if (newCond.isAlwaysFalse()) { + // If the condition is false, we do not need to add it + continue; + } + + // simplify the value + RexNode newValue = valueSimplifier.simplify(inputBranch.value, unknownAs); + + // create new branch + if (lastBranch != null) { + if (lastBranch.value.equals(newValue) && isSafeExpression(newCond)) { + // in this case, last branch and new branch have the same conclusion, + // hence we create a new composite condition and we do not add it to + // the final branches for the time being + newCond = rexBuilder.makeCall(SqlStdOperatorTable.OR, lastBranch.cond, newCond); + conditionNeedsSimplify = true; + } else { + // if we reach here, the new branch is not mergeable with the last one, + // hence we are going to add the last branch to the final branches. + // if the last branch was merged, then we will simplify it first. + // otherwise, we just add it + CaseBranch branch = + generateBranch(conditionNeedsSimplify, condSimplifier, lastBranch); + if (!branch.cond.isAlwaysFalse()) { + // If the condition is not false, we add it to the final result + branches.add(branch); + if (branch.cond.isAlwaysTrue()) { + // If the condition is always true, we are done + lastBranch = null; + break; + } + } + conditionNeedsSimplify = false; + } + } + lastBranch = new CaseBranch(newCond, newValue); + + if (newCond.isAlwaysTrue()) { + // If the condition is always true, we are done (useful in first loop iteration) + break; + } + } + if (lastBranch != null) { + // we need to add the last pending branch once we have finished + // with the for loop + CaseBranch branch = generateBranch(conditionNeedsSimplify, condSimplifier, lastBranch); + if (!branch.cond.isAlwaysFalse()) { + branches.add(branch); + } + } + + if (branches.size() == 1) { + // we can just return the value in this case (matching the case type) + final RexNode value = branches.get(0).value; + if (sameTypeOrNarrowsNullability(caseType, value.getType())) { + return value; + } else { + return rexBuilder.makeAbstractCast( + call.getParserPosition(), caseType, value, false); + } + } + + if (call.getType().getSqlTypeName() == SqlTypeName.BOOLEAN) { + final RexNode result = simplifyBooleanCase(rexBuilder, branches, unknownAs, caseType); + if (result != null) { + if (sameTypeOrNarrowsNullability(caseType, result.getType())) { + return simplify(result, unknownAs); + } else { + // If the simplification would widen the nullability + RexNode simplified = simplify(result, UNKNOWN); + if (!simplified.getType().isNullable()) { + return simplified; + } else { + return rexBuilder.makeCast( + call.getParserPosition(), call.getType(), simplified); + } + } + } + } + List newOperands = CaseBranch.toCaseOperands(branches); + if (newOperands.equals(call.getOperands())) { + return call; + } + return rexBuilder.makeCall(call.getParserPosition(), SqlStdOperatorTable.CASE, newOperands); + } + + /** + * If boolean is true, simplify cond in input branch and return new branch. Otherwise, simply + * return input branch. + */ + private static CaseBranch generateBranch( + boolean simplifyCond, RexSimplify simplifier, CaseBranch branch) { + if (simplifyCond) { + // the previous branch was merged, time to simplify it and + // add it to the final result + return new CaseBranch( + simplifier.simplify(branch.cond, RexUnknownAs.FALSE), branch.value); + } + return branch; + } + + /** Return if the new type is the same and at most narrows the nullability. */ + private boolean sameTypeOrNarrowsNullability(RelDataType oldType, RelDataType newType) { + return oldType.equals(newType) + || (SqlTypeUtil.equalSansNullability(rexBuilder.typeFactory, oldType, newType) + && oldType.isNullable()); + } + + /** Object to describe a CASE branch. */ + static final class CaseBranch { + + private final RexNode cond; + private final RexNode value; + + CaseBranch(RexNode cond, RexNode value) { + this.cond = cond; + this.value = value; + } + + @Override + public String toString() { + return cond + " => " + value; + } + + /** Given "CASE WHEN p1 THEN v1 ... ELSE e END" returns [(p1, v1), ..., (true, e)]. */ + private static List fromCaseOperands( + RexBuilder rexBuilder, List operands) { + List ret = new ArrayList<>(); + for (int i = 0; i < operands.size() - 1; i += 2) { + ret.add(new CaseBranch(operands.get(i), operands.get(i + 1))); + } + ret.add(new CaseBranch(rexBuilder.makeLiteral(true), Util.last(operands))); + return ret; + } + + private static List toCaseOperands(List branches) { + List ret = new ArrayList<>(); + for (int i = 0; i < branches.size() - 1; i++) { + CaseBranch branch = branches.get(i); + ret.add(branch.cond); + ret.add(branch.value); + } + CaseBranch lastBranch = Util.last(branches); + assert lastBranch.cond.isAlwaysTrue(); + ret.add(lastBranch.value); + return ret; + } + } + + /** Decides whether it is safe to flatten the given CASE part into ANDs/ORs. */ + enum SafeRexVisitor implements RexVisitor { + INSTANCE; + + @SuppressWarnings("ImmutableEnumChecker") + private final Set safeOps; + + @SuppressWarnings("ImmutableEnumChecker") + private final ImmutableSet safeOperators; + + SafeRexVisitor() { + ImmutableSet.Builder builder = ImmutableSet.builder(); + builder.addAll(SqlStdOperatorTable.QUANTIFY_OPERATORS); + safeOperators = builder.build(); + + Set safeOps = EnumSet.noneOf(SqlKind.class); + safeOps.addAll(SqlKind.COMPARISON); + safeOps.add(SqlKind.ARRAY_VALUE_CONSTRUCTOR); + safeOps.add(SqlKind.PLUS_PREFIX); + safeOps.add(SqlKind.MINUS_PREFIX); + safeOps.add(SqlKind.CHECKED_MINUS_PREFIX); + safeOps.add(SqlKind.PLUS); + safeOps.add(SqlKind.MINUS); + safeOps.add(SqlKind.TIMES); + safeOps.add(SqlKind.CHECKED_PLUS); + safeOps.add(SqlKind.CHECKED_MINUS); + safeOps.add(SqlKind.CHECKED_TIMES); + safeOps.add(SqlKind.IS_FALSE); + safeOps.add(SqlKind.IS_NOT_FALSE); + safeOps.add(SqlKind.IS_TRUE); + safeOps.add(SqlKind.IS_NOT_TRUE); + safeOps.add(SqlKind.IS_NULL); + safeOps.add(SqlKind.IS_NOT_NULL); + safeOps.add(SqlKind.IS_DISTINCT_FROM); + safeOps.add(SqlKind.IS_NOT_DISTINCT_FROM); + safeOps.add(SqlKind.IN); + safeOps.add(SqlKind.SEARCH); + safeOps.add(SqlKind.OR); + safeOps.add(SqlKind.AND); + safeOps.add(SqlKind.NOT); + safeOps.add(SqlKind.CASE); + safeOps.add(SqlKind.LIKE); + safeOps.add(SqlKind.COALESCE); + safeOps.add(SqlKind.TRIM); + safeOps.add(SqlKind.LTRIM); + safeOps.add(SqlKind.RTRIM); + safeOps.add(SqlKind.BETWEEN); + safeOps.add(SqlKind.CEIL); + safeOps.add(SqlKind.FLOOR); + safeOps.add(SqlKind.REVERSE); + safeOps.add(SqlKind.TIMESTAMP_ADD); + safeOps.add(SqlKind.TIMESTAMP_DIFF); + this.safeOps = Sets.immutableEnumSet(safeOps); + } + + @Override + public Boolean visitInputRef(RexInputRef inputRef) { + return true; + } + + @Override + public Boolean visitLocalRef(RexLocalRef localRef) { + return false; + } + + @Override + public Boolean visitLiteral(RexLiteral literal) { + return true; + } + + @Override + public Boolean visitCall(RexCall call) { + SqlKind sqlKind = call.getKind(); + SqlOperator sqlOperator = call.getOperator(); + + switch (sqlKind) { + case DIVIDE: + case MOD: + List operands = call.getOperands(); + boolean isSafe = + RexVisitorImpl.visitArrayAnd(this, ImmutableList.of(operands.get(0))); + if (!isSafe) { + return false; + } + if (operands.get(1) instanceof RexLiteral) { + RexLiteral literal = (RexLiteral) operands.get(1); + return RexUtil.isNullLiteral(literal, true); + } + return false; + default: + break; + } + + if (sqlOperator.isSafeOperator() + || RexUtil.isLosslessCast(call) + || safeOps.contains(sqlKind) + || safeOperators.contains(sqlOperator)) { + return RexVisitorImpl.visitArrayAnd(this, call.operands); + } + + return false; + } + + @Override + public Boolean visitOver(RexOver over) { + return false; + } + + @Override + public Boolean visitCorrelVariable(RexCorrelVariable correlVariable) { + return false; + } + + @Override + public Boolean visitDynamicParam(RexDynamicParam dynamicParam) { + return false; + } + + @Override + public Boolean visitRangeRef(RexRangeRef rangeRef) { + return false; + } + + @Override + public Boolean visitFieldAccess(RexFieldAccess fieldAccess) { + return true; + } + + @Override + public Boolean visitSubQuery(RexSubQuery subQuery) { + return false; + } + + @Override + public Boolean visitTableInputRef(RexTableInputRef fieldRef) { + return false; + } + + @Override + public Boolean visitPatternFieldRef(RexPatternFieldRef fieldRef) { + return false; + } + + @Override + public Boolean visitLambda(RexLambda lambda) { + return lambda.getExpression().accept(this); + } + + @Override + public Boolean visitLambdaRef(RexLambdaRef lambdaRef) { + return true; + } + + @Override + public Boolean visitNodeAndFieldIndex(RexNodeAndFieldIndex nodeAndFieldIndex) { + return true; + } + } + + /** + * Analyzes a given {@link RexNode} and decides whenever it is safe to unwind. + * + *

"Safe" means that it only contains a combination of known good operators. + * + *

Division is an unsafe operator; consider the following: + * + *

case when a > 0 then 1 / a else null end
+ */ + static boolean isSafeExpression(RexNode r) { + return r.accept(SafeRexVisitor.INSTANCE); + } + + private @Nullable RexNode simplifyBooleanCase( + RexBuilder rexBuilder, + List inputBranches, + @SuppressWarnings("unused") RexUnknownAs unknownAs, + RelDataType branchType) { + RexNode result; + + // prepare all condition/branches for boolean interpretation + // It's done here make these interpretation changes available to case2or simplifications + // but not interfere with the normal simplification recursion + List branches = new ArrayList<>(); + for (CaseBranch branch : inputBranches) { + if ((!branches.isEmpty() && !isSafeExpression(branch.cond)) + || !isSafeExpression(branch.value)) { + return null; + } + final RexNode cond = isTrue(branch.cond); + final RexNode value; + if (!branchType.equals(branch.value.getType())) { + value = rexBuilder.makeAbstractCast(branchType, branch.value, false); + } else { + value = branch.value; + } + branches.add(new CaseBranch(cond, value)); + } + + result = simplifyBooleanCaseGeneric(rexBuilder, branches); + return result; + } + + /** + * Generic boolean case simplification. + * + *

Rewrites: + * + *

+     * CASE
+     *   WHEN p1 THEN x
+     *   WHEN p2 THEN y
+     *   ELSE z
+     * END
+     * 
+ * + * to + * + *
(p1 and x) or (p2 and y and not(p1)) or (true and z and not(p1) and not(p2))
+ */ + private static RexNode simplifyBooleanCaseGeneric( + RexBuilder rexBuilder, List branches) { + + boolean booleanBranches = + branches.stream() + .allMatch( + branch -> + branch.value.isAlwaysTrue() + || branch.value.isAlwaysFalse()); + final List terms = new ArrayList<>(); + final List notTerms = new ArrayList<>(); + for (CaseBranch branch : branches) { + boolean useBranch = !branch.value.isAlwaysFalse(); + if (useBranch) { + final RexNode branchTerm; + if (branch.value.isAlwaysTrue()) { + branchTerm = branch.cond; + } else { + branchTerm = + rexBuilder.makeCall(SqlStdOperatorTable.AND, branch.cond, branch.value); + } + terms.add(RexUtil.andNot(rexBuilder, branchTerm, notTerms)); + } + if (booleanBranches && useBranch) { + // we are safe to ignore this branch because for boolean true branches: + // a || (b && !a) === a || b + } else { + notTerms.add(branch.cond); + } + } + return RexUtil.composeDisjunction(rexBuilder, terms); + } + + @Deprecated // to be removed before 2.0 + public RexNode simplifyAnd(RexCall e) { + ensureParanoidOff(); + return simplifyAnd(e, defaultUnknownAs); + } + + RexNode simplifyAnd(RexCall e, RexUnknownAs unknownAs) { + List operands = RelOptUtil.conjunctions(e); + + if (unknownAs == FALSE && predicateElimination) { + simplifyAndTerms(operands, FALSE); + } else { + simplifyList(operands, unknownAs); + } + + final List terms = new ArrayList<>(); + final List notTerms = new ArrayList<>(); + + final SargCollector sargCollector = new SargCollector(rexBuilder, true); + operands.forEach(t -> sargCollector.accept(t, terms)); + if (sargCollector.needToFix()) { + operands.clear(); + terms.forEach(t -> operands.add(SargCollector.fix(rexBuilder, t, unknownAs))); + } + terms.clear(); + + for (RexNode o : operands) { + RelOptUtil.decomposeConjunction(o, terms, notTerms); + } + + switch (unknownAs) { + case FALSE: + return simplifyAnd2ForUnknownAsFalse(terms, notTerms, Comparable.class); + default: + break; + } + return simplifyAnd2(terms, notTerms); + } + + // package-protected only to support a deprecated method; treat as private + RexNode simplifyAnd2(List terms, List notTerms) { + for (RexNode term : terms) { + if (term.isAlwaysFalse()) { + return rexBuilder.makeLiteral(false); + } + } + if (terms.isEmpty() && notTerms.isEmpty()) { + return rexBuilder.makeLiteral(true); + } + // If one of the not-disjunctions is a disjunction that is wholly + // contained in the disjunctions list, the expression is not + // satisfiable. + // + // Example #1. x AND y AND z AND NOT (x AND y) - not satisfiable + // Example #2. x AND y AND NOT (x AND y) - not satisfiable + // Example #3. x AND y AND NOT (x AND y AND z) - may be satisfiable + List notSatisfiableNullables = null; + for (RexNode notDisjunction : notTerms) { + final List terms2 = RelOptUtil.conjunctions(notDisjunction); + if (!terms.containsAll(terms2)) { + // may be satisfiable ==> check other terms + continue; + } + if (!notDisjunction.getType().isNullable()) { + // x is NOT nullable, then x AND NOT(x) ==> FALSE + return rexBuilder.makeLiteral(false); + } + // x AND NOT(x) is UNKNOWN for NULL input + // So we search for the shortest notDisjunction then convert + // original expression to NULL and x IS NULL + if (notSatisfiableNullables == null) { + notSatisfiableNullables = new ArrayList<>(); + } + notSatisfiableNullables.add(notDisjunction); + } + + if (notSatisfiableNullables != null) { + // Remove the intersection of "terms" and "notTerms" + terms.removeAll(notSatisfiableNullables); + notTerms.removeAll(notSatisfiableNullables); + + // The intersection simplify to "null and x1 is null and x2 is null..." + terms.add(rexBuilder.makeNullLiteral(notSatisfiableNullables.get(0).getType())); + for (RexNode notSatisfiableNullable : notSatisfiableNullables) { + terms.add( + simplifyIs( + (RexCall) + rexBuilder.makeCall( + SqlStdOperatorTable.IS_NULL, + notSatisfiableNullable), + UNKNOWN)); + } + } + // Add the NOT disjunctions back in. + for (RexNode notDisjunction : notTerms) { + terms.add(simplify(not(notDisjunction), UNKNOWN)); + } + return RexUtil.composeConjunction(rexBuilder, terms); + } + + /** + * As {@link #simplifyAnd2(List, List)} but we assume that if the expression returns UNKNOWN it + * will be interpreted as FALSE. + */ + RexNode simplifyAnd2ForUnknownAsFalse(List terms, List notTerms) { + //noinspection unchecked + return simplifyAnd2ForUnknownAsFalse(terms, notTerms, Comparable.class); + } + + private > RexNode simplifyAnd2ForUnknownAsFalse( + List terms, List notTerms, Class clazz) { + for (RexNode term : terms) { + if (term.isAlwaysFalse() || RexLiteral.isNullLiteral(term)) { + return rexBuilder.makeLiteral(false); + } + } + if (terms.isEmpty() && notTerms.isEmpty()) { + return rexBuilder.makeLiteral(true); + } + if (terms.size() == 1 && notTerms.isEmpty()) { + // Make sure "x OR y OR x" (a single-term conjunction) gets simplified. + return simplify(terms.get(0), FALSE); + } + // Try to simplify the expression + final Multimap> equalityTerms = ArrayListMultimap.create(); + final Map, List>> rangeTerms = new HashMap<>(); + final Map equalityConstantTerms = new HashMap<>(); + final Set negatedTerms = new HashSet<>(); + final Set nullOperands = new HashSet<>(); + final Set notNullOperands = new LinkedHashSet<>(); + + // Add the predicates from the source to the range terms. + for (RexNode predicate : predicates.pulledUpPredicates) { + final Comparison comparison = Comparison.of(predicate); + if (comparison != null && comparison.kind != SqlKind.NOT_EQUALS) { // not supported yet + final C v0 = comparison.literal.getValueAs(clazz); + if (v0 != null) { + final RexNode result = + processRange( + rexBuilder, + terms, + rangeTerms, + predicate, + comparison.ref, + v0, + comparison.kind); + if (result != null) { + // Not satisfiable + return result; + } + } + } + } + + for (int i = 0; i < terms.size(); i++) { + RexNode term = terms.get(i); + if (!RexUtil.isDeterministic(term)) { + continue; + } + // Simplify BOOLEAN expressions if possible + while (term.getKind() == SqlKind.EQUALS) { + RexCall call = (RexCall) term; + if (call.getOperands().get(0).isAlwaysTrue()) { + term = call.getOperands().get(1); + terms.set(i, term); + continue; + } else if (call.getOperands().get(1).isAlwaysTrue()) { + term = call.getOperands().get(0); + terms.set(i, term); + continue; + } + break; + } + switch (term.getKind()) { + case EQUALS: + case NOT_EQUALS: + case LESS_THAN: + case GREATER_THAN: + case LESS_THAN_OR_EQUAL: + case GREATER_THAN_OR_EQUAL: + RexCall call = (RexCall) term; + final RexNode left = call.getOperands().get(0); + final RexNode right = call.getOperands().get(1); + final Comparison comparison = Comparison.of(term); + // Check for comparison with null values + if (comparison != null && comparison.literal.getValue() == null) { + return rexBuilder.makeLiteral(false); + } + // Check for equality on different constants. If the same ref or CAST(ref) + // is equal to different constants, this condition cannot be satisfied, + // and hence it can be evaluated to FALSE + if (term.getKind() == SqlKind.EQUALS) { + if (comparison != null) { + final RexLiteral literal = comparison.literal; + final RexLiteral prevLiteral = + equalityConstantTerms.put(comparison.ref, literal); + + if (prevLiteral != null + && literal.getType().equals(prevLiteral.getType()) + && !literal.equals(prevLiteral)) { + return rexBuilder.makeLiteral(false); + } + } else if (RexUtil.isReferenceOrAccess(left, true) + && RexUtil.isReferenceOrAccess(right, true)) { + equalityTerms.put(left, Pair.of(right, term)); + } + } + // Assume the expression a > 5 is part of a Filter condition. + // Then we can derive the negated term: a <= 5. + // But as the comparison is string based and thus operands order dependent, + // we should also add the inverted negated term: 5 >= a. + // Observe that for creating the inverted term we invert the list of operands. + RexNode negatedTerm = RexUtil.negate(rexBuilder, call); + if (negatedTerm != null) { + negatedTerms.add(negatedTerm); + RexNode invertNegatedTerm = + RexUtil.invert(rexBuilder, (RexCall) negatedTerm); + if (invertNegatedTerm != null) { + negatedTerms.add(invertNegatedTerm); + } + } + // Remove terms that are implied by predicates on the input, + // or weaken terms that are partially implied. + // E.g. given predicate "x >= 5" and term "x between 3 and 10" + // we weaken to term to "x between 5 and 10". + final RexNode term2 = simplifyUsingPredicates(term, clazz); + if (term2 != term) { + terms.set(i, term = term2); + } + // Range + if (comparison != null + && comparison.kind != SqlKind.NOT_EQUALS) { // not supported yet + final C constant = comparison.literal.getValueAs(clazz); + if (constant == null) { + break; + } + final RexNode result = + processRange( + rexBuilder, + terms, + rangeTerms, + term, + comparison.ref, + constant, + comparison.kind); + if (result != null) { + // Not satisfiable + return result; + } + } + break; + case IS_NOT_NULL: + notNullOperands.add(((RexCall) term).getOperands().get(0)); + terms.remove(i); + --i; + break; + case IS_NULL: + nullOperands.add(((RexCall) term).getOperands().get(0)); + break; + default: + break; + } + } + // Check for equality of two refs wrt equality with constants + // Example #1. x=5 AND y=5 AND x=y : x=5 AND y=5 + // Example #2. x=5 AND y=6 AND x=y - not satisfiable + for (RexNode ref1 : equalityTerms.keySet()) { + final RexLiteral literal1 = equalityConstantTerms.get(ref1); + if (literal1 == null) { + continue; + } + Collection> references = equalityTerms.get(ref1); + for (Pair ref2 : references) { + final RexLiteral literal2 = equalityConstantTerms.get(ref2.left); + if (literal2 == null) { + continue; + } + if (literal1.getType().equals(literal2.getType()) && !literal1.equals(literal2)) { + // If an expression is equal to two different constants, + // it is not satisfiable + return rexBuilder.makeLiteral(false); + } + // Otherwise we can remove the term, as we already know that + // the expression is equal to two constants + terms.remove(ref2.right); + } + } + // If one of the not-disjunctions is a disjunction that is wholly + // contained in the disjunctions list, the expression is not + // satisfiable. + // + // Example #1. x AND y AND z AND NOT (x AND y) - not satisfiable + // Example #2. x AND y AND NOT (x AND y) - not satisfiable + // Example #3. x AND y AND NOT (x AND y AND z) - may be satisfiable + final Set termsSet = new HashSet<>(terms); + for (RexNode notDisjunction : notTerms) { + if (!RexUtil.isDeterministic(notDisjunction)) { + continue; + } + final List terms2Set = RelOptUtil.conjunctions(notDisjunction); + if (termsSet.containsAll(terms2Set)) { + return rexBuilder.makeLiteral(false); + } + } + // The negated terms: only deterministic expressions + for (RexNode negatedTerm : negatedTerms) { + if (termsSet.contains(negatedTerm)) { + return rexBuilder.makeLiteral(false); + } + } + // Add the NOT disjunctions back in. + for (RexNode notDisjunction : notTerms) { + terms.add(not(notDisjunction)); + } + // Find operands that make will let whole expression evaluate to FALSE if set to NULL + final Set strongOperands = new HashSet<>(); + for (RexNode term : terms) { + if (!RexUtil.isDeterministic(term)) { + continue; + } + final VariableCollector collector = new VariableCollector(); + term.accept(collector); + for (RexInputRef ref : collector.refs) { + final boolean strong = Strong.isNotTrue(term, ImmutableBitSet.of(ref.index)); + if (strong) { + strongOperands.add(ref); + } + } + final RexUtil.FieldAccessFinder fieldAccessFinder = new RexUtil.FieldAccessFinder(); + term.accept(fieldAccessFinder); + for (RexFieldAccess rexFieldAccess : fieldAccessFinder.getFieldAccessList()) { + final boolean strong = Strong.of(ImmutableSet.of(rexFieldAccess)).isNotTrue(term); + if (strong) { + strongOperands.add(rexFieldAccess); + } + } + } + // If one column should be null and is in a comparison predicate, + // it is not satisfiable. + // Example. IS NULL(x) AND x < 5 - not satisfiable + if (!Collections.disjoint(nullOperands, strongOperands)) { + return rexBuilder.makeLiteral(false); + } + // Remove not necessary IS NOT NULL expressions. + // Example. IS NOT NULL(x) AND x < 5 : x < 5 + for (RexNode operand : notNullOperands) { + if (!strongOperands.contains(operand)) { + terms.add(rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, operand)); + } + } + return RexUtil.composeConjunction(rexBuilder, terms); + } + + private > RexNode simplifyUsingPredicates(RexNode e, Class clazz) { + if (predicates.pulledUpPredicates.isEmpty()) { + return e; + } + + final Comparison comparison = Comparison.of(e); + // Check for comparison with null values + if (comparison == null || comparison.literal.getValue() == null) { + return e; + } + + final C v0 = comparison.literal.getValueAs(clazz); + if (v0 == null) { + return e; + } + final RangeSet rangeSet = rangeSet(comparison.kind, v0); + final RangeSet rangeSet2 = + residue(comparison.ref, rangeSet, predicates.pulledUpPredicates, clazz); + if (rangeSet2.isEmpty()) { + // Term is impossible to satisfy given these predicates + return rexBuilder.makeLiteral(false); + } else if (rangeSet2.equals(rangeSet)) { + // no change + return e; + } else if (rangeSet2.equals(RangeSets.rangeSetAll())) { + // Range is always satisfied given these predicates; but nullability might + // be problematic + return simplify( + rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, comparison.ref), + RexUnknownAs.UNKNOWN); + } else if (rangeSet2.asRanges().size() == 1 + && Iterables.getOnlyElement(rangeSet2.asRanges()).hasLowerBound() + && Iterables.getOnlyElement(rangeSet2.asRanges()).hasUpperBound() + && Iterables.getOnlyElement(rangeSet2.asRanges()) + .lowerEndpoint() + .equals(Iterables.getOnlyElement(rangeSet2.asRanges()).upperEndpoint())) { + final Range r = Iterables.getOnlyElement(rangeSet2.asRanges()); + // range is now a point; it's worth simplifying + return rexBuilder.makeCall( + SqlStdOperatorTable.EQUALS, + comparison.ref, + rexBuilder.makeLiteral( + r.lowerEndpoint(), + comparison.literal.getType(), + comparison.literal.getTypeName())); + } else { + // range has been reduced but it's not worth simplifying + return e; + } + } + + /** + * Weakens a term so that it checks only what is not implied by predicates. + * + *

The term is broken into "ref comparison constant", for example "$0 < 5". + * + *

Examples: + * + *

    + *
  • {@code residue($0 < 10, [$0 < 5])} returns {@code true} + *
  • {@code residue($0 < 10, [$0 < 20, $0 > 0])} returns {@code $0 < 10} + *
+ */ + private static > RangeSet residue( + RexNode ref, RangeSet r0, List predicates, Class clazz) { + RangeSet result = r0; + for (RexNode predicate : predicates) { + switch (predicate.getKind()) { + case EQUALS: + case NOT_EQUALS: + case LESS_THAN: + case LESS_THAN_OR_EQUAL: + case GREATER_THAN: + case GREATER_THAN_OR_EQUAL: + final RexCall call = (RexCall) predicate; + final Comparison comparison = Comparison.of(call); + if (comparison != null && comparison.ref.equals(ref)) { + final C c1 = comparison.literal.getValueAs(clazz); + if (c1 == null) { + throw new AssertionError( + "value must not be null in " + comparison.literal); + } + switch (predicate.getKind()) { + case NOT_EQUALS: + // We want to intersect result with the range set of everything but + // c1. We subtract the point c1 from result, which is equivalent. + final Range pointRange = range(SqlKind.EQUALS, c1); + final RangeSet notEqualsRangeSet = + ImmutableRangeSet.of(pointRange).complement(); + if (result.enclosesAll(notEqualsRangeSet)) { + result = RangeSets.rangeSetAll(); + continue; + } + result = RangeSets.minus(result, pointRange); + break; + default: + final Range r1 = range(comparison.kind, c1); + if (result.encloses(r1)) { + // Given these predicates, term is always satisfied. + // e.g. r0 is "$0 < 10", r1 is "$0 < 5" + result = RangeSets.rangeSetAll(); + continue; + } + result = result.subRangeSet(r1); + } + if (result.isEmpty()) { + break; // short-cut + } + } + break; + default: + break; + } + } + return result; + } + + /** + * Simplifies OR(x, x) into x, and similar. The simplified expression returns UNKNOWN values as + * is (not as FALSE). + */ + @Deprecated // to be removed before 2.0 + public RexNode simplifyOr(RexCall call) { + ensureParanoidOff(); + return simplifyOr(call, UNKNOWN); + } + + private RexNode simplifyOr(RexCall call, RexUnknownAs unknownAs) { + assert call.getKind() == SqlKind.OR; + final List terms0 = RelOptUtil.disjunctions(call); + final List terms; + if (predicateElimination) { + terms = Util.moveToHead(terms0, e -> e.getKind() == SqlKind.IS_NULL); + simplifyOrTerms(terms, unknownAs); + } else { + terms = terms0; + simplifyList(terms, unknownAs); + } + return simplifyOrs(terms, unknownAs); + } + + /** + * Simplifies a list of terms and combines them into an OR. Modifies the list in place. The + * simplified expression returns UNKNOWN values as is (not as FALSE). + */ + @Deprecated // to be removed before 2.0 + public RexNode simplifyOrs(List terms) { + ensureParanoidOff(); + return simplifyOrs(terms, UNKNOWN); + } + + private void ensureParanoidOff() { + if (paranoid) { + throw new UnsupportedOperationException("Paranoid is not supported for this method"); + } + } + + /** Simplifies a list of terms and combines them into an OR. Modifies the list in place. */ + private RexNode simplifyOrs(List terms, RexUnknownAs unknownAs) { + final SargCollector sargCollector = new SargCollector(rexBuilder, false); + final List newTerms = new ArrayList<>(); + terms.forEach(t -> sargCollector.accept(t, newTerms)); + if (sargCollector.needToFix()) { + terms.clear(); + newTerms.forEach(t -> terms.add(SargCollector.fix(rexBuilder, t, unknownAs))); + } + + // CALCITE-3198 Auxiliary map to simplify cases like: + // X <> A OR X <> B => X IS NOT NULL or NULL + // The map key will be the 'X'; and the value the first call 'X<>A' that is found, + // or 'X IS NOT NULL' if a simplification takes place (because another 'X<>B' is found) + final Map notEqualsComparisonMap = new HashMap<>(); + final RexLiteral trueLiteral = rexBuilder.makeLiteral(true); + for (int i = 0; i < terms.size(); i++) { + final RexNode term = terms.get(i); + switch (term.getKind()) { + case LITERAL: + if (RexLiteral.isNullLiteral(term)) { + if (unknownAs == FALSE) { + terms.remove(i); + --i; + continue; + } else if (unknownAs == TRUE) { + return trueLiteral; + } + } else { + if (RexLiteral.booleanValue(term)) { + return term; // true + } else { + terms.remove(i); + --i; + continue; + } + } + break; + case NOT_EQUALS: + final Comparison notEqualsComparison = Comparison.of(term); + if (notEqualsComparison != null) { + // We are dealing with a X<>A term, check if we saw before another + // NOT_EQUALS involving X + final RexNode prevNotEquals = + notEqualsComparisonMap.get(notEqualsComparison.ref); + if (prevNotEquals == null) { + // This is the first NOT_EQUALS involving X, put it in the map + notEqualsComparisonMap.put(notEqualsComparison.ref, term); + } else { + // There is already in the map another NOT_EQUALS involving X: + // - if it is already an IS_NOT_NULL: it was already simplified, + // ignore this term + // - if it is not an IS_NOT_NULL (i.e. it is a NOT_EQUALS): check + // comparison values + if (prevNotEquals.getKind() != SqlKind.IS_NOT_NULL) { + final Comparable comparable1 = + notEqualsComparison.literal.getValue(); + final Comparable comparable2 = + castNonNull(Comparison.of(prevNotEquals)) + .literal + .getValue(); + //noinspection unchecked + if (comparable1 != null + && comparable2 != null + && comparable1.compareTo(comparable2) != 0) { + // X <> A OR X <> B => X IS NOT NULL OR NULL + final RexNode isNotNull = + rexBuilder.makeCall( + SqlStdOperatorTable.IS_NOT_NULL, + notEqualsComparison.ref); + final RexNode constantNull = + rexBuilder.makeNullLiteral(trueLiteral.getType()); + final RexNode newCondition = + simplify( + rexBuilder.makeCall( + SqlStdOperatorTable.OR, + isNotNull, + constantNull), + unknownAs); + if (newCondition.isAlwaysTrue()) { + return trueLiteral; + } + notEqualsComparisonMap.put(notEqualsComparison.ref, isNotNull); + final int pos = terms.indexOf(prevNotEquals); + terms.set(pos, newCondition); + } + } + terms.remove(i); + --i; + continue; + } + } + break; + case IS_NOT_TRUE: + RexNode arg = ((RexCall) term).getOperands().get(0); + if (isSafeExpression(arg) && terms.contains(arg)) { + return trueLiteral; + } + break; + case NOT: + RexNode x = ((RexCall) term).getOperands().get(0); + if (isSafeExpression(x) && terms.contains(x)) { + if (!x.getType().isNullable()) { + return trueLiteral; + } + + final RexNode isNotNull = + rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, x); + terms.set(terms.indexOf(x), simplifyIs((RexCall) isNotNull, unknownAs)); + terms.set(i, rexBuilder.makeNullLiteral(x.getType())); + i--; + } + break; + default: + break; + } + } + return RexUtil.composeDisjunction(rexBuilder, terms); + } + + private void verify(RexNode before, RexNode simplified, RexUnknownAs unknownAs) { + if (simplified.isAlwaysFalse() && before.isAlwaysTrue()) { + throw new AssertionError( + "always true [" + before + "] simplified to always false [" + simplified + "]"); + } + if (simplified.isAlwaysTrue() && before.isAlwaysFalse()) { + throw new AssertionError( + "always false [" + before + "] simplified to always true [" + simplified + "]"); + } + final RexAnalyzer foo0 = new RexAnalyzer(before, predicates); + final RexAnalyzer foo1 = new RexAnalyzer(simplified, predicates); + if (foo0.unsupportedCount > 0 || foo1.unsupportedCount > 0) { + // Analyzer cannot handle this expression currently + return; + } + if (!foo0.variables.containsAll(foo1.variables)) { + throw new AssertionError( + "variable mismatch: " + + before + + " has " + + foo0.variables + + ", " + + simplified + + " has " + + foo1.variables); + } + assignment_loop: + for (Map map : foo0.assignments()) { + for (RexNode predicate : predicates.pulledUpPredicates) { + final Comparable v = RexInterpreter.evaluate(predicate, map); + if (!Boolean.TRUE.equals(v)) { + continue assignment_loop; + } + } + Comparable v0 = RexInterpreter.evaluate(foo0.e, map); + if (v0 == null) { + throw new AssertionError("interpreter returned null for " + foo0.e); + } + Comparable v1 = RexInterpreter.evaluate(foo1.e, map); + if (v1 == null) { + throw new AssertionError("interpreter returned null for " + foo1.e); + } + if (before.getType().getSqlTypeName() == SqlTypeName.BOOLEAN) { + switch (unknownAs) { + case FALSE: + case TRUE: + if (v0 == NullSentinel.INSTANCE) { + v0 = unknownAs.toBoolean(); + } + if (v1 == NullSentinel.INSTANCE) { + v1 = unknownAs.toBoolean(); + } + break; + default: + break; + } + } + if (!v0.equals(v1)) { + throw new AssertionError( + "result mismatch (unknown as " + + unknownAs + + "): when applied to " + + map + + ",\n" + + before + + " yielded " + + v0 + + ";\n" + + simplified + + " yielded " + + v1); + } + } + } + + private RexNode simplifySearch(RexCall call, RexUnknownAs unknownAs) { + assert call.getKind() == SqlKind.SEARCH; + final RexNode operand = call.getOperands().get(0); + final RexNode simplifiedOperand = simplify(operand, unknownAs); + final boolean operandUnchanged = operand.equals(simplifiedOperand); + final RexNode searchOperand = operandUnchanged ? operand : simplifiedOperand; + if (call.getOperands().get(1) instanceof RexLiteral) { + RexLiteral literal = (RexLiteral) call.getOperands().get(1); + final Sarg sarg = castNonNull(literal.getValueAs(Sarg.class)); + if (sarg.isAll() || sarg.isNone()) { + RexNode rexNode = RexUtil.simpleSarg(rexBuilder, searchOperand, sarg, unknownAs); + return simplify(rexNode, unknownAs); + } + // Remove null from sarg if the left-hand side is never null + if (sarg.nullAs != UNKNOWN) { + final RexNode simplified = simplifyIs1(SqlKind.IS_NULL, searchOperand, unknownAs); + if (simplified != null && simplified.isAlwaysFalse()) { + final Sarg sarg2 = Sarg.of(UNKNOWN, sarg.rangeSet); + final RexLiteral literal2 = + rexBuilder.makeLiteral(sarg2, literal.getType(), literal.getTypeName()); + // Now we've strengthened the Sarg, try to simplify again + return simplifySearch( + call.clone(call.type, ImmutableList.of(searchOperand, literal2)), + unknownAs); + } + } else if (sarg.isPoints() && sarg.pointCount <= 1) { + // Expand "SEARCH(x, Sarg([point])" to "x = point" + // and "SEARCH(x, Sarg([])" to "false" + return RexUtil.expandSearch(rexBuilder, null, call); + } + } + return operandUnchanged + ? call + : call.clone( + call.type, ImmutableList.of(simplifiedOperand, call.getOperands().get(1))); + } + + private RexNode simplifyCast(RexCall e) { + RexNode operand = e.getOperands().get(0); + operand = simplify(operand, UNKNOWN); + // The type of DYNAMIC_PARAM is indeterminate, so the cast cannot be eliminated + if (operand.getKind() != SqlKind.DYNAMIC_PARAM + && sameTypeOrNarrowsNullability(e.getType(), operand.getType()) + // DECIMAL casts are never no-ops: they perform bounds checking + && e.getType().getSqlTypeName() != SqlTypeName.DECIMAL) { + return operand; + } + if (RexUtil.isLosslessCast(operand)) { + // x :: y below means cast(x as y) (which is PostgreSQL-specific cast by the way) + // A) Remove lossless casts: + // A.1) intExpr :: bigint :: int => intExpr + // A.2) char2Expr :: char(5) :: char(2) => char2Expr + // B) There are cases when we can't remove two casts, but we could probably remove inner + // one + // B.1) char2expression :: char(4) :: char(5) -> char2expression :: char(5) + // B.2) char2expression :: char(10) :: char(5) -> char2expression :: char(5) + // B.3) char2expression :: varchar(10) :: char(5) -> char2expression :: char(5) + // B.4) char6expression :: varchar(10) :: char(5) -> char6expression :: char(5) + // C) Simplification is not possible: + // C.1) char6expression :: char(3) :: char(5) -> must not be changed + // the input is truncated to 3 chars, so we can't use char6expression :: char(5) + // C.2) varchar2Expr :: char(5) :: varchar(2) -> must not be changed + // the input have to be padded with spaces (up to 2 chars) + // C.3) char2expression :: char(4) :: varchar(5) -> must not be changed + // would not have the padding + + // The approach seems to be: + // 1) Ensure inner cast is lossless (see if above) + // 2) If operand of the inner cast has the same type as the outer cast, + // remove two casts except C.2 or C.3-like pattern (== inner cast is CHAR) + // 3) If outer cast is lossless, remove inner cast (B-like cases) + + // Here we try to remove two casts in one go (A-like cases) + RexNode intExpr = ((RexCall) operand).operands.get(0); + // intExpr == CHAR detects A.1 + // operand != CHAR detects C.2 + if ((intExpr.getType().getSqlTypeName() == SqlTypeName.CHAR + || operand.getType().getSqlTypeName() != SqlTypeName.CHAR) + && sameTypeOrNarrowsNullability(e.getType(), intExpr.getType())) { + return intExpr; + } + // Here we try to remove inner cast (B-like cases) + if (RexUtil.isLosslessCast(intExpr.getType(), operand.getType()) + && (e.getType().getSqlTypeName() == operand.getType().getSqlTypeName() + || e.getType().getSqlTypeName() == SqlTypeName.CHAR + || operand.getType().getSqlTypeName() != SqlTypeName.CHAR) + && SqlTypeCoercionRule.instance() + .canApplyFrom( + intExpr.getType().getSqlTypeName(), + e.getType().getSqlTypeName())) { + return rexBuilder.makeCast(e.getParserPosition(), e.getType(), intExpr); + } + } + final boolean safe = e.getKind() == SqlKind.SAFE_CAST; + switch (operand.getKind()) { + case LITERAL: + final RexLiteral literal = (RexLiteral) operand; + final Comparable value = literal.getValueAs(Comparable.class); + final SqlTypeName typeName = literal.getTypeName(); + + // First, try to remove the cast without changing the value. + // makeCast and canRemoveCastFromLiteral have the same logic, so we are + // sure to be able to remove the cast. + if (rexBuilder.canRemoveCastFromLiteral(e.getType(), value, typeName)) { + return rexBuilder.makeCast(e.getParserPosition(), e.getType(), operand); + } + + // Next, try to convert the value to a different type, + // e.g. CAST('123' as integer) + switch (literal.getTypeName()) { + case TIME: + switch (e.getType().getSqlTypeName()) { + case TIMESTAMP: + return e; + default: + break; + } + break; + default: + break; + } + final List reducedValues = new ArrayList<>(); + final RexNode simplifiedExpr = + e.operandCount() == 2 + ? rexBuilder.makeCast( + e.getParserPosition(), + e.getType(), + operand, + safe, + safe, + (RexLiteral) e.getOperands().get(1)) + : rexBuilder.makeCast(e.getType(), operand, safe, safe); + executor.reduce(rexBuilder, ImmutableList.of(simplifiedExpr), reducedValues); + RexNode reducedRexNode = requireNonNull(Iterables.getOnlyElement(reducedValues)); + if (reducedRexNode.isA(SqlKind.CAST)) { + RexNode reducedOperand = ((RexCall) reducedRexNode).getOperands().get(0); + if (sameTypeOrNarrowsNullability( + reducedRexNode.getType(), reducedOperand.getType())) { + return reducedOperand; + } + } + return reducedRexNode; + default: + if (operand == e.getOperands().get(0)) { + return e; + } else { + return e.operands.size() > 1 + ? rexBuilder.makeCast( + e.getParserPosition(), + e.getType(), + operand, + safe, + safe, + (RexLiteral) e.getOperands().get(1)) + : rexBuilder.makeCast(e.getType(), operand, safe, safe); + } + } + } + + /** + * Tries to simplify CEIL/FLOOR function on top of CEIL/FLOOR. + * + *

Examples: + * + *

    + *
  • {@code floor(floor($0, flag(hour)), flag(day))} returns {@code floor($0, flag(day))} + *
  • {@code ceil(ceil($0, flag(second)), flag(day))} returns {@code ceil($0, flag(day))} + *
  • {@code floor(floor($0, flag(day)), flag(second))} does not change + *
+ */ + private RexNode simplifyCeilFloor(RexCall e) { + if (e.getOperands().size() != 2) { + // Bail out since we only simplify floor + return e; + } + final RexNode operand = simplify(e.getOperands().get(0), UNKNOWN); + if (e.getKind() == operand.getKind()) { + assert e.getKind() == SqlKind.CEIL || e.getKind() == SqlKind.FLOOR; + // CEIL/FLOOR on top of CEIL/FLOOR + final RexCall child = (RexCall) operand; + if (child.getOperands().size() != 2) { + // Bail out since we only simplify ceil/floor + return e; + } + final RexLiteral parentFlag = (RexLiteral) e.operands.get(1); + final TimeUnitRange parentFlagValue = (TimeUnitRange) parentFlag.getValue(); + final RexLiteral childFlag = (RexLiteral) child.operands.get(1); + final TimeUnitRange childFlagValue = (TimeUnitRange) childFlag.getValue(); + if (parentFlagValue != null && childFlagValue != null) { + if (canRollUp(parentFlagValue.startUnit, childFlagValue.startUnit)) { + return e.clone( + e.getType(), ImmutableList.of(child.getOperands().get(0), parentFlag)); + } + } + } + return e.clone(e.getType(), ImmutableList.of(operand, e.getOperands().get(1))); + } + + /** + * Simplify TRIM function by eliminating nested duplication. + * + *

Examples: + * + *

    + *
  • {@code trim(trim(' aa '))} returns {@code trim(' aa ')} + *
  • {@code trim(BOTH ' ' from trim(BOTH ' ' from ' aa '))} returns {@code trim(BOTH ' ' + * from ' aa ')} + *
  • {@code trim(LEADING 'a' from trim(BOTH ' ' from ' aa '))} does not change + *
+ */ + private RexNode simplifyTrim(RexCall e) { + if (e.getOperands().size() != 3) { + return e; + } + + RexNode trimType = simplify(e.operands.get(0)); + RexNode trimed = simplify(e.operands.get(1)); + if (e.getOperands().get(2) instanceof RexCall) { + RexCall childNode = (RexCall) e.getOperands().get(2); + // only strings with the same trim method and deduplication will be eliminated. + if (childNode.getKind() == SqlKind.TRIM + && trimType.equals(simplify(childNode.operands.get(0))) + && trimed.equals(simplify(childNode.operands.get(1)))) { + return simplifyTrim(childNode); + } + } + + ArrayList rexNodes = new ArrayList<>(); + rexNodes.add(trimType); + rexNodes.add(trimed); + rexNodes.add(simplify(e.operands.get(2))); + RexNode rexNode = rexBuilder.makeCall(e.getType(), e.getOperator(), rexNodes); + return rexNode; + } + + /** Method that returns whether we can rollup from inner time unit to outer time unit. */ + private static boolean canRollUp(TimeUnit outer, TimeUnit inner) { + // Special handling for QUARTER as it is not in the expected + // order in TimeUnit + switch (outer) { + case YEAR: + case MONTH: + case DAY: + case HOUR: + case MINUTE: + case SECOND: + case MILLISECOND: + case MICROSECOND: + switch (inner) { + case YEAR: + case QUARTER: + case MONTH: + case DAY: + case HOUR: + case MINUTE: + case SECOND: + case MILLISECOND: + case MICROSECOND: + if (inner == TimeUnit.QUARTER) { + return outer == TimeUnit.YEAR; + } + return outer.ordinal() <= inner.ordinal(); + default: + break; + } + break; + case QUARTER: + switch (inner) { + case QUARTER: + case MONTH: + case DAY: + case HOUR: + case MINUTE: + case SECOND: + case MILLISECOND: + case MICROSECOND: + return true; + default: + break; + } + break; + default: + break; + } + return false; + } + + /** + * Simplifies a measure being converted immediately (in the same SELECT clause) back to a value. + * + *

For most expressions {@code e}, simplifies "{@code m2v(v2m(e))}" to "{@code e}". For + * example, "{@code SELECT deptno + 1 AS MEASURE m}" is equivalent to "{@code SELECT deptno + 1 + * AS m}". + * + *

The exception is aggregate functions. "{@code SELECT COUNT(*) + 1 AS MEASURE m}" + * simplifies to "{@code SELECT COUNT(*) OVER (ROWS CURRENT ROW) + 1 AS MEASURE m}". + * + * @param e Call to {@code M2V} to be simplified + * @return Simplified call + */ + private RexNode simplifyM2v(RexCall e) { + assert e.op.kind == SqlKind.M2V; + final RexNode operand = e.getOperands().get(0); + switch (operand.getKind()) { + case V2M: + // M2V(V2M(x)) --> x + return flattenAggregate(((RexCall) operand).operands.get(0)); + default: + return e; + } + } + + /** + * Traverses over an expression, converting aggregate functions into single-row aggregate + * functions. + */ + private RexNode flattenAggregate(RexNode e) { + return e.accept( + new RexShuttle() { + @Override + public RexNode visitCall(RexCall call) { + if (call.op.isAggregator()) { + final RexWindow w = + rexBuilder.makeWindow( + ImmutableList.of(), + ImmutableList.of(), + RexWindowBounds.CURRENT_ROW, + RexWindowBounds.CURRENT_ROW, + true); + return new RexOver( + call.type, + (SqlAggFunction) call.op, + call.operands, + w, + false, + false); + } + return super.visitCall(call); + } + }); + } + + /** + * Removes any casts that change nullability but not type. + * + *

For example, {@code CAST(1 = 0 AS BOOLEAN)} becomes {@code 1 = 0}. + */ + public RexNode removeNullabilityCast(RexNode e) { + return RexUtil.removeNullabilityCast(rexBuilder.getTypeFactory(), e); + } + + private static > @Nullable RexNode processRange( + RexBuilder rexBuilder, + List terms, + Map, List>> rangeTerms, + RexNode term, + RexNode ref, + C v0, + SqlKind comparison) { + Pair, List> p = rangeTerms.get(ref); + if (p == null) { + rangeTerms.put(ref, Pair.of(range(comparison, v0), ImmutableList.of(term))); + } else { + // Exists + boolean removeUpperBound = false; + boolean removeLowerBound = false; + Range r = p.left; + final RexLiteral trueLiteral = rexBuilder.makeLiteral(true); + switch (comparison) { + case EQUALS: + if (!r.contains(v0)) { + // Range is empty, not satisfiable + return rexBuilder.makeLiteral(false); + } + rangeTerms.put(ref, Pair.of(Range.singleton(v0), ImmutableList.of(term))); + // remove + for (RexNode e : p.right) { + replaceLast(terms, e, trueLiteral); + } + break; + case LESS_THAN: + { + int comparisonResult = 0; + if (r.hasUpperBound()) { + comparisonResult = v0.compareTo(r.upperEndpoint()); + } + if (comparisonResult <= 0) { + // 1) No upper bound, or + // 2) We need to open the upper bound, or + // 3) New upper bound is lower than old upper bound + if (r.hasLowerBound()) { + if (v0.compareTo(r.lowerEndpoint()) <= 0) { + // Range is empty, not satisfiable + return rexBuilder.makeLiteral(false); + } + // a <= x < b OR a < x < b + r = + Range.range( + r.lowerEndpoint(), + r.lowerBoundType(), + v0, + BoundType.OPEN); + } else { + // x < b + r = Range.lessThan(v0); + } + + if (r.isEmpty()) { + // Range is empty, not satisfiable + return rexBuilder.makeLiteral(false); + } + + // remove prev upper bound + removeUpperBound = true; + } else { + // Remove this term as it is contained in current upper bound + replaceLast(terms, term, trueLiteral); + } + break; + } + case LESS_THAN_OR_EQUAL: + { + int comparisonResult = -1; + if (r.hasUpperBound()) { + comparisonResult = v0.compareTo(r.upperEndpoint()); + } + if (comparisonResult < 0) { + // 1) No upper bound, or + // 2) New upper bound is lower than old upper bound + if (r.hasLowerBound()) { + if (v0.compareTo(r.lowerEndpoint()) < 0) { + // Range is empty, not satisfiable + return rexBuilder.makeLiteral(false); + } + // a <= x <= b OR a < x <= b + r = + Range.range( + r.lowerEndpoint(), + r.lowerBoundType(), + v0, + BoundType.CLOSED); + } else { + // x <= b + r = Range.atMost(v0); + } + + if (r.isEmpty()) { + // Range is empty, not satisfiable + return rexBuilder.makeLiteral(false); + } + + // remove prev upper bound + removeUpperBound = true; + } else { + // Remove this term as it is contained in current upper bound + replaceLast(terms, term, trueLiteral); + } + break; + } + case GREATER_THAN: + { + int comparisonResult = 0; + if (r.hasLowerBound()) { + comparisonResult = v0.compareTo(r.lowerEndpoint()); + } + if (comparisonResult >= 0) { + // 1) No lower bound, or + // 2) We need to open the lower bound, or + // 3) New lower bound is greater than old lower bound + if (r.hasUpperBound()) { + if (v0.compareTo(r.upperEndpoint()) >= 0) { + // Range is empty, not satisfiable + return rexBuilder.makeLiteral(false); + } + // a < x <= b OR a < x < b + r = + Range.range( + v0, + BoundType.OPEN, + r.upperEndpoint(), + r.upperBoundType()); + } else { + // x > a + r = Range.greaterThan(v0); + } + + if (r.isEmpty()) { + // Range is empty, not satisfiable + return rexBuilder.makeLiteral(false); + } + + // remove prev lower bound + removeLowerBound = true; + } else { + // Remove this term as it is contained in current lower bound + replaceLast(terms, term, trueLiteral); + } + break; + } + case GREATER_THAN_OR_EQUAL: + { + int comparisonResult = 1; + if (r.hasLowerBound()) { + comparisonResult = v0.compareTo(r.lowerEndpoint()); + } + if (comparisonResult > 0) { + // 1) No lower bound, or + // 2) New lower bound is greater than old lower bound + if (r.hasUpperBound()) { + if (v0.compareTo(r.upperEndpoint()) > 0) { + // Range is empty, not satisfiable + return rexBuilder.makeLiteral(false); + } + // a <= x <= b OR a <= x < b + r = + Range.range( + v0, + BoundType.CLOSED, + r.upperEndpoint(), + r.upperBoundType()); + } else { + // x >= a + r = Range.atLeast(v0); + } + + if (r.isEmpty()) { + // Range is empty, not satisfiable + return rexBuilder.makeLiteral(false); + } + + // remove prev lower bound + removeLowerBound = true; + } else { + // Remove this term as it is contained in current lower bound + replaceLast(terms, term, trueLiteral); + } + break; + } + default: + throw new AssertionError(); + } + if (removeUpperBound) { + ImmutableList.Builder newBounds = ImmutableList.builder(); + for (RexNode e : p.right) { + if (isUpperBound(e)) { + replaceLast(terms, e, trueLiteral); + } else { + newBounds.add(e); + } + } + newBounds.add(term); + rangeTerms.put(ref, Pair.of(r, newBounds.build())); + } else if (removeLowerBound) { + ImmutableList.Builder newBounds = ImmutableList.builder(); + for (RexNode e : p.right) { + if (isLowerBound(e)) { + replaceLast(terms, e, trueLiteral); + } else { + newBounds.add(e); + } + } + newBounds.add(term); + rangeTerms.put(ref, Pair.of(r, newBounds.build())); + } + } + // Default + return null; + } + + private static > Range range(SqlKind comparison, C c) { + switch (comparison) { + case EQUALS: + return Range.singleton(c); + case LESS_THAN: + return Range.lessThan(c); + case LESS_THAN_OR_EQUAL: + return Range.atMost(c); + case GREATER_THAN: + return Range.greaterThan(c); + case GREATER_THAN_OR_EQUAL: + return Range.atLeast(c); + default: + throw new AssertionError(); + } + } + + private static > RangeSet rangeSet(SqlKind comparison, C c) { + switch (comparison) { + case EQUALS: + case LESS_THAN: + case LESS_THAN_OR_EQUAL: + case GREATER_THAN: + case GREATER_THAN_OR_EQUAL: + return ImmutableRangeSet.of(range(comparison, c)); + case NOT_EQUALS: + return ImmutableRangeSet.builder() + .add(range(SqlKind.LESS_THAN, c)) + .add(range(SqlKind.GREATER_THAN, c)) + .build(); + default: + throw new AssertionError(); + } + } + + /** Marker interface for predicates (expressions that evaluate to BOOLEAN). */ + private interface Predicate { + /** Wraps an expression in a Predicate or returns null. */ + static @Nullable Predicate of(RexNode t) { + final Predicate p = Comparison.of(t); + if (p != null) { + return p; + } + return IsPredicate.of(t); + } + + /** Returns whether this predicate can be used while simplifying other OR operands. */ + default boolean allowedInOr(RelOptPredicateList predicates) { + return true; + } + } + + /** Visitor which finds all inputs used by an expressions. */ + private static class VariableCollector extends RexVisitorImpl { + private final Set refs = new HashSet<>(); + + VariableCollector() { + super(true); + } + + @Override + public Void visitInputRef(RexInputRef inputRef) { + refs.add(inputRef); + return super.visitInputRef(inputRef); + } + } + + /** + * Represents a simple Comparison. + * + *

Left hand side is a {@link RexNode}, right hand side is a literal. + */ + private static class Comparison implements Predicate { + final RexNode ref; + final SqlKind kind; + final RexLiteral literal; + + private Comparison(RexNode ref, SqlKind kind, RexLiteral literal) { + this.ref = requireNonNull(ref, "ref"); + this.kind = requireNonNull(kind, "kind"); + this.literal = requireNonNull(literal, "literal"); + } + + /** + * Creates a comparison, between a {@link RexInputRef} or {@link RexFieldAccess} or + * deterministic {@link RexCall} and a literal. + */ + static @Nullable Comparison of(RexNode e) { + return of( + e, + node -> + RexUtil.isReferenceOrAccess(node, true) + || RexUtil.isDeterministic(node)); + } + + /** Creates a comparison, or returns null. */ + static @Nullable Comparison of( + RexNode e, java.util.function.Predicate nodePredicate) { + switch (e.getKind()) { + case EQUALS: + case NOT_EQUALS: + case LESS_THAN: + case GREATER_THAN: + case LESS_THAN_OR_EQUAL: + case GREATER_THAN_OR_EQUAL: + final RexCall call = (RexCall) e; + final RexNode left = call.getOperands().get(0); + final RexNode right = call.getOperands().get(1); + switch (right.getKind()) { + case LITERAL: + if (nodePredicate.test(left)) { + return new Comparison(left, e.getKind(), (RexLiteral) right); + } + break; + default: + break; + } + switch (left.getKind()) { + case LITERAL: + if (nodePredicate.test(right)) { + return new Comparison( + right, e.getKind().reverse(), (RexLiteral) left); + } + break; + default: + break; + } + break; + default: + break; + } + return null; + } + + @Override + public boolean allowedInOr(RelOptPredicateList predicates) { + // if ref is not a 'loss-less' cast then can't be allowed to be used + // while simplifying other OR operands + if (ref.isA(SqlKind.CAST) && !RexUtil.isLosslessCast(ref)) { + return false; + } + return !ref.getType().isNullable() || predicates.isEffectivelyNotNull(ref); + } + } + + /** Represents an IS Predicate. */ + private static class IsPredicate implements Predicate { + final RexNode ref; + final SqlKind kind; + + private IsPredicate(RexNode ref, SqlKind kind) { + this.ref = requireNonNull(ref, "ref"); + this.kind = requireNonNull(kind, "kind"); + } + + /** Creates an IS predicate, or returns null. */ + static @Nullable IsPredicate of(RexNode e) { + switch (e.getKind()) { + case IS_NULL: + case IS_NOT_NULL: + RexNode pA = ((RexCall) e).getOperands().get(0); + if (!(RexUtil.isReferenceOrAccess(pA, true) || RexUtil.isDeterministic(pA))) { + return null; + } + return new IsPredicate(pA, e.getKind()); + default: + break; + } + return null; + } + } + + private static boolean isUpperBound(final RexNode e) { + final List operands; + switch (e.getKind()) { + case LESS_THAN: + case LESS_THAN_OR_EQUAL: + operands = ((RexCall) e).getOperands(); + return RexUtil.isReferenceOrAccess(operands.get(0), true) + && operands.get(1).isA(SqlKind.LITERAL); + case GREATER_THAN: + case GREATER_THAN_OR_EQUAL: + operands = ((RexCall) e).getOperands(); + return RexUtil.isReferenceOrAccess(operands.get(1), true) + && operands.get(0).isA(SqlKind.LITERAL); + default: + return false; + } + } + + private static boolean isLowerBound(final RexNode e) { + final List operands; + switch (e.getKind()) { + case LESS_THAN: + case LESS_THAN_OR_EQUAL: + operands = ((RexCall) e).getOperands(); + return RexUtil.isReferenceOrAccess(operands.get(1), true) + && operands.get(0).isA(SqlKind.LITERAL); + case GREATER_THAN: + case GREATER_THAN_OR_EQUAL: + operands = ((RexCall) e).getOperands(); + return RexUtil.isReferenceOrAccess(operands.get(0), true) + && operands.get(1).isA(SqlKind.LITERAL); + default: + return false; + } + } + + /** + * Combines predicates AND, optimizes, and returns null if the result is always false. + * + *

The expression is simplified on the assumption that an UNKNOWN value is always treated as + * FALSE. Therefore the simplified expression may sometimes evaluate to FALSE where the original + * expression would evaluate to UNKNOWN. + * + * @param predicates Filter condition predicates + * @return simplified conjunction of predicates for the filter, null if always false + */ + public @Nullable RexNode simplifyFilterPredicates(Iterable predicates) { + final RexNode simplifiedAnds = + withPredicateElimination(Bug.CALCITE_2401_FIXED) + .simplifyUnknownAsFalse(RexUtil.composeConjunction(rexBuilder, predicates)); + if (simplifiedAnds.isAlwaysFalse()) { + return null; + } + + // Remove cast of BOOLEAN NOT NULL to BOOLEAN or vice versa. Filter accepts + // nullable and not-nullable conditions, but a CAST might get in the way of + // other rewrites. + return removeNullabilityCast(simplifiedAnds); + } + + /** + * Replaces the last occurrence of one specified value in a list with another. + * + *

Does not change the size of the list. + * + *

Returns whether the value was found. + */ + private static boolean replaceLast(List list, E oldVal, E newVal) { + @SuppressWarnings("argument.type.incompatible") + final int index = list.lastIndexOf(oldVal); + if (index < 0) { + return false; + } + list.set(index, newVal); + return true; + } + + /** Gathers expressions that can be converted into {@link Sarg search arguments}. */ + static class SargCollector { + final Map map = new HashMap<>(); + private final RexBuilder rexBuilder; + private final boolean negate; + + /** + * Count of the new terms after converting all the operands to {@code SEARCH} on a {@link + * Sarg}. It is used to decide whether the new terms are simpler. + */ + private int newTermsCount; + + SargCollector(RexBuilder rexBuilder, boolean negate) { + this.rexBuilder = rexBuilder; + this.negate = negate; + } + + /** + * Accepts an expression and converts it to a Sarg if possible. + * + * @param term the expression to convert + * @param newTerms the list holding the result of the conversion or the original term if it + * cannot be converted + */ + private void accept(RexNode term, List newTerms) { + if (!accept_(term, newTerms)) { + newTerms.add(term); + } + newTermsCount = newTerms.size(); + } + + /** + * Accepts an expression and converts it to a Sarg if possible. Only certain kinds of + * expressions can be converted to Sargs. + * + * @param e the expression to convert + * @param newTerms the list to which the Sarg will be added if the expression is accepted + * @return true if the expression can be converted to a Sarg, false otherwise + */ + private boolean accept_(RexNode e, List newTerms) { + switch (e.getKind()) { + case LESS_THAN: + case LESS_THAN_OR_EQUAL: + case GREATER_THAN: + case GREATER_THAN_OR_EQUAL: + case EQUALS: + case NOT_EQUALS: + case SEARCH: + case IS_NOT_DISTINCT_FROM: + case IS_DISTINCT_FROM: + return accept2( + ((RexCall) e).operands.get(0), + ((RexCall) e).operands.get(1), + e.getKind(), + newTerms); + case IS_NULL: + case IS_NOT_NULL: + final RexNode arg = ((RexCall) e).operands.get(0); + return accept1(arg, e.getKind(), newTerms); + default: + return false; + } + } + + /** + * Accepts two operands from a binary comparison operator and converts them to a Sarg if + * possible. Only comparisons between a literal and a deterministic expression can be + * converted. Simplifications with non-deterministic expressions are generally avoided to + * ensure consistent results. + * + * @param left the left operand of the comparison + * @param right the right operand of the comparison + * @param kind the kind of comparison operator + * @param newTerms the list to which the Sarg will be added if accepted + * @return true if the operands can be converted to a Sarg, false otherwise + */ + private boolean accept2(RexNode left, RexNode right, SqlKind kind, List newTerms) { + if (right.isA(SqlKind.LITERAL) && RexUtil.isDeterministic(left)) { + return accept2b(left, kind, (RexLiteral) right, newTerms); + } + if (left.isA(SqlKind.LITERAL) && RexUtil.isDeterministic(right)) { + return accept2b(right, kind.reverse(), (RexLiteral) left, newTerms); + } + return false; + } + + private static E addFluent(List list, E e) { + list.add(e); + return e; + } + + /** + * Accepts an operand from a null predicate and converts it to a Sarg. + * + * @param e the operand of the null predicate + * @param kind the kind of the null predicate + * @param newTerms the list to which the Sarg is added + * @return true since the operand is always converted to a Sarg + */ + private boolean accept1(RexNode e, SqlKind kind, List newTerms) { + final RexSargBuilder b = + map.computeIfAbsent( + e, + e2 -> addFluent(newTerms, new RexSargBuilder(e2, rexBuilder, negate))); + switch (negate ? kind.negate() : kind) { + case IS_NULL: + b.nullAs = b.nullAs.or(TRUE); + return true; + case IS_NOT_NULL: + b.nullAs = b.nullAs.or(FALSE); + b.addAll(); + return true; + default: + throw new AssertionError("unexpected " + kind); + } + } + + /** + * Accepts two operands from a binary comparison operator and converts them to a Sarg when + * the literal is not null. The conversion depends on the kind of comparison operator and + * only certain operators are supported. + * + * @param e any kind of deterministic expression + * @param kind the kind of comparison operator + * @param literal the literal operand of the comparison + * @param newTerms the list to which the Sarg is added if accepted + * @return false if the literal operand is null, true otherwise + */ + private boolean accept2b( + RexNode e, SqlKind kind, RexLiteral literal, List newTerms) { + if (literal.getValue() == null) { + // Cannot include expressions 'x > NULL' in a Sarg. Comparing to a NULL + // literal is a daft thing to do, because it always returns UNKNOWN. It + // is better handled by other simplifications. + return false; + } + final RexSargBuilder b = + map.computeIfAbsent( + e, + e2 -> addFluent(newTerms, new RexSargBuilder(e2, rexBuilder, negate))); + if (negate) { + kind = kind.negateNullSafe(); + } + final Comparable value = requireNonNull(literal.getValueAs(Comparable.class), "value"); + switch (kind) { + case LESS_THAN: + b.addRange(Range.lessThan(value), literal.getType()); + return true; + case LESS_THAN_OR_EQUAL: + b.addRange(Range.atMost(value), literal.getType()); + return true; + case GREATER_THAN: + b.addRange(Range.greaterThan(value), literal.getType()); + return true; + case GREATER_THAN_OR_EQUAL: + b.addRange(Range.atLeast(value), literal.getType()); + return true; + case EQUALS: + b.addRange(Range.singleton(value), literal.getType()); + return true; + case IS_NOT_DISTINCT_FROM: + b.addRange(Range.singleton(value), literal.getType(), FALSE); + return true; + case NOT_EQUALS: + b.addRange(Range.lessThan(value), literal.getType()); + b.addRange(Range.greaterThan(value), literal.getType()); + return true; + case IS_DISTINCT_FROM: + b.addRange(Range.lessThan(value), literal.getType(), TRUE); + b.addRange(Range.greaterThan(value), literal.getType(), TRUE); + return true; + case SEARCH: + final Sarg sarg = (Sarg) value; + b.addSarg(sarg, negate, literal.getType()); + return true; + default: + throw new AssertionError("unexpected " + kind); + } + } + + /** Returns whether it is worth to fix and convert to {@code SEARCH} calls. */ + boolean needToFix() { + // Fix and converts to SEARCH if: + // 1. A Sarg has complexity greater than 1; + // 2. A Sarg was merged with another Sarg or range; + // 3. The terms are reduced as simpler Sarg points; + // 4. The terms are reduced as simpler Sarg comparison. + + // Ignore 'negate' just to be compatible with previous versions of this + // method. "build().complexity()" would be a better estimate, if we could + // switch to it breaking lots of plans. + final Collection builders = map.values(); + return builders.stream().anyMatch(b -> b.build(false).complexity() > 1 || b.mergedSarg) + || newTermsCount == 1 && builders.stream().allMatch(b -> simpleSarg(b.build())); + } + + /** + * Returns whether this Sarg can be expanded to more simple form, e.g. the IN call or single + * comparison. + */ + private static boolean simpleSarg(Sarg sarg) { + return sarg.isPoints() + || RangeSets.isOpenInterval(sarg.rangeSet) + || sarg.isComplementedPoints(); + } + + /** + * If a term is a call to {@code SEARCH} on a {@link RexSargBuilder}, converts it to a + * {@code SEARCH} on a {@link Sarg}. + */ + static RexNode fix(RexBuilder rexBuilder, RexNode term, RexUnknownAs unknownAs) { + if (term instanceof RexSargBuilder) { + final RexSargBuilder sargBuilder = (RexSargBuilder) term; + final Sarg sarg = sargBuilder.build(); + boolean isSmall = sarg.complexity() <= 1 || sarg.isAll() || sarg.isNone(); + if (isSmall && simpleSarg(sarg)) { + // Expand small sargs into comparisons in order to avoid plan changes + // and better readability. + return RexUtil.sargRef( + rexBuilder, sargBuilder.ref, sarg, term.getType(), unknownAs); + } + return rexBuilder.makeCall( + SqlStdOperatorTable.SEARCH, + sargBuilder.ref, + rexBuilder.makeSearchArgumentLiteral(sarg, term.getType())); + } + return term; + } + } + + /** + * Equivalent to a {@link RexLiteral} whose value is a {@link Sarg}, but mutable, so that the + * Sarg can be expanded as {@link SargCollector} traverses a list of OR or AND terms. + * + *

The {@link SargCollector#fix} method converts it to an immutable literal. + * + *

The {@link #nullAs} field will become {@link Sarg#nullAs}, as follows: + * + *

    + *
  • If there is at least one term that returns TRUE when the argument is NULL, then the + * overall value will be TRUE; failing that, + *
  • if there is at least one term that returns UNKNOWN when the argument is NULL, then the + * overall value will be UNKNOWN; failing that, + *
  • the value will be FALSE. + *
+ * + *

This is analogous to the behavior of OR in three-valued logic: {@code TRUE OR UNKNOWN OR + * FALSE} returns {@code TRUE}; {@code UNKNOWN OR FALSE OR UNKNOWN} returns {@code UNKNOWN}; + * {@code FALSE OR FALSE} returns {@code FALSE}. + */ + private static class RexSargBuilder extends RexNode { + final RexNode ref; + final RexBuilder rexBuilder; + final boolean negate; + final List types = new ArrayList<>(); + final RangeSet rangeSet = TreeRangeSet.create(); + boolean hasSarg; + boolean mergedSarg; + RexUnknownAs nullAs = FALSE; + + RexSargBuilder(RexNode ref, RexBuilder rexBuilder, boolean negate) { + this.ref = requireNonNull(ref, "ref"); + this.rexBuilder = requireNonNull(rexBuilder, "rexBuilder"); + this.negate = negate; + } + + @Override + public String toString() { + return "SEARCH(" + + ref + + ", " + + (negate ? "NOT " : "") + + rangeSet + + "; NULL AS " + + nullAs + + ")"; + } + + > Sarg build() { + return build(negate); + } + + @SuppressWarnings({"rawtypes", "unchecked"}) + > Sarg build(boolean negate) { + final RangeSet r = (RangeSet) this.rangeSet; + if (negate) { + return Sarg.of(nullAs.negate(), r.complement()); + } else { + return Sarg.of(nullAs, r); + } + } + + @Override + public RelDataType getType() { + if (this.types.isEmpty()) { + // Expression is "x IS NULL" + return ref.getType(); + } + final List distinctTypes = Util.distinctList(this.types); + return requireNonNull( + rexBuilder.typeFactory.leastRestrictive(distinctTypes), + () -> "Can't find leastRestrictive type among " + distinctTypes); + } + + @Override + public R accept(RexVisitor visitor) { + throw new UnsupportedOperationException(); + } + + @Override + public R accept(RexBiVisitor visitor, P arg) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean equals(@Nullable Object obj) { + throw new UnsupportedOperationException(); + } + + @Override + public int hashCode() { + throw new UnsupportedOperationException(); + } + + void addAll() { + rangeSet.add(Range.all()); + } + + void addRange(Range range, RelDataType type) { + addRange(range, type, UNKNOWN); + } + + void addRange(Range range, RelDataType type, RexUnknownAs unknownAs) { + types.add(type); + rangeSet.add(range); + mergedSarg |= hasSarg; + nullAs = nullAs.or(unknownAs); + } + + @SuppressWarnings({"rawtypes", "unchecked"}) + void addSarg(Sarg sarg, boolean negate, RelDataType type) { + final RangeSet r; + final RexUnknownAs nullAs; + if (negate) { + r = sarg.rangeSet.complement(); + nullAs = sarg.nullAs.negate(); + } else { + r = sarg.rangeSet; + nullAs = sarg.nullAs; + } + types.add(type); + rangeSet.addAll(r); + mergedSarg |= !rangeSet.isEmpty(); + hasSarg = true; + switch (nullAs) { + case TRUE: + this.nullAs = this.nullAs.or(TRUE); + break; + case FALSE: + this.nullAs = this.nullAs.or(FALSE); + break; + case UNKNOWN: + this.nullAs = this.nullAs.or(UNKNOWN); + break; + } + } + } +} diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexUtil.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexUtil.java index 65e905b367472..fb9d701c18229 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexUtil.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rex/RexUtil.java @@ -25,12 +25,15 @@ import com.google.common.collect.Range; import org.apache.calcite.DataContexts; import org.apache.calcite.linq4j.function.Predicate1; +import org.apache.calcite.plan.PlanTooComplexError; import org.apache.calcite.plan.RelOptPredicateList; import org.apache.calcite.plan.RelOptUtil; import org.apache.calcite.rel.RelCollation; import org.apache.calcite.rel.RelCollations; import org.apache.calcite.rel.RelFieldCollation; +import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Calc; +import org.apache.calcite.rel.core.CorrelationId; import org.apache.calcite.rel.core.Filter; import org.apache.calcite.rel.core.Join; import org.apache.calcite.rel.core.Project; @@ -805,6 +808,11 @@ public Boolean visitLambda(RexLambda lambda) { public Boolean visitLambdaRef(RexLambdaRef lambdaRef) { return false; } + + @Override + public Boolean visitNodeAndFieldIndex(RexNodeAndFieldIndex nodeAndFieldIndex) { + return false; + } } /** @@ -844,13 +852,13 @@ public Void visitCall(RexCall call) { } public static List retainDeterministic(List list) { - List conjuctions = new ArrayList<>(); + List conjunctions = new ArrayList<>(); for (RexNode x : list) { if (isDeterministic(x)) { - conjuctions.add(x); + conjunctions.add(x); } } - return conjuctions; + return conjunctions; } /** @@ -1659,22 +1667,31 @@ public static boolean isLosslessCast(RexNode node) { * @return 'true' when the conversion can certainly be determined to be loss-less cast, but may * return 'false' for some lossless casts. */ - @API(since = "1.22", status = API.Status.EXPERIMENTAL) + @API(since = "1.22", status = API.Status.STABLE) public static boolean isLosslessCast(RelDataType source, RelDataType target) { final SqlTypeName sourceSqlTypeName = source.getSqlTypeName(); final SqlTypeName targetSqlTypeName = target.getSqlTypeName(); - // 1) Both INT numeric types - if (SqlTypeFamily.INTEGER.getTypeNames().contains(sourceSqlTypeName) - && SqlTypeFamily.INTEGER.getTypeNames().contains(targetSqlTypeName)) { - return targetSqlTypeName.compareTo(sourceSqlTypeName) >= 0; + + // INT -> INT: use range containment (signed/unsigned) + if (SqlTypeUtil.isIntType(source) && SqlTypeUtil.isIntType(target)) { + final boolean sourceIsUnsigned = + SqlTypeFamily.UNSIGNED_NUMERIC.getTypeNames().contains(sourceSqlTypeName); + final boolean targetIsUnsigned = + SqlTypeFamily.UNSIGNED_NUMERIC.getTypeNames().contains(targetSqlTypeName); + if (!sourceIsUnsigned && targetIsUnsigned) { + return false; + } + return SqlTypeUtil.integerRangeContains(target, source); } - // 2) Both CHARACTER types: it depends on the precision (length) + + // CHARACTER -> CHARACTER: valid if target order >= source and length grows if (SqlTypeFamily.CHARACTER.getTypeNames().contains(sourceSqlTypeName) && SqlTypeFamily.CHARACTER.getTypeNames().contains(targetSqlTypeName)) { return targetSqlTypeName.compareTo(sourceSqlTypeName) >= 0 && source.getPrecision() <= target.getPrecision(); } - // 3) From NUMERIC family to CHARACTER family: it depends on the precision/scale + + // NUMERIC -> CHARACTER: allow when target length accommodates sign/scale if (sourceSqlTypeName.getFamily() == SqlTypeFamily.NUMERIC && targetSqlTypeName.getFamily() == SqlTypeFamily.CHARACTER) { int sourceLength = source.getPrecision() + 1; // include sign @@ -1684,7 +1701,66 @@ public static boolean isLosslessCast(RelDataType source, RelDataType target) { final int targetPrecision = target.getPrecision(); return targetPrecision == PRECISION_NOT_SPECIFIED || targetPrecision >= sourceLength; } - // Return FALSE by default + + // DECIMAL -> DECIMAL: allow when precision/scale can only expand + if (sourceSqlTypeName == SqlTypeName.DECIMAL && targetSqlTypeName == SqlTypeName.DECIMAL) { + int sourcePrecision = source.getPrecision(); + int sourceScale = Math.max(source.getScale(), 0); + int targetPrecision = target.getPrecision(); + int targetScale = Math.max(target.getScale(), 0); + if (sourcePrecision <= 0 || targetPrecision <= 0) { + return false; + } + return targetScale >= sourceScale + && (targetPrecision - targetScale) >= (sourcePrecision - sourceScale); + } + + // INT (signed/unsigned) -> DECIMAL: valid if integer digits fit within target precision + if (SqlTypeUtil.isIntType(source) && targetSqlTypeName == SqlTypeName.DECIMAL) { + int targetPrecision = target.getPrecision(); + int targetScale = Math.max(target.getScale(), 0); + int sourcePrecision = source.getPrecision(); + return sourcePrecision > 0 && (targetPrecision - targetScale) >= sourcePrecision; + } + + // DECIMAL -> INT: only when scale = 0 and range fits target integer type + if (sourceSqlTypeName == SqlTypeName.DECIMAL && SqlTypeUtil.isIntType(target)) { + if (source.getScale() != 0) { + return false; + } + return SqlTypeUtil.integerRangeContains(target, source); + } + + // APPROXIMATE NUMERIC -> APPROXIMATE NUMERIC: allow when target precision >= source + // precision + if (SqlTypeFamily.APPROXIMATE_NUMERIC.getTypeNames().contains(sourceSqlTypeName) + && SqlTypeFamily.APPROXIMATE_NUMERIC.getTypeNames().contains(targetSqlTypeName)) { + // Lossless if target has at least as many significant digits as source + final int sourcePrecision = source.getPrecision(); + final int targetPrecision = target.getPrecision(); + return targetPrecision >= sourcePrecision; + } + + // EXACT NUMERIC -> APPROXIMATE NUMERIC: allow only for scale=0 values within target digits + if (SqlTypeFamily.EXACT_NUMERIC.getTypeNames().contains(sourceSqlTypeName) + && SqlTypeFamily.APPROXIMATE_NUMERIC.getTypeNames().contains(targetSqlTypeName)) { + final int targetPrecision = target.getPrecision(); + + // DECIMAL -> APPROXIMATE NUMERIC + if (sourceSqlTypeName == SqlTypeName.DECIMAL) { + final int sourcePrecision = source.getPrecision(); + if (sourcePrecision <= 0 || source.getScale() != 0) { + return false; + } + // scale is 0, just check precision + return sourcePrecision <= targetPrecision; + } + + // INT (signed/unsigned) -> APPROXIMATE NUMERIC + int sourcePrecision = source.getPrecision(); + return sourcePrecision > 0 && sourcePrecision <= targetPrecision; + } + return false; } @@ -1865,6 +1941,36 @@ public RexNode visitInputRef(RexInputRef input) { }); } + /** + * Shifts every {@link RexFieldAccess} with {@link CorrelationId} in an {@link RelNode} by + * {@code offset}. + */ + public static RelNode shiftFieldAccess( + RexBuilder rexBuilder, + RelNode node, + final CorrelationId id, + RelNode outer, + final int offset) { + if (offset == 0) { + return node; + } + + RexNode correl = rexBuilder.makeCorrel(outer.getRowType(), id); + return node.accept( + new RexShuttle() { + @Override + public RexNode visitFieldAccess(RexFieldAccess fieldAccess) { + if (fieldAccess.getReferenceExpr() instanceof RexCorrelVariable + && ((RexCorrelVariable) fieldAccess.getReferenceExpr()) + .id.equals(id)) { + return rexBuilder.makeFieldAccess( + correl, fieldAccess.getField().getIndex() + offset); + } + return fieldAccess; + } + }); + } + /** * Creates an equivalent version of a node where common factors among ORs are pulled up. * @@ -2984,12 +3090,6 @@ public RexNode visitInputRef(RexInputRef input) { } } - /** - * Exception to catch when optimizing the plan produces a result that is too complex, either at - * the Rel or at the Rex level. - */ - private static class PlanTooComplexError extends ControlFlowException {} - /** * Visitor that throws {@link org.apache.calcite.util.Util.FoundOne} if applied to an expression * that contains a {@link RexCorrelVariable}. diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/runtime/SqlFunctions.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/runtime/SqlFunctions.java index e5a41b36323ca..781331b9b817a 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/runtime/SqlFunctions.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/runtime/SqlFunctions.java @@ -42,6 +42,7 @@ import org.apache.calcite.linq4j.function.NonDeterministic; import org.apache.calcite.linq4j.function.Predicate1; import org.apache.calcite.linq4j.tree.Primitive; +import org.apache.calcite.linq4j.tree.UnsignedType; import org.apache.calcite.rel.type.TimeFrame; import org.apache.calcite.rel.type.TimeFrameSet; import org.apache.calcite.runtime.FlatLists.ComparableList; @@ -65,10 +66,15 @@ import org.apache.commons.codec.binary.Hex; import org.apache.commons.codec.digest.DigestUtils; import org.apache.commons.codec.language.Soundex; -import org.apache.commons.lang3.StringEscapeUtils; import org.apache.commons.math3.util.CombinatoricsUtils; +import org.apache.commons.text.StringEscapeUtils; import org.apache.commons.text.similarity.LevenshteinDistance; import org.checkerframework.checker.nullness.qual.Nullable; +import org.joou.UByte; +import org.joou.UInteger; +import org.joou.ULong; +import org.joou.UShort; +import org.joou.Unsigned; import java.lang.reflect.Field; import java.math.BigDecimal; @@ -2530,6 +2536,22 @@ public static Object plusAny(Object b0, Object b1) { throw notArithmetic("+", b0, b1); } + public static UByte plus(UByte b0, UByte b1) { + return (b0 == null || b1 == null) ? castNonNull(null) : b0.add(b1); + } + + public static UShort plus(UShort b0, UShort b1) { + return (b0 == null || b1 == null) ? castNonNull(null) : b0.add(b1); + } + + public static UInteger plus(UInteger b0, UInteger b1) { + return (b0 == null || b1 == null) ? castNonNull(null) : b0.add(b1); + } + + public static ULong plus(ULong b0, ULong b1) { + return (b0 == null || b1 == null) ? castNonNull(null) : b0.add(b1); + } + // checked + static byte intToByte(int value) { @@ -2564,6 +2586,22 @@ public static long checkedPlus(long b0, long b1) { return Math.addExact(b0, b1); } + public static UByte checkedPlus(UByte b0, UByte b1) { + return b0.add(b1); + } + + public static UShort checkedPlus(UShort b0, UShort b1) { + return b0.add(b1); + } + + public static UInteger checkedPlus(UInteger b0, UInteger b1) { + return b0.add(b1); + } + + public static ULong checkedPlus(ULong b0, ULong b1) { + return b0.add(b1); + } + // - /** SQL - operator applied to int values. */ @@ -2596,6 +2634,11 @@ public static Long minus(Integer b0, Long b1) { return (b0 == null || b1 == null) ? castNonNull(null) : (b0.longValue() - b1.longValue()); } + /** SQL - operator applied to nullable long and long values. */ + public static Long minus(Long b0, Long b1) { + return (b0 == null || b1 == null) ? castNonNull(null) : b0.longValue() - b1.longValue(); + } + /** SQL - operator applied to nullable BigDecimal values. */ public static BigDecimal minus(BigDecimal b0, BigDecimal b1) { return (b0 == null || b1 == null) ? castNonNull(null) : b0.subtract(b1); @@ -2617,6 +2660,23 @@ public static Object minusAny(Object b0, Object b1) { throw notArithmetic("-", b0, b1); } + public static UByte minus(UByte b0, UByte b1) { + return (b0 == null || b1 == null) ? castNonNull(null) : b0.subtract(b1); + } + + public static UShort minus(UShort b0, UShort b1) { + return (b0 == null || b1 == null) ? castNonNull(null) : b0.subtract(b1); + } + + public static UInteger minus(UInteger b0, UInteger b1) { + return (b0 == null || b1 == null) ? castNonNull(null) : b0.subtract(b1); + } + + /** SQL - operator applied to nullable unsigned long and long values. */ + public static ULong minus(ULong b0, ULong b1) { + return (b0 == null || b1 == null) ? castNonNull(null) : b0.subtract(b1); + } + // checked - public static byte checkedMinus(byte b0, byte b1) { @@ -2651,6 +2711,38 @@ public static long checkedUnaryMinus(long b) { return Math.subtractExact(0, b); } + public static UByte checkedMinus(UByte b0, UByte b1) { + return b0.subtract(b1); + } + + public static UShort checkedMinus(UShort b0, UShort b1) { + return b0.subtract(b1); + } + + public static UInteger checkedMinus(UInteger b0, UInteger b1) { + return b0.subtract(b1); + } + + public static ULong checkedMinus(ULong b0, ULong b1) { + return b0.subtract(b1); + } + + public static UByte checkedUnaryMinus(UByte b) { + return Unsigned.ubyte(0).subtract(b); + } + + public static UShort checkedUnaryMinus(UShort b) { + return Unsigned.ushort(0).subtract(b); + } + + public static UInteger checkedUnaryMinus(UInteger b) { + return Unsigned.uint(0).subtract(b); + } + + public static ULong checkedUnaryMinus(ULong b) { + return Unsigned.ulong(0).subtract(b); + } + // / /** SQL / operator applied to int values. */ @@ -2714,6 +2806,31 @@ public static long divide(long b0, BigDecimal b1) { return BigDecimal.valueOf(b0).divide(b1, RoundingMode.HALF_DOWN).longValue(); } + public static UByte divide(UByte b0, UByte b1) { + return (b0 == null || b1 == null) + ? castNonNull(null) + : UByte.valueOf(b0.intValue() / b1.intValue()); + } + + public static UShort divide(UShort b0, UShort b1) { + return (b0 == null || b1 == null) + ? castNonNull(null) + : UShort.valueOf(b0.intValue() / b1.intValue()); + } + + public static UInteger divide(UInteger b0, UInteger b1) { + return (b0 == null || b1 == null) + ? castNonNull(null) + : UInteger.valueOf(b0.longValue() / b1.longValue()); + } + + public static ULong divide(ULong b0, ULong b1) { + return (b0 == null || b1 == null) + ? castNonNull(null) + : ULong.valueOf( + UnsignedType.toBigInteger(b0).divide(UnsignedType.toBigInteger(b1))); + } + public static byte checkedDivide(byte b0, byte b1) { return intToByte(b0 / b1); } @@ -2742,6 +2859,22 @@ public static long checkedDivide(long b0, long b1) { } } + public static UByte checkedDivide(UByte b0, UByte b1) { + return UByte.valueOf(b0.intValue() / b1.intValue()); + } + + public static UShort checkedDivide(UShort b0, UShort b1) { + return UShort.valueOf(b0.intValue() / b1.intValue()); + } + + public static UInteger checkedDivide(UInteger b0, UInteger b1) { + return UInteger.valueOf(b0.longValue() / b1.longValue()); + } + + public static ULong checkedDivide(ULong b0, ULong b1) { + return ULong.valueOf(UnsignedType.toBigInteger(b0).divide(UnsignedType.toBigInteger(b1))); + } + // * /** SQL * operator applied to int values. */ @@ -2764,6 +2897,32 @@ public static Integer multiply(Integer b0, Integer b1) { return (b0 == null || b1 == null) ? castNonNull(null) : (b0 * b1); } + public static UByte multiply(UByte b0, UByte b1) { + return (b0 == null || b1 == null) + ? castNonNull(null) + : UByte.valueOf(b0.longValue() * b1.longValue()); + } + + public static UShort multiply(UShort b0, UShort b1) { + return (b0 == null || b1 == null) + ? castNonNull(null) + : UShort.valueOf(b0.intValue() * b1.intValue()); + } + + public static UInteger multiply(UInteger b0, UInteger b1) { + return (b0 == null || b1 == null) + ? castNonNull(null) + : UInteger.valueOf(b0.longValue() * b1.longValue()); + } + + public static ULong multiply(ULong b0, ULong b1) { + if (b0 == null || b1 == null) { + return castNonNull(null); + } + BigInteger result = UnsignedType.toBigInteger(b0).multiply(UnsignedType.toBigInteger(b1)); + return ULong.valueOf(result); + } + /** SQL * operator applied to nullable long and int values. */ public static Long multiply(Long b0, Integer b1) { return (b0 == null || b1 == null) ? castNonNull(null) : (b0.longValue() * b1.longValue()); @@ -2813,6 +2972,22 @@ public static long checkedMultiply(long b0, long b1) { return Math.multiplyExact(b0, b1); } + public static UByte checkedMultiply(UByte b0, UByte b1) { + return UByte.valueOf(b0.intValue() * b1.intValue()); + } + + public static UShort checkedMultiply(UShort b0, UShort b1) { + return UShort.valueOf(b0.intValue() * b1.intValue()); + } + + public static UInteger checkedMultiply(UInteger b0, UInteger b1) { + return UInteger.valueOf(b0.longValue() * b1.longValue()); + } + + public static ULong checkedMultiply(ULong b0, ULong b1) { + return ULong.valueOf(UnsignedType.toBigInteger(b0).multiply(UnsignedType.toBigInteger(b1))); + } + /** SQL SAFE_ADD function applied to long values. */ public static @Nullable Long safeAdd(long b0, long b1) { try { @@ -3105,6 +3280,126 @@ public static ByteString bitAnd(ByteString b0, ByteString b1) { return binaryOperator(b0, b1, (x, y) -> (byte) (x & y)); } + /** + * Bitwise function BITAND applied to {@link org.joou.UByte} values. Returns {@code + * null} if any operand is null. + */ + public static UByte bitAnd(UByte b0, UByte b1) { + return UByte.valueOf((short) (b0.shortValue() & b1.shortValue())); + } + + /** + * Bitwise function BITAND applied to {@link org.joou.UShort} values. Returns + * {@code null} if any operand is null. + */ + public static UShort bitAnd(UShort b0, UShort b1) { + return UShort.valueOf(b0.intValue() & b1.intValue()); + } + + /** + * Bitwise function BITAND applied to {@link org.joou.UInteger} values. Returns + * {@code null} if any operand is null. + */ + public static UInteger bitAnd(UInteger b0, UInteger b1) { + return UInteger.valueOf(b0.longValue() & b1.longValue()); + } + + /** + * Bitwise function BITAND applied to {@link org.joou.ULong} values. Returns {@code + * null} if any operand is null. + */ + public static ULong bitAnd(ULong b0, ULong b1) { + return ULong.valueOf(b0.longValue() & b1.longValue()); + } + + /** + * Bitwise function BITAND applied to {@link org.joou.UInteger} and {@link Integer} + * values. Returns {@code null} if any operand is null. + */ + public static long bitAnd(UInteger b0, long b1) { + return b0.intValue() & b1; + } + + /** + * Bitwise function BITAND applied to {@link org.joou.ULong} and {@link long} + * values. Returns {@code null} if any operand is null. + */ + public static long bitAnd(ULong b0, long b1) { + return b0.longValue() & b1; + } + + /** + * Bitwise function BITAND applied to {@link long} and {@link org.joou.UInteger} + * values. Returns {@code null} if any operand is null. + */ + public static long bitAnd(long b1, ULong b2) { + return b1 & b2.longValue(); + } + + /** + * Bitwise function BITAND applied to {@link long} and {@link org.joou.UInteger} + * values. Returns {@code null} if any operand is null. + */ + public static long bitAnd(long b1, UInteger b2) { + return b1 & b2.longValue(); + } + + /** + * Bitwise function BITAND applied to {@link org.joou.UShort} and {@link Integer} + * values. Returns {@code null} if any operand is null. + */ + public static long bitAnd(UShort b0, long b1) { + return b0.intValue() & b1; + } + + /** + * Bitwise function BITAND applied to {@link org.joou.UShort} and {@link Integer} + * values. Returns {@code null} if any operand is null. + */ + public static long bitAnd(long b0, UShort b1) { + return b0 & b1.intValue(); + } + + /** + * Bitwise function BITAND applied to {@link org.joou.UInteger} and {@link Integer} + * values. Returns {@code null} if any operand is null. + */ + public static UInteger bitAnd(UInteger b0, int b1) { + return UInteger.valueOf(b0.intValue() & b1); + } + + /** + * Bitwise function BITAND applied to {@link org.joou.UInteger} and {@link Integer} + * values. Returns {@code null} if any operand is null. + */ + public static ULong bitAnd(ULong b0, int b1) { + return ULong.valueOf(b0.longValue() & b1); + } + + /** + * Bitwise function BITAND applied to {@link Integer} and {@link org.joou.UInteger} + * values. Returns {@code null} if any operand is null. + */ + public static ULong bitAnd(int b1, ULong b2) { + return ULong.valueOf(b1 & b2.longValue()); + } + + /** + * Bitwise function BITAND applied to {@link org.joou.UShort} and {@link Integer} + * values. Returns {@code null} if any operand is null. + */ + public static Integer bitAnd(int b0, UShort b1) { + return b0 & b1.intValue(); + } + + /** + * Bitwise function BITAND applied to {@link org.joou.UShort} and {@link Integer} + * values. Returns {@code null} if any operand is null. + */ + public static Integer bitAnd(UShort b0, int b1) { + return b0.intValue() & b1; + } + /** * Helper function for implementing BITCOUNT. Counts the number of bits set in an * integer value. @@ -3232,16 +3527,16 @@ public static long bitXor(long b0, long b1) { } /** - * Bitwise function BITXOR applied to a Long and int value. Needed for handling - * NULL for the first argument. + * Bitwise function BITXOR applied to a Long and int value. Overload to support + * type coercion between boxed Long and primitive int. */ public static long bitXor(Long b0, int b1) { return b0 ^ b1; } /** - * Bitwise function BITXOR applied to a Long and int value. Needed for handling - * NULL for the second argument. + * Bitwise function BITXOR applied to a Long and int value. Overload to support + * type coercion between boxed Long and primitive int. */ public static long bitXor(int b0, Long b1) { return b0 ^ b1; @@ -3267,6 +3562,62 @@ public static ByteString bitNot(ByteString b) { return new ByteString(result); } + /** + * Bitwise function BITXOR applied to {@link Long} values. Returns {@code null} if + * any operand is null. + */ + public static long bitXor(Long b0, Long b1) { + return b0 ^ b1; + } + + /** + * Bitwise function BITXOR applied to {@link Integer} values. Returns {@code null} + * if any operand is null. + */ + public static long bitXor(Integer b0, Integer b1) { + return b0 ^ b1; + } + + /** + * Bitwise function BITXOR applied to {@link org.joou.UByte} values. Returns {@code + * null} if any operand is null. + */ + public static UByte bitXor(UByte b0, UByte b1) { + return UByte.valueOf(b0.shortValue() ^ b1.shortValue()); + } + + /** + * Bitwise function BITXOR applied to {@link org.joou.UShort} values. Returns + * {@code null} if any operand is null. + */ + public static UShort bitXor(UShort b0, UShort b1) { + return UShort.valueOf(b0.intValue() ^ b1.intValue()); + } + + /** + * Bitwise function BITXOR applied to {@link org.joou.UInteger} values. Returns + * {@code null} if any operand is null. + */ + public static UInteger bitXor(UInteger b0, UInteger b1) { + return UInteger.valueOf(b0.longValue() ^ b1.longValue()); + } + + /** + * Bitwise function BITXOR applied to {@link org.joou.ULong} values. Returns {@code + * null} if any operand is null. + */ + public static ULong bitXor(ULong b0, ULong b1) { + return ULong.valueOf(b0.longValue() ^ b1.longValue()); + } + + public static @Nullable Object bitXor(@Nullable Object b0, @Nullable Object b1) { + if (b0 == null || b1 == null) { + return null; + } + throw new IllegalArgumentException( + "Invalid arguments for BITXOR: " + "" + b0.getClass() + ", " + b1.getClass()); + } + /** * Utility for bitwise function applied to two byteString values. * @@ -3296,6 +3647,139 @@ private static ByteString binaryOperator( return new ByteString(result); } + /** + * Performs PostgresSQL-style bitwise shift on a 32-bit integer. + * + * @param x the integer value to shift + * @param y the shift amount (positive: left shift, negative: right shift) + * @return the shifted integer + */ + public static int leftShift(int x, int y) { + int shift = ((y % 32) + 32) % 32; // normalize to 0~31 + return y >= 0 ? x << shift : x >> shift; // arithmetic right shift + } + + // ----------------- long ----------------- + /** + * Performs PostgresSQL-style bitwise shift on a 64-bit long value. + * + * @param x the long value to shift + * @param y the shift amount + * @return the shifted long value + */ + public static long leftShift(long x, int y) { + int shift = ((y % 64) + 64) % 64; // normalize to 0~63 + return y >= 0 ? x << shift : x >> shift; + } + + /** + * Performs PostgresSQL-style bitwise shift on an int value with a long shift amount. + * + * @param x the int value to shift + * @param y the long shift amount + * @return the shifted value as long + */ + public static long leftShift(int x, long y) { + int shift = (int) (((y % 32) + 32) % 32); // normalize to 0~31 + return y >= 0 ? (long) x << shift : (long) x >> shift; + } + + /** + * Performs PostgresSQL-style bitwise shift on a byte array. Positive shift: left shift. + * Negative shift: treated as positive shift with modulo arithmetic. + * + * @param bytes the input byte array + * @param y the shift amount in bits + * @return the shifted byte array + */ + public static byte[] leftShift(byte[] bytes, int y) { + if (bytes.length == 0) { + return new byte[0]; + } + + int bitLen = bytes.length * 8; + + // PostgreSQL behavior: always treat as left shift with modulo arithmetic + // Negative y becomes equivalent positive shift + int shift = ((y % bitLen) + bitLen) % bitLen; + + if (shift == 0) { + return bytes.clone(); + } + + byte[] result = new byte[bytes.length]; + + // Always perform left shift (even for originally negative y) + int byteShift = shift / 8; + int bitShift = shift % 8; + + for (int i = 0; i < bytes.length; i++) { + int srcIndex = i - byteShift; + int val = 0; + + // Get the main byte + if (srcIndex >= 0) { + val = (bytes[srcIndex] & 0xFF) << bitShift; + } + + // Get carry bits from previous byte + if (srcIndex - 1 >= 0 && bitShift != 0) { + val |= (bytes[srcIndex - 1] & 0xFF) >>> (8 - bitShift); + } + + result[i] = (byte) val; + } + return result; + } + + /** + * Performs PostgresSQL-style bitwise shift on ByteString. + * + * @param bytes the ByteString to shift + * @param y the shift amount in bits + * @return shifted ByteString + */ + public static ByteString leftShift(ByteString bytes, int y) { + return new ByteString(leftShift(bytes.getBytes(), y)); + } + + /** Performs PostgresSQL-style bitwise shift on UByte. Overflow bits are masked to 8 bits. */ + public static UByte leftShift(UByte x, int y) { + int shift = ((y % 8) + 8) % 8; + int val = x.byteValue() & 0xFF; + val = (y >= 0) ? (val << shift) & 0xFF : (val >> shift) & 0xFF; + return UByte.valueOf((byte) val); + } + + /** Performs PostgresSQL-style bitwise shift on UShort. Overflow bits are masked to 16 bits. */ + public static UShort leftShift(UShort x, int y) { + int shift = ((y % 16) + 16) % 16; + int val = x.shortValue() & 0xFFFF; + val = (y >= 0) ? (val << shift) & 0xFFFF : (val >> shift) & 0xFFFF; + return UShort.valueOf((short) val); + } + + /** + * Performs PostgresSQL-style bitwise shift on UInteger. Overflow bits are masked to 32 bits. + */ + public static UInteger leftShift(UInteger x, int y) { + int shift = ((y % 32) + 32) % 32; + long val = x.longValue() & 0xFFFFFFFFL; + val = (y >= 0) ? (val << shift) & 0xFFFFFFFFL : (val >> shift) & 0xFFFFFFFFL; + return UInteger.valueOf(val); + } + + /** + * Performs PostgresSQL-style bitwise shift on ULong. Overflow bits are masked to 64 bits (long + * shifts naturally truncate). + */ + public static ULong leftShift(ULong x, int y) { + int shift = ((y % 64) + 64) % 64; + long val = x.longValue(); + val = (y >= 0) ? val << shift : val >> shift; + return ULong.valueOf(val); + } + // EXP /** SQL EXP operator applied to double values. */ @@ -5870,7 +6354,7 @@ public static String replace( return s.replace(search, replacement); } // for MSSQL's REPLACE function, search pattern is case-insensitive during matching - return org.apache.commons.lang3.StringUtils.replaceIgnoreCase(s, search, replacement); + return org.apache.commons.lang3.Strings.CI.replace(s, search, replacement); } /** @@ -6667,12 +7151,12 @@ public static String arrayToString(List list, String delimiter, @Nullable String * array by splitting the input string value into subvalues using the specified string value as * the "delimiter". Optionally, allows a specified string value to be interpreted as NULL. */ - public static List stringToArray(String string, @Nullable String delimiter) { + public static List<@Nullable String> stringToArray(String string, @Nullable String delimiter) { return stringToArray(string, delimiter, null); } /** SQL {@code STRING_TO_ARRAY(string, delimiter, nullString)} function. */ - public static List stringToArray( + public static List<@Nullable String> stringToArray( String string, @Nullable String delimiter, @Nullable String nullString) { String[] parts; if (delimiter == null) { @@ -6685,7 +7169,7 @@ public static List stringToArray( } else { parts = string.split(delimiter); } - List result = new ArrayList<>(parts.length); + List<@Nullable String> result = new ArrayList<>(parts.length); for (String part : parts) { if (nullString != null && nullString.equals(part)) { result.add(null); diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/fun/SqlCoalesceFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/fun/SqlCoalesceFunction.java index 556a0c2abd2bf..da666184d9e76 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/fun/SqlCoalesceFunction.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/fun/SqlCoalesceFunction.java @@ -123,11 +123,10 @@ public RelDataType inferReturnType(SqlOperatorBinding opBinding) { } if (returnType == null) { - throw new IllegalArgumentException( - "Cannot infer return type for " - + opBinding.getOperator() - + "; operand types: " - + opBinding.collectOperandTypes()); + throw opBinding.newError( + RESOURCE.cannotInferReturnType( + opBinding.getOperator().toString(), + opBinding.collectOperandTypes().toString())); } return returnType; diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/type/SqlTypeFactoryImpl.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/type/SqlTypeFactoryImpl.java index 6e564abe68e47..eb8bbaff448ea 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/type/SqlTypeFactoryImpl.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/type/SqlTypeFactoryImpl.java @@ -329,6 +329,8 @@ private static void assertBasic(SqlTypeName typeName) { ? createSqlType(typeName, type.getPrecision()) : createSqlType(typeName); type = createTypeWithNullability(type, originalType.isNullable()); + // update java type's family + family = type.getFamily(); } if (resultType == null) { diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/type/SqlTypeUtil.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/type/SqlTypeUtil.java index 2aa111e918a81..bab7f620637ff 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/type/SqlTypeUtil.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/type/SqlTypeUtil.java @@ -42,6 +42,7 @@ import org.apache.calcite.sql.validate.SqlValidator; import org.apache.calcite.sql.validate.SqlValidatorScope; import org.apache.calcite.sql.validate.SqlValidatorUtil; +import org.apache.calcite.sql2rel.SqlToRelConverter; import org.apache.calcite.util.NumberUtil; import org.apache.calcite.util.Pair; import org.apache.calcite.util.Util; @@ -50,6 +51,7 @@ import org.checkerframework.checker.nullness.qual.Nullable; import java.math.BigDecimal; +import java.math.BigInteger; import java.math.RoundingMode; import java.nio.charset.Charset; import java.util.AbstractList; @@ -120,6 +122,40 @@ public static boolean isCharTypeComparable(List argTypes) { return true; } + /** + * True if there are literals with the specified data type. Some data types do not have literals + * (e.g., UNSIGNED, ROW). + * + * @param type Type for literals. + */ + public static boolean hasLiterals(RelDataType type) { + switch (type.getSqlTypeName()) { + case UTINYINT: + case USMALLINT: + case UINTEGER: + case UBIGINT: + case ANY: + case SYMBOL: + case MULTISET: + case ARRAY: + case MAP: + case DISTINCT: + case STRUCTURED: + case ROW: + case OTHER: + case CURSOR: + case COLUMN_LIST: + case DYNAMIC_STAR: + case GEOMETRY: + case MEASURE: + case FUNCTION: + case SARG: + return false; + default: + return true; + } + } + /** * Returns whether the operands to a call are char type-comparable. * @@ -145,12 +181,58 @@ public static boolean isCharTypeComparable( return true; } + /** + * A namespace may contain multiple fields with the same name. However, a proper ROW type + * cannot; this function will assign unique names to fields when they are used to build a + * concrete ROW type. + * + * @param type A type for which type.isStruct() is true. + * @return A new version of this type where all fields have unique names. + *

Note: the same rule to rename fields is used by the {@link SqlToRelConverter} later, + * in convertNonAggregateSelectList (a private method). This ensures that the generated + * field names there will match with the inferred field names here. + */ + private static RelDataType uniquify(RelDataTypeFactory factory, RelDataType type) { + List unique = SqlValidatorUtil.uniquify(type.getFieldNames(), true); + List types = + type.getFieldList().stream() + .map(RelDataTypeField::getType) + .collect(Collectors.toList()); + return factory.createStructType(type.getStructKind(), types, unique); + } + + /** + * Derives component type for ARRAY, MULTISET, MAP when input is sub-query. + * + * @param factory Type factory used to generate new types if necessary + * @param origin original component type + * @return component type + */ + public static RelDataType deriveCollectionQueryComponentType( + RelDataTypeFactory factory, SqlTypeName collectionType, RelDataType origin) { + switch (collectionType) { + case ARRAY: + case MULTISET: + return origin.isStruct() && origin.getFieldCount() == 1 + ? origin.getFieldList().get(0).getType() + : uniquify(factory, origin); + case MAP: + return origin; + default: + throw new AssertionError( + "Impossible to derive component type for " + collectionType); + } + } + /** * Derives component type for ARRAY, MULTISET, MAP when input is sub-query. * * @param origin original component type * @return component type + * @deprecated Use {@link SqlTypeUtil#deriveCollectionQueryComponentType(RelDataTypeFactory, + * SqlTypeName, RelDataType)} */ + @Deprecated public static RelDataType deriveCollectionQueryComponentType( SqlTypeName collectionType, RelDataType origin) { switch (collectionType) { @@ -464,6 +546,10 @@ public static boolean isIntType(RelDataType type) { case SMALLINT: case INTEGER: case BIGINT: + case UTINYINT: + case USMALLINT: + case UINTEGER: + case UBIGINT: return true; default: return false; @@ -490,6 +576,10 @@ public static boolean isExactNumeric(RelDataType type) { case SMALLINT: case INTEGER: case BIGINT: + case UTINYINT: + case USMALLINT: + case UINTEGER: + case UBIGINT: case DECIMAL: return true; default: @@ -497,12 +587,88 @@ public static boolean isExactNumeric(RelDataType type) { } } + /** + * Returns whether {@code container} can represent every value produced by {@code content} + * without loss of information. + * + *

The {@code container} type must be one of the integer types (signed or unsigned). The + * {@code content} type can be integer, or a DECIMAL with scale {@code 0}. For all other types + * this method returns {@code false}. + * + * @throws IllegalArgumentException if {@code container} is not an integer type + */ + public static boolean integerRangeContains(RelDataType container, RelDataType content) { + checkArgument(isIntType(container), "container must be an integer type: %s", container); + + final SqlTypeName contentType = content.getSqlTypeName(); + final boolean contentIsDecimal = contentType == SqlTypeName.DECIMAL; + if (!isIntType(content) && (!contentIsDecimal || content.getScale() != 0)) { + return false; + } + + final BigInteger containerMin = integerBound(container, false); + final BigInteger containerMax = integerBound(container, true); + if (containerMin == null || containerMax == null) { + return false; + } + + final BigInteger contentMin = integerBound(content, false); + final BigInteger contentMax = integerBound(content, true); + if (contentMin == null || contentMax == null) { + return false; + } + + return containerMin.compareTo(contentMin) <= 0 && containerMax.compareTo(contentMax) >= 0; + } + + /** + * Returns the numeric bound for an integer or zero-scale decimal type. + * + * @param type Type whose bounds should be computed + * @param upper If {@code true}, returns the maximum inclusive bound; otherwise returns the + * minimum bound + * @return Bound as {@link BigInteger}, or {@code null} if the bound cannot be determined (for + * example, type is not integer or has non-zero scale) + */ + public static @Nullable BigInteger integerBound(RelDataType type, boolean upper) { + final SqlTypeName typeName = type.getSqlTypeName(); + + final boolean isDecimal = typeName == SqlTypeName.DECIMAL; + if (!isDecimal && !isIntType(type)) { + return null; + } + if (isDecimal && type.getScale() != 0) { + return null; + } + + final int precision = isDecimal ? type.getPrecision() : -1; + final int scale = isDecimal ? type.getScale() : -1; + final Object limit = + typeName.getLimit(upper, SqlTypeName.Limit.OVERFLOW, false, precision, scale); + if (limit == null) { + return null; + } + if (limit instanceof BigDecimal) { + try { + return ((BigDecimal) limit).toBigIntegerExact(); + } catch (ArithmeticException ignored) { + return null; + } + } + if (limit instanceof Number) { + return BigInteger.valueOf(((Number) limit).longValue()); + } + return null; + } + /** Returns whether a type's scale is set. */ public static boolean hasScale(RelDataType type) { return type.getScale() != Integer.MIN_VALUE; } - /** Returns the maximum value of an integral type, as a long value. */ + /** + * Returns the maximum value of an integral type, as a long value. DOES NOT WORK FOR UBIGINT. + */ public static long maxValue(RelDataType type) { assert SqlTypeUtil.isIntType(type); switch (type.getSqlTypeName()) { @@ -512,6 +678,12 @@ public static long maxValue(RelDataType type) { return Short.MAX_VALUE; case INTEGER: return Integer.MAX_VALUE; + case UTINYINT: + return 255; + case USMALLINT: + return 65535; + case UINTEGER: + return (1L << 32) - 1; case BIGINT: return Long.MAX_VALUE; default: @@ -648,8 +820,9 @@ public static int getMaxByteSize(RelDataType type) { /** * Returns the minimum unscaled value of a numeric type. * - * @param type a numeric type + * @deprecated Use {@link #integerBound(RelDataType, boolean)} with {@code upper = false} */ + @Deprecated // to be removed before 2.0 public static long getMinValue(RelDataType type) { SqlTypeName typeName = type.getSqlTypeName(); switch (typeName) { @@ -659,6 +832,11 @@ public static long getMinValue(RelDataType type) { return Short.MIN_VALUE; case INTEGER: return Integer.MIN_VALUE; + case UTINYINT: + case USMALLINT: + case UINTEGER: + case UBIGINT: + return 0; case BIGINT: case DECIMAL: return NumberUtil.getMinUnscaled(type.getPrecision()).longValue(); @@ -668,17 +846,25 @@ public static long getMinValue(RelDataType type) { } /** - * Returns the maximum unscaled value of a numeric type. + * Returns the maximum unscaled value of a numeric type. DOES NOT WORK CORRECTLY FOR U/BIGINT + * and many DECIMAL types. * - * @param type a numeric type + * @deprecated Use {@link #integerBound(RelDataType, boolean)} with {@code upper = true} */ + @Deprecated // to be removed before 2.0 public static long getMaxValue(RelDataType type) { SqlTypeName typeName = type.getSqlTypeName(); switch (typeName) { + case UTINYINT: + return 255; case TINYINT: return Byte.MAX_VALUE; + case USMALLINT: + return (1 << 16) - 1; case SMALLINT: return Short.MAX_VALUE; + case UINTEGER: + return (1L << 32) - 1; case INTEGER: return Integer.MAX_VALUE; case BIGINT: @@ -892,7 +1078,8 @@ public static boolean canCastFrom( return true; } if (toType.getSqlTypeName() == SqlTypeName.UUID) { - return fromType.getSqlTypeName() == SqlTypeName.UUID + return fromType.getSqlTypeName() == SqlTypeName.NULL + || fromType.getSqlTypeName() == SqlTypeName.UUID || fromType.getFamily() == SqlTypeFamily.CHARACTER || fromType.getFamily() == SqlTypeFamily.BINARY; } @@ -1081,7 +1268,8 @@ public static SqlDataTypeSpec convertTypeToSpec( if (isAtomic(type) || isNull(type) || type.getSqlTypeName() == SqlTypeName.UNKNOWN - || type.getSqlTypeName() == SqlTypeName.GEOMETRY) { + || type.getSqlTypeName() == SqlTypeName.GEOMETRY + || SqlTypeUtil.isInterval(type)) { int precision = typeName.allowsPrec() ? type.getPrecision() @@ -1099,9 +1287,11 @@ public static SqlDataTypeSpec convertTypeToSpec( new SqlBasicTypeNameSpec( typeName, precision, scale, charSetName, SqlParserPos.ZERO); } else if (isCollection(type)) { + RelDataType componentType = getComponentTypeOrThrow(type); typeNameSpec = new SqlCollectionTypeNameSpec( - convertTypeToSpec(getComponentTypeOrThrow(type)).getTypeNameSpec(), + convertTypeToSpec(componentType).getTypeNameSpec(), + componentType.isNullable(), typeName, SqlParserPos.ZERO); } else if (isRow(type)) { @@ -1139,9 +1329,7 @@ public static SqlDataTypeSpec convertTypeToSpec( // REVIEW angel 11-Jan-2006: // Use neg numbers to indicate unspecified precision/scale - // FLINK MODIFICATION BEGIN return new SqlDataTypeSpec(typeNameSpec, SqlParserPos.ZERO).withNullable(type.isNullable()); - // FLINK MODIFICATION BEGIN } /** @@ -1458,7 +1646,6 @@ public static boolean isFlat(RelDataType type) { * @return Whether types are comparable */ public static boolean isComparable(RelDataType type1, RelDataType type2) { - // FLINK MODIFICATION BEGIN Calcite-7230 final RelDataTypeFamily family1 = family(type1); final RelDataTypeFamily family2 = family(type2); @@ -1466,7 +1653,7 @@ public static boolean isComparable(RelDataType type1, RelDataType type2) { if (family1 == SqlTypeFamily.NULL || family2 == SqlTypeFamily.NULL) { return true; } - // FLINK MODIFICATION END + if (type1.isStruct() != type2.isStruct()) { return false; } @@ -1516,11 +1703,6 @@ public static boolean isComparable(RelDataType type1, RelDataType type2) { return true; } - // If one of the arguments is of type 'NULL', return true. - if (family1 == SqlTypeFamily.NULL || family2 == SqlTypeFamily.NULL) { - return true; - } - // We can implicitly convert from character to date return family1 == SqlTypeFamily.CHARACTER && canConvertStringInCompare(family2) || family2 == SqlTypeFamily.CHARACTER && canConvertStringInCompare(family1); @@ -1807,7 +1989,8 @@ public static boolean isAtomic(RelDataType type) { return SqlTypeUtil.isDatetime(type) || SqlTypeUtil.isNumeric(type) || SqlTypeUtil.isString(type) - || SqlTypeUtil.isBoolean(type); + || SqlTypeUtil.isBoolean(type) + || typeName == SqlTypeName.UUID; } /** Returns a DECIMAL type with the maximum precision for the current type system. */ diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java index 640765c51b48d..fa33edc4103e1 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java @@ -523,6 +523,15 @@ private boolean expandSelectItem( // calls. selectScope = getSelectScope(select); expanded = expandSelectExpr(selectItem, scope, select, expansions); + + // Non-strict GROUP BY: wrap non-aggregated, non-grouped columns in ANY_VALUE() + if (isAggregate(select) + && config.conformance().isNonStrictGroupBy() + && isNonAggregatedNonGroupedColumn(expanded, select)) { + expanded = + SqlStdOperatorTable.ANY_VALUE.createCall( + expanded.getParserPosition(), expanded); + } } final String alias = SqlValidatorUtil.alias(selectItem, aliases.size()); @@ -560,56 +569,32 @@ private boolean expandSelectItem( return false; } - private static SqlNode expandExprFromJoin( - SqlJoin join, SqlIdentifier identifier, SelectScope scope) { - if (join.getConditionType() != JoinConditionType.USING) { - return identifier; + /** Returns true if the node is a non-aggregated, non-grouped column in SELECT. */ + private boolean isNonAggregatedNonGroupedColumn(SqlNode node, SqlSelect select) { + if (aggFinder.findAgg(node) != null) { + return false; } - final Map fieldAliases = getFieldAliases(scope); - - for (String name : SqlIdentifier.simpleNames((SqlNodeList) getCondition(join))) { - if (identifier.getSimple().equals(name)) { - final List qualifiedNode = new ArrayList<>(); - for (ScopeChild child : requireNonNull(scope, "scope").children) { - if (child.namespace.getRowType().getFieldNames().contains(name)) { - final SqlIdentifier exp = - new SqlIdentifier( - ImmutableList.of(child.name, name), - identifier.getParserPosition()); - qualifiedNode.add(exp); - } - } - - assert qualifiedNode.size() == 2; - - // If there is an alias for the column, no need to wrap the coalesce with an AS - // operator - boolean haveAlias = fieldAliases.containsKey(name); - - final SqlCall coalesceCall = - SqlStdOperatorTable.COALESCE.createCall( - SqlParserPos.ZERO, qualifiedNode.get(0), qualifiedNode.get(1)); - - if (haveAlias) { - return coalesceCall; - } else { - return SqlStdOperatorTable.AS.createCall( - SqlParserPos.ZERO, - coalesceCall, - new SqlIdentifier(name, SqlParserPos.ZERO)); - } + if (node instanceof SqlIdentifier) { + SqlNodeList groupList = select.getGroup(); + if (groupList == null) { + return true; } + return groupList.getList().stream() + .noneMatch( + groupItem -> + groupItem != null && node.equalsDeep(groupItem, Litmus.IGNORE)); } - // Only need to try to expand the expr from the left input of join - // since it is always left-deep join. - final SqlNode node = join.getLeft(); - if (node instanceof SqlJoin) { - return expandExprFromJoin((SqlJoin) node, identifier, scope); - } else { - return identifier; + if (node instanceof SqlCall) { + return ((SqlCall) node) + .getOperandList().stream() + .anyMatch(operand -> isNonAggregatedNonGroupedColumn(operand, select)); + } else if (node instanceof SqlLiteral) { + return true; } + + return false; } private static Map getFieldAliases(final SelectScope scope) { @@ -661,31 +646,6 @@ private List deriveNaturalJoinColumnList(SqlJoin join) { getNamespaceOrThrow(join.getRight()).getRowType()); } - private static SqlNode expandCommonColumn( - SqlSelect sqlSelect, - SqlNode selectItem, - SelectScope scope, - SqlValidatorImpl validator) { - if (!(selectItem instanceof SqlIdentifier)) { - return selectItem; - } - - final SqlNode from = sqlSelect.getFrom(); - if (!(from instanceof SqlJoin)) { - return selectItem; - } - - final SqlIdentifier identifier = (SqlIdentifier) selectItem; - if (!identifier.isSimple()) { - if (!validator.config().conformance().allowQualifyingCommonColumn()) { - validateQualifiedCommonColumn((SqlJoin) from, identifier, scope, validator); - } - return selectItem; - } - - return expandExprFromJoin((SqlJoin) from, identifier, scope); - } - private static void validateQualifiedCommonColumn( SqlJoin join, SqlIdentifier identifier, SelectScope scope, SqlValidatorImpl validator) { List names = validator.usingNames(join); @@ -1393,9 +1353,7 @@ private SqlValidatorScope getScopeOrThrow(SqlNode node) { } // fall through case TABLE_REF: - // ----- FLINK MODIFICATION BEGIN ----- case LATERAL: - // ----- FLINK MODIFICATION END ----- case SNAPSHOT: case OVER: case COLLECTION_TABLE: @@ -3472,7 +3430,7 @@ public void validateLiteral(SqlLiteral literal) { final BigDecimal noTrailingZeros = bd.stripTrailingZeros(); // If we don't strip trailing zeros we may reject values such as 1.000....0. - final int maxPrecision = typeSystem.getMaxNumericPrecision(); + final int maxPrecision = typeSystem.getMaxPrecision(SqlTypeName.DECIMAL); if (noTrailingZeros.precision() > maxPrecision) { throw newValidationError( literal, RESOURCE.numberLiteralOutOfRange(bd.toString())); @@ -3735,6 +3693,24 @@ private void checkRollUpInUsing( } } + /** + * Get the number of scopes referenced by the specified node; the node represents a computation + * that will be converted to a Rel node eventually. + */ + private int getScopeCount(SqlNode node) { + SqlValidatorScope scope = scopes.get(node); + if (scope == null) { + // Not all nodes have an associated scope; count these as "1". + // For example, a VALUES node. + return 1; + } + if (scope instanceof ListScope) { + ListScope join = (ListScope) scope; + return join.children.size(); + } + return 1; + } + protected void validateJoin(SqlJoin join, SqlValidatorScope scope) { final SqlNode left = join.getLeft(); final SqlNode right = join.getRight(); @@ -3836,9 +3812,11 @@ protected void validateJoin(SqlJoin join, SqlValidatorScope scope) { condition, RESOURCE.asofConditionMustBeComparison()); } + int leftScopeCount = getScopeCount(left); CompareFromBothSides validateCompare = new CompareFromBothSides( joinScope, + leftScopeCount, catalogReader, RESOURCE.asofConditionMustBeComparison()); condition.accept(validateCompare); @@ -3858,7 +3836,10 @@ protected void validateJoin(SqlJoin join, SqlValidatorScope scope) { // Change the exception in validateCompare when we validate the match condition validateCompare = new CompareFromBothSides( - joinScope, catalogReader, RESOURCE.asofMatchMustBeComparison()); + joinScope, + leftScopeCount, + catalogReader, + RESOURCE.asofMatchMustBeComparison()); matchCondition.accept(validateCompare); break; } @@ -3874,16 +3855,21 @@ protected void validateJoin(SqlJoin join, SqlValidatorScope scope) { */ private class CompareFromBothSides extends SqlShuttle { final SqlValidatorScope scope; + // Number of children scopes on the left side of the join. + // Used to determine whether an identifier is from the left input or the right input. + final int leftScopeCount; final SqlValidatorCatalogReader catalogReader; final Resources.ExInst exception; private CompareFromBothSides( SqlValidatorScope scope, + int leftScopeCount, SqlValidatorCatalogReader catalogReader, Resources.ExInst exception) { this.scope = scope; this.catalogReader = catalogReader; this.exception = exception; + this.leftScopeCount = leftScopeCount; } @Override @@ -3917,16 +3903,11 @@ private CompareFromBothSides( id.names.subList(0, id.names.size() - 1), nameMatcher, false, resolved); SqlValidatorScope.Resolve resolve = resolved.only(); int index = resolve.path.steps().get(0).i; - if (index == 0) { + if (index < leftScopeCount) { leftFound = true; - } - if (index == 1) { + } else { rightFound = true; } - - if (!leftFound && !rightFound) { - throw newValidationError(call, this.exception); - } } if (!leftFound || !rightFound) { // The comparison does not look at both tables @@ -5686,13 +5667,6 @@ protected void checkTypeAssignment( // matched boolean isUpdateModifiableViewTable = false; if (query instanceof SqlUpdate) { - final SqlNodeList targetColumnList = - requireNonNull(((SqlUpdate) query).getTargetColumnList()); - final int targetColumnCount = targetColumnList.size(); - targetRowType = - SqlTypeUtil.extractLastNFields(typeFactory, targetRowType, targetColumnCount); - sourceRowType = - SqlTypeUtil.extractLastNFields(typeFactory, sourceRowType, targetColumnCount); isUpdateModifiableViewTable = table.unwrap(ModifiableViewTable.class) != null; } if (SqlTypeUtil.equalAsStructSansNullability( @@ -5813,6 +5787,16 @@ public void validateUpdate(SqlUpdate call) { final RelDataType sourceRowType = getValidatedNodeType(select); checkTypeAssignment(scopes.get(select), table, sourceRowType, targetRowType, call); + // Set validated sourceExpressionList from the source select. + // The last elements of sourceSelect are the expression list. + List sourceExpressionList = + Util.last(select.getSelectList(), call.getSourceExpressionList().size()); + call.setOperand( + 2, + SqlUtil.stripListAs( + new SqlNodeList( + sourceExpressionList, + call.getSourceExpressionList().getParserPosition()))); checkConstraint(table, call, targetRowType); validateAccess(call.getTargetTable(), table, SqlAccessEnum.UPDATE); @@ -7328,6 +7312,30 @@ public SqlNode go(SqlNode root) { return requireNonNull(root.accept(this), () -> this + " returned null for " + root); } + /** True if the exception ex indicates that name lookup has failed. */ + public static boolean isNotFoundException(Exception ex) { + if (!(ex instanceof CalciteContextException)) { + return false; + } + String message = ex.getMessage(); + if (message == null || !message.contains("not found")) { + return false; + } + return true; + } + + /** True if the exception ex indicates that a column is ambiguous. */ + public static boolean isAmbiguousException(Exception ex) { + if (!(ex instanceof CalciteContextException)) { + return false; + } + String message = ex.getMessage(); + if (message == null || !message.contains("is ambiguous")) { + return false; + } + return true; + } + @Override public @Nullable SqlNode visit(SqlIdentifier id) { // First check for builtin functions which don't have @@ -7383,6 +7391,88 @@ protected SqlNode expandDynamicStar(SqlIdentifier id, SqlIdentifier fqId) { } return fqId; } + + protected SqlNode expandCommonColumn( + SqlSelect sqlSelect, SqlNode selectItem, SelectScope scope) { + if (!(selectItem instanceof SqlIdentifier)) { + return selectItem; + } + + final SqlNode from = sqlSelect.getFrom(); + if (!(from instanceof SqlJoin)) { + return selectItem; + } + + final SqlIdentifier identifier = (SqlIdentifier) selectItem; + if (!identifier.isSimple()) { + if (!validator.config().conformance().allowQualifyingCommonColumn()) { + validateQualifiedCommonColumn((SqlJoin) from, identifier, scope, validator); + } + return selectItem; + } + + return expandExprFromJoin((SqlJoin) from, identifier, scope); + } + + private SqlNode expandExprFromJoin( + SqlJoin join, SqlIdentifier identifier, SelectScope scope) { + List commonColumnNames; + // must be NATURAL or USING here, and cannot specify NATURAL keyword with USING clause + if (join.isNatural()) { + commonColumnNames = validator.deriveNaturalJoinColumnList(join); + } else if (join.getConditionType() == JoinConditionType.USING) { + commonColumnNames = SqlIdentifier.simpleNames((SqlNodeList) getCondition(join)); + } else { + return identifier; + } + + final SqlNameMatcher matcher = validator.getCatalogReader().nameMatcher(); + final Map fieldAliases = getFieldAliases(scope); + + for (String name : commonColumnNames) { + if (matcher.matches(identifier.getSimple(), name)) { + final List qualifiedNode = new ArrayList<>(); + for (ScopeChild child : requireNonNull(scope, "scope").children) { + if (matcher.indexOf(child.namespace.getRowType().getFieldNames(), name) + >= 0) { + final SqlIdentifier exp = + new SqlIdentifier( + ImmutableList.of(child.name, name), + identifier.getParserPosition()); + qualifiedNode.add(exp); + } + } + + assert qualifiedNode.size() == 2; + + // If there is an alias for the column, no need to wrap the coalesce with an AS + // operator + boolean haveAlias = fieldAliases.containsKey(name); + + final SqlCall coalesceCall = + SqlStdOperatorTable.COALESCE.createCall( + SqlParserPos.ZERO, qualifiedNode.get(0), qualifiedNode.get(1)); + + if (haveAlias) { + return coalesceCall; + } else { + return SqlStdOperatorTable.AS.createCall( + SqlParserPos.ZERO, + coalesceCall, + new SqlIdentifier(identifier.getSimple(), SqlParserPos.ZERO)); + } + } + } + + // Only need to try to expand the expr from the left input of join + // since it is always left-deep join. + final SqlNode node = join.getLeft(); + if (node instanceof SqlJoin) { + return expandExprFromJoin((SqlJoin) node, identifier, scope); + } else { + return identifier; + } + } } /** @@ -7545,19 +7635,14 @@ public SqlNode go(SqlNode root) { @Override public @Nullable SqlNode visit(SqlIdentifier id) { - final SqlNode node = - expandCommonColumn(select, id, (SelectScope) getScope(), validator); + final SqlNode node = expandCommonColumn(select, id, (SelectScope) getScope()); if (node != id) { return node; } else { try { return super.visit(id); } catch (Exception ex) { - if (!(ex instanceof CalciteContextException)) { - throw ex; - } - String message = ex.getMessage(); - if (message != null && !message.contains("not found")) { + if (!isNotFoundException(ex)) { throw ex; } // This point is reached only if the name lookup failed using standard rules @@ -7642,7 +7727,10 @@ static class ExtendedExpander extends Expander { final SqlSelect select; final SqlNode root; final Clause clause; - // Retain only expandable aliases or ordinals to prevent their expansion in a SQL call expr. + // This contains the ordinal nodes in a GROUP BY that may need to be expanded + // into column names when using a conformance that allows ordinals in group bys + // E.g., For GROUP BY 1, 1 will be in this list + // For GROUP BY CUBE(1), 1 will also be in this list final Set aliasOrdinalExpandSet = Sets.newIdentityHashSet(); ExtendedExpander( @@ -7656,7 +7744,7 @@ static class ExtendedExpander extends Expander { this.root = root; this.clause = clause; if (clause == Clause.GROUP_BY) { - addExpandableExpressions(); + addExpandableOrdinals(); } } @@ -7668,18 +7756,53 @@ static class ExtendedExpander extends Expander { final SelectScope selectScope = validator.getRawSelectScopeNonNull(select); final boolean replaceAliases = clause.shouldReplaceAliases(validator.config); - if (!replaceAliases - || (clause == Clause.GROUP_BY && !aliasOrdinalExpandSet.contains(id))) { - SqlNode node = expandCommonColumn(select, id, selectScope, validator); - if (node != id) { - return node; + try { + // First try a standard expansion + if (clause == Clause.GROUP_BY) { + SqlNode node = expandCommonColumn(select, id, selectScope); + if (node != id) { + return node; + } + return super.visit(id); } - return super.visit(id); + } catch (Exception ex) { + // This behavior is from MySQL: + // - if there is no column in the FROM with the name used in the GROUP BY + // then look for a column alias in the SELECT + // - if there are multiple columns in the FROM with the name used in the GROUP BY, + // then also look for a column alias in the SELECT + if (!Expander.isNotFoundException(ex) && !Expander.isAmbiguousException(ex)) { + throw ex; + } + if (!replaceAliases) { + throw ex; + } + // Continue execution, trying to replace alias } + final SqlNameMatcher nameMatcher = validator.catalogReader.nameMatcher(); + + if (clause == Clause.HAVING) { + if (!replaceAliases) { + return id; + } + + // Do not expand aliases in HAVING if they are being grouped on + SqlNodeList list = select.getGroup(); + if (list != null) { + // HAVING can be used without GROUP BY + for (SqlNode node : list) { + if (node instanceof SqlIdentifier) { + SqlIdentifier grouped = (SqlIdentifier) node; + if (nameMatcher.matches(id.getSimple(), Util.last(grouped.names))) { + return id; + } + } + } + } + } String name = id.getSimple(); SqlNode expr = null; - final SqlNameMatcher nameMatcher = validator.catalogReader.nameMatcher(); int n = 0; for (SqlNode s : SqlNonNullableAccessors.getSelectList(select)) { final @Nullable String alias = SqlValidatorUtil.alias(s); @@ -7704,8 +7827,10 @@ static class ExtendedExpander extends Expander { // expr cannot be null; in that case n = 0 would have returned requireNonNull(expr, "expr"); - if (validator.getConformance().isSelectAlias() - != SqlConformance.SelectAliasLookup.UNSUPPORTED) { + if ((clause == Clause.SELECT + && (validator.getConformance().isSelectAlias() + != SqlConformance.SelectAliasLookup.UNSUPPORTED)) + || clause == Clause.GROUP_BY) { Map expansions = new HashMap<>(); final Expander expander = new SelectExpander(validator, selectScope, select, expansions); @@ -7751,18 +7876,13 @@ static class ExtendedExpander extends Expander { break; } } - return super.visit(literal); } - /** - * Add all possible expandable 'group by' expression to set, which is used to check whether - * expr could be expanded as alias or ordinal. - */ + /** Add all possible expandable 'group by' ordinals to {@link aliasOrdinalExpandSet}. */ @RequiresNonNull({"root"}) - private void addExpandableExpressions() { + private void addExpandableOrdinals() { switch (root.getKind()) { - case IDENTIFIER: case LITERAL: aliasOrdinalExpandSet.add(root); break; @@ -7772,7 +7892,7 @@ private void addExpandableExpressions() { if (root instanceof SqlBasicCall) { List operandList = ((SqlBasicCall) root).getOperandList(); for (SqlNode sqlNode : operandList) { - addIdentifierOrdinal2ExpandSet(sqlNode); + addOrdinal2ExpandSet(sqlNode); } } break; @@ -7782,18 +7902,17 @@ private void addExpandableExpressions() { } /** - * Identifier or literal in grouping sets, rollup, cube will be eligible for alias. + * Literal in grouping sets, rollup, cube will be expanded. * * @param sqlNode expression within grouping sets, rollup, cube */ - private void addIdentifierOrdinal2ExpandSet(SqlNode sqlNode) { + private void addOrdinal2ExpandSet(SqlNode sqlNode) { if (sqlNode.getKind() == SqlKind.ROW) { List rowOperandList = ((SqlCall) sqlNode).getOperandList(); for (SqlNode node : rowOperandList) { - addIdentifierOrdinal2ExpandSet(node); + addOrdinal2ExpandSet(node); } - } else if (sqlNode.getKind() == SqlKind.IDENTIFIER - || sqlNode.getKind() == SqlKind.LITERAL) { + } else if (sqlNode.getKind() == SqlKind.LITERAL) { aliasOrdinalExpandSet.add(sqlNode); } } @@ -8361,14 +8480,10 @@ private enum Clause { boolean shouldReplaceAliases(Config config) { switch (this) { case GROUP_BY: - return config.conformance().isGroupByAlias() - || (config.conformance().isSelectAlias() - != SqlConformance.SelectAliasLookup.UNSUPPORTED); + return config.conformance().isGroupByAlias(); case HAVING: - return config.conformance().isHavingAlias() - || (config.conformance().isSelectAlias() - != SqlConformance.SelectAliasLookup.UNSUPPORTED); + return config.conformance().isHavingAlias(); case QUALIFY: return true; diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java index d350cc7cca6ef..00661e9392b81 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql2rel/RelDecorrelator.java @@ -57,7 +57,6 @@ import org.apache.calcite.rel.core.Project; import org.apache.calcite.rel.core.RelFactories; import org.apache.calcite.rel.core.Sort; -import org.apache.calcite.rel.core.Values; import org.apache.calcite.rel.logical.LogicalAggregate; import org.apache.calcite.rel.logical.LogicalCorrelate; import org.apache.calcite.rel.logical.LogicalFilter; @@ -112,11 +111,14 @@ import org.slf4j.Logger; import java.math.BigDecimal; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; +import java.util.Deque; import java.util.HashMap; import java.util.HashSet; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.NavigableMap; @@ -153,10 +155,21 @@ public class RelDecorrelator implements ReflectiveVisitor { // map built during translation protected CorelMap cm; + /** + * Stack maintaining visible Frames to the currently invoked RelNode during top-down traversal. + * Each entry maps a CorrelationId to the Frame where its correlated variables originate. + */ + protected final Deque> frameStack = new ArrayDeque<>(); + @SuppressWarnings("method.invocation.invalid") protected final ReflectUtil.MethodDispatcher<@Nullable Frame> dispatcher = ReflectUtil.createMethodDispatcher( - Frame.class, getVisitor(), "decorrelateRel", RelNode.class, boolean.class); + Frame.class, + getVisitor(), + "decorrelateRel", + RelNode.class, + boolean.class, + boolean.class); // The rel which is being visited protected @Nullable RelNode currentRel; @@ -355,8 +368,28 @@ protected RelNode decorrelate(RelNode root) { // Perform decorrelation. map.clear(); - final Frame frame = getInvoke(root, false, null); + final Frame frame = getInvoke(root, false, null, true); if (frame != null) { + // Check if the frame has more fields than the original and discard the extra ones + RelNode result = frame.r; + int fields = frame.r.getRowType().getFieldCount(); + if (fields > frame.oldToNewOutputs.size()) { + relBuilder.push(result); + final List exprList = new ArrayList<>(); + List> entries = + new ArrayList<>(frame.oldToNewOutputs.entrySet()); + entries.sort(Map.Entry.comparingByKey()); + for (Map.Entry entry : entries) { + exprList.add(relBuilder.field(entry.getValue())); + } + relBuilder.project(exprList); + result = relBuilder.build(); + } else { + Litmus.THROW.check( + fields == frame.oldToNewOutputs.size(), + "Produced relation has fewer columns than the original relation"); + } + // has been rewritten; apply rules post-decorrelation final HepProgramBuilder builder = HepProgram.builder() @@ -376,7 +409,7 @@ protected RelNode decorrelate(RelNode root) { final HepProgram program2 = builder.build(); final HepPlanner planner2 = createPlanner(program2); - final RelNode newRoot = frame.r; + final RelNode newRoot = result; planner2.setRoot(newRoot); return planner2.findBestExp(); } @@ -504,14 +537,17 @@ protected RexNode removeCorrelationExpr( } /** Fallback if none of the other {@code decorrelateRel} methods match. */ - public @Nullable Frame decorrelateRel(RelNode rel, boolean isCorVarDefined) { + public @Nullable Frame decorrelateRel( + RelNode rel, boolean isCorVarDefined, boolean parentPropagatesNullValues) { RelNode newRel = rel.copy(rel.getTraitSet(), rel.getInputs()); if (!rel.getInputs().isEmpty()) { List oldInputs = rel.getInputs(); List newInputs = new ArrayList<>(); for (int i = 0; i < oldInputs.size(); ++i) { - final Frame frame = getInvoke(oldInputs.get(i), isCorVarDefined, rel); + final Frame frame = + getInvoke( + oldInputs.get(i), isCorVarDefined, rel, parentPropagatesNullValues); if (frame == null || !frame.corDefOutputs.isEmpty()) { // if input is not rewritten, or if it produces correlated // variables, terminate rewrite @@ -535,7 +571,8 @@ protected RexNode removeCorrelationExpr( ImmutableSortedMap.of()); } - public @Nullable Frame decorrelateRel(Sort rel, boolean isCorVarDefined) { + public @Nullable Frame decorrelateRel( + Sort rel, boolean isCorVarDefined, boolean parentPropagatesNullValues) { // // Rewrite logic: // @@ -552,7 +589,7 @@ protected RexNode removeCorrelationExpr( // need to call propagateExpr. final RelNode oldInput = rel.getInput(); - final Frame frame = getInvoke(oldInput, isCorVarDefined, rel); + final Frame frame = getInvoke(oldInput, isCorVarDefined, rel, true); if (frame == null) { // If input has not been rewritten, do not rewrite this rel. return null; @@ -590,16 +627,13 @@ protected RexNode removeCorrelationExpr( return register(rel, newSort, frame.oldToNewOutputs, frame.corDefOutputs); } - public @Nullable Frame decorrelateRel(Values rel, boolean isCorVarDefined) { - // There are no inputs, so rel does not need to be changed. - return null; - } - - public @Nullable Frame decorrelateRel(LogicalAggregate rel, boolean isCorVarDefined) { - return decorrelateRel((Aggregate) rel, isCorVarDefined); + public @Nullable Frame decorrelateRel( + LogicalAggregate rel, boolean isCorVarDefined, boolean parentPropagatesNullValues) { + return decorrelateRel((Aggregate) rel, isCorVarDefined, parentPropagatesNullValues); } - public @Nullable Frame decorrelateRel(Aggregate rel, boolean isCorVarDefined) { + public @Nullable Frame decorrelateRel( + Aggregate rel, boolean isCorVarDefined, boolean parentPropagatesNullValues) { // // Rewrite logic: // @@ -613,7 +647,7 @@ protected RexNode removeCorrelationExpr( assert !cm.mapRefRelToCorRef.containsKey(rel); final RelNode oldInput = rel.getInput(); - final Frame frame = getInvoke(oldInput, isCorVarDefined, rel); + final Frame frame = getInvoke(oldInput, isCorVarDefined, rel, parentPropagatesNullValues); if (frame == null) { // If input has not been rewritten, do not rewrite this rel. return null; @@ -648,7 +682,7 @@ protected RexNode removeCorrelationExpr( } // add mapping of group keys. - outputMap.put(idx, newPos); + outputMap.put(i, newPos); int newInputPos = requireNonNull(frame.oldToNewOutputs.get(idx)); RexInputRef.add2(projects, newInputPos, newInputOutput); mapNewInputToProjOutputs.put(newInputPos, newPos); @@ -762,9 +796,16 @@ protected RexNode removeCorrelationExpr( combinedMap.get(oldAggCall.filterArg), () -> "combinedMap.get(" + oldAggCall.filterArg + ")"); + boolean newHasEmptyGroup = newGroupSets == null && newGroupSet.isEmpty(); + if (newGroupSets != null) { + Iterator groupSetsIterator = newGroupSets.iterator(); + while (!newHasEmptyGroup && groupSetsIterator.hasNext()) { + newHasEmptyGroup |= groupSetsIterator.next().isEmpty(); + } + } newAggCalls.add( oldAggCall.adaptTo( - newProject, aggArgs, filterArg, oldGroupKeyCount, newGroupKeyCount)); + newProject, aggArgs, filterArg, rel.hasEmptyGroup(), newHasEmptyGroup)); // The old to new output position mapping will be the same as that // of newProject, plus any aggregates that the oldAgg produces. @@ -796,11 +837,161 @@ protected RexNode removeCorrelationExpr( RelNode newRel = relBuilder.build(); + for (AggregateCall aggCall : rel.getAggCallList()) { + if (aggCall.getAggregation() instanceof SqlCountAggFunction) { + parentPropagatesNullValues = false; + break; + } + } + + if (rel.getGroupType() == Aggregate.Group.SIMPLE + && rel.getGroupSet().isEmpty() + && !frame.corDefOutputs.isEmpty() + && !parentPropagatesNullValues) { + newRel = rewriteScalarAggregate(rel, newRel, outputMap, corDefOutputs); + } + // Aggregate does not change input ordering so corVars will be // located at the same position as the input newProject. return register(rel, newRel, outputMap, corDefOutputs); } + /** + * Special case where the group by is static (i.e., aggregation functions without group by). + * + *

Background: For the query: SELECT SUM(salary), COUNT(name) FROM A; When table A is empty, + * it returns [null, 0]. But for SELECT SUM(salary), COUNT(name) FROM A group by id When table A + * is empty, it returns empty. This causes result mismatch. In the general decorrelation + * framework, we add corVar as an additional groupKey to rewrite Correlate as JOIN. (See the + * code above for details) This means that when the input is empty, the result produced using a + * JOIN is incorrect. + * + *

We refer to this situation as: `The well-known count bug`, More details about this issue: + * Optimization of Nested SQL Queries Revisited (https://dl.acm.org/doi/pdf/10.1145/38714.38723) + * + *

To handle this situation, we using a LEFT JOIN to ensure that an output is always + * produced. + * + *

Given the SQL: SELECT deptno FROM dept d WHERE 0 = (SELECT COUNT(*) FROM emp e WHERE + * d.deptno = e.deptno) Corresponding plan: LogicalProject(DEPTNO=[$0]) + * LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{0}]) + * LogicalProject(DEPTNO=[$0]) LogicalTableScan(table=[[scott, DEPT]]) LogicalProject(cs=[true]) + * LogicalFilter(condition=[=(0, $0)]) LogicalAggregate(group=[{}], EXPR$0=[COUNT()]) + * LogicalFilter(condition=[=($cor0.DEPTNO, $7)]) LogicalTableScan(table=[[scott, EMP]]) + * + *

Rewriting this as: SELECT d.deptno FROM dept d JOIN ( SELECT true, e.deptno FROM emp e + * WHERE e.deptno IS NOT NULL GROUP BY e.deptno HAVING COUNT(*) = 0 ) AS d0 ON d.deptno = + * d0.deptno produces an incorrect result. Corresponding plan: LogicalProject(DEPTNO=[$0]) + * LogicalJoin(condition=[=($0, $2)], joinType=[inner]) LogicalProject(DEPTNO=[$0]) + * LogicalTableScan(table=[[scott, DEPT]]) LogicalProject(cs=[true], DEPTNO=[$0]) + * LogicalFilter(condition=[=(0, $1)]) LogicalAggregate(group=[{0}], EXPR$0=[COUNT()]) // + * corresponds to {@code oldRel} LogicalProject(DEPTNO=[$7]) LogicalFilter(condition=[IS NOT + * NULL($7)]) LogicalTableScan(table=[[scott, EMP]]) We can clearly observe that due to the + * presence of the GROUP BY clause, COUNT(*) = 0 will never evaluate to true, since rows with + * zero records won't appear in the GROUP BY results. This produced incorrect results. + * + *

Rewrite Aggregate as: SELECT d.deptno FROM dept d JOIN ( SELECT true AS cs, deptno FROM ( + * SELECT d2.deptno, CASE WHEN cnt0 IS NOT NULL THEN cnt0 ELSE 0 END AS cnt FROM (SELECT deptno + * FROM dept GROUP BY deptno) d2 LEFT JOIN ( SELECT deptno, COUNT(e.empno) cnt0 FROM emp WHERE + * deptno IS NOT NULL GROUP BY deptno) e ON d2.deptno IS NOT DISTINCT FROM e.deptno ) AS + * case_count WHERE cnt = 0 ) AS d0 ON d.deptno = d0.deptno Corresponding plan: [01] + * LogicalProject(DEPTNO=[$0]) [02] LogicalJoin(condition=[=($0, $2)], joinType=[inner]) [03] + * LogicalProject(DEPTNO=[$0]) [04] LogicalTableScan(table=[[scott, DEPT]]) [05] + * LogicalProject(cs=[true], DEPTNO=[$0]) [06] LogicalFilter(condition=[=(0, $1)]) [07] + * LogicalProject(DEPTNO=[$0], EXPR$0=[CASE(IS NOT NULL($2), $2, 0)]) [08] + * LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $1)], joinType=[left]) [09] + * LogicalAggregate(group=[{0}]) [10] LogicalProject(DEPTNO=[$0]) [11] + * LogicalTableScan(table=[[scott, DEPT]]) [12] LogicalAggregate(group=[{0}], EXPR$0=[COUNT()]) + * [13] LogicalProject(DEPTNO=[$7]) [14] LogicalFilter(condition=[IS NOT NULL($7)]) [15] + * LogicalTableScan(table=[[scott, EMP]]) + * + *

Here we perform an early join, preserving all possible CorVar sets from the outer scope + * and their corresponding aggregation results. This ensures that for any row from the left + * input of the Correlation, there is always an aggregation result available for join output. + * + *

Implementation based on: Improving Unnesting of Complex Queries + * (https://dl.gi.de/server/api/core/bitstreams/c1918e8c-6a87-4da2-930a-bfed289f2388/content) + */ + private RelNode rewriteScalarAggregate( + Aggregate oldRel, + RelNode newRel, + Map outputMap, + NavigableMap corDefOutputs) { + final Pair outerFramePair = requireNonNull(this.frameStack.peek()); + final Frame outFrame = outerFramePair.right; + RexBuilder rexBuilder = relBuilder.getRexBuilder(); + + int groupKeySize = + (int) + corDefOutputs.keySet().stream() + .filter(a -> a.corr.equals(outerFramePair.left)) + .count(); + List newRelFields = newRel.getRowType().getFieldList(); + ImmutableBitSet.Builder corFieldBuilder = ImmutableBitSet.builder(); + + // Here we record the mapping between the original index and the new project. + // For the count, we map it as `case when x is null then 0 else x`. + final Map newProjectMap = new HashMap<>(); + final List conditions = new ArrayList<>(); + for (Map.Entry corDefOutput : corDefOutputs.entrySet()) { + CorDef corDef = corDefOutput.getKey(); + Integer corIndex = corDefOutput.getValue(); + if (corDef.corr.equals(outerFramePair.left)) { + int newIdx = requireNonNull(outFrame.oldToNewOutputs.get(corDef.field)); + corFieldBuilder.set(newIdx); + + RelDataType type = outFrame.r.getRowType().getFieldList().get(newIdx).getType(); + RexNode left = new RexInputRef(corFieldBuilder.cardinality() - 1, type); + newProjectMap.put(corIndex + groupKeySize, left); + conditions.add( + relBuilder.isNotDistinctFrom( + left, + new RexInputRef( + corIndex + groupKeySize, + newRelFields.get(corIndex).getType()))); + } + } + + ImmutableBitSet groupSet = corFieldBuilder.build(); + // Build [09] LogicalAggregate(group=[{0}]) to obtain the distinct set of + // corVar from outFrame. + relBuilder.push(outFrame.r).aggregate(relBuilder.groupKey(groupSet)); + + // Build [08] LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $1)], joinType=[left]) + // to ensure each corVar's aggregate result is output. + final RelNode join = relBuilder.push(newRel).join(JoinRelType.LEFT, conditions).build(); + + for (int i1 = 0; i1 < oldRel.getAggCallList().size(); i1++) { + AggregateCall aggCall = oldRel.getAggCallList().get(i1); + if (aggCall.getAggregation() instanceof SqlCountAggFunction) { + int index = requireNonNull(outputMap.get(i1 + oldRel.getGroupSet().size())); + final RexInputRef ref = RexInputRef.of(index + groupKeySize, join.getRowType()); + RexNode specificCountValue = + rexBuilder.makeCall( + SqlStdOperatorTable.CASE, + ImmutableList.of( + relBuilder.isNotNull(ref), ref, relBuilder.literal(0))); + newProjectMap.put(ref.getIndex(), specificCountValue); + } + } + + final List newProjects = new ArrayList<>(); + for (int index : ImmutableBitSet.range(groupKeySize, join.getRowType().getFieldCount())) { + if (newProjectMap.containsKey(index)) { + newProjects.add(requireNonNull(newProjectMap.get(index))); + } else { + newProjects.add(RexInputRef.of(index, join.getRowType())); + } + } + + // Build [07] LogicalProject(DEPTNO=[$0], EXPR$0=[CASE(IS NOT NULL($2), $2, 0)]) + // to handle COUNT function by converting nulls to zero. + return relBuilder + .push(join) + .project(newProjects, newRel.getRowType().getFieldNames()) + .build(); + } + /** * Shift the mapping to fixed offset from the {@code startIndex}. * @@ -817,8 +1008,18 @@ private static void shiftMapping(Map mapping, int startIndex, } } - public @Nullable Frame getInvoke(RelNode r, boolean isCorVarDefined, @Nullable RelNode parent) { - final Frame frame = dispatcher.invoke(r, isCorVarDefined); + /** + * Invokes decorrelation logic for a given relational expression. + * + * @param parentPropagatesNullValues True if the parent RelNode produces null when all of its + * inputs fields are null. + */ + public @Nullable Frame getInvoke( + RelNode r, + boolean isCorVarDefined, + @Nullable RelNode parent, + boolean parentPropagatesNullValues) { + final Frame frame = dispatcher.invoke(r, isCorVarDefined, parentPropagatesNullValues); currentRel = parent; if (frame != null) { map.put(r, frame); @@ -985,19 +1186,26 @@ private static void shiftMapping(Map mapping, int startIndex, return null; } - public @Nullable Frame decorrelateRel(LogicalProject rel, boolean isCorVarDefined) { - return decorrelateRel((Project) rel, isCorVarDefined); + public @Nullable Frame decorrelateRel( + LogicalProject rel, boolean isCorVarDefined, boolean parentPropagatesNullValues) { + return decorrelateRel((Project) rel, isCorVarDefined, parentPropagatesNullValues); } - public @Nullable Frame decorrelateRel(Project rel, boolean isCorVarDefined) { + public @Nullable Frame decorrelateRel( + Project rel, boolean isCorVarDefined, boolean parentPropagatesNullValues) { // // Rewrite logic: // // 1. Pass along any correlated variables coming from the input. // + for (RexNode project : rel.getProjects()) { + if (!Strong.isStrong(project)) { + parentPropagatesNullValues = false; + } + } final RelNode oldInput = rel.getInput(); - Frame frame = getInvoke(oldInput, isCorVarDefined, rel); + Frame frame = getInvoke(oldInput, isCorVarDefined, rel, parentPropagatesNullValues); if (frame == null) { // If input has not been rewritten, do not rewrite this rel. return null; @@ -1361,25 +1569,31 @@ private static boolean isWidening(RelDataType type, RelDataType type1) { && type.getPrecision() >= type1.getPrecision(); } - public @Nullable Frame decorrelateRel(LogicalSnapshot rel, boolean isCorVarDefined) { + public @Nullable Frame decorrelateRel( + LogicalSnapshot rel, boolean isCorVarDefined, boolean parentPropagatesNullValues) { if (RexUtil.containsCorrelation(rel.getPeriod())) { return null; } - return decorrelateRel((RelNode) rel, isCorVarDefined); + return decorrelateRel((RelNode) rel, isCorVarDefined, parentPropagatesNullValues); } - public @Nullable Frame decorrelateRel(LogicalTableFunctionScan rel, boolean isCorVarDefined) { + public @Nullable Frame decorrelateRel( + LogicalTableFunctionScan rel, + boolean isCorVarDefined, + boolean parentPropagatesNullValues) { if (RexUtil.containsCorrelation(rel.getCall())) { return null; } - return decorrelateRel((RelNode) rel, isCorVarDefined); + return decorrelateRel((RelNode) rel, isCorVarDefined, parentPropagatesNullValues); } - public @Nullable Frame decorrelateRel(LogicalFilter rel, boolean isCorVarDefined) { - return decorrelateRel((Filter) rel, isCorVarDefined); + public @Nullable Frame decorrelateRel( + LogicalFilter rel, boolean isCorVarDefined, boolean parentPropagatesNullValues) { + return decorrelateRel((Filter) rel, isCorVarDefined, parentPropagatesNullValues); } - public @Nullable Frame decorrelateRel(Filter rel, boolean isCorVarDefined) { + public @Nullable Frame decorrelateRel( + Filter rel, boolean isCorVarDefined, boolean parentPropagatesNullValues) { // // Rewrite logic: // @@ -1397,7 +1611,7 @@ private static boolean isWidening(RelDataType type, RelDataType type1) { // final RelNode oldInput = rel.getInput(); - Frame frame = getInvoke(oldInput, isCorVarDefined, rel); + Frame frame = getInvoke(oldInput, isCorVarDefined, rel, parentPropagatesNullValues); if (frame == null) { // If input has not been rewritten, do not rewrite this rel. return null; @@ -1428,11 +1642,13 @@ private static boolean isWidening(RelDataType type, RelDataType type1) { return register(rel, relBuilder.build(), frame.oldToNewOutputs, frame.corDefOutputs); } - public @Nullable Frame decorrelateRel(LogicalCorrelate rel, boolean isCorVarDefined) { - return decorrelateRel((Correlate) rel, isCorVarDefined); + public @Nullable Frame decorrelateRel( + LogicalCorrelate rel, boolean isCorVarDefined, boolean parentPropagatesNullValues) { + return decorrelateRel((Correlate) rel, isCorVarDefined, parentPropagatesNullValues); } - public @Nullable Frame decorrelateRel(Correlate rel, boolean isCorVarDefined) { + public @Nullable Frame decorrelateRel( + Correlate rel, boolean isCorVarDefined, boolean parentPropagatesNullValues) { // // Rewrite logic: // @@ -1446,15 +1662,18 @@ private static boolean isWidening(RelDataType type, RelDataType type1) { final RelNode oldLeft = rel.getInput(0); final RelNode oldRight = rel.getInput(1); - final Frame leftFrame = getInvoke(oldLeft, isCorVarDefined, rel); - final Frame rightFrame = getInvoke(oldRight, true, rel); - - if (leftFrame == null || rightFrame == null) { - // If any input has not been rewritten, do not rewrite this rel. + final Frame leftFrame = + getInvoke(oldLeft, isCorVarDefined, rel, parentPropagatesNullValues); + if (leftFrame == null) { + // If input has not been rewritten, do not rewrite this rel. return null; } - if (rightFrame.corDefOutputs.isEmpty()) { + frameStack.push(Pair.of(rel.getCorrelationId(), leftFrame)); + final Frame rightFrame = getInvoke(oldRight, true, rel, parentPropagatesNullValues); + frameStack.pop(); + + if (rightFrame == null || rightFrame.corDefOutputs.isEmpty()) { return null; } @@ -1542,16 +1761,18 @@ private static boolean isWidening(RelDataType type, RelDataType type1) { return register(rel, newJoin, mapOldToNewOutputs, corDefOutputs); } - public @Nullable Frame decorrelateRel(LogicalJoin rel, boolean isCorVarDefined) { - return decorrelateRel((Join) rel, isCorVarDefined); + public @Nullable Frame decorrelateRel( + LogicalJoin rel, boolean isCorVarDefined, boolean parentPropagatesNullValues) { + return decorrelateRel((Join) rel, isCorVarDefined, parentPropagatesNullValues); } - public @Nullable Frame decorrelateRel(Join rel, boolean isCorVarDefined) { + public @Nullable Frame decorrelateRel( + Join rel, boolean isCorVarDefined, boolean parentPropagatesNullValues) { // For SEMI/ANTI join decorrelate it's input directly, // because the correlate variables can only be propagated from // the left side, which is not supported yet. if (!rel.getJoinType().projectsRight()) { - return decorrelateRel((RelNode) rel, isCorVarDefined); + return decorrelateRel((RelNode) rel, isCorVarDefined, parentPropagatesNullValues); } // // Rewrite logic: @@ -1563,8 +1784,10 @@ private static boolean isWidening(RelDataType type, RelDataType type1) { final RelNode oldLeft = rel.getInput(0); final RelNode oldRight = rel.getInput(1); - final Frame leftFrame = getInvoke(oldLeft, isCorVarDefined, rel); - final Frame rightFrame = getInvoke(oldRight, isCorVarDefined, rel); + final Frame leftFrame = + getInvoke(oldLeft, isCorVarDefined, rel, parentPropagatesNullValues); + final Frame rightFrame = + getInvoke(oldRight, isCorVarDefined, rel, parentPropagatesNullValues); if (leftFrame == null || rightFrame == null) { // If any input has not been rewritten, do not rewrite this rel. @@ -2764,8 +2987,8 @@ public void onMatch(RelOptRuleCall call) { joinOutputProject, argList, filterArg, - aggregate.getGroupCount(), - groupCount)); + aggregate.hasEmptyGroup(), + groupCount == 0)); } ImmutableBitSet groupSet = ImmutableBitSet.range(groupCount); @@ -2876,7 +3099,7 @@ public void onMatch(RelOptRuleCall call) { onMatch2(d, call, correlate, left, aggOutputProject, aggregate); } - private void onMatch2( + private static void onMatch2( RelDecorrelator d, RelOptRuleCall call, Correlate correlate, @@ -3297,6 +3520,8 @@ assert allLessThan( Litmus.THROW); assert allLessThan( this.oldToNewOutputs.values(), r.getRowType().getFieldCount(), Litmus.THROW); + RelDataType rowType = oldRel.getRowType(); + assert this.oldToNewOutputs.size() >= rowType.getFieldCount(); } } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java index abb0e9a882312..46a3e89c3de82 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java @@ -49,6 +49,7 @@ import org.apache.calcite.rel.RelCollationTraitDef; import org.apache.calcite.rel.RelCollations; import org.apache.calcite.rel.RelFieldCollation; +import org.apache.calcite.rel.RelHomogeneousShuttle; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.RelRoot; import org.apache.calcite.rel.RelShuttleImpl; @@ -101,6 +102,7 @@ import org.apache.calcite.rex.RexPatternFieldRef; import org.apache.calcite.rex.RexRangeRef; import org.apache.calcite.rex.RexShuttle; +import org.apache.calcite.rex.RexSimplify; import org.apache.calcite.rex.RexSubQuery; import org.apache.calcite.rex.RexUtil; import org.apache.calcite.rex.RexWindowBound; @@ -1168,6 +1170,16 @@ private static SqlNode reg(SqlValidatorScope scope, SqlNode e) { return e; } + private RexNode simplifyPredicate(RexNode predicate) { + final RexNode converted = RexUtil.removeNullabilityCast(typeFactory, predicate); + List conjuncts = RelOptUtil.conjunctions(converted); + List simplified = + conjuncts.stream() + .map(e -> RexSimplify.simplifyComparisonWithNull(e, rexBuilder)) + .collect(Collectors.toList()); + return RexUtil.composeConjunction(rexBuilder, simplified); + } + /** * Converts a WHERE clause. * @@ -1181,7 +1193,7 @@ private void convertWhere(final Blackboard bb, final @Nullable SqlNode where) { SqlNode newWhere = pushDownNotForIn(bb.scope, where); replaceSubQueries(bb, newWhere, RelOptUtil.Logic.UNKNOWN_AS_FALSE); final RexNode convertedWhere = bb.convertExpression(newWhere); - final RexNode convertedWhere2 = RexUtil.removeNullabilityCast(typeFactory, convertedWhere); + final RexNode convertedWhere2 = simplifyPredicate(convertedWhere); // only allocate filter if the condition is not TRUE if (convertedWhere2.isAlwaysTrue()) { @@ -1284,23 +1296,8 @@ private void substituteSubQuery(Blackboard bb, SubQuery subQuery) { if (!config.isExpand()) { if (query instanceof SqlNodeList) { - // convert - // select * from "scott".emp where sal > some (4000, 2000) - // to - // select * from "scott".emp where sal > some (VALUES (4000), (2000)) - // The SqlNodeList become a RexSubQuery then optimized by - // SubQueryRemoveRule. - RelNode relNode = - convertRowValues( - bb, query, (SqlNodeList) query, false, targetRowType); - final ImmutableList.Builder builder = ImmutableList.builder(); - for (SqlNode node : leftSqlKeys) { - builder.add(bb.convertExpression(node)); - } - final ImmutableList list = builder.build(); - assert relNode != null; - subQuery.expr = - createSubquery(subQuery.node.getKind(), relNode, list, call); + convertNodeListToSubQuery( + bb, subQuery, query, leftSqlKeys, targetRowType, call); return; } return; @@ -1339,7 +1336,12 @@ private void substituteSubQuery(Blackboard bb, SubQuery subQuery) { // // In such case, when converting SqlUpdate#condition, bb.root is null // and it makes no sense to do the sub-query substitution. + // Instead, we convert to a RexSubQuery which can be optimized later. if (bb.root == null) { + if (query instanceof SqlNodeList) { + convertNodeListToSubQuery( + bb, subQuery, query, leftSqlKeys, targetRowType, call); + } return; } @@ -1921,6 +1923,48 @@ private RexNode ensureComparisonTypes(RexNode node) { return node; } + /** + * Converts a {@link SqlNodeList} (for example an IN-list or VALUES list) into a relational + * expression and produces a Rex-level sub-query that references that relational expression. + * + *

For example, converts: + * + *

{@code
+     * select * from "scott".emp where sal > some (4000, 2000)
+     * }
+ * + * to: + * + *
{@code
+     * select * from "scott".emp where sal > some (VALUES (4000), (2000))
+     * }
+ * + * The SqlNodeList becomes a RexSubQuery then optimized by SubQueryRemoveRule. + * + * @param bb Blackboard containing the context + * @param subQuery The SubQuery to populate with the converted expression + * @param query The query node (expected to be a SqlNodeList) + * @param leftSqlKeys Left-hand key expressions for IN/SOME/ALL + * @param targetRowType The target row type for the conversion + * @param call The original SQL call + */ + private void convertNodeListToSubQuery( + Blackboard bb, + SubQuery subQuery, + SqlNode query, + List leftSqlKeys, + RelDataType targetRowType, + SqlCall call) { + RelNode relNode = convertRowValues(bb, query, (SqlNodeList) query, false, targetRowType); + final ImmutableList.Builder builder = ImmutableList.builder(); + for (SqlNode node : leftSqlKeys) { + builder.add(bb.convertExpression(node)); + } + final ImmutableList list = builder.build(); + assert relNode != null; + subQuery.expr = createSubquery(subQuery.node.getKind(), relNode, list, call); + } + /** * Gets the list size threshold under which {@link #convertInToOr} is used. Lists of this size * or greater will instead be converted to use a join against an inline table ({@link @@ -2670,15 +2714,15 @@ protected void convertFrom( convertCollectionTable(bb, call2); return; - // ----- FLINK MODIFICATION BEGIN ----- case LATERAL: call = (SqlCall) from; + // Extract and analyze lateral part of join call. assert call.getOperandList().size() == 1; final SqlCall callLateral = call.operand(0); convertFrom(bb, callLateral, fieldNames); return; - // ----- FLINK MODIFICATION END ----- + default: throw new AssertionError("not a join operator " + from); } @@ -3350,12 +3394,8 @@ protected RelNode createAsofJoin( } final ImmutableBitSet.Builder requiredColumns = ImmutableBitSet.builder(); final List correlNames = new ArrayList<>(); - - // All correlations must refer the same namespace since correlation - // produces exactly one correlation source. - // The same source might be referenced by different variables since - // DeferredLookups are not de-duplicated at create time. - SqlValidatorNamespace prevNs = null; + // Mapping from (correlId, originalFieldIndex) to projectedFieldIndex for aggregation + final Map, Integer> fieldMapping = new HashMap<>(); for (CorrelationId correlName : correlatedVariables) { DeferredLookup lookup = @@ -3371,7 +3411,6 @@ protected RelNode createAsofJoin( ImmutableList.of(lookup.originalRelName), nameMatcher, false, resolved); assert resolved.count() == 1; final SqlValidatorScope.Resolve resolve = resolved.only(); - final SqlValidatorNamespace foundNs = resolve.namespace; final RelDataType rowType = resolve.rowType(); final int childNamespaceIndex = resolve.path.steps().get(0).i; final SqlValidatorScope ancestorScope = resolve.scope; @@ -3381,18 +3420,6 @@ protected RelNode createAsofJoin( continue; } - if (prevNs == null) { - prevNs = foundNs; - } else { - assert prevNs == foundNs - : "All correlation variables should resolve" - + " to the same namespace." - + " Prev ns=" - + prevNs - + ", new ns=" - + foundNs; - } - int namespaceOffset = 0; if (childNamespaceIndex > 0) { // If not the first child, need to figure out the width @@ -3410,9 +3437,9 @@ protected RelNode createAsofJoin( while (topLevelFieldAccess.getReferenceExpr() instanceof RexFieldAccess) { topLevelFieldAccess = (RexFieldAccess) topLevelFieldAccess.getReferenceExpr(); } + final int originalFieldIndex = topLevelFieldAccess.getField().getIndex(); final RelDataTypeField field = - rowType.getFieldList() - .get(topLevelFieldAccess.getField().getIndex() - namespaceOffset); + rowType.getFieldList().get(originalFieldIndex - namespaceOffset); int pos = namespaceOffset + field.getIndex(); assert field.getType() == topLevelFieldAccess.getField().getType(); @@ -3427,6 +3454,7 @@ protected RelNode createAsofJoin( // the root of the outer relation. Integer projection = exprProjection.get(pos); if (projection != null) { + fieldMapping.put(Pair.of(correlName, originalFieldIndex), projection); pos = projection; } else { // correl not grouped @@ -3458,6 +3486,19 @@ protected RelNode createAsofJoin( // Add new node to leaves. leaves.put(r, r.getRowType().getFieldCount()); } + + // If there are field mappings (due to aggregation), rewrite the RelNode tree + // to update correlation variable row type and field indices + if (!fieldMapping.isEmpty()) { + r = + r.accept( + new CorrelationFieldMappingShuttle( + rexBuilder, + correlNames.get(0), + bb.root().getRowType(), + fieldMapping)); + } + return new CorrelationUse(correlNames.get(0), requiredColumns.build(), r); } @@ -3532,7 +3573,7 @@ private void convertJoin(Blackboard bb, SqlJoin join) { final RelNode tempRightRel = requireNonNull(rightBlackboard.root, "rightBlackboard.root"); final JoinConditionType conditionType = join.getConditionType(); - final RexNode condition; + RexNode condition; RelNode rightRel; if (join.isNatural()) { condition = convertNaturalCondition(getNamespace(left), getNamespace(right)); @@ -3561,6 +3602,8 @@ private void convertJoin(Blackboard bb, SqlJoin join) { throw Util.unexpected(conditionType); } } + condition = simplifyPredicate(condition); + final RelNode joinRel; if (joinType == JoinType.ASOF || joinType == JoinType.LEFT_ASOF) { SqlNode sqlMatchCondition = @@ -3570,6 +3613,7 @@ private void convertJoin(Blackboard bb, SqlJoin join) { Pair conditionAndRightNode = convertOnCondition(fromBlackboard, sqlMatchCondition, leftRel, tempRightRel); RexNode matchCondition = conditionAndRightNode.left; + matchCondition = simplifyPredicate(matchCondition); rightRel = conditionAndRightNode.right; joinRel = createAsofJoin( @@ -4666,14 +4710,27 @@ private RelNode convertUpdate(SqlUpdate call) { targetColumnNameList.add(field.getName()); } - RelNode sourceRel = convertSelect(sourceSelect, false); + // `sourceSelect` should contain target columns values plus source expressions + if (sourceSelect.getSelectList().size() + != targetTable.getRowType().getFieldCount() + + call.getSourceExpressionList().size()) { + throw new AssertionError( + "Unexpected select list size. Select list should contain both target table columns and " + + "set expressions"); + } + RelNode sourceRel = convertSelect(sourceSelect, false); bb.setRoot(sourceRel, false); - ImmutableList.Builder rexNodeSourceExpressionListBuilder = ImmutableList.builder(); - for (SqlNode n : call.getSourceExpressionList()) { - RexNode rn = bb.convertExpression(n); - rexNodeSourceExpressionListBuilder.add(rn); - } + + // sourceRel already contains all source expressions. Only create references to those + // fields. + List rexExpressionList = + Util.transform( + Util.last( + sourceRel.getRowType().getFieldList(), targetColumnNameList.size()), + expressionField -> + new RexInputRef( + expressionField.getIndex(), expressionField.getType())); return LogicalTableModify.create( targetTable, @@ -4681,7 +4738,7 @@ private RelNode convertUpdate(SqlUpdate call) { sourceRel, LogicalTableModify.Operation.UPDATE, targetColumnNameList, - rexNodeSourceExpressionListBuilder.build(), + rexExpressionList, false); } @@ -5009,7 +5066,7 @@ public SqlNode getNode() { case ARRAY_QUERY_CONSTRUCTOR: case MAP_QUERY_CONSTRUCTOR: final RelRoot root = convertQuery(call.operand(0), false, true); - input = root.rel; + input = root.project(); break; default: lastList.add(operand); @@ -5315,6 +5372,7 @@ private void convertValuesImpl( mapping = null; } + List fields = targetRowType.getFieldList(); for (SqlNode rowConstructor : values.getOperandList()) { SqlCall newRowConst = (SqlCall) rowConstructor; Blackboard tmpBb = createBlackboard(bb.scope, null, false); @@ -5323,6 +5381,7 @@ private void convertValuesImpl( Ord.forEach( newRowConst.getOperandList(), (operand, i) -> { + RelDataType fieldType = fields.get(i).getType(); RexNode def; if (processDefaults && operand.getKind() == SqlKind.DEFAULT @@ -5336,6 +5395,9 @@ && requireNonNull(mapping, "mapping")[i] != -1) { } else { def = tmpBb.convertExpression(operand); } + if (!(def instanceof RexDynamicParam) && !def.getType().equals(fieldType)) { + def = rexBuilder.makeCast(operand.getParserPosition(), fieldType, def); + } exps.add(def, SqlValidatorUtil.alias(operand, i)); }); @@ -5937,20 +5999,20 @@ public RexNode convertExpression(SqlNode expr) { query = Iterables.getOnlyElement(call.getOperandList()); // let top=true to make the query be top-level query, // then ORDER BY will be reserved. - root = convertQueryRecursive(query, true, null); - return RexSubQuery.array(root.rel); + root = convertQuery(query, false, true); + return RexSubQuery.array(root.project()); case MAP_QUERY_CONSTRUCTOR: call = (SqlCall) expr; query = Iterables.getOnlyElement(call.getOperandList()); - root = convertQueryRecursive(query, false, null); - return RexSubQuery.map(root.rel); + root = convertQuery(query, false, true); + return RexSubQuery.map(root.project()); case MULTISET_QUERY_CONSTRUCTOR: call = (SqlCall) expr; query = Iterables.getOnlyElement(call.getOperandList()); - root = convertQueryRecursive(query, false, null); - return RexSubQuery.multiset(root.rel); + root = convertQuery(query, false, true); + return RexSubQuery.multiset(root.project()); default: break; @@ -6157,14 +6219,14 @@ private boolean isConvertedSubq(RexNode rex) { } @Override - public int getGroupCount() { + public boolean hasEmptyGroup() { if (agg != null) { - return agg.groupExprs.size(); + return SqlValidatorUtil.hasEmptyGroup(agg.groupExprs); } if (window != null) { - return window.isAlwaysNonEmpty() ? 1 : 0; + return !window.isAlwaysNonEmpty(); } - return -1; + return false; } @Override @@ -6323,6 +6385,60 @@ RexFieldAccess getFieldAccess(CorrelationId name) { } } + /** + * Shuttle that rewrites correlation field accesses to use projected field indices when + * correlation references aggregated relations. + */ + private static class CorrelationFieldMappingShuttle extends RelHomogeneousShuttle { + private final RexBuilder rexBuilder; + private final CorrelationId targetCorrelId; + private final RelDataType newCorrelRowType; + private final Map, Integer> fieldMapping; + + CorrelationFieldMappingShuttle( + RexBuilder rexBuilder, + CorrelationId targetCorrelId, + RelDataType newCorrelRowType, + Map, Integer> fieldMapping) { + this.rexBuilder = rexBuilder; + this.targetCorrelId = targetCorrelId; + this.newCorrelRowType = newCorrelRowType; + this.fieldMapping = fieldMapping; + } + + @Override + public RelNode visit(RelNode other) { + return super.visit(other) + .accept( + new RexShuttle() { + @Override + public RexNode visitFieldAccess(RexFieldAccess fieldAccess) { + if (fieldAccess.getReferenceExpr() + instanceof RexCorrelVariable) { + RexCorrelVariable correlVar = + (RexCorrelVariable) fieldAccess.getReferenceExpr(); + if (correlVar.id.equals(targetCorrelId)) { + Integer newIndex = + fieldMapping.get( + Pair.of( + correlVar.id, + fieldAccess + .getField() + .getIndex())); + if (newIndex != null) { + return rexBuilder.makeFieldAccess( + rexBuilder.makeCorrel( + newCorrelRowType, correlVar.id), + newIndex); + } + } + } + return super.visitFieldAccess(fieldAccess); + } + }); + } + } + /** A default implementation of SubQueryConverter that does no conversion. */ private static class NoOpSubQueryConverter implements SubQueryConverter { @Override diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecMatch.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecMatch.java index 8da68878070a6..3e8918d181020 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecMatch.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecMatch.java @@ -34,6 +34,7 @@ import org.apache.flink.streaming.api.operators.ChainingStrategy; import org.apache.flink.streaming.api.transformations.OneInputTransformation; import org.apache.flink.table.api.TableException; +import org.apache.flink.table.api.ValidationException; import org.apache.flink.table.data.RowData; import org.apache.flink.table.planner.codegen.CodeGenUtils; import org.apache.flink.table.planner.codegen.CodeGeneratorContext; @@ -68,6 +69,7 @@ import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexNodeAndFieldIndex; import org.apache.calcite.sql.SqlMatchRecognize; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.fun.SqlStdOperatorTable; @@ -367,6 +369,12 @@ public Pattern visitCall(RexCall call) { } } + @Override + public Pattern visitNodeAndFieldIndex( + RexNodeAndFieldIndex nodeAndFieldIndex) { + throw new ValidationException("not supported yet"); + } + @Override public Pattern visitNode(RexNode rexNode) { throw new TableException( diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/LogicalCorrelateToJoinFromTemporalTableFunctionRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/LogicalCorrelateToJoinFromTemporalTableFunctionRule.java index d80062919e505..55a2f7d6b482f 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/LogicalCorrelateToJoinFromTemporalTableFunctionRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/LogicalCorrelateToJoinFromTemporalTableFunctionRule.java @@ -49,6 +49,7 @@ import org.apache.calcite.rex.RexFieldAccess; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexNodeAndFieldIndex; import org.apache.calcite.rex.RexVisitorImpl; import org.apache.calcite.sql.SqlOperator; import org.immutables.value.Value; @@ -336,6 +337,11 @@ public RexNode visitFieldAccess(RexFieldAccess fieldAccess) { return rexBuilder.makeInputRef(leftSide, leftIndex); } + @Override + public RexNode visitNodeAndFieldIndex(RexNodeAndFieldIndex nodeAndFieldIndex) { + throw new ValidationException("not supported yet"); + } + @Override public RexNode visitInputRef(RexInputRef inputRef) { return inputRef; diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/RemoteCorrelateSplitRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/RemoteCorrelateSplitRule.java index 1515b0975ff88..005f4bf054049 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/RemoteCorrelateSplitRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/RemoteCorrelateSplitRule.java @@ -18,6 +18,7 @@ package org.apache.flink.table.planner.plan.rules.logical; +import org.apache.flink.table.api.ValidationException; import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCalc; import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCorrelate; import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan; @@ -37,6 +38,7 @@ import org.apache.calcite.rex.RexFieldAccess; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexNodeAndFieldIndex; import org.apache.calcite.rex.RexProgram; import org.apache.calcite.rex.RexProgramBuilder; import org.apache.calcite.rex.RexUtil; @@ -131,6 +133,11 @@ public RexNode visitFieldAccess(RexFieldAccess fieldAccess) { } } + @Override + public RexNode visitNodeAndFieldIndex(RexNodeAndFieldIndex nodeAndFieldIndex) { + throw new ValidationException("not supported yet"); + } + @Override public RexNode visitCall(RexCall call) { List newProjects = diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/SplitPythonConditionFromCorrelateRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/SplitPythonConditionFromCorrelateRule.java index f284644717631..0a1acd46b23ae 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/SplitPythonConditionFromCorrelateRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/SplitPythonConditionFromCorrelateRule.java @@ -18,6 +18,7 @@ package org.apache.flink.table.planner.plan.rules.logical; +import org.apache.flink.table.api.ValidationException; import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCalc; import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCorrelate; import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalRel; @@ -34,6 +35,7 @@ import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexNodeAndFieldIndex; import org.apache.calcite.rex.RexProgram; import org.apache.calcite.rex.RexProgramBuilder; import org.apache.calcite.rex.RexUtil; @@ -217,6 +219,11 @@ public RexNode visitCall(RexCall call) { .collect(Collectors.toList())); } + @Override + public RexNode visitNodeAndFieldIndex(RexNodeAndFieldIndex nodeAndFieldIndex) { + throw new ValidationException("not supported yet"); + } + @Override public RexNode visitNode(RexNode rexNode) { return rexNode; diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/common/CommonPhysicalMatchRule.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/common/CommonPhysicalMatchRule.java index 5476fbbbebeed..8766292dca29d 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/common/CommonPhysicalMatchRule.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/physical/common/CommonPhysicalMatchRule.java @@ -37,6 +37,7 @@ import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexNodeAndFieldIndex; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.util.ImmutableBitSet; @@ -169,6 +170,11 @@ public Object visitCall(RexCall call) { return null; } + @Override + public Object visitNodeAndFieldIndex(RexNodeAndFieldIndex nodeAndFieldIndex) { + throw new ValidationException("not supported yet"); + } + @Override public Object visitNode(RexNode rexNode) { return null; diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/AsyncUtil.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/AsyncUtil.java index e283252b7043a..70ff3ce25fb4e 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/AsyncUtil.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/AsyncUtil.java @@ -18,6 +18,7 @@ package org.apache.flink.table.planner.plan.utils; +import org.apache.flink.table.api.ValidationException; import org.apache.flink.table.functions.FunctionDefinition; import org.apache.flink.table.functions.FunctionKind; import org.apache.flink.table.planner.plan.rules.logical.RemoteCallFinder; @@ -26,6 +27,7 @@ import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexNodeAndFieldIndex; /** Utility class for working with async function calls in RexNodes. */ public class AsyncUtil { @@ -135,6 +137,11 @@ public Boolean visitCall(RexCall call) { || (recursive && call.getOperands().stream().anyMatch(node -> node.accept(this))); } + + @Override + public Boolean visitNodeAndFieldIndex(RexNodeAndFieldIndex nodeAndFieldIndex) { + throw new ValidationException("not supported yet"); + } } /** diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/ConstantFoldingUtil.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/ConstantFoldingUtil.java index c8e3e40d7a313..23e62a406f0a5 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/ConstantFoldingUtil.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/ConstantFoldingUtil.java @@ -18,11 +18,13 @@ package org.apache.flink.table.planner.plan.utils; +import org.apache.flink.table.api.ValidationException; import org.apache.flink.table.functions.FunctionDefinition; import org.apache.flink.table.planner.utils.ShortcutUtils; import org.apache.calcite.rex.RexCall; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexNodeAndFieldIndex; /** Utility for deciding whether than expression supports constant folding or not. */ public class ConstantFoldingUtil { @@ -55,5 +57,10 @@ public Boolean visitCall(RexCall call) { return supportsConstantFolding && (call.getOperands().stream().allMatch(node -> node.accept(this))); } + + @Override + public Boolean visitNodeAndFieldIndex(RexNodeAndFieldIndex nodeAndFieldIndex) { + throw new ValidationException("not supported yet"); + } } } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/PythonUtil.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/PythonUtil.java index 65efb05f430ef..07c29b83407ae 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/PythonUtil.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/PythonUtil.java @@ -18,6 +18,7 @@ package org.apache.flink.table.planner.plan.utils; +import org.apache.flink.table.api.ValidationException; import org.apache.flink.table.functions.DeclarativeAggregateFunction; import org.apache.flink.table.functions.FunctionDefinition; import org.apache.flink.table.functions.python.PythonFunction; @@ -39,6 +40,7 @@ import org.apache.calcite.rex.RexFieldAccess; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexNodeAndFieldIndex; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlKind; @@ -272,6 +274,11 @@ public Boolean visitFieldAccess(RexFieldAccess fieldAccess) { return fieldAccess.getReferenceExpr().accept(this); } + @Override + public Boolean visitNodeAndFieldIndex(RexNodeAndFieldIndex nodeAndFieldIndex) { + throw new ValidationException("not supported yet"); + } + @Override public Boolean visitNode(RexNode rexNode) { return false; @@ -304,6 +311,11 @@ public Boolean visitFieldAccess(RexFieldAccess fieldAccess) { } } + @Override + public Boolean visitNodeAndFieldIndex(RexNodeAndFieldIndex nodeAndFieldIndex) { + throw new ValidationException("not supported yet"); + } + @Override public Boolean visitCall(RexCall call) { if (call.getKind() == SqlKind.AS) { diff --git a/flink-table/flink-table-planner/src/main/resources/META-INF/NOTICE b/flink-table/flink-table-planner/src/main/resources/META-INF/NOTICE index 0688c621bf0b6..943be5e3d5ff0 100644 --- a/flink-table/flink-table-planner/src/main/resources/META-INF/NOTICE +++ b/flink-table/flink-table-planner/src/main/resources/META-INF/NOTICE @@ -8,13 +8,13 @@ This project bundles the following dependencies under the Apache Software Licens - com.google.guava:guava:33.4.0-jre - com.google.guava:failureaccess:1.0.2 -- org.apache.calcite:calcite-core:1.40.0 -- org.apache.calcite:calcite-linq4j:1.40.0 -- org.apache.calcite.avatica:avatica-core:1.26.0 +- org.apache.calcite:calcite-core:1.41.0 +- org.apache.calcite:calcite-linq4j:1.41.0 +- org.apache.calcite.avatica:avatica-core:1.27.0 - org.apache.commons:commons-math3:3.6.1 -- org.apache.commons:commons-text:1.10.0 - commons-codec:commons-codec:1.15 - commons-io:commons-io:2.15.1 +- org.jooq:joou-java-6:0.9.4 This project bundles the following dependencies under the MIT License. (http://www.opensource.org/licenses/mit-license.php) diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala index b05eac28261ba..9427dace35da5 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala @@ -1026,4 +1026,9 @@ class ExprCodeGenerator( ShortcutUtils.isDeterministicThroughProgram( node, CodeGenUtils.getExprsFromProgramOrNull(rexProgram)) + + override def visitNodeAndFieldIndex( + nodeAndFieldIndex: RexNodeAndFieldIndex): GeneratedExpression = { + throw new CodeGenException("RexNodeAndFieldIndex are not supported yet.") + } } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/RemoteCalcSplitRule.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/RemoteCalcSplitRule.scala index 26baa44f43d82..9530871e43ddd 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/RemoteCalcSplitRule.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/RemoteCalcSplitRule.scala @@ -18,12 +18,13 @@ package org.apache.flink.table.planner.plan.rules.logical import org.apache.flink.table.functions.ScalarFunction +import org.apache.flink.table.planner.codegen.CodeGenException import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCalc import org.apache.flink.table.planner.plan.utils.{InputRefVisitor, RexDefaultVisitor} import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall} import org.apache.calcite.plan.RelOptRule.{any, operand} -import org.apache.calcite.rex.{RexBuilder, RexCall, RexCorrelVariable, RexFieldAccess, RexInputRef, RexLocalRef, RexNode, RexProgram} +import org.apache.calcite.rex.{RexBuilder, RexCall, RexCorrelVariable, RexFieldAccess, RexInputRef, RexLocalRef, RexNode, RexNodeAndFieldIndex, RexProgram} import org.apache.calcite.sql.validate.SqlValidatorUtil import java.util.function.Function @@ -498,6 +499,10 @@ class ScalarFunctionSplitter( new RexInputRef(fieldsRexCall(rexCallIndex), remoteCall.getType), node.getField.getIndex) } + + override def visitNodeAndFieldIndex(nodeAndFieldIndex: RexNodeAndFieldIndex): RexNode = { + throw new CodeGenException("RexNodeAndFieldIndex are not supported yet.") + } } /** @@ -546,4 +551,8 @@ private class ExtractedFunctionInputRewriter( } override def visitNode(rexNode: RexNode): RexNode = rexNode + + override def visitNodeAndFieldIndex(nodeAndFieldIndex: RexNodeAndFieldIndex): RexNode = { + throw new CodeGenException("RexNodeAndFieldIndex are not supported yet.") + } } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/MatchUtil.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/MatchUtil.scala index 0c1760e851020..6486114115d43 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/MatchUtil.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/MatchUtil.scala @@ -18,13 +18,14 @@ package org.apache.flink.table.planner.plan.utils import org.apache.flink.table.api.ValidationException +import org.apache.flink.table.planner.codegen.CodeGenException import org.apache.flink.table.planner.codegen.MatchCodeGenerator.ALL_PATTERN_VARIABLE import org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable import org.apache.flink.table.planner.plan.logical.MatchRecognize import org.apache.flink.table.planner.plan.nodes.exec.spec.{MatchSpec, PartitionSpec} import _root_.scala.collection.JavaConversions._ -import org.apache.calcite.rex.{RexCall, RexNode, RexPatternFieldRef} +import org.apache.calcite.rex.{RexCall, RexNode, RexNodeAndFieldIndex, RexPatternFieldRef} import org.apache.calcite.sql.fun.SqlStdOperatorTable object MatchUtil { @@ -57,6 +58,10 @@ object MatchUtil { } override def visitNode(rexNode: RexNode): Option[String] = None + + override def visitNodeAndFieldIndex(nodeAndFieldIndex: RexNodeAndFieldIndex): Option[String] = { + throw new CodeGenException("RexNodeAndFieldIndex are not supported yet.") + } } /** Convert [[MatchRecognize]] to [[MatchSpec]]. */ diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/RexNodeExtractor.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/RexNodeExtractor.scala index 4e2d1d662053f..40029d6c6064a 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/RexNodeExtractor.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/RexNodeExtractor.scala @@ -27,6 +27,7 @@ import org.apache.flink.table.expressions.ApiExpressionUtils._ import org.apache.flink.table.functions.{BuiltInFunctionDefinition, FunctionIdentifier} import org.apache.flink.table.functions.BuiltInFunctionDefinitions.{AND, CAST, OR, TRY_CAST} import org.apache.flink.table.planner.calcite.FlinkTypeFactory +import org.apache.flink.table.planner.codegen.CodeGenException import org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable import org.apache.flink.table.planner.utils.Logging import org.apache.flink.table.planner.utils.TimestampStringUtils.toLocalDateTime @@ -621,4 +622,8 @@ class RexNodeToExpressionConverter( str.replaceAll("\\s|_", "") } + override def visitNodeAndFieldIndex( + nodeAndFieldIndex: RexNodeAndFieldIndex): Option[ResolvedExpression] = { + throw new CodeGenException("RexNodeAndFieldIndex are not supported yet.") + } } diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/ValuesTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/ValuesTest.xml index 089c876dea6c0..fd253c2ec8793 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/ValuesTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/ValuesTest.xml @@ -25,9 +25,9 @@ limitations under the License. LogicalProject(a=[$0], b=[$1]), rowType=[RecordType(INTEGER a, DECIMAL(20, 1) b)] +- LogicalProject(a=[$0], b=[$1]), rowType=[RecordType(INTEGER a, DECIMAL(20, 1) b)] +- LogicalUnion(all=[true]), rowType=[RecordType(INTEGER EXPR$0, DECIMAL(20, 1) EXPR$1)] - :- LogicalProject(EXPR$0=[1], EXPR$1=[2.0:DECIMAL(2, 1)]), rowType=[RecordType(INTEGER EXPR$0, DECIMAL(2, 1) EXPR$1)] + :- LogicalProject(EXPR$0=[1], EXPR$1=[2.0:DECIMAL(20, 1)]), rowType=[RecordType(INTEGER EXPR$0, DECIMAL(20, 1) EXPR$1)] : +- LogicalValues(tuples=[[{ 0 }]]), rowType=[RecordType(INTEGER ZERO)] - +- LogicalProject(EXPR$0=[3], EXPR$1=[4:BIGINT]), rowType=[RecordType(INTEGER EXPR$0, BIGINT EXPR$1)] + +- LogicalProject(EXPR$0=[3], EXPR$1=[4.0:DECIMAL(20, 1)]), rowType=[RecordType(INTEGER EXPR$0, DECIMAL(20, 1) EXPR$1)] +- LogicalValues(tuples=[[{ 0 }]]), rowType=[RecordType(INTEGER ZERO)] ]]> diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/GroupingSetsTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/GroupingSetsTest.xml index ae9e14592854d..af6bc3c4d935c 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/GroupingSetsTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/GroupingSetsTest.xml @@ -182,10 +182,9 @@ LogicalUnion(all=[true]) : +- LogicalAggregate(group=[{0}], groups=[[{0}, {}]], c=[COUNT()]) : +- LogicalProject(deptno=[$7]) : +- LogicalTableScan(table=[[default_catalog, default_database, scott_emp]]) -+- LogicalProject(deptno=[$0], g=[1:BIGINT], c=[$1]) - +- LogicalAggregate(group=[{0}], groups=[[{}]], c=[COUNT()]) - +- LogicalProject(deptno=[$7]) - +- LogicalTableScan(table=[[default_catalog, default_database, scott_emp]]) ++- LogicalProject(deptno=[null:INTEGER], g=[1:BIGINT], c=[$0]) + +- LogicalAggregate(group=[{}], c=[COUNT()]) + +- LogicalTableScan(table=[[default_catalog, default_database, scott_emp]]) ]]> @@ -196,13 +195,14 @@ Union(all=[true], union=[deptno, g, c]) : +- Exchange(distribution=[hash[deptno, $e]]) : +- LocalHashAggregate(groupBy=[deptno, $e], select=[deptno, $e, Partial_COUNT(*) AS count1$0]) : +- Expand(projects=[{deptno, 0 AS $e}, {null AS deptno, 1 AS $e}]) -: +- Calc(select=[deptno])(reuse_id=[1]) -: +- TableSourceScan(table=[[default_catalog, default_database, scott_emp]], fields=[empno, ename, job, mgr, hiredate, sal, comm, deptno]) -+- Calc(select=[deptno, 1 AS g, c]) - +- HashAggregate(isMerge=[true], groupBy=[deptno], select=[deptno, Final_COUNT(count1$0) AS c]) - +- Exchange(distribution=[hash[deptno]]) - +- LocalHashAggregate(groupBy=[deptno], select=[deptno, Partial_COUNT(*) AS count1$0]) - +- Reused(reference_id=[1]) +: +- Calc(select=[deptno]) +: +- TableSourceScan(table=[[default_catalog, default_database, scott_emp]], fields=[empno, ename, job, mgr, hiredate, sal, comm, deptno])(reuse_id=[1]) ++- Calc(select=[null:INTEGER AS deptno, 1 AS g, c]) + +- HashAggregate(isMerge=[true], select=[Final_COUNT(count1$0) AS c]) + +- Exchange(distribution=[single]) + +- LocalHashAggregate(select=[Partial_COUNT(*) AS count1$0]) + +- Calc(select=[0 AS $f0]) + +- Reused(reference_id=[1]) ]]> diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/join/BroadcastHashJoinTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/join/BroadcastHashJoinTest.xml index dc0d9c3e721b3..1a38b9d05e958 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/join/BroadcastHashJoinTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/join/BroadcastHashJoinTest.xml @@ -472,7 +472,7 @@ Calc(select=[d, e, f]) @@ -43,9 +42,9 @@ FlinkLogicalUnion(all=[true]) : +- FlinkLogicalExpand(projects=[{a, 0 AS $e}, {null AS a, 1 AS $e}]) : +- FlinkLogicalCalc(select=[a]) : +- FlinkLogicalTableSourceScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c]) -+- FlinkLogicalCalc(select=[a, 1 AS g, c]) - +- FlinkLogicalAggregate(group=[{0}], groups=[[{}]], c=[COUNT()]) - +- FlinkLogicalCalc(select=[a]) ++- FlinkLogicalCalc(select=[null:INTEGER AS a, 1 AS g, c]) + +- FlinkLogicalAggregate(group=[{}], c=[COUNT()]) + +- FlinkLogicalCalc(select=[0 AS $f0]) +- FlinkLogicalTableSourceScan(table=[[default_catalog, default_database, MyTable]], fields=[a, b, c]) ]]> diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/FlinkFilterJoinRuleTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/FlinkFilterJoinRuleTest.xml index 9d2ed41f76409..3850d355889a7 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/FlinkFilterJoinRuleTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/FlinkFilterJoinRuleTest.xml @@ -563,7 +563,7 @@ LogicalProject(a1=[$0], b1=[$1], c1=[$2], b2=[$3], c2=[$4], a2=[$5]) diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/SimplifyFilterConditionRuleTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/SimplifyFilterConditionRuleTest.xml index c32350b43b61b..466f054085fa2 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/SimplifyFilterConditionRuleTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/SimplifyFilterConditionRuleTest.xml @@ -179,10 +179,10 @@ LogicalProject(a=[$0], b=[$1], c=[$2]) ($0, 1)), =($1, 3)))]) LogicalTableScan(table=[[default_catalog, default_database, y]]) -}), true)]) +})]) +- LogicalTableScan(table=[[default_catalog, default_database, x]]) ]]> diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/SimplifyJoinConditionRuleTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/SimplifyJoinConditionRuleTest.xml index ec5e4f7f1516e..1ec0eaf18bbb4 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/SimplifyJoinConditionRuleTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/SimplifyJoinConditionRuleTest.xml @@ -59,15 +59,19 @@ LogicalAggregate(group=[{}], EXPR$0=[COUNT()]) diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/subquery/SubQueryAntiJoinTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/subquery/SubQueryAntiJoinTest.xml index 0e9221a7d8541..5ee0b2fc55d14 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/subquery/SubQueryAntiJoinTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/subquery/SubQueryAntiJoinTest.xml @@ -96,33 +96,41 @@ LogicalProject(c=[$2]) +- LogicalFilter(condition=[OR(=($3, 0), AND(IS NULL($13), >=($4, $3)))]) +- LogicalJoin(condition=[=(CASE(AND(IS NOT TRUE(OR(IS NULL($1), =($6, 0))), IS NOT NULL($10)), 1, 2), $11)], joinType=[left]) :- LogicalJoin(condition=[AND(=($0, $9), =($1, $8))], joinType=[left]) - : :- LogicalJoin(condition=[=($0, $5)], joinType=[left]) + : :- LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $5)], joinType=[left]) : : :- LogicalJoin(condition=[true], joinType=[inner]) : : : :- LogicalTableScan(table=[[default_catalog, default_database, l]]) - : : : +- LogicalAggregate(group=[{}], c=[COUNT()], ck=[COUNT($0, $1)]) - : : : +- LogicalProject(d=[$0], e=[$1]) + : : : +- LogicalAggregate(group=[{}], c=[COUNT()], ck=[COUNT() FILTER $0]) + : : : +- LogicalProject($f2=[OR(IS NOT NULL($0), IS NOT NULL($1))]) : : : +- LogicalTableScan(table=[[default_catalog, default_database, r]]) - : : +- LogicalAggregate(group=[{0}], c=[COUNT()], ck=[COUNT($1)]) - : : +- LogicalProject(i=[$0], j=[$1]) - : : +- LogicalFilter(condition=[IS NOT NULL($0)]) - : : +- LogicalTableScan(table=[[default_catalog, default_database, t]]) + : : +- LogicalProject(i=[$0], c=[CASE(IS NOT NULL($2), $2, 0)], ck=[CASE(IS NOT NULL($3), $3, 0)]) + : : +- LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $1)], joinType=[left]) + : : :- LogicalAggregate(group=[{0}]) + : : : +- LogicalJoin(condition=[true], joinType=[inner]) + : : : :- LogicalTableScan(table=[[default_catalog, default_database, l]]) + : : : +- LogicalAggregate(group=[{}], c=[COUNT()], ck=[COUNT() FILTER $0]) + : : : +- LogicalProject($f2=[OR(IS NOT NULL($0), IS NOT NULL($1))]) + : : : +- LogicalTableScan(table=[[default_catalog, default_database, r]]) + : : +- LogicalAggregate(group=[{0}], c=[COUNT()], ck=[COUNT($1)]) + : : +- LogicalProject(i=[$0], j=[$1]) + : : +- LogicalFilter(condition=[IS NOT NULL($0)]) + : : +- LogicalTableScan(table=[[default_catalog, default_database, t]]) : +- LogicalFilter(condition=[IS NOT NULL($0)]) : +- LogicalProject(j=[$0], i2=[$1], $f2=[true]) : +- LogicalAggregate(group=[{0, 1}]) : +- LogicalProject(j=[$1], i2=[$0]) : +- LogicalFilter(condition=[IS NOT NULL($0)]) : +- LogicalTableScan(table=[[default_catalog, default_database, t]]) - +- LogicalJoin(condition=[true], joinType=[left]) - :- LogicalJoin(condition=[true], joinType=[left]) - : :- LogicalFilter(condition=[=(4, $1)]) - : : +- LogicalProject(d=[$0], e=[$1], $f2=[true]) - : : +- LogicalAggregate(group=[{0, 1}]) - : : +- LogicalProject(d=[$0], e=[$1]) - : : +- LogicalTableScan(table=[[default_catalog, default_database, r]]) - : +- LogicalProject($f0=[true], dummy=[$0]) - : +- LogicalValues(tuples=[[{ true }]]) - +- LogicalProject($f0=[true], dummy=[$0]) - +- LogicalValues(tuples=[[{ true }]]) + +- LogicalFilter(condition=[=(CASE(IS NULL($3), 3, IS NOT NULL($5), 4, 5), $1)]) + +- LogicalJoin(condition=[true], joinType=[left]) + :- LogicalJoin(condition=[true], joinType=[left]) + : :- LogicalProject(d=[$0], e=[$1], $f2=[true]) + : : +- LogicalAggregate(group=[{0, 1}]) + : : +- LogicalProject(d=[$0], e=[$1]) + : : +- LogicalTableScan(table=[[default_catalog, default_database, r]]) + : +- LogicalProject($f0=[true], dummy=[$0]) + : +- LogicalValues(tuples=[[{ true }]]) + +- LogicalProject($f0=[true], dummy=[$0]) + +- LogicalValues(tuples=[[{ true }]]) ]]> @@ -901,15 +909,21 @@ LogicalProject(d=[$0], e=[$1]) =($4, $3)))]) - +- LogicalJoin(condition=[=(CASE(OR(=($5, 0), AND(IS NULL($8), >=($6, $5), IS NOT NULL($1))), 3, 4), $10)], joinType=[left]) - :- LogicalJoin(condition=[=($1, $7)], joinType=[left]) ++- LogicalFilter(condition=[OR(=($3, 0), AND(IS NULL($13), >=($4, $3)))]) + +- LogicalJoin(condition=[AND(=(CASE(IS NULL($6), 1, 2), $11), =(CASE(OR(=($7, 0), AND(IS NULL($10), >=($8, $7), IS NOT NULL($1))), 3, 4), $12))], joinType=[left]) + :- LogicalJoin(condition=[=($1, $9)], joinType=[left]) : :- LogicalJoin(condition=[true], joinType=[inner]) - : : :- LogicalJoin(condition=[true], joinType=[inner]) - : : : :- LogicalTableScan(table=[[default_catalog, default_database, l]]) - : : : +- LogicalAggregate(group=[{}], c=[COUNT()], ck=[COUNT($0, $1)]) - : : : +- LogicalProject(d=[$0], e=[$1]) - : : : +- LogicalTableScan(table=[[default_catalog, default_database, r]]) + : : :- LogicalJoin(condition=[=($0, $5)], joinType=[left]) + : : : :- LogicalJoin(condition=[true], joinType=[inner]) + : : : : :- LogicalTableScan(table=[[default_catalog, default_database, l]]) + : : : : +- LogicalAggregate(group=[{}], c=[COUNT()], ck=[COUNT() FILTER $0]) + : : : : +- LogicalProject($f2=[OR(IS NOT NULL($0), IS NOT NULL($1))]) + : : : : +- LogicalTableScan(table=[[default_catalog, default_database, r]]) + : : : +- LogicalProject(i1=[$0], $f1=[true]) + : : : +- LogicalAggregate(group=[{0}]) + : : : +- LogicalProject(i1=[$0]) + : : : +- LogicalFilter(condition=[IS NOT NULL($0)]) + : : : +- LogicalTableScan(table=[[default_catalog, default_database, t]]) : : +- LogicalAggregate(group=[{}], c=[COUNT()], ck=[COUNT($0)]) : : +- LogicalProject(m=[$1]) : : +- LogicalTableScan(table=[[default_catalog, default_database, t2]]) @@ -917,17 +931,10 @@ LogicalProject(c=[$2]) : +- LogicalAggregate(group=[{0}]) : +- LogicalProject(m=[$1]) : +- LogicalTableScan(table=[[default_catalog, default_database, t2]]) - +- LogicalJoin(condition=[=($0, $3)], joinType=[left]) - :- LogicalFilter(condition=[=(2, $0)]) - : +- LogicalProject(d=[$0], e=[$1], $f2=[true]) - : +- LogicalAggregate(group=[{0, 1}]) - : +- LogicalProject(d=[$0], e=[$1]) - : +- LogicalTableScan(table=[[default_catalog, default_database, r]]) - +- LogicalProject(i1=[$0], $f1=[true]) - +- LogicalAggregate(group=[{0}]) - +- LogicalProject(i1=[$0]) - +- LogicalFilter(condition=[IS NOT NULL($0)]) - +- LogicalTableScan(table=[[default_catalog, default_database, t]]) + +- LogicalProject(d=[$0], e=[$1], $f2=[true]) + +- LogicalAggregate(group=[{0, 1}]) + +- LogicalProject(d=[$0], e=[$1]) + +- LogicalTableScan(table=[[default_catalog, default_database, r]]) ]]> @@ -2182,16 +2189,24 @@ LogicalProject(b=[$1]) +- LogicalFilter(condition=[OR(=($3, 0), AND(IS NULL($12), >=($4, $3)))]) +- LogicalJoin(condition=[=(CASE(OR(=($6, 0), IS NOT TRUE(OR(IS NULL($0), IS NOT NULL($10), <($7, $6)))), 1, 2), $11)], joinType=[left]) :- LogicalJoin(condition=[=($0, $9)], joinType=[left]) - : :- LogicalJoin(condition=[=($0, $5)], joinType=[left]) + : :- LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $5)], joinType=[left]) : : :- LogicalJoin(condition=[true], joinType=[inner]) : : : :- LogicalTableScan(table=[[default_catalog, default_database, l]]) : : : +- LogicalAggregate(group=[{}], c=[COUNT()], ck=[COUNT($0)]) : : : +- LogicalProject(d=[$0]) : : : +- LogicalTableScan(table=[[default_catalog, default_database, r]]) - : : +- LogicalAggregate(group=[{0}], c=[COUNT()], ck=[COUNT($1)]) - : : +- LogicalProject(i1=[$0], i=[$0]) - : : +- LogicalFilter(condition=[IS NOT NULL($0)]) - : : +- LogicalTableScan(table=[[default_catalog, default_database, t1]]) + : : +- LogicalProject(i1=[$0], c=[CASE(IS NOT NULL($2), $2, 0)], ck=[CASE(IS NOT NULL($3), $3, 0)]) + : : +- LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $1)], joinType=[left]) + : : :- LogicalAggregate(group=[{0}]) + : : : +- LogicalJoin(condition=[true], joinType=[inner]) + : : : :- LogicalTableScan(table=[[default_catalog, default_database, l]]) + : : : +- LogicalAggregate(group=[{}], c=[COUNT()], ck=[COUNT($0)]) + : : : +- LogicalProject(d=[$0]) + : : : +- LogicalTableScan(table=[[default_catalog, default_database, r]]) + : : +- LogicalAggregate(group=[{0}], c=[COUNT()], ck=[COUNT($1)]) + : : +- LogicalProject(i1=[$0], i=[$0]) + : : +- LogicalFilter(condition=[IS NOT NULL($0)]) + : : +- LogicalTableScan(table=[[default_catalog, default_database, t1]]) : +- LogicalFilter(condition=[=($1, $0)]) : +- LogicalProject(i=[$0], i2=[$1], $f2=[true]) : +- LogicalAggregate(group=[{0, 1}]) diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/subquery/SubQuerySemiJoinTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/subquery/SubQuerySemiJoinTest.xml index 8d81faec9ddd0..7f410f0ea78de 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/subquery/SubQuerySemiJoinTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/subquery/SubQuerySemiJoinTest.xml @@ -35,7 +35,7 @@ LogicalFilter(condition=[=($1, $cor0.b)]) ($3, 0), IS NOT NULL($6), IS NOT NULL($1)), 3, 4), $8))], joinType=[inner]) - :- LogicalJoin(condition=[=($1, $5)], joinType=[left]) ++- LogicalJoin(condition=[AND(=(CASE(IS NOT NULL($4), 1, 2), $9), =(CASE(AND(<>($5, 0), IS NOT NULL($8), IS NOT NULL($1)), 3, 4), $10))], joinType=[inner]) + :- LogicalJoin(condition=[=($1, $7)], joinType=[left]) : :- LogicalJoin(condition=[true], joinType=[inner]) - : : :- LogicalTableScan(table=[[default_catalog, default_database, l]]) + : : :- LogicalJoin(condition=[=($0, $3)], joinType=[left]) + : : : :- LogicalTableScan(table=[[default_catalog, default_database, l]]) + : : : +- LogicalProject(i1=[$0], $f1=[true]) + : : : +- LogicalAggregate(group=[{0}]) + : : : +- LogicalProject(i1=[$0]) + : : : +- LogicalFilter(condition=[IS NOT NULL($0)]) + : : : +- LogicalTableScan(table=[[default_catalog, default_database, t]]) : : +- LogicalAggregate(group=[{}], c=[COUNT()], ck=[COUNT($0)]) : : +- LogicalProject(m=[$1]) : : +- LogicalTableScan(table=[[default_catalog, default_database, t2]]) @@ -1557,15 +1563,9 @@ LogicalProject(c=[$2]) : +- LogicalAggregate(group=[{0}]) : +- LogicalProject(m=[$1]) : +- LogicalTableScan(table=[[default_catalog, default_database, t2]]) - +- LogicalJoin(condition=[=($0, $2)], joinType=[left]) - :- LogicalAggregate(group=[{0, 1}]) - : +- LogicalProject(d=[$0], e=[$1]) - : +- LogicalTableScan(table=[[default_catalog, default_database, r]]) - +- LogicalProject(i1=[$0], $f1=[true]) - +- LogicalAggregate(group=[{0}]) - +- LogicalProject(i1=[$0]) - +- LogicalFilter(condition=[IS NOT NULL($0)]) - +- LogicalTableScan(table=[[default_catalog, default_database, t]]) + +- LogicalAggregate(group=[{0, 1}]) + +- LogicalProject(d=[$0], e=[$1]) + +- LogicalTableScan(table=[[default_catalog, default_database, r]]) ]]> @@ -1612,7 +1612,7 @@ LogicalProject(d=[$1]) diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/ValuesTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/ValuesTest.xml index 42260c1b33fc6..e61d858649383 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/ValuesTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/ValuesTest.xml @@ -25,9 +25,9 @@ limitations under the License. LogicalProject(a=[$0], b=[$1]), rowType=[RecordType(INTEGER a, DECIMAL(20, 1) b)] +- LogicalProject(a=[$0], b=[$1]), rowType=[RecordType(INTEGER a, DECIMAL(20, 1) b)] +- LogicalUnion(all=[true]), rowType=[RecordType(INTEGER EXPR$0, DECIMAL(20, 1) EXPR$1)] - :- LogicalProject(EXPR$0=[1], EXPR$1=[2.0:DECIMAL(2, 1)]), rowType=[RecordType(INTEGER EXPR$0, DECIMAL(2, 1) EXPR$1)] + :- LogicalProject(EXPR$0=[1], EXPR$1=[2.0:DECIMAL(20, 1)]), rowType=[RecordType(INTEGER EXPR$0, DECIMAL(20, 1) EXPR$1)] : +- LogicalValues(tuples=[[{ 0 }]]), rowType=[RecordType(INTEGER ZERO)] - +- LogicalProject(EXPR$0=[3], EXPR$1=[4:BIGINT]), rowType=[RecordType(INTEGER EXPR$0, BIGINT EXPR$1)] + +- LogicalProject(EXPR$0=[3], EXPR$1=[4.0:DECIMAL(20, 1)]), rowType=[RecordType(INTEGER EXPR$0, DECIMAL(20, 1) EXPR$1)] +- LogicalValues(tuples=[[{ 0 }]]), rowType=[RecordType(INTEGER ZERO)] ]]> diff --git a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/GroupingSetsTest.xml b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/GroupingSetsTest.xml index a5a504566af3f..a56cec435289a 100644 --- a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/GroupingSetsTest.xml +++ b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/GroupingSetsTest.xml @@ -159,6 +159,42 @@ Calc(select=[deptno, c]) +- Expand(projects=[{deptno, 0 AS $e}, {null AS deptno, 1 AS $e}]) +- Calc(select=[deptno]) +- TableSourceScan(table=[[default_catalog, default_database, emps]], fields=[empno, name, deptno, gender, city, empid, age, slacker, manager, joinedat]) +]]> + + + + + + + + + + + diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdDistinctRowCountTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdDistinctRowCountTest.scala index 4beb3b1802ad6..0c781ca357ea6 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdDistinctRowCountTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdDistinctRowCountTest.scala @@ -105,7 +105,7 @@ class FlinkRelMdDistinctRowCountTest extends FlinkRelMdHandlerTestBase { assertEquals(1.0, mq.getDistinctRowCount(logicalProject, ImmutableBitSet.of(8), null)) assertEquals(1.0, mq.getDistinctRowCount(logicalProject, ImmutableBitSet.of(9), null)) assertEquals(1.0, mq.getDistinctRowCount(logicalProject, ImmutableBitSet.of(10), null)) - assertEquals(17.13, mq.getDistinctRowCount(logicalProject, ImmutableBitSet.of(11), null), 1e-2) + assertEquals(20.0, mq.getDistinctRowCount(logicalProject, ImmutableBitSet.of(11), null), 1e-2) // id > 10 val expr1 = relBuilder @@ -124,7 +124,10 @@ class FlinkRelMdDistinctRowCountTest extends FlinkRelMdHandlerTestBase { assertEquals(1.0, mq.getDistinctRowCount(logicalProject, ImmutableBitSet.of(8), expr1)) assertEquals(1.0, mq.getDistinctRowCount(logicalProject, ImmutableBitSet.of(9), expr1)) assertEquals(1.0, mq.getDistinctRowCount(logicalProject, ImmutableBitSet.of(10), expr1)) - assertEquals(17.13, mq.getDistinctRowCount(logicalProject, ImmutableBitSet.of(11), expr1), 1e-2) + assertEquals( + 16.464466094067262, + mq.getDistinctRowCount(logicalProject, ImmutableBitSet.of(11), expr1), + 1e-2) // age > 15 and class = 5 val expr2 = relBuilder @@ -146,7 +149,7 @@ class FlinkRelMdDistinctRowCountTest extends FlinkRelMdHandlerTestBase { assertEquals(1.0, mq.getDistinctRowCount(logicalProject, ImmutableBitSet.of(8), expr2)) assertEquals(1.0, mq.getDistinctRowCount(logicalProject, ImmutableBitSet.of(9), expr2)) assertEquals(1.0, mq.getDistinctRowCount(logicalProject, ImmutableBitSet.of(10), expr2)) - assertEquals(17.13, mq.getDistinctRowCount(logicalProject, ImmutableBitSet.of(11), expr2), 1e-2) + assertEquals(1.0, mq.getDistinctRowCount(logicalProject, ImmutableBitSet.of(11), expr2), 1e-2) assertEquals(1.0, mq.getDistinctRowCount(logicalProject, ImmutableBitSet.of(0, 1), expr2)) } @@ -200,7 +203,10 @@ class FlinkRelMdDistinctRowCountTest extends FlinkRelMdHandlerTestBase { assertEquals(1.0, mq.getDistinctRowCount(calc, ImmutableBitSet.of(8), null)) assertEquals(1.0, mq.getDistinctRowCount(calc, ImmutableBitSet.of(9), null)) assertEquals(1.0, mq.getDistinctRowCount(calc, ImmutableBitSet.of(10), null)) - assertEquals(11.22, mq.getDistinctRowCount(calc, ImmutableBitSet.of(11), null), 1e-2) + assertEquals( + 16.464466094067262, + mq.getDistinctRowCount(calc, ImmutableBitSet.of(11), null), + 1e-2) // class = 5 relBuilder.push(calc) @@ -217,7 +223,10 @@ class FlinkRelMdDistinctRowCountTest extends FlinkRelMdHandlerTestBase { assertEquals(1.0, mq.getDistinctRowCount(calc, ImmutableBitSet.of(8), expr2)) assertEquals(1.0, mq.getDistinctRowCount(calc, ImmutableBitSet.of(9), expr2)) assertEquals(1.0, mq.getDistinctRowCount(calc, ImmutableBitSet.of(10), expr2)) - assertEquals(11.22, mq.getDistinctRowCount(calc, ImmutableBitSet.of(11), expr2), 1e-2) + assertEquals( + 10.257214207425065, + mq.getDistinctRowCount(calc, ImmutableBitSet.of(11), expr2), + 1e-2) } @Test diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdPopulationSizeTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdPopulationSizeTest.scala index 8de367956090e..12ad645c46216 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdPopulationSizeTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdPopulationSizeTest.scala @@ -77,7 +77,7 @@ class FlinkRelMdPopulationSizeTest extends FlinkRelMdHandlerTestBase { assertEquals(1.0, mq.getPopulationSize(logicalProject, ImmutableBitSet.of(8))) assertEquals(1.0, mq.getPopulationSize(logicalProject, ImmutableBitSet.of(9))) assertEquals(1.0, mq.getPopulationSize(logicalProject, ImmutableBitSet.of(10))) - assertEquals(16.43, mq.getPopulationSize(logicalProject, ImmutableBitSet.of(11)), 1e-2) + assertEquals(20.0, mq.getPopulationSize(logicalProject, ImmutableBitSet.of(11)), 1e-2) assertEquals(50.0, mq.getPopulationSize(logicalProject, ImmutableBitSet.of(0, 1))) assertEquals(31.24, mq.getPopulationSize(logicalProject, ImmutableBitSet.of(1, 8)), 1e-2) @@ -113,7 +113,7 @@ class FlinkRelMdPopulationSizeTest extends FlinkRelMdHandlerTestBase { assertEquals(1.0, mq.getPopulationSize(logicalCalc, ImmutableBitSet.of(8))) assertEquals(1.0, mq.getPopulationSize(logicalCalc, ImmutableBitSet.of(9))) assertEquals(1.0, mq.getPopulationSize(logicalCalc, ImmutableBitSet.of(10))) - assertEquals(11.22, mq.getPopulationSize(logicalCalc, ImmutableBitSet.of(11)), 1e-2) + assertEquals(20.0, mq.getPopulationSize(logicalCalc, ImmutableBitSet.of(11)), 1e-2) assertEquals(50.0, mq.getPopulationSize(logicalCalc, ImmutableBitSet.of(0, 1))) assertEquals(19.64, mq.getPopulationSize(logicalCalc, ImmutableBitSet.of(1, 8)), 1e-2) diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/agg/GroupingSetsTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/agg/GroupingSetsTest.scala index 6b5c8ff9ca837..94a791d5a89d1 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/agg/GroupingSetsTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/agg/GroupingSetsTest.scala @@ -407,9 +407,7 @@ class GroupingSetsTest extends TableTestBase { |SELECT deptno, GROUP_ID() AS g, COUNT(*) AS c |FROM scott_emp GROUP BY GROUPING SETS (deptno, (), ()) """.stripMargin - assertThatThrownBy(() => util.verifyExecPlan(sqlQuery)) - .hasMessageContaining("GROUPING SETS are currently not supported") - .isInstanceOf[TableException] + util.verifyExecPlan(sqlQuery) } @Test diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/GroupingSetsITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/GroupingSetsITCase.scala index f79fb8ddbd296..d9aaae4e2790f 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/GroupingSetsITCase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/GroupingSetsITCase.scala @@ -618,13 +618,12 @@ class GroupingSetsITCase extends BatchTestBase { "select deptno, group_id() as g, count(*) as c " + "from scott_emp group by grouping sets (deptno, (), ())", Seq( + // Suddenly before CALCITE-7126 it returned result different from Postgres/Oracle row(10, 0, 3), - row(10, 1, 3), row(20, 0, 5), - row(20, 1, 5), row(30, 0, 6), - row(30, 1, 6), - row(null, 0, 14)) + row(null, 0, 14), + row(null, 1, 14)) ) } diff --git a/flink-table/pom.xml b/flink-table/pom.xml index c17c122c0e60a..b4d0ddd96546b 100644 --- a/flink-table/pom.xml +++ b/flink-table/pom.xml @@ -79,7 +79,7 @@ under the License. - 1.40.0 + 1.41.0 3.1.12 33.4.0-jre 2.5.2