From cf054dce562341c7193b8b30ec6392f43f47db79 Mon Sep 17 00:00:00 2001 From: Rafael Fernandez Date: Thu, 12 Feb 2026 12:03:34 +0100 Subject: [PATCH 1/9] add support aes_decrypt --- docs/spark_expressions_support.md | 2 +- .../apache/comet/serde/QueryPlanSerde.scala | 1 + .../comet/CometStringExpressionSuite.scala | 40 +++++++++++++++++++ 3 files changed, 42 insertions(+), 1 deletion(-) diff --git a/docs/spark_expressions_support.md b/docs/spark_expressions_support.md index 5474894108..752ee9d28e 100644 --- a/docs/spark_expressions_support.md +++ b/docs/spark_expressions_support.md @@ -353,7 +353,7 @@ ### misc_funcs -- [ ] aes_decrypt +- [x] aes_decrypt - [ ] aes_encrypt - [ ] assert_true - [x] current_catalog diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 8c39ba779d..6783801a9c 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -149,6 +149,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[Sha1] -> CometSha1) private val stringExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( + classOf[AesDecrypt] -> CometScalarFunction("aes_decrypt"), classOf[Ascii] -> CometScalarFunction("ascii"), classOf[BitLength] -> CometScalarFunction("bit_length"), classOf[Chr] -> CometScalarFunction("char"), diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index 121d7f7d5a..d071c595ff 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -478,4 +478,44 @@ class CometStringExpressionSuite extends CometTestBase { } } + test("aes_decrypt") { + withTable("aes_tbl") { + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + sql(""" + |CREATE TABLE aes_tbl( + | encrypted_default BINARY, + | encrypted_with_aad BINARY, + | `key` BINARY, + | mode STRING, + | padding STRING, + | aad BINARY + |) USING parquet + |""".stripMargin) + + sql(""" + |INSERT INTO aes_tbl + |SELECT + | aes_encrypt( + | encode('Spark SQL', 'UTF-8'), + | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8')), + | aes_encrypt( + | encode('Spark SQL', 'UTF-8'), + | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + | 'GCM', + | 'DEFAULT', + | unhex('00112233445566778899AABB')), + | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + | 'GCM', + | 'DEFAULT', + | unhex('00112233445566778899AABB') + |""".stripMargin) + } + + checkSparkAnswerAndOperator( + "SELECT CAST(aes_decrypt(encrypted_default, `key`) AS STRING) FROM aes_tbl") + checkSparkAnswerAndOperator( + "SELECT CAST(aes_decrypt(encrypted_with_aad, `key`, mode, padding, aad) AS STRING) FROM aes_tbl") + } + } + } From c3c660210db520aa60fa423674cd0f1d262176b4 Mon Sep 17 00:00:00 2001 From: Rafael Fernandez Date: Thu, 12 Feb 2026 12:33:38 +0100 Subject: [PATCH 2/9] add cargo clean before build with miri --- .github/workflows/miri.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/miri.yml b/.github/workflows/miri.yml index ea36e1359a..7c193d6d12 100644 --- a/.github/workflows/miri.yml +++ b/.github/workflows/miri.yml @@ -60,4 +60,5 @@ jobs: - name: Test with Miri run: | cd native + cargo clean --target-dir target/miri MIRIFLAGS="-Zmiri-disable-isolation" cargo miri test --lib --bins --tests --examples From d60747043408156809a1bc6bb1dc7d21dcf00b09 Mon Sep 17 00:00:00 2001 From: Rafael Fernandez Date: Fri, 13 Feb 2026 14:09:07 +0100 Subject: [PATCH 3/9] tests --- .github/workflows/miri.yml | 1 - .../expressions/misc/aes_decrypt.sql | 54 ++++++++++ .../comet/CometMiscExpressionSuite.scala | 102 ++++++++++++++++++ .../comet/CometStringExpressionSuite.scala | 40 ------- 4 files changed, 156 insertions(+), 41 deletions(-) create mode 100644 spark/src/test/resources/sql-tests/expressions/misc/aes_decrypt.sql create mode 100644 spark/src/test/scala/org/apache/comet/CometMiscExpressionSuite.scala diff --git a/.github/workflows/miri.yml b/.github/workflows/miri.yml index 7c193d6d12..ea36e1359a 100644 --- a/.github/workflows/miri.yml +++ b/.github/workflows/miri.yml @@ -60,5 +60,4 @@ jobs: - name: Test with Miri run: | cd native - cargo clean --target-dir target/miri MIRIFLAGS="-Zmiri-disable-isolation" cargo miri test --lib --bins --tests --examples diff --git a/spark/src/test/resources/sql-tests/expressions/misc/aes_decrypt.sql b/spark/src/test/resources/sql-tests/expressions/misc/aes_decrypt.sql new file mode 100644 index 0000000000..f89f24d1ee --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/misc/aes_decrypt.sql @@ -0,0 +1,54 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- MinSparkVersion: 3.5 + +statement +CREATE TABLE aes_tbl( + encrypted_default BINARY, + encrypted_with_aad BINARY, + `key` BINARY, + mode STRING, + padding STRING, + iv BINARY, + aad STRING +) USING parquet + +statement +INSERT INTO aes_tbl +SELECT + aes_encrypt( + encode('Spark SQL', 'UTF-8'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8')), + aes_encrypt( + encode('Spark SQL', 'UTF-8'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + 'GCM', + 'DEFAULT', + unhex('00112233445566778899AABB'), + 'Comet AAD'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + 'GCM', + 'DEFAULT', + unhex('00112233445566778899AABB'), + 'Comet AAD' + +query expect_fallback(Static invoke expression: aesDecrypt is not supported) +SELECT CAST(aes_decrypt(encrypted_default, `key`) AS STRING) FROM aes_tbl + +query expect_fallback(Static invoke expression: aesDecrypt is not supported) +SELECT CAST(aes_decrypt(encrypted_with_aad, `key`, mode, padding, aad) AS STRING) FROM aes_tbl diff --git a/spark/src/test/scala/org/apache/comet/CometMiscExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometMiscExpressionSuite.scala new file mode 100644 index 0000000000..d87a42d328 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometMiscExpressionSuite.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.apache.spark.sql.CometTestBase + +import org.apache.comet.CometSparkSessionExtensions.isSpark35Plus + +class CometMiscExpressionSuite extends CometTestBase { + + test("aes_decrypt") { + withTable("aes_tbl") { + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + sql(""" + |CREATE TABLE aes_tbl( + | encrypted_default BINARY, + | encrypted_with_aad BINARY, + | `key` BINARY, + | mode STRING, + | padding STRING, + | iv BINARY, + | aad STRING + |) USING parquet + |""".stripMargin) + + if (isSpark35Plus) { + sql(""" + |INSERT INTO aes_tbl + |SELECT + | aes_encrypt( + | encode('Spark SQL', 'UTF-8'), + | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8')), + | aes_encrypt( + | encode('Spark SQL', 'UTF-8'), + | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + | 'GCM', + | 'DEFAULT', + | unhex('00112233445566778899AABB'), + | 'Comet AAD'), + | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + | 'GCM', + | 'DEFAULT', + | unhex('00112233445566778899AABB'), + | 'Comet AAD' + |""".stripMargin) + } else { + sql(""" + |INSERT INTO aes_tbl + |SELECT + | aes_encrypt( + | encode('Spark SQL', 'UTF-8'), + | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8')), + | aes_encrypt( + | encode('Spark SQL', 'UTF-8'), + | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + | 'GCM', + | 'DEFAULT'), + | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + | 'GCM', + | 'DEFAULT', + | cast(null as binary), + | cast(null as string) + |""".stripMargin) + } + } + + if (isSpark35Plus) { + checkSparkAnswerAndFallbackReason( + "SELECT CAST(aes_decrypt(encrypted_default, `key`) AS STRING) FROM aes_tbl", + "Static invoke expression: aesDecrypt is not supported") + checkSparkAnswerAndFallbackReason( + "SELECT CAST(aes_decrypt(encrypted_with_aad, `key`, mode, padding, aad) AS STRING) FROM aes_tbl", + "Static invoke expression: aesDecrypt is not supported") + } else { + checkSparkAnswerAndFallbackReason( + "SELECT CAST(aes_decrypt(encrypted_default, `key`) AS STRING) FROM aes_tbl", + "Static invoke expression: aesDecrypt is not supported") + checkSparkAnswerAndFallbackReason( + "SELECT CAST(aes_decrypt(encrypted_with_aad, `key`, mode, padding) AS STRING) FROM aes_tbl", + "Static invoke expression: aesDecrypt is not supported") + } + } + } + +} diff --git a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala index d071c595ff..121d7f7d5a 100644 --- a/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometStringExpressionSuite.scala @@ -478,44 +478,4 @@ class CometStringExpressionSuite extends CometTestBase { } } - test("aes_decrypt") { - withTable("aes_tbl") { - withSQLConf(CometConf.COMET_ENABLED.key -> "false") { - sql(""" - |CREATE TABLE aes_tbl( - | encrypted_default BINARY, - | encrypted_with_aad BINARY, - | `key` BINARY, - | mode STRING, - | padding STRING, - | aad BINARY - |) USING parquet - |""".stripMargin) - - sql(""" - |INSERT INTO aes_tbl - |SELECT - | aes_encrypt( - | encode('Spark SQL', 'UTF-8'), - | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8')), - | aes_encrypt( - | encode('Spark SQL', 'UTF-8'), - | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), - | 'GCM', - | 'DEFAULT', - | unhex('00112233445566778899AABB')), - | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), - | 'GCM', - | 'DEFAULT', - | unhex('00112233445566778899AABB') - |""".stripMargin) - } - - checkSparkAnswerAndOperator( - "SELECT CAST(aes_decrypt(encrypted_default, `key`) AS STRING) FROM aes_tbl") - checkSparkAnswerAndOperator( - "SELECT CAST(aes_decrypt(encrypted_with_aad, `key`, mode, padding, aad) AS STRING) FROM aes_tbl") - } - } - } From e9308b0ff149649bffb7acb5dd8fe1bc50c4ec77 Mon Sep 17 00:00:00 2001 From: Rafael Fernandez Date: Fri, 13 Feb 2026 16:21:34 +0100 Subject: [PATCH 4/9] add support --- .../apache/comet/serde/QueryPlanSerde.scala | 2 +- .../org/apache/comet/serde/statics.scala | 40 +++++++- .../expressions/misc/aes_decrypt.sql | 4 +- .../comet/CometMiscExpressionSuite.scala | 93 +++++++------------ 4 files changed, 72 insertions(+), 67 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 6783801a9c..9a4226aade 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -149,7 +149,6 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[Sha1] -> CometSha1) private val stringExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( - classOf[AesDecrypt] -> CometScalarFunction("aes_decrypt"), classOf[Ascii] -> CometScalarFunction("ascii"), classOf[BitLength] -> CometScalarFunction("bit_length"), classOf[Chr] -> CometScalarFunction("char"), @@ -223,6 +222,7 @@ object QueryPlanSerde extends Logging with CometExprShim { private val miscExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( // TODO PromotePrecision + classOf[AesDecrypt] -> CometAesDecrypt, classOf[Alias] -> CometAlias, classOf[AttributeReference] -> CometAttributeReference, classOf[BloomFilterMightContain] -> CometBloomFilterMightContain, diff --git a/spark/src/main/scala/org/apache/comet/serde/statics.scala b/spark/src/main/scala/org/apache/comet/serde/statics.scala index 9dbc6d169f..0ff31908d4 100644 --- a/spark/src/main/scala/org/apache/comet/serde/statics.scala +++ b/spark/src/main/scala/org/apache/comet/serde/statics.scala @@ -19,11 +19,46 @@ package org.apache.comet.serde -import org.apache.spark.sql.catalyst.expressions.{Attribute, ExpressionImplUtils} +import org.apache.spark.sql.catalyst.expressions.{AesDecrypt, Attribute, Expression, ExpressionImplUtils} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.serde.ExprOuterClass.Expr +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProtoWithReturnType} + +private object CometAesDecryptHelper { + def convertToAesDecryptExpr[T <: Expression]( + expr: T, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + val childExpr = expr.children.map(exprToProtoInternal(_, inputs, binding)) + val optExpr = scalarFunctionExprToProtoWithReturnType( + "aes_decrypt", + expr.dataType, + failOnError = false, + childExpr: _*) + optExprWithInfo(optExpr, expr, expr.children: _*) + } +} + +object CometAesDecrypt extends CometExpressionSerde[AesDecrypt] { + override def convert( + expr: AesDecrypt, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + CometAesDecryptHelper.convertToAesDecryptExpr(expr, inputs, binding) + } +} + +object CometAesDecryptStaticInvoke extends CometExpressionSerde[StaticInvoke] { + override def convert( + expr: StaticInvoke, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + CometAesDecryptHelper.convertToAesDecryptExpr(expr, inputs, binding) + } +} object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] { @@ -35,7 +70,8 @@ object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] { Map( ("readSidePadding", classOf[CharVarcharCodegenUtils]) -> CometScalarFunction( "read_side_padding"), - ("isLuhnNumber", classOf[ExpressionImplUtils]) -> CometScalarFunction("luhn_check")) + ("isLuhnNumber", classOf[ExpressionImplUtils]) -> CometScalarFunction("luhn_check"), + ("aesDecrypt", classOf[ExpressionImplUtils]) -> CometAesDecryptStaticInvoke) override def convert( expr: StaticInvoke, diff --git a/spark/src/test/resources/sql-tests/expressions/misc/aes_decrypt.sql b/spark/src/test/resources/sql-tests/expressions/misc/aes_decrypt.sql index f89f24d1ee..cca41c83d7 100644 --- a/spark/src/test/resources/sql-tests/expressions/misc/aes_decrypt.sql +++ b/spark/src/test/resources/sql-tests/expressions/misc/aes_decrypt.sql @@ -47,8 +47,8 @@ SELECT unhex('00112233445566778899AABB'), 'Comet AAD' -query expect_fallback(Static invoke expression: aesDecrypt is not supported) +query SELECT CAST(aes_decrypt(encrypted_default, `key`) AS STRING) FROM aes_tbl -query expect_fallback(Static invoke expression: aesDecrypt is not supported) +query SELECT CAST(aes_decrypt(encrypted_with_aad, `key`, mode, padding, aad) AS STRING) FROM aes_tbl diff --git a/spark/src/test/scala/org/apache/comet/CometMiscExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometMiscExpressionSuite.scala index d87a42d328..b6f7e921b4 100644 --- a/spark/src/test/scala/org/apache/comet/CometMiscExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometMiscExpressionSuite.scala @@ -26,75 +26,44 @@ import org.apache.comet.CometSparkSessionExtensions.isSpark35Plus class CometMiscExpressionSuite extends CometTestBase { test("aes_decrypt") { - withTable("aes_tbl") { + withTempView("aes_tbl") { withSQLConf(CometConf.COMET_ENABLED.key -> "false") { - sql(""" - |CREATE TABLE aes_tbl( - | encrypted_default BINARY, - | encrypted_with_aad BINARY, - | `key` BINARY, - | mode STRING, - | padding STRING, - | iv BINARY, - | aad STRING - |) USING parquet - |""".stripMargin) - - if (isSpark35Plus) { - sql(""" - |INSERT INTO aes_tbl - |SELECT - | aes_encrypt( - | encode('Spark SQL', 'UTF-8'), - | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8')), - | aes_encrypt( - | encode('Spark SQL', 'UTF-8'), - | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), - | 'GCM', - | 'DEFAULT', - | unhex('00112233445566778899AABB'), - | 'Comet AAD'), - | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), - | 'GCM', - | 'DEFAULT', - | unhex('00112233445566778899AABB'), - | 'Comet AAD' - |""".stripMargin) + val aesDf = if (isSpark35Plus) { + spark + .range(1) + .selectExpr( + "aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8')) as encrypted_default", + "aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), 'GCM', 'DEFAULT', unhex('00112233445566778899AABB'), 'Comet AAD') as encrypted_with_aad", + "encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8') as `key`", + "'GCM' as mode", + "'DEFAULT' as padding", + "unhex('00112233445566778899AABB') as iv", + "'Comet AAD' as aad") } else { - sql(""" - |INSERT INTO aes_tbl - |SELECT - | aes_encrypt( - | encode('Spark SQL', 'UTF-8'), - | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8')), - | aes_encrypt( - | encode('Spark SQL', 'UTF-8'), - | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), - | 'GCM', - | 'DEFAULT'), - | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), - | 'GCM', - | 'DEFAULT', - | cast(null as binary), - | cast(null as string) - |""".stripMargin) + spark + .range(1) + .selectExpr( + "aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8')) as encrypted_default", + "aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), 'GCM', 'DEFAULT') as encrypted_with_aad", + "encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8') as `key`", + "'GCM' as mode", + "'DEFAULT' as padding", + "cast(null as binary) as iv", + "cast(null as string) as aad") } + aesDf.createOrReplaceTempView("aes_tbl") } if (isSpark35Plus) { - checkSparkAnswerAndFallbackReason( - "SELECT CAST(aes_decrypt(encrypted_default, `key`) AS STRING) FROM aes_tbl", - "Static invoke expression: aesDecrypt is not supported") - checkSparkAnswerAndFallbackReason( - "SELECT CAST(aes_decrypt(encrypted_with_aad, `key`, mode, padding, aad) AS STRING) FROM aes_tbl", - "Static invoke expression: aesDecrypt is not supported") + checkSparkAnswerAndOperator( + "SELECT CAST(aes_decrypt(encrypted_default, `key`) AS STRING) FROM aes_tbl") + checkSparkAnswerAndOperator( + "SELECT CAST(aes_decrypt(encrypted_with_aad, `key`, mode, padding, aad) AS STRING) FROM aes_tbl") } else { - checkSparkAnswerAndFallbackReason( - "SELECT CAST(aes_decrypt(encrypted_default, `key`) AS STRING) FROM aes_tbl", - "Static invoke expression: aesDecrypt is not supported") - checkSparkAnswerAndFallbackReason( - "SELECT CAST(aes_decrypt(encrypted_with_aad, `key`, mode, padding) AS STRING) FROM aes_tbl", - "Static invoke expression: aesDecrypt is not supported") + checkSparkAnswerAndOperator( + "SELECT CAST(aes_decrypt(encrypted_default, `key`) AS STRING) FROM aes_tbl") + checkSparkAnswerAndOperator( + "SELECT CAST(aes_decrypt(encrypted_with_aad, `key`, mode, padding) AS STRING) FROM aes_tbl") } } } From ac96a5ae0dbe79fb0ac0cc38e9e6dc9622f48757 Mon Sep 17 00:00:00 2001 From: Rafael Fernandez Date: Fri, 13 Feb 2026 16:23:18 +0100 Subject: [PATCH 5/9] add native support --- native/Cargo.lock | 184 ++++++---- native/spark-expr/Cargo.toml | 4 + native/spark-expr/src/comet_scalar_funcs.rs | 12 +- native/spark-expr/src/lib.rs | 2 +- .../spark-expr/src/math_funcs/aes_decrypt.rs | 323 ++++++++++++++++++ native/spark-expr/src/math_funcs/mod.rs | 2 + native/spark-expr/tests/spark_expr_reg.rs | 6 + 7 files changed, 461 insertions(+), 72 deletions(-) create mode 100644 native/spark-expr/src/math_funcs/aes_decrypt.rs diff --git a/native/Cargo.lock b/native/Cargo.lock index 230fc2a535..e03813118c 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -17,6 +17,16 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" +[[package]] +name = "aead" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0" +dependencies = [ + "crypto-common", + "generic-array", +] + [[package]] name = "aes" version = "0.8.4" @@ -28,6 +38,20 @@ dependencies = [ "cpufeatures 0.2.17", ] +[[package]] +name = "aes-gcm" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "831010a0f742e1209b3bcea8fab6a8e149051ba6099432c8cb2cc117dec3ead1" +dependencies = [ + "aead", + "aes", + "cipher", + "ctr", + "ghash", + "subtle", +] + [[package]] name = "ahash" version = "0.8.12" @@ -98,9 +122,9 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] name = "anstyle" -version = "1.0.13" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" +checksum = "940b3a0ca603d1eade50a4846a2afffd5ef57a9feac2c0e2ec2e14f9ead76000" [[package]] name = "anyhow" @@ -647,9 +671,9 @@ dependencies = [ [[package]] name = "aws-sdk-sso" -version = "1.96.0" +version = "1.97.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f64a6eded248c6b453966e915d32aeddb48ea63ad17932682774eb026fbef5b1" +checksum = "9aadc669e184501caaa6beafb28c6267fc1baef0810fb58f9b205485ca3f2567" dependencies = [ "aws-credential-types", "aws-runtime", @@ -671,9 +695,9 @@ dependencies = [ [[package]] name = "aws-sdk-ssooidc" -version = "1.98.0" +version = "1.99.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db96d720d3c622fcbe08bae1c4b04a72ce6257d8b0584cb5418da00ae20a344f" +checksum = "1342a7db8f358d3de0aed2007a0b54e875458e39848d54cc1d46700b2bfcb0a8" dependencies = [ "aws-credential-types", "aws-runtime", @@ -695,9 +719,9 @@ dependencies = [ [[package]] name = "aws-sdk-sts" -version = "1.100.0" +version = "1.101.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fafbdda43b93f57f699c5dfe8328db590b967b8a820a13ccdd6687355dfcc7ca" +checksum = "ab41ad64e4051ecabeea802d6a17845a91e83287e1dd249e6963ea1ba78c428a" dependencies = [ "aws-credential-types", "aws-runtime", @@ -868,9 +892,9 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.4.6" +version = "1.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2b1117b3b2bbe166d11199b540ceed0d0f7676e36e7b962b5a437a9971eac75" +checksum = "9d73dbfbaa8e4bc57b9045137680b958d274823509a360abfd8e1d514d40c95c" dependencies = [ "base64-simd", "bytes", @@ -1100,9 +1124,9 @@ dependencies = [ [[package]] name = "bon" -version = "3.9.0" +version = "3.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d13a61f2963b88eef9c1be03df65d42f6996dfeac1054870d950fcf66686f83" +checksum = "f47dbe92550676ee653353c310dfb9cf6ba17ee70396e1f7cf0a2020ad49b2fe" dependencies = [ "bon-macros", "rustversion", @@ -1110,9 +1134,9 @@ dependencies = [ [[package]] name = "bon-macros" -version = "3.9.0" +version = "3.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d314cc62af2b6b0c65780555abb4d02a03dd3b799cd42419044f0c38d99738c0" +checksum = "519bd3116aeeb42d5372c29d982d16d0170d3d4a5ed85fc7dd91642ffff3c67c" dependencies = [ "darling 0.23.0", "ident_case", @@ -1204,9 +1228,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.56" +version = "1.2.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aebf35691d1bfb0ac386a69bac2fde4dd276fb618cf8bf4f5318fe285e821bb2" +checksum = "7a0dd1ca384932ff3641c8718a02769f1698e7563dc6974ffd03346116310423" dependencies = [ "find-msvc-tools", "jobserver", @@ -1326,18 +1350,18 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.60" +version = "4.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2797f34da339ce31042b27d23607e051786132987f595b02ba4f6a6dffb7030a" +checksum = "b193af5b67834b676abd72466a96c1024e6a6ad978a1f484bd90b85c94041351" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.5.60" +version = "4.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24a241312cea5059b13574bb9b3861cabf758b879c15190b37b6d6fd63ab6876" +checksum = "714a53001bf66416adb0e2ef5ac857140e7dc3a0c48fb28b2f10762fc4b5069f" dependencies = [ "anstyle", "clap_lex", @@ -1345,9 +1369,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "1.0.0" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a822ea5bc7590f9d40f1ba12c0dc3c2760f3482c6984db1573ad11031420831" +checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9" [[package]] name = "cmake" @@ -1583,6 +1607,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" dependencies = [ "generic-array", + "rand_core 0.6.4", "typenum", ] @@ -1608,23 +1633,22 @@ dependencies = [ ] [[package]] -name = "darling" -version = "0.20.11" +name = "ctr" +version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" +checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835" dependencies = [ - "darling_core 0.20.11", - "darling_macro 0.20.11", + "cipher", ] [[package]] name = "darling" -version = "0.21.3" +version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cdf337090841a411e2a7f3deb9187445851f91b309c0c0a29e05f74a00a48c0" +checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" dependencies = [ - "darling_core 0.21.3", - "darling_macro 0.21.3", + "darling_core 0.20.11", + "darling_macro 0.20.11", ] [[package]] @@ -1651,20 +1675,6 @@ dependencies = [ "syn 2.0.117", ] -[[package]] -name = "darling_core" -version = "0.21.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1247195ecd7e3c85f83c8d2a366e4210d588e802133e1e355180a9870b517ea4" -dependencies = [ - "fnv", - "ident_case", - "proc-macro2", - "quote", - "strsim", - "syn 2.0.117", -] - [[package]] name = "darling_core" version = "0.23.0" @@ -1689,17 +1699,6 @@ dependencies = [ "syn 2.0.117", ] -[[package]] -name = "darling_macro" -version = "0.21.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81" -dependencies = [ - "darling_core 0.21.3", - "quote", - "syn 2.0.117", -] - [[package]] name = "darling_macro" version = "0.23.0" @@ -1925,12 +1924,16 @@ dependencies = [ name = "datafusion-comet-spark-expr" version = "0.15.0" dependencies = [ + "aes", + "aes-gcm", "arrow", "base64", + "cbc", "chrono", "chrono-tz", "criterion", "datafusion", + "ecb", "futures", "hex", "num", @@ -2626,9 +2629,9 @@ dependencies = [ [[package]] name = "dissimilar" -version = "1.0.10" +version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8975ffdaa0ef3661bfe02dbdcc06c9f829dfafe6a3c474de366a8d5e44276921" +checksum = "aeda16ab4059c5fd2a83f2b9c9e9c981327b18aa8e3b313f7e6563799d4f093e" [[package]] name = "dlv-list" @@ -2651,6 +2654,15 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" +[[package]] +name = "ecb" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a8bfa975b1aec2145850fcaa1c6fe269a16578c44705a532ae3edc92b8881c7" +dependencies = [ + "cipher", +] + [[package]] name = "either" version = "1.15.0" @@ -3003,6 +3015,16 @@ dependencies = [ "wasip3", ] +[[package]] +name = "ghash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0d8a4362ccb29cb0b265253fb0a2728f592895ee6854fd9bc13f2ffda266ff1" +dependencies = [ + "opaque-debug", + "polyval", +] + [[package]] name = "gimli" version = "0.32.3" @@ -4259,6 +4281,12 @@ version = "11.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" +[[package]] +name = "opaque-debug" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" + [[package]] name = "opendal" version = "0.55.0" @@ -4649,6 +4677,18 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "polyval" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25" +dependencies = [ + "cfg-if", + "cpufeatures 0.2.17", + "opaque-debug", + "universal-hash", +] + [[package]] name = "portable-atomic" version = "1.13.1" @@ -4657,9 +4697,9 @@ checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" [[package]] name = "portable-atomic-util" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a9db96d7fa8782dd8c15ce32ffe8680bbd1e978a43bf51a34d39483540495f5" +checksum = "091397be61a01d4be58e7841595bd4bfedb15f1cd54977d79b8271e94ed799a3" dependencies = [ "portable-atomic", ] @@ -5537,9 +5577,9 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.17.0" +version = "3.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "381b283ce7bc6b476d903296fb59d0d36633652b633b27f64db4fb46dcbfc3b9" +checksum = "dd5414fad8e6907dbdd5bc441a50ae8d6e26151a03b1de04d89a5576de61d01f" dependencies = [ "base64", "chrono", @@ -5556,11 +5596,11 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.17.0" +version = "3.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6d4e30573c8cb306ed6ab1dca8423eec9a463ea0e155f45399455e0368b27e0" +checksum = "d3db8978e608f1fe7357e211969fd9abdcae80bac1ba7a3369bb7eb6b404eb65" dependencies = [ - "darling 0.21.3", + "darling 0.23.0", "proc-macro2", "quote", "syn 2.0.117", @@ -6014,9 +6054,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa" +checksum = "3e61e67053d25a4e82c844e8424039d9745781b3fc4f32b8d55ed50f5f667ef3" dependencies = [ "tinyvec_macros", ] @@ -6258,6 +6298,16 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" +[[package]] +name = "universal-hash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea" +dependencies = [ + "crypto-common", + "subtle", +] + [[package]] name = "unsafe-any-ors" version = "1.0.0" diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index d4639c86ea..b7602f4f9c 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -41,6 +41,10 @@ twox-hash = "2.1.2" rand = { workspace = true } hex = "0.4.3" base64 = "0.22.1" +aes = "0.8.4" +aes-gcm = "0.10.3" +cbc = "0.1.2" +ecb = "0.1.2" [dev-dependencies] arrow = {workspace = true} diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index ff75de763b..15340a867b 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -20,10 +20,10 @@ use crate::math_funcs::abs::abs; use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mul, checked_sub}; use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ - spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan, - spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex, - spark_unscaled_value, EvalMode, SparkContains, SparkDateDiff, SparkDateTrunc, SparkMakeDate, - SparkSizeFunc, + spark_aes_decrypt, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, + spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, + spark_unhex, spark_unscaled_value, EvalMode, SparkContains, SparkDateDiff, SparkDateTrunc, + SparkMakeDate, SparkSizeFunc, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -177,6 +177,10 @@ pub fn create_comet_physical_fun_with_eval_mode( let func = Arc::new(abs); make_comet_scalar_udf!("abs", func, without data_type) } + "aes_decrypt" => { + let func = Arc::new(spark_aes_decrypt); + make_comet_scalar_udf!("aes_decrypt", func, without data_type) + } "split" => { let func = Arc::new(crate::string_funcs::spark_split); make_comet_scalar_udf!("split", func, without data_type) diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index ba19d6a9b2..01adb8c667 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -78,7 +78,7 @@ pub use error::{SparkError, SparkErrorWithContext, SparkResult}; pub use hash_funcs::*; pub use json_funcs::{FromJson, ToJson}; pub use math_funcs::{ - create_modulo_expr, create_negate_expr, spark_ceil, spark_decimal_div, + create_modulo_expr, create_negate_expr, spark_aes_decrypt, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_make_decimal, spark_round, spark_unhex, spark_unscaled_value, CheckOverflow, DecimalRescaleCheckOverflow, NegativeExpr, NormalizeNaNAndZero, WideDecimalBinaryExpr, WideDecimalOp, diff --git a/native/spark-expr/src/math_funcs/aes_decrypt.rs b/native/spark-expr/src/math_funcs/aes_decrypt.rs new file mode 100644 index 0000000000..7b6a9fc3a5 --- /dev/null +++ b/native/spark-expr/src/math_funcs/aes_decrypt.rs @@ -0,0 +1,323 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use aes::cipher::consts::{U12, U16}; +use aes::{Aes128, Aes192, Aes256}; +use aes_gcm::aead::{Aead, Payload}; +use aes_gcm::{Aes128Gcm, Aes256Gcm, AesGcm, KeyInit, Nonce}; +use arrow::array::{ + Array, ArrayRef, BinaryArray, BinaryBuilder, LargeBinaryArray, LargeStringArray, StringArray, +}; +use arrow::datatypes::DataType; +use cbc::cipher::{block_padding::Pkcs7, BlockDecryptMut, KeyIvInit}; +use datafusion::common::{exec_err, DataFusionError}; +use datafusion::logical_expr::ColumnarValue; + +const GCM_IV_LEN: usize = 12; +const CBC_IV_LEN: usize = 16; + +#[derive(Clone, Copy)] +enum AesMode { + Ecb, + Cbc, + Gcm, +} + +impl AesMode { + fn from_mode_padding(mode: &str, padding: &str) -> Result { + let is_none = padding.eq_ignore_ascii_case("NONE"); + let is_pkcs = padding.eq_ignore_ascii_case("PKCS"); + let is_default = padding.eq_ignore_ascii_case("DEFAULT"); + + if mode.eq_ignore_ascii_case("ECB") && (is_pkcs || is_default) { + Ok(Self::Ecb) + } else if mode.eq_ignore_ascii_case("CBC") && (is_pkcs || is_default) { + Ok(Self::Cbc) + } else if mode.eq_ignore_ascii_case("GCM") && (is_none || is_default) { + Ok(Self::Gcm) + } else { + exec_err!("Unsupported AES mode/padding combination: {mode}/{padding}") + } + } +} + +enum BinaryArg<'a> { + Binary(&'a BinaryArray), + LargeBinary(&'a LargeBinaryArray), +} + +impl<'a> BinaryArg<'a> { + fn from(arg_name: &str, arr: &'a ArrayRef) -> Result { + match arr.data_type() { + DataType::Binary => Ok(Self::Binary( + arr.as_any().downcast_ref::().ok_or_else(|| { + DataFusionError::Internal(format!( + "Failed to downcast {arg_name} to BinaryArray" + )) + })?, + )), + DataType::LargeBinary => Ok(Self::LargeBinary( + arr.as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "Failed to downcast {arg_name} to LargeBinaryArray" + )) + })?, + )), + other => exec_err!("{arg_name} must be Binary/LargeBinary, got {other:?}"), + } + } + + fn value(&self, i: usize) -> Option<&'a [u8]> { + match self { + Self::Binary(arr) => (!arr.is_null(i)).then(|| arr.value(i)), + Self::LargeBinary(arr) => (!arr.is_null(i)).then(|| arr.value(i)), + } + } +} + +enum StringArg<'a> { + Utf8(&'a StringArray), + LargeUtf8(&'a LargeStringArray), +} + +impl<'a> StringArg<'a> { + fn from(arg_name: &str, arr: &'a ArrayRef) -> Result { + match arr.data_type() { + DataType::Utf8 => Ok(Self::Utf8( + arr.as_any().downcast_ref::().ok_or_else(|| { + DataFusionError::Internal(format!( + "Failed to downcast {arg_name} to StringArray" + )) + })?, + )), + DataType::LargeUtf8 => Ok(Self::LargeUtf8( + arr.as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal(format!( + "Failed to downcast {arg_name} to LargeStringArray" + )) + })?, + )), + other => exec_err!("{arg_name} must be Utf8/LargeUtf8, got {other:?}"), + } + } + + fn value(&self, i: usize) -> Option<&'a str> { + match self { + Self::Utf8(arr) => (!arr.is_null(i)).then(|| arr.value(i)), + Self::LargeUtf8(arr) => (!arr.is_null(i)).then(|| arr.value(i)), + } + } +} + +type Aes128CbcDec = cbc::Decryptor; +type Aes192CbcDec = cbc::Decryptor; +type Aes256CbcDec = cbc::Decryptor; +type Aes128EcbDec = ecb::Decryptor; +type Aes192EcbDec = ecb::Decryptor; +type Aes256EcbDec = ecb::Decryptor; +type Aes192Gcm = AesGcm; + +fn decrypt_pkcs_cbc(input: &[u8], key: &[u8]) -> Result, DataFusionError> { + if input.len() < CBC_IV_LEN { + return exec_err!("AES decryption input is too short for CBC"); + } + let (iv, ciphertext) = input.split_at(CBC_IV_LEN); + let mut buf = ciphertext.to_vec(); + + let out = match key.len() { + 16 => Aes128CbcDec::new_from_slices(key, iv) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt_padded_mut::(&mut buf) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))?, + 24 => Aes192CbcDec::new_from_slices(key, iv) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt_padded_mut::(&mut buf) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))?, + 32 => Aes256CbcDec::new_from_slices(key, iv) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt_padded_mut::(&mut buf) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))?, + _ => return exec_err!("Invalid AES key length: {}", key.len()), + }; + + Ok(out.to_vec()) +} + +fn decrypt_pkcs_ecb(input: &[u8], key: &[u8]) -> Result, DataFusionError> { + let mut buf = input.to_vec(); + + let out = match key.len() { + 16 => Aes128EcbDec::new_from_slice(key) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt_padded_mut::(&mut buf) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))?, + 24 => Aes192EcbDec::new_from_slice(key) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt_padded_mut::(&mut buf) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))?, + 32 => Aes256EcbDec::new_from_slice(key) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt_padded_mut::(&mut buf) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))?, + _ => return exec_err!("Invalid AES key length: {}", key.len()), + }; + + Ok(out.to_vec()) +} + +fn decrypt_gcm(input: &[u8], key: &[u8], aad: &[u8]) -> Result, DataFusionError> { + if input.len() < GCM_IV_LEN { + return exec_err!("AES decryption input is too short for GCM"); + } + let (iv, ciphertext) = input.split_at(GCM_IV_LEN); + let nonce = Nonce::from_slice(iv); + let payload = Payload { + msg: ciphertext, + aad, + }; + + match key.len() { + 16 => Aes128Gcm::new_from_slice(key) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt(nonce, payload) + .map_err(|_| { + DataFusionError::Execution("AES crypto error: decrypt failed".to_string()) + }), + 24 => Aes192Gcm::new_from_slice(key) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt(nonce, payload) + .map_err(|_| { + DataFusionError::Execution("AES crypto error: decrypt failed".to_string()) + }), + 32 => Aes256Gcm::new_from_slice(key) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt(nonce, payload) + .map_err(|_| { + DataFusionError::Execution("AES crypto error: decrypt failed".to_string()) + }), + _ => exec_err!("Invalid AES key length: {}", key.len()), + } +} + +fn decrypt_one( + input: &[u8], + key: &[u8], + mode: &str, + padding: &str, + aad: &[u8], +) -> Result, DataFusionError> { + match AesMode::from_mode_padding(mode, padding)? { + AesMode::Ecb => decrypt_pkcs_ecb(input, key), + AesMode::Cbc => decrypt_pkcs_cbc(input, key), + AesMode::Gcm => decrypt_gcm(input, key, aad), + } +} + +pub fn spark_aes_decrypt(args: &[ColumnarValue]) -> Result { + if !(2..=5).contains(&args.len()) { + return exec_err!("aes_decrypt expects 2 to 5 arguments, got {}", args.len()); + } + + let are_scalars = args + .iter() + .all(|arg| matches!(arg, ColumnarValue::Scalar(_))); + let arrays = ColumnarValue::values_to_arrays(args)?; + let num_rows = arrays[0].len(); + + let input = BinaryArg::from("input", &arrays[0])?; + let key = BinaryArg::from("key", &arrays[1])?; + + let mode = if args.len() >= 3 { + Some(StringArg::from("mode", &arrays[2])?) + } else { + None + }; + let padding = if args.len() >= 4 { + Some(StringArg::from("padding", &arrays[3])?) + } else { + None + }; + let aad = if args.len() >= 5 { + Some(BinaryArg::from("aad", &arrays[4])?) + } else { + None + }; + + let mut builder = BinaryBuilder::new(); + + for row in 0..num_rows { + let Some(input_value) = input.value(row) else { + builder.append_null(); + continue; + }; + let Some(key_value) = key.value(row) else { + builder.append_null(); + continue; + }; + + let mode_value = match mode.as_ref() { + Some(mode) => { + let Some(mode) = mode.value(row) else { + builder.append_null(); + continue; + }; + mode + } + None => "GCM", + }; + + let padding_value = match padding.as_ref() { + Some(padding) => { + let Some(padding) = padding.value(row) else { + builder.append_null(); + continue; + }; + padding + } + None => "DEFAULT", + }; + + let aad_value = match aad.as_ref() { + Some(aad) => { + let Some(aad) = aad.value(row) else { + builder.append_null(); + continue; + }; + aad + } + None => &[], + }; + + let plaintext = decrypt_one(input_value, key_value, mode_value, padding_value, aad_value)?; + builder.append_value(plaintext); + } + + let array = Arc::new(builder.finish()); + if are_scalars { + Ok(ColumnarValue::Scalar( + datafusion::common::ScalarValue::try_from_array(array.as_ref(), 0)?, + )) + } else { + Ok(ColumnarValue::Array(array)) + } +} diff --git a/native/spark-expr/src/math_funcs/mod.rs b/native/spark-expr/src/math_funcs/mod.rs index 1219bc7208..808efde655 100644 --- a/native/spark-expr/src/math_funcs/mod.rs +++ b/native/spark-expr/src/math_funcs/mod.rs @@ -16,6 +16,7 @@ // under the License. pub(crate) mod abs; +pub(crate) mod aes_decrypt; mod ceil; pub(crate) mod checked_arithmetic; mod div; @@ -28,6 +29,7 @@ pub(crate) mod unhex; mod utils; mod wide_decimal_binary_expr; +pub use aes_decrypt::spark_aes_decrypt; pub use ceil::spark_ceil; pub use div::spark_decimal_div; pub use div::spark_decimal_integral_div; diff --git a/native/spark-expr/tests/spark_expr_reg.rs b/native/spark-expr/tests/spark_expr_reg.rs index 633b226068..f381b77881 100644 --- a/native/spark-expr/tests/spark_expr_reg.rs +++ b/native/spark-expr/tests/spark_expr_reg.rs @@ -35,6 +35,12 @@ mod tests { &session_state, None, )?); + let _ = session_state.register_udf(create_comet_physical_fun( + "aes_decrypt", + DataType::Binary, + &session_state, + None, + )?); let ctx = SessionContext::new_with_state(session_state); // 2. Execute SQL with literal values From b20248b09e89f39a212375dc8fe847cbc5aa5103 Mon Sep 17 00:00:00 2001 From: Rafael Fernandez Date: Tue, 17 Feb 2026 23:12:10 +0100 Subject: [PATCH 6/9] add downcast macros --- native/spark-expr/src/downcast.rs | 36 +++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 native/spark-expr/src/downcast.rs diff --git a/native/spark-expr/src/downcast.rs b/native/spark-expr/src/downcast.rs new file mode 100644 index 0000000000..ade2ef961b --- /dev/null +++ b/native/spark-expr/src/downcast.rs @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +macro_rules! opt_downcast_arg { + ($ARG:expr, $ARRAY_TYPE:ident) => {{ + $ARG.as_any().downcast_ref::<$ARRAY_TYPE>() + }}; +} + +macro_rules! downcast_named_arg { + ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{ + $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { + datafusion::common::internal_datafusion_err!( + "could not cast {} to {}", + $NAME, + std::any::type_name::<$ARRAY_TYPE>() + ) + })? + }}; +} + +pub(crate) use {downcast_named_arg, opt_downcast_arg}; From b69be6f58f872cbd0b508a8e0567ec8af393f417 Mon Sep 17 00:00:00 2001 From: Rafael Fernandez Date: Tue, 17 Feb 2026 23:12:21 +0100 Subject: [PATCH 7/9] edit implementation and move to misc reference --- native/spark-expr/src/lib.rs | 6 +- native/spark-expr/src/math_funcs/mod.rs | 2 - .../{math_funcs => misc_funcs}/aes_decrypt.rs | 152 ++++++++---------- native/spark-expr/src/misc_funcs/mod.rs | 20 +++ .../scala/org/apache/comet/serde/misc.scala | 59 +++++++ .../org/apache/comet/serde/statics.scala | 37 +---- 6 files changed, 155 insertions(+), 121 deletions(-) rename native/spark-expr/src/{math_funcs => misc_funcs}/aes_decrypt.rs (75%) create mode 100644 native/spark-expr/src/misc_funcs/mod.rs create mode 100644 spark/src/main/scala/org/apache/comet/serde/misc.scala diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 01adb8c667..f737fa87fc 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -19,7 +19,9 @@ // The lint makes easier for code reader/reviewer separate references clones from more heavyweight ones #![deny(clippy::clone_on_ref_ptr)] +mod downcast; mod error; +pub(crate) use downcast::{downcast_named_arg, opt_downcast_arg}; mod query_context; pub mod kernels; @@ -58,6 +60,7 @@ pub use bloom_filter::{BloomFilterAgg, BloomFilterMightContain}; mod conditional_funcs; mod conversion_funcs; mod math_funcs; +mod misc_funcs; mod nondetermenistic_funcs; pub use array_funcs::*; @@ -78,12 +81,13 @@ pub use error::{SparkError, SparkErrorWithContext, SparkResult}; pub use hash_funcs::*; pub use json_funcs::{FromJson, ToJson}; pub use math_funcs::{ - create_modulo_expr, create_negate_expr, spark_aes_decrypt, spark_ceil, spark_decimal_div, + create_modulo_expr, create_negate_expr, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_make_decimal, spark_round, spark_unhex, spark_unscaled_value, CheckOverflow, DecimalRescaleCheckOverflow, NegativeExpr, NormalizeNaNAndZero, WideDecimalBinaryExpr, WideDecimalOp, }; pub use query_context::{create_query_context_map, QueryContext, QueryContextMap}; +pub use misc_funcs::spark_aes_decrypt; pub use string_funcs::*; /// Spark supports three evaluation modes when evaluating expressions, which affect diff --git a/native/spark-expr/src/math_funcs/mod.rs b/native/spark-expr/src/math_funcs/mod.rs index 808efde655..1219bc7208 100644 --- a/native/spark-expr/src/math_funcs/mod.rs +++ b/native/spark-expr/src/math_funcs/mod.rs @@ -16,7 +16,6 @@ // under the License. pub(crate) mod abs; -pub(crate) mod aes_decrypt; mod ceil; pub(crate) mod checked_arithmetic; mod div; @@ -29,7 +28,6 @@ pub(crate) mod unhex; mod utils; mod wide_decimal_binary_expr; -pub use aes_decrypt::spark_aes_decrypt; pub use ceil::spark_ceil; pub use div::spark_decimal_div; pub use div::spark_decimal_integral_div; diff --git a/native/spark-expr/src/math_funcs/aes_decrypt.rs b/native/spark-expr/src/misc_funcs/aes_decrypt.rs similarity index 75% rename from native/spark-expr/src/math_funcs/aes_decrypt.rs rename to native/spark-expr/src/misc_funcs/aes_decrypt.rs index 7b6a9fc3a5..605fe50994 100644 --- a/native/spark-expr/src/math_funcs/aes_decrypt.rs +++ b/native/spark-expr/src/misc_funcs/aes_decrypt.rs @@ -15,18 +15,14 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - use aes::cipher::consts::{U12, U16}; use aes::{Aes128, Aes192, Aes256}; use aes_gcm::aead::{Aead, Payload}; use aes_gcm::{Aes128Gcm, Aes256Gcm, AesGcm, KeyInit, Nonce}; -use arrow::array::{ - Array, ArrayRef, BinaryArray, BinaryBuilder, LargeBinaryArray, LargeStringArray, StringArray, -}; +use arrow::array::{Array, ArrayRef, BinaryArray, LargeBinaryArray, LargeStringArray, StringArray}; use arrow::datatypes::DataType; use cbc::cipher::{block_padding::Pkcs7, BlockDecryptMut, KeyIvInit}; -use datafusion::common::{exec_err, DataFusionError}; +use datafusion::common::{exec_err, DataFusionError, ScalarValue}; use datafusion::logical_expr::ColumnarValue; const GCM_IV_LEN: usize = 12; @@ -66,21 +62,19 @@ impl<'a> BinaryArg<'a> { fn from(arg_name: &str, arr: &'a ArrayRef) -> Result { match arr.data_type() { DataType::Binary => Ok(Self::Binary( - arr.as_any().downcast_ref::().ok_or_else(|| { - DataFusionError::Internal(format!( - "Failed to downcast {arg_name} to BinaryArray" - )) + crate::opt_downcast_arg!(arr, BinaryArray).ok_or_else(|| { + datafusion::common::internal_datafusion_err!( + "could not cast {} to {}", + arg_name, + std::any::type_name::() + ) })?, )), - DataType::LargeBinary => Ok(Self::LargeBinary( - arr.as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal(format!( - "Failed to downcast {arg_name} to LargeBinaryArray" - )) - })?, - )), + DataType::LargeBinary => Ok(Self::LargeBinary(crate::downcast_named_arg!( + arr, + arg_name, + LargeBinaryArray + ))), other => exec_err!("{arg_name} must be Binary/LargeBinary, got {other:?}"), } } @@ -102,21 +96,19 @@ impl<'a> StringArg<'a> { fn from(arg_name: &str, arr: &'a ArrayRef) -> Result { match arr.data_type() { DataType::Utf8 => Ok(Self::Utf8( - arr.as_any().downcast_ref::().ok_or_else(|| { - DataFusionError::Internal(format!( - "Failed to downcast {arg_name} to StringArray" - )) + crate::opt_downcast_arg!(arr, StringArray).ok_or_else(|| { + datafusion::common::internal_datafusion_err!( + "could not cast {} to {}", + arg_name, + std::any::type_name::() + ) })?, )), - DataType::LargeUtf8 => Ok(Self::LargeUtf8( - arr.as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Internal(format!( - "Failed to downcast {arg_name} to LargeStringArray" - )) - })?, - )), + DataType::LargeUtf8 => Ok(Self::LargeUtf8(crate::downcast_named_arg!( + arr, + arg_name, + LargeStringArray + ))), other => exec_err!("{arg_name} must be Utf8/LargeUtf8, got {other:?}"), } } @@ -263,56 +255,52 @@ pub fn spark_aes_decrypt(args: &[ColumnarValue]) -> Result { - let Some(mode) = mode.value(row) else { - builder.append_null(); - continue; - }; - mode - } - None => "GCM", - }; - - let padding_value = match padding.as_ref() { - Some(padding) => { - let Some(padding) = padding.value(row) else { - builder.append_null(); - continue; - }; - padding - } - None => "DEFAULT", - }; - - let aad_value = match aad.as_ref() { - Some(aad) => { - let Some(aad) = aad.value(row) else { - builder.append_null(); - continue; - }; - aad - } - None => &[], - }; - - let plaintext = decrypt_one(input_value, key_value, mode_value, padding_value, aad_value)?; - builder.append_value(plaintext); - } - - let array = Arc::new(builder.finish()); + let values: Result, DataFusionError> = (0..num_rows) + .map(|row| { + let Some(input_value) = input.value(row) else { + return Ok(ScalarValue::Binary(None)); + }; + let Some(key_value) = key.value(row) else { + return Ok(ScalarValue::Binary(None)); + }; + + let mode_value = match mode.as_ref() { + Some(mode) => { + let Some(mode) = mode.value(row) else { + return Ok(ScalarValue::Binary(None)); + }; + mode + } + None => "GCM", + }; + + let padding_value = match padding.as_ref() { + Some(padding) => { + let Some(padding) = padding.value(row) else { + return Ok(ScalarValue::Binary(None)); + }; + padding + } + None => "DEFAULT", + }; + + let aad_value = match aad.as_ref() { + Some(aad) => { + let Some(aad) = aad.value(row) else { + return Ok(ScalarValue::Binary(None)); + }; + aad + } + None => &[], + }; + + let plaintext = + decrypt_one(input_value, key_value, mode_value, padding_value, aad_value)?; + Ok(ScalarValue::Binary(Some(plaintext))) + }) + .collect(); + + let array: ArrayRef = ScalarValue::iter_to_array(values?)?; if are_scalars { Ok(ColumnarValue::Scalar( datafusion::common::ScalarValue::try_from_array(array.as_ref(), 0)?, diff --git a/native/spark-expr/src/misc_funcs/mod.rs b/native/spark-expr/src/misc_funcs/mod.rs new file mode 100644 index 0000000000..c55b82811d --- /dev/null +++ b/native/spark-expr/src/misc_funcs/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub(crate) mod aes_decrypt; + +pub use aes_decrypt::spark_aes_decrypt; diff --git a/spark/src/main/scala/org/apache/comet/serde/misc.scala b/spark/src/main/scala/org/apache/comet/serde/misc.scala new file mode 100644 index 0000000000..9bb8416579 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/misc.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.catalyst.expressions.{AesDecrypt, Attribute, Expression} +import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke + +import org.apache.comet.serde.ExprOuterClass.Expr +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProtoWithReturnType} + +private object CometAesDecryptHelper { + def convertToAesDecryptExpr[T <: Expression]( + expr: T, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + val childExpr = expr.children.map(exprToProtoInternal(_, inputs, binding)) + val optExpr = scalarFunctionExprToProtoWithReturnType( + "aes_decrypt", + expr.dataType, + failOnError = false, + childExpr: _*) + optExprWithInfo(optExpr, expr, expr.children: _*) + } +} + +object CometAesDecrypt extends CometExpressionSerde[AesDecrypt] { + override def convert( + expr: AesDecrypt, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + CometAesDecryptHelper.convertToAesDecryptExpr(expr, inputs, binding) + } +} + +object CometAesDecryptStaticInvoke extends CometExpressionSerde[StaticInvoke] { + override def convert( + expr: StaticInvoke, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + CometAesDecryptHelper.convertToAesDecryptExpr(expr, inputs, binding) + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/statics.scala b/spark/src/main/scala/org/apache/comet/serde/statics.scala index 0ff31908d4..5356b3e11b 100644 --- a/spark/src/main/scala/org/apache/comet/serde/statics.scala +++ b/spark/src/main/scala/org/apache/comet/serde/statics.scala @@ -19,46 +19,11 @@ package org.apache.comet.serde -import org.apache.spark.sql.catalyst.expressions.{AesDecrypt, Attribute, Expression, ExpressionImplUtils} +import org.apache.spark.sql.catalyst.expressions.{Attribute, ExpressionImplUtils} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils import org.apache.comet.CometSparkSessionExtensions.withInfo -import org.apache.comet.serde.ExprOuterClass.Expr -import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProtoWithReturnType} - -private object CometAesDecryptHelper { - def convertToAesDecryptExpr[T <: Expression]( - expr: T, - inputs: Seq[Attribute], - binding: Boolean): Option[Expr] = { - val childExpr = expr.children.map(exprToProtoInternal(_, inputs, binding)) - val optExpr = scalarFunctionExprToProtoWithReturnType( - "aes_decrypt", - expr.dataType, - failOnError = false, - childExpr: _*) - optExprWithInfo(optExpr, expr, expr.children: _*) - } -} - -object CometAesDecrypt extends CometExpressionSerde[AesDecrypt] { - override def convert( - expr: AesDecrypt, - inputs: Seq[Attribute], - binding: Boolean): Option[Expr] = { - CometAesDecryptHelper.convertToAesDecryptExpr(expr, inputs, binding) - } -} - -object CometAesDecryptStaticInvoke extends CometExpressionSerde[StaticInvoke] { - override def convert( - expr: StaticInvoke, - inputs: Seq[Attribute], - binding: Boolean): Option[Expr] = { - CometAesDecryptHelper.convertToAesDecryptExpr(expr, inputs, binding) - } -} object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] { From 3b599e1acadada55b73505948f758bd1608d9136 Mon Sep 17 00:00:00 2001 From: Rafael Fernandez Date: Thu, 19 Feb 2026 21:50:29 +0100 Subject: [PATCH 8/9] suggestion - tests --- .../expressions/misc/aes_decrypt.sql | 110 ++++++++++++++++++ .../comet/CometMiscExpressionSuite.scala | 88 ++++++++++++++ 2 files changed, 198 insertions(+) diff --git a/spark/src/test/resources/sql-tests/expressions/misc/aes_decrypt.sql b/spark/src/test/resources/sql-tests/expressions/misc/aes_decrypt.sql index cca41c83d7..5c93dd6a7a 100644 --- a/spark/src/test/resources/sql-tests/expressions/misc/aes_decrypt.sql +++ b/spark/src/test/resources/sql-tests/expressions/misc/aes_decrypt.sql @@ -52,3 +52,113 @@ SELECT CAST(aes_decrypt(encrypted_default, `key`) AS STRING) FROM aes_tbl query SELECT CAST(aes_decrypt(encrypted_with_aad, `key`, mode, padding, aad) AS STRING) FROM aes_tbl + +statement +CREATE TABLE aes_modes_tbl( + encrypted BINARY, + `key` BINARY, + mode STRING, + padding STRING, + label STRING +) USING parquet + +statement +INSERT INTO aes_modes_tbl +SELECT + aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop', 'UTF-8'), 'GCM', 'DEFAULT'), + encode('abcdefghijklmnop', 'UTF-8'), + 'GCM', + 'DEFAULT', + 'gcm_128' +UNION ALL +SELECT + aes_encrypt( + encode('Spark SQL', 'UTF-8'), + encode('abcdefghijklmnop12345678', 'UTF-8'), + 'GCM', + 'DEFAULT'), + encode('abcdefghijklmnop12345678', 'UTF-8'), + 'GCM', + 'DEFAULT', + 'gcm_192' +UNION ALL +SELECT + aes_encrypt( + encode('Spark SQL', 'UTF-8'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + 'GCM', + 'DEFAULT'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + 'GCM', + 'DEFAULT', + 'gcm_256' +UNION ALL +SELECT + aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop', 'UTF-8'), 'CBC', 'PKCS'), + encode('abcdefghijklmnop', 'UTF-8'), + 'CBC', + 'PKCS', + 'cbc_128' +UNION ALL +SELECT + aes_encrypt( + encode('Spark SQL', 'UTF-8'), + encode('abcdefghijklmnop12345678', 'UTF-8'), + 'CBC', + 'PKCS'), + encode('abcdefghijklmnop12345678', 'UTF-8'), + 'CBC', + 'PKCS', + 'cbc_192' +UNION ALL +SELECT + aes_encrypt( + encode('Spark SQL', 'UTF-8'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + 'CBC', + 'PKCS'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + 'CBC', + 'PKCS', + 'cbc_256' +UNION ALL +SELECT + aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop', 'UTF-8'), 'ECB', 'PKCS'), + encode('abcdefghijklmnop', 'UTF-8'), + 'ECB', + 'PKCS', + 'ecb_128' +UNION ALL +SELECT + aes_encrypt( + encode('Spark SQL', 'UTF-8'), + encode('abcdefghijklmnop12345678', 'UTF-8'), + 'ECB', + 'PKCS'), + encode('abcdefghijklmnop12345678', 'UTF-8'), + 'ECB', + 'PKCS', + 'ecb_192' +UNION ALL +SELECT + aes_encrypt( + encode('Spark SQL', 'UTF-8'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + 'ECB', + 'PKCS'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + 'ECB', + 'PKCS', + 'ecb_256' +UNION ALL +SELECT + cast(null AS binary), + encode('abcdefghijklmnop', 'UTF-8'), + 'GCM', + 'DEFAULT', + 'null_input' + +query +SELECT label, CAST(aes_decrypt(encrypted, `key`, mode, padding) AS STRING) +FROM aes_modes_tbl +ORDER BY label diff --git a/spark/src/test/scala/org/apache/comet/CometMiscExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometMiscExpressionSuite.scala index b6f7e921b4..2e6cc90cb8 100644 --- a/spark/src/test/scala/org/apache/comet/CometMiscExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometMiscExpressionSuite.scala @@ -68,4 +68,92 @@ class CometMiscExpressionSuite extends CometTestBase { } } + test("aes_decrypt mode and key-size combinations") { + withTempView("aes_modes_tbl") { + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + spark + .sql(""" + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop', 'UTF-8'), 'GCM', 'DEFAULT') AS encrypted, + | encode('abcdefghijklmnop', 'UTF-8') AS `key`, + | 'GCM' AS mode, + | 'DEFAULT' AS padding, + | 'gcm_128' AS label + |UNION ALL + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678', 'UTF-8'), 'GCM', 'DEFAULT') AS encrypted, + | encode('abcdefghijklmnop12345678', 'UTF-8') AS `key`, + | 'GCM' AS mode, + | 'DEFAULT' AS padding, + | 'gcm_192' AS label + |UNION ALL + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), 'GCM', 'DEFAULT') AS encrypted, + | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8') AS `key`, + | 'GCM' AS mode, + | 'DEFAULT' AS padding, + | 'gcm_256' AS label + |UNION ALL + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop', 'UTF-8'), 'CBC', 'PKCS') AS encrypted, + | encode('abcdefghijklmnop', 'UTF-8') AS `key`, + | 'CBC' AS mode, + | 'PKCS' AS padding, + | 'cbc_128' AS label + |UNION ALL + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678', 'UTF-8'), 'CBC', 'PKCS') AS encrypted, + | encode('abcdefghijklmnop12345678', 'UTF-8') AS `key`, + | 'CBC' AS mode, + | 'PKCS' AS padding, + | 'cbc_192' AS label + |UNION ALL + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), 'CBC', 'PKCS') AS encrypted, + | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8') AS `key`, + | 'CBC' AS mode, + | 'PKCS' AS padding, + | 'cbc_256' AS label + |UNION ALL + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop', 'UTF-8'), 'ECB', 'PKCS') AS encrypted, + | encode('abcdefghijklmnop', 'UTF-8') AS `key`, + | 'ECB' AS mode, + | 'PKCS' AS padding, + | 'ecb_128' AS label + |UNION ALL + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678', 'UTF-8'), 'ECB', 'PKCS') AS encrypted, + | encode('abcdefghijklmnop12345678', 'UTF-8') AS `key`, + | 'ECB' AS mode, + | 'PKCS' AS padding, + | 'ecb_192' AS label + |UNION ALL + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), 'ECB', 'PKCS') AS encrypted, + | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8') AS `key`, + | 'ECB' AS mode, + | 'PKCS' AS padding, + | 'ecb_256' AS label + |UNION ALL + |SELECT + | cast(null AS binary) AS encrypted, + | encode('abcdefghijklmnop', 'UTF-8') AS `key`, + | 'GCM' AS mode, + | 'DEFAULT' AS padding, + | 'null_input' AS label + |""".stripMargin) + .createOrReplaceTempView("aes_modes_tbl") + } + + checkSparkAnswerAndOperator(""" + |SELECT + | label, + | CAST(aes_decrypt(encrypted, `key`, mode, padding) AS STRING) AS decrypted + |FROM aes_modes_tbl + |ORDER BY label + |""".stripMargin) + } + } + } From 4f9dcaa6b53ea972be3390893c9eec15aea02958 Mon Sep 17 00:00:00 2001 From: Rafael Fernandez Date: Thu, 19 Mar 2026 13:28:57 +0100 Subject: [PATCH 9/9] refactor: remove unused CometAesDecrypt direct path AesDecrypt extends RuntimeReplaceable in Spark, so it is always rewritten to StaticInvoke before Comet's serde runs. Remove the direct CometAesDecrypt handler and CometAesDecryptHelper, keeping only the CometAesDecryptStaticInvoke path. --- .../apache/comet/serde/QueryPlanSerde.scala | 3 +- .../scala/org/apache/comet/serde/misc.scala | 29 +++++-------------- 2 files changed, 9 insertions(+), 23 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 9a4226aade..8d358ee29c 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -222,7 +222,8 @@ object QueryPlanSerde extends Logging with CometExprShim { private val miscExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( // TODO PromotePrecision - classOf[AesDecrypt] -> CometAesDecrypt, + // AesDecrypt extends RuntimeReplaceable and is rewritten to StaticInvoke before Comet's + // serde runs, so it is handled via CometStaticInvoke / CometAesDecryptStaticInvoke. classOf[Alias] -> CometAlias, classOf[AttributeReference] -> CometAttributeReference, classOf[BloomFilterMightContain] -> CometBloomFilterMightContain, diff --git a/spark/src/main/scala/org/apache/comet/serde/misc.scala b/spark/src/main/scala/org/apache/comet/serde/misc.scala index 9bb8416579..1f1bc1854f 100644 --- a/spark/src/main/scala/org/apache/comet/serde/misc.scala +++ b/spark/src/main/scala/org/apache/comet/serde/misc.scala @@ -19,15 +19,18 @@ package org.apache.comet.serde -import org.apache.spark.sql.catalyst.expressions.{AesDecrypt, Attribute, Expression} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.comet.serde.ExprOuterClass.Expr import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProtoWithReturnType} -private object CometAesDecryptHelper { - def convertToAesDecryptExpr[T <: Expression]( - expr: T, +// AesDecrypt extends RuntimeReplaceable in Spark, so by the time Comet's serde runs it has +// already been replaced with a StaticInvoke. This handler is registered in CometStaticInvoke's +// staticInvokeExpressions map under the ("aesDecrypt", ExpressionImplUtils) key. +object CometAesDecryptStaticInvoke extends CometExpressionSerde[StaticInvoke] { + override def convert( + expr: StaticInvoke, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { val childExpr = expr.children.map(exprToProtoInternal(_, inputs, binding)) @@ -39,21 +42,3 @@ private object CometAesDecryptHelper { optExprWithInfo(optExpr, expr, expr.children: _*) } } - -object CometAesDecrypt extends CometExpressionSerde[AesDecrypt] { - override def convert( - expr: AesDecrypt, - inputs: Seq[Attribute], - binding: Boolean): Option[Expr] = { - CometAesDecryptHelper.convertToAesDecryptExpr(expr, inputs, binding) - } -} - -object CometAesDecryptStaticInvoke extends CometExpressionSerde[StaticInvoke] { - override def convert( - expr: StaticInvoke, - inputs: Seq[Attribute], - binding: Boolean): Option[Expr] = { - CometAesDecryptHelper.convertToAesDecryptExpr(expr, inputs, binding) - } -}