diff --git a/docs/spark_expressions_support.md b/docs/spark_expressions_support.md index 5474894108..752ee9d28e 100644 --- a/docs/spark_expressions_support.md +++ b/docs/spark_expressions_support.md @@ -353,7 +353,7 @@ ### misc_funcs -- [ ] aes_decrypt +- [x] aes_decrypt - [ ] aes_encrypt - [ ] assert_true - [x] current_catalog diff --git a/native/Cargo.lock b/native/Cargo.lock index 230fc2a535..e03813118c 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -17,6 +17,16 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" +[[package]] +name = "aead" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0" +dependencies = [ + "crypto-common", + "generic-array", +] + [[package]] name = "aes" version = "0.8.4" @@ -28,6 +38,20 @@ dependencies = [ "cpufeatures 0.2.17", ] +[[package]] +name = "aes-gcm" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "831010a0f742e1209b3bcea8fab6a8e149051ba6099432c8cb2cc117dec3ead1" +dependencies = [ + "aead", + "aes", + "cipher", + "ctr", + "ghash", + "subtle", +] + [[package]] name = "ahash" version = "0.8.12" @@ -98,9 +122,9 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] name = "anstyle" -version = "1.0.13" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" +checksum = "940b3a0ca603d1eade50a4846a2afffd5ef57a9feac2c0e2ec2e14f9ead76000" [[package]] name = "anyhow" @@ -647,9 +671,9 @@ dependencies = [ [[package]] name = "aws-sdk-sso" -version = "1.96.0" +version = "1.97.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f64a6eded248c6b453966e915d32aeddb48ea63ad17932682774eb026fbef5b1" +checksum = "9aadc669e184501caaa6beafb28c6267fc1baef0810fb58f9b205485ca3f2567" dependencies = [ "aws-credential-types", "aws-runtime", @@ -671,9 +695,9 @@ dependencies = [ [[package]] name = "aws-sdk-ssooidc" -version = "1.98.0" +version = "1.99.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db96d720d3c622fcbe08bae1c4b04a72ce6257d8b0584cb5418da00ae20a344f" +checksum = "1342a7db8f358d3de0aed2007a0b54e875458e39848d54cc1d46700b2bfcb0a8" dependencies = [ "aws-credential-types", "aws-runtime", @@ -695,9 +719,9 @@ dependencies = [ [[package]] name = "aws-sdk-sts" -version = "1.100.0" +version = "1.101.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fafbdda43b93f57f699c5dfe8328db590b967b8a820a13ccdd6687355dfcc7ca" +checksum = "ab41ad64e4051ecabeea802d6a17845a91e83287e1dd249e6963ea1ba78c428a" dependencies = [ "aws-credential-types", "aws-runtime", @@ -868,9 +892,9 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.4.6" +version = "1.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2b1117b3b2bbe166d11199b540ceed0d0f7676e36e7b962b5a437a9971eac75" +checksum = "9d73dbfbaa8e4bc57b9045137680b958d274823509a360abfd8e1d514d40c95c" dependencies = [ "base64-simd", "bytes", @@ -1100,9 +1124,9 @@ dependencies = [ [[package]] name = "bon" -version = "3.9.0" +version = "3.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2d13a61f2963b88eef9c1be03df65d42f6996dfeac1054870d950fcf66686f83" +checksum = "f47dbe92550676ee653353c310dfb9cf6ba17ee70396e1f7cf0a2020ad49b2fe" dependencies = [ "bon-macros", "rustversion", @@ -1110,9 +1134,9 @@ dependencies = [ [[package]] name = "bon-macros" -version = "3.9.0" +version = "3.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d314cc62af2b6b0c65780555abb4d02a03dd3b799cd42419044f0c38d99738c0" +checksum = "519bd3116aeeb42d5372c29d982d16d0170d3d4a5ed85fc7dd91642ffff3c67c" dependencies = [ "darling 0.23.0", "ident_case", @@ -1204,9 +1228,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.56" +version = "1.2.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aebf35691d1bfb0ac386a69bac2fde4dd276fb618cf8bf4f5318fe285e821bb2" +checksum = "7a0dd1ca384932ff3641c8718a02769f1698e7563dc6974ffd03346116310423" dependencies = [ "find-msvc-tools", "jobserver", @@ -1326,18 +1350,18 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.60" +version = "4.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2797f34da339ce31042b27d23607e051786132987f595b02ba4f6a6dffb7030a" +checksum = "b193af5b67834b676abd72466a96c1024e6a6ad978a1f484bd90b85c94041351" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.5.60" +version = "4.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24a241312cea5059b13574bb9b3861cabf758b879c15190b37b6d6fd63ab6876" +checksum = "714a53001bf66416adb0e2ef5ac857140e7dc3a0c48fb28b2f10762fc4b5069f" dependencies = [ "anstyle", "clap_lex", @@ -1345,9 +1369,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "1.0.0" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a822ea5bc7590f9d40f1ba12c0dc3c2760f3482c6984db1573ad11031420831" +checksum = "c8d4a3bb8b1e0c1050499d1815f5ab16d04f0959b233085fb31653fbfc9d98f9" [[package]] name = "cmake" @@ -1583,6 +1607,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" dependencies = [ "generic-array", + "rand_core 0.6.4", "typenum", ] @@ -1608,23 +1633,22 @@ dependencies = [ ] [[package]] -name = "darling" -version = "0.20.11" +name = "ctr" +version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" +checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835" dependencies = [ - "darling_core 0.20.11", - "darling_macro 0.20.11", + "cipher", ] [[package]] name = "darling" -version = "0.21.3" +version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cdf337090841a411e2a7f3deb9187445851f91b309c0c0a29e05f74a00a48c0" +checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" dependencies = [ - "darling_core 0.21.3", - "darling_macro 0.21.3", + "darling_core 0.20.11", + "darling_macro 0.20.11", ] [[package]] @@ -1651,20 +1675,6 @@ dependencies = [ "syn 2.0.117", ] -[[package]] -name = "darling_core" -version = "0.21.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1247195ecd7e3c85f83c8d2a366e4210d588e802133e1e355180a9870b517ea4" -dependencies = [ - "fnv", - "ident_case", - "proc-macro2", - "quote", - "strsim", - "syn 2.0.117", -] - [[package]] name = "darling_core" version = "0.23.0" @@ -1689,17 +1699,6 @@ dependencies = [ "syn 2.0.117", ] -[[package]] -name = "darling_macro" -version = "0.21.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81" -dependencies = [ - "darling_core 0.21.3", - "quote", - "syn 2.0.117", -] - [[package]] name = "darling_macro" version = "0.23.0" @@ -1925,12 +1924,16 @@ dependencies = [ name = "datafusion-comet-spark-expr" version = "0.15.0" dependencies = [ + "aes", + "aes-gcm", "arrow", "base64", + "cbc", "chrono", "chrono-tz", "criterion", "datafusion", + "ecb", "futures", "hex", "num", @@ -2626,9 +2629,9 @@ dependencies = [ [[package]] name = "dissimilar" -version = "1.0.10" +version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8975ffdaa0ef3661bfe02dbdcc06c9f829dfafe6a3c474de366a8d5e44276921" +checksum = "aeda16ab4059c5fd2a83f2b9c9e9c981327b18aa8e3b313f7e6563799d4f093e" [[package]] name = "dlv-list" @@ -2651,6 +2654,15 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" +[[package]] +name = "ecb" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a8bfa975b1aec2145850fcaa1c6fe269a16578c44705a532ae3edc92b8881c7" +dependencies = [ + "cipher", +] + [[package]] name = "either" version = "1.15.0" @@ -3003,6 +3015,16 @@ dependencies = [ "wasip3", ] +[[package]] +name = "ghash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0d8a4362ccb29cb0b265253fb0a2728f592895ee6854fd9bc13f2ffda266ff1" +dependencies = [ + "opaque-debug", + "polyval", +] + [[package]] name = "gimli" version = "0.32.3" @@ -4259,6 +4281,12 @@ version = "11.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" +[[package]] +name = "opaque-debug" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" + [[package]] name = "opendal" version = "0.55.0" @@ -4649,6 +4677,18 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "polyval" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25" +dependencies = [ + "cfg-if", + "cpufeatures 0.2.17", + "opaque-debug", + "universal-hash", +] + [[package]] name = "portable-atomic" version = "1.13.1" @@ -4657,9 +4697,9 @@ checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" [[package]] name = "portable-atomic-util" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a9db96d7fa8782dd8c15ce32ffe8680bbd1e978a43bf51a34d39483540495f5" +checksum = "091397be61a01d4be58e7841595bd4bfedb15f1cd54977d79b8271e94ed799a3" dependencies = [ "portable-atomic", ] @@ -5537,9 +5577,9 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.17.0" +version = "3.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "381b283ce7bc6b476d903296fb59d0d36633652b633b27f64db4fb46dcbfc3b9" +checksum = "dd5414fad8e6907dbdd5bc441a50ae8d6e26151a03b1de04d89a5576de61d01f" dependencies = [ "base64", "chrono", @@ -5556,11 +5596,11 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.17.0" +version = "3.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6d4e30573c8cb306ed6ab1dca8423eec9a463ea0e155f45399455e0368b27e0" +checksum = "d3db8978e608f1fe7357e211969fd9abdcae80bac1ba7a3369bb7eb6b404eb65" dependencies = [ - "darling 0.21.3", + "darling 0.23.0", "proc-macro2", "quote", "syn 2.0.117", @@ -6014,9 +6054,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa" +checksum = "3e61e67053d25a4e82c844e8424039d9745781b3fc4f32b8d55ed50f5f667ef3" dependencies = [ "tinyvec_macros", ] @@ -6258,6 +6298,16 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" +[[package]] +name = "universal-hash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea" +dependencies = [ + "crypto-common", + "subtle", +] + [[package]] name = "unsafe-any-ors" version = "1.0.0" diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index d4639c86ea..b7602f4f9c 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -41,6 +41,10 @@ twox-hash = "2.1.2" rand = { workspace = true } hex = "0.4.3" base64 = "0.22.1" +aes = "0.8.4" +aes-gcm = "0.10.3" +cbc = "0.1.2" +ecb = "0.1.2" [dev-dependencies] arrow = {workspace = true} diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index ff75de763b..15340a867b 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -20,10 +20,10 @@ use crate::math_funcs::abs::abs; use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mul, checked_sub}; use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ - spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan, - spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex, - spark_unscaled_value, EvalMode, SparkContains, SparkDateDiff, SparkDateTrunc, SparkMakeDate, - SparkSizeFunc, + spark_aes_decrypt, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, + spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, + spark_unhex, spark_unscaled_value, EvalMode, SparkContains, SparkDateDiff, SparkDateTrunc, + SparkMakeDate, SparkSizeFunc, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -177,6 +177,10 @@ pub fn create_comet_physical_fun_with_eval_mode( let func = Arc::new(abs); make_comet_scalar_udf!("abs", func, without data_type) } + "aes_decrypt" => { + let func = Arc::new(spark_aes_decrypt); + make_comet_scalar_udf!("aes_decrypt", func, without data_type) + } "split" => { let func = Arc::new(crate::string_funcs::spark_split); make_comet_scalar_udf!("split", func, without data_type) diff --git a/native/spark-expr/src/downcast.rs b/native/spark-expr/src/downcast.rs new file mode 100644 index 0000000000..ade2ef961b --- /dev/null +++ b/native/spark-expr/src/downcast.rs @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +macro_rules! opt_downcast_arg { + ($ARG:expr, $ARRAY_TYPE:ident) => {{ + $ARG.as_any().downcast_ref::<$ARRAY_TYPE>() + }}; +} + +macro_rules! downcast_named_arg { + ($ARG:expr, $NAME:expr, $ARRAY_TYPE:ident) => {{ + $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { + datafusion::common::internal_datafusion_err!( + "could not cast {} to {}", + $NAME, + std::any::type_name::<$ARRAY_TYPE>() + ) + })? + }}; +} + +pub(crate) use {downcast_named_arg, opt_downcast_arg}; diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index ba19d6a9b2..f737fa87fc 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -19,7 +19,9 @@ // The lint makes easier for code reader/reviewer separate references clones from more heavyweight ones #![deny(clippy::clone_on_ref_ptr)] +mod downcast; mod error; +pub(crate) use downcast::{downcast_named_arg, opt_downcast_arg}; mod query_context; pub mod kernels; @@ -58,6 +60,7 @@ pub use bloom_filter::{BloomFilterAgg, BloomFilterMightContain}; mod conditional_funcs; mod conversion_funcs; mod math_funcs; +mod misc_funcs; mod nondetermenistic_funcs; pub use array_funcs::*; @@ -84,6 +87,7 @@ pub use math_funcs::{ NormalizeNaNAndZero, WideDecimalBinaryExpr, WideDecimalOp, }; pub use query_context::{create_query_context_map, QueryContext, QueryContextMap}; +pub use misc_funcs::spark_aes_decrypt; pub use string_funcs::*; /// Spark supports three evaluation modes when evaluating expressions, which affect diff --git a/native/spark-expr/src/misc_funcs/aes_decrypt.rs b/native/spark-expr/src/misc_funcs/aes_decrypt.rs new file mode 100644 index 0000000000..605fe50994 --- /dev/null +++ b/native/spark-expr/src/misc_funcs/aes_decrypt.rs @@ -0,0 +1,311 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use aes::cipher::consts::{U12, U16}; +use aes::{Aes128, Aes192, Aes256}; +use aes_gcm::aead::{Aead, Payload}; +use aes_gcm::{Aes128Gcm, Aes256Gcm, AesGcm, KeyInit, Nonce}; +use arrow::array::{Array, ArrayRef, BinaryArray, LargeBinaryArray, LargeStringArray, StringArray}; +use arrow::datatypes::DataType; +use cbc::cipher::{block_padding::Pkcs7, BlockDecryptMut, KeyIvInit}; +use datafusion::common::{exec_err, DataFusionError, ScalarValue}; +use datafusion::logical_expr::ColumnarValue; + +const GCM_IV_LEN: usize = 12; +const CBC_IV_LEN: usize = 16; + +#[derive(Clone, Copy)] +enum AesMode { + Ecb, + Cbc, + Gcm, +} + +impl AesMode { + fn from_mode_padding(mode: &str, padding: &str) -> Result { + let is_none = padding.eq_ignore_ascii_case("NONE"); + let is_pkcs = padding.eq_ignore_ascii_case("PKCS"); + let is_default = padding.eq_ignore_ascii_case("DEFAULT"); + + if mode.eq_ignore_ascii_case("ECB") && (is_pkcs || is_default) { + Ok(Self::Ecb) + } else if mode.eq_ignore_ascii_case("CBC") && (is_pkcs || is_default) { + Ok(Self::Cbc) + } else if mode.eq_ignore_ascii_case("GCM") && (is_none || is_default) { + Ok(Self::Gcm) + } else { + exec_err!("Unsupported AES mode/padding combination: {mode}/{padding}") + } + } +} + +enum BinaryArg<'a> { + Binary(&'a BinaryArray), + LargeBinary(&'a LargeBinaryArray), +} + +impl<'a> BinaryArg<'a> { + fn from(arg_name: &str, arr: &'a ArrayRef) -> Result { + match arr.data_type() { + DataType::Binary => Ok(Self::Binary( + crate::opt_downcast_arg!(arr, BinaryArray).ok_or_else(|| { + datafusion::common::internal_datafusion_err!( + "could not cast {} to {}", + arg_name, + std::any::type_name::() + ) + })?, + )), + DataType::LargeBinary => Ok(Self::LargeBinary(crate::downcast_named_arg!( + arr, + arg_name, + LargeBinaryArray + ))), + other => exec_err!("{arg_name} must be Binary/LargeBinary, got {other:?}"), + } + } + + fn value(&self, i: usize) -> Option<&'a [u8]> { + match self { + Self::Binary(arr) => (!arr.is_null(i)).then(|| arr.value(i)), + Self::LargeBinary(arr) => (!arr.is_null(i)).then(|| arr.value(i)), + } + } +} + +enum StringArg<'a> { + Utf8(&'a StringArray), + LargeUtf8(&'a LargeStringArray), +} + +impl<'a> StringArg<'a> { + fn from(arg_name: &str, arr: &'a ArrayRef) -> Result { + match arr.data_type() { + DataType::Utf8 => Ok(Self::Utf8( + crate::opt_downcast_arg!(arr, StringArray).ok_or_else(|| { + datafusion::common::internal_datafusion_err!( + "could not cast {} to {}", + arg_name, + std::any::type_name::() + ) + })?, + )), + DataType::LargeUtf8 => Ok(Self::LargeUtf8(crate::downcast_named_arg!( + arr, + arg_name, + LargeStringArray + ))), + other => exec_err!("{arg_name} must be Utf8/LargeUtf8, got {other:?}"), + } + } + + fn value(&self, i: usize) -> Option<&'a str> { + match self { + Self::Utf8(arr) => (!arr.is_null(i)).then(|| arr.value(i)), + Self::LargeUtf8(arr) => (!arr.is_null(i)).then(|| arr.value(i)), + } + } +} + +type Aes128CbcDec = cbc::Decryptor; +type Aes192CbcDec = cbc::Decryptor; +type Aes256CbcDec = cbc::Decryptor; +type Aes128EcbDec = ecb::Decryptor; +type Aes192EcbDec = ecb::Decryptor; +type Aes256EcbDec = ecb::Decryptor; +type Aes192Gcm = AesGcm; + +fn decrypt_pkcs_cbc(input: &[u8], key: &[u8]) -> Result, DataFusionError> { + if input.len() < CBC_IV_LEN { + return exec_err!("AES decryption input is too short for CBC"); + } + let (iv, ciphertext) = input.split_at(CBC_IV_LEN); + let mut buf = ciphertext.to_vec(); + + let out = match key.len() { + 16 => Aes128CbcDec::new_from_slices(key, iv) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt_padded_mut::(&mut buf) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))?, + 24 => Aes192CbcDec::new_from_slices(key, iv) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt_padded_mut::(&mut buf) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))?, + 32 => Aes256CbcDec::new_from_slices(key, iv) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt_padded_mut::(&mut buf) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))?, + _ => return exec_err!("Invalid AES key length: {}", key.len()), + }; + + Ok(out.to_vec()) +} + +fn decrypt_pkcs_ecb(input: &[u8], key: &[u8]) -> Result, DataFusionError> { + let mut buf = input.to_vec(); + + let out = match key.len() { + 16 => Aes128EcbDec::new_from_slice(key) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt_padded_mut::(&mut buf) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))?, + 24 => Aes192EcbDec::new_from_slice(key) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt_padded_mut::(&mut buf) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))?, + 32 => Aes256EcbDec::new_from_slice(key) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt_padded_mut::(&mut buf) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))?, + _ => return exec_err!("Invalid AES key length: {}", key.len()), + }; + + Ok(out.to_vec()) +} + +fn decrypt_gcm(input: &[u8], key: &[u8], aad: &[u8]) -> Result, DataFusionError> { + if input.len() < GCM_IV_LEN { + return exec_err!("AES decryption input is too short for GCM"); + } + let (iv, ciphertext) = input.split_at(GCM_IV_LEN); + let nonce = Nonce::from_slice(iv); + let payload = Payload { + msg: ciphertext, + aad, + }; + + match key.len() { + 16 => Aes128Gcm::new_from_slice(key) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt(nonce, payload) + .map_err(|_| { + DataFusionError::Execution("AES crypto error: decrypt failed".to_string()) + }), + 24 => Aes192Gcm::new_from_slice(key) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt(nonce, payload) + .map_err(|_| { + DataFusionError::Execution("AES crypto error: decrypt failed".to_string()) + }), + 32 => Aes256Gcm::new_from_slice(key) + .map_err(|e| DataFusionError::Execution(format!("AES crypto error: {e}")))? + .decrypt(nonce, payload) + .map_err(|_| { + DataFusionError::Execution("AES crypto error: decrypt failed".to_string()) + }), + _ => exec_err!("Invalid AES key length: {}", key.len()), + } +} + +fn decrypt_one( + input: &[u8], + key: &[u8], + mode: &str, + padding: &str, + aad: &[u8], +) -> Result, DataFusionError> { + match AesMode::from_mode_padding(mode, padding)? { + AesMode::Ecb => decrypt_pkcs_ecb(input, key), + AesMode::Cbc => decrypt_pkcs_cbc(input, key), + AesMode::Gcm => decrypt_gcm(input, key, aad), + } +} + +pub fn spark_aes_decrypt(args: &[ColumnarValue]) -> Result { + if !(2..=5).contains(&args.len()) { + return exec_err!("aes_decrypt expects 2 to 5 arguments, got {}", args.len()); + } + + let are_scalars = args + .iter() + .all(|arg| matches!(arg, ColumnarValue::Scalar(_))); + let arrays = ColumnarValue::values_to_arrays(args)?; + let num_rows = arrays[0].len(); + + let input = BinaryArg::from("input", &arrays[0])?; + let key = BinaryArg::from("key", &arrays[1])?; + + let mode = if args.len() >= 3 { + Some(StringArg::from("mode", &arrays[2])?) + } else { + None + }; + let padding = if args.len() >= 4 { + Some(StringArg::from("padding", &arrays[3])?) + } else { + None + }; + let aad = if args.len() >= 5 { + Some(BinaryArg::from("aad", &arrays[4])?) + } else { + None + }; + + let values: Result, DataFusionError> = (0..num_rows) + .map(|row| { + let Some(input_value) = input.value(row) else { + return Ok(ScalarValue::Binary(None)); + }; + let Some(key_value) = key.value(row) else { + return Ok(ScalarValue::Binary(None)); + }; + + let mode_value = match mode.as_ref() { + Some(mode) => { + let Some(mode) = mode.value(row) else { + return Ok(ScalarValue::Binary(None)); + }; + mode + } + None => "GCM", + }; + + let padding_value = match padding.as_ref() { + Some(padding) => { + let Some(padding) = padding.value(row) else { + return Ok(ScalarValue::Binary(None)); + }; + padding + } + None => "DEFAULT", + }; + + let aad_value = match aad.as_ref() { + Some(aad) => { + let Some(aad) = aad.value(row) else { + return Ok(ScalarValue::Binary(None)); + }; + aad + } + None => &[], + }; + + let plaintext = + decrypt_one(input_value, key_value, mode_value, padding_value, aad_value)?; + Ok(ScalarValue::Binary(Some(plaintext))) + }) + .collect(); + + let array: ArrayRef = ScalarValue::iter_to_array(values?)?; + if are_scalars { + Ok(ColumnarValue::Scalar( + datafusion::common::ScalarValue::try_from_array(array.as_ref(), 0)?, + )) + } else { + Ok(ColumnarValue::Array(array)) + } +} diff --git a/native/spark-expr/src/misc_funcs/mod.rs b/native/spark-expr/src/misc_funcs/mod.rs new file mode 100644 index 0000000000..c55b82811d --- /dev/null +++ b/native/spark-expr/src/misc_funcs/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub(crate) mod aes_decrypt; + +pub use aes_decrypt::spark_aes_decrypt; diff --git a/native/spark-expr/tests/spark_expr_reg.rs b/native/spark-expr/tests/spark_expr_reg.rs index 633b226068..f381b77881 100644 --- a/native/spark-expr/tests/spark_expr_reg.rs +++ b/native/spark-expr/tests/spark_expr_reg.rs @@ -35,6 +35,12 @@ mod tests { &session_state, None, )?); + let _ = session_state.register_udf(create_comet_physical_fun( + "aes_decrypt", + DataType::Binary, + &session_state, + None, + )?); let ctx = SessionContext::new_with_state(session_state); // 2. Execute SQL with literal values diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 8c39ba779d..8d358ee29c 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -222,6 +222,8 @@ object QueryPlanSerde extends Logging with CometExprShim { private val miscExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( // TODO PromotePrecision + // AesDecrypt extends RuntimeReplaceable and is rewritten to StaticInvoke before Comet's + // serde runs, so it is handled via CometStaticInvoke / CometAesDecryptStaticInvoke. classOf[Alias] -> CometAlias, classOf[AttributeReference] -> CometAttributeReference, classOf[BloomFilterMightContain] -> CometBloomFilterMightContain, diff --git a/spark/src/main/scala/org/apache/comet/serde/misc.scala b/spark/src/main/scala/org/apache/comet/serde/misc.scala new file mode 100644 index 0000000000..1f1bc1854f --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/misc.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke + +import org.apache.comet.serde.ExprOuterClass.Expr +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProtoWithReturnType} + +// AesDecrypt extends RuntimeReplaceable in Spark, so by the time Comet's serde runs it has +// already been replaced with a StaticInvoke. This handler is registered in CometStaticInvoke's +// staticInvokeExpressions map under the ("aesDecrypt", ExpressionImplUtils) key. +object CometAesDecryptStaticInvoke extends CometExpressionSerde[StaticInvoke] { + override def convert( + expr: StaticInvoke, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + val childExpr = expr.children.map(exprToProtoInternal(_, inputs, binding)) + val optExpr = scalarFunctionExprToProtoWithReturnType( + "aes_decrypt", + expr.dataType, + failOnError = false, + childExpr: _*) + optExprWithInfo(optExpr, expr, expr.children: _*) + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/statics.scala b/spark/src/main/scala/org/apache/comet/serde/statics.scala index 9dbc6d169f..5356b3e11b 100644 --- a/spark/src/main/scala/org/apache/comet/serde/statics.scala +++ b/spark/src/main/scala/org/apache/comet/serde/statics.scala @@ -35,7 +35,8 @@ object CometStaticInvoke extends CometExpressionSerde[StaticInvoke] { Map( ("readSidePadding", classOf[CharVarcharCodegenUtils]) -> CometScalarFunction( "read_side_padding"), - ("isLuhnNumber", classOf[ExpressionImplUtils]) -> CometScalarFunction("luhn_check")) + ("isLuhnNumber", classOf[ExpressionImplUtils]) -> CometScalarFunction("luhn_check"), + ("aesDecrypt", classOf[ExpressionImplUtils]) -> CometAesDecryptStaticInvoke) override def convert( expr: StaticInvoke, diff --git a/spark/src/test/resources/sql-tests/expressions/misc/aes_decrypt.sql b/spark/src/test/resources/sql-tests/expressions/misc/aes_decrypt.sql new file mode 100644 index 0000000000..5c93dd6a7a --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/misc/aes_decrypt.sql @@ -0,0 +1,164 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- MinSparkVersion: 3.5 + +statement +CREATE TABLE aes_tbl( + encrypted_default BINARY, + encrypted_with_aad BINARY, + `key` BINARY, + mode STRING, + padding STRING, + iv BINARY, + aad STRING +) USING parquet + +statement +INSERT INTO aes_tbl +SELECT + aes_encrypt( + encode('Spark SQL', 'UTF-8'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8')), + aes_encrypt( + encode('Spark SQL', 'UTF-8'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + 'GCM', + 'DEFAULT', + unhex('00112233445566778899AABB'), + 'Comet AAD'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + 'GCM', + 'DEFAULT', + unhex('00112233445566778899AABB'), + 'Comet AAD' + +query +SELECT CAST(aes_decrypt(encrypted_default, `key`) AS STRING) FROM aes_tbl + +query +SELECT CAST(aes_decrypt(encrypted_with_aad, `key`, mode, padding, aad) AS STRING) FROM aes_tbl + +statement +CREATE TABLE aes_modes_tbl( + encrypted BINARY, + `key` BINARY, + mode STRING, + padding STRING, + label STRING +) USING parquet + +statement +INSERT INTO aes_modes_tbl +SELECT + aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop', 'UTF-8'), 'GCM', 'DEFAULT'), + encode('abcdefghijklmnop', 'UTF-8'), + 'GCM', + 'DEFAULT', + 'gcm_128' +UNION ALL +SELECT + aes_encrypt( + encode('Spark SQL', 'UTF-8'), + encode('abcdefghijklmnop12345678', 'UTF-8'), + 'GCM', + 'DEFAULT'), + encode('abcdefghijklmnop12345678', 'UTF-8'), + 'GCM', + 'DEFAULT', + 'gcm_192' +UNION ALL +SELECT + aes_encrypt( + encode('Spark SQL', 'UTF-8'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + 'GCM', + 'DEFAULT'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + 'GCM', + 'DEFAULT', + 'gcm_256' +UNION ALL +SELECT + aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop', 'UTF-8'), 'CBC', 'PKCS'), + encode('abcdefghijklmnop', 'UTF-8'), + 'CBC', + 'PKCS', + 'cbc_128' +UNION ALL +SELECT + aes_encrypt( + encode('Spark SQL', 'UTF-8'), + encode('abcdefghijklmnop12345678', 'UTF-8'), + 'CBC', + 'PKCS'), + encode('abcdefghijklmnop12345678', 'UTF-8'), + 'CBC', + 'PKCS', + 'cbc_192' +UNION ALL +SELECT + aes_encrypt( + encode('Spark SQL', 'UTF-8'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + 'CBC', + 'PKCS'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + 'CBC', + 'PKCS', + 'cbc_256' +UNION ALL +SELECT + aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop', 'UTF-8'), 'ECB', 'PKCS'), + encode('abcdefghijklmnop', 'UTF-8'), + 'ECB', + 'PKCS', + 'ecb_128' +UNION ALL +SELECT + aes_encrypt( + encode('Spark SQL', 'UTF-8'), + encode('abcdefghijklmnop12345678', 'UTF-8'), + 'ECB', + 'PKCS'), + encode('abcdefghijklmnop12345678', 'UTF-8'), + 'ECB', + 'PKCS', + 'ecb_192' +UNION ALL +SELECT + aes_encrypt( + encode('Spark SQL', 'UTF-8'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + 'ECB', + 'PKCS'), + encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), + 'ECB', + 'PKCS', + 'ecb_256' +UNION ALL +SELECT + cast(null AS binary), + encode('abcdefghijklmnop', 'UTF-8'), + 'GCM', + 'DEFAULT', + 'null_input' + +query +SELECT label, CAST(aes_decrypt(encrypted, `key`, mode, padding) AS STRING) +FROM aes_modes_tbl +ORDER BY label diff --git a/spark/src/test/scala/org/apache/comet/CometMiscExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometMiscExpressionSuite.scala new file mode 100644 index 0000000000..2e6cc90cb8 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometMiscExpressionSuite.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.apache.spark.sql.CometTestBase + +import org.apache.comet.CometSparkSessionExtensions.isSpark35Plus + +class CometMiscExpressionSuite extends CometTestBase { + + test("aes_decrypt") { + withTempView("aes_tbl") { + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + val aesDf = if (isSpark35Plus) { + spark + .range(1) + .selectExpr( + "aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8')) as encrypted_default", + "aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), 'GCM', 'DEFAULT', unhex('00112233445566778899AABB'), 'Comet AAD') as encrypted_with_aad", + "encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8') as `key`", + "'GCM' as mode", + "'DEFAULT' as padding", + "unhex('00112233445566778899AABB') as iv", + "'Comet AAD' as aad") + } else { + spark + .range(1) + .selectExpr( + "aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8')) as encrypted_default", + "aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), 'GCM', 'DEFAULT') as encrypted_with_aad", + "encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8') as `key`", + "'GCM' as mode", + "'DEFAULT' as padding", + "cast(null as binary) as iv", + "cast(null as string) as aad") + } + aesDf.createOrReplaceTempView("aes_tbl") + } + + if (isSpark35Plus) { + checkSparkAnswerAndOperator( + "SELECT CAST(aes_decrypt(encrypted_default, `key`) AS STRING) FROM aes_tbl") + checkSparkAnswerAndOperator( + "SELECT CAST(aes_decrypt(encrypted_with_aad, `key`, mode, padding, aad) AS STRING) FROM aes_tbl") + } else { + checkSparkAnswerAndOperator( + "SELECT CAST(aes_decrypt(encrypted_default, `key`) AS STRING) FROM aes_tbl") + checkSparkAnswerAndOperator( + "SELECT CAST(aes_decrypt(encrypted_with_aad, `key`, mode, padding) AS STRING) FROM aes_tbl") + } + } + } + + test("aes_decrypt mode and key-size combinations") { + withTempView("aes_modes_tbl") { + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + spark + .sql(""" + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop', 'UTF-8'), 'GCM', 'DEFAULT') AS encrypted, + | encode('abcdefghijklmnop', 'UTF-8') AS `key`, + | 'GCM' AS mode, + | 'DEFAULT' AS padding, + | 'gcm_128' AS label + |UNION ALL + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678', 'UTF-8'), 'GCM', 'DEFAULT') AS encrypted, + | encode('abcdefghijklmnop12345678', 'UTF-8') AS `key`, + | 'GCM' AS mode, + | 'DEFAULT' AS padding, + | 'gcm_192' AS label + |UNION ALL + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), 'GCM', 'DEFAULT') AS encrypted, + | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8') AS `key`, + | 'GCM' AS mode, + | 'DEFAULT' AS padding, + | 'gcm_256' AS label + |UNION ALL + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop', 'UTF-8'), 'CBC', 'PKCS') AS encrypted, + | encode('abcdefghijklmnop', 'UTF-8') AS `key`, + | 'CBC' AS mode, + | 'PKCS' AS padding, + | 'cbc_128' AS label + |UNION ALL + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678', 'UTF-8'), 'CBC', 'PKCS') AS encrypted, + | encode('abcdefghijklmnop12345678', 'UTF-8') AS `key`, + | 'CBC' AS mode, + | 'PKCS' AS padding, + | 'cbc_192' AS label + |UNION ALL + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), 'CBC', 'PKCS') AS encrypted, + | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8') AS `key`, + | 'CBC' AS mode, + | 'PKCS' AS padding, + | 'cbc_256' AS label + |UNION ALL + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop', 'UTF-8'), 'ECB', 'PKCS') AS encrypted, + | encode('abcdefghijklmnop', 'UTF-8') AS `key`, + | 'ECB' AS mode, + | 'PKCS' AS padding, + | 'ecb_128' AS label + |UNION ALL + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678', 'UTF-8'), 'ECB', 'PKCS') AS encrypted, + | encode('abcdefghijklmnop12345678', 'UTF-8') AS `key`, + | 'ECB' AS mode, + | 'PKCS' AS padding, + | 'ecb_192' AS label + |UNION ALL + |SELECT + | aes_encrypt(encode('Spark SQL', 'UTF-8'), encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8'), 'ECB', 'PKCS') AS encrypted, + | encode('abcdefghijklmnop12345678ABCDEFGH', 'UTF-8') AS `key`, + | 'ECB' AS mode, + | 'PKCS' AS padding, + | 'ecb_256' AS label + |UNION ALL + |SELECT + | cast(null AS binary) AS encrypted, + | encode('abcdefghijklmnop', 'UTF-8') AS `key`, + | 'GCM' AS mode, + | 'DEFAULT' AS padding, + | 'null_input' AS label + |""".stripMargin) + .createOrReplaceTempView("aes_modes_tbl") + } + + checkSparkAnswerAndOperator(""" + |SELECT + | label, + | CAST(aes_decrypt(encrypted, `key`, mode, padding) AS STRING) AS decrypted + |FROM aes_modes_tbl + |ORDER BY label + |""".stripMargin) + } + } + +}