Skip to content

Commit 37e4a23

Browse files
committed
fix: Convert Spark columnar batches to Arrow in CometNativeWriteExec (#2944)
1 parent 3a6452b commit 37e4a23

2 files changed

Lines changed: 148 additions & 8 deletions

File tree

spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,19 @@ import scala.jdk.CollectionConverters._
2626
import org.apache.hadoop.fs.Path
2727
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext, TaskAttemptID, TaskID, TaskType}
2828
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
29+
import org.apache.spark.TaskContext
2930
import org.apache.spark.internal.io.FileCommitProtocol
3031
import org.apache.spark.rdd.RDD
3132
import org.apache.spark.sql.catalyst.InternalRow
33+
import org.apache.spark.sql.comet.execution.arrow.CometArrowConverters
3234
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
3335
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
3436
import org.apache.spark.sql.vectorized.ColumnarBatch
3537
import org.apache.spark.util.Utils
3638

37-
import org.apache.comet.CometExecIterator
39+
import org.apache.comet.{CometConf, CometExecIterator}
3840
import org.apache.comet.serde.OperatorOuterClass.Operator
41+
import org.apache.comet.vector.CometVector
3942

4043
/**
4144
* Comet physical operator for native Parquet write operations with FileCommitProtocol support.
@@ -138,16 +141,21 @@ case class CometNativeWriteExec(
138141
}
139142

140143
override def doExecuteColumnar(): RDD[ColumnarBatch] = {
144+
// Check if the child produces Arrow/Comet batches or Spark batches
145+
val childIsComet = child.isInstanceOf[CometPlan]
146+
141147
// Get the input data from the child operator
142148
val childRDD = if (child.supportsColumnar) {
143149
child.executeColumnar()
144150
} else {
145-
// If child doesn't support columnar, convert to columnar
146-
child.execute().mapPartitionsInternal { _ =>
147-
// TODO this could delegate to CometRowToColumnar, but maybe Comet
148-
// does not need to support this case?
149-
throw new UnsupportedOperationException(
150-
"Row-based child operators not yet supported for native write")
151+
// If child doesn't support columnar, convert rows to Arrow columnar batches
152+
val maxRecordsPerBatch = CometConf.COMET_BATCH_SIZE.get(conf)
153+
val timeZoneId = conf.sessionLocalTimeZone
154+
val schema = child.schema
155+
child.execute().mapPartitionsInternal { rowIter =>
156+
val context = TaskContext.get()
157+
CometArrowConverters
158+
.rowToArrowBatchIter(rowIter, schema, maxRecordsPerBatch, timeZoneId, context)
151159
}
152160
}
153161

@@ -158,6 +166,10 @@ case class CometNativeWriteExec(
158166
val capturedJobTrackerID = jobTrackerID
159167
val capturedNativeOp = nativeOp
160168
val capturedAccumulator = taskCommitMessagesAccum // Capture accumulator for use in tasks
169+
val capturedChildIsComet = childIsComet
170+
val capturedSchema = child.schema
171+
val capturedMaxRecordsPerBatch = CometConf.COMET_BATCH_SIZE.get(conf)
172+
val capturedTimeZoneId = conf.sessionLocalTimeZone
161173

162174
// Execute native write operation with task-level commit protocol
163175
childRDD.mapPartitionsInternal { iter =>
@@ -201,9 +213,28 @@ case class CometNativeWriteExec(
201213
outputStream.close()
202214
val planBytes = outputStream.toByteArray
203215

216+
// Convert Spark columnar batches to Arrow format if child is not a Comet operator.
217+
// Comet native execution expects Arrow arrays, but Spark operators like RangeExec
218+
// produce OnHeapColumnVector which must be converted.
219+
val arrowIter = if (capturedChildIsComet) {
220+
// Child is already producing Arrow/Comet batches
221+
iter
222+
} else {
223+
// Convert Spark columnar batches to Arrow format
224+
val context = TaskContext.get()
225+
iter.flatMap { sparkBatch =>
226+
CometArrowConverters.columnarBatchToArrowBatchIter(
227+
sparkBatch,
228+
capturedSchema,
229+
capturedMaxRecordsPerBatch,
230+
capturedTimeZoneId,
231+
context)
232+
}
233+
}
234+
204235
val execIterator = new CometExecIterator(
205236
CometExec.newIterId,
206-
Seq(iter),
237+
Seq(arrowIter),
207238
numOutputCols,
208239
planBytes,
209240
nativeMetrics,

spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,4 +228,113 @@ class CometParquetWriterSuite extends CometTestBase {
228228
}
229229
}
230230
}
231+
232+
test("parquet write with spark.range() as data source - with spark-to-arrow conversion") {
233+
// Test that spark.range() works when CometSparkToColumnarExec is enabled to convert
234+
// Spark's OnHeapColumnVector to Arrow format
235+
withTempPath { dir =>
236+
val outputPath = new File(dir, "output.parquet").getAbsolutePath
237+
238+
withSQLConf(
239+
CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
240+
SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Halifax",
241+
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true",
242+
CometConf.COMET_EXEC_ENABLED.key -> "true",
243+
CometConf.COMET_SPARK_TO_ARROW_ENABLED.key -> "true",
244+
CometConf.COMET_SPARK_TO_ARROW_SUPPORTED_OPERATOR_LIST.key -> "Range") {
245+
246+
// Use a listener to capture the execution plan during write
247+
var capturedPlan: Option[org.apache.spark.sql.execution.QueryExecution] = None
248+
249+
val listener = new org.apache.spark.sql.util.QueryExecutionListener {
250+
override def onSuccess(
251+
funcName: String,
252+
qe: org.apache.spark.sql.execution.QueryExecution,
253+
durationNs: Long): Unit = {
254+
if (funcName == "save" || funcName.contains("command")) {
255+
capturedPlan = Some(qe)
256+
}
257+
}
258+
259+
override def onFailure(
260+
funcName: String,
261+
qe: org.apache.spark.sql.execution.QueryExecution,
262+
exception: Exception): Unit = {}
263+
}
264+
265+
spark.listenerManager.register(listener)
266+
267+
try {
268+
// spark.range() uses RangeExec which produces OnHeapColumnVector
269+
// CometSparkToColumnarExec converts these to Arrow format
270+
spark.range(1000).write.mode("overwrite").parquet(outputPath)
271+
272+
// Wait for listener
273+
val maxWaitTimeMs = 15000
274+
val checkIntervalMs = 100
275+
var iterations = 0
276+
277+
while (capturedPlan.isEmpty && iterations < maxWaitTimeMs / checkIntervalMs) {
278+
Thread.sleep(checkIntervalMs)
279+
iterations += 1
280+
}
281+
282+
// Verify that CometNativeWriteExec was used
283+
capturedPlan.foreach { qe =>
284+
val executedPlan = stripAQEPlan(qe.executedPlan)
285+
286+
var nativeWriteCount = 0
287+
executedPlan.foreach {
288+
case _: CometNativeWriteExec =>
289+
nativeWriteCount += 1
290+
case d: DataWritingCommandExec =>
291+
d.child.foreach {
292+
case _: CometNativeWriteExec =>
293+
nativeWriteCount += 1
294+
case _ =>
295+
}
296+
case _ =>
297+
}
298+
299+
assert(
300+
nativeWriteCount == 1,
301+
s"Expected exactly one CometNativeWriteExec in the plan, but found $nativeWriteCount:\n${executedPlan.treeString}")
302+
}
303+
304+
// Verify the data was written correctly
305+
val resultDf = spark.read.parquet(outputPath)
306+
assert(resultDf.count() == 1000, "Expected 1000 rows to be written")
307+
} finally {
308+
spark.listenerManager.unregister(listener)
309+
}
310+
}
311+
}
312+
}
313+
314+
test("parquet write with spark.range() - issue #2944 without spark-to-arrow") {
315+
// This test reproduces https://github.com/apache/datafusion-comet/issues/2944
316+
// Without CometSparkToColumnarExec enabled, the native writer should handle
317+
// Spark columnar batches by converting them to Arrow format internally.
318+
withTempPath { dir =>
319+
val outputPath = new File(dir, "output.parquet").getAbsolutePath
320+
321+
withSQLConf(
322+
CometConf.COMET_NATIVE_PARQUET_WRITE_ENABLED.key -> "true",
323+
SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Halifax",
324+
CometConf.getOperatorAllowIncompatConfigKey(classOf[DataWritingCommandExec]) -> "true",
325+
CometConf.COMET_EXEC_ENABLED.key -> "true",
326+
// Explicitly disable spark-to-arrow conversion to reproduce the issue
327+
CometConf.COMET_SPARK_TO_ARROW_ENABLED.key -> "false") {
328+
329+
// spark.range() uses RangeExec which produces OnHeapColumnVector (not Arrow)
330+
// Without the fix, this would fail with:
331+
// "Comet execution only takes Arrow Arrays, but got OnHeapColumnVector"
332+
spark.range(1000).write.mode("overwrite").parquet(outputPath)
333+
334+
// Verify the data was written correctly
335+
val resultDf = spark.read.parquet(outputPath)
336+
assert(resultDf.count() == 1000, "Expected 1000 rows to be written")
337+
}
338+
}
339+
}
231340
}

0 commit comments

Comments
 (0)