diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 1030e30aaf..9ce3db4de3 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -21,8 +21,11 @@ use super::{serde, utils::SparkArrowConvert}; use crate::{ errors::{try_unwrap_or_throw, CometError, CometResult}, execution::{ - metrics::utils::update_comet_metric, planner::PhysicalPlanner, serde::to_arrow_datatype, - shuffle::spark_unsafe::row::process_sorted_row_partition, sort::RdxSort, + metrics::utils::{build_metric_layout, update_comet_metric, MetricLayout}, + planner::PhysicalPlanner, + serde::to_arrow_datatype, + shuffle::spark_unsafe::row::process_sorted_row_partition, + sort::RdxSort, }, jvm_bridge::{jni_new_global_ref, JVMClasses}, }; @@ -175,6 +178,8 @@ struct ExecutionContext { pub memory_pool_config: MemoryPoolConfig, /// Whether to log memory usage on each call to execute_plan pub tracing_enabled: bool, + /// Pre-computed metric layout for flat array metric updates + pub metric_layout: Option, } /// Accept serialized query plan and return the address of the native query plan. @@ -322,6 +327,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( explain_native, memory_pool_config, tracing_enabled, + metric_layout: None, }); Ok(Box::into_raw(exec_context) as i64) @@ -547,6 +553,10 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( exec_context.root_op = Some(Arc::clone(&root_op)); exec_context.scans = scans; + // Build the flat metric layout for efficient metric updates + let metrics = exec_context.metrics.as_obj(); + exec_context.metric_layout = Some(build_metric_layout(&mut env, metrics)?); + if exec_context.explain_native { let formatted_plan_str = DisplayableExecutionPlan::new(root_op.native_plan.as_ref()).indent(true); @@ -678,9 +688,14 @@ pub extern "system" fn Java_org_apache_comet_Native_releasePlan( /// Updates the metrics of the query plan. fn update_metrics(env: &mut JNIEnv, exec_context: &mut ExecutionContext) -> CometResult<()> { - if let Some(native_query) = &exec_context.root_op { + if let Some(ref native_query) = exec_context.root_op { + let native_query = Arc::clone(native_query); let metrics = exec_context.metrics.as_obj(); - update_comet_metric(env, metrics, native_query) + if let Some(ref mut layout) = exec_context.metric_layout { + update_comet_metric(env, metrics, &native_query, layout) + } else { + Ok(()) + } } else { Ok(()) } diff --git a/native/core/src/execution/metrics/utils.rs b/native/core/src/execution/metrics/utils.rs index 9ec35ad951..73a095ac03 100644 --- a/native/core/src/execution/metrics/utils.rs +++ b/native/core/src/execution/metrics/utils.rs @@ -18,38 +18,94 @@ use crate::execution::spark_plan::SparkPlan; use crate::{errors::CometError, jvm_bridge::jni_call}; use datafusion::physical_plan::metrics::MetricValue; -use datafusion_comet_proto::spark_metric::NativeMetricNode; -use jni::{objects::JObject, JNIEnv}; -use prost::Message; +use jni::objects::{GlobalRef, JIntArray, JLongArray, JObject, JObjectArray}; +use jni::JNIEnv; use std::collections::HashMap; use std::sync::Arc; -/// Updates the metrics of a CometMetricNode. This function is called recursively to -/// update the metrics of all the children nodes. The metrics are pulled from the -/// native execution plan and pushed to the Java side through JNI. -pub(crate) fn update_comet_metric( +/// Pre-computed layout mapping metric names to indices in a flat array. +/// Built once at plan creation, reused on every metric update. +pub(crate) struct MetricLayout { + /// Per SparkPlan node (DFS order), maps metric name to index in the flat values array + node_indices: Vec>, + /// Flat array of metric values, written by native and bulk-copied to JVM + values: Vec, + /// Global reference to the JVM long[] array (kept alive for the lifetime of the plan) + jarray: Arc, +} + +/// Builds a MetricLayout by calling JNI methods on the CometMetricNode to retrieve +/// the flattened metric names, node offsets, and a reference to the pre-allocated long[]. +pub(crate) fn build_metric_layout( env: &mut JNIEnv, metric_node: &JObject, - spark_plan: &Arc, -) -> Result<(), CometError> { - if metric_node.is_null() { - return Ok(()); +) -> Result { + // Get metric names array (String[]) + let names_obj: JObject = + unsafe { jni_call!(env, comet_metric_node(metric_node).get_metric_names() -> JObject) }?; + let names_array = JObjectArray::from(names_obj); + let num_metrics = env.get_array_length(&names_array)? as usize; + + let mut metric_names = Vec::with_capacity(num_metrics); + for i in 0..num_metrics { + let jstr = env.get_object_array_element(&names_array, i as i32)?; + let name: String = env.get_string((&jstr).into())?.into(); + metric_names.push(name); } - let native_metric = to_native_metric_node(spark_plan); - let jbytes = env.byte_array_from_slice(&native_metric?.encode_to_vec())?; + // Get node offsets array (int[]) + let offsets_obj: JObject = + unsafe { jni_call!(env, comet_metric_node(metric_node).get_node_offsets() -> JObject) }?; + let offsets_array = JIntArray::from(offsets_obj); + let num_offsets = env.get_array_length(&offsets_array)? as usize; + let mut offsets = vec![0i32; num_offsets]; + env.get_int_array_region(&offsets_array, 0, &mut offsets)?; + + // Get values array reference (long[]) + let values_obj: JObject = + unsafe { jni_call!(env, comet_metric_node(metric_node).get_values_array() -> JObject) }?; + let jarray = Arc::new(env.new_global_ref(values_obj)?); - unsafe { jni_call!(env, comet_metric_node(metric_node).set_all_from_bytes(&jbytes) -> ()) } + // Build per-node index maps + let num_nodes = num_offsets - 1; + let mut node_indices = Vec::with_capacity(num_nodes); + for node_idx in 0..num_nodes { + let start = offsets[node_idx] as usize; + let end = offsets[node_idx + 1] as usize; + let mut map = HashMap::with_capacity(end - start); + for (i, name) in metric_names.iter().enumerate().take(end).skip(start) { + map.insert(name.clone(), i); + } + node_indices.push(map); + } + + Ok(MetricLayout { + node_indices, + values: vec![-1i64; num_metrics], + jarray, + }) } -pub(crate) fn to_native_metric_node( +/// Recursively fills the values array from DataFusion metrics on the SparkPlan tree. +fn fill_metric_values( spark_plan: &Arc, -) -> Result { - let mut native_metric_node = NativeMetricNode { - metrics: HashMap::new(), - children: Vec::new(), - }; + layout: &mut MetricLayout, + node_idx: &mut usize, +) { + let current_node = *node_idx; + *node_idx += 1; + + if current_node >= layout.node_indices.len() { + // Skip if node index exceeds layout (shouldn't happen with correct setup) + for child in spark_plan.children() { + fill_metric_values(child, layout, node_idx); + } + return; + } + + let indices = &layout.node_indices[current_node]; + // Collect metrics from the native plan (and additional plans) let node_metrics = if spark_plan.additional_native_plans.is_empty() { spark_plan.native_plan.metrics() } else { @@ -59,7 +115,7 @@ pub(crate) fn to_native_metric_node( for c in additional_metrics.iter() { match c.value() { MetricValue::OutputRows(_) => { - // we do not want to double count output rows + // do not double count output rows } _ => metrics.push(c.to_owned()), } @@ -68,21 +124,43 @@ pub(crate) fn to_native_metric_node( Some(metrics.aggregate_by_name()) }; - // add metrics - node_metrics - .unwrap_or_default() - .iter() - .map(|m| m.value()) - .map(|m| (m.name(), m.as_usize() as i64)) - .for_each(|(name, value)| { - native_metric_node.metrics.insert(name.to_string(), value); - }); - - // add children - for child_plan in spark_plan.children() { - let child_node = to_native_metric_node(child_plan)?; - native_metric_node.children.push(child_node); + // Write metric values into their pre-assigned slots + if let Some(metrics) = node_metrics { + for m in metrics.iter() { + let value = m.value(); + let name = value.name(); + if let Some(&idx) = indices.get(name) { + layout.values[idx] = value.as_usize() as i64; + } + } + } + + // Recurse into children + for child in spark_plan.children() { + fill_metric_values(child, layout, node_idx); + } +} + +/// Updates metrics by filling the flat values array and bulk-copying to JVM. +pub(crate) fn update_comet_metric( + env: &mut JNIEnv, + metric_node: &JObject, + spark_plan: &Arc, + layout: &mut MetricLayout, +) -> Result<(), CometError> { + if metric_node.is_null() { + return Ok(()); } - Ok(native_metric_node) + // Fill values from native metrics + let mut node_idx = 0; + fill_metric_values(spark_plan, layout, &mut node_idx); + + // Bulk copy values to JVM long[] via SetLongArrayRegion + let local_ref = env.new_local_ref(layout.jarray.as_obj())?; + let jlong_array = JLongArray::from(local_ref); + env.set_long_array_region(&jlong_array, 0, &layout.values)?; + + // Call updateFromValues() on the JVM side + unsafe { jni_call!(env, comet_metric_node(metric_node).update_from_values() -> ()) } } diff --git a/native/core/src/jvm_bridge/comet_metric_node.rs b/native/core/src/jvm_bridge/comet_metric_node.rs index f1f0255845..5646ec59ad 100644 --- a/native/core/src/jvm_bridge/comet_metric_node.rs +++ b/native/core/src/jvm_bridge/comet_metric_node.rs @@ -26,12 +26,14 @@ use jni::{ #[allow(dead_code)] // we need to keep references to Java items to prevent GC pub struct CometMetricNode<'a> { pub class: JClass<'a>, - pub method_get_child_node: JMethodID, - pub method_get_child_node_ret: ReturnType, - pub method_set: JMethodID, - pub method_set_ret: ReturnType, - pub method_set_all_from_bytes: JMethodID, - pub method_set_all_from_bytes_ret: ReturnType, + pub method_get_metric_names: JMethodID, + pub method_get_metric_names_ret: ReturnType, + pub method_get_node_offsets: JMethodID, + pub method_get_node_offsets_ret: ReturnType, + pub method_get_values_array: JMethodID, + pub method_get_values_array_ret: ReturnType, + pub method_update_from_values: JMethodID, + pub method_update_from_values_ret: ReturnType, } impl<'a> CometMetricNode<'a> { @@ -41,20 +43,30 @@ impl<'a> CometMetricNode<'a> { let class = env.find_class(Self::JVM_CLASS)?; Ok(CometMetricNode { - method_get_child_node: env.get_method_id( + method_get_metric_names: env.get_method_id( Self::JVM_CLASS, - "getChildNode", - format!("(I)L{:};", Self::JVM_CLASS).as_str(), + "getMetricNames", + "()[Ljava/lang/String;", )?, - method_get_child_node_ret: ReturnType::Object, - method_set: env.get_method_id(Self::JVM_CLASS, "set", "(Ljava/lang/String;J)V")?, - method_set_ret: ReturnType::Primitive(Primitive::Void), - method_set_all_from_bytes: env.get_method_id( + method_get_metric_names_ret: ReturnType::Object, + method_get_node_offsets: env.get_method_id( Self::JVM_CLASS, - "set_all_from_bytes", - "([B)V", + "getNodeOffsets", + "()[I", )?, - method_set_all_from_bytes_ret: ReturnType::Primitive(Primitive::Void), + method_get_node_offsets_ret: ReturnType::Object, + method_get_values_array: env.get_method_id( + Self::JVM_CLASS, + "getValuesArray", + "()[J", + )?, + method_get_values_array_ret: ReturnType::Object, + method_update_from_values: env.get_method_id( + Self::JVM_CLASS, + "updateFromValues", + "()V", + )?, + method_update_from_values_ret: ReturnType::Primitive(Primitive::Void), class, }) } diff --git a/native/core/src/parquet/mod.rs b/native/core/src/parquet/mod.rs index f2b0e80ab2..e0d8eb1778 100644 --- a/native/core/src/parquet/mod.rs +++ b/native/core/src/parquet/mod.rs @@ -47,7 +47,7 @@ use jni::{ use self::util::jni::TypePromotionInfo; use crate::execution::jni_api::get_runtime; -use crate::execution::metrics::utils::update_comet_metric; +use crate::execution::metrics::utils::{build_metric_layout, update_comet_metric, MetricLayout}; use crate::execution::operators::ExecutionError; use crate::execution::planner::PhysicalPlanner; use crate::execution::serde; @@ -605,6 +605,7 @@ enum ParquetReaderState { struct BatchContext { native_plan: Arc, metrics_node: Arc, + metric_layout: MetricLayout, batch_stream: Option, current_batch: Option, reader_state: ParquetReaderState, @@ -780,9 +781,13 @@ pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_initRecordBat let partition_index: usize = 0; let batch_stream = scan.execute(partition_index, session_ctx.task_ctx())?; + let metrics_global_ref = Arc::new(jni_new_global_ref!(env, metrics_node)?); + let metric_layout = build_metric_layout(&mut env, metrics_global_ref.as_obj())?; + let ctx = BatchContext { native_plan: Arc::new(SparkPlan::new(0, scan, vec![])), - metrics_node: Arc::new(jni_new_global_ref!(env, metrics_node)?), + metrics_node: metrics_global_ref, + metric_layout, batch_stream: Some(batch_stream), current_batch: None, reader_state: ParquetReaderState::Init, @@ -825,6 +830,7 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_readNextRecordBatch( &mut env, context.metrics_node.as_obj(), &context.native_plan, + &mut context.metric_layout, )?; context.current_batch = None; diff --git a/native/proto/build.rs b/native/proto/build.rs index 634888e29b..2ef0fc9d33 100644 --- a/native/proto/build.rs +++ b/native/proto/build.rs @@ -30,7 +30,6 @@ fn main() -> Result<()> { prost_build::Config::new().out_dir(out_dir).compile_protos( &[ "src/proto/expr.proto", - "src/proto/metric.proto", "src/proto/partitioning.proto", "src/proto/operator.proto", "src/proto/config.proto", diff --git a/native/proto/src/lib.rs b/native/proto/src/lib.rs index a55657b7af..b67574a4e9 100644 --- a/native/proto/src/lib.rs +++ b/native/proto/src/lib.rs @@ -39,12 +39,6 @@ pub mod spark_operator { include!(concat!("generated", "/spark.spark_operator.rs")); } -// Include generated modules from .proto files. -#[allow(missing_docs)] -pub mod spark_metric { - include!(concat!("generated", "/spark.spark_metric.rs")); -} - // Include generated modules from .proto files. #[allow(missing_docs)] pub mod spark_config { diff --git a/native/proto/src/proto/metric.proto b/native/proto/src/proto/metric.proto deleted file mode 100644 index f026e505ae..0000000000 --- a/native/proto/src/proto/metric.proto +++ /dev/null @@ -1,29 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - - - -syntax = "proto3"; - -package spark.spark_metric; - -option java_package = "org.apache.comet.serde"; - -message NativeMetricNode { - map metrics = 1; - repeated NativeMetricNode children = 2; -} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala index 8c75df1d45..7bbacee9bc 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala @@ -19,69 +19,81 @@ package org.apache.spark.sql.comet -import scala.jdk.CollectionConverters._ +import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkContext import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.comet.serde.Metric - /** - * A node carrying SQL metrics from SparkPlan, and metrics of its children. Native code will call - * [[getChildNode]] and [[set]] to update the metrics. + * A node carrying SQL metrics from SparkPlan, and metrics of its children. Metrics are flattened + * into parallel arrays for efficient bulk transfer from native code via JNI. * * @param metrics - * the mapping between metric name of native operator to `SQLMetric` of Spark operator. For - * example, `numOutputRows` -> `SQLMetrics("numOutputRows")` means the native operator will - * update `numOutputRows` metric with the value of `SQLMetrics("numOutputRows")` in Spark - * operator. + * the mapping between metric name of native operator to `SQLMetric` of Spark operator. */ case class CometMetricNode(metrics: Map[String, SQLMetric], children: Seq[CometMetricNode]) extends Logging { /** - * Gets a child node. Called from native. + * DFS flattening of the metric tree into parallel arrays. Within each node, metric names are + * sorted alphabetically for deterministic ordering. Returns (names, sqlMetrics, offsets) where + * offsets(i) is the start index of node i's metrics and offsets.length = numNodes + 1. */ - def getChildNode(i: Int): CometMetricNode = { - if (i < 0 || i >= children.length) { - // TODO: throw an exception, e.g. IllegalArgumentException, instead? - return null + private lazy val (flatMetricNames, flatSQLMetrics, nodeOffsets) = { + val names = new ArrayBuffer[String]() + val sqlMetrics = new ArrayBuffer[SQLMetric]() + val offsets = new ArrayBuffer[Int]() + + def walk(node: CometMetricNode): Unit = { + offsets += names.length + val sorted = node.metrics.toSeq.sortBy(_._1) + sorted.foreach { case (name, metric) => + names += name + sqlMetrics += metric + } + node.children.foreach(walk) } - children(i) + + walk(this) + offsets += names.length // sentinel + (names.toArray, sqlMetrics.toArray, offsets.toArray) } /** - * Update the value of a metric. This method will typically be called multiple times for the - * same metric during multiple calls to executePlan. - * - * @param metricName - * the name of the metric at native operator. - * @param v - * the value to set. + * Pre-allocated array for native code to write metric values into. Initialized to -1 (sentinel) + * so that JVM-only metrics that native doesn't produce are not overwritten. */ - def set(metricName: String, v: Long): Unit = { - metrics.get(metricName) match { - case Some(metric) => metric.set(v) - case None => - // no-op - logDebug(s"Non-existing metric: $metricName. Ignored") - } + lazy val metricValuesArray: Array[Long] = { + val arr = new Array[Long](flatMetricNames.length) + java.util.Arrays.fill(arr, -1L) + arr } - private def set_all(metricNode: Metric.NativeMetricNode): Unit = { - metricNode.getMetricsMap.forEach((name, value) => { - set(name, value) - }) - metricNode.getChildrenList.asScala.zip(children).foreach { case (child, childNode) => - childNode.set_all(child) - } - } + /** Returns the flattened metric names. Called from native once during setup. */ + def getMetricNames(): Array[String] = flatMetricNames - def set_all_from_bytes(bytes: Array[Byte]): Unit = { - val metricNode = Metric.NativeMetricNode.parseFrom(bytes) - set_all(metricNode) + /** Returns node start offsets into the flat arrays. Called from native once during setup. */ + def getNodeOffsets(): Array[Int] = nodeOffsets + + /** Returns the pre-allocated values array. Called from native once during setup. */ + def getValuesArray(): Array[Long] = metricValuesArray + + /** + * Updates all SQLMetrics from the values array. Called from native after bulk-copying metric + * values via SetLongArrayRegion. + */ + def updateFromValues(): Unit = { + var i = 0 + while (i < flatSQLMetrics.length) { + val v = metricValuesArray(i) + if (v >= 0) { + flatSQLMetrics(i).set(v) + metricValuesArray(i) = -1L // reset sentinel for next cycle + } + i += 1 + } } }