From 64009ae0adf6fe127bb6a256f9ec8bff5a8df4fb Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 28 Feb 2026 12:13:14 -0700 Subject: [PATCH 1/3] perf: replace protobuf metric reporting with pre-allocated flat long[] array MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the protobuf-based metric update path with a flat array approach for significantly faster metric transfers from native to JVM. Previously, every metric update built a HashMap and NativeMetricNode protobuf tree, encoded it, decoded on JVM side, then walked the tree with Map.get() lookups per metric. Now, metrics are flattened into parallel arrays at plan creation time, native fills a Vec and copies to JVM via a single SetLongArrayRegion bulk JNI call, and JVM reads linearly — no protobuf, no tree walking, no map lookups, no string allocations per update cycle. --- native/core/src/execution/jni_api.rs | 21 ++- native/core/src/execution/metrics/utils.rs | 150 +++++++++++++----- .../core/src/jvm_bridge/comet_metric_node.rs | 44 +++-- native/core/src/parquet/mod.rs | 11 +- native/proto/build.rs | 1 - native/proto/src/lib.rs | 6 - native/proto/src/proto/metric.proto | 29 ---- .../spark/sql/comet/CometMetricNode.scala | 85 +++++----- 8 files changed, 212 insertions(+), 135 deletions(-) delete mode 100644 native/proto/src/proto/metric.proto diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 0193f3012c..2367309075 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -21,7 +21,9 @@ use super::{serde, utils::SparkArrowConvert}; use crate::{ errors::{try_unwrap_or_throw, CometError, CometResult}, execution::{ - metrics::utils::update_comet_metric, planner::PhysicalPlanner, serde::to_arrow_datatype, + metrics::utils::{build_metric_layout, update_comet_metric, MetricLayout}, + planner::PhysicalPlanner, + serde::to_arrow_datatype, shuffle::spark_unsafe::row::process_sorted_row_partition, sort::RdxSort, }, jvm_bridge::{jni_new_global_ref, JVMClasses}, @@ -173,6 +175,8 @@ struct ExecutionContext { pub memory_pool_config: MemoryPoolConfig, /// Whether to log memory usage on each call to execute_plan pub tracing_enabled: bool, + /// Pre-computed metric layout for flat array metric updates + pub metric_layout: Option, } /// Accept serialized query plan and return the address of the native query plan. @@ -320,6 +324,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( explain_native, memory_pool_config, tracing_enabled, + metric_layout: None, }); Ok(Box::into_raw(exec_context) as i64) @@ -543,6 +548,11 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( exec_context.root_op = Some(Arc::clone(&root_op)); exec_context.scans = scans; + // Build the flat metric layout for efficient metric updates + let metrics = exec_context.metrics.as_obj(); + exec_context.metric_layout = + Some(build_metric_layout(&mut env, metrics)?); + if exec_context.explain_native { let formatted_plan_str = DisplayableExecutionPlan::new(root_op.native_plan.as_ref()).indent(true); @@ -674,9 +684,14 @@ pub extern "system" fn Java_org_apache_comet_Native_releasePlan( /// Updates the metrics of the query plan. fn update_metrics(env: &mut JNIEnv, exec_context: &mut ExecutionContext) -> CometResult<()> { - if let Some(native_query) = &exec_context.root_op { + if let Some(ref native_query) = exec_context.root_op { + let native_query = Arc::clone(native_query); let metrics = exec_context.metrics.as_obj(); - update_comet_metric(env, metrics, native_query) + if let Some(ref mut layout) = exec_context.metric_layout { + update_comet_metric(env, metrics, &native_query, layout) + } else { + Ok(()) + } } else { Ok(()) } diff --git a/native/core/src/execution/metrics/utils.rs b/native/core/src/execution/metrics/utils.rs index 9ec35ad951..875743764d 100644 --- a/native/core/src/execution/metrics/utils.rs +++ b/native/core/src/execution/metrics/utils.rs @@ -18,38 +18,94 @@ use crate::execution::spark_plan::SparkPlan; use crate::{errors::CometError, jvm_bridge::jni_call}; use datafusion::physical_plan::metrics::MetricValue; -use datafusion_comet_proto::spark_metric::NativeMetricNode; -use jni::{objects::JObject, JNIEnv}; -use prost::Message; +use jni::objects::{GlobalRef, JIntArray, JLongArray, JObject, JObjectArray}; +use jni::JNIEnv; use std::collections::HashMap; use std::sync::Arc; -/// Updates the metrics of a CometMetricNode. This function is called recursively to -/// update the metrics of all the children nodes. The metrics are pulled from the -/// native execution plan and pushed to the Java side through JNI. -pub(crate) fn update_comet_metric( +/// Pre-computed layout mapping metric names to indices in a flat array. +/// Built once at plan creation, reused on every metric update. +pub(crate) struct MetricLayout { + /// Per SparkPlan node (DFS order), maps metric name to index in the flat values array + node_indices: Vec>, + /// Flat array of metric values, written by native and bulk-copied to JVM + values: Vec, + /// Global reference to the JVM long[] array (kept alive for the lifetime of the plan) + jarray: Arc, +} + +/// Builds a MetricLayout by calling JNI methods on the CometMetricNode to retrieve +/// the flattened metric names, node offsets, and a reference to the pre-allocated long[]. +pub(crate) fn build_metric_layout( env: &mut JNIEnv, metric_node: &JObject, - spark_plan: &Arc, -) -> Result<(), CometError> { - if metric_node.is_null() { - return Ok(()); +) -> Result { + // Get metric names array (String[]) + let names_obj: JObject = + unsafe { jni_call!(env, comet_metric_node(metric_node).get_metric_names() -> JObject) }?; + let names_array = JObjectArray::from(names_obj); + let num_metrics = env.get_array_length(&names_array)? as usize; + + let mut metric_names = Vec::with_capacity(num_metrics); + for i in 0..num_metrics { + let jstr = env.get_object_array_element(&names_array, i as i32)?; + let name: String = env.get_string((&jstr).into())?.into(); + metric_names.push(name); } - let native_metric = to_native_metric_node(spark_plan); - let jbytes = env.byte_array_from_slice(&native_metric?.encode_to_vec())?; + // Get node offsets array (int[]) + let offsets_obj: JObject = + unsafe { jni_call!(env, comet_metric_node(metric_node).get_node_offsets() -> JObject) }?; + let offsets_array = JIntArray::from(offsets_obj); + let num_offsets = env.get_array_length(&offsets_array)? as usize; + let mut offsets = vec![0i32; num_offsets]; + env.get_int_array_region(&offsets_array, 0, &mut offsets)?; + + // Get values array reference (long[]) + let values_obj: JObject = + unsafe { jni_call!(env, comet_metric_node(metric_node).get_values_array() -> JObject) }?; + let jarray = Arc::new(env.new_global_ref(values_obj)?); - unsafe { jni_call!(env, comet_metric_node(metric_node).set_all_from_bytes(&jbytes) -> ()) } + // Build per-node index maps + let num_nodes = num_offsets - 1; + let mut node_indices = Vec::with_capacity(num_nodes); + for node_idx in 0..num_nodes { + let start = offsets[node_idx] as usize; + let end = offsets[node_idx + 1] as usize; + let mut map = HashMap::with_capacity(end - start); + for (i, name) in metric_names.iter().enumerate().take(end).skip(start) { + map.insert(name.clone(), i); + } + node_indices.push(map); + } + + Ok(MetricLayout { + node_indices, + values: vec![0i64; num_metrics], + jarray, + }) } -pub(crate) fn to_native_metric_node( +/// Recursively fills the values array from DataFusion metrics on the SparkPlan tree. +fn fill_metric_values( spark_plan: &Arc, -) -> Result { - let mut native_metric_node = NativeMetricNode { - metrics: HashMap::new(), - children: Vec::new(), - }; + layout: &mut MetricLayout, + node_idx: &mut usize, +) { + let current_node = *node_idx; + *node_idx += 1; + + if current_node >= layout.node_indices.len() { + // Skip if node index exceeds layout (shouldn't happen with correct setup) + for child in spark_plan.children() { + fill_metric_values(child, layout, node_idx); + } + return; + } + + let indices = &layout.node_indices[current_node]; + // Collect metrics from the native plan (and additional plans) let node_metrics = if spark_plan.additional_native_plans.is_empty() { spark_plan.native_plan.metrics() } else { @@ -59,7 +115,7 @@ pub(crate) fn to_native_metric_node( for c in additional_metrics.iter() { match c.value() { MetricValue::OutputRows(_) => { - // we do not want to double count output rows + // do not double count output rows } _ => metrics.push(c.to_owned()), } @@ -68,21 +124,43 @@ pub(crate) fn to_native_metric_node( Some(metrics.aggregate_by_name()) }; - // add metrics - node_metrics - .unwrap_or_default() - .iter() - .map(|m| m.value()) - .map(|m| (m.name(), m.as_usize() as i64)) - .for_each(|(name, value)| { - native_metric_node.metrics.insert(name.to_string(), value); - }); - - // add children - for child_plan in spark_plan.children() { - let child_node = to_native_metric_node(child_plan)?; - native_metric_node.children.push(child_node); + // Write metric values into their pre-assigned slots + if let Some(metrics) = node_metrics { + for m in metrics.iter() { + let value = m.value(); + let name = value.name(); + if let Some(&idx) = indices.get(name) { + layout.values[idx] = value.as_usize() as i64; + } + } + } + + // Recurse into children + for child in spark_plan.children() { + fill_metric_values(child, layout, node_idx); + } +} + +/// Updates metrics by filling the flat values array and bulk-copying to JVM. +pub(crate) fn update_comet_metric( + env: &mut JNIEnv, + metric_node: &JObject, + spark_plan: &Arc, + layout: &mut MetricLayout, +) -> Result<(), CometError> { + if metric_node.is_null() { + return Ok(()); } - Ok(native_metric_node) + // Fill values from native metrics + let mut node_idx = 0; + fill_metric_values(spark_plan, layout, &mut node_idx); + + // Bulk copy values to JVM long[] via SetLongArrayRegion + let local_ref = env.new_local_ref(layout.jarray.as_obj())?; + let jlong_array = JLongArray::from(local_ref); + env.set_long_array_region(&jlong_array, 0, &layout.values)?; + + // Call updateFromValues() on the JVM side + unsafe { jni_call!(env, comet_metric_node(metric_node).update_from_values() -> ()) } } diff --git a/native/core/src/jvm_bridge/comet_metric_node.rs b/native/core/src/jvm_bridge/comet_metric_node.rs index f1f0255845..5646ec59ad 100644 --- a/native/core/src/jvm_bridge/comet_metric_node.rs +++ b/native/core/src/jvm_bridge/comet_metric_node.rs @@ -26,12 +26,14 @@ use jni::{ #[allow(dead_code)] // we need to keep references to Java items to prevent GC pub struct CometMetricNode<'a> { pub class: JClass<'a>, - pub method_get_child_node: JMethodID, - pub method_get_child_node_ret: ReturnType, - pub method_set: JMethodID, - pub method_set_ret: ReturnType, - pub method_set_all_from_bytes: JMethodID, - pub method_set_all_from_bytes_ret: ReturnType, + pub method_get_metric_names: JMethodID, + pub method_get_metric_names_ret: ReturnType, + pub method_get_node_offsets: JMethodID, + pub method_get_node_offsets_ret: ReturnType, + pub method_get_values_array: JMethodID, + pub method_get_values_array_ret: ReturnType, + pub method_update_from_values: JMethodID, + pub method_update_from_values_ret: ReturnType, } impl<'a> CometMetricNode<'a> { @@ -41,20 +43,30 @@ impl<'a> CometMetricNode<'a> { let class = env.find_class(Self::JVM_CLASS)?; Ok(CometMetricNode { - method_get_child_node: env.get_method_id( + method_get_metric_names: env.get_method_id( Self::JVM_CLASS, - "getChildNode", - format!("(I)L{:};", Self::JVM_CLASS).as_str(), + "getMetricNames", + "()[Ljava/lang/String;", )?, - method_get_child_node_ret: ReturnType::Object, - method_set: env.get_method_id(Self::JVM_CLASS, "set", "(Ljava/lang/String;J)V")?, - method_set_ret: ReturnType::Primitive(Primitive::Void), - method_set_all_from_bytes: env.get_method_id( + method_get_metric_names_ret: ReturnType::Object, + method_get_node_offsets: env.get_method_id( Self::JVM_CLASS, - "set_all_from_bytes", - "([B)V", + "getNodeOffsets", + "()[I", )?, - method_set_all_from_bytes_ret: ReturnType::Primitive(Primitive::Void), + method_get_node_offsets_ret: ReturnType::Object, + method_get_values_array: env.get_method_id( + Self::JVM_CLASS, + "getValuesArray", + "()[J", + )?, + method_get_values_array_ret: ReturnType::Object, + method_update_from_values: env.get_method_id( + Self::JVM_CLASS, + "updateFromValues", + "()V", + )?, + method_update_from_values_ret: ReturnType::Primitive(Primitive::Void), class, }) } diff --git a/native/core/src/parquet/mod.rs b/native/core/src/parquet/mod.rs index f2b0e80ab2..ad1924fda6 100644 --- a/native/core/src/parquet/mod.rs +++ b/native/core/src/parquet/mod.rs @@ -47,7 +47,7 @@ use jni::{ use self::util::jni::TypePromotionInfo; use crate::execution::jni_api::get_runtime; -use crate::execution::metrics::utils::update_comet_metric; +use crate::execution::metrics::utils::{build_metric_layout, update_comet_metric, MetricLayout}; use crate::execution::operators::ExecutionError; use crate::execution::planner::PhysicalPlanner; use crate::execution::serde; @@ -605,6 +605,7 @@ enum ParquetReaderState { struct BatchContext { native_plan: Arc, metrics_node: Arc, + metric_layout: MetricLayout, batch_stream: Option, current_batch: Option, reader_state: ParquetReaderState, @@ -780,9 +781,14 @@ pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_initRecordBat let partition_index: usize = 0; let batch_stream = scan.execute(partition_index, session_ctx.task_ctx())?; + let metrics_global_ref = Arc::new(jni_new_global_ref!(env, metrics_node)?); + let metric_layout = + build_metric_layout(&mut env, metrics_global_ref.as_obj())?; + let ctx = BatchContext { native_plan: Arc::new(SparkPlan::new(0, scan, vec![])), - metrics_node: Arc::new(jni_new_global_ref!(env, metrics_node)?), + metrics_node: metrics_global_ref, + metric_layout, batch_stream: Some(batch_stream), current_batch: None, reader_state: ParquetReaderState::Init, @@ -825,6 +831,7 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_readNextRecordBatch( &mut env, context.metrics_node.as_obj(), &context.native_plan, + &mut context.metric_layout, )?; context.current_batch = None; diff --git a/native/proto/build.rs b/native/proto/build.rs index 634888e29b..2ef0fc9d33 100644 --- a/native/proto/build.rs +++ b/native/proto/build.rs @@ -30,7 +30,6 @@ fn main() -> Result<()> { prost_build::Config::new().out_dir(out_dir).compile_protos( &[ "src/proto/expr.proto", - "src/proto/metric.proto", "src/proto/partitioning.proto", "src/proto/operator.proto", "src/proto/config.proto", diff --git a/native/proto/src/lib.rs b/native/proto/src/lib.rs index a55657b7af..b67574a4e9 100644 --- a/native/proto/src/lib.rs +++ b/native/proto/src/lib.rs @@ -39,12 +39,6 @@ pub mod spark_operator { include!(concat!("generated", "/spark.spark_operator.rs")); } -// Include generated modules from .proto files. -#[allow(missing_docs)] -pub mod spark_metric { - include!(concat!("generated", "/spark.spark_metric.rs")); -} - // Include generated modules from .proto files. #[allow(missing_docs)] pub mod spark_config { diff --git a/native/proto/src/proto/metric.proto b/native/proto/src/proto/metric.proto deleted file mode 100644 index f026e505ae..0000000000 --- a/native/proto/src/proto/metric.proto +++ /dev/null @@ -1,29 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - - - -syntax = "proto3"; - -package spark.spark_metric; - -option java_package = "org.apache.comet.serde"; - -message NativeMetricNode { - map metrics = 1; - repeated NativeMetricNode children = 2; -} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala index 8c75df1d45..a7339477c1 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala @@ -19,70 +19,71 @@ package org.apache.spark.sql.comet -import scala.jdk.CollectionConverters._ +import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkContext import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.comet.serde.Metric - /** - * A node carrying SQL metrics from SparkPlan, and metrics of its children. Native code will call - * [[getChildNode]] and [[set]] to update the metrics. + * A node carrying SQL metrics from SparkPlan, and metrics of its children. Metrics are flattened + * into parallel arrays for efficient bulk transfer from native code via JNI. * * @param metrics - * the mapping between metric name of native operator to `SQLMetric` of Spark operator. For - * example, `numOutputRows` -> `SQLMetrics("numOutputRows")` means the native operator will - * update `numOutputRows` metric with the value of `SQLMetrics("numOutputRows")` in Spark - * operator. + * the mapping between metric name of native operator to `SQLMetric` of Spark operator. */ case class CometMetricNode(metrics: Map[String, SQLMetric], children: Seq[CometMetricNode]) extends Logging { /** - * Gets a child node. Called from native. + * DFS flattening of the metric tree into parallel arrays. Within each node, metric names are + * sorted alphabetically for deterministic ordering. Returns (names, sqlMetrics, offsets) where + * offsets(i) is the start index of node i's metrics and offsets.length = numNodes + 1. */ - def getChildNode(i: Int): CometMetricNode = { - if (i < 0 || i >= children.length) { - // TODO: throw an exception, e.g. IllegalArgumentException, instead? - return null + private lazy val (flatMetricNames, flatSQLMetrics, nodeOffsets) = { + val names = new ArrayBuffer[String]() + val sqlMetrics = new ArrayBuffer[SQLMetric]() + val offsets = new ArrayBuffer[Int]() + + def walk(node: CometMetricNode): Unit = { + offsets += names.length + val sorted = node.metrics.toSeq.sortBy(_._1) + sorted.foreach { case (name, metric) => + names += name + sqlMetrics += metric + } + node.children.foreach(walk) } - children(i) + + walk(this) + offsets += names.length // sentinel + (names.toArray, sqlMetrics.toArray, offsets.toArray) } + /** Pre-allocated array for native code to write metric values into. */ + lazy val metricValuesArray: Array[Long] = new Array[Long](flatMetricNames.length) + + /** Returns the flattened metric names. Called from native once during setup. */ + def getMetricNames(): Array[String] = flatMetricNames + + /** Returns node start offsets into the flat arrays. Called from native once during setup. */ + def getNodeOffsets(): Array[Int] = nodeOffsets + + /** Returns the pre-allocated values array. Called from native once during setup. */ + def getValuesArray(): Array[Long] = metricValuesArray + /** - * Update the value of a metric. This method will typically be called multiple times for the - * same metric during multiple calls to executePlan. - * - * @param metricName - * the name of the metric at native operator. - * @param v - * the value to set. + * Updates all SQLMetrics from the values array. Called from native after bulk-copying metric + * values via SetLongArrayRegion. */ - def set(metricName: String, v: Long): Unit = { - metrics.get(metricName) match { - case Some(metric) => metric.set(v) - case None => - // no-op - logDebug(s"Non-existing metric: $metricName. Ignored") - } - } - - private def set_all(metricNode: Metric.NativeMetricNode): Unit = { - metricNode.getMetricsMap.forEach((name, value) => { - set(name, value) - }) - metricNode.getChildrenList.asScala.zip(children).foreach { case (child, childNode) => - childNode.set_all(child) + def updateFromValues(): Unit = { + var i = 0 + while (i < flatSQLMetrics.length) { + flatSQLMetrics(i).set(metricValuesArray(i)) + i += 1 } } - - def set_all_from_bytes(bytes: Array[Byte]): Unit = { - val metricNode = Metric.NativeMetricNode.parseFrom(bytes) - set_all(metricNode) - } } object CometMetricNode { From eef646865d713be7a3052837b656bd6dce97ad7b Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 28 Feb 2026 12:49:46 -0700 Subject: [PATCH 2/3] style: apply cargo fmt formatting --- native/core/src/execution/jni_api.rs | 6 +++--- native/core/src/parquet/mod.rs | 3 +-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 2367309075..f3691e2733 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -24,7 +24,8 @@ use crate::{ metrics::utils::{build_metric_layout, update_comet_metric, MetricLayout}, planner::PhysicalPlanner, serde::to_arrow_datatype, - shuffle::spark_unsafe::row::process_sorted_row_partition, sort::RdxSort, + shuffle::spark_unsafe::row::process_sorted_row_partition, + sort::RdxSort, }, jvm_bridge::{jni_new_global_ref, JVMClasses}, }; @@ -550,8 +551,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( // Build the flat metric layout for efficient metric updates let metrics = exec_context.metrics.as_obj(); - exec_context.metric_layout = - Some(build_metric_layout(&mut env, metrics)?); + exec_context.metric_layout = Some(build_metric_layout(&mut env, metrics)?); if exec_context.explain_native { let formatted_plan_str = diff --git a/native/core/src/parquet/mod.rs b/native/core/src/parquet/mod.rs index ad1924fda6..e0d8eb1778 100644 --- a/native/core/src/parquet/mod.rs +++ b/native/core/src/parquet/mod.rs @@ -782,8 +782,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_initRecordBat let batch_stream = scan.execute(partition_index, session_ctx.task_ctx())?; let metrics_global_ref = Arc::new(jni_new_global_ref!(env, metrics_node)?); - let metric_layout = - build_metric_layout(&mut env, metrics_global_ref.as_obj())?; + let metric_layout = build_metric_layout(&mut env, metrics_global_ref.as_obj())?; let ctx = BatchContext { native_plan: Arc::new(SparkPlan::new(0, scan, vec![])), From 19d95c85af9dffe65d71f10c11a2f948a43382c2 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 28 Feb 2026 15:55:13 -0700 Subject: [PATCH 3/3] fix: use sentinel value to avoid overwriting JVM-only metrics Metrics like conversionTime are accumulated on the JVM side and have no native counterpart. Initialize the values array to -1 (sentinel) and only update SQLMetrics where native actually wrote a value (>= 0), preventing native metric updates from zeroing out JVM-accumulated metrics. --- native/core/src/execution/metrics/utils.rs | 2 +- .../spark/sql/comet/CometMetricNode.scala | 17 ++++++++++++++--- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/native/core/src/execution/metrics/utils.rs b/native/core/src/execution/metrics/utils.rs index 875743764d..73a095ac03 100644 --- a/native/core/src/execution/metrics/utils.rs +++ b/native/core/src/execution/metrics/utils.rs @@ -81,7 +81,7 @@ pub(crate) fn build_metric_layout( Ok(MetricLayout { node_indices, - values: vec![0i64; num_metrics], + values: vec![-1i64; num_metrics], jarray, }) } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala index a7339477c1..7bbacee9bc 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometMetricNode.scala @@ -61,8 +61,15 @@ case class CometMetricNode(metrics: Map[String, SQLMetric], children: Seq[CometM (names.toArray, sqlMetrics.toArray, offsets.toArray) } - /** Pre-allocated array for native code to write metric values into. */ - lazy val metricValuesArray: Array[Long] = new Array[Long](flatMetricNames.length) + /** + * Pre-allocated array for native code to write metric values into. Initialized to -1 (sentinel) + * so that JVM-only metrics that native doesn't produce are not overwritten. + */ + lazy val metricValuesArray: Array[Long] = { + val arr = new Array[Long](flatMetricNames.length) + java.util.Arrays.fill(arr, -1L) + arr + } /** Returns the flattened metric names. Called from native once during setup. */ def getMetricNames(): Array[String] = flatMetricNames @@ -80,7 +87,11 @@ case class CometMetricNode(metrics: Map[String, SQLMetric], children: Seq[CometM def updateFromValues(): Unit = { var i = 0 while (i < flatSQLMetrics.length) { - flatSQLMetrics(i).set(metricValuesArray(i)) + val v = metricValuesArray(i) + if (v >= 0) { + flatSQLMetrics(i).set(v) + metricValuesArray(i) = -1L // reset sentinel for next cycle + } i += 1 } }