From 809d73517451436736111cd497f20e763cc17d2b Mon Sep 17 00:00:00 2001 From: Leonardo Yvens Date: Wed, 18 Feb 2026 13:42:49 +0000 Subject: [PATCH] refactor(server): move peak memory tracking to query metrics Consolidate memory usage reporting into the metrics recording path instead of relying on a stream wrapper in common. - Remove `PeakMemoryStream` wrapper from `QueryContext` - Expose `memory_pool()` accessor on `QueryContext` - Record peak memory in `record_query_execution` alongside other metrics - Remove unused `record_query_memory` method Signed-off-by: Leonardo Yvens --- crates/core/common/src/context/query.rs | 62 ++++--------------------- crates/services/server/src/flight.rs | 16 +++++-- crates/services/server/src/metrics.rs | 18 +++++-- 3 files changed, 35 insertions(+), 61 deletions(-) diff --git a/crates/core/common/src/context/query.rs b/crates/core/common/src/context/query.rs index 25788eba2..9c3065993 100644 --- a/crates/core/common/src/context/query.rs +++ b/crates/core/common/src/context/query.rs @@ -1,19 +1,17 @@ use std::{ collections::BTreeMap, - pin::Pin, sync::{Arc, LazyLock}, - task::{Context, Poll}, }; -use arrow::{array::ArrayRef, compute::concat_batches, datatypes::SchemaRef}; +use arrow::{array::ArrayRef, compute::concat_batches}; use datafusion::{ self, arrow::array::RecordBatch, catalog::MemorySchemaProvider, error::DataFusionError, execution::{ - RecordBatchStream, SendableRecordBatchStream, SessionStateBuilder, config::SessionConfig, - context::SessionContext, memory_pool::human_readable_size, runtime_env::RuntimeEnv, + SendableRecordBatchStream, SessionStateBuilder, config::SessionConfig, + context::SessionContext, runtime_env::RuntimeEnv, }, logical_expr::LogicalPlan, physical_optimizer::PhysicalOptimizerRule, @@ -25,7 +23,7 @@ use datafusion_tracing::{ InstrumentationOptions, instrument_with_info_spans, pretty_format_compact_batch, }; use datasets_common::network_id::NetworkId; -use futures::{Stream, TryStreamExt, stream}; +use futures::{TryStreamExt, stream}; use regex::Regex; use tracing::field; @@ -82,6 +80,11 @@ impl QueryContext { }) } + /// Returns the tiered memory pool for this query context. + pub fn memory_pool(&self) -> &Arc { + &self.tiered_memory_pool + } + /// Returns the catalog snapshot backing this query context. pub fn catalog(&self) -> &CatalogSnapshot { &self.catalog @@ -120,10 +123,7 @@ impl QueryContext { .await .map_err(ExecutePlanError::Execute)?; - Ok(PeakMemoryStream::wrap( - result, - self.tiered_memory_pool.clone(), - )) + Ok(result) } /// This will load the result set entirely in memory, so it should be used with caution. @@ -535,48 +535,6 @@ fn print_physical_plan(plan: &dyn ExecutionPlan) -> String { sanitize_parquet_paths(&plan_str) } -/// A stream wrapper that logs peak memory usage when dropped. -/// -/// Because `execute_plan` returns a lazy `SendableRecordBatchStream`, memory is only -/// allocated when the stream is consumed. This wrapper defers the peak memory log to -/// when the stream is dropped (i.e., after consumption or cancellation). -struct PeakMemoryStream { - inner: SendableRecordBatchStream, - pool: Arc, -} - -impl PeakMemoryStream { - fn wrap( - inner: SendableRecordBatchStream, - pool: Arc, - ) -> SendableRecordBatchStream { - Box::pin(Self { inner, pool }) - } -} - -impl Drop for PeakMemoryStream { - fn drop(&mut self) { - tracing::debug!( - peak_memory_mb = human_readable_size(self.pool.peak_reserved()), - "Query memory usage" - ); - } -} - -impl Stream for PeakMemoryStream { - type Item = Result; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.inner.as_mut().poll_next(cx) - } -} - -impl RecordBatchStream for PeakMemoryStream { - fn schema(&self) -> SchemaRef { - self.inner.schema() - } -} - /// Creates an instrumentation rule that captures metrics and provides previews of data during execution. pub fn create_instrumentation_rule() -> Arc { let options_builder = InstrumentationOptions::builder() diff --git a/crates/services/server/src/flight.rs b/crates/services/server/src/flight.rs index 0705d7f80..81ff5521e 100644 --- a/crates/services/server/src/flight.rs +++ b/crates/services/server/src/flight.rs @@ -47,6 +47,7 @@ use common::{ }, dataset_store::{DatasetStore, GetDatasetError}, detached_logical_plan::{AttachPlanError, DetachedLogicalPlan}, + memory_pool::TieredMemoryPool, query_env::QueryEnv, sql::{ ResolveFunctionReferencesError, ResolveTableReferencesError, resolve_function_references, @@ -213,6 +214,7 @@ impl Service { ); } + let memory_pool = ctx.memory_pool().clone(); let record_batches = ctx .execute_plan(plan, true) .await @@ -224,7 +226,12 @@ impl Service { }; if let Some(metrics) = &self.metrics { - Ok(track_query_metrics(stream, metrics, query_start_time)) + Ok(track_query_metrics( + stream, + metrics, + query_start_time, + Some(memory_pool), + )) } else { Ok(stream) } @@ -278,7 +285,7 @@ impl Service { }; if let Some(metrics) = &self.metrics { - Ok(track_query_metrics(stream, metrics, query_start_time)) + Ok(track_query_metrics(stream, metrics, query_start_time, None)) } else { Ok(stream) } @@ -617,6 +624,7 @@ fn track_query_metrics( stream: QueryResultStream, metrics: &Arc, start_time: std::time::Instant, + memory_pool: Option>, ) -> QueryResultStream { let metrics = metrics.clone(); @@ -647,7 +655,7 @@ fn track_query_metrics( let duration = start_time.elapsed().as_millis() as f64; let err_msg = e.to_string(); metrics.record_query_error(&err_msg); - metrics.record_query_execution(duration, total_rows, total_bytes); + metrics.record_query_execution(duration, total_rows, total_bytes, memory_pool.as_ref()); yield Err(e); return; @@ -657,7 +665,7 @@ fn track_query_metrics( // Stream completed successfully, record metrics let duration = start_time.elapsed().as_millis() as f64; - metrics.record_query_execution(duration, total_rows, total_bytes); + metrics.record_query_execution(duration, total_rows, total_bytes, memory_pool.as_ref()); }; QueryResultStream::NonIncremental { diff --git a/crates/services/server/src/metrics.rs b/crates/services/server/src/metrics.rs index dd8af1e78..270bfb20b 100644 --- a/crates/services/server/src/metrics.rs +++ b/crates/services/server/src/metrics.rs @@ -1,3 +1,7 @@ +use std::sync::Arc; + +use common::memory_pool::TieredMemoryPool; +use datafusion::execution::memory_pool::human_readable_size; use monitoring::telemetry; #[derive(Debug, Clone)] @@ -133,11 +137,20 @@ impl MetricsRegistry { duration_millis: f64, rows_returned: u64, bytes_egress: u64, + memory_pool: Option<&Arc>, ) { self.query_count.inc(); self.query_duration.record(duration_millis); self.query_rows_returned.inc_by(rows_returned); self.query_bytes_egress.inc_by(bytes_egress); + if let Some(pool) = memory_pool { + let peak = pool.peak_reserved() as u64; + self.query_memory_peak_bytes.record(peak); + tracing::debug!( + peak_memory = human_readable_size(peak as usize), + "Query memory usage" + ); + } } /// Record query error @@ -166,9 +179,4 @@ impl MetricsRegistry { pub fn record_streaming_lifetime(&self, duration_millis: f64) { self.streaming_query_lifetime.record(duration_millis); } - - /// Record query memory usage - pub fn record_query_memory(&self, peak_bytes: u64) { - self.query_memory_peak_bytes.record(peak_bytes); - } }