From 5ded3893eb88c8cc1578896a256962ce08259263 Mon Sep 17 00:00:00 2001 From: FearfulTomcat27 <1471335448@qq.com> Date: Thu, 26 Mar 2026 17:16:11 +0800 Subject: [PATCH 1/2] fix: Validate weight parameter in approx_percentile function and handle invalid cases --- .../query/recent/IoTDBTableAggregationIT.java | 14 ++++++ ...ApproxPercentileWithWeightAccumulator.java | 50 +++++++++++++++++++ .../metadata/TableMetadataImpl.java | 21 ++++---- 3 files changed, 76 insertions(+), 9 deletions(-) diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationIT.java index df6ed279ea00c..3ddb613baa218 100644 --- a/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationIT.java @@ -4352,6 +4352,12 @@ public void approxPercentileTest() { "2024-09-24T06:15:55.000Z,shanghai,55,null,", }, DATABASE_NAME); + + tableResultSetEqualTest( + "select approx_percentile(s1,null,0.5) from table1", + new String[] {"_col0"}, + new String[] {"null,"}, + DATABASE_NAME); } @Test @@ -4432,6 +4438,14 @@ public void exceptionTest() { "select approx_percentile(s5,0.5) from table1", "701: Aggregation functions [approx_percentile] should have value column as numeric type [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]", DATABASE_NAME); + tableAssertTestFail( + "select approx_percentile(s1,-1,0.5) from table1", + "701: weight must be >= 1, was -1", + DATABASE_NAME); + tableAssertTestFail( + "select approx_percentile(s1,s2,0.5) from table1", + "701: Aggregation functions [approx_percentile] do not support weight as INT64 type", + DATABASE_NAME); } // ================================================================== diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/ApproxPercentileWithWeightAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/ApproxPercentileWithWeightAccumulator.java index 0e7de6c612b88..f52165472f7ec 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/ApproxPercentileWithWeightAccumulator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/ApproxPercentileWithWeightAccumulator.java @@ -14,6 +14,8 @@ package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation; +import org.apache.iotdb.db.exception.sql.SemanticException; + import org.apache.tsfile.block.column.Column; import org.apache.tsfile.enums.TSDataType; @@ -32,6 +34,12 @@ public void addIntInput(Column[] arguments, AggregationMask mask) { if (mask.isSelectAll()) { for (int i = 0; i < valueColumn.getPositionCount(); i++) { + if (weightColumn.isNull(i)) { + continue; + } + if (weightColumn.getInt(i) < 1) { + throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(i)); + } if (!valueColumn.isNull(i)) { tDigest.add(valueColumn.getInt(i), weightColumn.getInt(i)); } @@ -41,6 +49,12 @@ public void addIntInput(Column[] arguments, AggregationMask mask) { int position; for (int i = 0; i < positionCount; i++) { position = selectedPositions[i]; + if (weightColumn.isNull(position)) { + continue; + } + if (weightColumn.getInt(position) < 1) { + throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(position)); + } if (!valueColumn.isNull(position)) { tDigest.add(valueColumn.getInt(position), weightColumn.getInt(position)); } @@ -57,6 +71,12 @@ public void addLongInput(Column[] arguments, AggregationMask mask) { if (mask.isSelectAll()) { for (int i = 0; i < valueColumn.getPositionCount(); i++) { + if (weightColumn.isNull(i)) { + continue; + } + if (weightColumn.getInt(i) < 1) { + throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(i)); + } if (!valueColumn.isNull(i)) { tDigest.add(toDoubleExact(valueColumn.getLong(i)), weightColumn.getInt(i)); } @@ -66,6 +86,12 @@ public void addLongInput(Column[] arguments, AggregationMask mask) { int position; for (int i = 0; i < positionCount; i++) { position = selectedPositions[i]; + if (weightColumn.isNull(position)) { + continue; + } + if (weightColumn.getInt(position) < 1) { + throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(position)); + } if (!valueColumn.isNull(position)) { tDigest.add(toDoubleExact(valueColumn.getLong(position)), weightColumn.getInt(position)); } @@ -82,6 +108,12 @@ public void addFloatInput(Column[] arguments, AggregationMask mask) { if (mask.isSelectAll()) { for (int i = 0; i < valueColumn.getPositionCount(); i++) { + if (weightColumn.isNull(i)) { + continue; + } + if (weightColumn.getInt(i) < 1) { + throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(i)); + } if (!valueColumn.isNull(i)) { tDigest.add(valueColumn.getFloat(i), weightColumn.getInt(i)); } @@ -91,6 +123,12 @@ public void addFloatInput(Column[] arguments, AggregationMask mask) { int position; for (int i = 0; i < positionCount; i++) { position = selectedPositions[i]; + if (weightColumn.isNull(position)) { + continue; + } + if (weightColumn.getInt(position) < 1) { + throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(position)); + } if (!valueColumn.isNull(position)) { tDigest.add(valueColumn.getFloat(position), weightColumn.getInt(position)); } @@ -107,6 +145,12 @@ public void addDoubleInput(Column[] arguments, AggregationMask mask) { if (mask.isSelectAll()) { for (int i = 0; i < valueColumn.getPositionCount(); i++) { + if (weightColumn.isNull(i)) { + continue; + } + if (weightColumn.getInt(i) < 1) { + throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(i)); + } if (!valueColumn.isNull(i)) { tDigest.add(valueColumn.getDouble(i), weightColumn.getInt(i)); } @@ -116,6 +160,12 @@ public void addDoubleInput(Column[] arguments, AggregationMask mask) { int position; for (int i = 0; i < positionCount; i++) { position = selectedPositions[i]; + if (weightColumn.isNull(position)) { + continue; + } + if (weightColumn.getInt(position) < 1) { + throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(position)); + } if (!valueColumn.isNull(position)) { tDigest.add(valueColumn.getDouble(position), weightColumn.getInt(position)); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java index ed0be5b10d02c..000d9cddf808e 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java @@ -1194,19 +1194,22 @@ && isIntegerNumber(argumentTypes.get(2)))) { functionName)); } - // Validate percentage and weight parameters - boolean hasInvalidTypes = - (argumentSize == 2 && !isDecimalType(argumentTypes.get(1))) - || (argumentSize == 3 - && (!isIntegerNumber(argumentTypes.get(1)) - || !isDecimalType(argumentTypes.get(2)))); - - if (hasInvalidTypes) { + Type percentageType = argumentTypes.get(argumentSize - 1); + if (!isDecimalType(percentageType)) { throw new SemanticException( String.format( - "Aggregation functions [%s] should have weight as integer type and percentage as decimal type", + "Aggregation functions [%s] should have percentage as decimal type", functionName)); } + if (argumentSize == 3) { + Type weightType = argumentTypes.get(1); + if (!INT32.equals(weightType) && !isUnknownType(weightType)) { + throw new SemanticException( + String.format( + "Aggregation functions [%s] do not support weight as %s type", + functionName, weightType.getDisplayName())); + } + } break; case SqlConstant.COUNT: From a9a0069b45baf6a926a045b0304ec6065342be3c Mon Sep 17 00:00:00 2001 From: FearfulTomcat27 <1471335448@qq.com> Date: Thu, 26 Mar 2026 17:43:29 +0800 Subject: [PATCH 2/2] fix: Validate weight parameter in GroupedApproxPercentileWithWeightAccumulator and handle null cases --- .../query/recent/IoTDBTableAggregationIT.java | 10 ++++ ...ApproxPercentileWithWeightAccumulator.java | 49 +++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationIT.java index 3ddb613baa218..460fe83a397d2 100644 --- a/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationIT.java @@ -4358,6 +4358,12 @@ public void approxPercentileTest() { new String[] {"_col0"}, new String[] {"null,"}, DATABASE_NAME); + + tableResultSetEqualTest( + "select 1 as g, approx_percentile(s1,null,0.5) from table1 group by 1", + new String[] {"g", "_col1"}, + new String[] {"1,null,"}, + DATABASE_NAME); } @Test @@ -4446,6 +4452,10 @@ public void exceptionTest() { "select approx_percentile(s1,s2,0.5) from table1", "701: Aggregation functions [approx_percentile] do not support weight as INT64 type", DATABASE_NAME); + tableAssertTestFail( + "select 1 as g, approx_percentile(s1,s2,0.5) from table1 group by 1", + "701: Aggregation functions [approx_percentile] do not support weight as INT64 type", + DATABASE_NAME); } // ================================================================== diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedApproxPercentileWithWeightAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedApproxPercentileWithWeightAccumulator.java index abde1dde650a9..b7c587df0d0b4 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedApproxPercentileWithWeightAccumulator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedApproxPercentileWithWeightAccumulator.java @@ -14,6 +14,7 @@ package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped; +import org.apache.iotdb.db.exception.sql.SemanticException; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.AggregationMask; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.approximate.TDigest; @@ -36,6 +37,12 @@ public void addIntInput(int[] groupIds, Column[] arguments, AggregationMask mask if (mask.isSelectAll()) { for (int i = 0; i < positionCount; i++) { + if (weightColumn.isNull(i)) { + continue; + } + if (weightColumn.getInt(i) < 1) { + throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(i)); + } int groupId = groupIds[i]; TDigest tDigest = array.get(groupId); if (!valueColumn.isNull(i)) { @@ -48,6 +55,12 @@ public void addIntInput(int[] groupIds, Column[] arguments, AggregationMask mask int groupId; for (int i = 0; i < positionCount; i++) { position = selectedPositions[i]; + if (weightColumn.isNull(position)) { + continue; + } + if (weightColumn.getInt(position) < 1) { + throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(position)); + } groupId = groupIds[position]; TDigest tDigest = array.get(groupId); if (!valueColumn.isNull(position)) { @@ -66,6 +79,12 @@ public void addLongInput(int[] groupIds, Column[] arguments, AggregationMask mas if (mask.isSelectAll()) { for (int i = 0; i < positionCount; i++) { + if (weightColumn.isNull(i)) { + continue; + } + if (weightColumn.getInt(i) < 1) { + throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(i)); + } int groupId = groupIds[i]; TDigest tDigest = array.get(groupId); if (!valueColumn.isNull(i)) { @@ -78,6 +97,12 @@ public void addLongInput(int[] groupIds, Column[] arguments, AggregationMask mas int groupId; for (int i = 0; i < positionCount; i++) { position = selectedPositions[i]; + if (weightColumn.isNull(position)) { + continue; + } + if (weightColumn.getInt(position) < 1) { + throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(position)); + } groupId = groupIds[position]; TDigest tDigest = array.get(groupId); if (!valueColumn.isNull(position)) { @@ -96,6 +121,12 @@ public void addFloatInput(int[] groupIds, Column[] arguments, AggregationMask ma if (mask.isSelectAll()) { for (int i = 0; i < positionCount; i++) { + if (weightColumn.isNull(i)) { + continue; + } + if (weightColumn.getInt(i) < 1) { + throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(i)); + } int groupId = groupIds[i]; TDigest tDigest = array.get(groupId); if (!valueColumn.isNull(i)) { @@ -108,6 +139,12 @@ public void addFloatInput(int[] groupIds, Column[] arguments, AggregationMask ma int groupId; for (int i = 0; i < positionCount; i++) { position = selectedPositions[i]; + if (weightColumn.isNull(position)) { + continue; + } + if (weightColumn.getInt(position) < 1) { + throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(position)); + } groupId = groupIds[position]; TDigest tDigest = array.get(groupId); if (!valueColumn.isNull(position)) { @@ -126,6 +163,12 @@ public void addDoubleInput(int[] groupIds, Column[] arguments, AggregationMask m if (mask.isSelectAll()) { for (int i = 0; i < positionCount; i++) { + if (weightColumn.isNull(i)) { + continue; + } + if (weightColumn.getInt(i) < 1) { + throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(i)); + } int groupId = groupIds[i]; TDigest tDigest = array.get(groupId); if (!valueColumn.isNull(i)) { @@ -138,6 +181,12 @@ public void addDoubleInput(int[] groupIds, Column[] arguments, AggregationMask m int groupId; for (int i = 0; i < positionCount; i++) { position = selectedPositions[i]; + if (weightColumn.isNull(position)) { + continue; + } + if (weightColumn.getInt(position) < 1) { + throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(position)); + } groupId = groupIds[position]; TDigest tDigest = array.get(groupId); if (!valueColumn.isNull(position)) {