From 264c538a7fd1e94da80d52105642c3c06e7a0098 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Thu, 14 May 2026 16:04:16 -0700 Subject: [PATCH 01/14] Add Optional Batching for Push Updates --- src/config.rs | 12 ++ src/fetch/tests.rs | 6 +- src/main.rs | 34 ++++- src/push/mod.rs | 105 +++++++++++++-- src/push/tests.rs | 232 +++++++++++++++++++++++++++++++-- src/store/adapters/postgres.rs | 39 +++++- src/store/adapters/sqlite.rs | 40 +++++- src/store/traits.rs | 7 +- 8 files changed, 430 insertions(+), 45 deletions(-) diff --git a/src/config.rs b/src/config.rs index fd9b11e8..a953e96a 100644 --- a/src/config.rs +++ b/src/config.rs @@ -308,6 +308,15 @@ pub struct Config { /// Maximum milliseconds to wait before flushing a batch of status updates. pub status_update_interval_ms: u64, + /// Update claimed → processing (dispatch) updates in batches? + pub batch_push_updates: bool, + + /// The size of a batch of dispatch updates. + pub push_update_batch_size: usize, + + /// Maximum milliseconds to wait before flushing a batch of dispatch updates. + pub push_update_interval_ms: u64, + /// The hostname used to construct `callback_url` for task push requests. pub callback_addr: String, @@ -421,6 +430,9 @@ impl Default for Config { batch_status_updates: false, status_update_batch_size: 1, status_update_interval_ms: 100, + batch_push_updates: false, + push_update_batch_size: 1, + push_update_interval_ms: 100, callback_addr: "0.0.0.0".into(), callback_port: 50051, worker_map: [("sentry".into(), "http://127.0.0.1:50052".into())].into(), diff --git a/src/fetch/tests.rs b/src/fetch/tests.rs index 092d3503..3e89a730 100644 --- a/src/fetch/tests.rs +++ b/src/fetch/tests.rs @@ -98,10 +98,14 @@ impl InflightActivationStore for MockStore { }) } - async fn mark_activation_processing(&self, _id: &str) -> Result<(), Error> { + async fn mark_processing(&self, _id: &str) -> Result<(), Error> { Ok(()) } + async fn mark_processing_batch(&self, _ids: &[String]) -> Result { + unimplemented!() + } + async fn pending_activation_max_lag(&self, _now: &DateTime) -> f64 { unimplemented!() } diff --git a/src/main.rs b/src/main.rs index c3648a6d..543334d9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,7 +16,7 @@ use taskbroker::config::{Config, DatabaseAdapter, DeliveryMode}; use taskbroker::fetch::FetchPool; use taskbroker::grpc::auth_middleware::AuthLayer; use taskbroker::grpc::metrics_middleware::MetricsLayer; -use taskbroker::grpc::server::{TaskbrokerServer, flush_updates}; +use taskbroker::grpc::server::TaskbrokerServer; use taskbroker::kafka::admin::create_missing_topics; use taskbroker::kafka::consumer::start_consumer; use taskbroker::kafka::deserialize::{self, DeserializeConfig}; @@ -27,7 +27,6 @@ use taskbroker::kafka::inflight_activation_writer::{ ActivationWriterConfig, InflightActivationWriter, }; use taskbroker::kafka::os_stream_writer::{OsStream, OsStreamWriter}; -use taskbroker::logging; use taskbroker::metrics; use taskbroker::processing_strategy; use taskbroker::push::PushPool; @@ -40,6 +39,7 @@ use taskbroker::store::traits::InflightActivationStore; use taskbroker::upkeep::upkeep; use taskbroker::{Args, get_version}; use taskbroker::{SERVICE_NAME, flusher}; +use taskbroker::{grpc, logging, push}; async fn log_task_completion>(name: T, task: JoinHandle>) { match task.await { @@ -203,7 +203,7 @@ async fn main() -> Result<(), Error> { rx, flusher_config.status_update_batch_size, flusher_config.status_update_interval_ms, - move |buffer| Box::pin(flush_updates(flusher_store.clone(), buffer)), + move |buffer| Box::pin(grpc::server::flush_updates(flusher_store.clone(), buffer)), ) .await }); @@ -265,8 +265,30 @@ async fn main() -> Result<(), Error> { } }); + // Push update flush task + let (push_update_tx, push_update_task) = if config.batch_push_updates { + let (tx, rx) = tokio::sync::mpsc::channel(config.push_update_batch_size.max(1)); + + let flusher_store = store.clone(); + let flusher_config = config.clone(); + + let handle = tokio::spawn(async move { + flusher::run_flusher( + rx, + flusher_config.push_update_batch_size, + flusher_config.push_update_interval_ms, + move |buffer| Box::pin(push::flush_updates(flusher_store.clone(), buffer)), + ) + .await + }); + + (Some(tx), Some(handle)) + } else { + (None, None) + }; + // Initialize push and fetch pools - let push_pool = Arc::new(PushPool::new(config.clone(), store.clone())); + let push_pool = Arc::new(PushPool::new(config.clone(), store.clone(), push_update_tx)); let fetch_pool = FetchPool::new(store.clone(), config.clone(), push_pool.clone()); // Initialize push threads @@ -305,6 +327,10 @@ async fn main() -> Result<(), Error> { departure = departure.on_completion(log_task_completion("status_update_task", task)); } + if let Some(task) = push_update_task { + departure = departure.on_completion(log_task_completion("push_update_task", task)); + } + departure.await; Ok(()) } diff --git a/src/push/mod.rs b/src/push/mod.rs index fbf04018..8a5fc001 100644 --- a/src/push/mod.rs +++ b/src/push/mod.rs @@ -1,4 +1,3 @@ -use chrono::Utc; use std::cmp::max; use std::collections::HashMap; use std::future::Future; @@ -8,6 +7,7 @@ use std::time::{Duration, Instant}; use anyhow::{Context, Result}; use async_backtrace::framed; +use chrono::Utc; use elegant_departure::get_shutdown_guard; use flume::{Receiver, SendError, Sender}; use hmac::{Hmac, Mac}; @@ -15,11 +15,12 @@ use prost::Message; use sentry_protos::taskbroker::v1::worker_service_client::WorkerServiceClient; use sentry_protos::taskbroker::v1::{PushTaskRequest, TaskActivation}; use sha2::Sha256; +use tokio::sync::mpsc; use tokio::task::JoinSet; use tonic::async_trait; use tonic::metadata::MetadataValue; use tonic::transport::Channel; -use tracing::{debug, error, info}; +use tracing::{debug, error, info, warn}; use crate::config::Config; use crate::store::activation::InflightActivation; @@ -100,6 +101,9 @@ pub struct PushPool { /// The receiving end of a channel that accepts task activations. receiver: Receiver<(InflightActivation, Instant)>, + /// Queue for batching claimed → processing updates. + update_tx: Option>, + /// Taskbroker configuration. config: Arc, @@ -111,28 +115,36 @@ pub struct PushPool { impl PushPool { /// Initialize a new push pool. - pub fn new(config: Arc, store: Arc) -> Self { + pub fn new( + config: Arc, + store: Arc, + update_tx: Option>, + ) -> Self { let worker_factory: WorkerFactory = Arc::new(|endpoint: String| { Box::pin(async move { let client = WorkerServiceClient::connect(endpoint).await?; Ok(Box::new(client) as Box) }) }); - Self::new_with_factory(config, store, worker_factory) + + Self::new_with_factory(config, store, worker_factory, update_tx) } fn new_with_factory( config: Arc, store: Arc, worker_factory: WorkerFactory, + update_tx: Option>, ) -> Self { let (sender, receiver) = flume::bounded(config.push_queue_size); + Self { sender, receiver, config, store, worker_factory, + update_tx, } } @@ -148,6 +160,7 @@ impl PushPool { let receiver = self.receiver.clone(); let store = store.clone(); let worker_factory = worker_factory.clone(); + let update_tx = self.update_tx.clone(); let guard = get_shutdown_guard().shutdown_on_drop(); @@ -244,11 +257,31 @@ impl PushPool { } let start = Instant::now(); - let result = store.mark_activation_processing(&id).await; - metrics::histogram!("push.mark_activation_processing.duration").record(start.elapsed()); + + // Are we batching claimed → processing updates? + if let Some(ref tx) = update_tx { + let result = tx.send(id.clone()).await; + metrics::histogram!("push.mark_processing.duration").record(start.elapsed()); + + if let Err(e) = result { + metrics::counter!("push.mark_processing", "result" => "error").increment(1); + + error!( + task_id = %id, + error = ?e, + "Failed to enqueue push update" + ); + } + + continue; + } + + // Fall back to individual updates + let result = store.mark_processing(&id).await; + metrics::histogram!("push.mark_processing.duration").record(start.elapsed()); if let Err(e) = result { - metrics::counter!("push.mark_activation_processing", "result" => "error").increment(1); + metrics::counter!("push.mark_processing", "result" => "error").increment(1); error!( task_id = %id, @@ -304,12 +337,13 @@ impl PushPool { debug!(task_id = %id, "Activation sent to worker"); let start = Instant::now(); - let result = store.mark_activation_processing(&id).await; - metrics::histogram!("push.mark_activation_processing.duration") + let result = store.mark_processing(&id).await; + metrics::histogram!("push.mark_processing.duration") .record(start.elapsed()); if let Err(e) = result { - metrics::counter!("push.mark_activation_processing", "result" => "error").increment(1); + metrics::counter!("push.mark_processing", "result" => "error") + .increment(1); error!( task_id = %id, @@ -422,5 +456,56 @@ async fn push_task( result } +pub async fn flush_updates(store: Arc, buffer: &mut Vec) { + if buffer.is_empty() { + return; + } + + let start = Instant::now(); + + let ids = std::mem::take(buffer); + + let requested = ids.len() as u64; + metrics::histogram!("push.flush_updates.requested").record(requested as f64); + + let result = store.mark_processing_batch(&ids).await; + metrics::histogram!("push.mark_processing_batch.duration").record(start.elapsed()); + + match result { + Ok(affected) => { + metrics::histogram!("push.flush_updates.affected").record(affected as f64); + + metrics::counter!("push.flush_updates.updated").increment(affected); + metrics::counter!("push.flush_updates", "result" => "ok").increment(1); + + if affected < requested { + metrics::counter!("push.flush_updates.partial").increment(1); + + warn!( + requested, + affected, "Updated fewer rows than IDs requested from push pool" + ); + } + + debug!(affected, requested, "Flushed update batch from push pool"); + } + + Err(e) => { + metrics::counter!("push.flush_updates", "result" => "error").increment(1); + + error!( + requested, + error = ?e, + "Failed to flush update batch from push pool" + ); + + // Push failed updates back into the buffer so they can be retried on next flush + for id in ids { + buffer.push(id); + } + } + } +} + #[cfg(test)] mod tests; diff --git a/src/push/tests.rs b/src/push/tests.rs index 888b127d..3b5ee0f4 100644 --- a/src/push/tests.rs +++ b/src/push/tests.rs @@ -1,10 +1,11 @@ use std::sync::{Arc, Mutex}; +use std::time::Instant; use anyhow::anyhow; use async_trait::async_trait; use chrono::{DateTime, Utc}; use sentry_protos::taskbroker::v1::PushTaskRequest; -use tokio::sync::Notify; +use tokio::sync::{Notify, mpsc}; use tokio::time::{Duration, timeout}; use crate::config::Config; @@ -63,15 +64,42 @@ impl WorkerClient for NotifyingWorkerClient { } /// Minimal fake store that records which activation IDs have been marked as processing. -#[derive(Default, Clone)] +/// All IDs marked via either `mark_processing` or successful `mark_processing_batch`. +#[derive(Clone)] struct MockStore { marked_processing: Arc>>, + mark_processing_calls: Arc>>, + mark_processing_batches: Arc>>>, + mark_processing_batch_should_fail: Arc>, +} + +impl Default for MockStore { + fn default() -> Self { + Self { + marked_processing: Arc::new(Mutex::new(vec![])), + mark_processing_calls: Arc::new(Mutex::new(vec![])), + mark_processing_batches: Arc::new(Mutex::new(vec![])), + mark_processing_batch_should_fail: Arc::new(Mutex::new(false)), + } + } } impl MockStore { fn marked_ids(&self) -> Vec { self.marked_processing.lock().unwrap().clone() } + + fn mark_processing_direct_calls(&self) -> Vec { + self.mark_processing_calls.lock().unwrap().clone() + } + + fn mark_processing_batch_calls(&self) -> Vec> { + self.mark_processing_batches.lock().unwrap().clone() + } + + fn set_mark_processing_batch_fail(&self, fail: bool) { + *self.mark_processing_batch_should_fail.lock().unwrap() = fail; + } } #[async_trait] @@ -79,9 +107,11 @@ impl InflightActivationStore for MockStore { async fn store(&self, _batch: Vec) -> anyhow::Result { Ok(0) } + fn assign_partitions(&self, _partitions: Vec) -> anyhow::Result<()> { Ok(()) } + async fn claim_activations( &self, _application: Option<&str>, @@ -92,10 +122,35 @@ impl InflightActivationStore for MockStore { ) -> anyhow::Result> { Ok(vec![]) } - async fn mark_activation_processing(&self, id: &str) -> anyhow::Result<()> { + + async fn mark_processing(&self, id: &str) -> anyhow::Result<()> { + self.mark_processing_calls + .lock() + .unwrap() + .push(id.to_string()); self.marked_processing.lock().unwrap().push(id.to_string()); Ok(()) } + + async fn mark_processing_batch(&self, ids: &[String]) -> anyhow::Result { + if *self.mark_processing_batch_should_fail.lock().unwrap() { + return Err(anyhow!("mock mark_processing_batch failure")); + } + + self.mark_processing_batches + .lock() + .unwrap() + .push(ids.to_vec()); + + let mut guard = self.marked_processing.lock().unwrap(); + + for id in ids { + guard.push(id.clone()); + } + + Ok(ids.len() as u64) + } + async fn set_status( &self, _id: &str, @@ -103,6 +158,7 @@ impl InflightActivationStore for MockStore { ) -> anyhow::Result> { Ok(None) } + async fn set_status_batch( &self, _ids: &[String], @@ -110,18 +166,23 @@ impl InflightActivationStore for MockStore { ) -> anyhow::Result { Ok(0) } + async fn pending_activation_max_lag(&self, _now: &DateTime) -> f64 { 0.0 } + async fn count_by_status(&self, _status: InflightActivationStatus) -> anyhow::Result { Ok(0) } + async fn count(&self) -> anyhow::Result { Ok(0) } + async fn get_by_id(&self, _id: &str) -> anyhow::Result> { Ok(None) } + async fn set_processing_deadline( &self, _id: &str, @@ -129,51 +190,66 @@ impl InflightActivationStore for MockStore { ) -> anyhow::Result<()> { Ok(()) } + async fn delete_activation(&self, _id: &str) -> anyhow::Result<()> { Ok(()) } + async fn vacuum_db(&self) -> anyhow::Result<()> { Ok(()) } + async fn full_vacuum_db(&self) -> anyhow::Result<()> { Ok(()) } + async fn db_size(&self) -> anyhow::Result { Ok(0) } + async fn get_retry_activations(&self) -> anyhow::Result> { Ok(vec![]) } + async fn handle_claim_expiration(&self) -> anyhow::Result { Ok(0) } + async fn handle_processing_deadline(&self) -> anyhow::Result { Ok(0) } + async fn handle_processing_attempts(&self) -> anyhow::Result { Ok(0) } + async fn handle_expires_at(&self) -> anyhow::Result { Ok(0) } + async fn handle_delay_until(&self) -> anyhow::Result { Ok(0) } + async fn handle_failed_tasks(&self) -> anyhow::Result { Ok(FailedTasksForwarder { to_discard: vec![], to_deadletter: vec![], }) } + async fn mark_completed(&self, _ids: Vec) -> anyhow::Result { Ok(0) } + async fn remove_completed(&self) -> anyhow::Result { Ok(0) } + async fn remove_killswitched(&self, _killswitched_tasks: Vec) -> anyhow::Result { Ok(0) } + async fn clear(&self) -> anyhow::Result<()> { Ok(()) } @@ -269,7 +345,7 @@ async fn push_pool_submit_enqueues_item() { }); let store = create_test_store("sqlite").await; - let pool = PushPool::new(config, store); + let pool = PushPool::new(config, store, None); let activation = make_activations(1).remove(0); let time = Instant::now(); @@ -285,7 +361,7 @@ async fn push_pool_submit_backpressures_when_queue_full() { }); let store = create_test_store("sqlite").await; - let pool = PushPool::new(config, store); + let pool = PushPool::new(config, store, None); let time = Instant::now(); let first = make_activations(1).remove(0); @@ -321,7 +397,7 @@ async fn push_pool_start_worker_connect_failure_returns_error() { ..Config::default() }); let store = Arc::new(MockStore::default()); - let pool = PushPool::new_with_factory(config, store, failing_connect_factory()); + let pool = PushPool::new_with_factory(config, store, failing_connect_factory(), None); let result = pool.start().await; assert!( @@ -331,7 +407,7 @@ async fn push_pool_start_worker_connect_failure_returns_error() { } /// After a successful push for a first-attempt activation (processing_attempts == 0), -/// mark_activation_processing must be called on the store. +/// mark_processing must be called on the store. #[tokio::test] async fn push_pool_start_marks_activation_processing_on_first_attempt() { let notify = Arc::new(Notify::new()); @@ -346,6 +422,7 @@ async fn push_pool_start_marks_activation_processing_on_first_attempt() { config, store.clone(), notifying_factory(false, notify.clone()), + None, )); let pool_start = pool.clone(); @@ -361,7 +438,7 @@ async fn push_pool_start_marks_activation_processing_on_first_attempt() { .await .expect("submit should succeed"); - // Wait for the worker to call send(), then give it time to call mark_activation_processing + // Wait for the worker to call send(), then give it time to call mark_processing timeout(Duration::from_secs(2), notify.notified()) .await .expect("timed out waiting for push to be delivered"); @@ -370,12 +447,12 @@ async fn push_pool_start_marks_activation_processing_on_first_attempt() { assert_eq!( store.marked_ids(), vec![id], - "mark_activation_processing should be called after a successful first-attempt push" + "mark_processing should be called after a successful first-attempt push" ); } /// After a successful push for a retried activation (processing_attempts > 0), -/// mark_activation_processing must be called and latency recording is skipped. +/// mark_processing must be called and latency recording is skipped. #[tokio::test] async fn push_pool_start_marks_activation_processing_on_retry() { let notify = Arc::new(Notify::new()); @@ -390,6 +467,7 @@ async fn push_pool_start_marks_activation_processing_on_retry() { config, store.clone(), notifying_factory(false, notify.clone()), + None, )); let pool_start = pool.clone(); @@ -413,13 +491,13 @@ async fn push_pool_start_marks_activation_processing_on_retry() { assert_eq!( store.marked_ids(), vec![id], - "mark_activation_processing should be called after a successful retry push" + "mark_processing should be called after a successful retry push" ); } -/// When the worker fails to deliver an activation, mark_activation_processing must NOT be called. +/// When the worker fails to deliver an activation, mark_processing must NOT be called. #[tokio::test] -async fn push_pool_start_does_not_mark_activation_processing_on_push_failure() { +async fn push_pool_start_does_not_mark_processing_on_push_failure() { let notify = Arc::new(Notify::new()); let config = Arc::new(Config { worker_map: [("sentry".into(), "unused".into())].into(), @@ -432,6 +510,7 @@ async fn push_pool_start_does_not_mark_activation_processing_on_push_failure() { config, store.clone(), notifying_factory(true, notify.clone()), + None, )); let pool_start = pool.clone(); @@ -451,6 +530,131 @@ async fn push_pool_start_does_not_mark_activation_processing_on_push_failure() { assert!( store.marked_ids().is_empty(), - "mark_activation_processing should not be called when push fails" + "mark_processing should not be called when push fails" ); } + +/// With `update_tx` set, a successful push on the main loop enqueues the task ID on the channel. +/// Shutdown drain does not use batching - it applies `mark_processing` per activation, so this test +/// does not assert on direct `mark_processing` calls (those can appear only from drain under shutdown). +#[tokio::test] +async fn push_pool_forwards_successful_push_to_update_channel() { + let notify = Arc::new(Notify::new()); + let (update_tx, mut update_rx) = mpsc::channel::(8); + + let config = Arc::new(Config { + worker_map: [("sentry".into(), "unused".into())].into(), + push_threads: 1, + push_queue_size: 10, + ..Config::default() + }); + let store = Arc::new(MockStore::default()); + let pool = Arc::new(PushPool::new_with_factory( + config, + store.clone(), + notifying_factory(false, notify.clone()), + Some(update_tx), + )); + + let pool_start = pool.clone(); + tokio::spawn(async move { pool_start.start().await }); + + let activation = make_activations(1).remove(0); + let id = activation.id.clone(); + let time = Instant::now(); + + pool.submit(activation, time) + .await + .expect("Submit should succeed"); + + timeout(Duration::from_secs(2), notify.notified()) + .await + .expect("Timed out waiting for push to be delivered"); + tokio::time::sleep(Duration::from_millis(50)).await; + + assert!( + store.mark_processing_batch_calls().is_empty(), + "Method `mark_processing_batch` runs only via `flush_updates`, not the push worker" + ); + + let ch_id = update_rx + .recv() + .await + .expect("Task ID should be sent on update channel"); + assert_eq!(ch_id, id); +} + +/// Function `flush_updates` drains the buffer into `mark_processing_batch` and clears the buffer. +#[tokio::test] +async fn flush_updates_applies_batch_and_clears_buffer() { + let store = Arc::new(MockStore::default()); + let mut buf = vec!["id_0".to_string()]; + + flush_updates(store.clone(), &mut buf).await; + + assert!( + buf.is_empty(), + "buffer should be cleared after successful flush" + ); + assert!(store.mark_processing_direct_calls().is_empty()); + assert_eq!( + store.mark_processing_batch_calls(), + vec![vec!["id_0".to_string()]] + ); + assert_eq!(store.marked_ids(), vec!["id_0".to_string()]); +} + +/// On `mark_processing_batch` error, `flush_updates` restores IDs into the buffer for retry. +#[tokio::test] +async fn flush_updates_restores_buffer_on_batch_error() { + let store = Arc::new(MockStore::default()); + store.set_mark_processing_batch_fail(true); + + let mut buf = vec!["a".to_string(), "b".to_string()]; + flush_updates(store.clone(), &mut buf).await; + + assert_eq!(buf, vec!["a".to_string(), "b".to_string()]); + assert!(store.mark_processing_batch_calls().is_empty()); + assert!(store.marked_ids().is_empty()); +} + +/// After a successful worker push, a closed `update_tx` receiver means the main loop cannot enqueue +/// the id (no `mark_processing` there). Shutdown drain still uses individual `mark_processing` for +/// any remaining queue items, so this test does not assert the mock store stayed untouched. +#[tokio::test] +async fn push_pool_does_not_fallback_to_mark_processing_when_update_channel_closed() { + let notify = Arc::new(Notify::new()); + let (update_tx, update_rx) = mpsc::channel::(8); + drop(update_rx); + + let config = Arc::new(Config { + worker_map: [("sentry".into(), "unused".into())].into(), + push_threads: 1, + push_queue_size: 10, + ..Config::default() + }); + let store = Arc::new(MockStore::default()); + let pool = Arc::new(PushPool::new_with_factory( + config, + store.clone(), + notifying_factory(false, notify.clone()), + Some(update_tx), + )); + + let pool_start = pool.clone(); + tokio::spawn(async move { pool_start.start().await }); + + let activation = make_activations(1).remove(0); + let time = Instant::now(); + + pool.submit(activation, time) + .await + .expect("Submit should succeed"); + + timeout(Duration::from_secs(2), notify.notified()) + .await + .expect("Timed out waiting for push to be delivered"); + tokio::time::sleep(Duration::from_millis(50)).await; + + assert!(store.mark_processing_batch_calls().is_empty()); +} diff --git a/src/store/adapters/postgres.rs b/src/store/adapters/postgres.rs index 27ce276f..c82adf59 100644 --- a/src/store/adapters/postgres.rs +++ b/src/store/adapters/postgres.rs @@ -503,10 +503,8 @@ impl InflightActivationStore for PostgresActivationStore { #[instrument(skip_all)] #[framed] - async fn mark_activation_processing(&self, id: &str) -> Result<(), Error> { - let mut conn = self - .acquire_write_conn_metric("mark_activation_processing") - .await?; + async fn mark_processing(&self, id: &str) -> Result<(), Error> { + let mut conn = self.acquire_write_conn_metric("mark_processing").await?; let grace_period = self.config.processing_deadline_grace_sec; let result = sqlx::query(&format!( @@ -523,20 +521,47 @@ impl InflightActivationStore for PostgresActivationStore { .await?; if result.rows_affected() == 0 { - metrics::counter!("push.mark_activation_processing", "result" => "not_found") - .increment(1); + metrics::counter!("push.mark_processing", "result" => "not_found").increment(1); warn!( task_id = %id, "Activation could not be marked as processing, it may be missing or its status may have already changed" ); } else { - metrics::counter!("push.mark_activation_processing", "result" => "ok").increment(1); + metrics::counter!("push.mark_processing", "result" => "ok").increment(1); } Ok(()) } + #[instrument(skip_all)] + #[framed] + async fn mark_processing_batch(&self, ids: &[String]) -> Result { + if ids.is_empty() { + return Ok(0); + } + + let mut conn = self + .acquire_write_conn_metric("mark_processing_batch") + .await?; + + let grace_period = self.config.processing_deadline_grace_sec; + let result = sqlx::query(&format!( + "UPDATE inflight_taskactivations SET + status = $1, + processing_deadline = now() + (processing_deadline_duration * interval '1 second') + (interval '{grace_period} seconds'), + claim_expires_at = NULL + WHERE id = ANY($2) AND status = $3", + )) + .bind(InflightActivationStatus::Processing.to_string()) + .bind(ids) + .bind(InflightActivationStatus::Claimed.to_string()) + .execute(&mut *conn) + .await?; + + Ok(result.rows_affected()) + } + /// Get the age of the oldest pending activation in seconds. /// Only activations with status=pending and processing_attempts=0 are considered /// as we are interested in latency to the *first* attempt. diff --git a/src/store/adapters/sqlite.rs b/src/store/adapters/sqlite.rs index de457de4..8bca255b 100644 --- a/src/store/adapters/sqlite.rs +++ b/src/store/adapters/sqlite.rs @@ -597,10 +597,8 @@ impl InflightActivationStore for SqliteActivationStore { } #[instrument(skip_all)] - async fn mark_activation_processing(&self, id: &str) -> Result<(), Error> { - let mut conn = self - .acquire_write_conn_metric("mark_activation_processing") - .await?; + async fn mark_processing(&self, id: &str) -> Result<(), Error> { + let mut conn = self.acquire_write_conn_metric("mark_processing").await?; let grace_period = self.config.processing_deadline_grace_sec; let result = sqlx::query(&format!( @@ -617,20 +615,48 @@ impl InflightActivationStore for SqliteActivationStore { .await?; if result.rows_affected() == 0 { - metrics::counter!("push.mark_activation_processing", "result" => "not_found") - .increment(1); + metrics::counter!("push.mark_processing", "result" => "not_found").increment(1); warn!( task_id = %id, "Activation could not be marked as sent, it may be missing or its status may have already changed" ); } else { - metrics::counter!("push.mark_activation_processing", "result" => "ok").increment(1); + metrics::counter!("push.mark_processing", "result" => "ok").increment(1); } Ok(()) } + #[instrument(skip_all)] + async fn mark_processing_batch(&self, ids: &[String]) -> Result { + if ids.is_empty() { + return Ok(0); + } + + let mut conn = self + .acquire_write_conn_metric("mark_processing_batch") + .await?; + + let grace_period = self.config.processing_deadline_grace_sec; + let mut query_builder = QueryBuilder::new("UPDATE inflight_taskactivations SET status = "); + query_builder.push_bind(InflightActivationStatus::Processing); + query_builder.push(format!( + ", processing_deadline = unixepoch('now', '+' || (processing_deadline_duration + {grace_period}) || ' seconds'), claim_expires_at = NULL WHERE status = ", + )); + query_builder.push_bind(InflightActivationStatus::Claimed); + query_builder.push(" AND id IN ("); + + let mut separated = query_builder.separated(", "); + for id in ids.iter() { + separated.push_bind(id); + } + separated.push_unseparated(")"); + + let result = query_builder.build().execute(&mut *conn).await?; + Ok(result.rows_affected()) + } + /// Get the age of the oldest pending activation in seconds. /// Only activations with status=pending and processing_attempts=0 are considered /// as we are interested in latency to the *first* attempt. diff --git a/src/store/traits.rs b/src/store/traits.rs index d8bdb5e0..287fc12f 100644 --- a/src/store/traits.rs +++ b/src/store/traits.rs @@ -27,7 +27,7 @@ pub trait InflightActivationStore: Send + Sync { mark_processing: bool, ) -> Result, Error>; - /// Claims `limit` activations within the `bucket` range. Push mode uses status `Claimed` until `mark_activation_processing` moves to `Processing`. + /// Claims `limit` activations within the `bucket` range. Push mode uses status `Claimed` until `mark_processing` moves to `Processing`. async fn claim_activations_for_push( &self, limit: Option, @@ -69,7 +69,10 @@ pub trait InflightActivationStore: Send + Sync { } /// Record successful push. - async fn mark_activation_processing(&self, id: &str) -> Result<(), Error>; + async fn mark_processing(&self, id: &str) -> Result<(), Error>; + + /// Record a batch of successful pushes. + async fn mark_processing_batch(&self, ids: &[String]) -> Result; /// Update the status of a specific activation async fn set_status( From 26dfa5f764886e7fff028ed268ff71f5fe7b544d Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Thu, 14 May 2026 16:47:33 -0700 Subject: [PATCH 02/14] Fix Unbounded Buffering Bug --- src/fetch/mod.rs | 3 +-- src/flusher.rs | 16 +++++++++++++++- src/push/mod.rs | 2 ++ 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/src/fetch/mod.rs b/src/fetch/mod.rs index e4ae2fab..89b12981 100644 --- a/src/fetch/mod.rs +++ b/src/fetch/mod.rs @@ -122,11 +122,10 @@ impl FetchPool { } _ = async { - let start = Instant::now(); - debug!("Fetching next batch of pending activations..."); metrics::counter!("fetch.loop.count").increment(1); + let start = Instant::now(); let mut backoff = false; let result = store.claim_activations_for_push(limit, bucket).await; diff --git a/src/flusher.rs b/src/flusher.rs index 33732dfa..b0e7cd5a 100644 --- a/src/flusher.rs +++ b/src/flusher.rs @@ -29,7 +29,8 @@ where loop { tokio::select! { - msg = rx.recv() => { + // When the buffer is NOT full, try to receive another message + msg = rx.recv(), if buffer.len() < batch_size => { match msg { Some(v) => { buffer.push(v); @@ -51,6 +52,19 @@ where } } + // If the buffer IS full, the branch above will never execute, and we will never + // discover that the channel is now closed, which is why this branch is necessary + _ = std::future::ready(()), if rx.is_closed() => { + while let Ok(update) = rx.try_recv() { + // Buffer may grow beyond configured limit, which is OK because we are shutting down + buffer.push(update); + } + + flush(&mut buffer).await; + break; + } + + // Otherwise, try flushing whatever is in the buffer every `interval_ms` milliseconds _ = interval.tick() => { if !buffer.is_empty() { flush(&mut buffer).await; diff --git a/src/push/mod.rs b/src/push/mod.rs index 8a5fc001..6a2535eb 100644 --- a/src/push/mod.rs +++ b/src/push/mod.rs @@ -337,6 +337,8 @@ impl PushPool { debug!(task_id = %id, "Activation sent to worker"); let start = Instant::now(); + + // We won't batch these updates to keep things simple during shutdown let result = store.mark_processing(&id).await; metrics::histogram!("push.mark_processing.duration") .record(start.elapsed()); From 04b79622c658c2e5bc5509f4e87910d431d39dc1 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Thu, 14 May 2026 16:55:43 -0700 Subject: [PATCH 03/14] Add Debug Logs for Flusher --- src/flusher.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/flusher.rs b/src/flusher.rs index b0e7cd5a..741dae33 100644 --- a/src/flusher.rs +++ b/src/flusher.rs @@ -4,6 +4,7 @@ use std::time::Duration; use anyhow::Result; use tokio::sync::mpsc::Receiver; +use tracing::debug; /// Run flusher that receives values of type T from a channel and flushes /// them using the provided async `flush` function either when the batch is @@ -31,6 +32,8 @@ where tokio::select! { // When the buffer is NOT full, try to receive another message msg = rx.recv(), if buffer.len() < batch_size => { + debug!("Buffer is NOT full, receiving a message..."); + match msg { Some(v) => { buffer.push(v); @@ -40,12 +43,14 @@ where } if buffer.len() >= batch_size { + debug!("Flushing full buffer..."); flush(&mut buffer).await; } } None => { // Channel closed (shutdown), flush remaining and exit + debug!("Channel closed due to shutdown, flushing remaining before exit..."); flush(&mut buffer).await; break; } @@ -55,6 +60,8 @@ where // If the buffer IS full, the branch above will never execute, and we will never // discover that the channel is now closed, which is why this branch is necessary _ = std::future::ready(()), if rx.is_closed() => { + debug!("Channel is closed and buffer is full, draining channel before exiting..."); + while let Ok(update) = rx.try_recv() { // Buffer may grow beyond configured limit, which is OK because we are shutting down buffer.push(update); @@ -66,6 +73,8 @@ where // Otherwise, try flushing whatever is in the buffer every `interval_ms` milliseconds _ = interval.tick() => { + debug!("Performing periodic flush..."); + if !buffer.is_empty() { flush(&mut buffer).await; } From 9c87519b8fe0bd11f9633d512cd954b11f58ad5f Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Thu, 14 May 2026 17:01:15 -0700 Subject: [PATCH 04/14] Minor Formatting --- src/push/mod.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/push/mod.rs b/src/push/mod.rs index 6a2535eb..fb7acdbc 100644 --- a/src/push/mod.rs +++ b/src/push/mod.rs @@ -464,7 +464,6 @@ pub async fn flush_updates(store: Arc, buffer: &mut } let start = Instant::now(); - let ids = std::mem::take(buffer); let requested = ids.len() as u64; From 27ae1de4d573e4e0aefb0d0e942473f5fe3d53cf Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Thu, 14 May 2026 18:56:53 -0700 Subject: [PATCH 05/14] Fix Hanging Test by Batching on Push Pool Drain --- src/push/mod.rs | 42 +++++++++++++++++++++++++++++------------- src/push/tests.rs | 11 +++++------ 2 files changed, 34 insertions(+), 19 deletions(-) diff --git a/src/push/mod.rs b/src/push/mod.rs index fb7acdbc..05612c07 100644 --- a/src/push/mod.rs +++ b/src/push/mod.rs @@ -338,20 +338,36 @@ impl PushPool { let start = Instant::now(); - // We won't batch these updates to keep things simple during shutdown - let result = store.mark_processing(&id).await; - metrics::histogram!("push.mark_processing.duration") - .record(start.elapsed()); + if let Some(ref tx) = update_tx { + let result = tx.send(id.clone()).await; + metrics::histogram!("push.mark_processing.duration") + .record(start.elapsed()); - if let Err(e) = result { - metrics::counter!("push.mark_processing", "result" => "error") - .increment(1); + if let Err(e) = result { + metrics::counter!("push.mark_processing", "result" => "error") + .increment(1); - error!( - task_id = %id, - error = ?e, - "Failed to mark activation as processing after push" - ); + error!( + task_id = %id, + error = ?e, + "Failed to enqueue push update during shutdown drain" + ); + } + } else { + let result = store.mark_processing(&id).await; + metrics::histogram!("push.mark_processing.duration") + .record(start.elapsed()); + + if let Err(e) = result { + metrics::counter!("push.mark_processing", "result" => "error") + .increment(1); + + error!( + task_id = %id, + error = ?e, + "Failed to mark activation as processing after push" + ); + } } } @@ -464,7 +480,7 @@ pub async fn flush_updates(store: Arc, buffer: &mut } let start = Instant::now(); - let ids = std::mem::take(buffer); + let ids: Vec<_> = std::mem::take(buffer); let requested = ids.len() as u64; metrics::histogram!("push.flush_updates.requested").record(requested as f64); diff --git a/src/push/tests.rs b/src/push/tests.rs index 3b5ee0f4..e88458eb 100644 --- a/src/push/tests.rs +++ b/src/push/tests.rs @@ -534,9 +534,7 @@ async fn push_pool_start_does_not_mark_processing_on_push_failure() { ); } -/// With `update_tx` set, a successful push on the main loop enqueues the task ID on the channel. -/// Shutdown drain does not use batching - it applies `mark_processing` per activation, so this test -/// does not assert on direct `mark_processing` calls (those can appear only from drain under shutdown). +/// With `update_tx` set, a successful push enqueues the task ID on the channel. #[tokio::test] async fn push_pool_forwards_successful_push_to_update_channel() { let notify = Arc::new(Notify::new()); @@ -618,9 +616,8 @@ async fn flush_updates_restores_buffer_on_batch_error() { assert!(store.marked_ids().is_empty()); } -/// After a successful worker push, a closed `update_tx` receiver means the main loop cannot enqueue -/// the id (no `mark_processing` there). Shutdown drain still uses individual `mark_processing` for -/// any remaining queue items, so this test does not assert the mock store stayed untouched. +/// After a successful worker push, a closed `update_tx` receiver means neither the main loop nor +/// shutdown drain can enqueue the ID. #[tokio::test] async fn push_pool_does_not_fallback_to_mark_processing_when_update_channel_closed() { let notify = Arc::new(Notify::new()); @@ -657,4 +654,6 @@ async fn push_pool_does_not_fallback_to_mark_processing_when_update_channel_clos tokio::time::sleep(Duration::from_millis(50)).await; assert!(store.mark_processing_batch_calls().is_empty()); + assert!(store.mark_processing_direct_calls().is_empty()); + assert!(store.marked_ids().is_empty()); } From 67a692d733b825270c3fb1555d6793856ff4e0f5 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Fri, 15 May 2026 13:35:32 -0700 Subject: [PATCH 06/14] Move Push Logic into Seprate Function, Add Update Queue Size Metrics, Bias Flusher --- src/flusher.rs | 2 + src/grpc/server.rs | 3 + src/push/mod.rs | 253 +++++++++++++++------------------------------ src/push/tests.rs | 27 +---- 4 files changed, 92 insertions(+), 193 deletions(-) diff --git a/src/flusher.rs b/src/flusher.rs index 741dae33..c6208675 100644 --- a/src/flusher.rs +++ b/src/flusher.rs @@ -30,6 +30,8 @@ where loop { tokio::select! { + biased; + // When the buffer is NOT full, try to receive another message msg = rx.recv(), if buffer.len() < batch_size => { debug!("Buffer is NOT full, receiving a message..."); diff --git a/src/grpc/server.rs b/src/grpc/server.rs index 9d669c5e..3f68de7d 100644 --- a/src/grpc/server.rs +++ b/src/grpc/server.rs @@ -107,6 +107,9 @@ impl ConsumerService for TaskbrokerServer { } if let Some(ref tx) = self.update_tx { + let depth = tx.max_capacity() - tx.capacity(); + metrics::gauge!("grpc_server.update_queue.depth").set(depth as f64); + tx.send((id, status)) .await .map_err(|_| Status::internal("Status update channel closed"))?; diff --git a/src/push/mod.rs b/src/push/mod.rs index 05612c07..ae1547c9 100644 --- a/src/push/mod.rs +++ b/src/push/mod.rs @@ -1,4 +1,3 @@ -use std::cmp::max; use std::collections::HashMap; use std::future::Future; use std::pin::Pin; @@ -7,7 +6,6 @@ use std::time::{Duration, Instant}; use anyhow::{Context, Result}; use async_backtrace::framed; -use chrono::Utc; use elegant_departure::get_shutdown_guard; use flume::{Receiver, SendError, Sender}; use hmac::{Hmac, Mac}; @@ -164,11 +162,6 @@ impl PushPool { let guard = get_shutdown_guard().shutdown_on_drop(); - let callback_url = format!( - "{}:{}", - self.config.callback_addr, self.config.callback_port - ); - let timeout = Duration::from_millis(self.config.push_timeout_ms); let grpc_shared_secret = self.config.grpc_shared_secret.clone(); @@ -215,174 +208,22 @@ impl PushPool { metrics::histogram!("push.queue.latency").record(time.elapsed()); - let id = activation.id.clone(); - let callback_url = callback_url.clone(); - - let Some(worker) = workers.get_mut(&activation.application) else { - metrics::counter!("push.missing_worker_mapping", "application" => activation.application.clone()).increment(1); - - error!( - task_id = %id, - application = activation.application, - "Task application has no worker pool mapping" - ); - - continue; - }; - - match push_task( - worker.as_mut(), - activation.clone(), - callback_url, - timeout, - grpc_shared_secret.as_slice(), - ) - .await - { - Ok(_) => { - metrics::counter!("push.push_task", "result" => "ok").increment(1); - debug!(task_id = %id, "Activation sent to worker"); - - if activation.processing_attempts < 1 { - let latency = max(0, activation.received_latency(Utc::now())); - - metrics::histogram!( - "push.received_to_push.latency", - "namespace" => activation.namespace, - "taskname" => activation.taskname, - ) - .record(latency as f64); - } else { - debug!(task_id = %id, namespace = activation.namespace, taskname = activation.taskname, "Activation already processed, skipping received → push latency recording"); - } - - let start = Instant::now(); - - // Are we batching claimed → processing updates? - if let Some(ref tx) = update_tx { - let result = tx.send(id.clone()).await; - metrics::histogram!("push.mark_processing.duration").record(start.elapsed()); - - if let Err(e) = result { - metrics::counter!("push.mark_processing", "result" => "error").increment(1); - - error!( - task_id = %id, - error = ?e, - "Failed to enqueue push update" - ); - } - - continue; - } - - // Fall back to individual updates - let result = store.mark_processing(&id).await; - metrics::histogram!("push.mark_processing.duration").record(start.elapsed()); - - if let Err(e) = result { - metrics::counter!("push.mark_processing", "result" => "error").increment(1); - - error!( - task_id = %id, - error = ?e, - "Failed to mark activation as sent after push" - ); - } - } - - // Once claim expires, status will be set back to pending - Err(e) => { - metrics::counter!("push.push_task", "result" => "error").increment(1); - - error!( - task_id = %id, - error = ?e, - "Failed to send activation to worker" - ) - } - }; + push_task(store.clone(), update_tx.as_ref(), activation, &mut workers, timeout, grpc_shared_secret.as_slice()).await; } } } // Drain channel before exiting without recording duration metrics since they don't matter at this time for (activation, _) in receiver.drain() { - let id = activation.id.clone(); - let callback_url = callback_url.clone(); - - let Some(worker) = workers.get_mut(&activation.application) else { - metrics::counter!("push.missing_worker_mapping", "application" => activation.application.clone()).increment(1); - - error!( - task_id = %id, - application = activation.application, - "Task application has no worker pool mapping" - ); - - continue; - }; - - match push_task( - worker.as_mut(), + push_task( + store.clone(), + update_tx.as_ref(), activation, - callback_url, + &mut workers, timeout, grpc_shared_secret.as_slice(), ) - .await - { - Ok(_) => { - metrics::counter!("push.push_task", "result" => "ok").increment(1); - debug!(task_id = %id, "Activation sent to worker"); - - let start = Instant::now(); - - if let Some(ref tx) = update_tx { - let result = tx.send(id.clone()).await; - metrics::histogram!("push.mark_processing.duration") - .record(start.elapsed()); - - if let Err(e) = result { - metrics::counter!("push.mark_processing", "result" => "error") - .increment(1); - - error!( - task_id = %id, - error = ?e, - "Failed to enqueue push update during shutdown drain" - ); - } - } else { - let result = store.mark_processing(&id).await; - metrics::histogram!("push.mark_processing.duration") - .record(start.elapsed()); - - if let Err(e) = result { - metrics::counter!("push.mark_processing", "result" => "error") - .increment(1); - - error!( - task_id = %id, - error = ?e, - "Failed to mark activation as processing after push" - ); - } - } - } - - // Once processing deadline expires, status will be set back to pending - Err(e) => { - metrics::counter!("push.push_task", "result" => "error") - .increment(1); - - error!( - task_id = %id, - error = ?e, - "Failed to send activation to worker" - ) - } - }; + .await; } Ok(()) @@ -438,12 +279,86 @@ impl PushPool { } } -/// Decode task activation and push it to a worker. -#[framed] +/// Determine which worker should receive an activation, send the activation, and update its status. async fn push_task( + store: Arc, + update_tx: Option<&mpsc::Sender>, + activation: InflightActivation, + workers: &mut HashMap>, + timeout: Duration, + grpc_shared_secret: &[String], +) { + let id = activation.id.clone(); + + let Some(worker) = workers.get_mut(&activation.application) else { + metrics::counter!("push.missing_worker_mapping", "application" => activation.application.clone()).increment(1); + + error!( + task_id = %id, + application = activation.application, + "Task application has no worker pool mapping" + ); + + return; + }; + + match send_task(worker.as_mut(), activation, timeout, grpc_shared_secret).await { + Ok(_) => { + metrics::counter!("push.push_task", "result" => "ok").increment(1); + debug!(task_id = %id, "Activation sent to worker"); + + let start = Instant::now(); + + if let Some(tx) = update_tx { + let depth = tx.max_capacity() - tx.capacity(); + metrics::gauge!("push.update_queue.depth").set(depth as f64); + + let result = tx.send(id.clone()).await; + metrics::histogram!("push.mark_processing.duration").record(start.elapsed()); + + if let Err(e) = result { + metrics::counter!("push.mark_processing", "result" => "error").increment(1); + + error!( + task_id = %id, + error = ?e, + "Failed to enqueue push update during shutdown drain" + ); + } + } else { + let result = store.mark_processing(&id).await; + metrics::histogram!("push.mark_processing.duration").record(start.elapsed()); + + if let Err(e) = result { + metrics::counter!("push.mark_processing", "result" => "error").increment(1); + + error!( + task_id = %id, + error = ?e, + "Failed to mark activation as processing after push" + ); + } + } + } + + // Once processing deadline expires, status will be set back to pending + Err(e) => { + metrics::counter!("push.push_task", "result" => "error").increment(1); + + error!( + task_id = %id, + error = ?e, + "Failed to send activation to worker" + ) + } + }; +} + +/// Decode task activation and send it to the worker service for a particular application. +#[framed] +async fn send_task( worker: &mut (dyn WorkerClient + Send), activation: InflightActivation, - callback_url: String, timeout: Duration, grpc_shared_secret: &[String], ) -> Result<()> { @@ -461,7 +376,7 @@ async fn push_task( let request = PushTaskRequest { task: Some(task), - callback_url, + callback_url: "".into(), }; let result = match tokio::time::timeout(timeout, worker.send(request, grpc_shared_secret)).await diff --git a/src/push/tests.rs b/src/push/tests.rs index e88458eb..8a3f2e8d 100644 --- a/src/push/tests.rs +++ b/src/push/tests.rs @@ -279,14 +279,7 @@ async fn push_task_returns_ok_on_client_success() { let mut worker = MockWorkerClient::new(false); let callback_url = "taskbroker:50051".to_string(); - let result = push_task( - &mut worker, - activation.clone(), - callback_url.clone(), - Duration::from_secs(5), - &[], - ) - .await; + let result = send_task(&mut worker, activation.clone(), Duration::from_secs(5), &[]).await; assert!(result.is_ok(), "push_task should succeed"); assert_eq!(worker.captured_requests.len(), 1); @@ -304,14 +297,7 @@ async fn push_task_returns_err_on_invalid_payload() { activation.activation = vec![1, 2, 3, 4]; let mut worker = MockWorkerClient::new(false); - let result = push_task( - &mut worker, - activation, - "taskbroker:50051".to_string(), - Duration::from_secs(5), - &[], - ) - .await; + let result = send_task(&mut worker, activation, Duration::from_secs(5), &[]).await; assert!(result.is_err(), "invalid payload should fail decoding"); assert!( @@ -325,14 +311,7 @@ async fn push_task_propagates_client_error() { let activation = make_activations(1).remove(0); let mut worker = MockWorkerClient::new(true); - let result = push_task( - &mut worker, - activation, - "taskbroker:50051".to_string(), - Duration::from_secs(5), - &[], - ) - .await; + let result = send_task(&mut worker, activation, Duration::from_secs(5), &[]).await; assert!(result.is_err(), "worker send errors should propagate"); assert_eq!(worker.captured_requests.len(), 1); } From c2112fe62b9ecdbbd9f3abe4ef880566bbe3adc5 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Fri, 15 May 2026 13:42:26 -0700 Subject: [PATCH 07/14] Fix Callback URL Issues --- src/config.rs | 4 ++-- src/push/mod.rs | 1 + src/push/tests.rs | 2 -- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/config.rs b/src/config.rs index a953e96a..ed8a3da8 100644 --- a/src/config.rs +++ b/src/config.rs @@ -317,10 +317,10 @@ pub struct Config { /// Maximum milliseconds to wait before flushing a batch of dispatch updates. pub push_update_interval_ms: u64, - /// The hostname used to construct `callback_url` for task push requests. + /// (DEPRECATED) The hostname used to construct `callback_url` for task push requests. pub callback_addr: String, - /// The port used to construct `callback_url` for task push requests. + /// (DEPRECATED) The port used to construct `callback_url` for task push requests. pub callback_port: u32, /// Maps every application to its worker endpoint, both represented as strings. diff --git a/src/push/mod.rs b/src/push/mod.rs index ae1547c9..cbde41c8 100644 --- a/src/push/mod.rs +++ b/src/push/mod.rs @@ -374,6 +374,7 @@ async fn send_task( } }; + // The callback URL isn't used by push taskworkers anymore, so we can use an empty string until it's removed from the schema let request = PushTaskRequest { task: Some(task), callback_url: "".into(), diff --git a/src/push/tests.rs b/src/push/tests.rs index 8a3f2e8d..a59bc826 100644 --- a/src/push/tests.rs +++ b/src/push/tests.rs @@ -277,14 +277,12 @@ fn failing_connect_factory() -> WorkerFactory { async fn push_task_returns_ok_on_client_success() { let activation = make_activations(1).remove(0); let mut worker = MockWorkerClient::new(false); - let callback_url = "taskbroker:50051".to_string(); let result = send_task(&mut worker, activation.clone(), Duration::from_secs(5), &[]).await; assert!(result.is_ok(), "push_task should succeed"); assert_eq!(worker.captured_requests.len(), 1); let request = &worker.captured_requests[0]; - assert_eq!(request.callback_url, callback_url); assert_eq!( request.task.as_ref().map(|task| task.id.as_str()), Some(activation.id.as_str()) From 81f949d3effa55477dacff215cb435854df1be81 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Fri, 15 May 2026 14:20:15 -0700 Subject: [PATCH 08/14] Fix Misleading Error Message --- src/push/mod.rs | 26 ++++++++++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/src/push/mod.rs b/src/push/mod.rs index cbde41c8..aee8bb8e 100644 --- a/src/push/mod.rs +++ b/src/push/mod.rs @@ -1,3 +1,4 @@ +use std::cmp::max; use std::collections::HashMap; use std::future::Future; use std::pin::Pin; @@ -6,6 +7,7 @@ use std::time::{Duration, Instant}; use anyhow::{Context, Result}; use async_backtrace::framed; +use chrono::Utc; use elegant_departure::get_shutdown_guard; use flume::{Receiver, SendError, Sender}; use hmac::{Hmac, Mac}; @@ -302,11 +304,31 @@ async fn push_task( return; }; - match send_task(worker.as_mut(), activation, timeout, grpc_shared_secret).await { + match send_task( + worker.as_mut(), + activation.clone(), + timeout, + grpc_shared_secret, + ) + .await + { Ok(_) => { metrics::counter!("push.push_task", "result" => "ok").increment(1); debug!(task_id = %id, "Activation sent to worker"); + if activation.processing_attempts < 1 { + let latency = max(0, activation.received_latency(Utc::now())); + + metrics::histogram!( + "push.received_to_push.latency", + "namespace" => activation.namespace, + "taskname" => activation.taskname, + ) + .record(latency as f64); + } else { + debug!(task_id = %id, namespace = activation.namespace, taskname = activation.taskname, "Activation already processed, skipping received → push latency recording"); + } + let start = Instant::now(); if let Some(tx) = update_tx { @@ -322,7 +344,7 @@ async fn push_task( error!( task_id = %id, error = ?e, - "Failed to enqueue push update during shutdown drain" + "Failed to enqueue push update" ); } } else { From 88dce44a9bbca47f698d3de9ef85f7d75725aee0 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Fri, 15 May 2026 19:40:41 -0700 Subject: [PATCH 09/14] Changes to Flusher to Ensure Smooth Shutdown --- src/flusher.rs | 38 ++++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/src/flusher.rs b/src/flusher.rs index c6208675..8d43bc82 100644 --- a/src/flusher.rs +++ b/src/flusher.rs @@ -3,6 +3,7 @@ use std::pin::Pin; use std::time::Duration; use anyhow::Result; +use elegant_departure::get_shutdown_guard; use tokio::sync::mpsc::Receiver; use tracing::debug; @@ -28,6 +29,8 @@ where let mut buffer: Vec = Vec::with_capacity(batch_size); + let guard = get_shutdown_guard().shutdown_on_drop(); + loop { tokio::select! { biased; @@ -51,38 +54,37 @@ where } None => { - // Channel closed (shutdown), flush remaining and exit - debug!("Channel closed due to shutdown, flushing remaining before exit..."); - flush(&mut buffer).await; + // Channel closed + debug!("Channel closed!"); break; } } } - // If the buffer IS full, the branch above will never execute, and we will never - // discover that the channel is now closed, which is why this branch is necessary - _ = std::future::ready(()), if rx.is_closed() => { - debug!("Channel is closed and buffer is full, draining channel before exiting..."); + // Otherwise, try flushing whatever is in the buffer every `interval_ms` milliseconds + _ = interval.tick() => { + debug!("Performing periodic flush..."); - while let Ok(update) = rx.try_recv() { - // Buffer may grow beyond configured limit, which is OK because we are shutting down - buffer.push(update); + if rx.is_closed() { + debug!("Channel closed on tick!"); + break; } flush(&mut buffer).await; - break; } - // Otherwise, try flushing whatever is in the buffer every `interval_ms` milliseconds - _ = interval.tick() => { - debug!("Performing periodic flush..."); - - if !buffer.is_empty() { - flush(&mut buffer).await; - } + _ = guard.wait() => { + debug!("Shutdown guard triggered!"); + break; } } } + // Drain and flush before exit + while let Ok(update) = rx.try_recv() { + buffer.push(update); + } + + flush(&mut buffer).await; Ok(()) } From bb5b66ef2ff9ae35ca6a4f979559aff018b7a7ce Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Mon, 18 May 2026 13:19:48 -0700 Subject: [PATCH 10/14] Break Long Lines --- src/fetch/mod.rs | 11 +++++++++-- src/push/mod.rs | 8 +++++++- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/fetch/mod.rs b/src/fetch/mod.rs index 89b12981..d898bf07 100644 --- a/src/fetch/mod.rs +++ b/src/fetch/mod.rs @@ -48,7 +48,8 @@ pub fn bucket_range_for_fetch_thread(thread_index: usize, fetch_threads: usize) (low, high) } -/// Thin interface for the push pool. It mostly serves to enable proper unit testing, but it also decouples fetch logic from push logic even further. +/// Thin interface for the push pool. It mostly serves to enable proper unit testing, +/// but it also decouples fetch logic from push logic even further. #[async_trait] pub trait TaskPusher { /// Submit a single task to the push pool. @@ -164,7 +165,13 @@ impl FetchPool { ) .record(latency as f64); } else { - debug!(task_id = %id, namespace = activation.namespace, taskname = activation.taskname, "Activation already processed, skipping received → claimed latency recording"); + debug!( + task_id = %id, + namespace = activation.namespace, + taskname = activation.taskname, + "Activation already processed, skipping \ + received → claimed latency recording" + ); } match pusher.submit_task(activation, start).await { diff --git a/src/push/mod.rs b/src/push/mod.rs index aee8bb8e..31c0636b 100644 --- a/src/push/mod.rs +++ b/src/push/mod.rs @@ -326,7 +326,13 @@ async fn push_task( ) .record(latency as f64); } else { - debug!(task_id = %id, namespace = activation.namespace, taskname = activation.taskname, "Activation already processed, skipping received → push latency recording"); + debug!( + task_id = %id, + namespace = activation.namespace, + taskname = activation.taskname, + "Activation already processed, skipping \ + received → push latency recording" + ); } let start = Instant::now(); From bee392a8f4d5083ff5b7308957924abd57fe6dcb Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Mon, 18 May 2026 14:41:41 -0700 Subject: [PATCH 11/14] Improve Claim Expiration Computation --- src/config.rs | 6 ++++++ src/store/adapters/postgres.rs | 28 +++++++++++++++++++++++++--- src/store/adapters/sqlite.rs | 24 +++++++++++++++++++++++- 3 files changed, 54 insertions(+), 4 deletions(-) diff --git a/src/config.rs b/src/config.rs index ed8a3da8..75325df1 100644 --- a/src/config.rs +++ b/src/config.rs @@ -222,6 +222,11 @@ pub struct Config { /// brokers are under load, or there are small networking delays. pub processing_deadline_grace_sec: u64, + /// The number of additional seconds that claim expirations + /// are extended by. This helps reduce claim expirations when + /// brokers are under load, or there are small networking delays. + pub claim_expiration_grace_sec: u64, + /// The frequency at which upkeep tasks /// (discarding, retrying activations, etc.) are executed. pub upkeep_task_interval_ms: u64, @@ -406,6 +411,7 @@ impl Default for Config { max_processing_count: 2048, max_processing_attempts: 5, processing_deadline_grace_sec: 3, + claim_expiration_grace_sec: 3, upkeep_task_interval_ms: 1000, upkeep_unhealthy_interval_ms: 5000, health_check_killswitched: false, diff --git a/src/store/adapters/postgres.rs b/src/store/adapters/postgres.rs index c82adf59..47c99c4d 100644 --- a/src/store/adapters/postgres.rs +++ b/src/store/adapters/postgres.rs @@ -154,6 +154,29 @@ impl PostgresActivationStoreConfig { url.as_ref().split('?').next().unwrap().to_string() + "?" + extra_query_params; conn_opts = PgConnectOptions::from_str(&new_url).unwrap(); } + + // Compute the longest amount of time an activation may be claimed + let claim_lease_ms = { + // In the worst case, every activation in the batch will time out when appending to the push queue + let queue_ms = config.fetch_batch_size as u64 * config.push_queue_timeout_ms; + + // In the worst case, every activation in the push queue will time out when sending + let send_ms = config.push_queue_size as u64 * config.push_timeout_ms; + + let update_ms = if config.batch_push_updates { + // In the worst case, we will need to wait an entire interval before flushing a batch of push updates + config.push_update_interval_ms + } else { + // Grace seconds will cover the update query duration until we decide to implement query timeouts + 0 + }; + + // Account for grace seconds specified in configuration + let grace_ms = config.claim_expiration_grace_sec * 1000; + + queue_ms + send_ms + update_ms + grace_ms + }; + Self { pg_connection: conn_opts, pg_database_name: config.pg_database_name.clone(), @@ -161,9 +184,9 @@ impl PostgresActivationStoreConfig { run_migrations: config.run_migrations, max_processing_attempts: config.max_processing_attempts, vacuum_page_count: config.vacuum_page_count, - processing_deadline_grace_sec: config.processing_deadline_grace_sec, - claim_lease_ms: config.fetch_batch_size.max(1) as u64 * config.push_queue_timeout_ms, enable_sqlite_status_metrics: config.enable_sqlite_status_metrics, + processing_deadline_grace_sec: config.processing_deadline_grace_sec, + claim_lease_ms, } } } @@ -505,7 +528,6 @@ impl InflightActivationStore for PostgresActivationStore { #[framed] async fn mark_processing(&self, id: &str) -> Result<(), Error> { let mut conn = self.acquire_write_conn_metric("mark_processing").await?; - let grace_period = self.config.processing_deadline_grace_sec; let result = sqlx::query(&format!( "UPDATE inflight_taskactivations SET diff --git a/src/store/adapters/sqlite.rs b/src/store/adapters/sqlite.rs index 8bca255b..e8a8e3eb 100644 --- a/src/store/adapters/sqlite.rs +++ b/src/store/adapters/sqlite.rs @@ -147,12 +147,34 @@ pub struct InflightActivationStoreConfig { impl InflightActivationStoreConfig { pub fn from_config(config: &Config) -> Self { + // Compute the longest amount of time an activation may be claimed + let claim_lease_ms = { + // In the worst case, every activation in the batch will time out when appending to the push queue + let queue_ms = config.fetch_batch_size as u64 * config.push_queue_timeout_ms; + + // In the worst case, every activation in the push queue will time out when sending + let send_ms = config.push_queue_size as u64 * config.push_timeout_ms; + + let update_ms = if config.batch_push_updates { + // In the worst case, we will need to wait an entire interval before flushing a batch of push updates + config.push_update_interval_ms + } else { + // Grace seconds will cover the update query duration until we decide to implement query timeouts + 0 + }; + + // Account for grace seconds specified in configuration + let grace_ms = config.claim_expiration_grace_sec * 1000; + + queue_ms + send_ms + update_ms + grace_ms + }; + Self { max_processing_attempts: config.max_processing_attempts, vacuum_page_count: config.vacuum_page_count, processing_deadline_grace_sec: config.processing_deadline_grace_sec, - claim_lease_ms: config.fetch_batch_size.max(1) as u64 * config.push_queue_timeout_ms, enable_sqlite_status_metrics: config.enable_sqlite_status_metrics, + claim_lease_ms, } } } From 0956223cc71463ede4fa349ea210cc08f5161ed7 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Mon, 18 May 2026 15:32:43 -0700 Subject: [PATCH 12/14] Fix Flusher Drain Shutdown Logic --- src/flusher.rs | 14 ++++++-------- src/main.rs | 3 +-- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/src/flusher.rs b/src/flusher.rs index 8d43bc82..dbfd32d2 100644 --- a/src/flusher.rs +++ b/src/flusher.rs @@ -28,8 +28,7 @@ where interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); let mut buffer: Vec = Vec::with_capacity(batch_size); - - let guard = get_shutdown_guard().shutdown_on_drop(); + let guard = get_shutdown_guard(); loop { tokio::select! { @@ -54,7 +53,7 @@ where } None => { - // Channel closed + // Channel closed because all senders were dropped debug!("Channel closed!"); break; } @@ -66,17 +65,13 @@ where debug!("Performing periodic flush..."); if rx.is_closed() { + // Channel closed because all senders were dropped debug!("Channel closed on tick!"); break; } flush(&mut buffer).await; } - - _ = guard.wait() => { - debug!("Shutdown guard triggered!"); - break; - } } } @@ -85,6 +80,9 @@ where buffer.push(update); } + // Delay shutdown until we have flushed everything in the buffer flush(&mut buffer).await; + drop(guard); + Ok(()) } diff --git a/src/main.rs b/src/main.rs index 543334d9..8cf145d4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -217,7 +217,6 @@ async fn main() -> Result<(), Error> { let grpc_server_task = tokio::spawn({ let grpc_store = store.clone(); let grpc_config = config.clone(); - let grpc_status_tx = status_update_tx.clone(); async move { let addr = format!("{}:{}", grpc_config.grpc_addr, grpc_config.grpc_port) @@ -234,7 +233,7 @@ async fn main() -> Result<(), Error> { .add_service(ConsumerServiceServer::new(TaskbrokerServer { store: grpc_store, config: grpc_config, - update_tx: grpc_status_tx, + update_tx: status_update_tx, })) .add_service(health_service.clone()) .serve(addr); From 8b17becde195315ead5f8650c185aed7abc17654 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Mon, 18 May 2026 15:47:20 -0700 Subject: [PATCH 13/14] Fix Claim Lease Double Count Grace Seconds --- src/store/adapters/postgres.rs | 9 +++++---- src/store/adapters/sqlite.rs | 8 +++++--- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/store/adapters/postgres.rs b/src/store/adapters/postgres.rs index 47c99c4d..393fecc5 100644 --- a/src/store/adapters/postgres.rs +++ b/src/store/adapters/postgres.rs @@ -444,9 +444,6 @@ impl InflightActivationStore for PostgresActivationStore { ) -> Result, Error> { let now = Utc::now(); - let grace_period = self.config.processing_deadline_grace_sec; - let claim_lease_ms = self.config.claim_lease_ms as i64; - let mut query_builder = QueryBuilder::::new( "WITH selected_activations AS ( SELECT id @@ -492,6 +489,8 @@ impl InflightActivationStore for PostgresActivationStore { query_builder.push(" FOR UPDATE SKIP LOCKED)"); if mark_processing { + let grace_period = self.config.processing_deadline_grace_sec; + query_builder.push(format!( "UPDATE inflight_taskactivations SET processing_deadline = now() + (processing_deadline_duration * interval '1 second') + (interval '{grace_period} seconds'), @@ -501,9 +500,11 @@ impl InflightActivationStore for PostgresActivationStore { query_builder.push_bind(InflightActivationStatus::Processing.to_string()); } else { + let claim_lease = self.config.claim_lease_ms as i64; + query_builder.push(format!( "UPDATE inflight_taskactivations - SET claim_expires_at = now() + ({claim_lease_ms} * interval '1 millisecond') + (interval '{grace_period} seconds'), + SET claim_expires_at = now() + ({claim_lease} * interval '1 millisecond'), processing_deadline = NULL, status = " )); diff --git a/src/store/adapters/sqlite.rs b/src/store/adapters/sqlite.rs index e8a8e3eb..9c01e154 100644 --- a/src/store/adapters/sqlite.rs +++ b/src/store/adapters/sqlite.rs @@ -557,20 +557,22 @@ impl InflightActivationStore for SqliteActivationStore { mark_processing: bool, ) -> Result, Error> { let now = Utc::now(); - let grace_period = self.config.processing_deadline_grace_sec; let mut query_builder = QueryBuilder::new("UPDATE inflight_taskactivations SET "); if mark_processing { + let grace_period = self.config.processing_deadline_grace_sec; + query_builder.push(format!( "processing_deadline = unixepoch('now', '+' || (processing_deadline_duration + {grace_period}) || ' seconds'), claim_expires_at = NULL, status = " )); query_builder.push_bind(InflightActivationStatus::Processing); } else { + let claim_lease = self.config.claim_lease_ms as f64 / 1000.0; + query_builder.push(format!( - "claim_expires_at = unixepoch('now', '+' || {:.3} || ' seconds', '+' || {grace_period} || ' seconds'), processing_deadline = NULL, status = ", - self.config.claim_lease_ms as f64 / 1000.0, + "claim_expires_at = unixepoch('now', '+' || {claim_lease:.3} || ' seconds'), processing_deadline = NULL, status = " )); query_builder.push_bind(InflightActivationStatus::Claimed); From 9acd2b99137ca3786a461f3f46ec7dc907b81f99 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Mon, 18 May 2026 15:48:46 -0700 Subject: [PATCH 14/14] Remove Unneeded Comments --- src/store/adapters/postgres.rs | 1 - src/store/adapters/sqlite.rs | 1 - 2 files changed, 2 deletions(-) diff --git a/src/store/adapters/postgres.rs b/src/store/adapters/postgres.rs index 393fecc5..6210a914 100644 --- a/src/store/adapters/postgres.rs +++ b/src/store/adapters/postgres.rs @@ -135,7 +135,6 @@ pub struct PostgresActivationStoreConfig { pub run_migrations: bool, pub max_processing_attempts: usize, pub processing_deadline_grace_sec: u64, - /// Milliseconds added to `claim_expires_at` before grace: `fetch_batch_size * push_queue_timeout_ms`. pub claim_lease_ms: u64, pub vacuum_page_count: Option, pub enable_sqlite_status_metrics: bool, diff --git a/src/store/adapters/sqlite.rs b/src/store/adapters/sqlite.rs index 9c01e154..d68c6df6 100644 --- a/src/store/adapters/sqlite.rs +++ b/src/store/adapters/sqlite.rs @@ -139,7 +139,6 @@ pub async fn create_sqlite_pool(url: &str) -> Result<(Pool, Pool pub struct InflightActivationStoreConfig { pub max_processing_attempts: usize, pub processing_deadline_grace_sec: u64, - /// Milliseconds added to `claim_expires_at` before grace: `fetch_batch_size * push_queue_timeout_ms`. pub claim_lease_ms: u64, pub vacuum_page_count: Option, pub enable_sqlite_status_metrics: bool,