diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 094777e796..0e9a7fc755 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -71,7 +71,8 @@ use datafusion::{ }; use datafusion_comet_spark_expr::{ create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, BinaryOutputStyle, - BloomFilterAgg, BloomFilterMightContain, CsvWriteOptions, EvalMode, SumInteger, ToCsv, + BloomFilterAgg, BloomFilterMightContain, CsvWriteOptions, EvalMode, SparkArraysZipFunc, + SumInteger, ToCsv, }; use iceberg::expr::Bind; @@ -96,6 +97,7 @@ use datafusion::physical_expr::LexOrdering; use crate::parquet::parquet_exec::init_datasource_exec; +use crate::execution::planner::expression_registry::ExpressionType::ArraysZip; use arrow::array::{ new_empty_array, Array, ArrayRef, BinaryBuilder, BooleanArray, Date32Array, Decimal128Array, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, ListArray, @@ -604,6 +606,17 @@ impl PhysicalPlanner { csv_write_options, ))) } + ExprStruct::ArraysZip(expr) => { + assert!(!expr.children.is_empty()); + + let children = expr + .children + .iter() + .map(|child| self.create_expr(child, Arc::clone(&input_schema))) + .collect::, _>>()?; + + Ok(Arc::new(SparkArraysZipFunc::new(children))) + } expr => Err(GeneralError(format!("Not implemented: {expr:?}"))), } } diff --git a/native/core/src/execution/planner/expression_registry.rs b/native/core/src/execution/planner/expression_registry.rs index 34aa3de179..3a5b93b99b 100644 --- a/native/core/src/execution/planner/expression_registry.rs +++ b/native/core/src/execution/planner/expression_registry.rs @@ -103,6 +103,7 @@ pub enum ExpressionType { Randn, SparkPartitionId, MonotonicallyIncreasingId, + ArraysZip, // Time functions Hour, @@ -370,6 +371,7 @@ impl ExpressionRegistry { Some(ExprStruct::MonotonicallyIncreasingId(_)) => { Ok(ExpressionType::MonotonicallyIncreasingId) } + Some(ExprStruct::ArraysZip(_)) => Ok(ExpressionType::ArraysZip), Some(ExprStruct::Hour(_)) => Ok(ExpressionType::Hour), Some(ExprStruct::Minute(_)) => Ok(ExpressionType::Minute), diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index 944505ba1c..b7557c2778 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -88,6 +88,7 @@ message Expr { UnixTimestamp unix_timestamp = 65; FromJson from_json = 66; ToCsv to_csv = 67; + ArraysZip arrays_zip = 68; } } @@ -440,3 +441,7 @@ message ArrayJoin { message Rand { int64 seed = 1; } + +message ArraysZip { + repeated Expr children = 1; +} \ No newline at end of file diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index c5880e00ed..77288ca94d 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -64,7 +64,8 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[ElementAt] -> CometElementAt, classOf[Flatten] -> CometFlatten, classOf[GetArrayItem] -> CometGetArrayItem, - classOf[Size] -> CometSize) + classOf[Size] -> CometSize, + classOf[ArraysZip] -> CometArraysZip) private val conditionalExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(classOf[CaseWhen] -> CometCaseWhen, classOf[If] -> CometIf) diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index b7ebb9ba7b..ad1bcf9fd6 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -20,8 +20,9 @@ package org.apache.comet.serde import scala.annotation.tailrec +import scala.jdk.CollectionConverters._ -import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size} +import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArraysZip, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size} import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -657,6 +658,35 @@ object CometSize extends CometExpressionSerde[Size] { } +object CometArraysZip extends CometExpressionSerde[ArraysZip] { + override def getSupportLevel(expr: ArraysZip): SupportLevel = { + expr.dataType match { + case _: ArrayType => Compatible() + case other => + // this should be unreachable because Spark only supports array inputs + Unsupported(Some(s"Unsupported child data type: $other")) + } + } + + override def convert( + expr: ArraysZip, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + + val exprChildren = expr.children.map(exprToProto(_, inputs, binding)) + if (exprChildren.forall(_.isDefined)) { + val builder = ExprOuterClass.ArraysZip + .newBuilder() + .addAllChildren(exprChildren.map(_.get).asJava) + + Some(ExprOuterClass.Expr.newBuilder().setArraysZip(builder).build()) + } else { + withInfo(expr, expr.children: _*) + None + } + } +} + trait ArraysBase { def isTypeSupported(dt: DataType): Boolean = { diff --git a/spark/src/test/resources/sql-tests/expressions/array/arrays_zip.sql b/spark/src/test/resources/sql-tests/expressions/array/arrays_zip.sql new file mode 100644 index 0000000000..49818b4bd9 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/array/arrays_zip.sql @@ -0,0 +1,27 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- ConfigMatrix: parquet.enable.dictionary=false,true + +query +SELECT arrays_zip(array(1, 2), array(2, 3), array(3, 4)); + +query +SELECT arrays_zip(array(1, 2, 3), array('a', 'b')); + +query +SELECT arrays_zip(array(1, null, 3), array('x', 'y', 'z')); \ No newline at end of file