From 827a8ca063d284b254e6c9d1120e6c2a2dd9b70c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Mar 2026 15:14:25 -0600 Subject: [PATCH 01/16] docs: add design spec for shuffle direct read optimization Adds a design document for bypassing Arrow FFI in the shuffle read path when both the shuffle writer and downstream operator are native. --- .../2026-03-18-shuffle-direct-read-design.md | 163 ++++++++++++++++++ 1 file changed, 163 insertions(+) create mode 100644 docs/superpowers/specs/2026-03-18-shuffle-direct-read-design.md diff --git a/docs/superpowers/specs/2026-03-18-shuffle-direct-read-design.md b/docs/superpowers/specs/2026-03-18-shuffle-direct-read-design.md new file mode 100644 index 0000000000..2f002a2d89 --- /dev/null +++ b/docs/superpowers/specs/2026-03-18-shuffle-direct-read-design.md @@ -0,0 +1,163 @@ +# Shuffle Direct Read: Bypass FFI for Native Shuffle Read Path + +## Problem + +When a native shuffle exchange feeds into a downstream native operator, shuffle data crosses the JVM/native FFI boundary twice: + +1. **Native to JVM**: `decodeShuffleBlock` JNI call decompresses Arrow IPC, creates a `RecordBatch`, and exports it via Arrow C Data Interface (per-column `FFI_ArrowArray` + `FFI_ArrowSchema` allocation, export, and import). +2. **JVM to Native**: `CometBatchIterator` re-exports the `ColumnarBatch` via Arrow C Data Interface back to native, where `ScanExec` imports and copies/unpacks the arrays. + +Each crossing involves per-column schema serialization, struct allocation, and array copying. For queries with many shuffle stages or wide schemas, this overhead is significant. + +## Solution + +Introduce a direct read path where native code consumes compressed shuffle blocks directly, bypassing Arrow FFI entirely. The JVM reads raw bytes from Spark's shuffle infrastructure and hands them to native via a `DirectByteBuffer` (zero-copy pointer access). Native decompresses and decodes in-place, feeding `RecordBatch` directly into the execution plan. + +### Data Flow Comparison + +**Current path (double FFI):** + +``` +Shuffle stream + -> NativeBatchDecoderIterator (JVM) + -> JNI: decodeShuffleBlock + -> FFI export: RecordBatch -> ArrowArray/Schema (native -> JVM) + -> ColumnarBatch on JVM + -> CometBatchIterator + -> FFI export: ColumnarBatch -> ArrowArray/Schema (JVM -> native) + -> ScanExec imports + copies arrays + -> Native operators +``` + +**New path (zero FFI):** + +``` +Shuffle stream + -> CometShuffleBlockIterator (JVM) + -> reads header + compressed body into DirectByteBuffer + -> holds bytes, waits for native pull + +ShuffleScanExec (native, pull-based) + -> JNI callback: iterator.hasNext()/getBuffer() + -> read_ipc_compressed() -> RecordBatch + -> feeds directly into native execution plan +``` + +## Scope + +- Native shuffle (`CometNativeShuffle`) only. JVM columnar shuffle is excluded because its per-batch dictionary encoding decisions can change the schema between batches. +- Both paths (old and new) are retained. A config flag controls which is used. + +## Components + +### New JVM Components + +#### `CometShuffleBlockIterator` (Java) + +A new class that wraps a shuffle `InputStream` and exposes raw compressed blocks for native consumption. Absorbs the header-reading and buffer-management logic from `NativeBatchDecoderIterator`, but does not decode. + +JNI-callable interface: + +- `hasNext() -> int`: Reads the next block's header from the stream. The header is 16 bytes: 8-byte compressed length (includes the 8-byte field count but not itself) + 8-byte field count. The field count from the header is discarded — the schema is determined by the `ShuffleScan` protobuf's `fields` list, which is authoritative. Returns the compressed body length in bytes (i.e., `compressedLength - 8`, which includes the 4-byte codec prefix + compressed IPC data), or -1 for EOF. +- `getBuffer() -> ByteBuffer`: Returns the `DirectByteBuffer` containing the current block's compressed bytes (4-byte codec prefix + compressed IPC data). This buffer is only valid until the next `hasNext()` call — the caller must fully consume it (via `read_ipc_compressed()`, which decompresses into a new allocation) before pulling the next block. + +Uses its own `DirectByteBuffer` instance (not shared with `NativeBatchDecoderIterator`) with the same pooling strategy: initial 128KB, grows as needed, reset on close. + +**Lifecycle**: Implements `Closeable`. `close()` closes the underlying shuffle `InputStream` and resets the buffer. `CometBlockStoreShuffleReader` registers a task completion listener to close it, matching the existing pattern for `NativeBatchDecoderIterator`. + +### New Native Components + +#### `ShuffleScanExec` (Rust) + +Location: `native/core/src/execution/operators/shuffle_scan.rs` + +A new `ExecutionPlan` operator that replaces `ScanExec` at shuffle boundaries. On each `poll_next`: + +1. Calls JNI into `CometShuffleBlockIterator.hasNext()` to get the next block's byte length (or -1 for EOF). +2. Calls `CometShuffleBlockIterator.getBuffer()` to get a `DirectByteBuffer`. +3. Obtains the buffer's raw pointer via `JNIEnv::get_direct_buffer_address()` and creates a slice over it (zero-copy, same pattern as `decodeShuffleBlock`). +4. Calls `read_ipc_compressed()` to decompress and decode into a `RecordBatch`. This allocates new memory for the decompressed data — the `DirectByteBuffer` can be safely reused afterward. +5. Returns the `RecordBatch` directly to the downstream native operator. + +No `FFI_ArrowArray`, `FFI_ArrowSchema`, `ArrowImporter`, or `CometVector` involved. + +Implements `on_close` for cleanup (releasing the JNI `GlobalRef`), matching the `ScanExec` pattern. + +#### `ShuffleScan` Protobuf Message + +Location: `native/proto/src/proto/operator.proto` + +New message alongside existing `Scan`: + +```protobuf +message ShuffleScan { + repeated spark.spark_expression.DataType fields = 1; + string source = 2; // Informational label (e.g., "CometShuffleExchangeExec [id=5]") +} +``` + +The `Operator` message gains a new `shuffle_scan` field in its oneof. + +### Modified JVM Components + +#### `CometExchangeSink` / `CometExecRule` + +The decision to use `ShuffleScan` vs `Scan` is made when `CometNativeExec` is constructed (not during the bottom-up conversion pass). At that point, the operator tree is already converted: `CometExecRule.convertBlock()` wraps a contiguous group of native operators into `CometNativeExec` and serializes the protobuf plan. The children (including `CometSinkPlaceHolder` wrapping shuffle exchanges) are already known. So the check is: when serializing a `CometSinkPlaceHolder` whose `originalPlan` is a `CometShuffleExchangeExec` with `shuffleType == CometNativeShuffle`, and the config flag is enabled, emit `ShuffleScan` instead of `Scan`. + +Conditions for `ShuffleScan`: + +1. Shuffle type is `CometNativeShuffle` +2. The sink is inside a `CometNativeExec` block (always true at serialization time — this is where sinks get serialized) +3. Config `spark.comet.shuffle.directRead.enabled` is true (default: true) + +#### `CometNativeExec` (operators.scala) + +When collecting input RDDs and creating iterators, distinguish the two cases: + +- `ShuffleScan` input: Wrap the shuffle RDD's `Iterator[ColumnarBatch]` stream in `CometShuffleBlockIterator` — but note that `CometShuffleBlockIterator` wraps the raw `InputStream` from shuffle blocks, not decoded `ColumnarBatch`. This means the RDD must provide the raw shuffle `InputStream` rather than going through `NativeBatchDecoderIterator`. The `CometShuffledBatchRDD` / `CometBlockStoreShuffleReader` needs a mode where it yields raw `InputStream` objects per block instead of decoded batches. +- `Scan` input: Wrap in `CometBatchIterator` (existing behavior) + +#### `CometExecIterator` — JNI Input Contract + +Currently `CometExecIterator` wraps all inputs as `CometBatchIterator` and passes them to `Native.createPlan()` as `Array[CometBatchIterator]`. To support `CometShuffleBlockIterator`: + +- Change the JNI parameter from `Array[CometBatchIterator]` to `Array[Object]`. On the native side in `createPlan`, the planner already knows from the protobuf whether each input is a `Scan` or `ShuffleScan`, so it knows which JNI methods to call on each `GlobalRef` — no type checking needed at runtime. +- `CometExecIterator` populates the array with either `CometBatchIterator` or `CometShuffleBlockIterator` based on whether the corresponding leaf in the protobuf plan is `Scan` or `ShuffleScan`. + +### Native Planner Changes + +In `planner.rs`, handle the `ShuffleScan` protobuf variant: + +- Consume an input from `inputs.remove(0)` (same pattern as `Scan`) +- Create `ShuffleScanExec` instead of `ScanExec` +- The `GlobalRef` points to a `CometShuffleBlockIterator` Java object + +## Fallback Behavior + +The new path is used only when all conditions above are met. Otherwise, the existing path is used unchanged. The most common fallback case is a shuffle whose output is consumed by a non-native Spark operator (e.g., `collect()`, or an unsupported operator), where the JVM needs a materialized `ColumnarBatch`. + +## Configuration + +| Config | Default | Description | +|--------|---------|-------------| +| `spark.comet.shuffle.directRead.enabled` | `true` | Use direct native read path for native shuffle when downstream operator is native | + +## Error Handling + +- `ShuffleScanExec` reuses `read_ipc_compressed()`, which handles corrupt data and unsupported codecs. +- JNI errors from `CometShuffleBlockIterator` (stream closed, EOF, I/O errors) propagate through the existing `try_unwrap_or_throw` pattern. +- If the JVM iterator throws, the exception surfaces as a Rust error and propagates through DataFusion's error handling. +- Empty batches (zero rows): `read_ipc_compressed()` calls `reader.next().unwrap()` which panics if the stream contains no batches. The shuffle writer never writes zero-row blocks (guarded by `if batch.num_rows() == 0 { return Ok(0) }` in `ShuffleBlockWriter.write_batch`), so this case does not arise. + +## Metrics + +`ShuffleScanExec` tracks and reports: + +- `decodeTime`: Time spent in `read_ipc_compressed()` (decompression + IPC decode). Same metric as `NativeBatchDecoderIterator` reports today. +- Shuffle read metrics (`recordsRead`, `bytesRead`) continue to be reported by `CometBlockStoreShuffleReader` and the `ShuffleBlockFetcherIterator`, which are upstream of the new code and unchanged. + +## Testing + +- Existing shuffle tests (`CometShuffleSuite`) run with the config defaulting to true, automatically covering the new path. +- Add a test that runs the same queries with the config flag on and off, asserting identical results. +- Add a Rust unit test for `ShuffleScanExec` with pre-built compressed IPC blocks (no JNI), using the `TEST_EXEC_CONTEXT_ID` pattern from `ScanExec` tests. From 3a3edb48a0feded6af9c06648a8bfce9aee24f77 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Mar 2026 15:27:46 -0600 Subject: [PATCH 02/16] docs: add implementation plan for shuffle direct read --- .../plans/2026-03-18-shuffle-direct-read.md | 1011 +++++++++++++++++ 1 file changed, 1011 insertions(+) create mode 100644 docs/superpowers/plans/2026-03-18-shuffle-direct-read.md diff --git a/docs/superpowers/plans/2026-03-18-shuffle-direct-read.md b/docs/superpowers/plans/2026-03-18-shuffle-direct-read.md new file mode 100644 index 0000000000..647f122cc4 --- /dev/null +++ b/docs/superpowers/plans/2026-03-18-shuffle-direct-read.md @@ -0,0 +1,1011 @@ +# Shuffle Direct Read Implementation Plan + +> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. + +**Goal:** Eliminate double Arrow FFI crossing at shuffle boundaries by having native code consume compressed IPC blocks directly from JVM-provided byte buffers. + +**Architecture:** A new `ShuffleScanExec` Rust operator pulls raw compressed bytes from a JVM `CometShuffleBlockIterator` via JNI, decompresses and decodes them in native code, and feeds `RecordBatch` directly into the execution plan. This bypasses the current path where data is decoded to JVM `ColumnarBatch` (FFI export), then re-exported back to native (FFI import). + +**Tech Stack:** Scala, Java, Rust, Protobuf, JNI, Arrow IPC + +**Spec:** `docs/superpowers/specs/2026-03-18-shuffle-direct-read-design.md` + +--- + +### Task 1: Add config flag + +**Files:** +- Modify: `common/src/main/scala/org/apache/comet/CometConf.scala` + +- [ ] **Step 1: Add the config entry** + +Find the existing shuffle config entries (search for `COMET_EXEC_SHUFFLE_ENABLED`) and add nearby: + +```scala +val COMET_SHUFFLE_DIRECT_READ_ENABLED: ConfigEntry[Boolean] = + conf("spark.comet.shuffle.directRead.enabled") + .category(CATEGORY_EXEC) + .doc( + "When enabled, native operators that consume shuffle output will read " + + "compressed shuffle blocks directly in native code, bypassing Arrow FFI. " + + "Only applies to native shuffle (not JVM columnar shuffle). " + + "Requires spark.comet.exec.shuffle.enabled to be true.") + .booleanConf + .createWithDefault(true) +``` + +- [ ] **Step 2: Verify it compiles** + +Run: `./mvnw compile -DskipTests -pl common` +Expected: BUILD SUCCESS + +- [ ] **Step 3: Commit** + +```bash +git add common/src/main/scala/org/apache/comet/CometConf.scala +git commit -m "feat: add spark.comet.shuffle.directRead.enabled config" +``` + +--- + +### Task 2: Add ShuffleScan protobuf message + +**Files:** +- Modify: `native/proto/src/proto/operator.proto` + +- [ ] **Step 1: Add ShuffleScan message** + +Add after the existing `Scan` message (after line 86): + +```protobuf +message ShuffleScan { + repeated spark.spark_expression.DataType fields = 1; + // Informational label for debug output (e.g., "CometShuffleExchangeExec [id=5]") + string source = 2; +} +``` + +- [ ] **Step 2: Add shuffle_scan to the Operator oneof** + +In the `oneof op_struct` block (lines 38-55), add after `csv_scan = 115`: + +```protobuf + ShuffleScan shuffle_scan = 116; +``` + +- [ ] **Step 3: Rebuild protobuf and verify** + +Run: `make core` +Expected: Successful build with generated protobuf code. + +- [ ] **Step 4: Commit** + +```bash +git add native/proto/src/proto/operator.proto +git commit -m "feat: add ShuffleScan protobuf message" +``` + +--- + +### Task 3: Create CometShuffleBlockIterator (Java) + +**Files:** +- Create: `spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java` + +- [ ] **Step 1: Create the class** + +```java +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet; + +import java.io.Closeable; +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; + +/** + * Provides raw compressed shuffle blocks to native code via JNI. + * + *

Reads block headers (compressed length + field count) from a shuffle InputStream and loads + * the compressed body into a DirectByteBuffer. Native code pulls blocks by calling hasNext() + * and getBuffer(). + * + *

The DirectByteBuffer returned by getBuffer() is only valid until the next hasNext() call. + * Native code must fully consume it (via read_ipc_compressed which allocates new memory for + * the decompressed data) before pulling the next block. + */ +public class CometShuffleBlockIterator implements Closeable { + + private static final int INITIAL_BUFFER_SIZE = 128 * 1024; + + private final ReadableByteChannel channel; + private final InputStream inputStream; + private final ByteBuffer headerBuf = + ByteBuffer.allocate(16).order(ByteOrder.LITTLE_ENDIAN); + private ByteBuffer dataBuf = ByteBuffer.allocateDirect(INITIAL_BUFFER_SIZE); + private boolean closed = false; + private int currentBlockLength = 0; + + public CometShuffleBlockIterator(InputStream in) { + this.inputStream = in; + this.channel = Channels.newChannel(in); + } + + /** + * Reads the next block header and loads the compressed body into the internal buffer. + * Called by native code via JNI. + * + *

Header format: 8-byte compressedLength (includes field count but not itself) + + * 8-byte fieldCount (discarded, schema comes from protobuf). + * + * @return the compressed body length in bytes (codec prefix + compressed IPC), or -1 if EOF + */ + public int hasNext() throws IOException { + if (closed) { + return -1; + } + + // Read 16-byte header + headerBuf.clear(); + while (headerBuf.hasRemaining()) { + int bytesRead = channel.read(headerBuf); + if (bytesRead < 0) { + if (headerBuf.position() == 0) { + return -1; + } + throw new EOFException( + "Data corrupt: unexpected EOF while reading batch header"); + } + } + headerBuf.flip(); + long compressedLength = headerBuf.getLong(); + // Field count discarded - schema determined by ShuffleScan protobuf fields + headerBuf.getLong(); + + long bytesToRead = compressedLength - 8; + if (bytesToRead > Integer.MAX_VALUE) { + throw new IllegalStateException( + "Native shuffle block size of " + bytesToRead + " exceeds maximum of " + + Integer.MAX_VALUE + ". Try reducing shuffle batch size."); + } + + if (dataBuf.capacity() < bytesToRead) { + int newCapacity = (int) Math.min(bytesToRead * 2L, Integer.MAX_VALUE); + dataBuf = ByteBuffer.allocateDirect(newCapacity); + } + + dataBuf.clear(); + dataBuf.limit((int) bytesToRead); + while (dataBuf.hasRemaining()) { + int bytesRead = channel.read(dataBuf); + if (bytesRead < 0) { + throw new EOFException( + "Data corrupt: unexpected EOF while reading compressed batch"); + } + } + // Note: native side uses get_direct_buffer_address (base pointer) + currentBlockLength, + // not the buffer's position/limit. No flip needed. + + currentBlockLength = (int) bytesToRead; + return currentBlockLength; + } + + /** + * Returns the DirectByteBuffer containing the current block's compressed bytes + * (4-byte codec prefix + compressed IPC data). + * Called by native code via JNI. + */ + public ByteBuffer getBuffer() { + return dataBuf; + } + + /** + * Returns the length of the current block in bytes. + * Called by native code via JNI. + */ + public int getCurrentBlockLength() { + return currentBlockLength; + } + + @Override + public void close() throws IOException { + if (!closed) { + closed = true; + inputStream.close(); + if (dataBuf.capacity() > INITIAL_BUFFER_SIZE) { + dataBuf = ByteBuffer.allocateDirect(INITIAL_BUFFER_SIZE); + } + } + } +} +``` + +- [ ] **Step 2: Verify it compiles** + +Run: `./mvnw compile -DskipTests` +Expected: BUILD SUCCESS + +- [ ] **Step 3: Commit** + +```bash +git add spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java +git commit -m "feat: add CometShuffleBlockIterator for raw shuffle block access" +``` + +--- + +### Task 4: Add JNI bridge for CometShuffleBlockIterator (Rust) + +**Files:** +- Create: `native/core/src/jvm_bridge/shuffle_block_iterator.rs` +- Modify: `native/core/src/jvm_bridge/mod.rs` + +- [ ] **Step 1: Create the JNI bridge struct** + +Create `native/core/src/jvm_bridge/shuffle_block_iterator.rs`: + +```rust +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use jni::signature::Primitive; +use jni::{ + errors::Result as JniResult, + objects::{JClass, JMethodID}, + signature::ReturnType, + JNIEnv, +}; + +/// JNI method IDs for `CometShuffleBlockIterator`. +#[allow(dead_code)] +pub struct CometShuffleBlockIterator<'a> { + pub class: JClass<'a>, + pub method_has_next: JMethodID, + pub method_has_next_ret: ReturnType, + pub method_get_buffer: JMethodID, + pub method_get_buffer_ret: ReturnType, + pub method_get_current_block_length: JMethodID, + pub method_get_current_block_length_ret: ReturnType, +} + +impl<'a> CometShuffleBlockIterator<'a> { + pub const JVM_CLASS: &'static str = "org/apache/comet/CometShuffleBlockIterator"; + + pub fn new(env: &mut JNIEnv<'a>) -> JniResult> { + let class = env.find_class(Self::JVM_CLASS)?; + + Ok(CometShuffleBlockIterator { + class, + method_has_next: env.get_method_id(Self::JVM_CLASS, "hasNext", "()I")?, + method_has_next_ret: ReturnType::Primitive(Primitive::Int), + method_get_buffer: env.get_method_id( + Self::JVM_CLASS, + "getBuffer", + "()Ljava/nio/ByteBuffer;", + )?, + method_get_buffer_ret: ReturnType::Object, + method_get_current_block_length: env.get_method_id( + Self::JVM_CLASS, + "getCurrentBlockLength", + "()I", + )?, + method_get_current_block_length_ret: ReturnType::Primitive(Primitive::Int), + }) + } +} +``` + +- [ ] **Step 2: Register in mod.rs** + +In `native/core/src/jvm_bridge/mod.rs`: + +Add `mod shuffle_block_iterator;` alongside the existing `mod batch_iterator;` (line 174). + +Add `use shuffle_block_iterator::CometShuffleBlockIterator as CometShuffleBlockIteratorBridge;` (to avoid name collision with the operator). + +Add a field to the `JVMClasses` struct (around line 206): +```rust +pub comet_shuffle_block_iterator: CometShuffleBlockIteratorBridge<'a>, +``` + +Initialize it in `JVMClasses::init` alongside the existing `comet_batch_iterator` init (around line 259): +```rust +comet_shuffle_block_iterator: CometShuffleBlockIteratorBridge::new(env).unwrap(), +``` + +- [ ] **Step 3: Add a `jni_call!` compatible accessor** + +Check how `comet_batch_iterator` is called in `scan.rs`. The `jni_call!` macro uses the field name from `JVMClasses`. Ensure `comet_shuffle_block_iterator` follows the same pattern. You may need to add a module in the `jni_bridge` macros — look at how `jni_call!(&mut env, comet_batch_iterator(iter).has_next() -> i32)` is defined and add equivalent patterns for `comet_shuffle_block_iterator`. + +Check `native/core/src/jvm_bridge/` for macro definitions (likely in a separate file or in `mod.rs`) that define the `jni_call!` dispatch for each class. + +- [ ] **Step 4: Verify it compiles** + +Run: `cd native && cargo build` +Expected: Successful build. + +- [ ] **Step 5: Commit** + +```bash +git add native/core/src/jvm_bridge/shuffle_block_iterator.rs +git add native/core/src/jvm_bridge/mod.rs +git commit -m "feat: add JNI bridge for CometShuffleBlockIterator" +``` + +--- + +### Task 5: Create ShuffleScanExec (Rust) + +**Files:** +- Create: `native/core/src/execution/operators/shuffle_scan.rs` +- Modify: `native/core/src/execution/operators/mod.rs` + +**Design decision — pre-pull pattern:** `ShuffleScanExec` MUST use the pre-pull pattern (same as `ScanExec`). The comment at `jni_api.rs:483-488` explains why: JNI calls cannot happen from within `poll_next` on tokio threads. So `ShuffleScanExec` stores a `batch: Arc>>` and `get_next_batch()` is called from `pull_input_batches` before each `poll_next`. + +- [ ] **Step 1: Create shuffle_scan.rs** + +Use `scan.rs` as the template. The key differences: +- `get_next_batch` calls `hasNext()`/`getBuffer()`/`getCurrentBlockLength()` on `CometShuffleBlockIterator` instead of Arrow FFI methods on `CometBatchIterator` +- After getting the `DirectByteBuffer`, call `read_ipc_compressed()` to decode +- No `arrow_ffi_safe` flag, no selection vectors, no `copy_or_unpack_array` +- Track `decode_time` metric + +The core `get_next` method: + +```rust +fn get_next( + exec_context_id: i64, + iter: &JObject, + data_types: &[DataType], +) -> Result { + let mut env = JVMClasses::get_env()?; + + // Call hasNext() — returns block length or -1 for EOF + let block_length: i32 = unsafe { + jni_call!(&mut env, comet_shuffle_block_iterator(iter).has_next() -> i32)? + }; + + if block_length < 0 { + return Ok(InputBatch::EOF); + } + + // Get the DirectByteBuffer + let buffer: JByteBuffer = unsafe { + jni_call!(&mut env, comet_shuffle_block_iterator(iter).get_buffer() -> JObject)? + }.into(); + + // Get raw pointer to the buffer data + let raw_pointer = env.get_direct_buffer_address(&buffer)?; + let length = block_length as usize; + let slice: &[u8] = unsafe { std::slice::from_raw_parts(raw_pointer, length) }; + + // Decompress and decode the IPC block + let batch = read_ipc_compressed(slice)?; + + // Convert RecordBatch columns to InputBatch + let arrays: Vec = batch.columns().to_vec(); + let num_rows = batch.num_rows(); + + Ok(InputBatch::new(arrays, Some(num_rows))) +} +``` + +For the `ExecutionPlan` trait implementation, follow `ScanExec` closely: +- `schema()` returns schema built from `data_types` +- `execute()` returns a `ScanStream` (reuse the same stream type from `scan.rs`) +- The `ScanStream` checks `self.batch` mutex on each `poll_next`, takes the batch if available + +- [ ] **Step 2: Register the module** + +In `native/core/src/execution/operators/mod.rs`, add: + +```rust +mod shuffle_scan; +pub use shuffle_scan::ShuffleScanExec; +``` + +- [ ] **Step 3: Verify it compiles** + +Run: `cd native && cargo build` +Expected: Successful build. + +- [ ] **Step 4: Commit** + +```bash +git add native/core/src/execution/operators/shuffle_scan.rs +git add native/core/src/execution/operators/mod.rs +git commit -m "feat: add ShuffleScanExec native operator for direct shuffle read" +``` + +--- + +### Task 6: Wire ShuffleScanExec into the native planner and pre-pull + +**Files:** +- Modify: `native/core/src/execution/planner.rs` +- Modify: `native/core/src/execution/jni_api.rs` + +**Design decision — separate scan vectors:** The planner's `create_plan` currently returns `(Vec, Arc)`. Change the return type to include shuffle scans: `(Vec, Vec, Arc)`. All intermediate operators pass both vectors through. `ExecutionContext` gets a new `shuffle_scans: Vec` field, and `pull_input_batches` iterates both. + +- [ ] **Step 1: Update create_plan return type** + +In `planner.rs`, change the `create_plan` return type (line 915): + +```rust +) -> Result<(Vec, Vec, Arc), ExecutionError> +``` + +Update every match arm that calls `create_plan` recursively or returns results: +- Single-child operators (Filter, Project, Sort, etc.): destructure as `let (scans, shuffle_scans, child) = ...` and pass both through +- Multi-child operators (joins via `parse_join_parameters`): concatenate both scan vectors from left and right children +- `Scan` arm: returns `(vec![scan.clone()], vec![], ...)` +- Add `ShuffleScan` arm (see step 2) + +This is a mechanical change across many match arms. Each `Ok((scans, ...))` becomes `Ok((scans, shuffle_scans, ...))`. + +Also update `parse_join_parameters` return type similarly. + +- [ ] **Step 2: Add ShuffleScan match arm** + +```rust +OpStruct::ShuffleScan(scan) => { + let data_types = scan.fields.iter().map(to_arrow_datatype).collect_vec(); + + if self.exec_context_id != TEST_EXEC_CONTEXT_ID && inputs.is_empty() { + return Err(GeneralError("No input for shuffle scan".to_string())); + } + + let input_source = + if self.exec_context_id == TEST_EXEC_CONTEXT_ID && inputs.is_empty() { + None + } else { + Some(inputs.remove(0)) + }; + + let shuffle_scan = ShuffleScanExec::new( + self.exec_context_id, + input_source, + &scan.source, + data_types, + )?; + + Ok(( + vec![], + vec![shuffle_scan.clone()], + Arc::new(SparkPlan::new(spark_plan.plan_id, Arc::new(shuffle_scan), vec![])), + )) +} +``` + +- [ ] **Step 3: Update ExecutionContext and pull_input_batches** + +In `jni_api.rs`: + +Add `shuffle_scans: Vec` field to `ExecutionContext` struct (after `scans` on line 153). Initialize as `shuffle_scans: vec![]` in the constructor (line 313). + +Where `create_plan` results are stored (line 542-550): + +```rust +let (scans, shuffle_scans, root_op) = planner.create_plan(...)?; +exec_context.scans = scans; +exec_context.shuffle_scans = shuffle_scans; +``` + +Update `pull_input_batches` (line 490): + +```rust +fn pull_input_batches(exec_context: &mut ExecutionContext) -> Result<(), CometError> { + exec_context.scans.iter_mut().try_for_each(|scan| { + scan.get_next_batch()?; + Ok::<(), CometError>(()) + })?; + exec_context.shuffle_scans.iter_mut().try_for_each(|scan| { + scan.get_next_batch()?; + Ok::<(), CometError>(()) + }) +} +``` + +Also update the `exec_context.scans.is_empty()` check (line 563) to also check `shuffle_scans`: + +```rust +if exec_context.scans.is_empty() && exec_context.shuffle_scans.is_empty() { +``` + +- [ ] **Step 4: Verify it compiles** + +Run: `cd native && cargo build` +Expected: Successful build. + +- [ ] **Step 5: Commit** + +```bash +git add native/core/src/execution/planner.rs +git add native/core/src/execution/jni_api.rs +git commit -m "feat: wire ShuffleScanExec into planner and pre-pull mechanism" +``` + +--- + +### Task 7: Emit ShuffleScan from JVM serde + +**Files:** +- Modify: `spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala` + +The `CometExchangeSink.convert()` receives the outer operator (e.g., `ShuffleQueryStageExec`) not the inner `CometShuffleExchangeExec`. We must unwrap to check `shuffleType`. + +- [ ] **Step 1: Override convert in CometExchangeSink** + +Replace the `CometExchangeSink` object (lines 87-100) with: + +```scala +object CometExchangeSink extends CometSink[SparkPlan] { + + override def isFfiSafe: Boolean = true + + override def convert( + op: SparkPlan, + builder: Operator.Builder, + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { + if (shouldUseShuffleScan(op)) { + convertToShuffleScan(op, builder) + } else { + super.convert(op, builder, childOp: _*) + } + } + + private def shouldUseShuffleScan(op: SparkPlan): Boolean = { + if (!CometConf.COMET_SHUFFLE_DIRECT_READ_ENABLED.get()) return false + + // Extract the CometShuffleExchangeExec from the wrapper + val shuffleExec = op match { + case ShuffleQueryStageExec(_, s: CometShuffleExchangeExec, _) => Some(s) + case ShuffleQueryStageExec(_, ReusedExchangeExec(_, s: CometShuffleExchangeExec), _) => + Some(s) + case s: CometShuffleExchangeExec => Some(s) + case _ => None + } + + shuffleExec.exists(_.shuffleType == CometNativeShuffle) + } + + private def convertToShuffleScan( + op: SparkPlan, + builder: Operator.Builder): Option[OperatorOuterClass.Operator] = { + val supportedTypes = + op.output.forall(a => supportedDataType(a.dataType, allowComplex = true)) + + if (!supportedTypes) { + withInfo(op, "Unsupported data type for shuffle direct read") + return None + } + + val scanBuilder = OperatorOuterClass.ShuffleScan.newBuilder() + val source = op.simpleStringWithNodeId() + if (source.isEmpty) { + scanBuilder.setSource(op.getClass.getSimpleName) + } else { + scanBuilder.setSource(source) + } + + val scanTypes = op.output.flatMap { attr => + serializeDataType(attr.dataType) + } + + if (scanTypes.length == op.output.length) { + scanBuilder.addAllFields(scanTypes.asJava) + builder.clearChildren() + Some(builder.setShuffleScan(scanBuilder).build()) + } else { + withInfo(op, "unsupported data types for shuffle direct read") + // Fall back to regular Scan + None + } + } + + override def createExec(nativeOp: Operator, op: SparkPlan): CometNativeExec = + CometSinkPlaceHolder(nativeOp, op, op) +} +``` + +Add necessary imports at the top of the file: +```scala +import org.apache.spark.sql.comet.execution.shuffle.{CometNativeShuffle, CometShuffleExchangeExec} +import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec +import org.apache.comet.CometConf +``` + +- [ ] **Step 2: Verify it compiles** + +Run: `./mvnw compile -DskipTests` +Expected: BUILD SUCCESS + +- [ ] **Step 3: Commit** + +```bash +git add spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala +git commit -m "feat: emit ShuffleScan protobuf for native shuffle with direct read" +``` + +--- + +### Task 8: Wire CometShuffleBlockIterator into JVM execution path + +**Files:** +- Modify: `spark/src/main/scala/org/apache/comet/Native.scala` +- Modify: `spark/src/main/scala/org/apache/comet/CometExecIterator.scala` +- Modify: `spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala` +- Modify: `spark/src/main/scala/org/apache/spark/sql/comet/operators.scala` +- Modify: `spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala` + +This task connects the JVM plumbing so that `ShuffleScan` inputs get `CometShuffleBlockIterator` (wrapping raw `InputStream`) instead of `CometBatchIterator` (wrapping decoded `ColumnarBatch`). + +**Key insight**: Currently all inputs flow through `RDD[ColumnarBatch]`. For shuffle direct read, we need the raw `InputStream` before decoding. The approach: add a parallel input channel for raw shuffle streams alongside the existing `ColumnarBatch` inputs. + +- [ ] **Step 1: Change Native.scala createPlan signature** + +In `spark/src/main/scala/org/apache/comet/Native.scala` (line 57), change: + +```scala +iterators: Array[CometBatchIterator], +``` +to: +```scala +iterators: Array[Object], +``` + +The JNI side (`jni_api.rs:190`) already uses `JObjectArray`, so no Rust changes needed. + +- [ ] **Step 2: Add shuffle stream inputs to CometExecIterator** + +In `spark/src/main/scala/org/apache/comet/CometExecIterator.scala`, add a parameter for shuffle block iterators that should be used instead of regular batch iterators at specific input positions: + +```scala +class CometExecIterator( + val id: Long, + inputs: Seq[Iterator[ColumnarBatch]], + numOutputCols: Int, + protobufQueryPlan: Array[Byte], + nativeMetrics: CometMetricNode, + numParts: Int, + partitionIndex: Int, + broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]] = None, + encryptedFilePaths: Seq[String] = Seq.empty, + shuffleBlockIterators: Map[Int, CometShuffleBlockIterator] = Map.empty) +``` + +Replace the `cometBatchIterators` construction (lines 81-83): + +```scala +private val nativeIterators: Array[Object] = { + val result = new Array[Object](inputs.size) + inputs.zipWithIndex.foreach { case (iterator, idx) => + result(idx) = shuffleBlockIterators.getOrElse( + idx, + new CometBatchIterator(iterator, nativeUtil)) + } + result +} +``` + +Change `nativeLib.createPlan(id, cometBatchIterators, ...)` (line 109) to use `nativeIterators`. + +In the `close()` method, also close `CometShuffleBlockIterator` instances: +```scala +shuffleBlockIterators.values.foreach { iter => + try { iter.close() } catch { case _: Exception => } +} +``` + +- [ ] **Step 3: Add shuffle stream support to CometExecRDD** + +In `spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala`, add a parameter to carry shuffle block iterator factories: + +```scala +private[spark] class CometExecRDD( + sc: SparkContext, + var inputRDDs: Seq[RDD[ColumnarBatch]], + ... + encryptedFilePaths: Seq[String] = Seq.empty, + shuffleBlockIteratorFactories: Map[Int, (TaskContext, Partition) => CometShuffleBlockIterator] = Map.empty) +``` + +In the `compute` method (line 112), pass them to `CometExecIterator`: + +```scala +// Create shuffle block iterators for this partition +val shuffleBlockIters = shuffleBlockIteratorFactories.map { case (idx, factory) => + idx -> factory(context, partition.inputPartitions(idx)) +} + +val it = new CometExecIterator( + CometExec.newIterId, + inputs, + numOutputCols, + actualPlan, + nativeMetrics, + numPartitions, + partition.index, + broadcastedHadoopConfForEncryption, + encryptedFilePaths, + shuffleBlockIters) +``` + +- [ ] **Step 4: Identify ShuffleScan inputs in operators.scala** + +In `spark/src/main/scala/org/apache/spark/sql/comet/operators.scala`, in `CometNativeExec.doExecuteColumnar` (around line 480): + +After `foreachUntilCometInput(this)(sparkPlans += _)`, determine which inputs correspond to `ShuffleScan` operators. Parse the serialized protobuf plan to find `ShuffleScan` leaf positions: + +```scala +import org.apache.comet.serde.OperatorOuterClass + +// Find which input indices correspond to ShuffleScan operators +val shuffleScanIndices: Set[Int] = { + val plan = OperatorOuterClass.Operator.parseFrom(serializedPlanCopy) + var scanIndex = 0 + val indices = scala.collection.mutable.Set.empty[Int] + def walk(op: OperatorOuterClass.Operator): Unit = { + if (op.hasShuffleScan) { + indices += scanIndex + scanIndex += 1 + } else if (op.hasScan) { + scanIndex += 1 + } else { + // Recurse into children in order + (0 until op.getChildrenCount).foreach(i => walk(op.getChildren(i))) + } + } + walk(plan) + indices.toSet +} +``` + +Then in the `sparkPlans.zipWithIndex.foreach` loop (line 523), for plans at shuffle scan indices, create a factory that produces `CometShuffleBlockIterator`: + +```scala +val shuffleBlockIteratorFactories = scala.collection.mutable.Map.empty[Int, (TaskContext, Partition) => CometShuffleBlockIterator] + +sparkPlans.zipWithIndex.foreach { case (plan, idx) => + plan match { + // ... existing cases ... + case _ if shuffleScanIndices.contains(inputIndexForPlan(idx)) => + // Still add the RDD for partition tracking, but also register + // a factory for the raw InputStream + val rdd = plan.executeColumnar() + inputs += rdd + // The factory creates a CometShuffleBlockIterator from the raw shuffle stream + // We need to get the raw InputStream - see Step 5 + shuffleBlockIteratorFactories(inputs.size - 1) = ... + // ... remaining cases ... + } +} +``` + +The tricky part is getting the raw `InputStream` from the shuffle read. See Step 5. + +- [ ] **Step 5: Add raw InputStream mode to CometBlockStoreShuffleReader** + +In `spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala`: + +The current `read()` method creates `NativeBatchDecoderIterator` which decodes blocks. For direct read, we need a mode that yields the raw `InputStream` wrapped in `CometShuffleBlockIterator`. + +Add a method: + +```scala +def readRawStreams(): Iterator[CometShuffleBlockIterator] = { + fetchIterator.map { case (_, inputStream) => + new CometShuffleBlockIterator(inputStream) + } +} +``` + +The challenge is that `CometShuffledBatchRDD` calls `reader.read()` which returns `Iterator[Product2[Int, ColumnarBatch]]`. For the direct read path, we need a different RDD that calls `readRawStreams()` instead. + +**Approach**: Create `CometShuffledRawStreamRDD` — a simple RDD that wraps the shuffle reader and yields `CometShuffleBlockIterator` objects per block. Then in `operators.scala`, instead of using the ColumnarBatch RDD, create a `CometShuffledRawStreamRDD` and pass its iterator-producing factory to `CometExecRDD`. + +Alternatively, since `CometShuffleBlockIterator` wraps a single `InputStream` that may contain multiple blocks, and `fetchIterator` yields one `InputStream` per shuffle block, the simplest approach is to **concatenate all InputStreams into one** per partition: + +```scala +def readAsRawStream(): InputStream = { + val streams = fetchIterator.map(_._2) + new SequenceInputStream(java.util.Collections.enumeration( + streams.toList.asJava)) +} +``` + +Then in the factory: `(ctx, part) => new CometShuffleBlockIterator(reader.readAsRawStream())` + +But the reader is created per-partition in `CometShuffledBatchRDD.compute()`. The factory approach means the reader creation must be deferred. + +**Simplest concrete approach**: Instead of a factory, create a new RDD `CometShuffledRawRDD` that returns `Iterator[CometShuffleBlockIterator]`. Pass this as a separate input alongside the regular `ColumnarBatch` inputs: + +```scala +// In CometExecRDD, add: +shuffleRawInputRDDs: Seq[(Int, RDD[CometShuffleBlockIterator])] +``` + +In `compute`, create iterators from these RDDs and pass them to `CometExecIterator` via the `shuffleBlockIterators` map. + +This is the most invasive part of the implementation. The exact approach should be determined by reading the code at implementation time, as there are multiple valid paths. The key constraint: the raw `InputStream` from `fetchIterator` must reach `CometShuffleBlockIterator` without going through `NativeBatchDecoderIterator`. + +- [ ] **Step 6: Verify it compiles** + +Run: `./mvnw compile -DskipTests` +Expected: BUILD SUCCESS + +- [ ] **Step 7: Commit** + +```bash +git add spark/src/main/scala/org/apache/comet/Native.scala +git add spark/src/main/scala/org/apache/comet/CometExecIterator.scala +git add spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala +git add spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +git add spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala +git commit -m "feat: wire CometShuffleBlockIterator into JVM execution path" +``` + +--- + +### Task 9: End-to-end testing + +**Files:** +- Modify: Appropriate test suite (find the right suite by searching for existing shuffle tests) + +- [ ] **Step 1: Build everything** + +Run: `make` +Expected: Successful build of both native and JVM. + +- [ ] **Step 2: Run existing shuffle tests** + +Run: `./mvnw test -Dsuites="org.apache.comet.exec.CometShuffleSuite"` +Expected: All existing tests pass (they now use the new direct read path by default). + +If tests fail, debug by setting `spark.comet.shuffle.directRead.enabled=false` to confirm the old path still works, then investigate the new path. + +- [ ] **Step 3: Add comparison test** + +Add a test that runs the same queries with direct read enabled and disabled: + +```scala +test("shuffle direct read produces same results as FFI path") { + Seq(true, false).foreach { directRead => + withSQLConf( + CometConf.COMET_SHUFFLE_DIRECT_READ_ENABLED.key -> directRead.toString) { + val df = spark.range(1000) + .selectExpr("id", "id % 10 as key", "cast(id as string) as value") + .repartition(4, col("key")) + .groupBy("key") + .agg(sum("id").as("total"), count("value").as("cnt")) + .orderBy("key") + checkSparkAnswer(df) + } + } +} +``` + +- [ ] **Step 4: Add Rust unit test for ShuffleScanExec** + +In `native/core/src/execution/operators/shuffle_scan.rs`, add a `#[cfg(test)]` module: + +```rust +#[cfg(test)] +mod tests { + use super::*; + use crate::execution::shuffle::codec::{CompressionCodec, ShuffleBlockWriter}; + use arrow::array::{Int32Array, StringArray}; + use arrow::datatypes::{Field, Schema}; + use arrow::record_batch::RecordBatch; + use std::io::Cursor; + use std::sync::Arc; + + #[test] + fn test_read_compressed_ipc_block() { + // Create a test RecordBatch + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["a", "b", "c"])), + ], + ).unwrap(); + + // Write it as compressed IPC using ShuffleBlockWriter + let writer = ShuffleBlockWriter::try_new( + &batch.schema(), CompressionCodec::Zstd(1) + ).unwrap(); + let mut buf = Cursor::new(Vec::new()); + let ipc_time = datafusion::physical_plan::metrics::Time::new(); + writer.write_batch(&batch, &mut buf, &ipc_time).unwrap(); + + // Read back the body (skip the 16-byte header) + let bytes = buf.into_inner(); + let body = &bytes[16..]; // Skip compressed_length(8) + field_count(8) + + // Decode using read_ipc_compressed + let decoded = read_ipc_compressed(body).unwrap(); + assert_eq!(decoded.num_rows(), 3); + assert_eq!(decoded.num_columns(), 2); + } +} +``` + +- [ ] **Step 5: Run all tests** + +Run: `make test` + +- [ ] **Step 6: Run clippy** + +Run: `cd native && cargo clippy --all-targets --workspace -- -D warnings` +Expected: No warnings. + +- [ ] **Step 7: Format** + +Run: `make format` + +- [ ] **Step 8: Commit** + +```bash +git add -A +git commit -m "test: add shuffle direct read tests" +``` + +--- + +## Implementation Notes + +### Task 8 is the hardest + +The core challenge is routing raw `InputStream` from Spark's shuffle infrastructure through to `CometShuffleBlockIterator` without going through the decode path. The current RDD pipeline (`CometShuffledBatchRDD` → `CometBlockStoreShuffleReader.read()` → `NativeBatchDecoderIterator`) always decodes. You need to intercept before `NativeBatchDecoderIterator` is created. + +The most surgical approach: in `CometBlockStoreShuffleReader`, add a `readRaw()` method that returns the raw `InputStream` (or a `CometShuffleBlockIterator` wrapping it) instead of decoded batches. Then create a parallel RDD (`CometShuffledRawRDD`) that calls `readRaw()` in its `compute` method and pass it through to `CometExecIterator`. + +### Metrics + +`ShuffleScanExec` should track `decode_time` using DataFusion's `Time` metric. Register it in `ShuffleScanExec::new` via `MetricBuilder` following the pattern in `ScanExec`. + +### Order of tasks + +Tasks 1-7 can be done sequentially. Task 8 depends on all previous tasks. Task 9 validates everything. From cb2fe12a887edd863a3b707caa2280500150ee7f Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Mar 2026 15:30:11 -0600 Subject: [PATCH 03/16] feat: add spark.comet.shuffle.directRead.enabled config --- .../src/main/scala/org/apache/comet/CometConf.scala | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 4d2e37924a..ad3774567c 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -343,6 +343,17 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(true) + val COMET_SHUFFLE_DIRECT_READ_ENABLED: ConfigEntry[Boolean] = + conf("spark.comet.shuffle.directRead.enabled") + .category(CATEGORY_EXEC) + .doc( + "When enabled, native operators that consume shuffle output will read " + + "compressed shuffle blocks directly in native code, bypassing Arrow FFI. " + + "Only applies to native shuffle (not JVM columnar shuffle). " + + "Requires spark.comet.exec.shuffle.enabled to be true.") + .booleanConf + .createWithDefault(true) + val COMET_SHUFFLE_MODE: ConfigEntry[String] = conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.mode") .category(CATEGORY_SHUFFLE) .doc( From 191bbe1a32033ca2664baf274340f2862d9d6597 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Mar 2026 15:31:29 -0600 Subject: [PATCH 04/16] feat: add ShuffleScan protobuf message --- native/core/src/execution/planner/operator_registry.rs | 1 + native/proto/src/proto/operator.proto | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/native/core/src/execution/planner/operator_registry.rs b/native/core/src/execution/planner/operator_registry.rs index b34a80df95..e20624b6c9 100644 --- a/native/core/src/execution/planner/operator_registry.rs +++ b/native/core/src/execution/planner/operator_registry.rs @@ -153,5 +153,6 @@ fn get_operator_type(spark_operator: &Operator) -> Option { OpStruct::Window(_) => Some(OperatorType::Window), OpStruct::Explode(_) => None, // Not yet in OperatorType enum OpStruct::CsvScan(_) => Some(OperatorType::CsvScan), + OpStruct::ShuffleScan(_) => None, // Not yet in OperatorType enum } } diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 4afc1fefb7..344b9f0f21 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -52,6 +52,7 @@ message Operator { ParquetWriter parquet_writer = 113; Explode explode = 114; CsvScan csv_scan = 115; + ShuffleScan shuffle_scan = 116; } } @@ -85,6 +86,12 @@ message Scan { bool arrow_ffi_safe = 3; } +message ShuffleScan { + repeated spark.spark_expression.DataType fields = 1; + // Informational label for debug output (e.g., "CometShuffleExchangeExec [id=5]") + string source = 2; +} + // Common data shared by all partitions in split mode (sent once at planning) message NativeScanCommon { repeated SparkStructField required_schema = 1; From 7ac1d93595607970d038e2858d68f584a3aed3c3 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Mar 2026 15:33:22 -0600 Subject: [PATCH 05/16] feat: add CometShuffleBlockIterator for raw shuffle block access --- .../comet/CometShuffleBlockIterator.java | 141 ++++++++++++++++++ 1 file changed, 141 insertions(+) create mode 100644 spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java diff --git a/spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java b/spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java new file mode 100644 index 0000000000..5de5e05c4e --- /dev/null +++ b/spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet; + +import java.io.Closeable; +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; + +/** + * Provides raw compressed shuffle blocks to native code via JNI. + * + *

Reads block headers (compressed length + field count) from a shuffle InputStream and loads the + * compressed body into a DirectByteBuffer. Native code pulls blocks by calling hasNext() and + * getBuffer(). + * + *

The DirectByteBuffer returned by getBuffer() is only valid until the next hasNext() call. + * Native code must fully consume it (via read_ipc_compressed which allocates new memory for the + * decompressed data) before pulling the next block. + */ +public class CometShuffleBlockIterator implements Closeable { + + private static final int INITIAL_BUFFER_SIZE = 128 * 1024; + + private final ReadableByteChannel channel; + private final InputStream inputStream; + private final ByteBuffer headerBuf = ByteBuffer.allocate(16).order(ByteOrder.LITTLE_ENDIAN); + private ByteBuffer dataBuf = ByteBuffer.allocateDirect(INITIAL_BUFFER_SIZE); + private boolean closed = false; + private int currentBlockLength = 0; + + public CometShuffleBlockIterator(InputStream in) { + this.inputStream = in; + this.channel = Channels.newChannel(in); + } + + /** + * Reads the next block header and loads the compressed body into the internal buffer. Called by + * native code via JNI. + * + *

Header format: 8-byte compressedLength (includes field count but not itself) + 8-byte + * fieldCount (discarded, schema comes from protobuf). + * + * @return the compressed body length in bytes (codec prefix + compressed IPC), or -1 if EOF + */ + public int hasNext() throws IOException { + if (closed) { + return -1; + } + + // Read 16-byte header + headerBuf.clear(); + while (headerBuf.hasRemaining()) { + int bytesRead = channel.read(headerBuf); + if (bytesRead < 0) { + if (headerBuf.position() == 0) { + return -1; + } + throw new EOFException("Data corrupt: unexpected EOF while reading batch header"); + } + } + headerBuf.flip(); + long compressedLength = headerBuf.getLong(); + // Field count discarded - schema determined by ShuffleScan protobuf fields + headerBuf.getLong(); + + long bytesToRead = compressedLength - 8; + if (bytesToRead > Integer.MAX_VALUE) { + throw new IllegalStateException( + "Native shuffle block size of " + + bytesToRead + + " exceeds maximum of " + + Integer.MAX_VALUE + + ". Try reducing shuffle batch size."); + } + + if (dataBuf.capacity() < bytesToRead) { + int newCapacity = (int) Math.min(bytesToRead * 2L, Integer.MAX_VALUE); + dataBuf = ByteBuffer.allocateDirect(newCapacity); + } + + dataBuf.clear(); + dataBuf.limit((int) bytesToRead); + while (dataBuf.hasRemaining()) { + int bytesRead = channel.read(dataBuf); + if (bytesRead < 0) { + throw new EOFException("Data corrupt: unexpected EOF while reading compressed batch"); + } + } + // Note: native side uses get_direct_buffer_address (base pointer) + currentBlockLength, + // not the buffer's position/limit. No flip needed. + + currentBlockLength = (int) bytesToRead; + return currentBlockLength; + } + + /** + * Returns the DirectByteBuffer containing the current block's compressed bytes (4-byte codec + * prefix + compressed IPC data). Called by native code via JNI. + */ + public ByteBuffer getBuffer() { + return dataBuf; + } + + /** Returns the length of the current block in bytes. Called by native code via JNI. */ + public int getCurrentBlockLength() { + return currentBlockLength; + } + + @Override + public void close() throws IOException { + if (!closed) { + closed = true; + inputStream.close(); + if (dataBuf.capacity() > INITIAL_BUFFER_SIZE) { + dataBuf = ByteBuffer.allocateDirect(INITIAL_BUFFER_SIZE); + } + } + } +} From 98bab7348af98282241570056a0a1234b9363da9 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Mar 2026 15:34:38 -0600 Subject: [PATCH 06/16] feat: add JNI bridge for CometShuffleBlockIterator --- native/core/src/jvm_bridge/mod.rs | 5 ++ .../src/jvm_bridge/shuffle_block_iterator.rs | 56 +++++++++++++++++++ 2 files changed, 61 insertions(+) create mode 100644 native/core/src/jvm_bridge/shuffle_block_iterator.rs diff --git a/native/core/src/jvm_bridge/mod.rs b/native/core/src/jvm_bridge/mod.rs index 00fe7b33c3..85c2ae7577 100644 --- a/native/core/src/jvm_bridge/mod.rs +++ b/native/core/src/jvm_bridge/mod.rs @@ -174,11 +174,13 @@ pub use comet_exec::*; mod batch_iterator; mod comet_metric_node; mod comet_task_memory_manager; +mod shuffle_block_iterator; use crate::{errors::CometError, JAVA_VM}; use batch_iterator::CometBatchIterator; pub use comet_metric_node::*; pub use comet_task_memory_manager::*; +use shuffle_block_iterator::CometShuffleBlockIterator; /// The JVM classes that are used in the JNI calls. #[allow(dead_code)] // we need to keep references to Java items to prevent GC @@ -204,6 +206,8 @@ pub struct JVMClasses<'a> { pub comet_exec: CometExec<'a>, /// The CometBatchIterator class. Used for iterating over the batches. pub comet_batch_iterator: CometBatchIterator<'a>, + /// The CometShuffleBlockIterator class. Used for iterating over shuffle blocks. + pub comet_shuffle_block_iterator: CometShuffleBlockIterator<'a>, /// The CometTaskMemoryManager used for interacting with JVM side to /// acquire & release native memory. pub comet_task_memory_manager: CometTaskMemoryManager<'a>, @@ -257,6 +261,7 @@ impl JVMClasses<'_> { comet_metric_node: CometMetricNode::new(env).unwrap(), comet_exec: CometExec::new(env).unwrap(), comet_batch_iterator: CometBatchIterator::new(env).unwrap(), + comet_shuffle_block_iterator: CometShuffleBlockIterator::new(env).unwrap(), comet_task_memory_manager: CometTaskMemoryManager::new(env).unwrap(), } }); diff --git a/native/core/src/jvm_bridge/shuffle_block_iterator.rs b/native/core/src/jvm_bridge/shuffle_block_iterator.rs new file mode 100644 index 0000000000..02fcf8ca27 --- /dev/null +++ b/native/core/src/jvm_bridge/shuffle_block_iterator.rs @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use jni::signature::Primitive; +use jni::{ + errors::Result as JniResult, + objects::{JClass, JMethodID}, + signature::ReturnType, + JNIEnv, +}; + +/// A struct that holds all the JNI methods and fields for JVM `CometShuffleBlockIterator` class. +#[allow(dead_code)] // we need to keep references to Java items to prevent GC +pub struct CometShuffleBlockIterator<'a> { + pub class: JClass<'a>, + pub method_has_next: JMethodID, + pub method_has_next_ret: ReturnType, + pub method_get_buffer: JMethodID, + pub method_get_buffer_ret: ReturnType, + pub method_get_current_block_length: JMethodID, + pub method_get_current_block_length_ret: ReturnType, +} + +impl<'a> CometShuffleBlockIterator<'a> { + pub const JVM_CLASS: &'static str = "org/apache/comet/CometShuffleBlockIterator"; + + pub fn new(env: &mut JNIEnv<'a>) -> JniResult> { + let class = env.find_class(Self::JVM_CLASS)?; + + Ok(CometShuffleBlockIterator { + class, + method_has_next: env.get_method_id(Self::JVM_CLASS, "hasNext", "()I")?, + method_has_next_ret: ReturnType::Primitive(Primitive::Int), + method_get_buffer: env + .get_method_id(Self::JVM_CLASS, "getBuffer", "()Ljava/nio/ByteBuffer;")?, + method_get_buffer_ret: ReturnType::Object, + method_get_current_block_length: env + .get_method_id(Self::JVM_CLASS, "getCurrentBlockLength", "()I")?, + method_get_current_block_length_ret: ReturnType::Primitive(Primitive::Int), + }) + } +} From c01cf1d47f1d12d2d1a6222d7468a3be9a176c96 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Mar 2026 15:58:35 -0600 Subject: [PATCH 07/16] feat: add ShuffleScanExec native operator for direct shuffle read Add a new ShuffleScanExec operator that pulls compressed shuffle blocks from JVM via CometShuffleBlockIterator and decodes them natively using read_ipc_compressed(). Uses the pre-pull pattern (get_next_batch called externally before poll_next) to avoid JNI calls on tokio threads. --- native/core/src/execution/operators/mod.rs | 2 + .../src/execution/operators/shuffle_scan.rs | 348 ++++++++++++++++++ 2 files changed, 350 insertions(+) create mode 100644 native/core/src/execution/operators/shuffle_scan.rs diff --git a/native/core/src/execution/operators/mod.rs b/native/core/src/execution/operators/mod.rs index 07ee995367..ad3ec3f08b 100644 --- a/native/core/src/execution/operators/mod.rs +++ b/native/core/src/execution/operators/mod.rs @@ -34,7 +34,9 @@ pub use parquet_writer::ParquetWriterExec; mod csv_scan; pub mod projection; mod scan; +mod shuffle_scan; pub use csv_scan::init_csv_datasource_exec; +pub use shuffle_scan::ShuffleScanExec; /// Error returned during executing operators. #[derive(thiserror::Error, Debug)] diff --git a/native/core/src/execution/operators/shuffle_scan.rs b/native/core/src/execution/operators/shuffle_scan.rs new file mode 100644 index 0000000000..4a8d09111b --- /dev/null +++ b/native/core/src/execution/operators/shuffle_scan.rs @@ -0,0 +1,348 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{ + errors::CometError, + execution::{ + operators::ExecutionError, planner::TEST_EXEC_CONTEXT_ID, + shuffle::codec::read_ipc_compressed, + }, + jvm_bridge::{jni_call, JVMClasses}, +}; +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::common::{arrow_datafusion_err, Result as DataFusionResult}; +use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion::physical_plan::metrics::{ + BaselineMetrics, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, Time, +}; +use datafusion::{ + execution::TaskContext, + physical_expr::*, + physical_plan::{ExecutionPlan, *}, +}; +use futures::Stream; +use jni::objects::{GlobalRef, JByteBuffer, JObject}; +use std::{ + any::Any, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, +}; + +use super::scan::InputBatch; + +/// ShuffleScanExec reads compressed shuffle blocks from JVM via JNI and decodes them natively. +/// Unlike ScanExec which receives Arrow arrays via FFI, ShuffleScanExec receives raw compressed +/// bytes from CometShuffleBlockIterator and decodes them using read_ipc_compressed(). +#[derive(Debug, Clone)] +pub struct ShuffleScanExec { + /// The ID of the execution context that owns this subquery. + pub exec_context_id: i64, + /// The input source: a global reference to a JVM CometShuffleBlockIterator object. + pub input_source: Option>, + /// The data types of columns in the shuffle output. + pub data_types: Vec, + /// Schema of the shuffle output. + pub schema: SchemaRef, + /// The current input batch, populated by get_next_batch() before poll_next(). + pub batch: Arc>>, + /// Cache of plan properties. + cache: PlanProperties, + /// Metrics collector. + metrics: ExecutionPlanMetricsSet, + /// Baseline metrics. + baseline_metrics: BaselineMetrics, + /// Time spent decoding compressed shuffle blocks. + decode_time: Time, +} + +impl ShuffleScanExec { + pub fn new( + exec_context_id: i64, + input_source: Option>, + data_types: Vec, + ) -> Result { + let metrics_set = ExecutionPlanMetricsSet::default(); + let baseline_metrics = BaselineMetrics::new(&metrics_set, 0); + let decode_time = MetricBuilder::new(&metrics_set).subset_time("decode_time", 0); + + let schema = schema_from_data_types(&data_types); + + let cache = PlanProperties::new( + EquivalenceProperties::new(Arc::clone(&schema)), + Partitioning::UnknownPartitioning(1), + EmissionType::Final, + Boundedness::Bounded, + ); + + Ok(Self { + exec_context_id, + input_source, + data_types, + batch: Arc::new(Mutex::new(None)), + cache, + metrics: metrics_set, + baseline_metrics, + schema, + decode_time, + }) + } + + /// Feeds input batch into this scan. Only used in unit tests. + pub fn set_input_batch(&mut self, input: InputBatch) { + *self.batch.try_lock().unwrap() = Some(input); + } + + /// Pull next input batch from JVM. Called externally before poll_next() + /// because JNI calls cannot happen from within poll_next on tokio threads. + pub fn get_next_batch(&mut self) -> Result<(), CometError> { + if self.input_source.is_none() { + // Unit test mode - no JNI calls needed. + return Ok(()); + } + let mut timer = self.baseline_metrics.elapsed_compute().timer(); + + let mut current_batch = self.batch.try_lock().unwrap(); + if current_batch.is_none() { + let next_batch = Self::get_next( + self.exec_context_id, + self.input_source.as_ref().unwrap().as_obj(), + &self.data_types, + &self.decode_time, + )?; + *current_batch = Some(next_batch); + } + + timer.stop(); + + Ok(()) + } + + /// Invokes JNI calls to get the next compressed shuffle block and decode it. + fn get_next( + exec_context_id: i64, + iter: &JObject, + data_types: &[DataType], + decode_time: &Time, + ) -> Result { + if exec_context_id == TEST_EXEC_CONTEXT_ID { + return Ok(InputBatch::EOF); + } + + if iter.is_null() { + return Err(CometError::from(ExecutionError::GeneralError(format!( + "Null shuffle block iterator object. Plan id: {exec_context_id}" + )))); + } + + let mut env = JVMClasses::get_env()?; + + // has_next() returns block length or -1 if no more blocks + let block_length: i32 = unsafe { + jni_call!(&mut env, + comet_shuffle_block_iterator(iter).has_next() -> i32)? + }; + + if block_length == -1 { + return Ok(InputBatch::EOF); + } + + // Get the DirectByteBuffer containing the compressed shuffle block + let buffer: JObject = unsafe { + jni_call!(&mut env, + comet_shuffle_block_iterator(iter).get_buffer() -> JObject)? + }; + + // Get the actual block length (may differ from has_next return value) + let length: i32 = unsafe { + jni_call!(&mut env, + comet_shuffle_block_iterator(iter).get_current_block_length() -> i32)? + }; + + let byte_buffer = JByteBuffer::from(buffer); + let raw_pointer = env.get_direct_buffer_address(&byte_buffer)?; + let length = length as usize; + let slice: &[u8] = unsafe { std::slice::from_raw_parts(raw_pointer, length) }; + + // Decode the compressed IPC data + let mut timer = decode_time.timer(); + let batch = read_ipc_compressed(slice)?; + timer.stop(); + + let num_rows = batch.num_rows(); + + // The read_ipc_compressed already produces owned arrays, so we skip the + // header (field count + codec) that was already consumed by read_ipc_compressed. + // Extract column arrays from the RecordBatch. + let columns: Vec = batch.columns().to_vec(); + + debug_assert_eq!( + columns.len(), + data_types.len(), + "Shuffle block column count mismatch: got {} but expected {}", + columns.len(), + data_types.len() + ); + + Ok(InputBatch::new(columns, Some(num_rows))) + } +} + +fn schema_from_data_types(data_types: &[DataType]) -> SchemaRef { + let fields = data_types + .iter() + .enumerate() + .map(|(idx, dt)| Field::new(format!("col_{idx}"), dt.clone(), true)) + .collect::>(); + + Arc::new(Schema::new(fields)) +} + +impl ExecutionPlan for ShuffleScanExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> datafusion::common::Result> { + Ok(self) + } + + fn execute( + &self, + partition: usize, + _: Arc, + ) -> datafusion::common::Result { + Ok(Box::pin(ShuffleScanStream::new( + self.clone(), + self.schema(), + partition, + self.baseline_metrics.clone(), + ))) + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn name(&self) -> &str { + "ShuffleScanExec" + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } +} + +impl DisplayAs for ShuffleScanExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + let fields: Vec = self + .data_types + .iter() + .enumerate() + .map(|(idx, dt)| format!("col_{idx}: {dt}")) + .collect(); + write!(f, "ShuffleScanExec: schema=[{}]", fields.join(", "))?; + } + DisplayFormatType::TreeRender => unimplemented!(), + } + Ok(()) + } +} + +/// An async stream that feeds decoded shuffle batches into the DataFusion plan. +struct ShuffleScanStream { + /// The ShuffleScanExec producing input batches. + shuffle_scan: ShuffleScanExec, + /// Schema of the output. + schema: SchemaRef, + /// Metrics. + baseline_metrics: BaselineMetrics, +} + +impl ShuffleScanStream { + pub fn new( + shuffle_scan: ShuffleScanExec, + schema: SchemaRef, + _partition: usize, + baseline_metrics: BaselineMetrics, + ) -> Self { + Self { + shuffle_scan, + schema, + baseline_metrics, + } + } +} + +impl Stream for ShuffleScanStream { + type Item = DataFusionResult; + + fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + let mut timer = self.baseline_metrics.elapsed_compute().timer(); + let mut scan_batch = self.shuffle_scan.batch.try_lock().unwrap(); + + let input_batch = &*scan_batch; + let input_batch = if let Some(batch) = input_batch { + batch + } else { + timer.stop(); + return Poll::Pending; + }; + + let result = match input_batch { + InputBatch::EOF => Poll::Ready(None), + InputBatch::Batch(columns, num_rows) => { + self.baseline_metrics.record_output(*num_rows); + let options = arrow::array::RecordBatchOptions::new() + .with_row_count(Some(*num_rows)); + let maybe_batch = arrow::array::RecordBatch::try_new_with_options( + Arc::clone(&self.schema), + columns.clone(), + &options, + ) + .map_err(|e| arrow_datafusion_err!(e)); + Poll::Ready(Some(maybe_batch)) + } + }; + + *scan_batch = None; + + timer.stop(); + + result + } +} + +impl RecordBatchStream for ShuffleScanStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} From e1c9111203d88214802d82c35d86075f6ef61861 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Mar 2026 16:06:22 -0600 Subject: [PATCH 08/16] feat: wire ShuffleScanExec into planner and pre-pull mechanism --- native/core/src/execution/jni_api.rs | 14 ++- .../src/execution/operators/projection.rs | 9 +- native/core/src/execution/planner.rs | 114 ++++++++++++++---- .../execution/planner/operator_registry.rs | 11 +- 4 files changed, 110 insertions(+), 38 deletions(-) diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 361deae182..d20cf128b5 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -82,7 +82,7 @@ use tokio::sync::mpsc; use crate::execution::memory_pools::{ create_memory_pool, handle_task_shared_pool_release, parse_memory_pool_config, MemoryPoolConfig, }; -use crate::execution::operators::ScanExec; +use crate::execution::operators::{ScanExec, ShuffleScanExec}; use crate::execution::shuffle::{read_ipc_compressed, CompressionCodec}; use crate::execution::spark_plan::SparkPlan; @@ -151,6 +151,8 @@ struct ExecutionContext { pub root_op: Option>, /// The input sources for the DataFusion plan pub scans: Vec, + /// The shuffle scan input sources for the DataFusion plan + pub shuffle_scans: Vec, /// The global reference of input sources for the DataFusion plan pub input_sources: Vec>, /// The record batch stream to pull results from @@ -311,6 +313,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( partition_count: partition_count as usize, root_op: None, scans: vec![], + shuffle_scans: vec![], input_sources, stream: None, batch_receiver: None, @@ -491,6 +494,10 @@ fn pull_input_batches(exec_context: &mut ExecutionContext) -> Result<(), CometEr exec_context.scans.iter_mut().try_for_each(|scan| { scan.get_next_batch()?; Ok::<(), CometError>(()) + })?; + exec_context.shuffle_scans.iter_mut().try_for_each(|scan| { + scan.get_next_batch()?; + Ok::<(), CometError>(()) }) } @@ -539,7 +546,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( let planner = PhysicalPlanner::new(Arc::clone(&exec_context.session_ctx), partition) .with_exec_id(exec_context_id); - let (scans, root_op) = planner.create_plan( + let (scans, shuffle_scans, root_op) = planner.create_plan( &exec_context.spark_plan, &mut exec_context.input_sources.clone(), exec_context.partition_count, @@ -548,6 +555,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( exec_context.plan_creation_time += physical_plan_time; exec_context.scans = scans; + exec_context.shuffle_scans = shuffle_scans; if exec_context.explain_native { let formatted_plan_str = @@ -560,7 +568,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( // so we should always execute partition 0. let stream = root_op.native_plan.execute(0, task_ctx)?; - if exec_context.scans.is_empty() { + if exec_context.scans.is_empty() && exec_context.shuffle_scans.is_empty() { // No JVM data sources — spawn onto tokio so the executor // thread parks in blocking_recv instead of busy-polling. // diff --git a/native/core/src/execution/operators/projection.rs b/native/core/src/execution/operators/projection.rs index 6ba1bb5d59..4169ed8d40 100644 --- a/native/core/src/execution/operators/projection.rs +++ b/native/core/src/execution/operators/projection.rs @@ -25,8 +25,7 @@ use jni::objects::GlobalRef; use crate::{ execution::{ - operators::{ExecutionError, ScanExec}, - planner::{operator_registry::OperatorBuilder, PhysicalPlanner}, + planner::{operator_registry::OperatorBuilder, PlanCreationResult, PhysicalPlanner}, spark_plan::SparkPlan, }, extract_op, @@ -42,12 +41,13 @@ impl OperatorBuilder for ProjectionBuilder { inputs: &mut Vec>, partition_count: usize, planner: &PhysicalPlanner, - ) -> Result<(Vec, Arc), ExecutionError> { + ) -> PlanCreationResult { let project = extract_op!(spark_plan, Projection); let children = &spark_plan.children; assert_eq!(children.len(), 1); - let (scans, child) = planner.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + planner.create_plan(&children[0], inputs, partition_count)?; // Create projection expressions let exprs: Result, _> = project @@ -68,6 +68,7 @@ impl OperatorBuilder for ProjectionBuilder { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, projection, vec![child])), )) } diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index bd37755922..e19891a0d6 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -27,7 +27,7 @@ use crate::{ errors::ExpressionError, execution::{ expressions::subquery::Subquery, - operators::{ExecutionError, ExpandExec, ParquetWriterExec, ScanExec}, + operators::{ExecutionError, ExpandExec, ParquetWriterExec, ScanExec, ShuffleScanExec}, planner::expression_registry::ExpressionRegistry, planner::operator_registry::OperatorRegistry, serde::to_arrow_datatype, @@ -141,6 +141,8 @@ use url::Url; type PhyAggResult = Result, ExecutionError>; type PhyExprResult = Result, String)>, ExecutionError>; type PartitionPhyExprResult = Result>, ExecutionError>; +pub type PlanCreationResult = + Result<(Vec, Vec, Arc), ExecutionError>; struct JoinParameters { pub left: Arc, @@ -913,7 +915,7 @@ impl PhysicalPlanner { spark_plan: &'a Operator, inputs: &mut Vec>, partition_count: usize, - ) -> Result<(Vec, Arc), ExecutionError> { + ) -> PlanCreationResult { // Try to use the modular registry first - this automatically handles any registered operator types if OperatorRegistry::global().can_handle(spark_plan) { return OperatorRegistry::global().create_plan( @@ -929,7 +931,8 @@ impl PhysicalPlanner { match spark_plan.op_struct.as_ref().unwrap() { OpStruct::Filter(filter) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let predicate = self.create_expr(filter.predicate.as_ref().unwrap(), child.schema())?; @@ -940,12 +943,14 @@ impl PhysicalPlanner { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, filter, vec![child])), )) } OpStruct::HashAgg(agg) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let group_exprs: PhyExprResult = agg .grouping_exprs @@ -996,6 +1001,7 @@ impl PhysicalPlanner { if agg.result_exprs.is_empty() { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, aggregate, vec![child])), )) } else { @@ -1012,6 +1018,7 @@ impl PhysicalPlanner { )?); Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new_with_additional( spark_plan.plan_id, projection, @@ -1030,7 +1037,8 @@ impl PhysicalPlanner { "Invalid limit/offset combination: [{num}. {offset}]" ))); } - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let limit: Arc = if offset == 0 { Arc::new(LocalLimitExec::new( Arc::clone(&child.native_plan), @@ -1050,12 +1058,14 @@ impl PhysicalPlanner { }; Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, limit, vec![child])), )) } OpStruct::Sort(sort) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let exprs: Result, ExecutionError> = sort .sort_orders @@ -1079,6 +1089,7 @@ impl PhysicalPlanner { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new( spark_plan.plan_id, sort_exec, @@ -1115,6 +1126,7 @@ impl PhysicalPlanner { if partition_files.partitioned_file.is_empty() { let empty_exec = Arc::new(EmptyExec::new(required_schema)); return Ok(( + vec![], vec![], Arc::new(SparkPlan::new(spark_plan.plan_id, empty_exec, vec![])), )); @@ -1205,6 +1217,7 @@ impl PhysicalPlanner { common.encryption_enabled, )?; Ok(( + vec![], vec![], Arc::new(SparkPlan::new(spark_plan.plan_id, scan, vec![])), )) @@ -1243,6 +1256,7 @@ impl PhysicalPlanner { &scan.csv_options.clone().unwrap(), )?; Ok(( + vec![], vec![], Arc::new(SparkPlan::new(spark_plan.plan_id, scan, vec![])), )) @@ -1276,6 +1290,7 @@ impl PhysicalPlanner { Ok(( vec![scan.clone()], + vec![], Arc::new(SparkPlan::new(spark_plan.plan_id, Arc::new(scan), vec![])), )) } @@ -1307,6 +1322,7 @@ impl PhysicalPlanner { )?; Ok(( + vec![], vec![], Arc::new(SparkPlan::new( spark_plan.plan_id, @@ -1317,7 +1333,8 @@ impl PhysicalPlanner { } OpStruct::ShuffleWriter(writer) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let partitioning = self.create_partitioning( writer.partitioning.as_ref().unwrap(), @@ -1350,6 +1367,7 @@ impl PhysicalPlanner { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new( spark_plan.plan_id, shuffle_writer, @@ -1359,7 +1377,8 @@ impl PhysicalPlanner { } OpStruct::ParquetWriter(writer) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let codec = match writer.compression.try_into() { Ok(SparkCompressionCodec::None) => Ok(CompressionCodec::None), @@ -1396,6 +1415,7 @@ impl PhysicalPlanner { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new( spark_plan.plan_id, parquet_writer, @@ -1405,7 +1425,8 @@ impl PhysicalPlanner { } OpStruct::Expand(expand) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let mut projections = vec![]; let mut projection = vec![]; @@ -1448,12 +1469,14 @@ impl PhysicalPlanner { let expand = Arc::new(ExpandExec::new(projections, input, schema)); Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, expand, vec![child])), )) } OpStruct::Explode(explode) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; // Create the expression for the array to explode let child_expr = if let Some(child_expr) = &explode.child { @@ -1559,11 +1582,12 @@ impl PhysicalPlanner { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, unnest_exec, vec![child])), )) } OpStruct::SortMergeJoin(join) => { - let (join_params, scans) = self.parse_join_parameters( + let (join_params, scans, shuffle_scans) = self.parse_join_parameters( inputs, children, &join.left_join_keys, @@ -1615,6 +1639,7 @@ impl PhysicalPlanner { )); Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new_with_additional( spark_plan.plan_id, coalesce_batches, @@ -1628,6 +1653,7 @@ impl PhysicalPlanner { } else { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new( spark_plan.plan_id, join, @@ -1640,7 +1666,7 @@ impl PhysicalPlanner { } } OpStruct::HashJoin(join) => { - let (join_params, scans) = self.parse_join_parameters( + let (join_params, scans, shuffle_scans) = self.parse_join_parameters( inputs, children, &join.left_join_keys, @@ -1670,6 +1696,7 @@ impl PhysicalPlanner { if join.build_side == BuildSide::BuildLeft as i32 { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new( spark_plan.plan_id, hash_join, @@ -1688,6 +1715,7 @@ impl PhysicalPlanner { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new_with_additional( spark_plan.plan_id, swapped_hash_join, @@ -1698,7 +1726,8 @@ impl PhysicalPlanner { } } OpStruct::Window(wnd) => { - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let input_schema = child.schema(); let sort_exprs: Result, ExecutionError> = wnd .order_by_list @@ -1736,9 +1765,42 @@ impl PhysicalPlanner { )?); Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, window_agg, vec![child])), )) } + OpStruct::ShuffleScan(scan) => { + let data_types = scan.fields.iter().map(to_arrow_datatype).collect_vec(); + + if self.exec_context_id != TEST_EXEC_CONTEXT_ID && inputs.is_empty() { + return Err(GeneralError( + "No input for shuffle scan".to_string(), + )); + } + + let input_source = + if self.exec_context_id == TEST_EXEC_CONTEXT_ID && inputs.is_empty() { + None + } else { + Some(inputs.remove(0)) + }; + + let shuffle_scan = ShuffleScanExec::new( + self.exec_context_id, + input_source, + data_types, + )?; + + Ok(( + vec![], + vec![shuffle_scan.clone()], + Arc::new(SparkPlan::new( + spark_plan.plan_id, + Arc::new(shuffle_scan), + vec![], + )), + )) + } _ => Err(GeneralError(format!( "Unsupported or unregistered operator type: {:?}", spark_plan.op_struct @@ -1756,12 +1818,15 @@ impl PhysicalPlanner { join_type: i32, condition: &Option, partition_count: usize, - ) -> Result<(JoinParameters, Vec), ExecutionError> { + ) -> Result<(JoinParameters, Vec, Vec), ExecutionError> { assert_eq!(children.len(), 2); - let (mut left_scans, left) = self.create_plan(&children[0], inputs, partition_count)?; - let (mut right_scans, right) = self.create_plan(&children[1], inputs, partition_count)?; + let (mut left_scans, mut left_shuffle_scans, left) = + self.create_plan(&children[0], inputs, partition_count)?; + let (mut right_scans, mut right_shuffle_scans, right) = + self.create_plan(&children[1], inputs, partition_count)?; left_scans.append(&mut right_scans); + left_shuffle_scans.append(&mut right_shuffle_scans); let left_join_exprs: Vec<_> = left_join_keys .iter() @@ -1882,6 +1947,7 @@ impl PhysicalPlanner { join_filter, }, left_scans, + left_shuffle_scans, )) } @@ -3670,7 +3736,7 @@ mod tests { let input_array = DictionaryArray::new(keys, Arc::new(values)); let input_batch = InputBatch::Batch(vec![Arc::new(input_array)], row_count); - let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut vec![], 1).unwrap(); + let (mut scans, _shuffle_scans, datafusion_plan) = planner.create_plan(&op, &mut vec![], 1).unwrap(); scans[0].set_input_batch(input_batch); let session_ctx = SessionContext::new(); @@ -3744,7 +3810,7 @@ mod tests { let input_array = DictionaryArray::new(keys, Arc::new(values)); let input_batch = InputBatch::Batch(vec![Arc::new(input_array)], row_count); - let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut vec![], 1).unwrap(); + let (mut scans, _shuffle_scans, datafusion_plan) = planner.create_plan(&op, &mut vec![], 1).unwrap(); // Scan's schema is determined by the input batch, so we need to set it before execution. scans[0].set_input_batch(input_batch); @@ -3791,7 +3857,7 @@ mod tests { let op = create_filter(op_scan, 0); let planner = PhysicalPlanner::default(); - let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut vec![], 1).unwrap(); + let (mut scans, _shuffle_scans, datafusion_plan) = planner.create_plan(&op, &mut vec![], 1).unwrap(); let scan = &mut scans[0]; scan.set_input_batch(InputBatch::EOF); @@ -3876,7 +3942,7 @@ mod tests { let op = create_filter(op_scan, 0); let planner = PhysicalPlanner::default(); - let (_scans, filter_exec) = planner.create_plan(&op, &mut vec![], 1).unwrap(); + let (_scans, _shuffle_scans, filter_exec) = planner.create_plan(&op, &mut vec![], 1).unwrap(); assert_eq!("FilterExec", filter_exec.native_plan.name()); assert_eq!(1, filter_exec.children.len()); @@ -3900,7 +3966,7 @@ mod tests { let planner = PhysicalPlanner::default(); - let (_scans, hash_join_exec) = planner.create_plan(&op_join, &mut vec![], 1).unwrap(); + let (_scans, _shuffle_scans, hash_join_exec) = planner.create_plan(&op_join, &mut vec![], 1).unwrap(); assert_eq!("HashJoinExec", hash_join_exec.native_plan.name()); assert_eq!(2, hash_join_exec.children.len()); @@ -4014,7 +4080,7 @@ mod tests { })), }; - let (mut scans, datafusion_plan) = + let (mut scans, _shuffle_scans, datafusion_plan) = planner.create_plan(&projection, &mut vec![], 1).unwrap(); let mut stream = datafusion_plan.native_plan.execute(0, task_ctx).unwrap(); @@ -4140,7 +4206,7 @@ mod tests { }; // Create a physical plan - let (mut scans, datafusion_plan) = + let (mut scans, _shuffle_scans, datafusion_plan) = planner.create_plan(&projection, &mut vec![], 1).unwrap(); // Start executing the plan in a separate thread @@ -4631,7 +4697,7 @@ mod tests { }; // Create the physical plan - let (mut scans, datafusion_plan) = + let (mut scans, _shuffle_scans, datafusion_plan) = planner.create_plan(&projection, &mut vec![], 1).unwrap(); // Create test data: Date32 and Int8 columns diff --git a/native/core/src/execution/planner/operator_registry.rs b/native/core/src/execution/planner/operator_registry.rs index e20624b6c9..cad5df40c5 100644 --- a/native/core/src/execution/planner/operator_registry.rs +++ b/native/core/src/execution/planner/operator_registry.rs @@ -25,11 +25,8 @@ use std::{ use datafusion_comet_proto::spark_operator::Operator; use jni::objects::GlobalRef; -use super::PhysicalPlanner; -use crate::execution::{ - operators::{ExecutionError, ScanExec}, - spark_plan::SparkPlan, -}; +use super::{PhysicalPlanner, PlanCreationResult}; +use crate::execution::operators::ExecutionError; /// Trait for building physical operators from Spark protobuf operators pub trait OperatorBuilder: Send + Sync { @@ -40,7 +37,7 @@ pub trait OperatorBuilder: Send + Sync { inputs: &mut Vec>, partition_count: usize, planner: &PhysicalPlanner, - ) -> Result<(Vec, Arc), ExecutionError>; + ) -> PlanCreationResult; } /// Enum to identify different operator types for registry dispatch @@ -100,7 +97,7 @@ impl OperatorRegistry { inputs: &mut Vec>, partition_count: usize, planner: &PhysicalPlanner, - ) -> Result<(Vec, Arc), ExecutionError> { + ) -> PlanCreationResult { let operator_type = get_operator_type(spark_operator).ok_or_else(|| { ExecutionError::GeneralError(format!( "Unsupported operator type: {:?}", From 9a9812a0bcd877327cf165a269266480c7ee2eb4 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Mar 2026 16:09:31 -0600 Subject: [PATCH 09/16] feat: emit ShuffleScan protobuf for native shuffle with direct read --- .../comet/serde/operator/CometSink.scala | 70 +++++++++++++++++-- 1 file changed, 63 insertions(+), 7 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala b/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala index ca9dbdad7c..dde36d9789 100644 --- a/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala +++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala @@ -22,8 +22,12 @@ package org.apache.comet.serde.operator import scala.jdk.CollectionConverters._ import org.apache.spark.sql.comet.{CometNativeExec, CometSinkPlaceHolder} +import org.apache.spark.sql.comet.execution.shuffle.{CometNativeShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec +import org.apache.spark.sql.execution.exchange.ReusedExchangeExec +import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.ConfigEntry import org.apache.comet.serde.{CometOperatorSerde, OperatorOuterClass} @@ -86,15 +90,67 @@ abstract class CometSink[T <: SparkPlan] extends CometOperatorSerde[T] { object CometExchangeSink extends CometSink[SparkPlan] { - /** - * Exchange data is FFI safe because there is no use of mutable buffers involved. - * - * Source of broadcast exchange batches is ArrowStreamReader. - * - * Source of shuffle exchange batches is NativeBatchDecoderIterator. - */ override def isFfiSafe: Boolean = true + override def convert( + op: SparkPlan, + builder: Operator.Builder, + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { + if (shouldUseShuffleScan(op)) { + convertToShuffleScan(op, builder) + } else { + super.convert(op, builder, childOp: _*) + } + } + + private def shouldUseShuffleScan(op: SparkPlan): Boolean = { + if (!CometConf.COMET_SHUFFLE_DIRECT_READ_ENABLED.get()) return false + + // Extract the CometShuffleExchangeExec from the wrapper + val shuffleExec = op match { + case ShuffleQueryStageExec(_, s: CometShuffleExchangeExec, _) => Some(s) + case ShuffleQueryStageExec(_, ReusedExchangeExec(_, s: CometShuffleExchangeExec), _) => + Some(s) + case s: CometShuffleExchangeExec => Some(s) + case _ => None + } + + shuffleExec.exists(_.shuffleType == CometNativeShuffle) + } + + private def convertToShuffleScan( + op: SparkPlan, + builder: Operator.Builder): Option[OperatorOuterClass.Operator] = { + val supportedTypes = + op.output.forall(a => supportedDataType(a.dataType, allowComplex = true)) + + if (!supportedTypes) { + withInfo(op, "Unsupported data type for shuffle direct read") + return None + } + + val scanBuilder = OperatorOuterClass.ShuffleScan.newBuilder() + val source = op.simpleStringWithNodeId() + if (source.isEmpty) { + scanBuilder.setSource(op.getClass.getSimpleName) + } else { + scanBuilder.setSource(source) + } + + val scanTypes = op.output.flatMap { attr => + serializeDataType(attr.dataType) + } + + if (scanTypes.length == op.output.length) { + scanBuilder.addAllFields(scanTypes.asJava) + builder.clearChildren() + Some(builder.setShuffleScan(scanBuilder).build()) + } else { + withInfo(op, "unsupported data types for shuffle direct read") + None + } + } + override def createExec(nativeOp: Operator, op: SparkPlan): CometNativeExec = CometSinkPlaceHolder(nativeOp, op, op) } From e098cd5df5d5e03b3f50443e22b64ca108089b2f Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Mar 2026 16:17:38 -0600 Subject: [PATCH 10/16] feat: wire CometShuffleBlockIterator into JVM execution path --- .../org/apache/comet/CometExecIterator.scala | 15 +- .../main/scala/org/apache/comet/Native.scala | 2 +- .../apache/spark/sql/comet/CometExecRDD.scala | 24 +++- .../CometBlockStoreShuffleReader.scala | 11 ++ .../apache/spark/sql/comet/operators.scala | 133 +++++++++++++++++- 5 files changed, 171 insertions(+), 14 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index 44ebf7e36e..e198ac99ff 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -67,7 +67,8 @@ class CometExecIterator( numParts: Int, partitionIndex: Int, broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]] = None, - encryptedFilePaths: Seq[String] = Seq.empty) + encryptedFilePaths: Seq[String] = Seq.empty, + shuffleBlockIterators: Map[Int, CometShuffleBlockIterator] = Map.empty) extends Iterator[ColumnarBatch] with Logging { @@ -78,8 +79,13 @@ class CometExecIterator( private val taskAttemptId = TaskContext.get().taskAttemptId private val taskCPUs = TaskContext.get().cpus() private val cometTaskMemoryManager = new CometTaskMemoryManager(id, taskAttemptId) - private val cometBatchIterators = inputs.map { iterator => - new CometBatchIterator(iterator, nativeUtil) + // Build a mixed array of iterators: CometShuffleBlockIterator for shuffle + // scan indices, CometBatchIterator for regular scan indices. + private val inputIterators: Array[Object] = inputs.zipWithIndex.map { + case (_, idx) if shuffleBlockIterators.contains(idx) => + shuffleBlockIterators(idx).asInstanceOf[Object] + case (iterator, _) => + new CometBatchIterator(iterator, nativeUtil).asInstanceOf[Object] }.toArray private val plan = { @@ -106,7 +112,7 @@ class CometExecIterator( nativeLib.createPlan( id, - cometBatchIterators, + inputIterators, protobufQueryPlan, protobufSparkConfigs, numParts, @@ -229,6 +235,7 @@ class CometExecIterator( currentBatch = null } nativeUtil.close() + shuffleBlockIterators.values.foreach(_.close()) nativeLib.releasePlan(plan) if (tracingEnabled) { diff --git a/spark/src/main/scala/org/apache/comet/Native.scala b/spark/src/main/scala/org/apache/comet/Native.scala index 55e0c70e72..f6800626d6 100644 --- a/spark/src/main/scala/org/apache/comet/Native.scala +++ b/spark/src/main/scala/org/apache/comet/Native.scala @@ -54,7 +54,7 @@ class Native extends NativeBase { // scalastyle:off @native def createPlan( id: Long, - iterators: Array[CometBatchIterator], + iterators: Array[Object], plan: Array[Byte], configMapProto: Array[Byte], partitionCount: Int, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala index ad0c4f2afe..cb8652507f 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.ScalarSubquery import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.SerializableConfiguration -import org.apache.comet.CometExecIterator +import org.apache.comet.{CometExecIterator, CometShuffleBlockIterator} import org.apache.comet.serde.OperatorOuterClass /** @@ -64,7 +64,10 @@ private[spark] class CometExecRDD( nativeMetrics: CometMetricNode, subqueries: Seq[ScalarSubquery], broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]] = None, - encryptedFilePaths: Seq[String] = Seq.empty) + encryptedFilePaths: Seq[String] = Seq.empty, + shuffleBlockIteratorFactories: Map[ + Int, + (TaskContext, Partition) => CometShuffleBlockIterator] = Map.empty) extends RDD[ColumnarBatch](sc, inputRDDs.map(rdd => new OneToOneDependency(rdd))) { // Determine partition count: from inputs if available, otherwise from parameter @@ -109,6 +112,12 @@ private[spark] class CometExecRDD( serializedPlan } + // Create shuffle block iterators for indices that have factories + val shuffleBlockIters = shuffleBlockIteratorFactories.map { case (idx, factory) => + val inputPart = partition.inputPartitions(idx) + idx -> factory(context, inputPart) + } + val it = new CometExecIterator( CometExec.newIterId, inputs, @@ -118,7 +127,8 @@ private[spark] class CometExecRDD( numPartitions, partition.index, broadcastedHadoopConfForEncryption, - encryptedFilePaths) + encryptedFilePaths, + shuffleBlockIters) // Register ScalarSubqueries so native code can look them up subqueries.foreach(sub => CometScalarSubquery.setSubquery(it.id, sub)) @@ -167,7 +177,10 @@ object CometExecRDD { nativeMetrics: CometMetricNode, subqueries: Seq[ScalarSubquery], broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]] = None, - encryptedFilePaths: Seq[String] = Seq.empty): CometExecRDD = { + encryptedFilePaths: Seq[String] = Seq.empty, + shuffleBlockIteratorFactories: Map[ + Int, + (TaskContext, Partition) => CometShuffleBlockIterator] = Map.empty): CometExecRDD = { // scalastyle:on new CometExecRDD( @@ -181,6 +194,7 @@ object CometExecRDD { nativeMetrics, subqueries, broadcastedHadoopConfForEncryption, - encryptedFilePaths) + encryptedFilePaths, + shuffleBlockIteratorFactories) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala index e95eb92d21..647d4a0856 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala @@ -21,6 +21,8 @@ package org.apache.spark.sql.comet.execution.shuffle import java.io.InputStream +import scala.jdk.CollectionConverters._ + import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext} import org.apache.spark.internal.{config, Logging} import org.apache.spark.io.CompressionCodec @@ -153,6 +155,15 @@ class CometBlockStoreShuffleReader[K, C]( } } + /** + * Returns the raw concatenated InputStream of all shuffle blocks, bypassing the decode step. + * Used by ShuffleScan direct read path. + */ + def readAsRawStream(): InputStream = { + val streams = fetchIterator.map(_._2).toList + new java.io.SequenceInputStream(java.util.Collections.enumeration(streams.asJava)) + } + private def fetchContinuousBlocksInBatch: Boolean = { val conf = SparkEnv.get.conf val serializerRelocatable = dep.serializer.supportsRelocationOfSerializedObjects diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index da2ae21a95..a0cb14bbd0 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -25,6 +25,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ +import org.apache.spark.{Partition, SparkEnv, TaskContext} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -33,14 +34,14 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec +import org.apache.spark.sql.comet.execution.shuffle.{CometBlockStoreShuffleReader, CometShuffledBatchRDD, CometShuffleExchangeExec, ShuffledRowRDDPartition} import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec} -import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, TimestampNTZType} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -50,7 +51,7 @@ import org.apache.spark.util.io.ChunkedByteBuffer import com.google.common.base.Objects import com.google.protobuf.CodedOutputStream -import org.apache.comet.{CometConf, CometExecIterator, CometRuntimeException, ConfigEntry} +import org.apache.comet.{CometConf, CometExecIterator, CometRuntimeException, CometShuffleBlockIterator, ConfigEntry} import org.apache.comet.CometSparkSessionExtensions.{isCometShuffleEnabled, withInfo} import org.apache.comet.parquet.CometParquetUtils import org.apache.comet.serde.{CometOperatorSerde, Compatible, Incompatible, OperatorOuterClass, SupportLevel, Unsupported} @@ -553,6 +554,11 @@ abstract class CometNativeExec extends CometExec { throw new CometRuntimeException(s"No input for CometNativeExec:\n $this") } + // Detect ShuffleScan indices and create factories for direct read + val shuffleScanIndices = findShuffleScanIndices(serializedPlanCopy) + val shuffleBlockIteratorFactories = + buildShuffleBlockIteratorFactories(sparkPlans, inputs, shuffleScanIndices) + // Unified RDD creation - CometExecRDD handles all cases val subqueries = collectSubqueries(this) CometExecRDD( @@ -566,7 +572,8 @@ abstract class CometNativeExec extends CometExec { nativeMetrics, subqueries, broadcastedHadoopConfForEncryption, - encryptedFilePaths) + encryptedFilePaths, + shuffleBlockIteratorFactories) } } @@ -606,6 +613,124 @@ abstract class CometNativeExec extends CometExec { } } + /** + * Walk the serialized protobuf plan depth-first to find which input indices correspond to + * ShuffleScan vs Scan leaf nodes. Each Scan or ShuffleScan leaf consumes one input in order. + */ + private def findShuffleScanIndices(planBytes: Array[Byte]): Set[Int] = { + val plan = OperatorOuterClass.Operator.parseFrom(planBytes) + var scanIndex = 0 + val indices = mutable.Set.empty[Int] + def walk(op: OperatorOuterClass.Operator): Unit = { + if (op.hasShuffleScan) { + indices += scanIndex + scanIndex += 1 + } else if (op.hasScan) { + scanIndex += 1 + } else { + op.getChildrenList.asScala.foreach(walk) + } + } + walk(plan) + indices.toSet + } + + /** + * Build factory functions that produce CometShuffleBlockIterator for each input index that is a + * ShuffleScan. Maps from input index to a factory that, given TaskContext and Partition, + * creates the iterator. + */ + private def buildShuffleBlockIteratorFactories( + sparkPlans: ArrayBuffer[SparkPlan], + inputs: ArrayBuffer[RDD[ColumnarBatch]], + shuffleScanIndices: Set[Int]) + : Map[Int, (TaskContext, Partition) => CometShuffleBlockIterator] = { + if (shuffleScanIndices.isEmpty) return Map.empty + + // Build the mapping from sparkPlans index to inputs index + // (CometNativeExec entries are skipped in inputs) + var inputIdx = 0 + val sparkPlanToInputIdx = mutable.Map.empty[Int, Int] + sparkPlans.zipWithIndex.foreach { case (plan, spIdx) => + plan match { + case _: CometNativeExec => // skipped, no input + case _ => + sparkPlanToInputIdx(spIdx) = inputIdx + inputIdx += 1 + } + } + + val factories = mutable.Map.empty[Int, (TaskContext, Partition) => CometShuffleBlockIterator] + + shuffleScanIndices.foreach { scanIdx => + if (scanIdx < inputs.length) { + inputs(scanIdx) match { + case rdd: CometShuffledBatchRDD => + val dep = rdd.dependency + factories(scanIdx) = (context, part) => { + val shufflePart = + part + .asInstanceOf[CometExecPartition] + .inputPartitions(scanIdx) + .asInstanceOf[ShuffledRowRDDPartition] + val tempMetrics = + context.taskMetrics().createTempShuffleReadMetrics() + val sqlMetricsReporter = + new SQLShuffleReadMetricsReporter(tempMetrics, Map.empty) + val reader = shufflePart.spec match { + case CoalescedPartitionSpec(startReducerIndex, endReducerIndex, _) => + SparkEnv.get.shuffleManager + .getReader( + dep.shuffleHandle, + startReducerIndex, + endReducerIndex, + context, + sqlMetricsReporter) + .asInstanceOf[CometBlockStoreShuffleReader[_, _]] + case PartialReducerPartitionSpec(reducerIndex, startMapIndex, endMapIndex, _) => + SparkEnv.get.shuffleManager + .getReader( + dep.shuffleHandle, + startMapIndex, + endMapIndex, + reducerIndex, + reducerIndex + 1, + context, + sqlMetricsReporter) + .asInstanceOf[CometBlockStoreShuffleReader[_, _]] + case PartialMapperPartitionSpec(mapIndex, startReducerIndex, endReducerIndex) => + SparkEnv.get.shuffleManager + .getReader( + dep.shuffleHandle, + mapIndex, + mapIndex + 1, + startReducerIndex, + endReducerIndex, + context, + sqlMetricsReporter) + .asInstanceOf[CometBlockStoreShuffleReader[_, _]] + case CoalescedMapperPartitionSpec(startMapIndex, endMapIndex, numReducers) => + SparkEnv.get.shuffleManager + .getReader( + dep.shuffleHandle, + startMapIndex, + endMapIndex, + 0, + numReducers, + context, + sqlMetricsReporter) + .asInstanceOf[CometBlockStoreShuffleReader[_, _]] + } + val rawStream = reader.readAsRawStream() + new CometShuffleBlockIterator(rawStream) + } + case _ => // Not a CometShuffledBatchRDD, skip + } + } + } + factories.toMap + } + /** * Find all plan nodes with per-partition planning data in the plan tree. Returns two maps keyed * by a unique identifier: one for common data (shared across partitions) and one for From bf7040ffbf1bf93d0b6672a45506fc8dc7a6756c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Mar 2026 16:37:44 -0600 Subject: [PATCH 11/16] test: add shuffle direct read tests Fix two bugs discovered during testing: - ClassCastException: factory closure incorrectly cast Partition to CometExecPartition before extracting ShuffledRowRDDPartition; the partition passed to the factory is already the unwrapped partition from the input RDD - NoSuchElementException in SQLShuffleReadMetricsReporter: metrics field in CometShuffledBatchRDD was not exposed as a val, causing Map.empty to be used instead of the real shuffle metrics map Add Scala integration test that runs a repartition+aggregate query with direct read enabled and disabled to verify result parity. Add Rust unit test for read_ipc_compressed codec round-trip. --- .../src/execution/operators/projection.rs | 2 +- .../src/execution/operators/shuffle_scan.rs | 58 ++++++++++++++++++- native/core/src/execution/planner.rs | 26 ++++----- .../src/jvm_bridge/shuffle_block_iterator.rs | 14 +++-- .../shuffle/CometShuffledRowRDD.scala | 2 +- .../apache/spark/sql/comet/operators.scala | 9 +-- .../comet/exec/CometNativeShuffleSuite.scala | 17 +++++- 7 files changed, 100 insertions(+), 28 deletions(-) diff --git a/native/core/src/execution/operators/projection.rs b/native/core/src/execution/operators/projection.rs index 4169ed8d40..194fa6769a 100644 --- a/native/core/src/execution/operators/projection.rs +++ b/native/core/src/execution/operators/projection.rs @@ -25,7 +25,7 @@ use jni::objects::GlobalRef; use crate::{ execution::{ - planner::{operator_registry::OperatorBuilder, PlanCreationResult, PhysicalPlanner}, + planner::{operator_registry::OperatorBuilder, PhysicalPlanner, PlanCreationResult}, spark_plan::SparkPlan, }, extract_op, diff --git a/native/core/src/execution/operators/shuffle_scan.rs b/native/core/src/execution/operators/shuffle_scan.rs index 4a8d09111b..567a6e22f4 100644 --- a/native/core/src/execution/operators/shuffle_scan.rs +++ b/native/core/src/execution/operators/shuffle_scan.rs @@ -321,8 +321,8 @@ impl Stream for ShuffleScanStream { InputBatch::EOF => Poll::Ready(None), InputBatch::Batch(columns, num_rows) => { self.baseline_metrics.record_output(*num_rows); - let options = arrow::array::RecordBatchOptions::new() - .with_row_count(Some(*num_rows)); + let options = + arrow::array::RecordBatchOptions::new().with_row_count(Some(*num_rows)); let maybe_batch = arrow::array::RecordBatch::try_new_with_options( Arc::clone(&self.schema), columns.clone(), @@ -346,3 +346,57 @@ impl RecordBatchStream for ShuffleScanStream { Arc::clone(&self.schema) } } + +#[cfg(test)] +mod tests { + use crate::execution::shuffle::codec::{CompressionCodec, ShuffleBlockWriter}; + use arrow::array::{Int32Array, StringArray}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion::physical_plan::metrics::Time; + use std::io::Cursor; + use std::sync::Arc; + + use crate::execution::shuffle::codec::read_ipc_compressed; + + #[test] + fn test_read_compressed_ipc_block() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["a", "b", "c"])), + ], + ) + .unwrap(); + + // Write as compressed IPC + let writer = + ShuffleBlockWriter::try_new(&batch.schema(), CompressionCodec::Zstd(1)).unwrap(); + let mut buf = Cursor::new(Vec::new()); + let ipc_time = Time::new(); + writer.write_batch(&batch, &mut buf, &ipc_time).unwrap(); + + // Read back (skip 16-byte header: 8 compressed_length + 8 field_count) + let bytes = buf.into_inner(); + let body = &bytes[16..]; + + let decoded = read_ipc_compressed(body).unwrap(); + assert_eq!(decoded.num_rows(), 3); + assert_eq!(decoded.num_columns(), 2); + + // Verify data + let col0 = decoded + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col0.value(0), 1); + assert_eq!(col0.value(1), 2); + assert_eq!(col0.value(2), 3); + } +} diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index e19891a0d6..b5892d763c 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1773,9 +1773,7 @@ impl PhysicalPlanner { let data_types = scan.fields.iter().map(to_arrow_datatype).collect_vec(); if self.exec_context_id != TEST_EXEC_CONTEXT_ID && inputs.is_empty() { - return Err(GeneralError( - "No input for shuffle scan".to_string(), - )); + return Err(GeneralError("No input for shuffle scan".to_string())); } let input_source = @@ -1785,11 +1783,8 @@ impl PhysicalPlanner { Some(inputs.remove(0)) }; - let shuffle_scan = ShuffleScanExec::new( - self.exec_context_id, - input_source, - data_types, - )?; + let shuffle_scan = + ShuffleScanExec::new(self.exec_context_id, input_source, data_types)?; Ok(( vec![], @@ -3736,7 +3731,8 @@ mod tests { let input_array = DictionaryArray::new(keys, Arc::new(values)); let input_batch = InputBatch::Batch(vec![Arc::new(input_array)], row_count); - let (mut scans, _shuffle_scans, datafusion_plan) = planner.create_plan(&op, &mut vec![], 1).unwrap(); + let (mut scans, _shuffle_scans, datafusion_plan) = + planner.create_plan(&op, &mut vec![], 1).unwrap(); scans[0].set_input_batch(input_batch); let session_ctx = SessionContext::new(); @@ -3810,7 +3806,8 @@ mod tests { let input_array = DictionaryArray::new(keys, Arc::new(values)); let input_batch = InputBatch::Batch(vec![Arc::new(input_array)], row_count); - let (mut scans, _shuffle_scans, datafusion_plan) = planner.create_plan(&op, &mut vec![], 1).unwrap(); + let (mut scans, _shuffle_scans, datafusion_plan) = + planner.create_plan(&op, &mut vec![], 1).unwrap(); // Scan's schema is determined by the input batch, so we need to set it before execution. scans[0].set_input_batch(input_batch); @@ -3857,7 +3854,8 @@ mod tests { let op = create_filter(op_scan, 0); let planner = PhysicalPlanner::default(); - let (mut scans, _shuffle_scans, datafusion_plan) = planner.create_plan(&op, &mut vec![], 1).unwrap(); + let (mut scans, _shuffle_scans, datafusion_plan) = + planner.create_plan(&op, &mut vec![], 1).unwrap(); let scan = &mut scans[0]; scan.set_input_batch(InputBatch::EOF); @@ -3942,7 +3940,8 @@ mod tests { let op = create_filter(op_scan, 0); let planner = PhysicalPlanner::default(); - let (_scans, _shuffle_scans, filter_exec) = planner.create_plan(&op, &mut vec![], 1).unwrap(); + let (_scans, _shuffle_scans, filter_exec) = + planner.create_plan(&op, &mut vec![], 1).unwrap(); assert_eq!("FilterExec", filter_exec.native_plan.name()); assert_eq!(1, filter_exec.children.len()); @@ -3966,7 +3965,8 @@ mod tests { let planner = PhysicalPlanner::default(); - let (_scans, _shuffle_scans, hash_join_exec) = planner.create_plan(&op_join, &mut vec![], 1).unwrap(); + let (_scans, _shuffle_scans, hash_join_exec) = + planner.create_plan(&op_join, &mut vec![], 1).unwrap(); assert_eq!("HashJoinExec", hash_join_exec.native_plan.name()); assert_eq!(2, hash_join_exec.children.len()); diff --git a/native/core/src/jvm_bridge/shuffle_block_iterator.rs b/native/core/src/jvm_bridge/shuffle_block_iterator.rs index 02fcf8ca27..c3bb5af5fb 100644 --- a/native/core/src/jvm_bridge/shuffle_block_iterator.rs +++ b/native/core/src/jvm_bridge/shuffle_block_iterator.rs @@ -45,11 +45,17 @@ impl<'a> CometShuffleBlockIterator<'a> { class, method_has_next: env.get_method_id(Self::JVM_CLASS, "hasNext", "()I")?, method_has_next_ret: ReturnType::Primitive(Primitive::Int), - method_get_buffer: env - .get_method_id(Self::JVM_CLASS, "getBuffer", "()Ljava/nio/ByteBuffer;")?, + method_get_buffer: env.get_method_id( + Self::JVM_CLASS, + "getBuffer", + "()Ljava/nio/ByteBuffer;", + )?, method_get_buffer_ret: ReturnType::Object, - method_get_current_block_length: env - .get_method_id(Self::JVM_CLASS, "getCurrentBlockLength", "()I")?, + method_get_current_block_length: env.get_method_id( + Self::JVM_CLASS, + "getCurrentBlockLength", + "()I", + )?, method_get_current_block_length_ret: ReturnType::Primitive(Primitive::Int), }) } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala index ba6fc588e2..6594982c85 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch */ class CometShuffledBatchRDD( var dependency: ShuffleDependency[Int, _, _], - metrics: Map[String, SQLMetric], + val metrics: Map[String, SQLMetric], partitionSpecs: Array[ShufflePartitionSpec]) extends RDD[ColumnarBatch](dependency.rdd.context, Nil) { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index a0cb14bbd0..9edaf447c5 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -667,16 +667,13 @@ abstract class CometNativeExec extends CometExec { inputs(scanIdx) match { case rdd: CometShuffledBatchRDD => val dep = rdd.dependency + val rddMetrics = rdd.metrics factories(scanIdx) = (context, part) => { - val shufflePart = - part - .asInstanceOf[CometExecPartition] - .inputPartitions(scanIdx) - .asInstanceOf[ShuffledRowRDDPartition] + val shufflePart = part.asInstanceOf[ShuffledRowRDDPartition] val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics() val sqlMetricsReporter = - new SQLShuffleReadMetricsReporter(tempMetrics, Map.empty) + new SQLShuffleReadMetricsReporter(tempMetrics, rddMetrics) val reader = shufflePart.spec match { case CoalescedPartitionSpec(startReducerIndex, endReducerIndex, _) => SparkEnv.get.shuffleManager diff --git a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala index 1cf43ea598..11f825e70d 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.SparkEnv import org.apache.spark.sql.{CometTestBase, DataFrame, Dataset, Row} import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.{col, count, sum} import org.apache.comet.CometConf @@ -437,4 +437,19 @@ class CometNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper } } } + + test("shuffle direct read produces same results as FFI path") { + Seq(true, false).foreach { directRead => + withSQLConf(CometConf.COMET_SHUFFLE_DIRECT_READ_ENABLED.key -> directRead.toString) { + val df = spark + .range(1000) + .selectExpr("id", "id % 10 as key", "cast(id as string) as value") + .repartition(4, col("key")) + .groupBy("key") + .agg(sum("id").as("total"), count("value").as("cnt")) + .orderBy("key") + checkSparkAnswer(df) + } + } + } } From c91d1b9af8826315e1d419859f85eef8c8a54356 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Mar 2026 16:45:08 -0600 Subject: [PATCH 12/16] refactor: simplify shuffle direct read code - Remove redundant getCurrentBlockLength() JNI call (reuse hasNext() return value) - Make readAsRawStream() lazy instead of materializing all streams to a List - Remove pointless DirectByteBuffer re-allocation in close() - Remove dead sparkPlanToInputIdx map --- native/core/src/execution/operators/shuffle_scan.rs | 10 ++-------- .../org/apache/comet/CometShuffleBlockIterator.java | 3 --- .../shuffle/CometBlockStoreShuffleReader.scala | 7 +++++-- .../org/apache/spark/sql/comet/operators.scala | 13 ------------- 4 files changed, 7 insertions(+), 26 deletions(-) diff --git a/native/core/src/execution/operators/shuffle_scan.rs b/native/core/src/execution/operators/shuffle_scan.rs index 567a6e22f4..80c72a6d31 100644 --- a/native/core/src/execution/operators/shuffle_scan.rs +++ b/native/core/src/execution/operators/shuffle_scan.rs @@ -152,7 +152,7 @@ impl ShuffleScanExec { let mut env = JVMClasses::get_env()?; - // has_next() returns block length or -1 if no more blocks + // has_next() reads the next block and returns its length, or -1 if EOF let block_length: i32 = unsafe { jni_call!(&mut env, comet_shuffle_block_iterator(iter).has_next() -> i32)? @@ -168,15 +168,9 @@ impl ShuffleScanExec { comet_shuffle_block_iterator(iter).get_buffer() -> JObject)? }; - // Get the actual block length (may differ from has_next return value) - let length: i32 = unsafe { - jni_call!(&mut env, - comet_shuffle_block_iterator(iter).get_current_block_length() -> i32)? - }; - let byte_buffer = JByteBuffer::from(buffer); let raw_pointer = env.get_direct_buffer_address(&byte_buffer)?; - let length = length as usize; + let length = block_length as usize; let slice: &[u8] = unsafe { std::slice::from_raw_parts(raw_pointer, length) }; // Decode the compressed IPC data diff --git a/spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java b/spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java index 5de5e05c4e..f9abef1c36 100644 --- a/spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java +++ b/spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java @@ -133,9 +133,6 @@ public void close() throws IOException { if (!closed) { closed = true; inputStream.close(); - if (dataBuf.capacity() > INITIAL_BUFFER_SIZE) { - dataBuf = ByteBuffer.allocateDirect(INITIAL_BUFFER_SIZE); - } } } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala index 647d4a0856..ecf37efe2a 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala @@ -160,8 +160,11 @@ class CometBlockStoreShuffleReader[K, C]( * Used by ShuffleScan direct read path. */ def readAsRawStream(): InputStream = { - val streams = fetchIterator.map(_._2).toList - new java.io.SequenceInputStream(java.util.Collections.enumeration(streams.asJava)) + val streams = fetchIterator.map(_._2) + new java.io.SequenceInputStream(new java.util.Enumeration[InputStream] { + override def hasMoreElements: Boolean = streams.hasNext + override def nextElement(): InputStream = streams.next() + }) } private def fetchContinuousBlocksInBatch: Boolean = { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 9edaf447c5..2e195e73eb 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -647,19 +647,6 @@ abstract class CometNativeExec extends CometExec { : Map[Int, (TaskContext, Partition) => CometShuffleBlockIterator] = { if (shuffleScanIndices.isEmpty) return Map.empty - // Build the mapping from sparkPlans index to inputs index - // (CometNativeExec entries are skipped in inputs) - var inputIdx = 0 - val sparkPlanToInputIdx = mutable.Map.empty[Int, Int] - sparkPlans.zipWithIndex.foreach { case (plan, spIdx) => - plan match { - case _: CometNativeExec => // skipped, no input - case _ => - sparkPlanToInputIdx(spIdx) = inputIdx - inputIdx += 1 - } - } - val factories = mutable.Map.empty[Int, (TaskContext, Partition) => CometShuffleBlockIterator] shuffleScanIndices.foreach { scanIdx => From b41889dcd6f466b35769692b71517ed1f56f5779 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Mar 2026 17:03:26 -0600 Subject: [PATCH 13/16] style: remove unused import --- .../comet/execution/shuffle/CometBlockStoreShuffleReader.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala index ecf37efe2a..14e656f038 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala @@ -21,8 +21,6 @@ package org.apache.spark.sql.comet.execution.shuffle import java.io.InputStream -import scala.jdk.CollectionConverters._ - import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext} import org.apache.spark.internal.{config, Logging} import org.apache.spark.io.CompressionCodec From 6e24a270f400588d1035e307ad50fac4a85164f4 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Mar 2026 18:21:20 -0600 Subject: [PATCH 14/16] remove design doc --- .../2026-03-18-shuffle-direct-read-design.md | 163 ------------------ 1 file changed, 163 deletions(-) delete mode 100644 docs/superpowers/specs/2026-03-18-shuffle-direct-read-design.md diff --git a/docs/superpowers/specs/2026-03-18-shuffle-direct-read-design.md b/docs/superpowers/specs/2026-03-18-shuffle-direct-read-design.md deleted file mode 100644 index 2f002a2d89..0000000000 --- a/docs/superpowers/specs/2026-03-18-shuffle-direct-read-design.md +++ /dev/null @@ -1,163 +0,0 @@ -# Shuffle Direct Read: Bypass FFI for Native Shuffle Read Path - -## Problem - -When a native shuffle exchange feeds into a downstream native operator, shuffle data crosses the JVM/native FFI boundary twice: - -1. **Native to JVM**: `decodeShuffleBlock` JNI call decompresses Arrow IPC, creates a `RecordBatch`, and exports it via Arrow C Data Interface (per-column `FFI_ArrowArray` + `FFI_ArrowSchema` allocation, export, and import). -2. **JVM to Native**: `CometBatchIterator` re-exports the `ColumnarBatch` via Arrow C Data Interface back to native, where `ScanExec` imports and copies/unpacks the arrays. - -Each crossing involves per-column schema serialization, struct allocation, and array copying. For queries with many shuffle stages or wide schemas, this overhead is significant. - -## Solution - -Introduce a direct read path where native code consumes compressed shuffle blocks directly, bypassing Arrow FFI entirely. The JVM reads raw bytes from Spark's shuffle infrastructure and hands them to native via a `DirectByteBuffer` (zero-copy pointer access). Native decompresses and decodes in-place, feeding `RecordBatch` directly into the execution plan. - -### Data Flow Comparison - -**Current path (double FFI):** - -``` -Shuffle stream - -> NativeBatchDecoderIterator (JVM) - -> JNI: decodeShuffleBlock - -> FFI export: RecordBatch -> ArrowArray/Schema (native -> JVM) - -> ColumnarBatch on JVM - -> CometBatchIterator - -> FFI export: ColumnarBatch -> ArrowArray/Schema (JVM -> native) - -> ScanExec imports + copies arrays - -> Native operators -``` - -**New path (zero FFI):** - -``` -Shuffle stream - -> CometShuffleBlockIterator (JVM) - -> reads header + compressed body into DirectByteBuffer - -> holds bytes, waits for native pull - -ShuffleScanExec (native, pull-based) - -> JNI callback: iterator.hasNext()/getBuffer() - -> read_ipc_compressed() -> RecordBatch - -> feeds directly into native execution plan -``` - -## Scope - -- Native shuffle (`CometNativeShuffle`) only. JVM columnar shuffle is excluded because its per-batch dictionary encoding decisions can change the schema between batches. -- Both paths (old and new) are retained. A config flag controls which is used. - -## Components - -### New JVM Components - -#### `CometShuffleBlockIterator` (Java) - -A new class that wraps a shuffle `InputStream` and exposes raw compressed blocks for native consumption. Absorbs the header-reading and buffer-management logic from `NativeBatchDecoderIterator`, but does not decode. - -JNI-callable interface: - -- `hasNext() -> int`: Reads the next block's header from the stream. The header is 16 bytes: 8-byte compressed length (includes the 8-byte field count but not itself) + 8-byte field count. The field count from the header is discarded — the schema is determined by the `ShuffleScan` protobuf's `fields` list, which is authoritative. Returns the compressed body length in bytes (i.e., `compressedLength - 8`, which includes the 4-byte codec prefix + compressed IPC data), or -1 for EOF. -- `getBuffer() -> ByteBuffer`: Returns the `DirectByteBuffer` containing the current block's compressed bytes (4-byte codec prefix + compressed IPC data). This buffer is only valid until the next `hasNext()` call — the caller must fully consume it (via `read_ipc_compressed()`, which decompresses into a new allocation) before pulling the next block. - -Uses its own `DirectByteBuffer` instance (not shared with `NativeBatchDecoderIterator`) with the same pooling strategy: initial 128KB, grows as needed, reset on close. - -**Lifecycle**: Implements `Closeable`. `close()` closes the underlying shuffle `InputStream` and resets the buffer. `CometBlockStoreShuffleReader` registers a task completion listener to close it, matching the existing pattern for `NativeBatchDecoderIterator`. - -### New Native Components - -#### `ShuffleScanExec` (Rust) - -Location: `native/core/src/execution/operators/shuffle_scan.rs` - -A new `ExecutionPlan` operator that replaces `ScanExec` at shuffle boundaries. On each `poll_next`: - -1. Calls JNI into `CometShuffleBlockIterator.hasNext()` to get the next block's byte length (or -1 for EOF). -2. Calls `CometShuffleBlockIterator.getBuffer()` to get a `DirectByteBuffer`. -3. Obtains the buffer's raw pointer via `JNIEnv::get_direct_buffer_address()` and creates a slice over it (zero-copy, same pattern as `decodeShuffleBlock`). -4. Calls `read_ipc_compressed()` to decompress and decode into a `RecordBatch`. This allocates new memory for the decompressed data — the `DirectByteBuffer` can be safely reused afterward. -5. Returns the `RecordBatch` directly to the downstream native operator. - -No `FFI_ArrowArray`, `FFI_ArrowSchema`, `ArrowImporter`, or `CometVector` involved. - -Implements `on_close` for cleanup (releasing the JNI `GlobalRef`), matching the `ScanExec` pattern. - -#### `ShuffleScan` Protobuf Message - -Location: `native/proto/src/proto/operator.proto` - -New message alongside existing `Scan`: - -```protobuf -message ShuffleScan { - repeated spark.spark_expression.DataType fields = 1; - string source = 2; // Informational label (e.g., "CometShuffleExchangeExec [id=5]") -} -``` - -The `Operator` message gains a new `shuffle_scan` field in its oneof. - -### Modified JVM Components - -#### `CometExchangeSink` / `CometExecRule` - -The decision to use `ShuffleScan` vs `Scan` is made when `CometNativeExec` is constructed (not during the bottom-up conversion pass). At that point, the operator tree is already converted: `CometExecRule.convertBlock()` wraps a contiguous group of native operators into `CometNativeExec` and serializes the protobuf plan. The children (including `CometSinkPlaceHolder` wrapping shuffle exchanges) are already known. So the check is: when serializing a `CometSinkPlaceHolder` whose `originalPlan` is a `CometShuffleExchangeExec` with `shuffleType == CometNativeShuffle`, and the config flag is enabled, emit `ShuffleScan` instead of `Scan`. - -Conditions for `ShuffleScan`: - -1. Shuffle type is `CometNativeShuffle` -2. The sink is inside a `CometNativeExec` block (always true at serialization time — this is where sinks get serialized) -3. Config `spark.comet.shuffle.directRead.enabled` is true (default: true) - -#### `CometNativeExec` (operators.scala) - -When collecting input RDDs and creating iterators, distinguish the two cases: - -- `ShuffleScan` input: Wrap the shuffle RDD's `Iterator[ColumnarBatch]` stream in `CometShuffleBlockIterator` — but note that `CometShuffleBlockIterator` wraps the raw `InputStream` from shuffle blocks, not decoded `ColumnarBatch`. This means the RDD must provide the raw shuffle `InputStream` rather than going through `NativeBatchDecoderIterator`. The `CometShuffledBatchRDD` / `CometBlockStoreShuffleReader` needs a mode where it yields raw `InputStream` objects per block instead of decoded batches. -- `Scan` input: Wrap in `CometBatchIterator` (existing behavior) - -#### `CometExecIterator` — JNI Input Contract - -Currently `CometExecIterator` wraps all inputs as `CometBatchIterator` and passes them to `Native.createPlan()` as `Array[CometBatchIterator]`. To support `CometShuffleBlockIterator`: - -- Change the JNI parameter from `Array[CometBatchIterator]` to `Array[Object]`. On the native side in `createPlan`, the planner already knows from the protobuf whether each input is a `Scan` or `ShuffleScan`, so it knows which JNI methods to call on each `GlobalRef` — no type checking needed at runtime. -- `CometExecIterator` populates the array with either `CometBatchIterator` or `CometShuffleBlockIterator` based on whether the corresponding leaf in the protobuf plan is `Scan` or `ShuffleScan`. - -### Native Planner Changes - -In `planner.rs`, handle the `ShuffleScan` protobuf variant: - -- Consume an input from `inputs.remove(0)` (same pattern as `Scan`) -- Create `ShuffleScanExec` instead of `ScanExec` -- The `GlobalRef` points to a `CometShuffleBlockIterator` Java object - -## Fallback Behavior - -The new path is used only when all conditions above are met. Otherwise, the existing path is used unchanged. The most common fallback case is a shuffle whose output is consumed by a non-native Spark operator (e.g., `collect()`, or an unsupported operator), where the JVM needs a materialized `ColumnarBatch`. - -## Configuration - -| Config | Default | Description | -|--------|---------|-------------| -| `spark.comet.shuffle.directRead.enabled` | `true` | Use direct native read path for native shuffle when downstream operator is native | - -## Error Handling - -- `ShuffleScanExec` reuses `read_ipc_compressed()`, which handles corrupt data and unsupported codecs. -- JNI errors from `CometShuffleBlockIterator` (stream closed, EOF, I/O errors) propagate through the existing `try_unwrap_or_throw` pattern. -- If the JVM iterator throws, the exception surfaces as a Rust error and propagates through DataFusion's error handling. -- Empty batches (zero rows): `read_ipc_compressed()` calls `reader.next().unwrap()` which panics if the stream contains no batches. The shuffle writer never writes zero-row blocks (guarded by `if batch.num_rows() == 0 { return Ok(0) }` in `ShuffleBlockWriter.write_batch`), so this case does not arise. - -## Metrics - -`ShuffleScanExec` tracks and reports: - -- `decodeTime`: Time spent in `read_ipc_compressed()` (decompression + IPC decode). Same metric as `NativeBatchDecoderIterator` reports today. -- Shuffle read metrics (`recordsRead`, `bytesRead`) continue to be reported by `CometBlockStoreShuffleReader` and the `ShuffleBlockFetcherIterator`, which are upstream of the new code and unchanged. - -## Testing - -- Existing shuffle tests (`CometShuffleSuite`) run with the config defaulting to true, automatically covering the new path. -- Add a test that runs the same queries with the config flag on and off, asserting identical results. -- Add a Rust unit test for `ShuffleScanExec` with pre-built compressed IPC blocks (no JNI), using the `TEST_EXEC_CONTEXT_ID` pattern from `ScanExec` tests. From 33c2f11a00e69b41133edb8e1c7b9c495fd607ab Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Mar 2026 18:23:24 -0600 Subject: [PATCH 15/16] Remove doc --- .../plans/2026-03-18-shuffle-direct-read.md | 1011 ----------------- 1 file changed, 1011 deletions(-) delete mode 100644 docs/superpowers/plans/2026-03-18-shuffle-direct-read.md diff --git a/docs/superpowers/plans/2026-03-18-shuffle-direct-read.md b/docs/superpowers/plans/2026-03-18-shuffle-direct-read.md deleted file mode 100644 index 647f122cc4..0000000000 --- a/docs/superpowers/plans/2026-03-18-shuffle-direct-read.md +++ /dev/null @@ -1,1011 +0,0 @@ -# Shuffle Direct Read Implementation Plan - -> **For agentic workers:** REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (`- [ ]`) syntax for tracking. - -**Goal:** Eliminate double Arrow FFI crossing at shuffle boundaries by having native code consume compressed IPC blocks directly from JVM-provided byte buffers. - -**Architecture:** A new `ShuffleScanExec` Rust operator pulls raw compressed bytes from a JVM `CometShuffleBlockIterator` via JNI, decompresses and decodes them in native code, and feeds `RecordBatch` directly into the execution plan. This bypasses the current path where data is decoded to JVM `ColumnarBatch` (FFI export), then re-exported back to native (FFI import). - -**Tech Stack:** Scala, Java, Rust, Protobuf, JNI, Arrow IPC - -**Spec:** `docs/superpowers/specs/2026-03-18-shuffle-direct-read-design.md` - ---- - -### Task 1: Add config flag - -**Files:** -- Modify: `common/src/main/scala/org/apache/comet/CometConf.scala` - -- [ ] **Step 1: Add the config entry** - -Find the existing shuffle config entries (search for `COMET_EXEC_SHUFFLE_ENABLED`) and add nearby: - -```scala -val COMET_SHUFFLE_DIRECT_READ_ENABLED: ConfigEntry[Boolean] = - conf("spark.comet.shuffle.directRead.enabled") - .category(CATEGORY_EXEC) - .doc( - "When enabled, native operators that consume shuffle output will read " + - "compressed shuffle blocks directly in native code, bypassing Arrow FFI. " + - "Only applies to native shuffle (not JVM columnar shuffle). " + - "Requires spark.comet.exec.shuffle.enabled to be true.") - .booleanConf - .createWithDefault(true) -``` - -- [ ] **Step 2: Verify it compiles** - -Run: `./mvnw compile -DskipTests -pl common` -Expected: BUILD SUCCESS - -- [ ] **Step 3: Commit** - -```bash -git add common/src/main/scala/org/apache/comet/CometConf.scala -git commit -m "feat: add spark.comet.shuffle.directRead.enabled config" -``` - ---- - -### Task 2: Add ShuffleScan protobuf message - -**Files:** -- Modify: `native/proto/src/proto/operator.proto` - -- [ ] **Step 1: Add ShuffleScan message** - -Add after the existing `Scan` message (after line 86): - -```protobuf -message ShuffleScan { - repeated spark.spark_expression.DataType fields = 1; - // Informational label for debug output (e.g., "CometShuffleExchangeExec [id=5]") - string source = 2; -} -``` - -- [ ] **Step 2: Add shuffle_scan to the Operator oneof** - -In the `oneof op_struct` block (lines 38-55), add after `csv_scan = 115`: - -```protobuf - ShuffleScan shuffle_scan = 116; -``` - -- [ ] **Step 3: Rebuild protobuf and verify** - -Run: `make core` -Expected: Successful build with generated protobuf code. - -- [ ] **Step 4: Commit** - -```bash -git add native/proto/src/proto/operator.proto -git commit -m "feat: add ShuffleScan protobuf message" -``` - ---- - -### Task 3: Create CometShuffleBlockIterator (Java) - -**Files:** -- Create: `spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java` - -- [ ] **Step 1: Create the class** - -```java -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.comet; - -import java.io.Closeable; -import java.io.EOFException; -import java.io.IOException; -import java.io.InputStream; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.channels.Channels; -import java.nio.channels.ReadableByteChannel; - -/** - * Provides raw compressed shuffle blocks to native code via JNI. - * - *

Reads block headers (compressed length + field count) from a shuffle InputStream and loads - * the compressed body into a DirectByteBuffer. Native code pulls blocks by calling hasNext() - * and getBuffer(). - * - *

The DirectByteBuffer returned by getBuffer() is only valid until the next hasNext() call. - * Native code must fully consume it (via read_ipc_compressed which allocates new memory for - * the decompressed data) before pulling the next block. - */ -public class CometShuffleBlockIterator implements Closeable { - - private static final int INITIAL_BUFFER_SIZE = 128 * 1024; - - private final ReadableByteChannel channel; - private final InputStream inputStream; - private final ByteBuffer headerBuf = - ByteBuffer.allocate(16).order(ByteOrder.LITTLE_ENDIAN); - private ByteBuffer dataBuf = ByteBuffer.allocateDirect(INITIAL_BUFFER_SIZE); - private boolean closed = false; - private int currentBlockLength = 0; - - public CometShuffleBlockIterator(InputStream in) { - this.inputStream = in; - this.channel = Channels.newChannel(in); - } - - /** - * Reads the next block header and loads the compressed body into the internal buffer. - * Called by native code via JNI. - * - *

Header format: 8-byte compressedLength (includes field count but not itself) + - * 8-byte fieldCount (discarded, schema comes from protobuf). - * - * @return the compressed body length in bytes (codec prefix + compressed IPC), or -1 if EOF - */ - public int hasNext() throws IOException { - if (closed) { - return -1; - } - - // Read 16-byte header - headerBuf.clear(); - while (headerBuf.hasRemaining()) { - int bytesRead = channel.read(headerBuf); - if (bytesRead < 0) { - if (headerBuf.position() == 0) { - return -1; - } - throw new EOFException( - "Data corrupt: unexpected EOF while reading batch header"); - } - } - headerBuf.flip(); - long compressedLength = headerBuf.getLong(); - // Field count discarded - schema determined by ShuffleScan protobuf fields - headerBuf.getLong(); - - long bytesToRead = compressedLength - 8; - if (bytesToRead > Integer.MAX_VALUE) { - throw new IllegalStateException( - "Native shuffle block size of " + bytesToRead + " exceeds maximum of " - + Integer.MAX_VALUE + ". Try reducing shuffle batch size."); - } - - if (dataBuf.capacity() < bytesToRead) { - int newCapacity = (int) Math.min(bytesToRead * 2L, Integer.MAX_VALUE); - dataBuf = ByteBuffer.allocateDirect(newCapacity); - } - - dataBuf.clear(); - dataBuf.limit((int) bytesToRead); - while (dataBuf.hasRemaining()) { - int bytesRead = channel.read(dataBuf); - if (bytesRead < 0) { - throw new EOFException( - "Data corrupt: unexpected EOF while reading compressed batch"); - } - } - // Note: native side uses get_direct_buffer_address (base pointer) + currentBlockLength, - // not the buffer's position/limit. No flip needed. - - currentBlockLength = (int) bytesToRead; - return currentBlockLength; - } - - /** - * Returns the DirectByteBuffer containing the current block's compressed bytes - * (4-byte codec prefix + compressed IPC data). - * Called by native code via JNI. - */ - public ByteBuffer getBuffer() { - return dataBuf; - } - - /** - * Returns the length of the current block in bytes. - * Called by native code via JNI. - */ - public int getCurrentBlockLength() { - return currentBlockLength; - } - - @Override - public void close() throws IOException { - if (!closed) { - closed = true; - inputStream.close(); - if (dataBuf.capacity() > INITIAL_BUFFER_SIZE) { - dataBuf = ByteBuffer.allocateDirect(INITIAL_BUFFER_SIZE); - } - } - } -} -``` - -- [ ] **Step 2: Verify it compiles** - -Run: `./mvnw compile -DskipTests` -Expected: BUILD SUCCESS - -- [ ] **Step 3: Commit** - -```bash -git add spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java -git commit -m "feat: add CometShuffleBlockIterator for raw shuffle block access" -``` - ---- - -### Task 4: Add JNI bridge for CometShuffleBlockIterator (Rust) - -**Files:** -- Create: `native/core/src/jvm_bridge/shuffle_block_iterator.rs` -- Modify: `native/core/src/jvm_bridge/mod.rs` - -- [ ] **Step 1: Create the JNI bridge struct** - -Create `native/core/src/jvm_bridge/shuffle_block_iterator.rs`: - -```rust -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use jni::signature::Primitive; -use jni::{ - errors::Result as JniResult, - objects::{JClass, JMethodID}, - signature::ReturnType, - JNIEnv, -}; - -/// JNI method IDs for `CometShuffleBlockIterator`. -#[allow(dead_code)] -pub struct CometShuffleBlockIterator<'a> { - pub class: JClass<'a>, - pub method_has_next: JMethodID, - pub method_has_next_ret: ReturnType, - pub method_get_buffer: JMethodID, - pub method_get_buffer_ret: ReturnType, - pub method_get_current_block_length: JMethodID, - pub method_get_current_block_length_ret: ReturnType, -} - -impl<'a> CometShuffleBlockIterator<'a> { - pub const JVM_CLASS: &'static str = "org/apache/comet/CometShuffleBlockIterator"; - - pub fn new(env: &mut JNIEnv<'a>) -> JniResult> { - let class = env.find_class(Self::JVM_CLASS)?; - - Ok(CometShuffleBlockIterator { - class, - method_has_next: env.get_method_id(Self::JVM_CLASS, "hasNext", "()I")?, - method_has_next_ret: ReturnType::Primitive(Primitive::Int), - method_get_buffer: env.get_method_id( - Self::JVM_CLASS, - "getBuffer", - "()Ljava/nio/ByteBuffer;", - )?, - method_get_buffer_ret: ReturnType::Object, - method_get_current_block_length: env.get_method_id( - Self::JVM_CLASS, - "getCurrentBlockLength", - "()I", - )?, - method_get_current_block_length_ret: ReturnType::Primitive(Primitive::Int), - }) - } -} -``` - -- [ ] **Step 2: Register in mod.rs** - -In `native/core/src/jvm_bridge/mod.rs`: - -Add `mod shuffle_block_iterator;` alongside the existing `mod batch_iterator;` (line 174). - -Add `use shuffle_block_iterator::CometShuffleBlockIterator as CometShuffleBlockIteratorBridge;` (to avoid name collision with the operator). - -Add a field to the `JVMClasses` struct (around line 206): -```rust -pub comet_shuffle_block_iterator: CometShuffleBlockIteratorBridge<'a>, -``` - -Initialize it in `JVMClasses::init` alongside the existing `comet_batch_iterator` init (around line 259): -```rust -comet_shuffle_block_iterator: CometShuffleBlockIteratorBridge::new(env).unwrap(), -``` - -- [ ] **Step 3: Add a `jni_call!` compatible accessor** - -Check how `comet_batch_iterator` is called in `scan.rs`. The `jni_call!` macro uses the field name from `JVMClasses`. Ensure `comet_shuffle_block_iterator` follows the same pattern. You may need to add a module in the `jni_bridge` macros — look at how `jni_call!(&mut env, comet_batch_iterator(iter).has_next() -> i32)` is defined and add equivalent patterns for `comet_shuffle_block_iterator`. - -Check `native/core/src/jvm_bridge/` for macro definitions (likely in a separate file or in `mod.rs`) that define the `jni_call!` dispatch for each class. - -- [ ] **Step 4: Verify it compiles** - -Run: `cd native && cargo build` -Expected: Successful build. - -- [ ] **Step 5: Commit** - -```bash -git add native/core/src/jvm_bridge/shuffle_block_iterator.rs -git add native/core/src/jvm_bridge/mod.rs -git commit -m "feat: add JNI bridge for CometShuffleBlockIterator" -``` - ---- - -### Task 5: Create ShuffleScanExec (Rust) - -**Files:** -- Create: `native/core/src/execution/operators/shuffle_scan.rs` -- Modify: `native/core/src/execution/operators/mod.rs` - -**Design decision — pre-pull pattern:** `ShuffleScanExec` MUST use the pre-pull pattern (same as `ScanExec`). The comment at `jni_api.rs:483-488` explains why: JNI calls cannot happen from within `poll_next` on tokio threads. So `ShuffleScanExec` stores a `batch: Arc>>` and `get_next_batch()` is called from `pull_input_batches` before each `poll_next`. - -- [ ] **Step 1: Create shuffle_scan.rs** - -Use `scan.rs` as the template. The key differences: -- `get_next_batch` calls `hasNext()`/`getBuffer()`/`getCurrentBlockLength()` on `CometShuffleBlockIterator` instead of Arrow FFI methods on `CometBatchIterator` -- After getting the `DirectByteBuffer`, call `read_ipc_compressed()` to decode -- No `arrow_ffi_safe` flag, no selection vectors, no `copy_or_unpack_array` -- Track `decode_time` metric - -The core `get_next` method: - -```rust -fn get_next( - exec_context_id: i64, - iter: &JObject, - data_types: &[DataType], -) -> Result { - let mut env = JVMClasses::get_env()?; - - // Call hasNext() — returns block length or -1 for EOF - let block_length: i32 = unsafe { - jni_call!(&mut env, comet_shuffle_block_iterator(iter).has_next() -> i32)? - }; - - if block_length < 0 { - return Ok(InputBatch::EOF); - } - - // Get the DirectByteBuffer - let buffer: JByteBuffer = unsafe { - jni_call!(&mut env, comet_shuffle_block_iterator(iter).get_buffer() -> JObject)? - }.into(); - - // Get raw pointer to the buffer data - let raw_pointer = env.get_direct_buffer_address(&buffer)?; - let length = block_length as usize; - let slice: &[u8] = unsafe { std::slice::from_raw_parts(raw_pointer, length) }; - - // Decompress and decode the IPC block - let batch = read_ipc_compressed(slice)?; - - // Convert RecordBatch columns to InputBatch - let arrays: Vec = batch.columns().to_vec(); - let num_rows = batch.num_rows(); - - Ok(InputBatch::new(arrays, Some(num_rows))) -} -``` - -For the `ExecutionPlan` trait implementation, follow `ScanExec` closely: -- `schema()` returns schema built from `data_types` -- `execute()` returns a `ScanStream` (reuse the same stream type from `scan.rs`) -- The `ScanStream` checks `self.batch` mutex on each `poll_next`, takes the batch if available - -- [ ] **Step 2: Register the module** - -In `native/core/src/execution/operators/mod.rs`, add: - -```rust -mod shuffle_scan; -pub use shuffle_scan::ShuffleScanExec; -``` - -- [ ] **Step 3: Verify it compiles** - -Run: `cd native && cargo build` -Expected: Successful build. - -- [ ] **Step 4: Commit** - -```bash -git add native/core/src/execution/operators/shuffle_scan.rs -git add native/core/src/execution/operators/mod.rs -git commit -m "feat: add ShuffleScanExec native operator for direct shuffle read" -``` - ---- - -### Task 6: Wire ShuffleScanExec into the native planner and pre-pull - -**Files:** -- Modify: `native/core/src/execution/planner.rs` -- Modify: `native/core/src/execution/jni_api.rs` - -**Design decision — separate scan vectors:** The planner's `create_plan` currently returns `(Vec, Arc)`. Change the return type to include shuffle scans: `(Vec, Vec, Arc)`. All intermediate operators pass both vectors through. `ExecutionContext` gets a new `shuffle_scans: Vec` field, and `pull_input_batches` iterates both. - -- [ ] **Step 1: Update create_plan return type** - -In `planner.rs`, change the `create_plan` return type (line 915): - -```rust -) -> Result<(Vec, Vec, Arc), ExecutionError> -``` - -Update every match arm that calls `create_plan` recursively or returns results: -- Single-child operators (Filter, Project, Sort, etc.): destructure as `let (scans, shuffle_scans, child) = ...` and pass both through -- Multi-child operators (joins via `parse_join_parameters`): concatenate both scan vectors from left and right children -- `Scan` arm: returns `(vec![scan.clone()], vec![], ...)` -- Add `ShuffleScan` arm (see step 2) - -This is a mechanical change across many match arms. Each `Ok((scans, ...))` becomes `Ok((scans, shuffle_scans, ...))`. - -Also update `parse_join_parameters` return type similarly. - -- [ ] **Step 2: Add ShuffleScan match arm** - -```rust -OpStruct::ShuffleScan(scan) => { - let data_types = scan.fields.iter().map(to_arrow_datatype).collect_vec(); - - if self.exec_context_id != TEST_EXEC_CONTEXT_ID && inputs.is_empty() { - return Err(GeneralError("No input for shuffle scan".to_string())); - } - - let input_source = - if self.exec_context_id == TEST_EXEC_CONTEXT_ID && inputs.is_empty() { - None - } else { - Some(inputs.remove(0)) - }; - - let shuffle_scan = ShuffleScanExec::new( - self.exec_context_id, - input_source, - &scan.source, - data_types, - )?; - - Ok(( - vec![], - vec![shuffle_scan.clone()], - Arc::new(SparkPlan::new(spark_plan.plan_id, Arc::new(shuffle_scan), vec![])), - )) -} -``` - -- [ ] **Step 3: Update ExecutionContext and pull_input_batches** - -In `jni_api.rs`: - -Add `shuffle_scans: Vec` field to `ExecutionContext` struct (after `scans` on line 153). Initialize as `shuffle_scans: vec![]` in the constructor (line 313). - -Where `create_plan` results are stored (line 542-550): - -```rust -let (scans, shuffle_scans, root_op) = planner.create_plan(...)?; -exec_context.scans = scans; -exec_context.shuffle_scans = shuffle_scans; -``` - -Update `pull_input_batches` (line 490): - -```rust -fn pull_input_batches(exec_context: &mut ExecutionContext) -> Result<(), CometError> { - exec_context.scans.iter_mut().try_for_each(|scan| { - scan.get_next_batch()?; - Ok::<(), CometError>(()) - })?; - exec_context.shuffle_scans.iter_mut().try_for_each(|scan| { - scan.get_next_batch()?; - Ok::<(), CometError>(()) - }) -} -``` - -Also update the `exec_context.scans.is_empty()` check (line 563) to also check `shuffle_scans`: - -```rust -if exec_context.scans.is_empty() && exec_context.shuffle_scans.is_empty() { -``` - -- [ ] **Step 4: Verify it compiles** - -Run: `cd native && cargo build` -Expected: Successful build. - -- [ ] **Step 5: Commit** - -```bash -git add native/core/src/execution/planner.rs -git add native/core/src/execution/jni_api.rs -git commit -m "feat: wire ShuffleScanExec into planner and pre-pull mechanism" -``` - ---- - -### Task 7: Emit ShuffleScan from JVM serde - -**Files:** -- Modify: `spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala` - -The `CometExchangeSink.convert()` receives the outer operator (e.g., `ShuffleQueryStageExec`) not the inner `CometShuffleExchangeExec`. We must unwrap to check `shuffleType`. - -- [ ] **Step 1: Override convert in CometExchangeSink** - -Replace the `CometExchangeSink` object (lines 87-100) with: - -```scala -object CometExchangeSink extends CometSink[SparkPlan] { - - override def isFfiSafe: Boolean = true - - override def convert( - op: SparkPlan, - builder: Operator.Builder, - childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { - if (shouldUseShuffleScan(op)) { - convertToShuffleScan(op, builder) - } else { - super.convert(op, builder, childOp: _*) - } - } - - private def shouldUseShuffleScan(op: SparkPlan): Boolean = { - if (!CometConf.COMET_SHUFFLE_DIRECT_READ_ENABLED.get()) return false - - // Extract the CometShuffleExchangeExec from the wrapper - val shuffleExec = op match { - case ShuffleQueryStageExec(_, s: CometShuffleExchangeExec, _) => Some(s) - case ShuffleQueryStageExec(_, ReusedExchangeExec(_, s: CometShuffleExchangeExec), _) => - Some(s) - case s: CometShuffleExchangeExec => Some(s) - case _ => None - } - - shuffleExec.exists(_.shuffleType == CometNativeShuffle) - } - - private def convertToShuffleScan( - op: SparkPlan, - builder: Operator.Builder): Option[OperatorOuterClass.Operator] = { - val supportedTypes = - op.output.forall(a => supportedDataType(a.dataType, allowComplex = true)) - - if (!supportedTypes) { - withInfo(op, "Unsupported data type for shuffle direct read") - return None - } - - val scanBuilder = OperatorOuterClass.ShuffleScan.newBuilder() - val source = op.simpleStringWithNodeId() - if (source.isEmpty) { - scanBuilder.setSource(op.getClass.getSimpleName) - } else { - scanBuilder.setSource(source) - } - - val scanTypes = op.output.flatMap { attr => - serializeDataType(attr.dataType) - } - - if (scanTypes.length == op.output.length) { - scanBuilder.addAllFields(scanTypes.asJava) - builder.clearChildren() - Some(builder.setShuffleScan(scanBuilder).build()) - } else { - withInfo(op, "unsupported data types for shuffle direct read") - // Fall back to regular Scan - None - } - } - - override def createExec(nativeOp: Operator, op: SparkPlan): CometNativeExec = - CometSinkPlaceHolder(nativeOp, op, op) -} -``` - -Add necessary imports at the top of the file: -```scala -import org.apache.spark.sql.comet.execution.shuffle.{CometNativeShuffle, CometShuffleExchangeExec} -import org.apache.spark.sql.execution.exchange.{ReusedExchangeExec, ShuffleExchangeExec} -import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec -import org.apache.comet.CometConf -``` - -- [ ] **Step 2: Verify it compiles** - -Run: `./mvnw compile -DskipTests` -Expected: BUILD SUCCESS - -- [ ] **Step 3: Commit** - -```bash -git add spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala -git commit -m "feat: emit ShuffleScan protobuf for native shuffle with direct read" -``` - ---- - -### Task 8: Wire CometShuffleBlockIterator into JVM execution path - -**Files:** -- Modify: `spark/src/main/scala/org/apache/comet/Native.scala` -- Modify: `spark/src/main/scala/org/apache/comet/CometExecIterator.scala` -- Modify: `spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala` -- Modify: `spark/src/main/scala/org/apache/spark/sql/comet/operators.scala` -- Modify: `spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala` - -This task connects the JVM plumbing so that `ShuffleScan` inputs get `CometShuffleBlockIterator` (wrapping raw `InputStream`) instead of `CometBatchIterator` (wrapping decoded `ColumnarBatch`). - -**Key insight**: Currently all inputs flow through `RDD[ColumnarBatch]`. For shuffle direct read, we need the raw `InputStream` before decoding. The approach: add a parallel input channel for raw shuffle streams alongside the existing `ColumnarBatch` inputs. - -- [ ] **Step 1: Change Native.scala createPlan signature** - -In `spark/src/main/scala/org/apache/comet/Native.scala` (line 57), change: - -```scala -iterators: Array[CometBatchIterator], -``` -to: -```scala -iterators: Array[Object], -``` - -The JNI side (`jni_api.rs:190`) already uses `JObjectArray`, so no Rust changes needed. - -- [ ] **Step 2: Add shuffle stream inputs to CometExecIterator** - -In `spark/src/main/scala/org/apache/comet/CometExecIterator.scala`, add a parameter for shuffle block iterators that should be used instead of regular batch iterators at specific input positions: - -```scala -class CometExecIterator( - val id: Long, - inputs: Seq[Iterator[ColumnarBatch]], - numOutputCols: Int, - protobufQueryPlan: Array[Byte], - nativeMetrics: CometMetricNode, - numParts: Int, - partitionIndex: Int, - broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]] = None, - encryptedFilePaths: Seq[String] = Seq.empty, - shuffleBlockIterators: Map[Int, CometShuffleBlockIterator] = Map.empty) -``` - -Replace the `cometBatchIterators` construction (lines 81-83): - -```scala -private val nativeIterators: Array[Object] = { - val result = new Array[Object](inputs.size) - inputs.zipWithIndex.foreach { case (iterator, idx) => - result(idx) = shuffleBlockIterators.getOrElse( - idx, - new CometBatchIterator(iterator, nativeUtil)) - } - result -} -``` - -Change `nativeLib.createPlan(id, cometBatchIterators, ...)` (line 109) to use `nativeIterators`. - -In the `close()` method, also close `CometShuffleBlockIterator` instances: -```scala -shuffleBlockIterators.values.foreach { iter => - try { iter.close() } catch { case _: Exception => } -} -``` - -- [ ] **Step 3: Add shuffle stream support to CometExecRDD** - -In `spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala`, add a parameter to carry shuffle block iterator factories: - -```scala -private[spark] class CometExecRDD( - sc: SparkContext, - var inputRDDs: Seq[RDD[ColumnarBatch]], - ... - encryptedFilePaths: Seq[String] = Seq.empty, - shuffleBlockIteratorFactories: Map[Int, (TaskContext, Partition) => CometShuffleBlockIterator] = Map.empty) -``` - -In the `compute` method (line 112), pass them to `CometExecIterator`: - -```scala -// Create shuffle block iterators for this partition -val shuffleBlockIters = shuffleBlockIteratorFactories.map { case (idx, factory) => - idx -> factory(context, partition.inputPartitions(idx)) -} - -val it = new CometExecIterator( - CometExec.newIterId, - inputs, - numOutputCols, - actualPlan, - nativeMetrics, - numPartitions, - partition.index, - broadcastedHadoopConfForEncryption, - encryptedFilePaths, - shuffleBlockIters) -``` - -- [ ] **Step 4: Identify ShuffleScan inputs in operators.scala** - -In `spark/src/main/scala/org/apache/spark/sql/comet/operators.scala`, in `CometNativeExec.doExecuteColumnar` (around line 480): - -After `foreachUntilCometInput(this)(sparkPlans += _)`, determine which inputs correspond to `ShuffleScan` operators. Parse the serialized protobuf plan to find `ShuffleScan` leaf positions: - -```scala -import org.apache.comet.serde.OperatorOuterClass - -// Find which input indices correspond to ShuffleScan operators -val shuffleScanIndices: Set[Int] = { - val plan = OperatorOuterClass.Operator.parseFrom(serializedPlanCopy) - var scanIndex = 0 - val indices = scala.collection.mutable.Set.empty[Int] - def walk(op: OperatorOuterClass.Operator): Unit = { - if (op.hasShuffleScan) { - indices += scanIndex - scanIndex += 1 - } else if (op.hasScan) { - scanIndex += 1 - } else { - // Recurse into children in order - (0 until op.getChildrenCount).foreach(i => walk(op.getChildren(i))) - } - } - walk(plan) - indices.toSet -} -``` - -Then in the `sparkPlans.zipWithIndex.foreach` loop (line 523), for plans at shuffle scan indices, create a factory that produces `CometShuffleBlockIterator`: - -```scala -val shuffleBlockIteratorFactories = scala.collection.mutable.Map.empty[Int, (TaskContext, Partition) => CometShuffleBlockIterator] - -sparkPlans.zipWithIndex.foreach { case (plan, idx) => - plan match { - // ... existing cases ... - case _ if shuffleScanIndices.contains(inputIndexForPlan(idx)) => - // Still add the RDD for partition tracking, but also register - // a factory for the raw InputStream - val rdd = plan.executeColumnar() - inputs += rdd - // The factory creates a CometShuffleBlockIterator from the raw shuffle stream - // We need to get the raw InputStream - see Step 5 - shuffleBlockIteratorFactories(inputs.size - 1) = ... - // ... remaining cases ... - } -} -``` - -The tricky part is getting the raw `InputStream` from the shuffle read. See Step 5. - -- [ ] **Step 5: Add raw InputStream mode to CometBlockStoreShuffleReader** - -In `spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala`: - -The current `read()` method creates `NativeBatchDecoderIterator` which decodes blocks. For direct read, we need a mode that yields the raw `InputStream` wrapped in `CometShuffleBlockIterator`. - -Add a method: - -```scala -def readRawStreams(): Iterator[CometShuffleBlockIterator] = { - fetchIterator.map { case (_, inputStream) => - new CometShuffleBlockIterator(inputStream) - } -} -``` - -The challenge is that `CometShuffledBatchRDD` calls `reader.read()` which returns `Iterator[Product2[Int, ColumnarBatch]]`. For the direct read path, we need a different RDD that calls `readRawStreams()` instead. - -**Approach**: Create `CometShuffledRawStreamRDD` — a simple RDD that wraps the shuffle reader and yields `CometShuffleBlockIterator` objects per block. Then in `operators.scala`, instead of using the ColumnarBatch RDD, create a `CometShuffledRawStreamRDD` and pass its iterator-producing factory to `CometExecRDD`. - -Alternatively, since `CometShuffleBlockIterator` wraps a single `InputStream` that may contain multiple blocks, and `fetchIterator` yields one `InputStream` per shuffle block, the simplest approach is to **concatenate all InputStreams into one** per partition: - -```scala -def readAsRawStream(): InputStream = { - val streams = fetchIterator.map(_._2) - new SequenceInputStream(java.util.Collections.enumeration( - streams.toList.asJava)) -} -``` - -Then in the factory: `(ctx, part) => new CometShuffleBlockIterator(reader.readAsRawStream())` - -But the reader is created per-partition in `CometShuffledBatchRDD.compute()`. The factory approach means the reader creation must be deferred. - -**Simplest concrete approach**: Instead of a factory, create a new RDD `CometShuffledRawRDD` that returns `Iterator[CometShuffleBlockIterator]`. Pass this as a separate input alongside the regular `ColumnarBatch` inputs: - -```scala -// In CometExecRDD, add: -shuffleRawInputRDDs: Seq[(Int, RDD[CometShuffleBlockIterator])] -``` - -In `compute`, create iterators from these RDDs and pass them to `CometExecIterator` via the `shuffleBlockIterators` map. - -This is the most invasive part of the implementation. The exact approach should be determined by reading the code at implementation time, as there are multiple valid paths. The key constraint: the raw `InputStream` from `fetchIterator` must reach `CometShuffleBlockIterator` without going through `NativeBatchDecoderIterator`. - -- [ ] **Step 6: Verify it compiles** - -Run: `./mvnw compile -DskipTests` -Expected: BUILD SUCCESS - -- [ ] **Step 7: Commit** - -```bash -git add spark/src/main/scala/org/apache/comet/Native.scala -git add spark/src/main/scala/org/apache/comet/CometExecIterator.scala -git add spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala -git add spark/src/main/scala/org/apache/spark/sql/comet/operators.scala -git add spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala -git commit -m "feat: wire CometShuffleBlockIterator into JVM execution path" -``` - ---- - -### Task 9: End-to-end testing - -**Files:** -- Modify: Appropriate test suite (find the right suite by searching for existing shuffle tests) - -- [ ] **Step 1: Build everything** - -Run: `make` -Expected: Successful build of both native and JVM. - -- [ ] **Step 2: Run existing shuffle tests** - -Run: `./mvnw test -Dsuites="org.apache.comet.exec.CometShuffleSuite"` -Expected: All existing tests pass (they now use the new direct read path by default). - -If tests fail, debug by setting `spark.comet.shuffle.directRead.enabled=false` to confirm the old path still works, then investigate the new path. - -- [ ] **Step 3: Add comparison test** - -Add a test that runs the same queries with direct read enabled and disabled: - -```scala -test("shuffle direct read produces same results as FFI path") { - Seq(true, false).foreach { directRead => - withSQLConf( - CometConf.COMET_SHUFFLE_DIRECT_READ_ENABLED.key -> directRead.toString) { - val df = spark.range(1000) - .selectExpr("id", "id % 10 as key", "cast(id as string) as value") - .repartition(4, col("key")) - .groupBy("key") - .agg(sum("id").as("total"), count("value").as("cnt")) - .orderBy("key") - checkSparkAnswer(df) - } - } -} -``` - -- [ ] **Step 4: Add Rust unit test for ShuffleScanExec** - -In `native/core/src/execution/operators/shuffle_scan.rs`, add a `#[cfg(test)]` module: - -```rust -#[cfg(test)] -mod tests { - use super::*; - use crate::execution::shuffle::codec::{CompressionCodec, ShuffleBlockWriter}; - use arrow::array::{Int32Array, StringArray}; - use arrow::datatypes::{Field, Schema}; - use arrow::record_batch::RecordBatch; - use std::io::Cursor; - use std::sync::Arc; - - #[test] - fn test_read_compressed_ipc_block() { - // Create a test RecordBatch - let schema = Arc::new(Schema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("name", DataType::Utf8, true), - ])); - let batch = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from(vec![1, 2, 3])), - Arc::new(StringArray::from(vec!["a", "b", "c"])), - ], - ).unwrap(); - - // Write it as compressed IPC using ShuffleBlockWriter - let writer = ShuffleBlockWriter::try_new( - &batch.schema(), CompressionCodec::Zstd(1) - ).unwrap(); - let mut buf = Cursor::new(Vec::new()); - let ipc_time = datafusion::physical_plan::metrics::Time::new(); - writer.write_batch(&batch, &mut buf, &ipc_time).unwrap(); - - // Read back the body (skip the 16-byte header) - let bytes = buf.into_inner(); - let body = &bytes[16..]; // Skip compressed_length(8) + field_count(8) - - // Decode using read_ipc_compressed - let decoded = read_ipc_compressed(body).unwrap(); - assert_eq!(decoded.num_rows(), 3); - assert_eq!(decoded.num_columns(), 2); - } -} -``` - -- [ ] **Step 5: Run all tests** - -Run: `make test` - -- [ ] **Step 6: Run clippy** - -Run: `cd native && cargo clippy --all-targets --workspace -- -D warnings` -Expected: No warnings. - -- [ ] **Step 7: Format** - -Run: `make format` - -- [ ] **Step 8: Commit** - -```bash -git add -A -git commit -m "test: add shuffle direct read tests" -``` - ---- - -## Implementation Notes - -### Task 8 is the hardest - -The core challenge is routing raw `InputStream` from Spark's shuffle infrastructure through to `CometShuffleBlockIterator` without going through the decode path. The current RDD pipeline (`CometShuffledBatchRDD` → `CometBlockStoreShuffleReader.read()` → `NativeBatchDecoderIterator`) always decodes. You need to intercept before `NativeBatchDecoderIterator` is created. - -The most surgical approach: in `CometBlockStoreShuffleReader`, add a `readRaw()` method that returns the raw `InputStream` (or a `CometShuffleBlockIterator` wrapping it) instead of decoded batches. Then create a parallel RDD (`CometShuffledRawRDD`) that calls `readRaw()` in its `compute` method and pass it through to `CometExecIterator`. - -### Metrics - -`ShuffleScanExec` should track `decode_time` using DataFusion's `Time` metric. Register it in `ShuffleScanExec::new` via `MetricBuilder` following the pattern in `ScanExec`. - -### Order of tasks - -Tasks 1-7 can be done sequentially. Task 8 depends on all previous tasks. Task 9 validates everything. From 19cb04b0d34af995046942005111873e757b7518 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 18 Mar 2026 19:11:43 -0600 Subject: [PATCH 16/16] test: skip miri-incompatible zstd FFI test Skip test_read_compressed_ipc_block under Miri since it calls foreign zstd functions that Miri cannot execute. --- native/core/src/execution/operators/shuffle_scan.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/native/core/src/execution/operators/shuffle_scan.rs b/native/core/src/execution/operators/shuffle_scan.rs index 80c72a6d31..163fc9992a 100644 --- a/native/core/src/execution/operators/shuffle_scan.rs +++ b/native/core/src/execution/operators/shuffle_scan.rs @@ -354,6 +354,7 @@ mod tests { use crate::execution::shuffle::codec::read_ipc_compressed; #[test] + #[cfg_attr(miri, ignore)] // Miri cannot call FFI functions (zstd) fn test_read_compressed_ipc_block() { let schema = Arc::new(Schema::new(vec![ Field::new("id", DataType::Int32, false),