diff --git a/integration-test/src/main/java/org/apache/iotdb/itbase/constant/BuiltinAggregationFunctionEnum.java b/integration-test/src/main/java/org/apache/iotdb/itbase/constant/BuiltinAggregationFunctionEnum.java index 7c2c283d30e3a..5ebb8e12f9be0 100644 --- a/integration-test/src/main/java/org/apache/iotdb/itbase/constant/BuiltinAggregationFunctionEnum.java +++ b/integration-test/src/main/java/org/apache/iotdb/itbase/constant/BuiltinAggregationFunctionEnum.java @@ -42,7 +42,14 @@ public enum BuiltinAggregationFunctionEnum { AVG("avg"), SUM("sum"), MAX_BY("max_by"), - MIN_BY("min_by"); + MIN_BY("min_by"), + CORR("corr"), + COVAR_POP("covar_pop"), + COVAR_SAMP("covar_samp"), + REGR_SLOPE("regr_slope"), + REGR_INTERCEPT("regr_intercept"), + SKEWNESS("skewness"), + KURTOSIS("kurtosis"); private final String functionName; diff --git a/integration-test/src/main/java/org/apache/iotdb/itbase/constant/TestConstant.java b/integration-test/src/main/java/org/apache/iotdb/itbase/constant/TestConstant.java index 8a9d11516c5ae..75ff63cef2498 100644 --- a/integration-test/src/main/java/org/apache/iotdb/itbase/constant/TestConstant.java +++ b/integration-test/src/main/java/org/apache/iotdb/itbase/constant/TestConstant.java @@ -145,6 +145,34 @@ public static String varSamp(String path) { return String.format("var_samp(%s)", path); } + public static String corr(String path) { + return String.format("corr(%s)", path); + } + + public static String covarPop(String path) { + return String.format("covar_pop(%s)", path); + } + + public static String covarSamp(String path) { + return String.format("covar_samp(%s)", path); + } + + public static String regrSlope(String path) { + return String.format("regr_slope(%s)", path); + } + + public static String regrIntercept(String path) { + return String.format("regr_intercept(%s)", path); + } + + public static String kurtosis(String path) { + return String.format("kurtosis(%s)", path); + } + + public static String skewness(String path) { + return String.format("skewness(%s)", path); + } + public static String countUDAF(String path) { return String.format("count_udaf(%s)", path); } diff --git a/integration-test/src/test/java/org/apache/iotdb/db/it/aggregation/IoTDBCorrelationIT.java b/integration-test/src/test/java/org/apache/iotdb/db/it/aggregation/IoTDBCorrelationIT.java new file mode 100644 index 0000000000000..7148388d3e4e6 --- /dev/null +++ b/integration-test/src/test/java/org/apache/iotdb/db/it/aggregation/IoTDBCorrelationIT.java @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iotdb.db.it.aggregation; + +import org.apache.iotdb.it.env.EnvFactory; +import org.apache.iotdb.it.framework.IoTDBTestRunner; +import org.apache.iotdb.itbase.category.ClusterIT; +import org.apache.iotdb.itbase.category.LocalStandaloneIT; + +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.Statement; + +import static org.apache.iotdb.db.it.utils.TestUtils.prepareData; +import static org.apache.iotdb.db.it.utils.TestUtils.resultSetEqualTest; +import static org.apache.iotdb.itbase.constant.TestConstant.DEVICE; +import static org.apache.iotdb.itbase.constant.TestConstant.TIMESTAMP_STR; +import static org.apache.iotdb.itbase.constant.TestConstant.corr; +import static org.junit.Assert.fail; + +@RunWith(IoTDBTestRunner.class) +@Category({LocalStandaloneIT.class, ClusterIT.class}) +public class IoTDBCorrelationIT { + + protected static final String[] SQLs = + new String[] { + "CREATE DATABASE root.db", + "CREATE TIMESERIES root.db.d1.s1 WITH DATATYPE=INT32, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d1.s2 WITH DATATYPE=INT64, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d1.s3 WITH DATATYPE=BOOLEAN, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d1.s4 WITH DATATYPE=FLOAT, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d1.s5 WITH DATATYPE=DOUBLE, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d1.s6 WITH DATATYPE=TEXT, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d2.s1 WITH DATATYPE=INT32, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d2.s2 WITH DATATYPE=INT64, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d2.s4 WITH DATATYPE=FLOAT, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d2.s5 WITH DATATYPE=DOUBLE, ENCODING=PLAIN", + "INSERT INTO root.db.d1(timestamp,s1,s2,s3,s4,s5,s6) values(1, 1, 1, true, 1, 1, \"1\")", + "INSERT INTO root.db.d1(timestamp,s1,s2,s3,s4,s5,s6) values(2, 2, 2, false, 2, 2, \"2\")", + "INSERT INTO root.db.d1(timestamp,s1,s2,s3,s4,s5,s6) values(3, 3, 2, false, 3, 2, \"2\")", + "INSERT INTO root.db.d1(timestamp,s1,s2,s3,s4,s5,s6) values(10000000000, 4, 1, true, 4, 1, \"1\")", + "INSERT INTO root.db.d1(timestamp,s1,s2,s3,s4,s5,s6) values(10000000001, 5, 1, true, 5, 1, \"1\")", + "INSERT INTO root.db.d2(timestamp,s1,s2,s4,s5) values(1, 1, 2, 3, 4)", + "INSERT INTO root.db.d2(timestamp,s1,s2,s4,s5) values(2, 1, 2, 3, 4)", + "INSERT INTO root.db.d2(timestamp,s1,s2,s4,s5) values(10000000000, 1, 2, 3, 4)", + "INSERT INTO root.db.d2(timestamp,s1,s2,s4,s5) values(10000000001, 1, 2, 3, 4)", + "INSERT INTO root.db.d2(timestamp,s1,s2,s4,s5) values(10000000002, 1, 2, 3, 4)", + "flush" + }; + + @BeforeClass + public static void setUp() throws Exception { + EnvFactory.getEnv().getConfig().getCommonConfig().setPartitionInterval(1000); + EnvFactory.getEnv().initClusterEnvironment(); + prepareData(SQLs); + } + + @AfterClass + public static void tearDown() throws Exception { + EnvFactory.getEnv().cleanClusterEnvironment(); + } + + @Test + public void testCorrWithUnsupportedTypesAndWrongArity() { + String typeError = + "Aggregate functions [CORR, COVAR_POP, COVAR_SAMP, REGR_SLOPE, REGR_INTERCEPT] only support " + + "numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]"; + String argError = "Error size of input expressions"; + try (Connection connection = EnvFactory.getEnv().getConnection(); + Statement statement = connection.createStatement()) { + try { + try (ResultSet resultSet = statement.executeQuery("SELECT corr(s3, s1) FROM root.db.d1")) { + resultSet.next(); + fail(); + } + } catch (Exception e) { + Assert.assertTrue(e.getMessage(), e.getMessage().contains(typeError)); + } + + try { + try (ResultSet resultSet = statement.executeQuery("SELECT corr(s6, s1) FROM root.db.d1")) { + resultSet.next(); + fail(); + } + } catch (Exception e) { + Assert.assertTrue(e.getMessage(), e.getMessage().contains(typeError)); + } + + try { + statement.executeQuery("SELECT corr(s1) FROM root.db.d1"); + fail(); + } catch (Exception e) { + Assert.assertTrue(e.getMessage(), e.getMessage().contains(argError)); + } + } catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testCorrWithDifferentTypes() { + String[] expectedHeader = + new String[] {corr("root.db.d1.s1, root.db.d1.s2"), corr("root.db.d1.s4, root.db.d1.s5")}; + String[] retArray = new String[] {"-0.28867513459481287,-0.28867513459481287,"}; + resultSetEqualTest("select corr(s1,s2),corr(s4,s5) from root.db.d1", expectedHeader, retArray); + + retArray = new String[] {"0.8660254037844386,0.8660254037844386,"}; + resultSetEqualTest( + "select corr(s1,s2),corr(s4,s5) from root.db.d1 where time < 10", expectedHeader, retArray); + } + + @Test + public void testCorrAlignByDevice() { + String[] expectedHeader = new String[] {DEVICE, corr("s1, s2"), corr("s4, s5")}; + String[] retArray = new String[] {"root.db.d1,-0.28867513459481287,-0.28867513459481287,"}; + resultSetEqualTest( + "select corr(s1,s2),corr(s4,s5) from root.db.d1 align by device", expectedHeader, retArray); + + retArray = new String[] {"root.db.d1,0.8660254037844386,0.8660254037844386,"}; + resultSetEqualTest( + "select corr(s1,s2),corr(s4,s5) from root.db.d1 where time < 10 align by device", + expectedHeader, + retArray); + } + + @Test + public void testCorrInHaving() { + String[] expectedHeader = new String[] {corr("root.db.d1.s1, root.db.d1.s2")}; + String[] retArray = new String[] {"-0.28867513459481287,"}; + resultSetEqualTest( + "select corr(s1,s2) from root.db.d1 having corr(s1,s2) < 0", expectedHeader, retArray); + } + + @Test + public void testCorrMultiDeviceWithoutGroupByLevel() { + String[] expectedHeader = + new String[] { + corr("root.db.d1.s1, root.db.d1.s2"), + corr("root.db.d1.s1, root.db.d2.s2"), + corr("root.db.d2.s1, root.db.d1.s2"), + corr("root.db.d2.s1, root.db.d2.s2") + }; + String[] retArray = new String[] {"-0.28867513459481287,null,null,null,"}; + resultSetEqualTest("select corr(s1,s2) from root.db.d1,root.db.d2", expectedHeader, retArray); + } + + @Test + public void testCorrMultiDeviceWithGroupByLevel() { + String[] expectedHeader = new String[] {corr("root.*.*.s1, root.*.*.s2")}; + String[] retArray = new String[] {"-0.08111071056538134,"}; + resultSetEqualTest( + "select corr(s1,s2) from root.db.* group by level = 0", expectedHeader, retArray); + } + + @Test + public void testCorrWithSlidingWindow() { + String[] expectedHeader = new String[] {TIMESTAMP_STR, corr("root.db.d1.s1, root.db.d1.s2")}; + String[] retArray = new String[] {"1,0.8660254037844387,", "3,null,"}; + resultSetEqualTest( + "select corr(s1,s2) from root.db.d1 group by time([1,4),3ms,2ms)", + expectedHeader, + retArray); + } +} diff --git a/integration-test/src/test/java/org/apache/iotdb/db/it/aggregation/IoTDBCovarianceIT.java b/integration-test/src/test/java/org/apache/iotdb/db/it/aggregation/IoTDBCovarianceIT.java new file mode 100644 index 0000000000000..4d384d9f8aa39 --- /dev/null +++ b/integration-test/src/test/java/org/apache/iotdb/db/it/aggregation/IoTDBCovarianceIT.java @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.it.aggregation; + +import org.apache.iotdb.it.env.EnvFactory; +import org.apache.iotdb.it.framework.IoTDBTestRunner; +import org.apache.iotdb.itbase.category.ClusterIT; +import org.apache.iotdb.itbase.category.LocalStandaloneIT; + +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; + +import java.sql.Connection; +import java.sql.Statement; + +import static org.apache.iotdb.db.it.utils.TestUtils.prepareData; +import static org.apache.iotdb.db.it.utils.TestUtils.resultSetEqualTest; +import static org.apache.iotdb.itbase.constant.TestConstant.DEVICE; +import static org.apache.iotdb.itbase.constant.TestConstant.TIMESTAMP_STR; +import static org.apache.iotdb.itbase.constant.TestConstant.covarPop; +import static org.apache.iotdb.itbase.constant.TestConstant.covarSamp; +import static org.junit.Assert.fail; + +@RunWith(IoTDBTestRunner.class) +@Category({LocalStandaloneIT.class, ClusterIT.class}) +public class IoTDBCovarianceIT { + + protected static final String[] SQLs = + new String[] { + "CREATE DATABASE root.db", + "CREATE TIMESERIES root.db.d1.s1 WITH DATATYPE=INT32, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d1.s2 WITH DATATYPE=INT64, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d1.s3 WITH DATATYPE=BOOLEAN, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d1.s4 WITH DATATYPE=FLOAT, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d1.s5 WITH DATATYPE=DOUBLE, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d1.s6 WITH DATATYPE=TEXT, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d2.s1 WITH DATATYPE=INT32, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d2.s2 WITH DATATYPE=INT64, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d2.s4 WITH DATATYPE=FLOAT, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d2.s5 WITH DATATYPE=DOUBLE, ENCODING=PLAIN", + "INSERT INTO root.db.d1(timestamp,s1,s2,s3,s4,s5,s6) values(1, 1, 1, true, 1, 1, \"1\")", + "INSERT INTO root.db.d1(timestamp,s1,s2,s3,s4,s5,s6) values(2, 2, 2, false, 2, 2, \"2\")", + "INSERT INTO root.db.d1(timestamp,s1,s2,s3,s4,s5,s6) values(3, 3, 2, false, 3, 2, \"2\")", + "INSERT INTO root.db.d1(timestamp,s1,s2,s3,s4,s5,s6) values(10000000000, 4, 1, true, 4, 1, \"1\")", + "INSERT INTO root.db.d1(timestamp,s1,s2,s3,s4,s5,s6) values(10000000001, 5, 1, true, 5, 1, \"1\")", + "INSERT INTO root.db.d2(timestamp,s1,s2,s4,s5) values(1, 1, 2, 3, 4)", + "INSERT INTO root.db.d2(timestamp,s1,s2,s4,s5) values(2, 1, 2, 3, 4)", + "INSERT INTO root.db.d2(timestamp,s1,s2,s4,s5) values(10000000000, 1, 2, 3, 4)", + "INSERT INTO root.db.d2(timestamp,s1,s2,s4,s5) values(10000000001, 1, 2, 3, 4)", + "INSERT INTO root.db.d2(timestamp,s1,s2,s4,s5) values(10000000002, 1, 2, 3, 4)", + "flush" + }; + + @BeforeClass + public static void setUp() throws Exception { + EnvFactory.getEnv().getConfig().getCommonConfig().setPartitionInterval(1000); + EnvFactory.getEnv().initClusterEnvironment(); + prepareData(SQLs); + } + + @AfterClass + public static void tearDown() throws Exception { + EnvFactory.getEnv().cleanClusterEnvironment(); + } + + @Test + public void testCovarWithUnsupportedTypesAndWrongArity() { + String typeError = + "Aggregate functions [CORR, COVAR_POP, COVAR_SAMP, REGR_SLOPE, REGR_INTERCEPT] only support " + + "numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]"; + String argError = "Error size of input expressions"; + try (Connection connection = EnvFactory.getEnv().getConnection(); + Statement statement = connection.createStatement()) { + try { + statement.executeQuery("SELECT covar_pop(s3, s1) FROM root.db.d1"); + fail(); + } catch (Exception e) { + Assert.assertTrue(e.getMessage(), e.getMessage().contains(typeError)); + } + + try { + statement.executeQuery("SELECT covar_samp(s6, s1) FROM root.db.d1"); + fail(); + } catch (Exception e) { + Assert.assertTrue(e.getMessage(), e.getMessage().contains(typeError)); + } + + try { + statement.executeQuery("SELECT covar_pop(s1) FROM root.db.d1"); + fail(); + } catch (Exception e) { + Assert.assertTrue(e.getMessage(), e.getMessage().contains(argError)); + } + } catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testCovarWithDifferentTypes() { + String[] expectedHeader = + new String[] { + covarPop("root.db.d1.s1, root.db.d1.s2"), + covarSamp("root.db.d1.s1, root.db.d1.s2"), + covarPop("root.db.d1.s4, root.db.d1.s5"), + covarSamp("root.db.d1.s4, root.db.d1.s5") + }; + String[] retArray = + new String[] { + "-0.19999999999999998,-0.24999999999999997,-0.19999999999999998,-0.24999999999999997," + }; + resultSetEqualTest( + "select covar_pop(s1,s2),covar_samp(s1,s2),covar_pop(s4,s5),covar_samp(s4,s5) from root.db.d1", + expectedHeader, + retArray); + + retArray = + new String[] { + "0.3333333333333333,0.49999999999999994,0.3333333333333333,0.49999999999999994," + }; + resultSetEqualTest( + "select covar_pop(s1,s2),covar_samp(s1,s2),covar_pop(s4,s5),covar_samp(s4,s5) " + + "from root.db.d1 where time < 10", + expectedHeader, + retArray); + } + + @Test + public void testCovarAlignByDevice() { + String[] expectedHeader = + new String[] { + DEVICE, covarPop("s1, s2"), covarSamp("s1, s2"), covarPop("s4, s5"), covarSamp("s4, s5") + }; + String[] retArray = + new String[] { + "root.db.d1,-0.19999999999999998,-0.24999999999999997,-0.19999999999999998,-0.24999999999999997," + }; + resultSetEqualTest( + "select covar_pop(s1,s2),covar_samp(s1,s2),covar_pop(s4,s5),covar_samp(s4,s5) " + + "from root.db.d1 align by device", + expectedHeader, + retArray); + + retArray = + new String[] { + "root.db.d1,0.3333333333333333,0.49999999999999994,0.3333333333333333,0.49999999999999994," + }; + resultSetEqualTest( + "select covar_pop(s1,s2),covar_samp(s1,s2),covar_pop(s4,s5),covar_samp(s4,s5) " + + "from root.db.d1 where time < 10 align by device", + expectedHeader, + retArray); + } + + @Test + public void testCovarInHaving() { + String[] expectedHeader = + new String[] { + covarPop("root.db.d1.s1, root.db.d1.s2"), covarSamp("root.db.d1.s1, root.db.d1.s2") + }; + String[] retArray = new String[] {"-0.19999999999999998,-0.24999999999999997,"}; + resultSetEqualTest( + "select covar_pop(s1,s2),covar_samp(s1,s2) from root.db.d1 " + + "having covar_pop(s1,s2) < 0 and covar_samp(s1,s2) < 0", + expectedHeader, + retArray); + } + + @Test + public void testCovarWithGroupByLevel() { + String[] expectedHeader = + new String[] {covarPop("root.*.*.s1, root.*.*.s2"), covarSamp("root.*.*.s1, root.*.*.s2")}; + String[] retArray = new String[] {"-0.055555555555555566,-0.05882352941176478,"}; + resultSetEqualTest( + "select covar_pop(s1,s2),covar_samp(s1,s2) from root.db.* group by level = 0", + expectedHeader, + retArray); + } + + @Test + public void testCovarWithSlidingWindow() { + String[] expectedHeader = + new String[] { + TIMESTAMP_STR, + covarPop("root.db.d1.s1, root.db.d1.s2"), + covarSamp("root.db.d1.s1, root.db.d1.s2") + }; + String[] retArray = new String[] {"1,0.3333333333333333,0.5,", "3,0.0,null,"}; + resultSetEqualTest( + "select covar_pop(s1,s2),covar_samp(s1,s2) from root.db.d1 group by time([1,4),3ms,2ms)", + expectedHeader, + retArray); + } +} diff --git a/integration-test/src/test/java/org/apache/iotdb/db/it/aggregation/IoTDBRegressionIT.java b/integration-test/src/test/java/org/apache/iotdb/db/it/aggregation/IoTDBRegressionIT.java new file mode 100644 index 0000000000000..6607738fa0a66 --- /dev/null +++ b/integration-test/src/test/java/org/apache/iotdb/db/it/aggregation/IoTDBRegressionIT.java @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.it.aggregation; + +import org.apache.iotdb.it.env.EnvFactory; +import org.apache.iotdb.it.framework.IoTDBTestRunner; +import org.apache.iotdb.itbase.category.ClusterIT; +import org.apache.iotdb.itbase.category.LocalStandaloneIT; + +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; + +import java.sql.Connection; +import java.sql.Statement; + +import static org.apache.iotdb.db.it.utils.TestUtils.prepareData; +import static org.apache.iotdb.db.it.utils.TestUtils.resultSetEqualTest; +import static org.apache.iotdb.itbase.constant.TestConstant.DEVICE; +import static org.apache.iotdb.itbase.constant.TestConstant.TIMESTAMP_STR; +import static org.apache.iotdb.itbase.constant.TestConstant.regrIntercept; +import static org.apache.iotdb.itbase.constant.TestConstant.regrSlope; +import static org.junit.Assert.fail; + +@RunWith(IoTDBTestRunner.class) +@Category({LocalStandaloneIT.class, ClusterIT.class}) +public class IoTDBRegressionIT { + + protected static final String[] SQLs = + new String[] { + "CREATE DATABASE root.db", + "CREATE TIMESERIES root.db.d1.s1 WITH DATATYPE=INT32, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d1.s2 WITH DATATYPE=INT64, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d1.s3 WITH DATATYPE=BOOLEAN, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d1.s4 WITH DATATYPE=FLOAT, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d1.s5 WITH DATATYPE=DOUBLE, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d1.s6 WITH DATATYPE=TEXT, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d2.s1 WITH DATATYPE=INT32, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d2.s2 WITH DATATYPE=INT64, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d2.s4 WITH DATATYPE=FLOAT, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d2.s5 WITH DATATYPE=DOUBLE, ENCODING=PLAIN", + "INSERT INTO root.db.d1(timestamp,s1,s2,s3,s4,s5,s6) values(1, 1, 1, true, 1, 1, \"1\")", + "INSERT INTO root.db.d1(timestamp,s1,s2,s3,s4,s5,s6) values(2, 2, 2, false, 2, 2, \"2\")", + "INSERT INTO root.db.d1(timestamp,s1,s2,s3,s4,s5,s6) values(3, 3, 2, false, 3, 2, \"2\")", + "INSERT INTO root.db.d1(timestamp,s1,s2,s3,s4,s5,s6) values(10000000000, 4, 1, true, 4, 1, \"1\")", + "INSERT INTO root.db.d1(timestamp,s1,s2,s3,s4,s5,s6) values(10000000001, 5, 1, true, 5, 1, \"1\")", + "INSERT INTO root.db.d2(timestamp,s1,s2,s4,s5) values(1, 1, 2, 3, 4)", + "INSERT INTO root.db.d2(timestamp,s1,s2,s4,s5) values(2, 1, 2, 3, 4)", + "INSERT INTO root.db.d2(timestamp,s1,s2,s4,s5) values(10000000000, 1, 2, 3, 4)", + "INSERT INTO root.db.d2(timestamp,s1,s2,s4,s5) values(10000000001, 1, 2, 3, 4)", + "INSERT INTO root.db.d2(timestamp,s1,s2,s4,s5) values(10000000002, 1, 2, 3, 4)", + "flush" + }; + + @BeforeClass + public static void setUp() throws Exception { + EnvFactory.getEnv().getConfig().getCommonConfig().setPartitionInterval(1000); + EnvFactory.getEnv().initClusterEnvironment(); + prepareData(SQLs); + } + + @AfterClass + public static void tearDown() throws Exception { + EnvFactory.getEnv().cleanClusterEnvironment(); + } + + @Test + public void testRegrWithUnsupportedTypesAndWrongArity() { + String typeError = + "Aggregate functions [CORR, COVAR_POP, COVAR_SAMP, REGR_SLOPE, REGR_INTERCEPT] only support " + + "numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]"; + String argError = "Error size of input expressions"; + try (Connection connection = EnvFactory.getEnv().getConnection(); + Statement statement = connection.createStatement()) { + try { + statement.executeQuery("SELECT regr_slope(s3, s1) FROM root.db.d1"); + fail(); + } catch (Exception e) { + Assert.assertTrue(e.getMessage(), e.getMessage().contains(typeError)); + } + + try { + statement.executeQuery("SELECT regr_intercept(s6, s1) FROM root.db.d1"); + fail(); + } catch (Exception e) { + Assert.assertTrue(e.getMessage(), e.getMessage().contains(typeError)); + } + + try { + statement.executeQuery("SELECT regr_slope(s1) FROM root.db.d1"); + fail(); + } catch (Exception e) { + Assert.assertTrue(e.getMessage(), e.getMessage().contains(argError)); + } + } catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testRegrWithDifferentTypes() { + String[] expectedHeader = + new String[] { + regrSlope("root.db.d1.s1, root.db.d1.s2"), + regrIntercept("root.db.d1.s1, root.db.d1.s2"), + regrSlope("root.db.d1.s4, root.db.d1.s5"), + regrIntercept("root.db.d1.s4, root.db.d1.s5") + }; + String[] retArray = + new String[] { + "-0.8333333333333334,4.166666666666667,-0.8333333333333334,4.166666666666667," + }; + resultSetEqualTest( + "select regr_slope(s1,s2),regr_intercept(s1,s2),regr_slope(s4,s5),regr_intercept(s4,s5) " + + "from root.db.d1", + expectedHeader, + retArray); + + retArray = new String[] {"1.5,-0.5,1.5,-0.5,"}; + resultSetEqualTest( + "select regr_slope(s1,s2),regr_intercept(s1,s2),regr_slope(s4,s5),regr_intercept(s4,s5) " + + "from root.db.d1 where time < 10", + expectedHeader, + retArray); + } + + @Test + public void testRegrAlignByDevice() { + String[] expectedHeader = + new String[] { + DEVICE, + regrSlope("s1, s2"), + regrIntercept("s1, s2"), + regrSlope("s4, s5"), + regrIntercept("s4, s5") + }; + String[] retArray = + new String[] { + "root.db.d1,-0.8333333333333334,4.166666666666667,-0.8333333333333334,4.166666666666667," + }; + resultSetEqualTest( + "select regr_slope(s1,s2),regr_intercept(s1,s2),regr_slope(s4,s5),regr_intercept(s4,s5) " + + "from root.db.d1 align by device", + expectedHeader, + retArray); + + retArray = new String[] {"root.db.d1,1.5,-0.5,1.5,-0.5,"}; + resultSetEqualTest( + "select regr_slope(s1,s2),regr_intercept(s1,s2),regr_slope(s4,s5),regr_intercept(s4,s5) " + + "from root.db.d1 where time < 10 align by device", + expectedHeader, + retArray); + } + + @Test + public void testRegrInHaving() { + String[] expectedHeader = + new String[] { + regrSlope("root.db.d1.s1, root.db.d1.s2"), regrIntercept("root.db.d1.s1, root.db.d1.s2") + }; + String[] retArray = new String[] {"-0.8333333333333334,4.166666666666667,"}; + resultSetEqualTest( + "select regr_slope(s1,s2),regr_intercept(s1,s2) from root.db.d1 " + + "having regr_slope(s1,s2) < 0", + expectedHeader, + retArray); + } + + @Test + public void testRegrWithGroupByLevel() { + String[] expectedHeader = + new String[] { + regrSlope("root.*.*.s1, root.*.*.s2"), regrIntercept("root.*.*.s1, root.*.*.s2") + }; + String[] retArray = new String[] {"-0.25000000000000006,2.416666666666667,"}; + resultSetEqualTest( + "select regr_slope(s1,s2),regr_intercept(s1,s2) from root.db.* group by level = 0", + expectedHeader, + retArray); + } + + @Test + public void testRegrWithSlidingWindow() { + String[] expectedHeader = + new String[] { + TIMESTAMP_STR, + regrSlope("root.db.d1.s1, root.db.d1.s2"), + regrIntercept("root.db.d1.s1, root.db.d1.s2") + }; + String[] retArray = new String[] {"1,1.5,-0.5,", "3,null,null,"}; + resultSetEqualTest( + "select regr_slope(s1,s2),regr_intercept(s1,s2) from root.db.d1 group by time([1,4),3ms,2ms)", + expectedHeader, + retArray); + } + + @Test + public void testRegrWithConstantX() { + String[] expectedHeader = + new String[] { + regrSlope("root.db.d2.s1, root.db.d2.s2"), regrIntercept("root.db.d2.s1, root.db.d2.s2") + }; + String[] retArray = new String[] {"null,null,"}; + resultSetEqualTest( + "select regr_slope(s1,s2),regr_intercept(s1,s2) from root.db.d2", expectedHeader, retArray); + } +} diff --git a/integration-test/src/test/java/org/apache/iotdb/db/it/aggregation/IoTDBSkewnessKurtosisIT.java b/integration-test/src/test/java/org/apache/iotdb/db/it/aggregation/IoTDBSkewnessKurtosisIT.java new file mode 100644 index 0000000000000..7f0695ba4c44a --- /dev/null +++ b/integration-test/src/test/java/org/apache/iotdb/db/it/aggregation/IoTDBSkewnessKurtosisIT.java @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.it.aggregation; + +import org.apache.iotdb.it.env.EnvFactory; +import org.apache.iotdb.it.framework.IoTDBTestRunner; +import org.apache.iotdb.itbase.category.ClusterIT; +import org.apache.iotdb.itbase.category.LocalStandaloneIT; + +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; + +import java.sql.Connection; +import java.sql.Statement; + +import static org.apache.iotdb.db.it.utils.TestUtils.prepareData; +import static org.apache.iotdb.db.it.utils.TestUtils.resultSetEqualTest; +import static org.apache.iotdb.itbase.constant.TestConstant.DEVICE; +import static org.apache.iotdb.itbase.constant.TestConstant.TIMESTAMP_STR; +import static org.apache.iotdb.itbase.constant.TestConstant.kurtosis; +import static org.apache.iotdb.itbase.constant.TestConstant.skewness; +import static org.junit.Assert.fail; + +@RunWith(IoTDBTestRunner.class) +@Category({LocalStandaloneIT.class, ClusterIT.class}) +public class IoTDBSkewnessKurtosisIT { + + protected static final String[] SQLs = + new String[] { + "CREATE DATABASE root.db", + "CREATE TIMESERIES root.db.d1.s1 WITH DATATYPE=INT32, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d1.s2 WITH DATATYPE=INT64, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d1.s3 WITH DATATYPE=BOOLEAN, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d1.s4 WITH DATATYPE=FLOAT, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d1.s5 WITH DATATYPE=DOUBLE, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d1.s6 WITH DATATYPE=TEXT, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d2.s1 WITH DATATYPE=INT32, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d2.s2 WITH DATATYPE=INT64, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d2.s4 WITH DATATYPE=FLOAT, ENCODING=PLAIN", + "CREATE TIMESERIES root.db.d2.s5 WITH DATATYPE=DOUBLE, ENCODING=PLAIN", + "INSERT INTO root.db.d1(timestamp,s1,s2,s3,s4,s5,s6) values(1, 1, 1, true, 1, 1, \"1\")", + "INSERT INTO root.db.d1(timestamp,s1,s2,s3,s4,s5,s6) values(2, 2, 2, false, 2, 2, \"2\")", + "INSERT INTO root.db.d1(timestamp,s1,s2,s3,s4,s5,s6) values(3, 3, 2, false, 3, 2, \"2\")", + "INSERT INTO root.db.d1(timestamp,s1,s2,s3,s4,s5,s6) values(10000000000, 4, 1, true, 4, 1, \"1\")", + "INSERT INTO root.db.d1(timestamp,s1,s2,s3,s4,s5,s6) values(10000000001, 5, 1, true, 5, 1, \"1\")", + "INSERT INTO root.db.d2(timestamp,s1,s2,s4,s5) values(1, 1, 2, 3, 4)", + "INSERT INTO root.db.d2(timestamp,s1,s2,s4,s5) values(2, 1, 2, 3, 4)", + "INSERT INTO root.db.d2(timestamp,s1,s2,s4,s5) values(10000000000, 1, 2, 3, 4)", + "INSERT INTO root.db.d2(timestamp,s1,s2,s4,s5) values(10000000001, 1, 2, 3, 4)", + "INSERT INTO root.db.d2(timestamp,s1,s2,s4,s5) values(10000000002, 1, 2, 3, 4)", + "flush" + }; + + @BeforeClass + public static void setUp() throws Exception { + EnvFactory.getEnv().getConfig().getCommonConfig().setPartitionInterval(1000); + EnvFactory.getEnv().initClusterEnvironment(); + prepareData(SQLs); + } + + @AfterClass + public static void tearDown() throws Exception { + EnvFactory.getEnv().cleanClusterEnvironment(); + } + + @Test + public void testMomentsWithUnsupportedTypesAndWrongArity() { + String typeError = + "Aggregate functions [SKEWNESS, KURTOSIS] only support " + + "numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]"; + String argError = "Error size of input expressions"; + try (Connection connection = EnvFactory.getEnv().getConnection(); + Statement statement = connection.createStatement()) { + try { + statement.executeQuery("SELECT skewness(s3) FROM root.db.d1"); + fail(); + } catch (Exception e) { + Assert.assertTrue(e.getMessage(), e.getMessage().contains(typeError)); + } + + try { + statement.executeQuery("SELECT kurtosis(s6) FROM root.db.d1"); + fail(); + } catch (Exception e) { + Assert.assertTrue(e.getMessage(), e.getMessage().contains(typeError)); + } + + try { + statement.executeQuery("SELECT kurtosis(s1, s2) FROM root.db.d1"); + fail(); + } catch (Exception e) { + Assert.assertTrue(e.getMessage(), e.getMessage().contains(argError)); + } + } catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testMomentsWithDifferentTypes() { + String[] expectedHeader = + new String[] { + skewness("root.db.d1.s1"), + skewness("root.db.d1.s2"), + kurtosis("root.db.d1.s1"), + kurtosis("root.db.d1.s2"), + skewness("root.db.d1.s4"), + skewness("root.db.d1.s5"), + kurtosis("root.db.d1.s4"), + kurtosis("root.db.d1.s5") + }; + String[] retArray = + new String[] { + "0.0,0.4082482904638631,-1.2000000000000002,-3.333333333333333,0.0,0.4082482904638631,-1.2000000000000002,-3.333333333333333," + }; + resultSetEqualTest( + "select skewness(s1),skewness(s2),kurtosis(s1),kurtosis(s2)," + + "skewness(s4),skewness(s5),kurtosis(s4),kurtosis(s5) from root.db.d1", + expectedHeader, + retArray); + + expectedHeader = + new String[] { + skewness("root.db.d1.s1"), skewness("root.db.d1.s2"), kurtosis("root.db.d1.s1") + }; + retArray = new String[] {"0.0,-0.7071067811865475,null,"}; + resultSetEqualTest( + "select skewness(s1),skewness(s2),kurtosis(s1) from root.db.d1 where time < 10", + expectedHeader, + retArray); + } + + @Test + public void testMomentsAlignByDevice() { + String[] expectedHeader = + new String[] { + DEVICE, skewness("s1"), skewness("s2"), kurtosis("s1"), kurtosis("s2"), + }; + String[] retArray = + new String[] {"root.db.d1,0.0,0.4082482904638631,-1.2000000000000002,-3.333333333333333,"}; + resultSetEqualTest( + "select skewness(s1),skewness(s2),kurtosis(s1),kurtosis(s2) from root.db.d1 align by device", + expectedHeader, + retArray); + + retArray = new String[] {"root.db.d1,0.0,-0.7071067811865475,null,null,"}; + resultSetEqualTest( + "select skewness(s1),skewness(s2),kurtosis(s1),kurtosis(s2) from root.db.d1 " + + "where time < 10 align by device", + expectedHeader, + retArray); + } + + @Test + public void testMomentsInHaving() { + String[] expectedHeader = new String[] {skewness("root.db.d1.s2"), kurtosis("root.db.d1.s2")}; + String[] retArray = new String[] {"0.4082482904638631,-3.333333333333333,"}; + resultSetEqualTest( + "select skewness(s2),kurtosis(s2) from root.db.d1 having skewness(s2) > 0", + expectedHeader, + retArray); + } + + @Test + public void testMomentsWithGroupByLevel() { + String[] expectedHeader = + new String[] { + skewness("root.*.*.s1"), + kurtosis("root.*.*.s1"), + skewness("root.*.*.s2"), + kurtosis("root.*.*.s2") + }; + String[] retArray = + new String[] { + "1.0606601717798214,0.25714285714285623,-0.8728715609439697,-1.224489795918367," + }; + resultSetEqualTest( + "select skewness(s1),kurtosis(s1),skewness(s2),kurtosis(s2) from root.db.* group by level = 0", + expectedHeader, + retArray); + } + + @Test + public void testMomentsWithSlidingWindow() { + String[] expectedHeader = + new String[] {TIMESTAMP_STR, skewness("root.db.d1.s1"), kurtosis("root.db.d1.s1")}; + String[] retArray = new String[] {"1,0.0,null,", "3,null,null,"}; + resultSetEqualTest( + "select skewness(s1),kurtosis(s1) from root.db.d1 group by time([1,4),3ms,2ms)", + expectedHeader, + retArray); + } +} diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationIT.java index ec13fa95bbbca..92bf54ff76d77 100644 --- a/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationIT.java @@ -110,6 +110,28 @@ public class IoTDBTableAggregationIT { "INSERT INTO table1(time,province,city,region,device_id,color,type,s3,s5,s7,s9,s10) values (2024-09-24T06:15:30.000+00:00,'beijing','beijing','haidian','d16','yellow','BBBBBBBBBBBBBBBB',30.0,true,'beijing_haidian_yellow_B_d16_30',2024-09-24T06:15:30.000+00:00,'2024-09-24')", "INSERT INTO table1(time,province,city,region,device_id,color,type,s2,s9) values (2024-09-24T06:15:40.000+00:00,'beijing','beijing','haidian','d16','yellow','BBBBBBBBBBBBBBBB',40000,2024-09-24T06:15:40.000+00:00)", "INSERT INTO table1(time,province,city,region,device_id,color,type,s1,s4,s6,s8,s9) values (2024-09-24T06:15:55.000+00:00,'beijing','beijing','haidian','d16','yellow','BBBBBBBBBBBBBBBB',55,55.0,'beijing_haidian_yellow_B_d16_55',X'cafebabe55',2024-09-24T06:15:55.000+00:00)", + // stat_table for statistical aggregation function tests (CORR, COVAR_POP, COVAR_SAMP, + // REGR_SLOPE, REGR_INTERCEPT, KURTOSIS, SKEWNESS) + "CREATE TABLE stat_table(device_id STRING TAG, s1 INT32 FIELD, s2 INT64 FIELD, s3 FLOAT FIELD, s4 DOUBLE FIELD, s5 BOOLEAN FIELD, s6 TEXT FIELD)", + "INSERT INTO stat_table(time, device_id, s1, s2, s3, s4, s5, s6) VALUES (1, 'd1', 1, 1, 1.0, 1.0, true, 'a')", + "INSERT INTO stat_table(time, device_id, s1, s2, s3, s4, s5, s6) VALUES (2, 'd1', 2, 2, 2.0, 2.0, false, 'b')", + "INSERT INTO stat_table(time, device_id, s1, s2, s3, s4, s5, s6) VALUES (3, 'd1', 3, 2, 3.0, 2.0, false, 'c')", + "INSERT INTO stat_table(time, device_id, s1, s2, s3, s4, s5, s6) VALUES (4, 'd1', 4, 1, 4.0, 1.0, true, 'd')", + "INSERT INTO stat_table(time, device_id, s1, s2, s3, s4, s5, s6) VALUES (5, 'd1', 5, 1, 5.0, 1.0, true, 'e')", + "INSERT INTO stat_table(time, device_id, s1, s2, s3, s4) VALUES (1, 'd2', 1, 2, 1.0, 2.0)", + "INSERT INTO stat_table(time, device_id, s1, s2, s3, s4) VALUES (2, 'd2', 1, 2, 1.0, 2.0)", + "INSERT INTO stat_table(time, device_id, s1, s2, s3, s4) VALUES (3, 'd2', 1, 2, 1.0, 2.0)", + "INSERT INTO stat_table(time, device_id, s1, s2, s3, s4) VALUES (4, 'd2', 1, 2, 1.0, 2.0)", + "INSERT INTO stat_table(time, device_id, s1, s2, s3, s4) VALUES (5, 'd2', 1, 2, 1.0, 2.0)", + "INSERT INTO stat_table(time, device_id, s1, s2, s3, s4) VALUES (1, 'd3', 10, 100, 10.0, 100.0)", + "INSERT INTO stat_table(time, device_id, s1, s2, s3, s4) VALUES (2, 'd3', 20, 200, 20.0, 200.0)", + "INSERT INTO stat_table(time, device_id, s1, s2, s3, s4) VALUES (1, 'd4', 42, 99, 42.0, 99.0)", + "INSERT INTO stat_table(time, device_id, s1, s2, s3, s4) VALUES (1, 'n1', 10, 50, 10.0, 50.0)", + "INSERT INTO stat_table(time, device_id, s1, s2, s3, s4) VALUES (2, 'n1', 20, 40, 20.0, 40.0)", + "INSERT INTO stat_table(time, device_id, s1, s2, s3, s4) VALUES (3, 'n1', 30, 30, 30.0, 30.0)", + "INSERT INTO stat_table(time, device_id, s2, s3, s4) VALUES (4, 'n1', 20, 20.0, 20.0)", + "INSERT INTO stat_table(time, device_id, s1, s3, s4) VALUES (5, 'n1', 50, 50.0, 50.0)", + "INSERT INTO stat_table(time, device_id, s1, s2, s3, s4) VALUES (6, 'n1', 40, 20, 40.0, 20.0)", "FLUSH", "CLEAR ATTRIBUTE CACHE", }; @@ -5410,4 +5432,146 @@ public void emptyBlockInStreamOperatorTest() { retArray, DATABASE_NAME); } + + @Test + public void statFunctionsBasicAndCrossTypeTest() { + String[] expectedHeader = + new String[] { + "_col0", "_col1", "_col2", "_col3", "_col4", "_col5", "_col6", "_col7", "_col8", + "_col9", "_col10", "_col11", "_col12", "_col13", "_col14", "_col15", "_col16", "_col17" + }; + String[] retArray = + new String[] { + "-0.28867513459481287,-0.19999999999999998,-0.24999999999999997,-0.8333333333333334,4.166666666666667,0.0,0.4082482904638631,-1.2000000000000002,-3.333333333333333," + + "-0.28867513459481287,-0.19999999999999998,-0.24999999999999997,-0.8333333333333334,4.166666666666667,0.0,0.4082482904638631,-1.2000000000000002,-3.333333333333333," + }; + tableResultSetEqualTest( + "select corr(s1, s2), covar_pop(s1, s2), covar_samp(s1, s2), regr_slope(s1, s2), regr_intercept(s1, s2), " + + "skewness(s1), skewness(s2), kurtosis(s1), kurtosis(s2), " + + "corr(s3, s4), covar_pop(s3, s4), covar_samp(s3, s4), regr_slope(s3, s4), regr_intercept(s3, s4), " + + "skewness(s3), skewness(s4), kurtosis(s3), kurtosis(s4) " + + "from stat_table where device_id = 'd1'", + expectedHeader, + retArray, + DATABASE_NAME); + } + + @Test + public void statFunctionsSampleSizeAndFilterEdgeTest() { + String[] expectedHeader = + new String[] {"_col0", "_col1", "_col2", "_col3", "_col4", "_col5", "_col6"}; + + String[] filteredHeader = + new String[] {"_col0", "_col1", "_col2", "_col3", "_col4", "_col5", "_col6", "_col7"}; + String[] filteredRetArray = + new String[] { + "0.8660254037844386,0.3333333333333333,0.49999999999999994,1.5,-0.5,0.0,-0.7071067811865475,null," + }; + tableResultSetEqualTest( + "select corr(s1, s2), covar_pop(s1, s2), covar_samp(s1, s2), regr_slope(s1, s2), regr_intercept(s1, s2), " + + "skewness(s1), skewness(s2), kurtosis(s1) " + + "from stat_table where device_id = 'd1' and time <= 3", + filteredHeader, + filteredRetArray, + DATABASE_NAME); + + String[] retArray = new String[] {"null,0.0,null,null,null,null,null,"}; + tableResultSetEqualTest( + "select corr(s1, s2), covar_pop(s1, s2), covar_samp(s1, s2), regr_slope(s1, s2), regr_intercept(s1, s2), skewness(s1), kurtosis(s1) " + + "from stat_table where device_id = 'd4'", + expectedHeader, + retArray, + DATABASE_NAME); + + retArray = new String[] {"1.0,250.0,500.0,0.1,0.0,null,null,"}; + tableResultSetEqualTest( + "select corr(s1, s2), covar_pop(s1, s2), covar_samp(s1, s2), regr_slope(s1, s2), regr_intercept(s1, s2), skewness(s1), kurtosis(s1) " + + "from stat_table where device_id = 'd3'", + expectedHeader, + retArray, + DATABASE_NAME); + } + + @Test + public void statFunctionsNullAndZeroVarianceTest() { + String[] expectedHeader = + new String[] {"_col0", "_col1", "_col2", "_col3", "_col4", "_col5", "_col6"}; + + String[] retArray = + new String[] {"-1.0,-125.0,-166.66666666666666,-1.0,60.0,0.0,-1.2000000000000002,"}; + tableResultSetEqualTest( + "select corr(s1, s2), covar_pop(s1, s2), covar_samp(s1, s2), regr_slope(s1, s2), regr_intercept(s1, s2), skewness(s1), kurtosis(s1) " + + "from stat_table where device_id = 'n1'", + expectedHeader, + retArray, + DATABASE_NAME); + + retArray = new String[] {"null,0.0,0.0,null,null,null,null,"}; + tableResultSetEqualTest( + "select corr(s1, s2), covar_pop(s1, s2), covar_samp(s1, s2), regr_slope(s1, s2), regr_intercept(s1, s2), skewness(s1), kurtosis(s1) " + + "from stat_table where device_id = 'd2'", + expectedHeader, + retArray, + DATABASE_NAME); + } + + @Test + public void statFunctionsInHavingTest() { + String[] expectedHeader = new String[] {"device_id", "_col1", "_col2", "_col3", "_col4"}; + String[] retArray = + new String[] {"d1,-0.28867513459481287,-0.24999999999999997,0.0,-1.2000000000000002,"}; + tableResultSetEqualTest( + "select device_id, corr(s1, s2), covar_samp(s1, s2), skewness(s1), kurtosis(s1) " + + "from stat_table where device_id in ('d1', 'd2') group by device_id " + + "having corr(s1, s2) < 0 and covar_samp(s1, s2) < 0 order by device_id", + expectedHeader, + retArray, + DATABASE_NAME); + } + + @Test + public void statFunctionsErrorTest() { + String typeError = "only support numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]"; + String argError = "Error size of input expressions"; + + tableAssertTestFail( + "select corr(s5, s1) from stat_table where device_id = 'd1'", typeError, DATABASE_NAME); + tableAssertTestFail( + "select covar_pop(s1, s6) from stat_table where device_id = 'd1'", + typeError, + DATABASE_NAME); + tableAssertTestFail( + "select covar_samp(s6, s1) from stat_table where device_id = 'd1'", + typeError, + DATABASE_NAME); + tableAssertTestFail( + "select regr_slope(s6, s1) from stat_table where device_id = 'd1'", + typeError, + DATABASE_NAME); + tableAssertTestFail( + "select regr_intercept(s5, s1) from stat_table where device_id = 'd1'", + typeError, + DATABASE_NAME); + tableAssertTestFail( + "select skewness(s5) from stat_table where device_id = 'd1'", typeError, DATABASE_NAME); + tableAssertTestFail( + "select kurtosis(s6) from stat_table where device_id = 'd1'", typeError, DATABASE_NAME); + + tableAssertTestFail( + "select corr(s1) from stat_table where device_id = 'd1'", argError, DATABASE_NAME); + tableAssertTestFail( + "select covar_pop(s1) from stat_table where device_id = 'd1'", argError, DATABASE_NAME); + tableAssertTestFail( + "select covar_samp(s1) from stat_table where device_id = 'd1'", argError, DATABASE_NAME); + tableAssertTestFail( + "select regr_slope(s1) from stat_table where device_id = 'd1'", argError, DATABASE_NAME); + tableAssertTestFail( + "select regr_intercept(s1) from stat_table where device_id = 'd1'", + argError, + DATABASE_NAME); + tableAssertTestFail( + "select skewness(s1, s2) from stat_table where device_id = 'd1'", argError, DATABASE_NAME); + tableAssertTestFail( + "select kurtosis(s1, s2) from stat_table where device_id = 'd1'", argError, DATABASE_NAME); + } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/AccumulatorFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/AccumulatorFactory.java index 24a998f54a917..c7ccaec5ce178 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/AccumulatorFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/AccumulatorFactory.java @@ -69,6 +69,11 @@ public static boolean isMultiInputAggregation(TAggregationType aggregationType) switch (aggregationType) { case MAX_BY: case MIN_BY: + case CORR: + case COVAR_POP: + case COVAR_SAMP: + case REGR_SLOPE: + case REGR_INTERCEPT: return true; default: return false; @@ -84,6 +89,31 @@ public static Accumulator createBuiltinMultiInputAccumulator( case MIN_BY: checkState(inputDataTypes.size() == 2, "Wrong inputDataTypes size."); return new MinByAccumulator(inputDataTypes.get(0), inputDataTypes.get(1)); + case CORR: + checkState(inputDataTypes.size() == 2, "Wrong inputDataTypes size."); + return new CorrelationAccumulator( + new TSDataType[] {inputDataTypes.get(0), inputDataTypes.get(1)}, + CorrelationAccumulator.CorrelationType.CORR); + case COVAR_POP: + checkState(inputDataTypes.size() == 2, "Wrong inputDataTypes size."); + return new CovarianceAccumulator( + new TSDataType[] {inputDataTypes.get(0), inputDataTypes.get(1)}, + CovarianceAccumulator.CovarianceType.COVAR_POP); + case COVAR_SAMP: + checkState(inputDataTypes.size() == 2, "Wrong inputDataTypes size."); + return new CovarianceAccumulator( + new TSDataType[] {inputDataTypes.get(0), inputDataTypes.get(1)}, + CovarianceAccumulator.CovarianceType.COVAR_SAMP); + case REGR_SLOPE: + checkState(inputDataTypes.size() == 2, "Wrong inputDataTypes size."); + return new RegressionAccumulator( + new TSDataType[] {inputDataTypes.get(0), inputDataTypes.get(1)}, + RegressionAccumulator.RegressionType.REGR_SLOPE); + case REGR_INTERCEPT: + checkState(inputDataTypes.size() == 2, "Wrong inputDataTypes size."); + return new RegressionAccumulator( + new TSDataType[] {inputDataTypes.get(0), inputDataTypes.get(1)}, + RegressionAccumulator.RegressionType.REGR_INTERCEPT); default: throw new IllegalArgumentException("Invalid Aggregation function: " + aggregationType); } @@ -140,6 +170,12 @@ private static Accumulator createBuiltinSingleInputAccumulator( return new VarianceAccumulator(tsDataType, VarianceAccumulator.VarianceType.VAR_SAMP); case VAR_POP: return new VarianceAccumulator(tsDataType, VarianceAccumulator.VarianceType.VAR_POP); + case SKEWNESS: + return new CentralMomentAccumulator( + tsDataType, CentralMomentAccumulator.MomentType.SKEWNESS); + case KURTOSIS: + return new CentralMomentAccumulator( + tsDataType, CentralMomentAccumulator.MomentType.KURTOSIS); default: throw new IllegalArgumentException("Invalid Aggregation function: " + aggregationType); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CentralMomentAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CentralMomentAccumulator.java new file mode 100644 index 0000000000000..f4974b42e811c --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CentralMomentAccumulator.java @@ -0,0 +1,294 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.aggregation; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.file.metadata.statistics.Statistics; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.BitMap; + +import java.nio.ByteBuffer; + +import static com.google.common.base.Preconditions.checkArgument; + +public class CentralMomentAccumulator implements Accumulator { + + public enum MomentType { + SKEWNESS, + KURTOSIS + } + + private final TSDataType seriesDataType; + private final MomentType momentType; + + private long count; + private double mean; + private double m2; + private double m3; + private double m4; + + public CentralMomentAccumulator(TSDataType seriesDataType, MomentType momentType) { + this.seriesDataType = seriesDataType; + this.momentType = momentType; + } + + @Override + public void addInput(Column[] columns, BitMap bitMap) { + + int size = columns[1].getPositionCount(); + for (int i = 0; i < size; i++) { + if (bitMap != null && !bitMap.isMarked(i)) { + continue; + } + if (columns[1].isNull(i)) { + continue; + } + update(getDoubleValue(columns[1], i)); + } + } + + private double getDoubleValue(Column column, int position) { + switch (seriesDataType) { + case INT32: + case DATE: + return column.getInt(position); + case INT64: + case TIMESTAMP: + return column.getLong(position); + case FLOAT: + return column.getFloat(position); + case DOUBLE: + return column.getDouble(position); + default: + throw new UnsupportedOperationException( + "Unsupported data type in CentralMoment Aggregation: " + seriesDataType); + } + } + + private void update(double value) { + long n1 = count; + count++; + + double delta = value - mean; + double delta_n = delta / count; + double delta_n2 = delta_n * delta_n; + double term1 = delta * delta_n * n1; + + mean += delta_n; + + m4 += term1 * delta_n2 * (count * count - 3 * count + 3) + 6 * delta_n2 * m2 - 4 * delta_n * m3; + + m3 += term1 * delta_n * (count - 2) - 3 * delta_n * m2; + + m2 += term1; + } + + @Override + public void addIntermediate(Column[] partialResult) { + checkArgument(partialResult.length == 1, "partialResult of CentralMoment should be 1"); + if (partialResult[0].isNull(0)) { + return; + } + byte[] bytes = partialResult[0].getBinary(0).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long otherCount = buffer.getLong(); + double otherMean = buffer.getDouble(); + double otherM2 = buffer.getDouble(); + double otherM3 = buffer.getDouble(); + double otherM4 = buffer.getDouble(); + + merge(otherCount, otherMean, otherM2, otherM3, otherM4); + } + + private void merge(long nB, double meanB, double m2B, double m3B, double m4B) { + if (nB == 0) return; + if (count == 0) { + count = nB; + mean = meanB; + m2 = m2B; + m3 = m3B; + m4 = m4B; + } else { + long nA = count; + long nTotal = nA + nB; + double delta = meanB - mean; + double delta2 = delta * delta; + double delta3 = delta * delta2; + double delta4 = delta2 * delta2; + + m4 += + m4B + + delta4 * nA * nB * (nA * nA - nA * nB + nB * nB) / (nTotal * nTotal * nTotal) + + 6.0 * delta2 * (nA * nA * m2B + nB * nB * m2) / (nTotal * nTotal) + + 4.0 * delta * (nA * m3B - nB * m3) / nTotal; + + m3 += + m3B + + delta3 * nA * nB * (nA - nB) / (nTotal * nTotal) + + 3.0 * delta * (nA * m2B - nB * m2) / nTotal; + + m2 += m2B + delta2 * nA * nB / nTotal; + + mean += delta * nB / nTotal; + count = nTotal; + } + } + + @Override + public void outputIntermediate(ColumnBuilder[] columnBuilders) { + checkArgument(columnBuilders.length == 1, "partialResult should be 1"); + if (count == 0) { + columnBuilders[0].appendNull(); + } else { + + byte[] bytes = new byte[40]; + ByteBuffer buffer = ByteBuffer.wrap(bytes); + buffer.putLong(count); + buffer.putDouble(mean); + buffer.putDouble(m2); + buffer.putDouble(m3); + buffer.putDouble(m4); + columnBuilders[0].writeBinary(new Binary(bytes)); + } + } + + @Override + public void outputFinal(ColumnBuilder columnBuilder) { + if (count == 0 || m2 == 0) { + columnBuilder.appendNull(); + return; + } + + if (momentType == MomentType.SKEWNESS) { + if (count < 3) { + columnBuilder.appendNull(); + return; + } + double result = Math.sqrt((double) count) * m3 / Math.pow(m2, 1.5); + columnBuilder.writeDouble(result); + } else { + if (count < 4) { + columnBuilder.appendNull(); + } else { + + double variance = m2 / (count - 1); + double term1 = + (count * (count + 1) * m4) + / ((count - 1) * (count - 2) * (count - 3) * variance * variance); + double term2 = (3 * Math.pow(count - 1, 2)) / ((count - 2) * (count - 3)); + columnBuilder.writeDouble(term1 - term2); + } + } + } + + @Override + public void removeIntermediate(Column[] input) { + checkArgument(input.length == 1, "Input of CentralMoment should be 1"); + if (input[0].isNull(0)) { + return; + } + + byte[] bytes = input[0].getBinary(0).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long nB = buffer.getLong(); + if (nB == 0) { + return; + } + checkArgument(count >= nB, "CentralMoment state count is smaller than removed state count"); + + if (count == nB) { + reset(); + return; + } + + double meanB = buffer.getDouble(); + double m2B = buffer.getDouble(); + double m3B = buffer.getDouble(); + double m4B = buffer.getDouble(); + + long nTotal = count; + long nA = nTotal - nB; + + double meanA = ((double) nTotal * mean - (double) nB * meanB) / nA; + + double delta = meanB - meanA; + double delta2 = delta * delta; + double delta3 = delta * delta2; + double delta4 = delta2 * delta2; + + double m2A = m2 - m2B - delta2 * nA * nB / nTotal; + double m3A = + m3 + - m3B + - delta3 * nA * nB * (nA - nB) / ((double) nTotal * nTotal) + - 3.0 * delta * (nA * m2B - nB * m2A) / nTotal; + + double m4A = + m4 + - m4B + - delta4 + * nA + * nB + * ((double) nA * nA - (double) nA * nB + (double) nB * nB) + / ((double) nTotal * nTotal * nTotal) + - 6.0 + * delta2 + * ((double) nA * nA * m2B + (double) nB * nB * m2A) + / ((double) nTotal * nTotal) + - 4.0 * delta * (nA * m3B - nB * m3A) / nTotal; + + count = nA; + mean = meanA; + m2 = m2A; + m3 = m3A; + m4 = m4A; + } + + @Override + public void addStatistics(Statistics statistics) { + throw new UnsupportedOperationException(); + } + + @Override + public void setFinal(Column finalResult) {} + + @Override + public void reset() { + count = 0; + mean = 0; + m2 = 0; + m3 = 0; + m4 = 0; + } + + @Override + public boolean hasFinalResult() { + return false; + } + + @Override + public TSDataType[] getIntermediateType() { + return new TSDataType[] {TSDataType.TEXT}; + } + + @Override + public TSDataType getFinalType() { + return TSDataType.DOUBLE; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CorrelationAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CorrelationAccumulator.java new file mode 100644 index 0000000000000..dacaf4efb38b7 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CorrelationAccumulator.java @@ -0,0 +1,267 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.aggregation; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.file.metadata.statistics.Statistics; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.BitMap; + +import java.nio.ByteBuffer; + +import static com.google.common.base.Preconditions.checkArgument; + +public class CorrelationAccumulator implements Accumulator { + + public enum CorrelationType { + CORR + } + + private final TSDataType[] seriesDataTypes; + private final CorrelationType correlationType; + + private long count; + private double meanX; + private double meanY; + private double m2X; + private double m2Y; + private double c2; + + public CorrelationAccumulator(TSDataType[] seriesDataTypes, CorrelationType correlationType) { + this.seriesDataTypes = seriesDataTypes; + this.correlationType = correlationType; + } + + @Override + public void addInput(Column[] columns, BitMap bitMap) { + + int size = columns[0].getPositionCount(); + for (int i = 0; i < size; i++) { + if (bitMap != null && !bitMap.isMarked(i)) { + continue; + } + if (columns[1].isNull(i) || columns[2].isNull(i)) { + continue; + } + + double x = getDoubleValue(columns[1], i, seriesDataTypes[0]); + double y = getDoubleValue(columns[2], i, seriesDataTypes[1]); + + update(x, y); + } + } + + private double getDoubleValue(Column column, int position, TSDataType dataType) { + switch (dataType) { + case INT32: + return column.getInt(position); + case INT64: + case TIMESTAMP: + return column.getLong(position); + case FLOAT: + return column.getFloat(position); + case DOUBLE: + return column.getDouble(position); + default: + throw new IllegalArgumentException("Unsupported data type: " + dataType); + } + } + + private void update(double x, double y) { + long newCount = count + 1; + + double oldMeanX = meanX; + double oldMeanY = meanY; + + meanX = oldMeanX + (x - oldMeanX) / newCount; + meanY = oldMeanY + (y - oldMeanY) / newCount; + + c2 += (x - oldMeanX) * (y - meanY); + m2X += (x - oldMeanX) * (x - meanX); + m2Y += (y - oldMeanY) * (y - meanY); + + count = newCount; + } + + @Override + public void addIntermediate(Column[] partialResult) { + checkArgument(partialResult.length == 1, "partialResult of Correlation should be 1"); + if (partialResult[0].isNull(0)) { + return; + } + byte[] bytes = partialResult[0].getBinary(0).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long otherCount = buffer.getLong(); + double otherMeanX = buffer.getDouble(); + double otherMeanY = buffer.getDouble(); + double otherM2X = buffer.getDouble(); + double otherM2Y = buffer.getDouble(); + double otherC2 = buffer.getDouble(); + + merge(otherCount, otherMeanX, otherMeanY, otherM2X, otherM2Y, otherC2); + } + + private void merge( + long otherCount, + double otherMeanX, + double otherMeanY, + double otherM2X, + double otherM2Y, + double otherC2) { + if (otherCount == 0) { + return; + } + if (count == 0) { + count = otherCount; + meanX = otherMeanX; + meanY = otherMeanY; + m2X = otherM2X; + m2Y = otherM2Y; + c2 = otherC2; + } else { + long newCount = count + otherCount; + double deltaX = otherMeanX - meanX; + double deltaY = otherMeanY - meanY; + + c2 += otherC2 + deltaX * deltaY * count * otherCount / newCount; + m2X += otherM2X + deltaX * deltaX * count * otherCount / newCount; + m2Y += otherM2Y + deltaY * deltaY * count * otherCount / newCount; + + meanX += deltaX * otherCount / newCount; + meanY += deltaY * otherCount / newCount; + count = newCount; + } + } + + @Override + public void outputIntermediate(ColumnBuilder[] columnBuilders) { + checkArgument(columnBuilders.length == 1, "partialResult of Correlation should be 1"); + if (count == 0) { + columnBuilders[0].appendNull(); + } else { + ByteBuffer buffer = ByteBuffer.allocate(Long.BYTES + Double.BYTES * 5); + buffer.putLong(count); + buffer.putDouble(meanX); + buffer.putDouble(meanY); + buffer.putDouble(m2X); + buffer.putDouble(m2Y); + buffer.putDouble(c2); + columnBuilders[0].writeBinary(new Binary(buffer.array())); + } + } + + @Override + public void outputFinal(ColumnBuilder columnBuilder) { + if (correlationType != CorrelationType.CORR) { + throw new UnsupportedOperationException("Unknown type: " + correlationType); + } + + if (count < 2) { + columnBuilder.appendNull(); + } else if (m2X == 0 || m2Y == 0) { + columnBuilder.appendNull(); + } else { + columnBuilder.writeDouble(c2 / Math.sqrt(m2X * m2Y)); + } + } + + @Override + public void removeIntermediate(Column[] input) { + checkArgument(input.length == 1, "Input of Correlation should be 1"); + if (input[0].isNull(0)) { + return; + } + + byte[] bytes = input[0].getBinary(0).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long otherCount = buffer.getLong(); + if (otherCount == 0) { + return; + } + checkArgument( + count >= otherCount, "Correlation state count is smaller than removed state count"); + + if (count == otherCount) { + reset(); + return; + } + + double otherMeanX = buffer.getDouble(); + double otherMeanY = buffer.getDouble(); + double otherM2X = buffer.getDouble(); + double otherM2Y = buffer.getDouble(); + double otherC2 = buffer.getDouble(); + + long totalCount = count; + long newCount = totalCount - otherCount; + + double newMeanX = (totalCount * meanX - otherCount * otherMeanX) / newCount; + double newMeanY = (totalCount * meanY - otherCount * otherMeanY) / newCount; + + double deltaX = otherMeanX - newMeanX; + double deltaY = otherMeanY - newMeanY; + double correction = ((double) newCount * otherCount) / totalCount; + + c2 = c2 - otherC2 - deltaX * deltaY * correction; + m2X = m2X - otherM2X - deltaX * deltaX * correction; + m2Y = m2Y - otherM2Y - deltaY * deltaY * correction; + + meanX = newMeanX; + meanY = newMeanY; + count = newCount; + } + + @Override + public void addStatistics(Statistics statistics) { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public void setFinal(Column finalResult) {} + + @Override + public void reset() { + count = 0; + meanX = 0; + meanY = 0; + m2X = 0; + m2Y = 0; + c2 = 0; + } + + @Override + public boolean hasFinalResult() { + return false; + } + + @Override + public TSDataType[] getIntermediateType() { + return new TSDataType[] {TSDataType.TEXT}; + } + + @Override + public TSDataType getFinalType() { + return TSDataType.DOUBLE; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CovarianceAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CovarianceAccumulator.java new file mode 100644 index 0000000000000..c125755fec4c2 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/CovarianceAccumulator.java @@ -0,0 +1,247 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.aggregation; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.file.metadata.statistics.Statistics; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.BitMap; + +import java.nio.ByteBuffer; + +import static com.google.common.base.Preconditions.checkArgument; + +public class CovarianceAccumulator implements Accumulator { + + public enum CovarianceType { + COVAR_POP, + COVAR_SAMP + } + + private final TSDataType[] seriesDataTypes; + private final CovarianceType covarianceType; + + private long count; + private double meanX; + private double meanY; + private double c2; + + public CovarianceAccumulator(TSDataType[] seriesDataTypes, CovarianceType covarianceType) { + this.seriesDataTypes = seriesDataTypes; + this.covarianceType = covarianceType; + } + + @Override + public void addInput(Column[] columns, BitMap bitMap) { + int size = columns[0].getPositionCount(); + for (int i = 0; i < size; i++) { + if (bitMap != null && !bitMap.isMarked(i)) { + continue; + } + if (columns[1].isNull(i) || columns[2].isNull(i)) { + continue; + } + + double x = getDoubleValue(columns[1], i, seriesDataTypes[0]); + double y = getDoubleValue(columns[2], i, seriesDataTypes[1]); + update(x, y); + } + } + + private double getDoubleValue(Column column, int position, TSDataType dataType) { + switch (dataType) { + case INT32: + return column.getInt(position); + case INT64: + case TIMESTAMP: + return column.getLong(position); + case FLOAT: + return column.getFloat(position); + case DOUBLE: + return column.getDouble(position); + default: + throw new IllegalArgumentException("Unsupported data type: " + dataType); + } + } + + private void update(double x, double y) { + long newCount = count + 1; + double oldMeanX = meanX; + meanX = oldMeanX + (x - oldMeanX) / newCount; + double oldMeanY = meanY; + double newMeanY = oldMeanY + (y - oldMeanY) / newCount; + meanY = newMeanY; + c2 += (x - oldMeanX) * (y - newMeanY); + count = newCount; + } + + @Override + public void addIntermediate(Column[] partialResult) { + checkArgument(partialResult.length == 1, "partialResult of Covariance should be 1"); + if (partialResult[0].isNull(0)) { + return; + } + + byte[] bytes = partialResult[0].getBinary(0).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long otherCount = buffer.getLong(); + double otherMeanX = buffer.getDouble(); + double otherMeanY = buffer.getDouble(); + double otherC2 = buffer.getDouble(); + + merge(otherCount, otherMeanX, otherMeanY, otherC2); + } + + private void merge(long otherCount, double otherMeanX, double otherMeanY, double otherC2) { + if (otherCount == 0) { + return; + } + if (count == 0) { + count = otherCount; + meanX = otherMeanX; + meanY = otherMeanY; + c2 = otherC2; + return; + } + + long newCount = count + otherCount; + double deltaX = otherMeanX - meanX; + double deltaY = otherMeanY - meanY; + + c2 += otherC2 + deltaX * deltaY * count * otherCount / newCount; + meanX += deltaX * otherCount / newCount; + meanY += deltaY * otherCount / newCount; + count = newCount; + } + + @Override + public void outputIntermediate(ColumnBuilder[] columnBuilders) { + checkArgument(columnBuilders.length == 1, "partialResult of Covariance should be 1"); + if (count == 0) { + columnBuilders[0].appendNull(); + return; + } + + ByteBuffer buffer = ByteBuffer.allocate(Long.BYTES + Double.BYTES * 3); + buffer.putLong(count); + buffer.putDouble(meanX); + buffer.putDouble(meanY); + buffer.putDouble(c2); + columnBuilders[0].writeBinary(new Binary(buffer.array())); + } + + @Override + public void outputFinal(ColumnBuilder columnBuilder) { + switch (covarianceType) { + case COVAR_POP: + if (count == 0) { + columnBuilder.appendNull(); + } else { + columnBuilder.writeDouble(c2 / count); + } + break; + case COVAR_SAMP: + if (count < 2) { + columnBuilder.appendNull(); + } else { + columnBuilder.writeDouble(c2 / (count - 1)); + } + break; + default: + throw new UnsupportedOperationException("Unknown type: " + covarianceType); + } + } + + @Override + public void removeIntermediate(Column[] input) { + checkArgument(input.length == 1, "Input of Covariance should be 1"); + if (input[0].isNull(0)) { + return; + } + + byte[] bytes = input[0].getBinary(0).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long otherCount = buffer.getLong(); + if (otherCount == 0) { + return; + } + checkArgument( + count >= otherCount, "Covariance state count is smaller than removed state count"); + + if (count == otherCount) { + reset(); + return; + } + + double otherMeanX = buffer.getDouble(); + double otherMeanY = buffer.getDouble(); + double otherC2 = buffer.getDouble(); + + long totalCount = count; + long newCount = totalCount - otherCount; + + double newMeanX = ((double) totalCount * meanX - (double) otherCount * otherMeanX) / newCount; + double newMeanY = ((double) totalCount * meanY - (double) otherCount * otherMeanY) / newCount; + + double deltaX = otherMeanX - newMeanX; + double deltaY = otherMeanY - newMeanY; + double correction = ((double) newCount * otherCount) / totalCount; + + c2 = c2 - otherC2 - deltaX * deltaY * correction; + meanX = newMeanX; + meanY = newMeanY; + count = newCount; + } + + @Override + public void addStatistics(Statistics statistics) { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public void setFinal(Column finalResult) {} + + @Override + public void reset() { + count = 0; + meanX = 0; + meanY = 0; + c2 = 0; + } + + @Override + public boolean hasFinalResult() { + return false; + } + + @Override + public TSDataType[] getIntermediateType() { + return new TSDataType[] {TSDataType.TEXT}; + } + + @Override + public TSDataType getFinalType() { + return TSDataType.DOUBLE; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/RegressionAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/RegressionAccumulator.java new file mode 100644 index 0000000000000..e8d69db9575e2 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/RegressionAccumulator.java @@ -0,0 +1,260 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.aggregation; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.file.metadata.statistics.Statistics; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.BitMap; + +import java.nio.ByteBuffer; + +import static com.google.common.base.Preconditions.checkArgument; + +public class RegressionAccumulator implements Accumulator { + + public enum RegressionType { + REGR_SLOPE, + REGR_INTERCEPT + } + + private final TSDataType[] seriesDataTypes; + private final RegressionType regressionType; + + private long count; + private double meanX; + private double meanY; + private double m2X; + private double c2; + + public RegressionAccumulator(TSDataType[] seriesDataTypes, RegressionType regressionType) { + this.seriesDataTypes = seriesDataTypes; + this.regressionType = regressionType; + } + + @Override + public void addInput(Column[] columns, BitMap bitMap) { + + int size = columns[1].getPositionCount(); + for (int i = 0; i < size; i++) { + if (bitMap != null && !bitMap.isMarked(i)) { + continue; + } + if (columns[1].isNull(i) || columns[2].isNull(i)) { + continue; + } + + double y = getDoubleValue(columns[1], i, seriesDataTypes[0]); + double x = getDoubleValue(columns[2], i, seriesDataTypes[1]); + + update(x, y); + } + } + + private double getDoubleValue(Column column, int position, TSDataType dataType) { + switch (dataType) { + case INT32: + return column.getInt(position); + case INT64: + case TIMESTAMP: + return column.getLong(position); + case FLOAT: + return column.getFloat(position); + case DOUBLE: + return column.getDouble(position); + default: + throw new IllegalArgumentException("Unsupported data type: " + dataType); + } + } + + private void update(double x, double y) { + long newCount = count + 1; + double deltaX = x - meanX; + double deltaY = y - meanY; + + meanX += deltaX / newCount; + meanY += deltaY / newCount; + + c2 += deltaX * (y - meanY); + m2X += deltaX * (x - meanX); + + count = newCount; + } + + @Override + public void addIntermediate(Column[] partialResult) { + checkArgument(partialResult.length == 1, "partialResult of Regression should be 1"); + if (partialResult[0].isNull(0)) { + return; + } + byte[] bytes = partialResult[0].getBinary(0).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long otherCount = buffer.getLong(); + double otherMeanX = buffer.getDouble(); + double otherMeanY = buffer.getDouble(); + double otherM2X = buffer.getDouble(); + double otherC2 = buffer.getDouble(); + + merge(otherCount, otherMeanX, otherMeanY, otherM2X, otherC2); + } + + private void merge( + long otherCount, double otherMeanX, double otherMeanY, double otherM2X, double otherC2) { + if (otherCount == 0) { + return; + } + if (count == 0) { + count = otherCount; + meanX = otherMeanX; + meanY = otherMeanY; + m2X = otherM2X; + c2 = otherC2; + } else { + long newCount = count + otherCount; + double deltaX = otherMeanX - meanX; + double deltaY = otherMeanY - meanY; + + c2 += otherC2 + deltaX * deltaY * count * otherCount / newCount; + m2X += otherM2X + deltaX * deltaX * count * otherCount / newCount; + + meanX += deltaX * otherCount / newCount; + meanY += deltaY * otherCount / newCount; + count = newCount; + } + } + + @Override + public void outputIntermediate(ColumnBuilder[] columnBuilders) { + checkArgument(columnBuilders.length == 1, "partialResult of Regression should be 1"); + if (count == 0) { + columnBuilders[0].appendNull(); + } else { + byte[] bytes = new byte[40]; + ByteBuffer buffer = ByteBuffer.wrap(bytes); + buffer.putLong(count); + buffer.putDouble(meanX); + buffer.putDouble(meanY); + buffer.putDouble(m2X); + buffer.putDouble(c2); + columnBuilders[0].writeBinary(new Binary(bytes)); + } + } + + @Override + public void outputFinal(ColumnBuilder columnBuilder) { + if (count == 0) { + columnBuilder.appendNull(); + return; + } + + if (m2X == 0) { + columnBuilder.appendNull(); + return; + } + + double slope = c2 / m2X; + + switch (regressionType) { + case REGR_SLOPE: + columnBuilder.writeDouble(slope); + break; + case REGR_INTERCEPT: + columnBuilder.writeDouble(meanY - slope * meanX); + break; + default: + throw new UnsupportedOperationException("Unknown type: " + regressionType); + } + } + + @Override + public void removeIntermediate(Column[] input) { + checkArgument(input.length == 1, "Input of Regression should be 1"); + if (input[0].isNull(0)) { + return; + } + + byte[] bytes = input[0].getBinary(0).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long otherCount = buffer.getLong(); + if (otherCount == 0) { + return; + } + checkArgument( + count >= otherCount, "Regression state count is smaller than removed state count"); + + if (count == otherCount) { + reset(); + return; + } + + double otherMeanX = buffer.getDouble(); + double otherMeanY = buffer.getDouble(); + double otherM2X = buffer.getDouble(); + double otherC2 = buffer.getDouble(); + + long totalCount = count; + long newCount = totalCount - otherCount; + + double newMeanX = ((double) totalCount * meanX - (double) otherCount * otherMeanX) / newCount; + double newMeanY = ((double) totalCount * meanY - (double) otherCount * otherMeanY) / newCount; + + double deltaX = otherMeanX - newMeanX; + double deltaY = otherMeanY - newMeanY; + double correction = ((double) newCount * otherCount) / totalCount; + + c2 = c2 - otherC2 - deltaX * deltaY * correction; + m2X = m2X - otherM2X - deltaX * deltaX * correction; + + meanX = newMeanX; + meanY = newMeanY; + count = newCount; + } + + @Override + public void addStatistics(Statistics statistics) { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public void setFinal(Column finalResult) {} + + @Override + public void reset() { + count = 0; + meanX = 0; + meanY = 0; + m2X = 0; + c2 = 0; + } + + @Override + public boolean hasFinalResult() { + return false; + } + + @Override + public TSDataType[] getIntermediateType() { + return new TSDataType[] {TSDataType.TEXT}; + } + + @Override + public TSDataType getFinalType() { + return TSDataType.DOUBLE; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/slidingwindow/SlidingWindowAggregatorFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/slidingwindow/SlidingWindowAggregatorFactory.java index 572d41d518486..a3ca7212c8bbc 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/slidingwindow/SlidingWindowAggregatorFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/aggregation/slidingwindow/SlidingWindowAggregatorFactory.java @@ -200,6 +200,13 @@ public static SlidingWindowAggregator createSlidingWindowAggregator( case VARIANCE: case VAR_POP: case VAR_SAMP: + case CORR: + case COVAR_POP: + case COVAR_SAMP: + case REGR_SLOPE: + case REGR_INTERCEPT: + case SKEWNESS: + case KURTOSIS: case UDAF: // Currently UDAF belongs to SmoothQueueSlidingWindowAggregator return new SmoothQueueSlidingWindowAggregator(accumulator, inputLocationList, step); case MAX_VALUE: diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java index 3ff20974168be..62b269561ad28 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/AccumulatorFactory.java @@ -21,6 +21,10 @@ import org.apache.iotdb.common.rpc.thrift.TAggregationType; import org.apache.iotdb.commons.udf.utils.UDFDataTypeTransformer; +import org.apache.iotdb.db.queryengine.execution.aggregation.CentralMomentAccumulator; +import org.apache.iotdb.db.queryengine.execution.aggregation.CorrelationAccumulator; +import org.apache.iotdb.db.queryengine.execution.aggregation.CovarianceAccumulator; +import org.apache.iotdb.db.queryengine.execution.aggregation.RegressionAccumulator; import org.apache.iotdb.db.queryengine.execution.aggregation.VarianceAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.BinaryGroupedApproxMostFrequentAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.BlobGroupedApproxMostFrequentAccumulator; @@ -30,9 +34,12 @@ import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedApproxCountDistinctAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedAvgAccumulator; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedCentralMomentAccumulator; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedCorrelationAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedCountAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedCountAllAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedCountIfAccumulator; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedCovarianceAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedExtremeAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedFirstAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedFirstByAccumulator; @@ -43,6 +50,7 @@ import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedMinAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedMinByAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedModeAccumulator; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedRegressionAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedSumAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedUserDefinedAggregateAccumulator; import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.GroupedVarianceAccumulator; @@ -256,6 +264,37 @@ private static GroupedAccumulator createBuiltinGroupedAccumulator( return new GroupedApproxCountDistinctAccumulator(inputDataTypes.get(0)); case APPROX_MOST_FREQUENT: return getGroupedApproxMostFrequentAccumulator(inputDataTypes.get(0)); + case CORR: + return new GroupedCorrelationAccumulator( + inputDataTypes.get(0), + inputDataTypes.get(1), + CorrelationAccumulator.CorrelationType.CORR); + case COVAR_POP: + return new GroupedCovarianceAccumulator( + inputDataTypes.get(0), + inputDataTypes.get(1), + CovarianceAccumulator.CovarianceType.COVAR_POP); + case COVAR_SAMP: + return new GroupedCovarianceAccumulator( + inputDataTypes.get(0), + inputDataTypes.get(1), + CovarianceAccumulator.CovarianceType.COVAR_SAMP); + case REGR_SLOPE: + return new GroupedRegressionAccumulator( + inputDataTypes.get(0), + inputDataTypes.get(1), + RegressionAccumulator.RegressionType.REGR_SLOPE); + case REGR_INTERCEPT: + return new GroupedRegressionAccumulator( + inputDataTypes.get(0), + inputDataTypes.get(1), + RegressionAccumulator.RegressionType.REGR_INTERCEPT); + case SKEWNESS: + return new GroupedCentralMomentAccumulator( + inputDataTypes.get(0), CentralMomentAccumulator.MomentType.SKEWNESS); + case KURTOSIS: + return new GroupedCentralMomentAccumulator( + inputDataTypes.get(0), CentralMomentAccumulator.MomentType.KURTOSIS); default: throw new IllegalArgumentException("Invalid Aggregation function: " + aggregationType); } @@ -325,6 +364,37 @@ public static TableAccumulator createBuiltinAccumulator( return new ApproxCountDistinctAccumulator(inputDataTypes.get(0)); case APPROX_MOST_FREQUENT: return getApproxMostFrequentAccumulator(inputDataTypes.get(0)); + case CORR: + return new TableCorrelationAccumulator( + inputDataTypes.get(0), + inputDataTypes.get(1), + CorrelationAccumulator.CorrelationType.CORR); + case COVAR_POP: + return new TableCovarianceAccumulator( + inputDataTypes.get(0), + inputDataTypes.get(1), + CovarianceAccumulator.CovarianceType.COVAR_POP); + case COVAR_SAMP: + return new TableCovarianceAccumulator( + inputDataTypes.get(0), + inputDataTypes.get(1), + CovarianceAccumulator.CovarianceType.COVAR_SAMP); + case REGR_SLOPE: + return new TableRegressionAccumulator( + inputDataTypes.get(0), + inputDataTypes.get(1), + RegressionAccumulator.RegressionType.REGR_SLOPE); + case REGR_INTERCEPT: + return new TableRegressionAccumulator( + inputDataTypes.get(0), + inputDataTypes.get(1), + RegressionAccumulator.RegressionType.REGR_INTERCEPT); + case SKEWNESS: + return new TableCentralMomentAccumulator( + inputDataTypes.get(0), CentralMomentAccumulator.MomentType.SKEWNESS); + case KURTOSIS: + return new TableCentralMomentAccumulator( + inputDataTypes.get(0), CentralMomentAccumulator.MomentType.KURTOSIS); default: throw new IllegalArgumentException("Invalid Aggregation function: " + aggregationType); } @@ -385,6 +455,12 @@ public static boolean isMultiInputAggregation(TAggregationType aggregationType) case MAX_BY: case MIN_BY: return true; + case CORR: + case COVAR_POP: + case COVAR_SAMP: + case REGR_SLOPE: + case REGR_INTERCEPT: + return true; default: return false; } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableCentralMomentAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableCentralMomentAccumulator.java new file mode 100644 index 0000000000000..7b10df15cf3b1 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableCentralMomentAccumulator.java @@ -0,0 +1,251 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation; + +import org.apache.iotdb.db.queryengine.execution.aggregation.CentralMomentAccumulator; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.file.metadata.statistics.Statistics; +import org.apache.tsfile.read.common.block.column.BinaryColumn; +import org.apache.tsfile.read.common.block.column.BinaryColumnBuilder; +import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.RamUsageEstimator; +import org.apache.tsfile.write.UnSupportedDataTypeException; + +import java.nio.ByteBuffer; + +import static com.google.common.base.Preconditions.checkArgument; + +public class TableCentralMomentAccumulator implements TableAccumulator { + + private static final long INSTANCE_SIZE = + RamUsageEstimator.shallowSizeOfInstance(TableCentralMomentAccumulator.class); + + private final TSDataType seriesDataType; + private final CentralMomentAccumulator.MomentType momentType; + + private long count; + private double mean; + private double m2; + private double m3; + private double m4; + + public TableCentralMomentAccumulator( + TSDataType seriesDataType, CentralMomentAccumulator.MomentType momentType) { + this.seriesDataType = seriesDataType; + this.momentType = momentType; + } + + @Override + public void addInput(Column[] arguments, AggregationMask mask) { + int positionCount = mask.getSelectedPositionCount(); + if (mask.isSelectAll()) { + for (int i = 0; i < positionCount; i++) { + if (!arguments[0].isNull(i)) { + update(getDoubleValue(arguments[0], i)); + } + } + } else { + int[] selectedPositions = mask.getSelectedPositions(); + for (int i = 0; i < positionCount; i++) { + int position = selectedPositions[i]; + if (!arguments[0].isNull(position)) { + update(getDoubleValue(arguments[0], position)); + } + } + } + } + + private double getDoubleValue(Column column, int position) { + switch (seriesDataType) { + case INT32: + case DATE: + return column.getInt(position); + case INT64: + case TIMESTAMP: + return column.getLong(position); + case FLOAT: + return column.getFloat(position); + case DOUBLE: + return column.getDouble(position); + default: + throw new UnSupportedDataTypeException( + String.format( + "Unsupported data type in CentralMoment Aggregation: %s", seriesDataType)); + } + } + + private void update(double value) { + long n1 = count; + count++; + double delta = value - mean; + double delta_n = delta / count; + double delta_n2 = delta_n * delta_n; + double term1 = delta * delta_n * n1; + mean += delta_n; + m4 += term1 * delta_n2 * (count * count - 3 * count + 3) + 6 * delta_n2 * m2 - 4 * delta_n * m3; + m3 += term1 * delta_n * (count - 2) - 3 * delta_n * m2; + m2 += term1; + } + + @Override + public void addIntermediate(Column argument) { + checkArgument( + argument instanceof BinaryColumn + || (argument instanceof RunLengthEncodedColumn + && ((RunLengthEncodedColumn) argument).getValue() instanceof BinaryColumn), + "intermediate input and output should be BinaryColumn"); + + for (int i = 0; i < argument.getPositionCount(); i++) { + if (argument.isNull(i)) { + continue; + } + byte[] bytes = argument.getBinary(i).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long otherCount = buffer.getLong(); + double otherMean = buffer.getDouble(); + double otherM2 = buffer.getDouble(); + double otherM3 = buffer.getDouble(); + double otherM4 = buffer.getDouble(); + + merge(otherCount, otherMean, otherM2, otherM3, otherM4); + } + } + + private void merge(long nB, double meanB, double m2B, double m3B, double m4B) { + if (nB == 0) return; + if (count == 0) { + count = nB; + mean = meanB; + m2 = m2B; + m3 = m3B; + m4 = m4B; + } else { + long nA = count; + long nTotal = nA + nB; + double delta = meanB - mean; + double delta2 = delta * delta; + double delta3 = delta * delta2; + double delta4 = delta2 * delta2; + + m4 += + m4B + + delta4 * nA * nB * (nA * nA - nA * nB + nB * nB) / (nTotal * nTotal * nTotal) + + 6.0 * delta2 * (nA * nA * m2B + nB * nB * m2) / (nTotal * nTotal) + + 4.0 * delta * (nA * m3B - nB * m3) / nTotal; + + m3 += + m3B + + delta3 * nA * nB * (nA - nB) / (nTotal * nTotal) + + 3.0 * delta * (nA * m2B - nB * m2) / nTotal; + + m2 += m2B + delta2 * nA * nB / nTotal; + + mean += delta * nB / nTotal; + count = nTotal; + } + } + + @Override + public void evaluateIntermediate(ColumnBuilder columnBuilder) { + checkArgument( + columnBuilder instanceof BinaryColumnBuilder, + "intermediate input and output should be BinaryColumn"); + + if (count == 0) { + columnBuilder.appendNull(); + } else { + + ByteBuffer buffer = ByteBuffer.allocate(Long.BYTES + Double.BYTES * 4); + buffer.putLong(count); + buffer.putDouble(mean); + buffer.putDouble(m2); + buffer.putDouble(m3); + buffer.putDouble(m4); + columnBuilder.writeBinary(new Binary(buffer.array())); + } + } + + @Override + public void evaluateFinal(ColumnBuilder columnBuilder) { + if (count == 0 || m2 == 0) { + columnBuilder.appendNull(); + return; + } + + if (momentType == CentralMomentAccumulator.MomentType.SKEWNESS) { + if (count < 3) { + columnBuilder.appendNull(); + return; + } + double result = Math.sqrt((double) count) * m3 / Math.pow(m2, 1.5); + columnBuilder.writeDouble(result); + } else { + if (count < 4) { + columnBuilder.appendNull(); + } else { + double variance = m2 / (count - 1); + double term1 = + (count * (count + 1) * m4) + / ((count - 1) * (count - 2) * (count - 3) * variance * variance); + double term2 = (3 * Math.pow(count - 1, 2)) / ((count - 2) * (count - 3)); + columnBuilder.writeDouble(term1 - term2); + } + } + } + + @Override + public long getEstimatedSize() { + return INSTANCE_SIZE; + } + + @Override + public TableAccumulator copy() { + return new TableCentralMomentAccumulator(seriesDataType, momentType); + } + + @Override + public void removeInput(Column[] arguments) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean hasFinalResult() { + return false; + } + + @Override + public void addStatistics(Statistics[] statistics) { + throw new UnsupportedOperationException(); + } + + @Override + public void reset() { + count = 0; + mean = 0; + m2 = 0; + m3 = 0; + m4 = 0; + } + + @Override + public boolean removable() { + return false; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableCorrelationAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableCorrelationAccumulator.java new file mode 100644 index 0000000000000..68c65311a6152 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableCorrelationAccumulator.java @@ -0,0 +1,249 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation; + +import org.apache.iotdb.db.queryengine.execution.aggregation.CorrelationAccumulator; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.file.metadata.statistics.Statistics; +import org.apache.tsfile.read.common.block.column.BinaryColumn; +import org.apache.tsfile.read.common.block.column.BinaryColumnBuilder; +import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.RamUsageEstimator; +import org.apache.tsfile.write.UnSupportedDataTypeException; + +import java.nio.ByteBuffer; + +import static com.google.common.base.Preconditions.checkArgument; + +public class TableCorrelationAccumulator implements TableAccumulator { + + private static final long INSTANCE_SIZE = + RamUsageEstimator.shallowSizeOfInstance(TableCorrelationAccumulator.class); + private final TSDataType xDataType; + private final TSDataType yDataType; + private final CorrelationAccumulator.CorrelationType correlationType; + + private long count; + private double meanX; + private double meanY; + private double m2X; + private double m2Y; + private double c2; + + public TableCorrelationAccumulator( + TSDataType xDataType, + TSDataType yDataType, + CorrelationAccumulator.CorrelationType correlationType) { + this.xDataType = xDataType; + this.yDataType = yDataType; + this.correlationType = correlationType; + } + + @Override + public long getEstimatedSize() { + return INSTANCE_SIZE; + } + + @Override + public TableAccumulator copy() { + return new TableCorrelationAccumulator(xDataType, yDataType, correlationType); + } + + @Override + public void addInput(Column[] arguments, AggregationMask mask) { + int positionCount = mask.getSelectedPositionCount(); + + if (mask.isSelectAll()) { + for (int i = 0; i < positionCount; i++) { + if (arguments[0].isNull(i) || arguments[1].isNull(i)) { + continue; + } + double x = getDoubleValue(arguments[0], i, xDataType); + double y = getDoubleValue(arguments[1], i, yDataType); + update(x, y); + } + } else { + int[] selectedPositions = mask.getSelectedPositions(); + for (int i = 0; i < positionCount; i++) { + int position = selectedPositions[i]; + if (arguments[0].isNull(position) || arguments[1].isNull(position)) { + continue; + } + double x = getDoubleValue(arguments[0], position, xDataType); + double y = getDoubleValue(arguments[1], position, yDataType); + update(x, y); + } + } + } + + private double getDoubleValue(Column column, int position, TSDataType dataType) { + switch (dataType) { + case INT32: + case DATE: + return column.getInt(position); + case INT64: + case TIMESTAMP: + return column.getLong(position); + case FLOAT: + return column.getFloat(position); + case DOUBLE: + return column.getDouble(position); + default: + throw new UnSupportedDataTypeException( + String.format("Unsupported data type in Correlation Aggregation: %s", dataType)); + } + } + + private void update(double x, double y) { + count++; + double oldMeanX = meanX; + meanX = oldMeanX + (x - oldMeanX) / count; + double oldMeanY = meanY; + double newMeanY = oldMeanY + (y - oldMeanY) / count; + meanY = newMeanY; + + c2 += (x - oldMeanX) * (y - newMeanY); + m2X += (x - oldMeanX) * (x - meanX); + m2Y += (y - oldMeanY) * (y - meanY); + } + + @Override + public void removeInput(Column[] arguments) { + throw new UnsupportedOperationException("Remove not implemented for Correlation Accumulator"); + } + + @Override + public void addIntermediate(Column argument) { + checkArgument( + argument instanceof BinaryColumn + || (argument instanceof RunLengthEncodedColumn + && ((RunLengthEncodedColumn) argument).getValue() instanceof BinaryColumn), + "intermediate input and output should be BinaryColumn"); + + for (int i = 0; i < argument.getPositionCount(); i++) { + if (argument.isNull(i)) { + continue; + } + + byte[] bytes = argument.getBinary(i).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long otherCount = buffer.getLong(); + double otherMeanX = buffer.getDouble(); + double otherMeanY = buffer.getDouble(); + double otherM2X = buffer.getDouble(); + double otherM2Y = buffer.getDouble(); + double otherC2 = buffer.getDouble(); + + merge(otherCount, otherMeanX, otherMeanY, otherM2X, otherM2Y, otherC2); + } + } + + private void merge( + long otherCount, + double otherMeanX, + double otherMeanY, + double otherM2X, + double otherM2Y, + double otherC2) { + if (otherCount == 0) { + return; + } + if (count == 0) { + count = otherCount; + meanX = otherMeanX; + meanY = otherMeanY; + m2X = otherM2X; + m2Y = otherM2Y; + c2 = otherC2; + } else { + long newCount = count + otherCount; + double deltaX = otherMeanX - meanX; + double deltaY = otherMeanY - meanY; + + c2 += otherC2 + deltaX * deltaY * count * otherCount / newCount; + m2X += otherM2X + deltaX * deltaX * count * otherCount / newCount; + m2Y += otherM2Y + deltaY * deltaY * count * otherCount / newCount; + + meanX += deltaX * otherCount / newCount; + meanY += deltaY * otherCount / newCount; + count = newCount; + } + } + + @Override + public void evaluateIntermediate(ColumnBuilder columnBuilder) { + checkArgument( + columnBuilder instanceof BinaryColumnBuilder, + "intermediate input and output should be BinaryColumn"); + + if (count == 0) { + columnBuilder.appendNull(); + } else { + ByteBuffer buffer = ByteBuffer.allocate(Long.BYTES + Double.BYTES * 5); + buffer.putLong(count); + buffer.putDouble(meanX); + buffer.putDouble(meanY); + buffer.putDouble(m2X); + buffer.putDouble(m2Y); + buffer.putDouble(c2); + columnBuilder.writeBinary(new Binary(buffer.array())); + } + } + + @Override + public void evaluateFinal(ColumnBuilder columnBuilder) { + if (correlationType != CorrelationAccumulator.CorrelationType.CORR) { + throw new UnsupportedOperationException("Unknown type: " + correlationType); + } + + if (count < 2) { + columnBuilder.appendNull(); + } else if (m2X == 0 || m2Y == 0) { + columnBuilder.appendNull(); + } else { + columnBuilder.writeDouble(c2 / Math.sqrt(m2X * m2Y)); + } + } + + @Override + public boolean hasFinalResult() { + return false; + } + + @Override + public void addStatistics(Statistics[] statistics) { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public void reset() { + count = 0; + meanX = 0; + meanY = 0; + m2X = 0; + m2Y = 0; + c2 = 0; + } + + @Override + public boolean removable() { + return false; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableCovarianceAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableCovarianceAccumulator.java new file mode 100644 index 0000000000000..12de1412c3d1b --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableCovarianceAccumulator.java @@ -0,0 +1,238 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation; + +import org.apache.iotdb.db.queryengine.execution.aggregation.CovarianceAccumulator; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.file.metadata.statistics.Statistics; +import org.apache.tsfile.read.common.block.column.BinaryColumn; +import org.apache.tsfile.read.common.block.column.BinaryColumnBuilder; +import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.RamUsageEstimator; +import org.apache.tsfile.write.UnSupportedDataTypeException; + +import java.nio.ByteBuffer; + +import static com.google.common.base.Preconditions.checkArgument; + +public class TableCovarianceAccumulator implements TableAccumulator { + + private static final long INSTANCE_SIZE = + RamUsageEstimator.shallowSizeOfInstance(TableCovarianceAccumulator.class); + + private final TSDataType xDataType; + private final TSDataType yDataType; + private final CovarianceAccumulator.CovarianceType covarianceType; + + private long count; + private double meanX; + private double meanY; + private double c2; + + public TableCovarianceAccumulator( + TSDataType xDataType, + TSDataType yDataType, + CovarianceAccumulator.CovarianceType covarianceType) { + this.xDataType = xDataType; + this.yDataType = yDataType; + this.covarianceType = covarianceType; + } + + @Override + public long getEstimatedSize() { + return INSTANCE_SIZE; + } + + @Override + public TableAccumulator copy() { + return new TableCovarianceAccumulator(xDataType, yDataType, covarianceType); + } + + @Override + public void addInput(Column[] arguments, AggregationMask mask) { + int positionCount = mask.getSelectedPositionCount(); + + if (mask.isSelectAll()) { + for (int i = 0; i < positionCount; i++) { + if (arguments[0].isNull(i) || arguments[1].isNull(i)) { + continue; + } + double x = getDoubleValue(arguments[0], i, xDataType); + double y = getDoubleValue(arguments[1], i, yDataType); + update(x, y); + } + } else { + int[] selectedPositions = mask.getSelectedPositions(); + for (int i = 0; i < positionCount; i++) { + int position = selectedPositions[i]; + if (arguments[0].isNull(position) || arguments[1].isNull(position)) { + continue; + } + double x = getDoubleValue(arguments[0], position, xDataType); + double y = getDoubleValue(arguments[1], position, yDataType); + update(x, y); + } + } + } + + private double getDoubleValue(Column column, int position, TSDataType dataType) { + switch (dataType) { + case INT32: + case DATE: + return column.getInt(position); + case INT64: + case TIMESTAMP: + return column.getLong(position); + case FLOAT: + return column.getFloat(position); + case DOUBLE: + return column.getDouble(position); + default: + throw new UnSupportedDataTypeException( + String.format("Unsupported data type in Covariance Aggregation: %s", dataType)); + } + } + + private void update(double x, double y) { + long newCount = count + 1; + double oldMeanX = meanX; + meanX = oldMeanX + (x - oldMeanX) / newCount; + double oldMeanY = meanY; + double newMeanY = oldMeanY + (y - oldMeanY) / newCount; + meanY = newMeanY; + c2 += (x - oldMeanX) * (y - newMeanY); + count = newCount; + } + + @Override + public void removeInput(Column[] arguments) { + throw new UnsupportedOperationException("Remove not implemented for Covariance Accumulator"); + } + + @Override + public void addIntermediate(Column argument) { + checkArgument( + argument instanceof BinaryColumn + || (argument instanceof RunLengthEncodedColumn + && ((RunLengthEncodedColumn) argument).getValue() instanceof BinaryColumn), + "intermediate input and output should be BinaryColumn"); + + for (int i = 0; i < argument.getPositionCount(); i++) { + if (argument.isNull(i)) { + continue; + } + + byte[] bytes = argument.getBinary(i).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long otherCount = buffer.getLong(); + double otherMeanX = buffer.getDouble(); + double otherMeanY = buffer.getDouble(); + double otherC2 = buffer.getDouble(); + + merge(otherCount, otherMeanX, otherMeanY, otherC2); + } + } + + private void merge(long otherCount, double otherMeanX, double otherMeanY, double otherC2) { + if (otherCount == 0) { + return; + } + if (count == 0) { + count = otherCount; + meanX = otherMeanX; + meanY = otherMeanY; + c2 = otherC2; + return; + } + + long newCount = count + otherCount; + double deltaX = otherMeanX - meanX; + double deltaY = otherMeanY - meanY; + + c2 += otherC2 + deltaX * deltaY * count * otherCount / newCount; + meanX += deltaX * otherCount / newCount; + meanY += deltaY * otherCount / newCount; + count = newCount; + } + + @Override + public void evaluateIntermediate(ColumnBuilder columnBuilder) { + checkArgument( + columnBuilder instanceof BinaryColumnBuilder, + "intermediate input and output should be BinaryColumn"); + + if (count == 0) { + columnBuilder.appendNull(); + return; + } + + ByteBuffer buffer = ByteBuffer.allocate(Long.BYTES + Double.BYTES * 3); + buffer.putLong(count); + buffer.putDouble(meanX); + buffer.putDouble(meanY); + buffer.putDouble(c2); + columnBuilder.writeBinary(new Binary(buffer.array())); + } + + @Override + public void evaluateFinal(ColumnBuilder columnBuilder) { + switch (covarianceType) { + case COVAR_POP: + if (count == 0) { + columnBuilder.appendNull(); + } else { + columnBuilder.writeDouble(c2 / count); + } + break; + case COVAR_SAMP: + if (count < 2) { + columnBuilder.appendNull(); + } else { + columnBuilder.writeDouble(c2 / (count - 1)); + } + break; + default: + throw new UnsupportedOperationException("Unknown type: " + covarianceType); + } + } + + @Override + public boolean hasFinalResult() { + return false; + } + + @Override + public void addStatistics(Statistics[] statistics) { + throw new UnsupportedOperationException(getClass().getName()); + } + + @Override + public void reset() { + count = 0; + meanX = 0; + meanY = 0; + c2 = 0; + } + + @Override + public boolean removable() { + return false; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableRegressionAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableRegressionAccumulator.java new file mode 100644 index 0000000000000..5b655a67eb23c --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/TableRegressionAccumulator.java @@ -0,0 +1,230 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation; + +import org.apache.iotdb.db.queryengine.execution.aggregation.RegressionAccumulator; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.file.metadata.statistics.Statistics; +import org.apache.tsfile.read.common.block.column.BinaryColumn; +import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.RamUsageEstimator; +import org.apache.tsfile.write.UnSupportedDataTypeException; + +import java.nio.ByteBuffer; + +import static com.google.common.base.Preconditions.checkArgument; + +public class TableRegressionAccumulator implements TableAccumulator { + + private static final long INSTANCE_SIZE = + RamUsageEstimator.shallowSizeOfInstance(TableRegressionAccumulator.class); + + private final TSDataType yDataType; + private final TSDataType xDataType; + private final RegressionAccumulator.RegressionType regressionType; + + private long count; + private double meanX; + private double meanY; + private double m2X; + private double c2; + + public TableRegressionAccumulator( + TSDataType yDataType, + TSDataType xDataType, + RegressionAccumulator.RegressionType regressionType) { + this.yDataType = yDataType; + this.xDataType = xDataType; + this.regressionType = regressionType; + } + + @Override + public long getEstimatedSize() { + return INSTANCE_SIZE; + } + + @Override + public TableAccumulator copy() { + return new TableRegressionAccumulator(yDataType, xDataType, regressionType); + } + + @Override + public void addInput(Column[] arguments, AggregationMask mask) { + + int positionCount = mask.getSelectedPositionCount(); + + if (mask.isSelectAll()) { + for (int i = 0; i < positionCount; i++) { + if (arguments[0].isNull(i) || arguments[1].isNull(i)) { + continue; + } + double y = getDoubleValue(arguments[0], i, yDataType); + double x = getDoubleValue(arguments[1], i, xDataType); + update(x, y); + } + } else { + int[] selectedPositions = mask.getSelectedPositions(); + for (int i = 0; i < positionCount; i++) { + int position = selectedPositions[i]; + if (arguments[0].isNull(position) || arguments[1].isNull(position)) { + continue; + } + double y = getDoubleValue(arguments[0], position, yDataType); + double x = getDoubleValue(arguments[1], position, xDataType); + update(x, y); + } + } + } + + private double getDoubleValue(Column column, int position, TSDataType dataType) { + switch (dataType) { + case INT32: + case DATE: + return column.getInt(position); + case INT64: + case TIMESTAMP: + return column.getLong(position); + case FLOAT: + return column.getFloat(position); + case DOUBLE: + return column.getDouble(position); + default: + throw new UnSupportedDataTypeException( + String.format("Unsupported data type in Regression Aggregation: %s", dataType)); + } + } + + private void update(double x, double y) { + long newCount = count + 1; + double deltaX = x - meanX; + double deltaY = y - meanY; + meanX += deltaX / newCount; + meanY += deltaY / newCount; + c2 += deltaX * (y - meanY); + m2X += deltaX * (x - meanX); + count = newCount; + } + + @Override + public void addIntermediate(Column argument) { + checkArgument( + argument instanceof BinaryColumn + || (argument instanceof RunLengthEncodedColumn + && ((RunLengthEncodedColumn) argument).getValue() instanceof BinaryColumn)); + + for (int i = 0; i < argument.getPositionCount(); i++) { + if (argument.isNull(i)) continue; + byte[] bytes = argument.getBinary(i).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long otherCount = buffer.getLong(); + double otherMeanX = buffer.getDouble(); + double otherMeanY = buffer.getDouble(); + double otherM2X = buffer.getDouble(); + double otherC2 = buffer.getDouble(); + + merge(otherCount, otherMeanX, otherMeanY, otherM2X, otherC2); + } + } + + private void merge( + long otherCount, double otherMeanX, double otherMeanY, double otherM2X, double otherC2) { + if (otherCount == 0) return; + if (count == 0) { + count = otherCount; + meanX = otherMeanX; + meanY = otherMeanY; + m2X = otherM2X; + c2 = otherC2; + } else { + long newCount = count + otherCount; + double deltaX = otherMeanX - meanX; + double deltaY = otherMeanY - meanY; + c2 += otherC2 + deltaX * deltaY * count * otherCount / newCount; + m2X += otherM2X + deltaX * deltaX * count * otherCount / newCount; + meanX += deltaX * otherCount / newCount; + meanY += deltaY * otherCount / newCount; + count = newCount; + } + } + + @Override + public void evaluateIntermediate(ColumnBuilder columnBuilder) { + if (count == 0) { + columnBuilder.appendNull(); + } else { + byte[] bytes = new byte[40]; + ByteBuffer buffer = ByteBuffer.wrap(bytes); + buffer.putLong(count); + buffer.putDouble(meanX); + buffer.putDouble(meanY); + buffer.putDouble(m2X); + buffer.putDouble(c2); + columnBuilder.writeBinary(new Binary(bytes)); + } + } + + @Override + public void evaluateFinal(ColumnBuilder columnBuilder) { + if (count == 0 || m2X == 0) { + columnBuilder.appendNull(); + return; + } + double slope = c2 / m2X; + switch (regressionType) { + case REGR_SLOPE: + columnBuilder.writeDouble(slope); + break; + case REGR_INTERCEPT: + columnBuilder.writeDouble(meanY - slope * meanX); + break; + default: + throw new UnsupportedOperationException("Unknown type: " + regressionType); + } + } + + @Override + public void removeInput(Column[] arguments) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean hasFinalResult() { + return false; + } + + @Override + public void addStatistics(Statistics[] statistics) { + throw new UnsupportedOperationException(); + } + + @Override + public void reset() { + count = 0; + meanX = 0; + meanY = 0; + m2X = 0; + c2 = 0; + } + + @Override + public boolean removable() { + return false; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedCentralMomentAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedCentralMomentAccumulator.java new file mode 100644 index 0000000000000..62d2e261ceab0 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedCentralMomentAccumulator.java @@ -0,0 +1,276 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped; + +import org.apache.iotdb.db.queryengine.execution.aggregation.CentralMomentAccumulator; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.AggregationMask; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.DoubleBigArray; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.LongBigArray; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.read.common.block.column.BinaryColumn; +import org.apache.tsfile.read.common.block.column.BinaryColumnBuilder; +import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.RamUsageEstimator; +import org.apache.tsfile.write.UnSupportedDataTypeException; + +import java.nio.ByteBuffer; + +import static com.google.common.base.Preconditions.checkArgument; + +public class GroupedCentralMomentAccumulator implements GroupedAccumulator { + + private static final long INSTANCE_SIZE = + RamUsageEstimator.shallowSizeOfInstance(GroupedCentralMomentAccumulator.class); + + private final TSDataType seriesDataType; + private final CentralMomentAccumulator.MomentType momentType; + + private final LongBigArray counts = new LongBigArray(); + private final DoubleBigArray means = new DoubleBigArray(); + private final DoubleBigArray m2s = new DoubleBigArray(); + private final DoubleBigArray m3s = new DoubleBigArray(); + private final DoubleBigArray m4s = new DoubleBigArray(); + + public GroupedCentralMomentAccumulator( + TSDataType seriesDataType, CentralMomentAccumulator.MomentType momentType) { + this.seriesDataType = seriesDataType; + this.momentType = momentType; + } + + @Override + public long getEstimatedSize() { + return INSTANCE_SIZE + + counts.sizeOf() + + means.sizeOf() + + m2s.sizeOf() + + m3s.sizeOf() + + m4s.sizeOf(); + } + + @Override + public void setGroupCount(long groupCount) { + counts.ensureCapacity(groupCount); + means.ensureCapacity(groupCount); + m2s.ensureCapacity(groupCount); + m3s.ensureCapacity(groupCount); + m4s.ensureCapacity(groupCount); + } + + @Override + public void addInput(int[] groupIds, Column[] arguments, AggregationMask mask) { + int positionCount = mask.getSelectedPositionCount(); + if (mask.isSelectAll()) { + for (int i = 0; i < positionCount; i++) { + if (!arguments[0].isNull(i)) { + update(groupIds[i], getDoubleValue(arguments[0], i)); + } + } + } else { + int[] selectedPositions = mask.getSelectedPositions(); + for (int i = 0; i < positionCount; i++) { + int position = selectedPositions[i]; + if (!arguments[0].isNull(position)) { + update(groupIds[position], getDoubleValue(arguments[0], position)); + } + } + } + } + + private double getDoubleValue(Column column, int position) { + switch (seriesDataType) { + case INT32: + case DATE: + return column.getInt(position); + case INT64: + case TIMESTAMP: + return column.getLong(position); + case FLOAT: + return column.getFloat(position); + case DOUBLE: + return column.getDouble(position); + default: + throw new UnSupportedDataTypeException( + String.format( + "Unsupported data type in CentralMoment Aggregation: %s", seriesDataType)); + } + } + + private void update(int groupId, double value) { + long n1 = counts.get(groupId); + long newCount = n1 + 1; + double mean = means.get(groupId); + double m2 = m2s.get(groupId); + double m3 = m3s.get(groupId); + double m4 = m4s.get(groupId); + + double delta = value - mean; + double delta_n = delta / newCount; + double delta_n2 = delta_n * delta_n; + double term1 = delta * delta_n * n1; + + mean += delta_n; + m4 += + term1 * delta_n2 * (newCount * newCount - 3 * newCount + 3) + + 6 * delta_n2 * m2 + - 4 * delta_n * m3; + m3 += term1 * delta_n * (newCount - 2) - 3 * delta_n * m2; + m2 += term1; + + counts.set(groupId, newCount); + means.set(groupId, mean); + m2s.set(groupId, m2); + m3s.set(groupId, m3); + m4s.set(groupId, m4); + } + + @Override + public void addIntermediate(int[] groupIds, Column argument) { + checkArgument( + argument instanceof BinaryColumn + || (argument instanceof RunLengthEncodedColumn + && ((RunLengthEncodedColumn) argument).getValue() instanceof BinaryColumn), + "intermediate input and output should be BinaryColumn"); + + for (int i = 0; i < argument.getPositionCount(); i++) { + if (argument.isNull(i)) { + continue; + } + byte[] bytes = argument.getBinary(i).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long otherCount = buffer.getLong(); + double otherMean = buffer.getDouble(); + double otherM2 = buffer.getDouble(); + double otherM3 = buffer.getDouble(); + double otherM4 = buffer.getDouble(); + + merge(groupIds[i], otherCount, otherMean, otherM2, otherM3, otherM4); + } + } + + private void merge(int groupId, long nB, double meanB, double m2B, double m3B, double m4B) { + if (nB == 0) return; + long nA = counts.get(groupId); + if (nA == 0) { + counts.set(groupId, nB); + means.set(groupId, meanB); + m2s.set(groupId, m2B); + m3s.set(groupId, m3B); + m4s.set(groupId, m4B); + } else { + long nTotal = nA + nB; + double delta = meanB - means.get(groupId); + double delta2 = delta * delta; + double delta3 = delta * delta2; + double delta4 = delta2 * delta2; + + double m2 = m2s.get(groupId); + double m3 = m3s.get(groupId); + double m4 = m4s.get(groupId); + + m4 += + m4B + + delta4 * nA * nB * (nA * nA - nA * nB + nB * nB) / (nTotal * nTotal * nTotal) + + 6.0 * delta2 * (nA * nA * m2B + nB * nB * m2) / (nTotal * nTotal) + + 4.0 * delta * (nA * m3B - nB * m3) / nTotal; + + m3 += + m3B + + delta3 * nA * nB * (nA - nB) / (nTotal * nTotal) + + 3.0 * delta * (nA * m2B - nB * m2) / nTotal; + + m2 += m2B + delta2 * nA * nB / nTotal; + + means.add(groupId, delta * nB / nTotal); + counts.set(groupId, nTotal); + m2s.set(groupId, m2); + m3s.set(groupId, m3); + m4s.set(groupId, m4); + } + } + + @Override + public void evaluateIntermediate(int groupId, ColumnBuilder columnBuilder) { + checkArgument( + columnBuilder instanceof BinaryColumnBuilder, + "intermediate input and output should be BinaryColumn"); + + if (counts.get(groupId) == 0) { + columnBuilder.appendNull(); + } else { + ByteBuffer buffer = ByteBuffer.allocate(Long.BYTES + Double.BYTES * 4); + buffer.putLong(counts.get(groupId)); + buffer.putDouble(means.get(groupId)); + buffer.putDouble(m2s.get(groupId)); + buffer.putDouble(m3s.get(groupId)); + buffer.putDouble(m4s.get(groupId)); + columnBuilder.writeBinary(new Binary(buffer.array())); + } + } + + @Override + public void evaluateFinal(int groupId, ColumnBuilder columnBuilder) { + long count = counts.get(groupId); + double m2 = m2s.get(groupId); + + if (count == 0 || m2 == 0) { + columnBuilder.appendNull(); + return; + } + + if (momentType == CentralMomentAccumulator.MomentType.SKEWNESS) { + if (count < 3) { + columnBuilder.appendNull(); + return; + } + double m3 = m3s.get(groupId); + double result = Math.sqrt((double) count) * m3 / Math.pow(m2, 1.5); + columnBuilder.writeDouble(result); + } else { + if (count < 4) { + columnBuilder.appendNull(); + } else { + double m4 = m4s.get(groupId); + double variance = m2 / (count - 1); + double term1 = + (count * (count + 1) * m4) + / ((count - 1) * (count - 2) * (count - 3) * variance * variance); + double term2 = (3 * Math.pow(count - 1, 2)) / ((count - 2) * (count - 3)); + columnBuilder.writeDouble(term1 - term2); + } + } + } + + @Override + public void prepareFinal() {} + + @Override + public void reset() { + counts.reset(); + means.reset(); + m2s.reset(); + m3s.reset(); + m4s.reset(); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedCorrelationAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedCorrelationAccumulator.java new file mode 100644 index 0000000000000..1adcf4bf285cd --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedCorrelationAccumulator.java @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped; + +import org.apache.iotdb.db.queryengine.execution.aggregation.CorrelationAccumulator; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.AggregationMask; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.DoubleBigArray; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.LongBigArray; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.read.common.block.column.BinaryColumn; +import org.apache.tsfile.read.common.block.column.BinaryColumnBuilder; +import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.RamUsageEstimator; +import org.apache.tsfile.write.UnSupportedDataTypeException; + +import java.nio.ByteBuffer; + +import static com.google.common.base.Preconditions.checkArgument; + +public class GroupedCorrelationAccumulator implements GroupedAccumulator { + + private static final long INSTANCE_SIZE = + RamUsageEstimator.shallowSizeOfInstance(GroupedCorrelationAccumulator.class); + private final TSDataType xDataType; + private final TSDataType yDataType; + private final CorrelationAccumulator.CorrelationType correlationType; + + private final LongBigArray counts = new LongBigArray(); + private final DoubleBigArray meanXs = new DoubleBigArray(); + private final DoubleBigArray meanYs = new DoubleBigArray(); + private final DoubleBigArray m2Xs = new DoubleBigArray(); + private final DoubleBigArray m2Ys = new DoubleBigArray(); + private final DoubleBigArray c2s = new DoubleBigArray(); + + public GroupedCorrelationAccumulator( + TSDataType xDataType, + TSDataType yDataType, + CorrelationAccumulator.CorrelationType correlationType) { + this.xDataType = xDataType; + this.yDataType = yDataType; + this.correlationType = correlationType; + } + + @Override + public long getEstimatedSize() { + return INSTANCE_SIZE + + counts.sizeOf() + + meanXs.sizeOf() + + meanYs.sizeOf() + + m2Xs.sizeOf() + + m2Ys.sizeOf() + + c2s.sizeOf(); + } + + @Override + public void setGroupCount(long groupCount) { + counts.ensureCapacity(groupCount); + meanXs.ensureCapacity(groupCount); + meanYs.ensureCapacity(groupCount); + m2Xs.ensureCapacity(groupCount); + m2Ys.ensureCapacity(groupCount); + c2s.ensureCapacity(groupCount); + } + + @Override + public void addInput(int[] groupIds, Column[] arguments, AggregationMask mask) { + int positionCount = mask.getSelectedPositionCount(); + + if (mask.isSelectAll()) { + for (int i = 0; i < positionCount; i++) { + if (arguments[0].isNull(i) || arguments[1].isNull(i)) { + continue; + } + double x = getDoubleValue(arguments[0], i, xDataType); + double y = getDoubleValue(arguments[1], i, yDataType); + update(groupIds[i], x, y); + } + } else { + int[] selectedPositions = mask.getSelectedPositions(); + for (int i = 0; i < positionCount; i++) { + int position = selectedPositions[i]; + if (arguments[0].isNull(position) || arguments[1].isNull(position)) { + continue; + } + double x = getDoubleValue(arguments[0], position, xDataType); + double y = getDoubleValue(arguments[1], position, yDataType); + update(groupIds[position], x, y); + } + } + } + + private double getDoubleValue(Column column, int position, TSDataType dataType) { + switch (dataType) { + case INT32: + case DATE: + return column.getInt(position); + case INT64: + case TIMESTAMP: + return column.getLong(position); + case FLOAT: + return column.getFloat(position); + case DOUBLE: + return column.getDouble(position); + default: + throw new UnSupportedDataTypeException( + String.format("Unsupported data type in Correlation Aggregation: %s", dataType)); + } + } + + private void update(int groupId, double x, double y) { + long newCount = counts.get(groupId) + 1; + double oldMeanX = meanXs.get(groupId); + double oldMeanY = meanYs.get(groupId); + double newMeanX = oldMeanX + (x - oldMeanX) / newCount; + double newMeanY = oldMeanY + (y - oldMeanY) / newCount; + + meanXs.set(groupId, newMeanX); + meanYs.set(groupId, newMeanY); + c2s.add(groupId, (x - oldMeanX) * (y - newMeanY)); + m2Xs.add(groupId, (x - oldMeanX) * (x - newMeanX)); + m2Ys.add(groupId, (y - oldMeanY) * (y - newMeanY)); + counts.set(groupId, newCount); + } + + @Override + public void addIntermediate(int[] groupIds, Column argument) { + checkArgument( + argument instanceof BinaryColumn + || (argument instanceof RunLengthEncodedColumn + && ((RunLengthEncodedColumn) argument).getValue() instanceof BinaryColumn), + "intermediate input and output should be BinaryColumn"); + + for (int i = 0; i < argument.getPositionCount(); i++) { + if (argument.isNull(i)) { + continue; + } + + byte[] bytes = argument.getBinary(i).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long otherCount = buffer.getLong(); + double otherMeanX = buffer.getDouble(); + double otherMeanY = buffer.getDouble(); + double otherM2X = buffer.getDouble(); + double otherM2Y = buffer.getDouble(); + double otherC2 = buffer.getDouble(); + + merge(groupIds[i], otherCount, otherMeanX, otherMeanY, otherM2X, otherM2Y, otherC2); + } + } + + private void merge( + int groupId, + long otherCount, + double otherMeanX, + double otherMeanY, + double otherM2X, + double otherM2Y, + double otherC2) { + if (otherCount == 0) { + return; + } + if (counts.get(groupId) == 0) { + counts.set(groupId, otherCount); + meanXs.set(groupId, otherMeanX); + meanYs.set(groupId, otherMeanY); + m2Xs.set(groupId, otherM2X); + m2Ys.set(groupId, otherM2Y); + c2s.set(groupId, otherC2); + } else { + long newCount = counts.get(groupId) + otherCount; + double deltaX = otherMeanX - meanXs.get(groupId); + double deltaY = otherMeanY - meanYs.get(groupId); + + c2s.add(groupId, otherC2 + deltaX * deltaY * counts.get(groupId) * otherCount / newCount); + m2Xs.add(groupId, otherM2X + deltaX * deltaX * counts.get(groupId) * otherCount / newCount); + m2Ys.add(groupId, otherM2Y + deltaY * deltaY * counts.get(groupId) * otherCount / newCount); + + meanXs.add(groupId, deltaX * otherCount / newCount); + meanYs.add(groupId, deltaY * otherCount / newCount); + counts.set(groupId, newCount); + } + } + + @Override + public void evaluateIntermediate(int groupId, ColumnBuilder columnBuilder) { + checkArgument( + columnBuilder instanceof BinaryColumnBuilder, + "intermediate input and output should be BinaryColumn"); + + if (counts.get(groupId) == 0) { + columnBuilder.appendNull(); + } else { + ByteBuffer buffer = ByteBuffer.allocate(Long.BYTES + Double.BYTES * 5); + buffer.putLong(counts.get(groupId)); + buffer.putDouble(meanXs.get(groupId)); + buffer.putDouble(meanYs.get(groupId)); + buffer.putDouble(m2Xs.get(groupId)); + buffer.putDouble(m2Ys.get(groupId)); + buffer.putDouble(c2s.get(groupId)); + columnBuilder.writeBinary(new Binary(buffer.array())); + } + } + + @Override + public void evaluateFinal(int groupId, ColumnBuilder columnBuilder) { + if (correlationType != CorrelationAccumulator.CorrelationType.CORR) { + throw new UnsupportedOperationException("Unknown type: " + correlationType); + } + + if (counts.get(groupId) < 2) { + columnBuilder.appendNull(); + } else if (m2Xs.get(groupId) == 0 || m2Ys.get(groupId) == 0) { + columnBuilder.appendNull(); + } else { + columnBuilder.writeDouble( + c2s.get(groupId) / Math.sqrt(m2Xs.get(groupId) * m2Ys.get(groupId))); + } + } + + @Override + public void prepareFinal() {} + + @Override + public void reset() { + counts.reset(); + meanXs.reset(); + meanYs.reset(); + m2Xs.reset(); + m2Ys.reset(); + c2s.reset(); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedCovarianceAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedCovarianceAccumulator.java new file mode 100644 index 0000000000000..3cee6379b291e --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedCovarianceAccumulator.java @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped; + +import org.apache.iotdb.db.queryengine.execution.aggregation.CovarianceAccumulator; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.AggregationMask; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.DoubleBigArray; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.LongBigArray; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.read.common.block.column.BinaryColumn; +import org.apache.tsfile.read.common.block.column.BinaryColumnBuilder; +import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.RamUsageEstimator; +import org.apache.tsfile.write.UnSupportedDataTypeException; + +import java.nio.ByteBuffer; + +import static com.google.common.base.Preconditions.checkArgument; + +public class GroupedCovarianceAccumulator implements GroupedAccumulator { + + private static final long INSTANCE_SIZE = + RamUsageEstimator.shallowSizeOfInstance(GroupedCovarianceAccumulator.class); + + private final TSDataType xDataType; + private final TSDataType yDataType; + private final CovarianceAccumulator.CovarianceType covarianceType; + + private final LongBigArray counts = new LongBigArray(); + private final DoubleBigArray meanXs = new DoubleBigArray(); + private final DoubleBigArray meanYs = new DoubleBigArray(); + private final DoubleBigArray c2s = new DoubleBigArray(); + + public GroupedCovarianceAccumulator( + TSDataType xDataType, + TSDataType yDataType, + CovarianceAccumulator.CovarianceType covarianceType) { + this.xDataType = xDataType; + this.yDataType = yDataType; + this.covarianceType = covarianceType; + } + + @Override + public long getEstimatedSize() { + return INSTANCE_SIZE + counts.sizeOf() + meanXs.sizeOf() + meanYs.sizeOf() + c2s.sizeOf(); + } + + @Override + public void setGroupCount(long groupCount) { + counts.ensureCapacity(groupCount); + meanXs.ensureCapacity(groupCount); + meanYs.ensureCapacity(groupCount); + c2s.ensureCapacity(groupCount); + } + + @Override + public void addInput(int[] groupIds, Column[] arguments, AggregationMask mask) { + int positionCount = mask.getSelectedPositionCount(); + + if (mask.isSelectAll()) { + for (int i = 0; i < positionCount; i++) { + if (arguments[0].isNull(i) || arguments[1].isNull(i)) { + continue; + } + double x = getDoubleValue(arguments[0], i, xDataType); + double y = getDoubleValue(arguments[1], i, yDataType); + update(groupIds[i], x, y); + } + } else { + int[] selectedPositions = mask.getSelectedPositions(); + for (int i = 0; i < positionCount; i++) { + int position = selectedPositions[i]; + if (arguments[0].isNull(position) || arguments[1].isNull(position)) { + continue; + } + double x = getDoubleValue(arguments[0], position, xDataType); + double y = getDoubleValue(arguments[1], position, yDataType); + update(groupIds[position], x, y); + } + } + } + + private double getDoubleValue(Column column, int position, TSDataType dataType) { + switch (dataType) { + case INT32: + case DATE: + return column.getInt(position); + case INT64: + case TIMESTAMP: + return column.getLong(position); + case FLOAT: + return column.getFloat(position); + case DOUBLE: + return column.getDouble(position); + default: + throw new UnSupportedDataTypeException( + String.format("Unsupported data type in Covariance Aggregation: %s", dataType)); + } + } + + private void update(int groupId, double x, double y) { + long newCount = counts.get(groupId) + 1; + double oldMeanX = meanXs.get(groupId); + double oldMeanY = meanYs.get(groupId); + double newMeanX = oldMeanX + (x - oldMeanX) / newCount; + double newMeanY = oldMeanY + (y - oldMeanY) / newCount; + + meanXs.set(groupId, newMeanX); + meanYs.set(groupId, newMeanY); + c2s.add(groupId, (x - oldMeanX) * (y - newMeanY)); + counts.set(groupId, newCount); + } + + @Override + public void addIntermediate(int[] groupIds, Column argument) { + checkArgument( + argument instanceof BinaryColumn + || (argument instanceof RunLengthEncodedColumn + && ((RunLengthEncodedColumn) argument).getValue() instanceof BinaryColumn), + "intermediate input and output should be BinaryColumn"); + + for (int i = 0; i < argument.getPositionCount(); i++) { + if (argument.isNull(i)) { + continue; + } + + byte[] bytes = argument.getBinary(i).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long otherCount = buffer.getLong(); + double otherMeanX = buffer.getDouble(); + double otherMeanY = buffer.getDouble(); + double otherC2 = buffer.getDouble(); + + merge(groupIds[i], otherCount, otherMeanX, otherMeanY, otherC2); + } + } + + private void merge( + int groupId, long otherCount, double otherMeanX, double otherMeanY, double otherC2) { + if (otherCount == 0) { + return; + } + + long count = counts.get(groupId); + if (count == 0) { + counts.set(groupId, otherCount); + meanXs.set(groupId, otherMeanX); + meanYs.set(groupId, otherMeanY); + c2s.set(groupId, otherC2); + return; + } + + long newCount = count + otherCount; + double meanX = meanXs.get(groupId); + double meanY = meanYs.get(groupId); + double deltaX = otherMeanX - meanX; + double deltaY = otherMeanY - meanY; + + c2s.add(groupId, otherC2 + deltaX * deltaY * count * otherCount / newCount); + meanXs.add(groupId, deltaX * otherCount / newCount); + meanYs.add(groupId, deltaY * otherCount / newCount); + counts.set(groupId, newCount); + } + + @Override + public void evaluateIntermediate(int groupId, ColumnBuilder columnBuilder) { + checkArgument( + columnBuilder instanceof BinaryColumnBuilder, + "intermediate input and output should be BinaryColumn"); + + long count = counts.get(groupId); + if (count == 0) { + columnBuilder.appendNull(); + return; + } + + ByteBuffer buffer = ByteBuffer.allocate(Long.BYTES + Double.BYTES * 3); + buffer.putLong(count); + buffer.putDouble(meanXs.get(groupId)); + buffer.putDouble(meanYs.get(groupId)); + buffer.putDouble(c2s.get(groupId)); + columnBuilder.writeBinary(new Binary(buffer.array())); + } + + @Override + public void evaluateFinal(int groupId, ColumnBuilder columnBuilder) { + long count = counts.get(groupId); + switch (covarianceType) { + case COVAR_POP: + if (count == 0) { + columnBuilder.appendNull(); + } else { + columnBuilder.writeDouble(c2s.get(groupId) / count); + } + break; + case COVAR_SAMP: + if (count < 2) { + columnBuilder.appendNull(); + } else { + columnBuilder.writeDouble(c2s.get(groupId) / (count - 1)); + } + break; + default: + throw new UnsupportedOperationException("Unknown type: " + covarianceType); + } + } + + @Override + public void prepareFinal() {} + + @Override + public void reset() { + counts.reset(); + meanXs.reset(); + meanYs.reset(); + c2s.reset(); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedRegressionAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedRegressionAccumulator.java new file mode 100644 index 0000000000000..97aabf8c96a19 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedRegressionAccumulator.java @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped; + +import org.apache.iotdb.db.queryengine.execution.aggregation.RegressionAccumulator; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.AggregationMask; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.DoubleBigArray; +import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.LongBigArray; + +import org.apache.tsfile.block.column.Column; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.read.common.block.column.BinaryColumn; +import org.apache.tsfile.read.common.block.column.BinaryColumnBuilder; +import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.RamUsageEstimator; +import org.apache.tsfile.write.UnSupportedDataTypeException; + +import java.nio.ByteBuffer; + +import static com.google.common.base.Preconditions.checkArgument; + +public class GroupedRegressionAccumulator implements GroupedAccumulator { + + private static final long INSTANCE_SIZE = + RamUsageEstimator.shallowSizeOfInstance(GroupedRegressionAccumulator.class); + + private final TSDataType yDataType; + private final TSDataType xDataType; + private final RegressionAccumulator.RegressionType regressionType; + + private final LongBigArray counts = new LongBigArray(); + private final DoubleBigArray meanXs = new DoubleBigArray(); + private final DoubleBigArray meanYs = new DoubleBigArray(); + private final DoubleBigArray m2Xs = new DoubleBigArray(); + private final DoubleBigArray c2s = new DoubleBigArray(); + + public GroupedRegressionAccumulator( + TSDataType yDataType, + TSDataType xDataType, + RegressionAccumulator.RegressionType regressionType) { + this.yDataType = yDataType; + this.xDataType = xDataType; + this.regressionType = regressionType; + } + + @Override + public long getEstimatedSize() { + return INSTANCE_SIZE + + counts.sizeOf() + + meanXs.sizeOf() + + meanYs.sizeOf() + + m2Xs.sizeOf() + + c2s.sizeOf(); + } + + @Override + public void setGroupCount(long groupCount) { + counts.ensureCapacity(groupCount); + meanXs.ensureCapacity(groupCount); + meanYs.ensureCapacity(groupCount); + m2Xs.ensureCapacity(groupCount); + c2s.ensureCapacity(groupCount); + } + + @Override + public void addInput(int[] groupIds, Column[] arguments, AggregationMask mask) { + + int positionCount = mask.getSelectedPositionCount(); + + if (mask.isSelectAll()) { + for (int i = 0; i < positionCount; i++) { + if (arguments[0].isNull(i) || arguments[1].isNull(i)) { + continue; + } + double y = getDoubleValue(arguments[0], i, yDataType); + double x = getDoubleValue(arguments[1], i, xDataType); + update(groupIds[i], x, y); + } + } else { + int[] selectedPositions = mask.getSelectedPositions(); + for (int i = 0; i < positionCount; i++) { + int position = selectedPositions[i]; + if (arguments[0].isNull(position) || arguments[1].isNull(position)) { + continue; + } + double y = getDoubleValue(arguments[0], position, yDataType); + double x = getDoubleValue(arguments[1], position, xDataType); + update(groupIds[position], x, y); + } + } + } + + private double getDoubleValue(Column column, int position, TSDataType dataType) { + switch (dataType) { + case INT32: + case DATE: + return column.getInt(position); + case INT64: + case TIMESTAMP: + return column.getLong(position); + case FLOAT: + return column.getFloat(position); + case DOUBLE: + return column.getDouble(position); + default: + throw new UnSupportedDataTypeException( + String.format("Unsupported data type in Regression Aggregation: %s", dataType)); + } + } + + private void update(int groupId, double x, double y) { + long newCount = counts.get(groupId) + 1; + double deltaX = x - meanXs.get(groupId); + double deltaY = y - meanYs.get(groupId); + + meanXs.add(groupId, deltaX / newCount); + meanYs.add(groupId, deltaY / newCount); + + c2s.add(groupId, deltaX * (y - meanYs.get(groupId))); + m2Xs.add(groupId, deltaX * (x - meanXs.get(groupId))); + + counts.set(groupId, newCount); + } + + @Override + public void addIntermediate(int[] groupIds, Column argument) { + checkArgument( + argument instanceof BinaryColumn + || (argument instanceof RunLengthEncodedColumn + && ((RunLengthEncodedColumn) argument).getValue() instanceof BinaryColumn), + "intermediate input and output should be BinaryColumn"); + + for (int i = 0; i < argument.getPositionCount(); i++) { + if (argument.isNull(i)) { + continue; + } + + byte[] bytes = argument.getBinary(i).getValues(); + ByteBuffer buffer = ByteBuffer.wrap(bytes); + + long otherCount = buffer.getLong(); + double otherMeanX = buffer.getDouble(); + double otherMeanY = buffer.getDouble(); + double otherM2X = buffer.getDouble(); + double otherC2 = buffer.getDouble(); + + merge(groupIds[i], otherCount, otherMeanX, otherMeanY, otherM2X, otherC2); + } + } + + private void merge( + int groupId, + long otherCount, + double otherMeanX, + double otherMeanY, + double otherM2X, + double otherC2) { + if (otherCount == 0) { + return; + } + if (counts.get(groupId) == 0) { + counts.set(groupId, otherCount); + meanXs.set(groupId, otherMeanX); + meanYs.set(groupId, otherMeanY); + m2Xs.set(groupId, otherM2X); + c2s.set(groupId, otherC2); + } else { + long newCount = counts.get(groupId) + otherCount; + double deltaX = otherMeanX - meanXs.get(groupId); + double deltaY = otherMeanY - meanYs.get(groupId); + + c2s.add(groupId, otherC2 + deltaX * deltaY * counts.get(groupId) * otherCount / newCount); + m2Xs.add(groupId, otherM2X + deltaX * deltaX * counts.get(groupId) * otherCount / newCount); + + meanXs.add(groupId, deltaX * otherCount / newCount); + meanYs.add(groupId, deltaY * otherCount / newCount); + counts.set(groupId, newCount); + } + } + + @Override + public void evaluateIntermediate(int groupId, ColumnBuilder columnBuilder) { + checkArgument( + columnBuilder instanceof BinaryColumnBuilder, + "intermediate input and output should be BinaryColumn"); + + if (counts.get(groupId) == 0) { + columnBuilder.appendNull(); + } else { + ByteBuffer buffer = ByteBuffer.allocate(Long.BYTES + Double.BYTES * 4); + buffer.putLong(counts.get(groupId)); + buffer.putDouble(meanXs.get(groupId)); + buffer.putDouble(meanYs.get(groupId)); + buffer.putDouble(m2Xs.get(groupId)); + buffer.putDouble(c2s.get(groupId)); + columnBuilder.writeBinary(new Binary(buffer.array())); + } + } + + @Override + public void evaluateFinal(int groupId, ColumnBuilder columnBuilder) { + if (counts.get(groupId) == 0 || m2Xs.get(groupId) == 0) { + columnBuilder.appendNull(); + return; + } + double slope = c2s.get(groupId) / m2Xs.get(groupId); + switch (regressionType) { + case REGR_SLOPE: + columnBuilder.writeDouble(slope); + break; + case REGR_INTERCEPT: + columnBuilder.writeDouble(meanYs.get(groupId) - slope * meanXs.get(groupId)); + break; + default: + throw new UnsupportedOperationException("Unknown type: " + regressionType); + } + } + + @Override + public void prepareFinal() {} + + @Override + public void reset() { + counts.reset(); + meanXs.reset(); + meanYs.reset(); + m2Xs.reset(); + c2s.reset(); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java index adc80c7bb1522..7bc866fdd2bf5 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ExpressionTypeAnalyzer.java @@ -361,6 +361,7 @@ public TSDataType visitFunctionExpression( } if (functionExpression.isBuiltInAggregationFunctionExpression()) { + return setExpressionType( functionExpression, TypeInferenceUtils.getBuiltinAggregationDataType( @@ -542,9 +543,21 @@ private TSDataType getInputExpressionTypeForAggregation( case SqlConstant.VARIANCE: case SqlConstant.VAR_POP: case SqlConstant.VAR_SAMP: + case SqlConstant.SKEWNESS: + case SqlConstant.KURTOSIS: case SqlConstant.MAX_BY: case SqlConstant.MIN_BY: return expressionTypes.get(NodeRef.of(inputExpressions.get(0))); + case SqlConstant.CORR: + case SqlConstant.COVAR_POP: + case SqlConstant.COVAR_SAMP: + case SqlConstant.REGR_SLOPE: + case SqlConstant.REGR_INTERCEPT: + TypeInferenceUtils.verifyIsAggregationDataTypeMatchedForBothInputs( + aggregateFunctionName, + expressionTypes.get(NodeRef.of(inputExpressions.get(0))), + expressionTypes.get(NodeRef.of(inputExpressions.get(1)))); + return expressionTypes.get(NodeRef.of(inputExpressions.get(0))); default: throw new IllegalArgumentException( "Invalid Aggregation function: " + aggregateFunctionName); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java index 706d14f052cd2..b8b2117e976a0 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/parser/ASTVisitor.java @@ -3186,6 +3186,8 @@ private void checkAggregationFunctionInput(FunctionExpression functionExpression case SqlConstant.VARIANCE: case SqlConstant.VAR_POP: case SqlConstant.VAR_SAMP: + case SqlConstant.SKEWNESS: + case SqlConstant.KURTOSIS: checkFunctionExpressionInputSize( functionExpression.getExpressionString(), functionExpression.getExpressions().size(), @@ -3194,6 +3196,11 @@ private void checkAggregationFunctionInput(FunctionExpression functionExpression case SqlConstant.COUNT_IF: case SqlConstant.MAX_BY: case SqlConstant.MIN_BY: + case SqlConstant.CORR: + case SqlConstant.COVAR_POP: + case SqlConstant.COVAR_SAMP: + case SqlConstant.REGR_SLOPE: + case SqlConstant.REGR_INTERCEPT: checkFunctionExpressionInputSize( functionExpression.getExpressionString(), functionExpression.getExpressions().size(), diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/AggregationDescriptor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/AggregationDescriptor.java index ac30dcf505afd..36c6470078853 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/AggregationDescriptor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/AggregationDescriptor.java @@ -187,6 +187,27 @@ public List getActualAggregationNames(boolean isPartial) { case VAR_SAMP: outputAggregationNames.add(addPartialSuffix(SqlConstant.VAR_SAMP)); break; + case CORR: + outputAggregationNames.add(addPartialSuffix(SqlConstant.CORR)); + break; + case COVAR_POP: + outputAggregationNames.add(addPartialSuffix(SqlConstant.COVAR_POP)); + break; + case COVAR_SAMP: + outputAggregationNames.add(addPartialSuffix(SqlConstant.COVAR_SAMP)); + break; + case REGR_SLOPE: + outputAggregationNames.add(addPartialSuffix(SqlConstant.REGR_SLOPE)); + break; + case REGR_INTERCEPT: + outputAggregationNames.add(addPartialSuffix(SqlConstant.REGR_INTERCEPT)); + break; + case SKEWNESS: + outputAggregationNames.add(addPartialSuffix(SqlConstant.SKEWNESS)); + break; + case KURTOSIS: + outputAggregationNames.add(addPartialSuffix(SqlConstant.KURTOSIS)); + break; case MAX_BY: outputAggregationNames.add(addPartialSuffix(SqlConstant.MAX_BY)); break; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java index 8934de172e9c5..1bc661eaf3052 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java @@ -597,6 +597,45 @@ && isIntegerNumber(argumentTypes.get(2)))) { functionName)); } break; + case SqlConstant.CORR: + case SqlConstant.COVAR_POP: + case SqlConstant.COVAR_SAMP: + case SqlConstant.REGR_SLOPE: + case SqlConstant.REGR_INTERCEPT: + if (argumentTypes.size() != 2) { + throw new SemanticException( + String.format( + "Error size of input expressions. expression: %s, actual size: %s, expected size: [2].", + functionName.toUpperCase(), argumentTypes.size())); + } + if (!isNumericType(argumentTypes.get(0))) { + throw new SemanticException( + String.format( + "Aggregate functions [%s] only support numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]", + functionName.toUpperCase())); + } + if (!isNumericType(argumentTypes.get(1))) { + throw new SemanticException( + String.format( + "Aggregate functions [%s] only support numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]", + functionName.toUpperCase())); + } + break; + case SqlConstant.SKEWNESS: + case SqlConstant.KURTOSIS: + if (argumentTypes.size() != 1) { + throw new SemanticException( + String.format( + "Error size of input expressions. expression: %s, actual size: %s, expected size: [1].", + functionName.toUpperCase(), argumentTypes.size())); + } + if (!isNumericType(argumentTypes.get(0))) { + throw new SemanticException( + String.format( + "Aggregate functions [%s] only support numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]", + functionName.toUpperCase())); + } + break; case SqlConstant.MIN: case SqlConstant.MAX: case SqlConstant.MODE: @@ -701,6 +740,13 @@ && isIntegerNumber(argumentTypes.get(2)))) { case SqlConstant.VARIANCE: case SqlConstant.VAR_POP: case SqlConstant.VAR_SAMP: + case SqlConstant.CORR: + case SqlConstant.COVAR_POP: + case SqlConstant.COVAR_SAMP: + case SqlConstant.REGR_SLOPE: + case SqlConstant.REGR_INTERCEPT: + case SqlConstant.SKEWNESS: + case SqlConstant.KURTOSIS: return DOUBLE; case SqlConstant.APPROX_MOST_FREQUENT: return STRING; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/SchemaUtils.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/SchemaUtils.java index 773de36a067fd..71a24bc23678f 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/SchemaUtils.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/SchemaUtils.java @@ -88,6 +88,13 @@ public static TSDataType getBuiltinAggregationTypeByFuncName(String aggregation) case SqlConstant.VARIANCE: case SqlConstant.VAR_POP: case SqlConstant.VAR_SAMP: + case SqlConstant.CORR: + case SqlConstant.COVAR_POP: + case SqlConstant.COVAR_SAMP: + case SqlConstant.REGR_SLOPE: + case SqlConstant.REGR_INTERCEPT: + case SqlConstant.SKEWNESS: + case SqlConstant.KURTOSIS: return TSDataType.DOUBLE; // Partial aggregation names case SqlConstant.STDDEV + "_partial": @@ -96,6 +103,13 @@ public static TSDataType getBuiltinAggregationTypeByFuncName(String aggregation) case SqlConstant.VARIANCE + "_partial": case SqlConstant.VAR_POP + "_partial": case SqlConstant.VAR_SAMP + "_partial": + case SqlConstant.CORR + "_partial": + case SqlConstant.COVAR_POP + "_partial": + case SqlConstant.COVAR_SAMP + "_partial": + case SqlConstant.REGR_SLOPE + "_partial": + case SqlConstant.REGR_INTERCEPT + "_partial": + case SqlConstant.SKEWNESS + "_partial": + case SqlConstant.KURTOSIS + "_partial": case SqlConstant.MAX_BY + "_partial": case SqlConstant.MIN_BY + "_partial": return TSDataType.TEXT; @@ -163,6 +177,20 @@ public static String getBuiltinAggregationName(TAggregationType aggregationType) return SqlConstant.VAR_POP; case VAR_SAMP: return SqlConstant.VAR_SAMP; + case CORR: + return SqlConstant.CORR; + case COVAR_POP: + return SqlConstant.COVAR_POP; + case COVAR_SAMP: + return SqlConstant.COVAR_SAMP; + case REGR_SLOPE: + return SqlConstant.REGR_SLOPE; + case REGR_INTERCEPT: + return SqlConstant.REGR_INTERCEPT; + case SKEWNESS: + return SqlConstant.SKEWNESS; + case KURTOSIS: + return SqlConstant.KURTOSIS; default: return null; } @@ -198,6 +226,13 @@ public static boolean isConsistentWithScanOrder( case VAR_SAMP: case MAX_BY: case MIN_BY: + case CORR: + case COVAR_POP: + case COVAR_SAMP: + case REGR_SLOPE: + case REGR_INTERCEPT: + case SKEWNESS: + case KURTOSIS: case UDAF: return true; default: @@ -232,6 +267,20 @@ public static List splitPartialBuiltinAggregation(TAggregationType aggre return Collections.singletonList(addPartialSuffix(SqlConstant.VAR_POP)); case VAR_SAMP: return Collections.singletonList(addPartialSuffix(SqlConstant.VAR_SAMP)); + case CORR: + return Collections.singletonList(addPartialSuffix(SqlConstant.CORR)); + case COVAR_POP: + return Collections.singletonList(addPartialSuffix(SqlConstant.COVAR_POP)); + case COVAR_SAMP: + return Collections.singletonList(addPartialSuffix(SqlConstant.COVAR_SAMP)); + case REGR_SLOPE: + return Collections.singletonList(addPartialSuffix(SqlConstant.REGR_SLOPE)); + case REGR_INTERCEPT: + return Collections.singletonList(addPartialSuffix(SqlConstant.REGR_INTERCEPT)); + case SKEWNESS: + return Collections.singletonList(addPartialSuffix(SqlConstant.SKEWNESS)); + case KURTOSIS: + return Collections.singletonList(addPartialSuffix(SqlConstant.KURTOSIS)); case MAX_BY: return Collections.singletonList(addPartialSuffix(SqlConstant.MAX_BY)); case MIN_BY: diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/TypeInferenceUtils.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/TypeInferenceUtils.java index 8fc1d647dc36e..10413c6f2537e 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/TypeInferenceUtils.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/TypeInferenceUtils.java @@ -154,6 +154,13 @@ public static TSDataType getBuiltinAggregationDataType( case SqlConstant.VARIANCE: case SqlConstant.VAR_POP: case SqlConstant.VAR_SAMP: + case SqlConstant.CORR: + case SqlConstant.COVAR_POP: + case SqlConstant.COVAR_SAMP: + case SqlConstant.REGR_SLOPE: + case SqlConstant.REGR_INTERCEPT: + case SqlConstant.SKEWNESS: + case SqlConstant.KURTOSIS: return TSDataType.DOUBLE; default: throw new IllegalArgumentException( @@ -190,7 +197,22 @@ private static void verifyIsAggregationDataTypeMatched(String aggrFuncName, TSDa return; } throw new SemanticException( - "Aggregate functions [AVG, SUM, EXTREME, STDDEV, STDDEV_POP, STDDEV_SAMP, VARIANCE, VAR_POP, VAR_SAMP] only support numeric data types [INT32, INT64, FLOAT, DOUBLE]"); + "Aggregate functions [AVG, SUM, EXTREME, STDDEV, STDDEV_POP, STDDEV_SAMP, " + + "VARIANCE, VAR_POP, VAR_SAMP] only support " + + "numeric data types [INT32, INT64, FLOAT, DOUBLE]"); + case SqlConstant.SKEWNESS: + case SqlConstant.KURTOSIS: + if (dataType.isNumeric() || TSDataType.TIMESTAMP.equals(dataType)) { + return; + } + throw new SemanticException( + "Aggregate functions [SKEWNESS, KURTOSIS] only support " + + "numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]"); + case SqlConstant.CORR: + case SqlConstant.COVAR_POP: + case SqlConstant.COVAR_SAMP: + case SqlConstant.REGR_SLOPE: + case SqlConstant.REGR_INTERCEPT: case SqlConstant.COUNT: case SqlConstant.COUNT_TIME: case SqlConstant.MIN_TIME: @@ -215,6 +237,30 @@ private static void verifyIsAggregationDataTypeMatched(String aggrFuncName, TSDa } } + public static void verifyIsAggregationDataTypeMatchedForBothInputs( + String aggrFuncName, TSDataType firstDataType, TSDataType secondDataType) { + switch (aggrFuncName.toLowerCase()) { + case SqlConstant.CORR: + case SqlConstant.COVAR_POP: + case SqlConstant.COVAR_SAMP: + case SqlConstant.REGR_SLOPE: + case SqlConstant.REGR_INTERCEPT: + if ((firstDataType != null + && !firstDataType.isNumeric() + && !TSDataType.TIMESTAMP.equals(firstDataType)) + || (secondDataType != null + && !secondDataType.isNumeric() + && !TSDataType.TIMESTAMP.equals(secondDataType))) { + throw new SemanticException( + "Aggregate functions [CORR, COVAR_POP, COVAR_SAMP, REGR_SLOPE, REGR_INTERCEPT] only support " + + "numeric data types [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]"); + } + return; + default: + break; + } + } + /** * Bind Type for non-series input Expressions of AggregationFunction and check Semantic * @@ -245,6 +291,13 @@ public static void bindTypeForBuiltinAggregationNonSeriesInputExpressions( case SqlConstant.VARIANCE: case SqlConstant.VAR_POP: case SqlConstant.VAR_SAMP: + case SqlConstant.CORR: + case SqlConstant.COVAR_POP: + case SqlConstant.COVAR_SAMP: + case SqlConstant.REGR_SLOPE: + case SqlConstant.REGR_INTERCEPT: + case SqlConstant.SKEWNESS: + case SqlConstant.KURTOSIS: case SqlConstant.MAX_BY: case SqlConstant.MIN_BY: return; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/constant/SqlConstant.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/constant/SqlConstant.java index 8120aff6059ba..32d43cedd5da3 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/constant/SqlConstant.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/utils/constant/SqlConstant.java @@ -74,6 +74,13 @@ protected SqlConstant() { public static final String VARIANCE = "variance"; public static final String VAR_POP = "var_pop"; public static final String VAR_SAMP = "var_samp"; + public static final String CORR = "corr"; + public static final String COVAR_POP = "covar_pop"; + public static final String COVAR_SAMP = "covar_samp"; + public static final String REGR_SLOPE = "regr_slope"; + public static final String REGR_INTERCEPT = "regr_intercept"; + public static final String SKEWNESS = "skewness"; + public static final String KURTOSIS = "kurtosis"; public static final String COUNT_TIME = "count_time"; public static final String COUNT_TIME_HEADER = "count_time(*)"; diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinAggregationFunction.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinAggregationFunction.java index 1c6b25ef53aaf..c8274b94d1c44 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinAggregationFunction.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/BuiltinAggregationFunction.java @@ -46,7 +46,14 @@ public enum BuiltinAggregationFunction { VAR_POP("var_pop"), VAR_SAMP("var_samp"), MAX_BY("max_by"), - MIN_BY("min_by"); + MIN_BY("min_by"), + CORR("corr"), + COVAR_POP("covar_pop"), + COVAR_SAMP("covar_samp"), + REGR_SLOPE("regr_slope"), + REGR_INTERCEPT("regr_intercept"), + SKEWNESS("skewness"), + KURTOSIS("kurtosis"); private final String functionName; @@ -97,6 +104,13 @@ public static boolean canUseStatistics(String name) { case "var_samp": case "max_by": case "min_by": + case "corr": + case "covar_pop": + case "covar_samp": + case "regr_slope": + case "regr_intercept": + case "skewness": + case "kurtosis": return false; default: throw new IllegalArgumentException("Invalid Aggregation function: " + name); @@ -131,6 +145,13 @@ public static boolean canSplitToMultiPhases(String name) { case "var_samp": case "max_by": case "min_by": + case "corr": + case "covar_pop": + case "covar_samp": + case "regr_slope": + case "regr_intercept": + case "skewness": + case "kurtosis": return true; case "count_if": case "count_time": diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java index 39f7cde84c490..151df10988d65 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/udf/builtin/relational/TableBuiltinAggregationFunction.java @@ -58,7 +58,14 @@ public enum TableBuiltinAggregationFunction { VAR_POP("var_pop"), VAR_SAMP("var_samp"), APPROX_COUNT_DISTINCT("approx_count_distinct"), - APPROX_MOST_FREQUENT("approx_most_frequent"); + APPROX_MOST_FREQUENT("approx_most_frequent"), + CORR("corr"), + COVAR_POP("covar_pop"), + COVAR_SAMP("covar_samp"), + REGR_SLOPE("regr_slope"), + REGR_INTERCEPT("regr_intercept"), + SKEWNESS("skewness"), + KURTOSIS("kurtosis"); private final String functionName; @@ -103,6 +110,13 @@ public static Type getIntermediateType(String name, List originalArgumentT case "variance": case "var_pop": case "var_samp": + case "corr": + case "covar_pop": + case "covar_samp": + case "regr_slope": + case "regr_intercept": + case "skewness": + case "kurtosis": case "approx_count_distinct": return RowType.anonymous(Collections.emptyList()); case "extreme": diff --git a/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift b/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift index 3287f35b9bb92..fa4ea97fef110 100644 --- a/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift +++ b/iotdb-protocol/thrift-commons/src/main/thrift/common.thrift @@ -293,7 +293,14 @@ enum TAggregationType { MAX, COUNT_ALL, APPROX_COUNT_DISTINCT, - APPROX_MOST_FREQUENT + APPROX_MOST_FREQUENT, + CORR, + COVAR_POP, + COVAR_SAMP, + REGR_SLOPE, + REGR_INTERCEPT, + SKEWNESS, + KURTOSIS } struct TShowConfigurationTemplateResp {