Skip to content
13 changes: 13 additions & 0 deletions spark/src/main/scala/org/apache/comet/vector/NativeUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,19 @@ class NativeUtil {
provider,
arrowArray,
arrowSchema)
case cv: org.apache.spark.sql.execution.vectorized.ConstantColumnVector =>
// Spark uses ConstantColumnVector for partition columns / per-batch constants (e.g.
// partition values, synthetic columns). Materialise to a fresh Arrow vector so Comet's
// native side -- which expects Arrow Arrays only -- can ingest the batch. Without this,
// queries that pull constants through a Comet operator fail with "Comet execution only
// takes Arrow Arrays".
val rows = batch.numRows()
numRows += rows
val materialised = org.apache.spark.sql.comet.util.Utils
.materializeConstantColumnVector(cv, cv.dataType(), rows, s"_const_$index", allocator)
val arrowSchema = ArrowSchema.wrap(schemaAddrs(index))
val arrowArray = ArrowArray.wrap(arrayAddrs(index))
Data.exportVector(allocator, materialised, null, arrowArray, arrowSchema)
case c =>
throw new SparkException(
"Comet execution only takes Arrow Arrays, but got " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@ package org.apache.spark.sql.comet.execution.arrow

import scala.jdk.CollectionConverters._

import org.apache.arrow.memory.BufferAllocator
import org.apache.arrow.vector._
import org.apache.arrow.vector.complex._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
import org.apache.spark.sql.comet.util.Utils
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.vectorized.ConstantColumnVector
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarArray

Expand Down Expand Up @@ -91,6 +93,36 @@ private[arrow] object ArrowWriter {
}
}

/**
* Materialises a Spark `ConstantColumnVector` (partition values / per-batch constants) into a
* fresh Arrow `FieldVector` holding the constant repeated `numRows` times.
*
* Reuses the per-type `ArrowFieldWriter`s above -- so EVERY type is covered (scalars, decimal,
* timestamps, and complex struct/array/map) and the logic stays in sync with Spark -- rather than
* a hand-rolled per-type switch. `ConstantColumnVector` returns its constant for any rowId, so a
* `ColumnarArray` view over rows `[0, numRows)` writes the constant (or null) `numRows` times.
*
* Lives in this package because `ArrowWriter` is `private[arrow]`. The caller owns the returned
* vector and must close it (or hand it to Arrow's exporter, which takes ownership).
*/
object ConstantColumnVectors {
def materialize(
cv: ConstantColumnVector,
dt: DataType,
numRows: Int,
name: String,
allocator: BufferAllocator,
timeZoneId: String): FieldVector = {
val field = Utils.toArrowField(name, dt, nullable = true, timeZoneId)
val vector = field.createVector(allocator)
vector.allocateNew()
val writer = ArrowWriter.createFieldWriter(vector)
writer.writeCol(new ColumnarArray(cv, 0, numRows))
writer.finish()
vector
}
}

class ArrowWriter(val root: VectorSchemaRoot, fields: Array[ArrowFieldWriter]) {

def schema: StructType = Utils.fromArrowSchema(root.getSchema())
Expand Down
49 changes: 49 additions & 0 deletions spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import java.nio.channels.Channels
import scala.jdk.CollectionConverters._

import org.apache.arrow.c.CDataDictionaryProvider
import org.apache.arrow.memory.BufferAllocator
import org.apache.arrow.vector._
import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector}
import org.apache.arrow.vector.dictionary.DictionaryProvider
Expand All @@ -38,6 +39,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.io.CompressionCodec
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.comet.execution.arrow.ArrowReaderIterator
import org.apache.spark.sql.execution.vectorized.ConstantColumnVector
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream}
Expand Down Expand Up @@ -387,6 +389,7 @@ object Utils extends CometTypeShim with Logging {
def getBatchFieldVectors(
batch: ColumnarBatch): (Seq[FieldVector], Option[DictionaryProvider]) = {
var provider: Option[DictionaryProvider] = None
val rows = batch.numRows()
val fieldVectors = (0 until batch.numCols()).map { index =>
batch.column(index) match {
case a: CometVector =>
Expand All @@ -399,6 +402,17 @@ object Utils extends CometTypeShim with Logging {

getFieldVector(valueVector, "serialize")

case cv: ConstantColumnVector =>
// Spark wraps file-source partition columns and other per-batch constants in
// `ConstantColumnVector`. Materialise to an Arrow vector so the serialisation path
// doesn't reject the batch.
materializeConstantColumnVector(
cv,
cv.dataType(),
rows,
s"_const_$index",
org.apache.comet.CometArrowAllocator)

case c =>
throw new SparkException(
s"Comet execution only takes Arrow Arrays, but got ${c.getClass}. " +
Expand Down Expand Up @@ -426,4 +440,39 @@ object Utils extends CometTypeShim with Logging {
throw new SparkException(s"Unsupported Arrow Vector for $reason: ${valueVector.getClass}")
}
}

/**
* Materialize a Spark `ConstantColumnVector` into a fresh Arrow `FieldVector` whose value is
* the same constant repeated `numRows` times.
*
* Spark wraps file-source partition columns and other per-batch constants in
* `ConstantColumnVector`; downstream Comet operators feeding `NativeUtil.exportBatch` or
* `getBatchFieldVectors` trip on it because those paths only handle `CometVector`. This helper
* materializes the constant into an Arrow vector inline.
*
* The caller owns the returned vector and must close it (or hand it to Arrow's exporter, which
* transfers ownership). The vector is allocated against `allocator`, sized to exactly
* `numRows`, and pre-filled with the constant value (or null when `cv.isNullAt(0)`).
*
* All Spark types are supported (delegates to the per-type ArrowFieldWriters, which include
* struct/array/map); throws only for a type Arrow itself can't represent.
*/
def materializeConstantColumnVector(
cv: ConstantColumnVector,
dt: DataType,
numRows: Int,
name: String,
allocator: BufferAllocator): FieldVector = {
// "UTC" is deliberate here, NOT the session-local timezone that `toArrowSchema` threads
// through. These constants are materialised alongside non-constant columns in the same
// batch/`VectorSchemaRoot`, and Comet's non-constant `TimestampType` columns are Arrow
// vectors exported from native execution, where Comet always tags them `Timestamp(us, "UTC")`
// (see native `serde.rs`). Spark itself stores `TimestampType` as micros in UTC, so the
// constant's value is already a UTC instant. Tagging the materialised constant "UTC" keeps its
// Arrow timezone metadata consistent with its sibling timestamp columns; threading the
// session-local timezone here would instead introduce the mismatch. `TimestampNTZType` carries
// no zone regardless of this argument.
org.apache.spark.sql.comet.execution.arrow.ConstantColumnVectors
.materialize(cv, dt, numRows, name, allocator, "UTC")
}
}
90 changes: 90 additions & 0 deletions spark/src/test/scala/org/apache/comet/vector/NativeUtilSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.comet.vector

import org.apache.arrow.c.{ArrowArray, ArrowSchema}
import org.apache.spark.sql.CometTestBase
import org.apache.spark.sql.execution.vectorized.ConstantColumnVector
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}

class NativeUtilSuite extends CometTestBase {

test("exportBatch round-trips a ConstantColumnVector through Arrow FFI") {
// Smoke test for the ConstantColumnVector arm of NativeUtil.exportBatch: a batch carrying
// Spark ConstantColumnVectors (partition values / per-batch constants) is exported across the
// Arrow C Data Interface and imported back, exercising materializeConstantColumnVector +
// Data.exportVector + the allocator handoff -- the FFI wiring that the serializeBatches test
// does not cover. Mirrors the export/import round trip that NativeUtil.getNextBatch performs
// in production, just without a native callee.
val numRows = 4

val valueCol = new ConstantColumnVector(numRows, IntegerType)
valueCol.setInt(42)
val nullCol = new ConstantColumnVector(numRows, IntegerType)
nullCol.setNull()

// A struct constant exercises the complex-type export path (getStruct/getChild) through FFI.
val structSchema = StructType(
Seq(StructField("id", IntegerType), StructField("name", StringType, nullable = true)))
val structCol = new ConstantColumnVector(numRows, structSchema)
structCol.setNotNull()
val idChild = new ConstantColumnVector(numRows, IntegerType)
idChild.setInt(7)
val nameChild = new ConstantColumnVector(numRows, StringType)
nameChild.setNull()
structCol.setChild(0, idChild)
structCol.setChild(1, nameChild)

val batch =
new ColumnarBatch(Array[ColumnVector](valueCol, nullCol, structCol), numRows)

val nativeUtil = new NativeUtil
var imported: ColumnarBatch = null
try {
val (arrayAddrs, schemaAddrs, exportedRows) = nativeUtil.exportBatchToAddresses(batch)
assert(exportedRows == numRows)

val arrays = arrayAddrs.map(ArrowArray.wrap)
val schemas = schemaAddrs.map(ArrowSchema.wrap)
val vectors = nativeUtil.importVector(arrays, schemas)
imported = new ColumnarBatch(vectors.toArray, numRows)

assert(imported.numCols() == 3)
assert(imported.numRows() == numRows)

val values = (0 until numRows).map(i => imported.column(0).getInt(i))
assert(values.forall(_ == 42), s"expected all 42, got $values")

val nulls = (0 until numRows).map(i => imported.column(1).isNullAt(i))
assert(nulls.forall(identity), s"expected all null, got $nulls")

val ids = (0 until numRows).map(i => imported.column(2).getStruct(i).getInt(0))
assert(ids.forall(_ == 7), s"expected all id 7, got $ids")
val nameNulls = (0 until numRows).map(i => imported.column(2).getStruct(i).isNullAt(1))
assert(nameNulls.forall(identity), s"expected all name null, got $nameNulls")
} finally {
if (imported != null) {
imported.close()
}
nativeUtil.close()
}
}
}
108 changes: 108 additions & 0 deletions spark/src/test/scala/org/apache/spark/sql/comet/util/UtilsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
package org.apache.spark.sql.comet.util

import org.apache.spark.sql.CometTestBase
import org.apache.spark.sql.execution.vectorized.ConstantColumnVector
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType, TimestampType}
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}

class UtilsSuite extends CometTestBase {
Expand Down Expand Up @@ -51,4 +53,110 @@ class UtilsSuite extends CometTestBase {
val decoded = coalesced.iterator.flatMap(b => Utils.decodeBatches(b, "test")).toSeq
assert(decoded.map(_.numRows()).sum == expected)
}

test("serializeBatches materializes ConstantColumnVector columns") {
// Spark wraps file-source partition columns and other per-batch constants in
// ConstantColumnVector. When such a batch reaches Comet's serialization/export path
// (getBatchFieldVectors), it must be materialized to an Arrow vector rather than
// rejected with "Comet execution only takes Arrow Arrays".
val numRows = 4

val valueCol = new ConstantColumnVector(numRows, IntegerType)
valueCol.setInt(42)
val nullCol = new ConstantColumnVector(numRows, IntegerType)
nullCol.setNull()
val batch = new ColumnarBatch(Array[ColumnVector](valueCol, nullCol), numRows)

val (rowCount, buf) = Utils.serializeBatches(Iterator(batch)).next()
assert(rowCount == numRows)

// Read the decoded values eagerly: ArrowReaderIterator releases a batch's buffers once the
// iterator advances past it (hasNext closes the previous batch), so values must be read from
// the current batch before calling hasNext/next again.
val it = Utils.decodeBatches(buf, "test")
assert(it.hasNext)
val out = it.next()
assert(out.numCols() == 2)
assert(out.numRows() == numRows)
val values = (0 until numRows).map(i => out.column(0).getInt(i))
val nulls = (0 until numRows).map(i => out.column(1).isNullAt(i))
assert(!it.hasNext)

assert(values.forall(_ == 42), s"expected all 42, got $values")
assert(nulls.forall(identity), s"expected all null, got $nulls")
}

test("serializeBatches materializes a TimestampType ConstantColumnVector") {
// Covers the TimestampType materialize path (TimestampWriter -> TimeStampMicroTZVector) and
// pins down the "UTC" timezone choice in materializeConstantColumnVector: Spark stores
// TimestampType as micros in UTC, and Comet tags its timestamp Arrow vectors "UTC", so the
// constant micros round-trip unchanged. This guards against anyone later swapping the zone
// argument, which would make the materialised constant's Arrow field metadata diverge from the
// sibling non-constant timestamp columns it shares a VectorSchemaRoot with.
val numRows = 3
// 2023-11-14T22:13:20Z in micros since epoch.
val micros = 1700000000000000L

val tsCol = new ConstantColumnVector(numRows, TimestampType)
tsCol.setLong(micros)
val batch = new ColumnarBatch(Array[ColumnVector](tsCol), numRows)

val (rowCount, buf) = Utils.serializeBatches(Iterator(batch)).next()
assert(rowCount == numRows)

val it = Utils.decodeBatches(buf, "test")
assert(it.hasNext)
val out = it.next()
assert(out.numCols() == 1)
assert(out.numRows() == numRows)
val got = (0 until numRows).map(i => out.column(0).getLong(i))
assert(!it.hasNext)

assert(got.forall(_ == micros), s"expected all $micros, got $got")
}

test("serializeBatches materializes a nullable StructType ConstantColumnVector") {
// Exercises a different ArrowFieldWriter path than the scalar cases: a struct constant is
// written via getStruct(rowId) -> getChild(ordinal). Covers both a non-null struct (with a
// null nested field) and a wholly-null struct constant.
val numRows = 3
val schema = StructType(
Seq(StructField("id", IntegerType), StructField("name", StringType, nullable = true)))

// Non-null struct whose `name` field is null, proving nested nullability round-trips.
val structCol = new ConstantColumnVector(numRows, schema)
structCol.setNotNull()
val idChild = new ConstantColumnVector(numRows, IntegerType)
idChild.setInt(7)
val nameChild = new ConstantColumnVector(numRows, StringType)
nameChild.setNull()
structCol.setChild(0, idChild)
structCol.setChild(1, nameChild)

// A wholly-null struct constant.
val nullStructCol = new ConstantColumnVector(numRows, schema)
nullStructCol.setNull()
nullStructCol.setChild(0, new ConstantColumnVector(numRows, IntegerType))
nullStructCol.setChild(1, new ConstantColumnVector(numRows, StringType))

val batch =
new ColumnarBatch(Array[ColumnVector](structCol, nullStructCol), numRows)

val (rowCount, buf) = Utils.serializeBatches(Iterator(batch)).next()
assert(rowCount == numRows)

val it = Utils.decodeBatches(buf, "test")
assert(it.hasNext)
val out = it.next()
assert(out.numCols() == 2)
assert(out.numRows() == numRows)
val ids = (0 until numRows).map(i => out.column(0).getStruct(i).getInt(0))
val nameNulls = (0 until numRows).map(i => out.column(0).getStruct(i).isNullAt(1))
val structNulls = (0 until numRows).map(i => out.column(1).isNullAt(i))
assert(!it.hasNext)

assert(ids.forall(_ == 7), s"expected all id 7, got $ids")
assert(nameNulls.forall(identity), s"expected all name null, got $nameNulls")
assert(structNulls.forall(identity), s"expected all struct null, got $structNulls")
}
}