Skip to content

Commit fae74f6

Browse files
committed
fix: array to array cast
1 parent f6d84b1 commit fae74f6

2 files changed

Lines changed: 83 additions & 7 deletions

File tree

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ pub(crate) fn cast_array(
294294
};
295295

296296
let cast_result = match (&from_type, to_type) {
297+
(Null, _) => Ok(cast_with_options(&array, to_type, &native_cast_options)?),
297298
(Utf8, Boolean) => spark_cast_utf8_to_boolean::<i32>(&array, eval_mode),
298299
(LargeUtf8, Boolean) => spark_cast_utf8_to_boolean::<i64>(&array, eval_mode),
299300
(Utf8, Timestamp(_, _)) => {
@@ -366,8 +367,22 @@ pub(crate) fn cast_array(
366367
cast_options,
367368
)?),
368369
(List(_), Utf8) => Ok(cast_array_to_string(array.as_list(), cast_options)?),
369-
(List(_), List(_)) if can_cast_types(&from_type, to_type) => {
370-
Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
370+
(List(from), List(to))
371+
if can_cast_types(&from_type, to_type)
372+
|| (matches!(from.data_type(), Decimal128(_, _))
373+
&& matches!(to.data_type(), Boolean)) =>
374+
{
375+
let list_array = array.as_list::<i32>();
376+
Ok(Arc::new(ListArray::new(
377+
Arc::clone(to),
378+
list_array.offsets().clone(),
379+
cast_array(
380+
Arc::clone(list_array.values()),
381+
to.data_type(),
382+
cast_options,
383+
)?,
384+
list_array.nulls().cloned(),
385+
)) as ArrayRef)
371386
}
372387
(Map(_, _), Map(_, _)) => Ok(cast_map_to_map(&array, &from_type, to_type, cast_options)?),
373388
(UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64)
@@ -803,7 +818,8 @@ fn cast_binary_formatter(value: &[u8]) -> String {
803818
#[cfg(test)]
804819
mod tests {
805820
use super::*;
806-
use arrow::array::StringArray;
821+
use arrow::array::{BooleanArray, Decimal128Array, ListArray, StringArray};
822+
use arrow::buffer::OffsetBuffer;
807823
use arrow::datatypes::TimestampMicrosecondType;
808824
use arrow::datatypes::{Field, Fields};
809825
#[test]
@@ -955,8 +971,6 @@ mod tests {
955971

956972
#[test]
957973
fn test_cast_i32_array_to_string() {
958-
use arrow::array::ListArray;
959-
use arrow::buffer::OffsetBuffer;
960974
let values_array = Int32Array::from(vec![Some(1), Some(2), Some(3), Some(1), None, None]);
961975
let offsets_buffer = OffsetBuffer::<i32>::new(vec![0, 3, 5, 6, 6].into());
962976
let item_field = Arc::new(Field::new("item", DataType::Int32, true));
@@ -977,4 +991,42 @@ mod tests {
977991
assert_eq!(r#"[null]"#, string_array.value(2));
978992
assert_eq!(r#"[]"#, string_array.value(3));
979993
}
994+
995+
#[test]
996+
fn test_cast_decimal_array_to_boolean_array() {
997+
let values_array =
998+
Decimal128Array::from(vec![Some(0), Some(100), None, Some(-100), Some(0)])
999+
.with_precision_and_scale(10, 2)
1000+
.unwrap();
1001+
let offsets_buffer = OffsetBuffer::<i32>::new(vec![0, 2, 3, 5].into());
1002+
let item_field = Arc::new(Field::new("item", DataType::Decimal128(10, 2), true));
1003+
let list_array = Arc::new(ListArray::new(
1004+
item_field,
1005+
offsets_buffer,
1006+
Arc::new(values_array),
1007+
Some(vec![true, false, true].into()),
1008+
));
1009+
1010+
let casted_array = cast_array(
1011+
list_array,
1012+
&DataType::List(Arc::new(Field::new("item", DataType::Boolean, true))),
1013+
&SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
1014+
)
1015+
.unwrap();
1016+
let casted_list = casted_array.as_list::<i32>();
1017+
1018+
assert_eq!(3, casted_list.len());
1019+
assert!(!casted_list.is_null(0));
1020+
assert!(casted_list.is_null(1));
1021+
assert!(!casted_list.is_null(2));
1022+
1023+
let values = casted_list
1024+
.values()
1025+
.as_any()
1026+
.downcast_ref::<BooleanArray>()
1027+
.expect("Casted list values should be a BooleanArray");
1028+
let expected =
1029+
BooleanArray::from(vec![Some(false), Some(true), None, Some(true), Some(false)]);
1030+
assert_eq!(&expected, values);
1031+
}
9801032
}

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1266,6 +1266,30 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
12661266
}
12671267
}
12681268

1269+
test("cast ArrayType to ArrayType") {
1270+
val types = Seq(
1271+
BooleanType,
1272+
StringType,
1273+
ByteType,
1274+
IntegerType,
1275+
LongType,
1276+
ShortType,
1277+
DecimalType(10, 2),
1278+
DecimalType(38, 18))
1279+
for (fromType <- types) {
1280+
for (toType <- types) {
1281+
if (fromType != toType &&
1282+
!tags
1283+
.get(s"cast $fromType to $toType")
1284+
.exists(s => s.contains("org.scalatest.Ignore")) &&
1285+
Cast.canCast(fromType, toType) &&
1286+
CometCast.isSupported(fromType, toType, None, CometEvalMode.LEGACY) == Compatible()) {
1287+
castTest(generateArrays(100, fromType), ArrayType(toType))
1288+
}
1289+
}
1290+
}
1291+
}
1292+
12691293
private def generateFloats(): DataFrame = {
12701294
withNulls(gen.generateFloats(dataSize)).toDF("a")
12711295
}
@@ -1294,10 +1318,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
12941318
withNulls(gen.generateLongs(dataSize)).toDF("a")
12951319
}
12961320

1297-
private def generateArrays(rowSize: Int, elementType: DataType): DataFrame = {
1321+
private def generateArrays(rowNum: Int, elementType: DataType): DataFrame = {
12981322
import scala.collection.JavaConverters._
12991323
val schema = StructType(Seq(StructField("a", ArrayType(elementType), true)))
1300-
spark.createDataFrame(gen.generateRows(rowSize, schema).asJava, schema)
1324+
spark.createDataFrame(gen.generateRows(rowNum, schema).asJava, schema)
13011325
}
13021326

13031327
// https://github.com/apache/datafusion-comet/issues/2038

0 commit comments

Comments
 (0)