From 4518deaa38baf03f57adfc744f345b34d4e858fb Mon Sep 17 00:00:00 2001 From: Ruslan Iushchenko Date: Wed, 6 May 2026 16:04:52 +0200 Subject: [PATCH] #847 Fix array flattening in Spark 4+. --- .../cobrix/spark/cobol/utils/SparkUtils.scala | 34 +++++++++++++------ 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala b/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala index aa79dcad..ab58dffd 100644 --- a/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala +++ b/spark-cobol/src/main/scala/za/co/absa/cobrix/spark/cobol/utils/SparkUtils.scala @@ -80,6 +80,19 @@ object SparkUtils extends Logging { val fields = new mutable.ListBuffer[Column]() val stringFields = new mutable.ListBuffer[String]() val usedNames = new mutable.HashSet[String]() + val isUseArrayGet = df.sparkSession.version.split('.').head.toInt >= 4 + + def getArrayIndexExpr(path: String, index: Int): String = { + // Since Spark 4 the behavior of index operator has changed and requires it to be in-bounds and throws exception + // otherwise: [INVALID_ARRAY_INDEX] The index 0 is out of bounds. + // So 'get()' Spark SQL function is used instead which is introduced in Spark 4. + // Older Spark versions need to use the indexing via square brackets. + if (isUseArrayGet) { + s"get($path, $index)" + } else { + s"$path[$index]" + } + } def getNewFieldName(desiredName: String): String = { var name = desiredName @@ -102,21 +115,22 @@ object SparkUtils extends Logging { */ def flattenStructArray(path: String, fieldNamePrefix: String, structField: StructField, arrayType: ArrayType): Unit = { val maxInd = getMaxArraySize(s"$path${structField.name}") + val fieldName = s"$path`${structField.name}`" var i = 0 while (i < maxInd) { arrayType.elementType match { case st: StructType => val newFieldNamePrefix = s"${fieldNamePrefix}${i}_" - flattenGroup(s"$path`${structField.name}`[$i].", newFieldNamePrefix, st) + flattenGroup(s"${getArrayIndexExpr(fieldName, i)}.", newFieldNamePrefix, st) case ar: ArrayType => val newFieldNamePrefix = s"${fieldNamePrefix}${i}_" - flattenArray(s"$path`${structField.name}`[$i].", newFieldNamePrefix, structField, ar) + flattenArray(s"${getArrayIndexExpr(fieldName, i)}.", newFieldNamePrefix, structField, ar) // AtomicType is protected on package 'sql' level so have to enumerate all subtypes :( case _ => val newFieldNamePrefix = s"${fieldNamePrefix}${i}" val newFieldName = getNewFieldName(s"$newFieldNamePrefix") fields += expr(s"$path`${structField.name}`[$i]").as(newFieldName, structField.metadata) - stringFields += s"""expr("$path`${structField.name}`[$i] AS `$newFieldName`")""" + stringFields += s"""expr("${getArrayIndexExpr(fieldName, i)} AS `$newFieldName`")""" } i += 1 } @@ -128,17 +142,17 @@ object SparkUtils extends Logging { while (i < maxInd) { arrayType.elementType match { case st: StructType => - val newFieldNamePrefix = s"${fieldNamePrefix}${i}_" - flattenGroup(s"$path[$i]", newFieldNamePrefix, st) + val newFieldNamePrefix = s"$fieldNamePrefix${i}_" + flattenGroup(s"${getArrayIndexExpr(path, i)}", newFieldNamePrefix, st) case ar: ArrayType => - val newFieldNamePrefix = s"${fieldNamePrefix}${i}_" - flattenNestedArrays(s"$path[$i]", newFieldNamePrefix, ar, metadata) + val newFieldNamePrefix = s"$fieldNamePrefix${i}_" + flattenNestedArrays(s"${getArrayIndexExpr(path, i)}", newFieldNamePrefix, ar, metadata) // AtomicType is protected on package 'sql' level so have to enumerate all subtypes :( case _ => - val newFieldNamePrefix = s"${fieldNamePrefix}${i}" + val newFieldNamePrefix = s"$fieldNamePrefix${i}" val newFieldName = getNewFieldName(s"$newFieldNamePrefix") - fields += expr(s"$path[$i]").as(newFieldName, metadata) - stringFields += s"""expr("$path`[$i] AS `$newFieldName`")""" + fields += expr(s"${getArrayIndexExpr(path, i)}").as(newFieldName, metadata) + stringFields += s"""expr("${getArrayIndexExpr(path, i)} AS `$newFieldName`")""" } i += 1 }