Skip to content

Commit 5f7b017

Browse files
committed
fix: array to array cast
1 parent 726c4b7 commit 5f7b017

3 files changed

Lines changed: 43 additions & 4 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ apache-rat-*.jar
1818
venv
1919
dev/release/comet-rm/workdir
2020
spark/benchmarks
21+
/comet-event-trace.json

native/spark-expr/src/conversion_funcs/cast.rs

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1076,8 +1076,22 @@ fn cast_array(
10761076
cast_options,
10771077
)?),
10781078
(List(_), Utf8) => Ok(cast_array_to_string(array.as_list(), cast_options)?),
1079-
(List(_), List(_)) if can_cast_types(from_type, to_type) => {
1080-
Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
1079+
(List(from), List(to))
1080+
if can_cast_types(from_type, to_type)
1081+
|| (matches!(from.data_type(), Decimal128(_, _))
1082+
&& matches!(to.data_type(), Boolean)) =>
1083+
{
1084+
let list_array = array.as_list::<i32>();
1085+
Ok(Arc::new(ListArray::new(
1086+
Arc::clone(to),
1087+
list_array.offsets().clone(),
1088+
cast_array(
1089+
Arc::clone(list_array.values()),
1090+
to.data_type(),
1091+
cast_options,
1092+
)?,
1093+
list_array.nulls().cloned(),
1094+
)) as ArrayRef)
10811095
}
10821096
(UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64)
10831097
if cast_options.allow_cast_unsigned_ints =>

spark/src/test/scala/org/apache/comet/CometCastSuite.scala

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,6 +1154,30 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
11541154
}
11551155
}
11561156

1157+
test("cast ArrayType to ArrayType") {
1158+
val types = Seq(
1159+
BooleanType,
1160+
StringType,
1161+
ByteType,
1162+
IntegerType,
1163+
LongType,
1164+
ShortType,
1165+
DecimalType(10, 2),
1166+
DecimalType(38, 18))
1167+
for (fromType <- types) {
1168+
for (toType <- types) {
1169+
if (fromType != toType &&
1170+
!tags
1171+
.get(s"cast $fromType to $toType")
1172+
.exists(s => s.contains("org.scalatest.Ignore")) &&
1173+
Cast.canCast(fromType, toType) &&
1174+
CometCast.isSupported(fromType, toType, None, CometEvalMode.LEGACY) == Compatible()) {
1175+
castTest(generateArrays(100, fromType), ArrayType(toType))
1176+
}
1177+
}
1178+
}
1179+
}
1180+
11571181
private def generateFloats(): DataFrame = {
11581182
withNulls(gen.generateFloats(dataSize)).toDF("a")
11591183
}
@@ -1182,10 +1206,10 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
11821206
withNulls(gen.generateLongs(dataSize)).toDF("a")
11831207
}
11841208

1185-
private def generateArrays(rowSize: Int, elementType: DataType): DataFrame = {
1209+
private def generateArrays(rowNum: Int, elementType: DataType): DataFrame = {
11861210
import scala.collection.JavaConverters._
11871211
val schema = StructType(Seq(StructField("a", ArrayType(elementType), true)))
1188-
spark.createDataFrame(gen.generateRows(rowSize, schema).asJava, schema)
1212+
spark.createDataFrame(gen.generateRows(rowNum, schema).asJava, schema)
11891213
}
11901214

11911215
// https://github.com/apache/datafusion-comet/issues/2038

0 commit comments

Comments
 (0)