diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 4d2e37924a..ad3774567c 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -343,6 +343,17 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(true) + val COMET_SHUFFLE_DIRECT_READ_ENABLED: ConfigEntry[Boolean] = + conf("spark.comet.shuffle.directRead.enabled") + .category(CATEGORY_EXEC) + .doc( + "When enabled, native operators that consume shuffle output will read " + + "compressed shuffle blocks directly in native code, bypassing Arrow FFI. " + + "Only applies to native shuffle (not JVM columnar shuffle). " + + "Requires spark.comet.exec.shuffle.enabled to be true.") + .booleanConf + .createWithDefault(true) + val COMET_SHUFFLE_MODE: ConfigEntry[String] = conf(s"$COMET_EXEC_CONFIG_PREFIX.shuffle.mode") .category(CATEGORY_SHUFFLE) .doc( diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 361deae182..d20cf128b5 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -82,7 +82,7 @@ use tokio::sync::mpsc; use crate::execution::memory_pools::{ create_memory_pool, handle_task_shared_pool_release, parse_memory_pool_config, MemoryPoolConfig, }; -use crate::execution::operators::ScanExec; +use crate::execution::operators::{ScanExec, ShuffleScanExec}; use crate::execution::shuffle::{read_ipc_compressed, CompressionCodec}; use crate::execution::spark_plan::SparkPlan; @@ -151,6 +151,8 @@ struct ExecutionContext { pub root_op: Option>, /// The input sources for the DataFusion plan pub scans: Vec, + /// The shuffle scan input sources for the DataFusion plan + pub shuffle_scans: Vec, /// The global reference of input sources for the DataFusion plan pub input_sources: Vec>, /// The record batch stream to pull results from @@ -311,6 +313,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( partition_count: partition_count as usize, root_op: None, scans: vec![], + shuffle_scans: vec![], input_sources, stream: None, batch_receiver: None, @@ -491,6 +494,10 @@ fn pull_input_batches(exec_context: &mut ExecutionContext) -> Result<(), CometEr exec_context.scans.iter_mut().try_for_each(|scan| { scan.get_next_batch()?; Ok::<(), CometError>(()) + })?; + exec_context.shuffle_scans.iter_mut().try_for_each(|scan| { + scan.get_next_batch()?; + Ok::<(), CometError>(()) }) } @@ -539,7 +546,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( let planner = PhysicalPlanner::new(Arc::clone(&exec_context.session_ctx), partition) .with_exec_id(exec_context_id); - let (scans, root_op) = planner.create_plan( + let (scans, shuffle_scans, root_op) = planner.create_plan( &exec_context.spark_plan, &mut exec_context.input_sources.clone(), exec_context.partition_count, @@ -548,6 +555,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( exec_context.plan_creation_time += physical_plan_time; exec_context.scans = scans; + exec_context.shuffle_scans = shuffle_scans; if exec_context.explain_native { let formatted_plan_str = @@ -560,7 +568,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( // so we should always execute partition 0. let stream = root_op.native_plan.execute(0, task_ctx)?; - if exec_context.scans.is_empty() { + if exec_context.scans.is_empty() && exec_context.shuffle_scans.is_empty() { // No JVM data sources — spawn onto tokio so the executor // thread parks in blocking_recv instead of busy-polling. // diff --git a/native/core/src/execution/operators/mod.rs b/native/core/src/execution/operators/mod.rs index 07ee995367..ad3ec3f08b 100644 --- a/native/core/src/execution/operators/mod.rs +++ b/native/core/src/execution/operators/mod.rs @@ -34,7 +34,9 @@ pub use parquet_writer::ParquetWriterExec; mod csv_scan; pub mod projection; mod scan; +mod shuffle_scan; pub use csv_scan::init_csv_datasource_exec; +pub use shuffle_scan::ShuffleScanExec; /// Error returned during executing operators. #[derive(thiserror::Error, Debug)] diff --git a/native/core/src/execution/operators/projection.rs b/native/core/src/execution/operators/projection.rs index 6ba1bb5d59..194fa6769a 100644 --- a/native/core/src/execution/operators/projection.rs +++ b/native/core/src/execution/operators/projection.rs @@ -25,8 +25,7 @@ use jni::objects::GlobalRef; use crate::{ execution::{ - operators::{ExecutionError, ScanExec}, - planner::{operator_registry::OperatorBuilder, PhysicalPlanner}, + planner::{operator_registry::OperatorBuilder, PhysicalPlanner, PlanCreationResult}, spark_plan::SparkPlan, }, extract_op, @@ -42,12 +41,13 @@ impl OperatorBuilder for ProjectionBuilder { inputs: &mut Vec>, partition_count: usize, planner: &PhysicalPlanner, - ) -> Result<(Vec, Arc), ExecutionError> { + ) -> PlanCreationResult { let project = extract_op!(spark_plan, Projection); let children = &spark_plan.children; assert_eq!(children.len(), 1); - let (scans, child) = planner.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + planner.create_plan(&children[0], inputs, partition_count)?; // Create projection expressions let exprs: Result, _> = project @@ -68,6 +68,7 @@ impl OperatorBuilder for ProjectionBuilder { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, projection, vec![child])), )) } diff --git a/native/core/src/execution/operators/shuffle_scan.rs b/native/core/src/execution/operators/shuffle_scan.rs new file mode 100644 index 0000000000..163fc9992a --- /dev/null +++ b/native/core/src/execution/operators/shuffle_scan.rs @@ -0,0 +1,397 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::{ + errors::CometError, + execution::{ + operators::ExecutionError, planner::TEST_EXEC_CONTEXT_ID, + shuffle::codec::read_ipc_compressed, + }, + jvm_bridge::{jni_call, JVMClasses}, +}; +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion::common::{arrow_datafusion_err, Result as DataFusionResult}; +use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion::physical_plan::metrics::{ + BaselineMetrics, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, Time, +}; +use datafusion::{ + execution::TaskContext, + physical_expr::*, + physical_plan::{ExecutionPlan, *}, +}; +use futures::Stream; +use jni::objects::{GlobalRef, JByteBuffer, JObject}; +use std::{ + any::Any, + pin::Pin, + sync::{Arc, Mutex}, + task::{Context, Poll}, +}; + +use super::scan::InputBatch; + +/// ShuffleScanExec reads compressed shuffle blocks from JVM via JNI and decodes them natively. +/// Unlike ScanExec which receives Arrow arrays via FFI, ShuffleScanExec receives raw compressed +/// bytes from CometShuffleBlockIterator and decodes them using read_ipc_compressed(). +#[derive(Debug, Clone)] +pub struct ShuffleScanExec { + /// The ID of the execution context that owns this subquery. + pub exec_context_id: i64, + /// The input source: a global reference to a JVM CometShuffleBlockIterator object. + pub input_source: Option>, + /// The data types of columns in the shuffle output. + pub data_types: Vec, + /// Schema of the shuffle output. + pub schema: SchemaRef, + /// The current input batch, populated by get_next_batch() before poll_next(). + pub batch: Arc>>, + /// Cache of plan properties. + cache: PlanProperties, + /// Metrics collector. + metrics: ExecutionPlanMetricsSet, + /// Baseline metrics. + baseline_metrics: BaselineMetrics, + /// Time spent decoding compressed shuffle blocks. + decode_time: Time, +} + +impl ShuffleScanExec { + pub fn new( + exec_context_id: i64, + input_source: Option>, + data_types: Vec, + ) -> Result { + let metrics_set = ExecutionPlanMetricsSet::default(); + let baseline_metrics = BaselineMetrics::new(&metrics_set, 0); + let decode_time = MetricBuilder::new(&metrics_set).subset_time("decode_time", 0); + + let schema = schema_from_data_types(&data_types); + + let cache = PlanProperties::new( + EquivalenceProperties::new(Arc::clone(&schema)), + Partitioning::UnknownPartitioning(1), + EmissionType::Final, + Boundedness::Bounded, + ); + + Ok(Self { + exec_context_id, + input_source, + data_types, + batch: Arc::new(Mutex::new(None)), + cache, + metrics: metrics_set, + baseline_metrics, + schema, + decode_time, + }) + } + + /// Feeds input batch into this scan. Only used in unit tests. + pub fn set_input_batch(&mut self, input: InputBatch) { + *self.batch.try_lock().unwrap() = Some(input); + } + + /// Pull next input batch from JVM. Called externally before poll_next() + /// because JNI calls cannot happen from within poll_next on tokio threads. + pub fn get_next_batch(&mut self) -> Result<(), CometError> { + if self.input_source.is_none() { + // Unit test mode - no JNI calls needed. + return Ok(()); + } + let mut timer = self.baseline_metrics.elapsed_compute().timer(); + + let mut current_batch = self.batch.try_lock().unwrap(); + if current_batch.is_none() { + let next_batch = Self::get_next( + self.exec_context_id, + self.input_source.as_ref().unwrap().as_obj(), + &self.data_types, + &self.decode_time, + )?; + *current_batch = Some(next_batch); + } + + timer.stop(); + + Ok(()) + } + + /// Invokes JNI calls to get the next compressed shuffle block and decode it. + fn get_next( + exec_context_id: i64, + iter: &JObject, + data_types: &[DataType], + decode_time: &Time, + ) -> Result { + if exec_context_id == TEST_EXEC_CONTEXT_ID { + return Ok(InputBatch::EOF); + } + + if iter.is_null() { + return Err(CometError::from(ExecutionError::GeneralError(format!( + "Null shuffle block iterator object. Plan id: {exec_context_id}" + )))); + } + + let mut env = JVMClasses::get_env()?; + + // has_next() reads the next block and returns its length, or -1 if EOF + let block_length: i32 = unsafe { + jni_call!(&mut env, + comet_shuffle_block_iterator(iter).has_next() -> i32)? + }; + + if block_length == -1 { + return Ok(InputBatch::EOF); + } + + // Get the DirectByteBuffer containing the compressed shuffle block + let buffer: JObject = unsafe { + jni_call!(&mut env, + comet_shuffle_block_iterator(iter).get_buffer() -> JObject)? + }; + + let byte_buffer = JByteBuffer::from(buffer); + let raw_pointer = env.get_direct_buffer_address(&byte_buffer)?; + let length = block_length as usize; + let slice: &[u8] = unsafe { std::slice::from_raw_parts(raw_pointer, length) }; + + // Decode the compressed IPC data + let mut timer = decode_time.timer(); + let batch = read_ipc_compressed(slice)?; + timer.stop(); + + let num_rows = batch.num_rows(); + + // The read_ipc_compressed already produces owned arrays, so we skip the + // header (field count + codec) that was already consumed by read_ipc_compressed. + // Extract column arrays from the RecordBatch. + let columns: Vec = batch.columns().to_vec(); + + debug_assert_eq!( + columns.len(), + data_types.len(), + "Shuffle block column count mismatch: got {} but expected {}", + columns.len(), + data_types.len() + ); + + Ok(InputBatch::new(columns, Some(num_rows))) + } +} + +fn schema_from_data_types(data_types: &[DataType]) -> SchemaRef { + let fields = data_types + .iter() + .enumerate() + .map(|(idx, dt)| Field::new(format!("col_{idx}"), dt.clone(), true)) + .collect::>(); + + Arc::new(Schema::new(fields)) +} + +impl ExecutionPlan for ShuffleScanExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> datafusion::common::Result> { + Ok(self) + } + + fn execute( + &self, + partition: usize, + _: Arc, + ) -> datafusion::common::Result { + Ok(Box::pin(ShuffleScanStream::new( + self.clone(), + self.schema(), + partition, + self.baseline_metrics.clone(), + ))) + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn name(&self) -> &str { + "ShuffleScanExec" + } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } +} + +impl DisplayAs for ShuffleScanExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + let fields: Vec = self + .data_types + .iter() + .enumerate() + .map(|(idx, dt)| format!("col_{idx}: {dt}")) + .collect(); + write!(f, "ShuffleScanExec: schema=[{}]", fields.join(", "))?; + } + DisplayFormatType::TreeRender => unimplemented!(), + } + Ok(()) + } +} + +/// An async stream that feeds decoded shuffle batches into the DataFusion plan. +struct ShuffleScanStream { + /// The ShuffleScanExec producing input batches. + shuffle_scan: ShuffleScanExec, + /// Schema of the output. + schema: SchemaRef, + /// Metrics. + baseline_metrics: BaselineMetrics, +} + +impl ShuffleScanStream { + pub fn new( + shuffle_scan: ShuffleScanExec, + schema: SchemaRef, + _partition: usize, + baseline_metrics: BaselineMetrics, + ) -> Self { + Self { + shuffle_scan, + schema, + baseline_metrics, + } + } +} + +impl Stream for ShuffleScanStream { + type Item = DataFusionResult; + + fn poll_next(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + let mut timer = self.baseline_metrics.elapsed_compute().timer(); + let mut scan_batch = self.shuffle_scan.batch.try_lock().unwrap(); + + let input_batch = &*scan_batch; + let input_batch = if let Some(batch) = input_batch { + batch + } else { + timer.stop(); + return Poll::Pending; + }; + + let result = match input_batch { + InputBatch::EOF => Poll::Ready(None), + InputBatch::Batch(columns, num_rows) => { + self.baseline_metrics.record_output(*num_rows); + let options = + arrow::array::RecordBatchOptions::new().with_row_count(Some(*num_rows)); + let maybe_batch = arrow::array::RecordBatch::try_new_with_options( + Arc::clone(&self.schema), + columns.clone(), + &options, + ) + .map_err(|e| arrow_datafusion_err!(e)); + Poll::Ready(Some(maybe_batch)) + } + }; + + *scan_batch = None; + + timer.stop(); + + result + } +} + +impl RecordBatchStream for ShuffleScanStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } +} + +#[cfg(test)] +mod tests { + use crate::execution::shuffle::codec::{CompressionCodec, ShuffleBlockWriter}; + use arrow::array::{Int32Array, StringArray}; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use datafusion::physical_plan::metrics::Time; + use std::io::Cursor; + use std::sync::Arc; + + use crate::execution::shuffle::codec::read_ipc_compressed; + + #[test] + #[cfg_attr(miri, ignore)] // Miri cannot call FFI functions (zstd) + fn test_read_compressed_ipc_block() { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["a", "b", "c"])), + ], + ) + .unwrap(); + + // Write as compressed IPC + let writer = + ShuffleBlockWriter::try_new(&batch.schema(), CompressionCodec::Zstd(1)).unwrap(); + let mut buf = Cursor::new(Vec::new()); + let ipc_time = Time::new(); + writer.write_batch(&batch, &mut buf, &ipc_time).unwrap(); + + // Read back (skip 16-byte header: 8 compressed_length + 8 field_count) + let bytes = buf.into_inner(); + let body = &bytes[16..]; + + let decoded = read_ipc_compressed(body).unwrap(); + assert_eq!(decoded.num_rows(), 3); + assert_eq!(decoded.num_columns(), 2); + + // Verify data + let col0 = decoded + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col0.value(0), 1); + assert_eq!(col0.value(1), 2); + assert_eq!(col0.value(2), 3); + } +} diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index bd37755922..b5892d763c 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -27,7 +27,7 @@ use crate::{ errors::ExpressionError, execution::{ expressions::subquery::Subquery, - operators::{ExecutionError, ExpandExec, ParquetWriterExec, ScanExec}, + operators::{ExecutionError, ExpandExec, ParquetWriterExec, ScanExec, ShuffleScanExec}, planner::expression_registry::ExpressionRegistry, planner::operator_registry::OperatorRegistry, serde::to_arrow_datatype, @@ -141,6 +141,8 @@ use url::Url; type PhyAggResult = Result, ExecutionError>; type PhyExprResult = Result, String)>, ExecutionError>; type PartitionPhyExprResult = Result>, ExecutionError>; +pub type PlanCreationResult = + Result<(Vec, Vec, Arc), ExecutionError>; struct JoinParameters { pub left: Arc, @@ -913,7 +915,7 @@ impl PhysicalPlanner { spark_plan: &'a Operator, inputs: &mut Vec>, partition_count: usize, - ) -> Result<(Vec, Arc), ExecutionError> { + ) -> PlanCreationResult { // Try to use the modular registry first - this automatically handles any registered operator types if OperatorRegistry::global().can_handle(spark_plan) { return OperatorRegistry::global().create_plan( @@ -929,7 +931,8 @@ impl PhysicalPlanner { match spark_plan.op_struct.as_ref().unwrap() { OpStruct::Filter(filter) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let predicate = self.create_expr(filter.predicate.as_ref().unwrap(), child.schema())?; @@ -940,12 +943,14 @@ impl PhysicalPlanner { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, filter, vec![child])), )) } OpStruct::HashAgg(agg) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let group_exprs: PhyExprResult = agg .grouping_exprs @@ -996,6 +1001,7 @@ impl PhysicalPlanner { if agg.result_exprs.is_empty() { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, aggregate, vec![child])), )) } else { @@ -1012,6 +1018,7 @@ impl PhysicalPlanner { )?); Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new_with_additional( spark_plan.plan_id, projection, @@ -1030,7 +1037,8 @@ impl PhysicalPlanner { "Invalid limit/offset combination: [{num}. {offset}]" ))); } - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let limit: Arc = if offset == 0 { Arc::new(LocalLimitExec::new( Arc::clone(&child.native_plan), @@ -1050,12 +1058,14 @@ impl PhysicalPlanner { }; Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, limit, vec![child])), )) } OpStruct::Sort(sort) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let exprs: Result, ExecutionError> = sort .sort_orders @@ -1079,6 +1089,7 @@ impl PhysicalPlanner { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new( spark_plan.plan_id, sort_exec, @@ -1115,6 +1126,7 @@ impl PhysicalPlanner { if partition_files.partitioned_file.is_empty() { let empty_exec = Arc::new(EmptyExec::new(required_schema)); return Ok(( + vec![], vec![], Arc::new(SparkPlan::new(spark_plan.plan_id, empty_exec, vec![])), )); @@ -1205,6 +1217,7 @@ impl PhysicalPlanner { common.encryption_enabled, )?; Ok(( + vec![], vec![], Arc::new(SparkPlan::new(spark_plan.plan_id, scan, vec![])), )) @@ -1243,6 +1256,7 @@ impl PhysicalPlanner { &scan.csv_options.clone().unwrap(), )?; Ok(( + vec![], vec![], Arc::new(SparkPlan::new(spark_plan.plan_id, scan, vec![])), )) @@ -1276,6 +1290,7 @@ impl PhysicalPlanner { Ok(( vec![scan.clone()], + vec![], Arc::new(SparkPlan::new(spark_plan.plan_id, Arc::new(scan), vec![])), )) } @@ -1307,6 +1322,7 @@ impl PhysicalPlanner { )?; Ok(( + vec![], vec![], Arc::new(SparkPlan::new( spark_plan.plan_id, @@ -1317,7 +1333,8 @@ impl PhysicalPlanner { } OpStruct::ShuffleWriter(writer) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let partitioning = self.create_partitioning( writer.partitioning.as_ref().unwrap(), @@ -1350,6 +1367,7 @@ impl PhysicalPlanner { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new( spark_plan.plan_id, shuffle_writer, @@ -1359,7 +1377,8 @@ impl PhysicalPlanner { } OpStruct::ParquetWriter(writer) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let codec = match writer.compression.try_into() { Ok(SparkCompressionCodec::None) => Ok(CompressionCodec::None), @@ -1396,6 +1415,7 @@ impl PhysicalPlanner { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new( spark_plan.plan_id, parquet_writer, @@ -1405,7 +1425,8 @@ impl PhysicalPlanner { } OpStruct::Expand(expand) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let mut projections = vec![]; let mut projection = vec![]; @@ -1448,12 +1469,14 @@ impl PhysicalPlanner { let expand = Arc::new(ExpandExec::new(projections, input, schema)); Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, expand, vec![child])), )) } OpStruct::Explode(explode) => { assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; // Create the expression for the array to explode let child_expr = if let Some(child_expr) = &explode.child { @@ -1559,11 +1582,12 @@ impl PhysicalPlanner { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, unnest_exec, vec![child])), )) } OpStruct::SortMergeJoin(join) => { - let (join_params, scans) = self.parse_join_parameters( + let (join_params, scans, shuffle_scans) = self.parse_join_parameters( inputs, children, &join.left_join_keys, @@ -1615,6 +1639,7 @@ impl PhysicalPlanner { )); Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new_with_additional( spark_plan.plan_id, coalesce_batches, @@ -1628,6 +1653,7 @@ impl PhysicalPlanner { } else { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new( spark_plan.plan_id, join, @@ -1640,7 +1666,7 @@ impl PhysicalPlanner { } } OpStruct::HashJoin(join) => { - let (join_params, scans) = self.parse_join_parameters( + let (join_params, scans, shuffle_scans) = self.parse_join_parameters( inputs, children, &join.left_join_keys, @@ -1670,6 +1696,7 @@ impl PhysicalPlanner { if join.build_side == BuildSide::BuildLeft as i32 { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new( spark_plan.plan_id, hash_join, @@ -1688,6 +1715,7 @@ impl PhysicalPlanner { Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new_with_additional( spark_plan.plan_id, swapped_hash_join, @@ -1698,7 +1726,8 @@ impl PhysicalPlanner { } } OpStruct::Window(wnd) => { - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let (scans, shuffle_scans, child) = + self.create_plan(&children[0], inputs, partition_count)?; let input_schema = child.schema(); let sort_exprs: Result, ExecutionError> = wnd .order_by_list @@ -1736,9 +1765,37 @@ impl PhysicalPlanner { )?); Ok(( scans, + shuffle_scans, Arc::new(SparkPlan::new(spark_plan.plan_id, window_agg, vec![child])), )) } + OpStruct::ShuffleScan(scan) => { + let data_types = scan.fields.iter().map(to_arrow_datatype).collect_vec(); + + if self.exec_context_id != TEST_EXEC_CONTEXT_ID && inputs.is_empty() { + return Err(GeneralError("No input for shuffle scan".to_string())); + } + + let input_source = + if self.exec_context_id == TEST_EXEC_CONTEXT_ID && inputs.is_empty() { + None + } else { + Some(inputs.remove(0)) + }; + + let shuffle_scan = + ShuffleScanExec::new(self.exec_context_id, input_source, data_types)?; + + Ok(( + vec![], + vec![shuffle_scan.clone()], + Arc::new(SparkPlan::new( + spark_plan.plan_id, + Arc::new(shuffle_scan), + vec![], + )), + )) + } _ => Err(GeneralError(format!( "Unsupported or unregistered operator type: {:?}", spark_plan.op_struct @@ -1756,12 +1813,15 @@ impl PhysicalPlanner { join_type: i32, condition: &Option, partition_count: usize, - ) -> Result<(JoinParameters, Vec), ExecutionError> { + ) -> Result<(JoinParameters, Vec, Vec), ExecutionError> { assert_eq!(children.len(), 2); - let (mut left_scans, left) = self.create_plan(&children[0], inputs, partition_count)?; - let (mut right_scans, right) = self.create_plan(&children[1], inputs, partition_count)?; + let (mut left_scans, mut left_shuffle_scans, left) = + self.create_plan(&children[0], inputs, partition_count)?; + let (mut right_scans, mut right_shuffle_scans, right) = + self.create_plan(&children[1], inputs, partition_count)?; left_scans.append(&mut right_scans); + left_shuffle_scans.append(&mut right_shuffle_scans); let left_join_exprs: Vec<_> = left_join_keys .iter() @@ -1882,6 +1942,7 @@ impl PhysicalPlanner { join_filter, }, left_scans, + left_shuffle_scans, )) } @@ -3670,7 +3731,8 @@ mod tests { let input_array = DictionaryArray::new(keys, Arc::new(values)); let input_batch = InputBatch::Batch(vec![Arc::new(input_array)], row_count); - let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut vec![], 1).unwrap(); + let (mut scans, _shuffle_scans, datafusion_plan) = + planner.create_plan(&op, &mut vec![], 1).unwrap(); scans[0].set_input_batch(input_batch); let session_ctx = SessionContext::new(); @@ -3744,7 +3806,8 @@ mod tests { let input_array = DictionaryArray::new(keys, Arc::new(values)); let input_batch = InputBatch::Batch(vec![Arc::new(input_array)], row_count); - let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut vec![], 1).unwrap(); + let (mut scans, _shuffle_scans, datafusion_plan) = + planner.create_plan(&op, &mut vec![], 1).unwrap(); // Scan's schema is determined by the input batch, so we need to set it before execution. scans[0].set_input_batch(input_batch); @@ -3791,7 +3854,8 @@ mod tests { let op = create_filter(op_scan, 0); let planner = PhysicalPlanner::default(); - let (mut scans, datafusion_plan) = planner.create_plan(&op, &mut vec![], 1).unwrap(); + let (mut scans, _shuffle_scans, datafusion_plan) = + planner.create_plan(&op, &mut vec![], 1).unwrap(); let scan = &mut scans[0]; scan.set_input_batch(InputBatch::EOF); @@ -3876,7 +3940,8 @@ mod tests { let op = create_filter(op_scan, 0); let planner = PhysicalPlanner::default(); - let (_scans, filter_exec) = planner.create_plan(&op, &mut vec![], 1).unwrap(); + let (_scans, _shuffle_scans, filter_exec) = + planner.create_plan(&op, &mut vec![], 1).unwrap(); assert_eq!("FilterExec", filter_exec.native_plan.name()); assert_eq!(1, filter_exec.children.len()); @@ -3900,7 +3965,8 @@ mod tests { let planner = PhysicalPlanner::default(); - let (_scans, hash_join_exec) = planner.create_plan(&op_join, &mut vec![], 1).unwrap(); + let (_scans, _shuffle_scans, hash_join_exec) = + planner.create_plan(&op_join, &mut vec![], 1).unwrap(); assert_eq!("HashJoinExec", hash_join_exec.native_plan.name()); assert_eq!(2, hash_join_exec.children.len()); @@ -4014,7 +4080,7 @@ mod tests { })), }; - let (mut scans, datafusion_plan) = + let (mut scans, _shuffle_scans, datafusion_plan) = planner.create_plan(&projection, &mut vec![], 1).unwrap(); let mut stream = datafusion_plan.native_plan.execute(0, task_ctx).unwrap(); @@ -4140,7 +4206,7 @@ mod tests { }; // Create a physical plan - let (mut scans, datafusion_plan) = + let (mut scans, _shuffle_scans, datafusion_plan) = planner.create_plan(&projection, &mut vec![], 1).unwrap(); // Start executing the plan in a separate thread @@ -4631,7 +4697,7 @@ mod tests { }; // Create the physical plan - let (mut scans, datafusion_plan) = + let (mut scans, _shuffle_scans, datafusion_plan) = planner.create_plan(&projection, &mut vec![], 1).unwrap(); // Create test data: Date32 and Int8 columns diff --git a/native/core/src/execution/planner/operator_registry.rs b/native/core/src/execution/planner/operator_registry.rs index b34a80df95..cad5df40c5 100644 --- a/native/core/src/execution/planner/operator_registry.rs +++ b/native/core/src/execution/planner/operator_registry.rs @@ -25,11 +25,8 @@ use std::{ use datafusion_comet_proto::spark_operator::Operator; use jni::objects::GlobalRef; -use super::PhysicalPlanner; -use crate::execution::{ - operators::{ExecutionError, ScanExec}, - spark_plan::SparkPlan, -}; +use super::{PhysicalPlanner, PlanCreationResult}; +use crate::execution::operators::ExecutionError; /// Trait for building physical operators from Spark protobuf operators pub trait OperatorBuilder: Send + Sync { @@ -40,7 +37,7 @@ pub trait OperatorBuilder: Send + Sync { inputs: &mut Vec>, partition_count: usize, planner: &PhysicalPlanner, - ) -> Result<(Vec, Arc), ExecutionError>; + ) -> PlanCreationResult; } /// Enum to identify different operator types for registry dispatch @@ -100,7 +97,7 @@ impl OperatorRegistry { inputs: &mut Vec>, partition_count: usize, planner: &PhysicalPlanner, - ) -> Result<(Vec, Arc), ExecutionError> { + ) -> PlanCreationResult { let operator_type = get_operator_type(spark_operator).ok_or_else(|| { ExecutionError::GeneralError(format!( "Unsupported operator type: {:?}", @@ -153,5 +150,6 @@ fn get_operator_type(spark_operator: &Operator) -> Option { OpStruct::Window(_) => Some(OperatorType::Window), OpStruct::Explode(_) => None, // Not yet in OperatorType enum OpStruct::CsvScan(_) => Some(OperatorType::CsvScan), + OpStruct::ShuffleScan(_) => None, // Not yet in OperatorType enum } } diff --git a/native/core/src/jvm_bridge/mod.rs b/native/core/src/jvm_bridge/mod.rs index 00fe7b33c3..85c2ae7577 100644 --- a/native/core/src/jvm_bridge/mod.rs +++ b/native/core/src/jvm_bridge/mod.rs @@ -174,11 +174,13 @@ pub use comet_exec::*; mod batch_iterator; mod comet_metric_node; mod comet_task_memory_manager; +mod shuffle_block_iterator; use crate::{errors::CometError, JAVA_VM}; use batch_iterator::CometBatchIterator; pub use comet_metric_node::*; pub use comet_task_memory_manager::*; +use shuffle_block_iterator::CometShuffleBlockIterator; /// The JVM classes that are used in the JNI calls. #[allow(dead_code)] // we need to keep references to Java items to prevent GC @@ -204,6 +206,8 @@ pub struct JVMClasses<'a> { pub comet_exec: CometExec<'a>, /// The CometBatchIterator class. Used for iterating over the batches. pub comet_batch_iterator: CometBatchIterator<'a>, + /// The CometShuffleBlockIterator class. Used for iterating over shuffle blocks. + pub comet_shuffle_block_iterator: CometShuffleBlockIterator<'a>, /// The CometTaskMemoryManager used for interacting with JVM side to /// acquire & release native memory. pub comet_task_memory_manager: CometTaskMemoryManager<'a>, @@ -257,6 +261,7 @@ impl JVMClasses<'_> { comet_metric_node: CometMetricNode::new(env).unwrap(), comet_exec: CometExec::new(env).unwrap(), comet_batch_iterator: CometBatchIterator::new(env).unwrap(), + comet_shuffle_block_iterator: CometShuffleBlockIterator::new(env).unwrap(), comet_task_memory_manager: CometTaskMemoryManager::new(env).unwrap(), } }); diff --git a/native/core/src/jvm_bridge/shuffle_block_iterator.rs b/native/core/src/jvm_bridge/shuffle_block_iterator.rs new file mode 100644 index 0000000000..c3bb5af5fb --- /dev/null +++ b/native/core/src/jvm_bridge/shuffle_block_iterator.rs @@ -0,0 +1,62 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use jni::signature::Primitive; +use jni::{ + errors::Result as JniResult, + objects::{JClass, JMethodID}, + signature::ReturnType, + JNIEnv, +}; + +/// A struct that holds all the JNI methods and fields for JVM `CometShuffleBlockIterator` class. +#[allow(dead_code)] // we need to keep references to Java items to prevent GC +pub struct CometShuffleBlockIterator<'a> { + pub class: JClass<'a>, + pub method_has_next: JMethodID, + pub method_has_next_ret: ReturnType, + pub method_get_buffer: JMethodID, + pub method_get_buffer_ret: ReturnType, + pub method_get_current_block_length: JMethodID, + pub method_get_current_block_length_ret: ReturnType, +} + +impl<'a> CometShuffleBlockIterator<'a> { + pub const JVM_CLASS: &'static str = "org/apache/comet/CometShuffleBlockIterator"; + + pub fn new(env: &mut JNIEnv<'a>) -> JniResult> { + let class = env.find_class(Self::JVM_CLASS)?; + + Ok(CometShuffleBlockIterator { + class, + method_has_next: env.get_method_id(Self::JVM_CLASS, "hasNext", "()I")?, + method_has_next_ret: ReturnType::Primitive(Primitive::Int), + method_get_buffer: env.get_method_id( + Self::JVM_CLASS, + "getBuffer", + "()Ljava/nio/ByteBuffer;", + )?, + method_get_buffer_ret: ReturnType::Object, + method_get_current_block_length: env.get_method_id( + Self::JVM_CLASS, + "getCurrentBlockLength", + "()I", + )?, + method_get_current_block_length_ret: ReturnType::Primitive(Primitive::Int), + }) + } +} diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 4afc1fefb7..344b9f0f21 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -52,6 +52,7 @@ message Operator { ParquetWriter parquet_writer = 113; Explode explode = 114; CsvScan csv_scan = 115; + ShuffleScan shuffle_scan = 116; } } @@ -85,6 +86,12 @@ message Scan { bool arrow_ffi_safe = 3; } +message ShuffleScan { + repeated spark.spark_expression.DataType fields = 1; + // Informational label for debug output (e.g., "CometShuffleExchangeExec [id=5]") + string source = 2; +} + // Common data shared by all partitions in split mode (sent once at planning) message NativeScanCommon { repeated SparkStructField required_schema = 1; diff --git a/spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java b/spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java new file mode 100644 index 0000000000..f9abef1c36 --- /dev/null +++ b/spark/src/main/java/org/apache/comet/CometShuffleBlockIterator.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet; + +import java.io.Closeable; +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; + +/** + * Provides raw compressed shuffle blocks to native code via JNI. + * + *

Reads block headers (compressed length + field count) from a shuffle InputStream and loads the + * compressed body into a DirectByteBuffer. Native code pulls blocks by calling hasNext() and + * getBuffer(). + * + *

The DirectByteBuffer returned by getBuffer() is only valid until the next hasNext() call. + * Native code must fully consume it (via read_ipc_compressed which allocates new memory for the + * decompressed data) before pulling the next block. + */ +public class CometShuffleBlockIterator implements Closeable { + + private static final int INITIAL_BUFFER_SIZE = 128 * 1024; + + private final ReadableByteChannel channel; + private final InputStream inputStream; + private final ByteBuffer headerBuf = ByteBuffer.allocate(16).order(ByteOrder.LITTLE_ENDIAN); + private ByteBuffer dataBuf = ByteBuffer.allocateDirect(INITIAL_BUFFER_SIZE); + private boolean closed = false; + private int currentBlockLength = 0; + + public CometShuffleBlockIterator(InputStream in) { + this.inputStream = in; + this.channel = Channels.newChannel(in); + } + + /** + * Reads the next block header and loads the compressed body into the internal buffer. Called by + * native code via JNI. + * + *

Header format: 8-byte compressedLength (includes field count but not itself) + 8-byte + * fieldCount (discarded, schema comes from protobuf). + * + * @return the compressed body length in bytes (codec prefix + compressed IPC), or -1 if EOF + */ + public int hasNext() throws IOException { + if (closed) { + return -1; + } + + // Read 16-byte header + headerBuf.clear(); + while (headerBuf.hasRemaining()) { + int bytesRead = channel.read(headerBuf); + if (bytesRead < 0) { + if (headerBuf.position() == 0) { + return -1; + } + throw new EOFException("Data corrupt: unexpected EOF while reading batch header"); + } + } + headerBuf.flip(); + long compressedLength = headerBuf.getLong(); + // Field count discarded - schema determined by ShuffleScan protobuf fields + headerBuf.getLong(); + + long bytesToRead = compressedLength - 8; + if (bytesToRead > Integer.MAX_VALUE) { + throw new IllegalStateException( + "Native shuffle block size of " + + bytesToRead + + " exceeds maximum of " + + Integer.MAX_VALUE + + ". Try reducing shuffle batch size."); + } + + if (dataBuf.capacity() < bytesToRead) { + int newCapacity = (int) Math.min(bytesToRead * 2L, Integer.MAX_VALUE); + dataBuf = ByteBuffer.allocateDirect(newCapacity); + } + + dataBuf.clear(); + dataBuf.limit((int) bytesToRead); + while (dataBuf.hasRemaining()) { + int bytesRead = channel.read(dataBuf); + if (bytesRead < 0) { + throw new EOFException("Data corrupt: unexpected EOF while reading compressed batch"); + } + } + // Note: native side uses get_direct_buffer_address (base pointer) + currentBlockLength, + // not the buffer's position/limit. No flip needed. + + currentBlockLength = (int) bytesToRead; + return currentBlockLength; + } + + /** + * Returns the DirectByteBuffer containing the current block's compressed bytes (4-byte codec + * prefix + compressed IPC data). Called by native code via JNI. + */ + public ByteBuffer getBuffer() { + return dataBuf; + } + + /** Returns the length of the current block in bytes. Called by native code via JNI. */ + public int getCurrentBlockLength() { + return currentBlockLength; + } + + @Override + public void close() throws IOException { + if (!closed) { + closed = true; + inputStream.close(); + } + } +} diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index 44ebf7e36e..e198ac99ff 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -67,7 +67,8 @@ class CometExecIterator( numParts: Int, partitionIndex: Int, broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]] = None, - encryptedFilePaths: Seq[String] = Seq.empty) + encryptedFilePaths: Seq[String] = Seq.empty, + shuffleBlockIterators: Map[Int, CometShuffleBlockIterator] = Map.empty) extends Iterator[ColumnarBatch] with Logging { @@ -78,8 +79,13 @@ class CometExecIterator( private val taskAttemptId = TaskContext.get().taskAttemptId private val taskCPUs = TaskContext.get().cpus() private val cometTaskMemoryManager = new CometTaskMemoryManager(id, taskAttemptId) - private val cometBatchIterators = inputs.map { iterator => - new CometBatchIterator(iterator, nativeUtil) + // Build a mixed array of iterators: CometShuffleBlockIterator for shuffle + // scan indices, CometBatchIterator for regular scan indices. + private val inputIterators: Array[Object] = inputs.zipWithIndex.map { + case (_, idx) if shuffleBlockIterators.contains(idx) => + shuffleBlockIterators(idx).asInstanceOf[Object] + case (iterator, _) => + new CometBatchIterator(iterator, nativeUtil).asInstanceOf[Object] }.toArray private val plan = { @@ -106,7 +112,7 @@ class CometExecIterator( nativeLib.createPlan( id, - cometBatchIterators, + inputIterators, protobufQueryPlan, protobufSparkConfigs, numParts, @@ -229,6 +235,7 @@ class CometExecIterator( currentBatch = null } nativeUtil.close() + shuffleBlockIterators.values.foreach(_.close()) nativeLib.releasePlan(plan) if (tracingEnabled) { diff --git a/spark/src/main/scala/org/apache/comet/Native.scala b/spark/src/main/scala/org/apache/comet/Native.scala index 55e0c70e72..f6800626d6 100644 --- a/spark/src/main/scala/org/apache/comet/Native.scala +++ b/spark/src/main/scala/org/apache/comet/Native.scala @@ -54,7 +54,7 @@ class Native extends NativeBase { // scalastyle:off @native def createPlan( id: Long, - iterators: Array[CometBatchIterator], + iterators: Array[Object], plan: Array[Byte], configMapProto: Array[Byte], partitionCount: Int, diff --git a/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala b/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala index ca9dbdad7c..dde36d9789 100644 --- a/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala +++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala @@ -22,8 +22,12 @@ package org.apache.comet.serde.operator import scala.jdk.CollectionConverters._ import org.apache.spark.sql.comet.{CometNativeExec, CometSinkPlaceHolder} +import org.apache.spark.sql.comet.execution.shuffle.{CometNativeShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec +import org.apache.spark.sql.execution.exchange.ReusedExchangeExec +import org.apache.comet.CometConf import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.ConfigEntry import org.apache.comet.serde.{CometOperatorSerde, OperatorOuterClass} @@ -86,15 +90,67 @@ abstract class CometSink[T <: SparkPlan] extends CometOperatorSerde[T] { object CometExchangeSink extends CometSink[SparkPlan] { - /** - * Exchange data is FFI safe because there is no use of mutable buffers involved. - * - * Source of broadcast exchange batches is ArrowStreamReader. - * - * Source of shuffle exchange batches is NativeBatchDecoderIterator. - */ override def isFfiSafe: Boolean = true + override def convert( + op: SparkPlan, + builder: Operator.Builder, + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { + if (shouldUseShuffleScan(op)) { + convertToShuffleScan(op, builder) + } else { + super.convert(op, builder, childOp: _*) + } + } + + private def shouldUseShuffleScan(op: SparkPlan): Boolean = { + if (!CometConf.COMET_SHUFFLE_DIRECT_READ_ENABLED.get()) return false + + // Extract the CometShuffleExchangeExec from the wrapper + val shuffleExec = op match { + case ShuffleQueryStageExec(_, s: CometShuffleExchangeExec, _) => Some(s) + case ShuffleQueryStageExec(_, ReusedExchangeExec(_, s: CometShuffleExchangeExec), _) => + Some(s) + case s: CometShuffleExchangeExec => Some(s) + case _ => None + } + + shuffleExec.exists(_.shuffleType == CometNativeShuffle) + } + + private def convertToShuffleScan( + op: SparkPlan, + builder: Operator.Builder): Option[OperatorOuterClass.Operator] = { + val supportedTypes = + op.output.forall(a => supportedDataType(a.dataType, allowComplex = true)) + + if (!supportedTypes) { + withInfo(op, "Unsupported data type for shuffle direct read") + return None + } + + val scanBuilder = OperatorOuterClass.ShuffleScan.newBuilder() + val source = op.simpleStringWithNodeId() + if (source.isEmpty) { + scanBuilder.setSource(op.getClass.getSimpleName) + } else { + scanBuilder.setSource(source) + } + + val scanTypes = op.output.flatMap { attr => + serializeDataType(attr.dataType) + } + + if (scanTypes.length == op.output.length) { + scanBuilder.addAllFields(scanTypes.asJava) + builder.clearChildren() + Some(builder.setShuffleScan(scanBuilder).build()) + } else { + withInfo(op, "unsupported data types for shuffle direct read") + None + } + } + override def createExec(nativeOp: Operator, op: SparkPlan): CometNativeExec = CometSinkPlaceHolder(nativeOp, op, op) } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala index ad0c4f2afe..cb8652507f 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.ScalarSubquery import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.SerializableConfiguration -import org.apache.comet.CometExecIterator +import org.apache.comet.{CometExecIterator, CometShuffleBlockIterator} import org.apache.comet.serde.OperatorOuterClass /** @@ -64,7 +64,10 @@ private[spark] class CometExecRDD( nativeMetrics: CometMetricNode, subqueries: Seq[ScalarSubquery], broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]] = None, - encryptedFilePaths: Seq[String] = Seq.empty) + encryptedFilePaths: Seq[String] = Seq.empty, + shuffleBlockIteratorFactories: Map[ + Int, + (TaskContext, Partition) => CometShuffleBlockIterator] = Map.empty) extends RDD[ColumnarBatch](sc, inputRDDs.map(rdd => new OneToOneDependency(rdd))) { // Determine partition count: from inputs if available, otherwise from parameter @@ -109,6 +112,12 @@ private[spark] class CometExecRDD( serializedPlan } + // Create shuffle block iterators for indices that have factories + val shuffleBlockIters = shuffleBlockIteratorFactories.map { case (idx, factory) => + val inputPart = partition.inputPartitions(idx) + idx -> factory(context, inputPart) + } + val it = new CometExecIterator( CometExec.newIterId, inputs, @@ -118,7 +127,8 @@ private[spark] class CometExecRDD( numPartitions, partition.index, broadcastedHadoopConfForEncryption, - encryptedFilePaths) + encryptedFilePaths, + shuffleBlockIters) // Register ScalarSubqueries so native code can look them up subqueries.foreach(sub => CometScalarSubquery.setSubquery(it.id, sub)) @@ -167,7 +177,10 @@ object CometExecRDD { nativeMetrics: CometMetricNode, subqueries: Seq[ScalarSubquery], broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]] = None, - encryptedFilePaths: Seq[String] = Seq.empty): CometExecRDD = { + encryptedFilePaths: Seq[String] = Seq.empty, + shuffleBlockIteratorFactories: Map[ + Int, + (TaskContext, Partition) => CometShuffleBlockIterator] = Map.empty): CometExecRDD = { // scalastyle:on new CometExecRDD( @@ -181,6 +194,7 @@ object CometExecRDD { nativeMetrics, subqueries, broadcastedHadoopConfForEncryption, - encryptedFilePaths) + encryptedFilePaths, + shuffleBlockIteratorFactories) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala index e95eb92d21..14e656f038 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala @@ -153,6 +153,18 @@ class CometBlockStoreShuffleReader[K, C]( } } + /** + * Returns the raw concatenated InputStream of all shuffle blocks, bypassing the decode step. + * Used by ShuffleScan direct read path. + */ + def readAsRawStream(): InputStream = { + val streams = fetchIterator.map(_._2) + new java.io.SequenceInputStream(new java.util.Enumeration[InputStream] { + override def hasMoreElements: Boolean = streams.hasNext + override def nextElement(): InputStream = streams.next() + }) + } + private def fetchContinuousBlocksInBatch: Boolean = { val conf = SparkEnv.get.conf val serializerRelocatable = dep.serializer.supportsRelocationOfSerializedObjects diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala index ba6fc588e2..6594982c85 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch */ class CometShuffledBatchRDD( var dependency: ShuffleDependency[Int, _, _], - metrics: Map[String, SQLMetric], + val metrics: Map[String, SQLMetric], partitionSpecs: Array[ShufflePartitionSpec]) extends RDD[ColumnarBatch](dependency.rdd.context, Nil) { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index da2ae21a95..2e195e73eb 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -25,6 +25,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ +import org.apache.spark.{Partition, SparkEnv, TaskContext} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -33,14 +34,14 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec +import org.apache.spark.sql.comet.execution.shuffle.{CometBlockStoreShuffleReader, CometShuffledBatchRDD, CometShuffleExchangeExec, ShuffledRowRDDPartition} import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec} -import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, TimestampNTZType} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -50,7 +51,7 @@ import org.apache.spark.util.io.ChunkedByteBuffer import com.google.common.base.Objects import com.google.protobuf.CodedOutputStream -import org.apache.comet.{CometConf, CometExecIterator, CometRuntimeException, ConfigEntry} +import org.apache.comet.{CometConf, CometExecIterator, CometRuntimeException, CometShuffleBlockIterator, ConfigEntry} import org.apache.comet.CometSparkSessionExtensions.{isCometShuffleEnabled, withInfo} import org.apache.comet.parquet.CometParquetUtils import org.apache.comet.serde.{CometOperatorSerde, Compatible, Incompatible, OperatorOuterClass, SupportLevel, Unsupported} @@ -553,6 +554,11 @@ abstract class CometNativeExec extends CometExec { throw new CometRuntimeException(s"No input for CometNativeExec:\n $this") } + // Detect ShuffleScan indices and create factories for direct read + val shuffleScanIndices = findShuffleScanIndices(serializedPlanCopy) + val shuffleBlockIteratorFactories = + buildShuffleBlockIteratorFactories(sparkPlans, inputs, shuffleScanIndices) + // Unified RDD creation - CometExecRDD handles all cases val subqueries = collectSubqueries(this) CometExecRDD( @@ -566,7 +572,8 @@ abstract class CometNativeExec extends CometExec { nativeMetrics, subqueries, broadcastedHadoopConfForEncryption, - encryptedFilePaths) + encryptedFilePaths, + shuffleBlockIteratorFactories) } } @@ -606,6 +613,108 @@ abstract class CometNativeExec extends CometExec { } } + /** + * Walk the serialized protobuf plan depth-first to find which input indices correspond to + * ShuffleScan vs Scan leaf nodes. Each Scan or ShuffleScan leaf consumes one input in order. + */ + private def findShuffleScanIndices(planBytes: Array[Byte]): Set[Int] = { + val plan = OperatorOuterClass.Operator.parseFrom(planBytes) + var scanIndex = 0 + val indices = mutable.Set.empty[Int] + def walk(op: OperatorOuterClass.Operator): Unit = { + if (op.hasShuffleScan) { + indices += scanIndex + scanIndex += 1 + } else if (op.hasScan) { + scanIndex += 1 + } else { + op.getChildrenList.asScala.foreach(walk) + } + } + walk(plan) + indices.toSet + } + + /** + * Build factory functions that produce CometShuffleBlockIterator for each input index that is a + * ShuffleScan. Maps from input index to a factory that, given TaskContext and Partition, + * creates the iterator. + */ + private def buildShuffleBlockIteratorFactories( + sparkPlans: ArrayBuffer[SparkPlan], + inputs: ArrayBuffer[RDD[ColumnarBatch]], + shuffleScanIndices: Set[Int]) + : Map[Int, (TaskContext, Partition) => CometShuffleBlockIterator] = { + if (shuffleScanIndices.isEmpty) return Map.empty + + val factories = mutable.Map.empty[Int, (TaskContext, Partition) => CometShuffleBlockIterator] + + shuffleScanIndices.foreach { scanIdx => + if (scanIdx < inputs.length) { + inputs(scanIdx) match { + case rdd: CometShuffledBatchRDD => + val dep = rdd.dependency + val rddMetrics = rdd.metrics + factories(scanIdx) = (context, part) => { + val shufflePart = part.asInstanceOf[ShuffledRowRDDPartition] + val tempMetrics = + context.taskMetrics().createTempShuffleReadMetrics() + val sqlMetricsReporter = + new SQLShuffleReadMetricsReporter(tempMetrics, rddMetrics) + val reader = shufflePart.spec match { + case CoalescedPartitionSpec(startReducerIndex, endReducerIndex, _) => + SparkEnv.get.shuffleManager + .getReader( + dep.shuffleHandle, + startReducerIndex, + endReducerIndex, + context, + sqlMetricsReporter) + .asInstanceOf[CometBlockStoreShuffleReader[_, _]] + case PartialReducerPartitionSpec(reducerIndex, startMapIndex, endMapIndex, _) => + SparkEnv.get.shuffleManager + .getReader( + dep.shuffleHandle, + startMapIndex, + endMapIndex, + reducerIndex, + reducerIndex + 1, + context, + sqlMetricsReporter) + .asInstanceOf[CometBlockStoreShuffleReader[_, _]] + case PartialMapperPartitionSpec(mapIndex, startReducerIndex, endReducerIndex) => + SparkEnv.get.shuffleManager + .getReader( + dep.shuffleHandle, + mapIndex, + mapIndex + 1, + startReducerIndex, + endReducerIndex, + context, + sqlMetricsReporter) + .asInstanceOf[CometBlockStoreShuffleReader[_, _]] + case CoalescedMapperPartitionSpec(startMapIndex, endMapIndex, numReducers) => + SparkEnv.get.shuffleManager + .getReader( + dep.shuffleHandle, + startMapIndex, + endMapIndex, + 0, + numReducers, + context, + sqlMetricsReporter) + .asInstanceOf[CometBlockStoreShuffleReader[_, _]] + } + val rawStream = reader.readAsRawStream() + new CometShuffleBlockIterator(rawStream) + } + case _ => // Not a CometShuffledBatchRDD, skip + } + } + } + factories.toMap + } + /** * Find all plan nodes with per-partition planning data in the plan tree. Returns two maps keyed * by a unique identifier: one for common data (shared across partitions) and one for diff --git a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala index 1cf43ea598..11f825e70d 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.SparkEnv import org.apache.spark.sql.{CometTestBase, DataFrame, Dataset, Row} import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.functions.{col, count, sum} import org.apache.comet.CometConf @@ -437,4 +437,19 @@ class CometNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper } } } + + test("shuffle direct read produces same results as FFI path") { + Seq(true, false).foreach { directRead => + withSQLConf(CometConf.COMET_SHUFFLE_DIRECT_READ_ENABLED.key -> directRead.toString) { + val df = spark + .range(1000) + .selectExpr("id", "id % 10 as key", "cast(id as string) as value") + .repartition(4, col("key")) + .groupBy("key") + .agg(sum("id").as("total"), count("value").as("cnt")) + .orderBy("key") + checkSparkAnswer(df) + } + } + } }