diff --git a/spark/src/main/scala/org/apache/comet/vector/NativeUtil.scala b/spark/src/main/scala/org/apache/comet/vector/NativeUtil.scala index 4f027cd9e7..895a2a2a9a 100644 --- a/spark/src/main/scala/org/apache/comet/vector/NativeUtil.scala +++ b/spark/src/main/scala/org/apache/comet/vector/NativeUtil.scala @@ -137,6 +137,19 @@ class NativeUtil { provider, arrowArray, arrowSchema) + case cv: org.apache.spark.sql.execution.vectorized.ConstantColumnVector => + // Spark uses ConstantColumnVector for partition columns / per-batch constants (e.g. + // partition values, synthetic columns). Materialise to a fresh Arrow vector so Comet's + // native side -- which expects Arrow Arrays only -- can ingest the batch. Without this, + // queries that pull constants through a Comet operator fail with "Comet execution only + // takes Arrow Arrays". + val rows = batch.numRows() + numRows += rows + val materialised = org.apache.spark.sql.comet.util.Utils + .materializeConstantColumnVector(cv, cv.dataType(), rows, s"_const_$index", allocator) + val arrowSchema = ArrowSchema.wrap(schemaAddrs(index)) + val arrowArray = ArrowArray.wrap(arrayAddrs(index)) + Data.exportVector(allocator, materialised, null, arrowArray, arrowSchema) case c => throw new SparkException( "Comet execution only takes Arrow Arrays, but got " + diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ArrowWriters.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ArrowWriters.scala index 342441ce28..092805cb20 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ArrowWriters.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ArrowWriters.scala @@ -21,12 +21,14 @@ package org.apache.spark.sql.comet.execution.arrow import scala.jdk.CollectionConverters._ +import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector._ import org.apache.arrow.vector.complex._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.vectorized.ConstantColumnVector import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarArray @@ -91,6 +93,36 @@ private[arrow] object ArrowWriter { } } +/** + * Materialises a Spark `ConstantColumnVector` (partition values / per-batch constants) into a + * fresh Arrow `FieldVector` holding the constant repeated `numRows` times. + * + * Reuses the per-type `ArrowFieldWriter`s above -- so EVERY type is covered (scalars, decimal, + * timestamps, and complex struct/array/map) and the logic stays in sync with Spark -- rather than + * a hand-rolled per-type switch. `ConstantColumnVector` returns its constant for any rowId, so a + * `ColumnarArray` view over rows `[0, numRows)` writes the constant (or null) `numRows` times. + * + * Lives in this package because `ArrowWriter` is `private[arrow]`. The caller owns the returned + * vector and must close it (or hand it to Arrow's exporter, which takes ownership). + */ +object ConstantColumnVectors { + def materialize( + cv: ConstantColumnVector, + dt: DataType, + numRows: Int, + name: String, + allocator: BufferAllocator, + timeZoneId: String): FieldVector = { + val field = Utils.toArrowField(name, dt, nullable = true, timeZoneId) + val vector = field.createVector(allocator) + vector.allocateNew() + val writer = ArrowWriter.createFieldWriter(vector) + writer.writeCol(new ColumnarArray(cv, 0, numRows)) + writer.finish() + vector + } +} + class ArrowWriter(val root: VectorSchemaRoot, fields: Array[ArrowFieldWriter]) { def schema: StructType = Utils.fromArrowSchema(root.getSchema()) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala b/spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala index 15e1e2c410..9ae8625f27 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala @@ -26,6 +26,7 @@ import java.nio.channels.Channels import scala.jdk.CollectionConverters._ import org.apache.arrow.c.CDataDictionaryProvider +import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector._ import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector} import org.apache.arrow.vector.dictionary.DictionaryProvider @@ -38,6 +39,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.io.CompressionCodec import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.comet.execution.arrow.ArrowReaderIterator +import org.apache.spark.sql.execution.vectorized.ConstantColumnVector import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} @@ -387,6 +389,7 @@ object Utils extends CometTypeShim with Logging { def getBatchFieldVectors( batch: ColumnarBatch): (Seq[FieldVector], Option[DictionaryProvider]) = { var provider: Option[DictionaryProvider] = None + val rows = batch.numRows() val fieldVectors = (0 until batch.numCols()).map { index => batch.column(index) match { case a: CometVector => @@ -399,6 +402,17 @@ object Utils extends CometTypeShim with Logging { getFieldVector(valueVector, "serialize") + case cv: ConstantColumnVector => + // Spark wraps file-source partition columns and other per-batch constants in + // `ConstantColumnVector`. Materialise to an Arrow vector so the serialisation path + // doesn't reject the batch. + materializeConstantColumnVector( + cv, + cv.dataType(), + rows, + s"_const_$index", + org.apache.comet.CometArrowAllocator) + case c => throw new SparkException( s"Comet execution only takes Arrow Arrays, but got ${c.getClass}. " + @@ -426,4 +440,39 @@ object Utils extends CometTypeShim with Logging { throw new SparkException(s"Unsupported Arrow Vector for $reason: ${valueVector.getClass}") } } + + /** + * Materialize a Spark `ConstantColumnVector` into a fresh Arrow `FieldVector` whose value is + * the same constant repeated `numRows` times. + * + * Spark wraps file-source partition columns and other per-batch constants in + * `ConstantColumnVector`; downstream Comet operators feeding `NativeUtil.exportBatch` or + * `getBatchFieldVectors` trip on it because those paths only handle `CometVector`. This helper + * materializes the constant into an Arrow vector inline. + * + * The caller owns the returned vector and must close it (or hand it to Arrow's exporter, which + * transfers ownership). The vector is allocated against `allocator`, sized to exactly + * `numRows`, and pre-filled with the constant value (or null when `cv.isNullAt(0)`). + * + * All Spark types are supported (delegates to the per-type ArrowFieldWriters, which include + * struct/array/map); throws only for a type Arrow itself can't represent. + */ + def materializeConstantColumnVector( + cv: ConstantColumnVector, + dt: DataType, + numRows: Int, + name: String, + allocator: BufferAllocator): FieldVector = { + // "UTC" is deliberate here, NOT the session-local timezone that `toArrowSchema` threads + // through. These constants are materialised alongside non-constant columns in the same + // batch/`VectorSchemaRoot`, and Comet's non-constant `TimestampType` columns are Arrow + // vectors exported from native execution, where Comet always tags them `Timestamp(us, "UTC")` + // (see native `serde.rs`). Spark itself stores `TimestampType` as micros in UTC, so the + // constant's value is already a UTC instant. Tagging the materialised constant "UTC" keeps its + // Arrow timezone metadata consistent with its sibling timestamp columns; threading the + // session-local timezone here would instead introduce the mismatch. `TimestampNTZType` carries + // no zone regardless of this argument. + org.apache.spark.sql.comet.execution.arrow.ConstantColumnVectors + .materialize(cv, dt, numRows, name, allocator, "UTC") + } } diff --git a/spark/src/test/scala/org/apache/comet/vector/NativeUtilSuite.scala b/spark/src/test/scala/org/apache/comet/vector/NativeUtilSuite.scala new file mode 100644 index 0000000000..6024873fba --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/vector/NativeUtilSuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.vector + +import org.apache.arrow.c.{ArrowArray, ArrowSchema} +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.execution.vectorized.ConstantColumnVector +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} + +class NativeUtilSuite extends CometTestBase { + + test("exportBatch round-trips a ConstantColumnVector through Arrow FFI") { + // Smoke test for the ConstantColumnVector arm of NativeUtil.exportBatch: a batch carrying + // Spark ConstantColumnVectors (partition values / per-batch constants) is exported across the + // Arrow C Data Interface and imported back, exercising materializeConstantColumnVector + + // Data.exportVector + the allocator handoff -- the FFI wiring that the serializeBatches test + // does not cover. Mirrors the export/import round trip that NativeUtil.getNextBatch performs + // in production, just without a native callee. + val numRows = 4 + + val valueCol = new ConstantColumnVector(numRows, IntegerType) + valueCol.setInt(42) + val nullCol = new ConstantColumnVector(numRows, IntegerType) + nullCol.setNull() + + // A struct constant exercises the complex-type export path (getStruct/getChild) through FFI. + val structSchema = StructType( + Seq(StructField("id", IntegerType), StructField("name", StringType, nullable = true))) + val structCol = new ConstantColumnVector(numRows, structSchema) + structCol.setNotNull() + val idChild = new ConstantColumnVector(numRows, IntegerType) + idChild.setInt(7) + val nameChild = new ConstantColumnVector(numRows, StringType) + nameChild.setNull() + structCol.setChild(0, idChild) + structCol.setChild(1, nameChild) + + val batch = + new ColumnarBatch(Array[ColumnVector](valueCol, nullCol, structCol), numRows) + + val nativeUtil = new NativeUtil + var imported: ColumnarBatch = null + try { + val (arrayAddrs, schemaAddrs, exportedRows) = nativeUtil.exportBatchToAddresses(batch) + assert(exportedRows == numRows) + + val arrays = arrayAddrs.map(ArrowArray.wrap) + val schemas = schemaAddrs.map(ArrowSchema.wrap) + val vectors = nativeUtil.importVector(arrays, schemas) + imported = new ColumnarBatch(vectors.toArray, numRows) + + assert(imported.numCols() == 3) + assert(imported.numRows() == numRows) + + val values = (0 until numRows).map(i => imported.column(0).getInt(i)) + assert(values.forall(_ == 42), s"expected all 42, got $values") + + val nulls = (0 until numRows).map(i => imported.column(1).isNullAt(i)) + assert(nulls.forall(identity), s"expected all null, got $nulls") + + val ids = (0 until numRows).map(i => imported.column(2).getStruct(i).getInt(0)) + assert(ids.forall(_ == 7), s"expected all id 7, got $ids") + val nameNulls = (0 until numRows).map(i => imported.column(2).getStruct(i).isNullAt(1)) + assert(nameNulls.forall(identity), s"expected all name null, got $nameNulls") + } finally { + if (imported != null) { + imported.close() + } + nativeUtil.close() + } + } +} diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/util/UtilsSuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/util/UtilsSuite.scala index a79b862793..c3b00a2814 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/util/UtilsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/util/UtilsSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.comet.util import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.execution.vectorized.ConstantColumnVector +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType, TimestampType} import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} class UtilsSuite extends CometTestBase { @@ -51,4 +53,110 @@ class UtilsSuite extends CometTestBase { val decoded = coalesced.iterator.flatMap(b => Utils.decodeBatches(b, "test")).toSeq assert(decoded.map(_.numRows()).sum == expected) } + + test("serializeBatches materializes ConstantColumnVector columns") { + // Spark wraps file-source partition columns and other per-batch constants in + // ConstantColumnVector. When such a batch reaches Comet's serialization/export path + // (getBatchFieldVectors), it must be materialized to an Arrow vector rather than + // rejected with "Comet execution only takes Arrow Arrays". + val numRows = 4 + + val valueCol = new ConstantColumnVector(numRows, IntegerType) + valueCol.setInt(42) + val nullCol = new ConstantColumnVector(numRows, IntegerType) + nullCol.setNull() + val batch = new ColumnarBatch(Array[ColumnVector](valueCol, nullCol), numRows) + + val (rowCount, buf) = Utils.serializeBatches(Iterator(batch)).next() + assert(rowCount == numRows) + + // Read the decoded values eagerly: ArrowReaderIterator releases a batch's buffers once the + // iterator advances past it (hasNext closes the previous batch), so values must be read from + // the current batch before calling hasNext/next again. + val it = Utils.decodeBatches(buf, "test") + assert(it.hasNext) + val out = it.next() + assert(out.numCols() == 2) + assert(out.numRows() == numRows) + val values = (0 until numRows).map(i => out.column(0).getInt(i)) + val nulls = (0 until numRows).map(i => out.column(1).isNullAt(i)) + assert(!it.hasNext) + + assert(values.forall(_ == 42), s"expected all 42, got $values") + assert(nulls.forall(identity), s"expected all null, got $nulls") + } + + test("serializeBatches materializes a TimestampType ConstantColumnVector") { + // Covers the TimestampType materialize path (TimestampWriter -> TimeStampMicroTZVector) and + // pins down the "UTC" timezone choice in materializeConstantColumnVector: Spark stores + // TimestampType as micros in UTC, and Comet tags its timestamp Arrow vectors "UTC", so the + // constant micros round-trip unchanged. This guards against anyone later swapping the zone + // argument, which would make the materialised constant's Arrow field metadata diverge from the + // sibling non-constant timestamp columns it shares a VectorSchemaRoot with. + val numRows = 3 + // 2023-11-14T22:13:20Z in micros since epoch. + val micros = 1700000000000000L + + val tsCol = new ConstantColumnVector(numRows, TimestampType) + tsCol.setLong(micros) + val batch = new ColumnarBatch(Array[ColumnVector](tsCol), numRows) + + val (rowCount, buf) = Utils.serializeBatches(Iterator(batch)).next() + assert(rowCount == numRows) + + val it = Utils.decodeBatches(buf, "test") + assert(it.hasNext) + val out = it.next() + assert(out.numCols() == 1) + assert(out.numRows() == numRows) + val got = (0 until numRows).map(i => out.column(0).getLong(i)) + assert(!it.hasNext) + + assert(got.forall(_ == micros), s"expected all $micros, got $got") + } + + test("serializeBatches materializes a nullable StructType ConstantColumnVector") { + // Exercises a different ArrowFieldWriter path than the scalar cases: a struct constant is + // written via getStruct(rowId) -> getChild(ordinal). Covers both a non-null struct (with a + // null nested field) and a wholly-null struct constant. + val numRows = 3 + val schema = StructType( + Seq(StructField("id", IntegerType), StructField("name", StringType, nullable = true))) + + // Non-null struct whose `name` field is null, proving nested nullability round-trips. + val structCol = new ConstantColumnVector(numRows, schema) + structCol.setNotNull() + val idChild = new ConstantColumnVector(numRows, IntegerType) + idChild.setInt(7) + val nameChild = new ConstantColumnVector(numRows, StringType) + nameChild.setNull() + structCol.setChild(0, idChild) + structCol.setChild(1, nameChild) + + // A wholly-null struct constant. + val nullStructCol = new ConstantColumnVector(numRows, schema) + nullStructCol.setNull() + nullStructCol.setChild(0, new ConstantColumnVector(numRows, IntegerType)) + nullStructCol.setChild(1, new ConstantColumnVector(numRows, StringType)) + + val batch = + new ColumnarBatch(Array[ColumnVector](structCol, nullStructCol), numRows) + + val (rowCount, buf) = Utils.serializeBatches(Iterator(batch)).next() + assert(rowCount == numRows) + + val it = Utils.decodeBatches(buf, "test") + assert(it.hasNext) + val out = it.next() + assert(out.numCols() == 2) + assert(out.numRows() == numRows) + val ids = (0 until numRows).map(i => out.column(0).getStruct(i).getInt(0)) + val nameNulls = (0 until numRows).map(i => out.column(0).getStruct(i).isNullAt(1)) + val structNulls = (0 until numRows).map(i => out.column(1).isNullAt(i)) + assert(!it.hasNext) + + assert(ids.forall(_ == 7), s"expected all id 7, got $ids") + assert(nameNulls.forall(identity), s"expected all name null, got $nameNulls") + assert(structNulls.forall(identity), s"expected all struct null, got $structNulls") + } }