diff --git a/native/spark-expr/src/conversion_funcs/cast.rs b/native/spark-expr/src/conversion_funcs/cast.rs index 6e5a80c84a..8f4e0adc0c 100644 --- a/native/spark-expr/src/conversion_funcs/cast.rs +++ b/native/spark-expr/src/conversion_funcs/cast.rs @@ -43,7 +43,6 @@ use arrow::array::builder::StringBuilder; use arrow::array::{ BinaryBuilder, DictionaryArray, GenericByteArray, ListArray, MapArray, StringArray, StructArray, }; -use arrow::compute::can_cast_types; use arrow::datatypes::{ArrowDictionaryKeyType, ArrowNativeType, DataType, Schema}; use arrow::datatypes::{Field, Fields, GenericBinaryType}; use arrow::error::ArrowError; @@ -294,6 +293,7 @@ pub(crate) fn cast_array( }; let cast_result = match (&from_type, to_type) { + (Null, _) => Ok(cast_with_options(&array, to_type, &native_cast_options)?), (Utf8, Boolean) => spark_cast_utf8_to_boolean::(&array, eval_mode), (LargeUtf8, Boolean) => spark_cast_utf8_to_boolean::(&array, eval_mode), (Utf8, Timestamp(_, _)) => { @@ -366,8 +366,18 @@ pub(crate) fn cast_array( cast_options, )?), (List(_), Utf8) => Ok(cast_array_to_string(array.as_list(), cast_options)?), - (List(_), List(_)) if can_cast_types(&from_type, to_type) => { - Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?) + (List(_), List(to)) => { + let list_array = array.as_list::(); + Ok(Arc::new(ListArray::new( + Arc::clone(to), + list_array.offsets().clone(), + cast_array( + Arc::clone(list_array.values()), + to.data_type(), + cast_options, + )?, + list_array.nulls().cloned(), + )) as ArrayRef) } (Map(_, _), Map(_, _)) => Ok(cast_map_to_map(&array, &from_type, to_type, cast_options)?), (UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64) @@ -803,7 +813,8 @@ fn cast_binary_formatter(value: &[u8]) -> String { #[cfg(test)] mod tests { use super::*; - use arrow::array::StringArray; + use arrow::array::{ListArray, StringArray}; + use arrow::buffer::OffsetBuffer; use arrow::datatypes::TimestampMicrosecondType; use arrow::datatypes::{Field, Fields}; #[test] @@ -929,8 +940,6 @@ mod tests { #[test] fn test_cast_string_array_to_string() { - use arrow::array::ListArray; - use arrow::buffer::OffsetBuffer; let values_array = StringArray::from(vec![Some("a"), Some("b"), Some("c"), Some("a"), None, None]); let offsets_buffer = OffsetBuffer::::new(vec![0, 3, 5, 6, 6].into()); @@ -955,8 +964,6 @@ mod tests { #[test] fn test_cast_i32_array_to_string() { - use arrow::array::ListArray; - use arrow::buffer::OffsetBuffer; let values_array = Int32Array::from(vec![Some(1), Some(2), Some(3), Some(1), None, None]); let offsets_buffer = OffsetBuffer::::new(vec![0, 3, 5, 6, 6].into()); let item_field = Arc::new(Field::new("item", DataType::Int32, true)); diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala index 3d9acc39e4..d2c2dad127 100644 --- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala @@ -1266,6 +1266,30 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("cast ArrayType to ArrayType") { + val types = Seq( + BooleanType, + StringType, + ByteType, + IntegerType, + LongType, + ShortType, + DecimalType(10, 2), + DecimalType(38, 18)) + for (fromType <- types) { + for (toType <- types) { + if (fromType != toType && + !tags + .get(s"cast $fromType to $toType") + .exists(s => s.contains("org.scalatest.Ignore")) && + Cast.canCast(fromType, toType) && + CometCast.isSupported(fromType, toType, None, CometEvalMode.LEGACY) == Compatible()) { + castTest(generateArrays(100, fromType), ArrayType(toType)) + } + } + } + } + private def generateFloats(): DataFrame = { withNulls(gen.generateFloats(dataSize)).toDF("a") } @@ -1294,10 +1318,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper { withNulls(gen.generateLongs(dataSize)).toDF("a") } - private def generateArrays(rowSize: Int, elementType: DataType): DataFrame = { + private def generateArrays(rowNum: Int, elementType: DataType): DataFrame = { import scala.collection.JavaConverters._ val schema = StructType(Seq(StructField("a", ArrayType(elementType), true))) - spark.createDataFrame(gen.generateRows(rowSize, schema).asJava, schema) + spark.createDataFrame(gen.generateRows(rowNum, schema).asJava, schema) } // https://github.com/apache/datafusion-comet/issues/2038