diff --git a/src/config.rs b/src/config.rs index fd9b11e8..ed8a3da8 100644 --- a/src/config.rs +++ b/src/config.rs @@ -308,10 +308,19 @@ pub struct Config { /// Maximum milliseconds to wait before flushing a batch of status updates. pub status_update_interval_ms: u64, - /// The hostname used to construct `callback_url` for task push requests. + /// Update claimed → processing (dispatch) updates in batches? + pub batch_push_updates: bool, + + /// The size of a batch of dispatch updates. + pub push_update_batch_size: usize, + + /// Maximum milliseconds to wait before flushing a batch of dispatch updates. + pub push_update_interval_ms: u64, + + /// (DEPRECATED) The hostname used to construct `callback_url` for task push requests. pub callback_addr: String, - /// The port used to construct `callback_url` for task push requests. + /// (DEPRECATED) The port used to construct `callback_url` for task push requests. pub callback_port: u32, /// Maps every application to its worker endpoint, both represented as strings. @@ -421,6 +430,9 @@ impl Default for Config { batch_status_updates: false, status_update_batch_size: 1, status_update_interval_ms: 100, + batch_push_updates: false, + push_update_batch_size: 1, + push_update_interval_ms: 100, callback_addr: "0.0.0.0".into(), callback_port: 50051, worker_map: [("sentry".into(), "http://127.0.0.1:50052".into())].into(), diff --git a/src/fetch/mod.rs b/src/fetch/mod.rs index e4ae2fab..89b12981 100644 --- a/src/fetch/mod.rs +++ b/src/fetch/mod.rs @@ -122,11 +122,10 @@ impl FetchPool { } _ = async { - let start = Instant::now(); - debug!("Fetching next batch of pending activations..."); metrics::counter!("fetch.loop.count").increment(1); + let start = Instant::now(); let mut backoff = false; let result = store.claim_activations_for_push(limit, bucket).await; diff --git a/src/fetch/tests.rs b/src/fetch/tests.rs index 092d3503..3e89a730 100644 --- a/src/fetch/tests.rs +++ b/src/fetch/tests.rs @@ -98,10 +98,14 @@ impl InflightActivationStore for MockStore { }) } - async fn mark_activation_processing(&self, _id: &str) -> Result<(), Error> { + async fn mark_processing(&self, _id: &str) -> Result<(), Error> { Ok(()) } + async fn mark_processing_batch(&self, _ids: &[String]) -> Result { + unimplemented!() + } + async fn pending_activation_max_lag(&self, _now: &DateTime) -> f64 { unimplemented!() } diff --git a/src/flusher.rs b/src/flusher.rs index 33732dfa..8d43bc82 100644 --- a/src/flusher.rs +++ b/src/flusher.rs @@ -3,7 +3,9 @@ use std::pin::Pin; use std::time::Duration; use anyhow::Result; +use elegant_departure::get_shutdown_guard; use tokio::sync::mpsc::Receiver; +use tracing::debug; /// Run flusher that receives values of type T from a channel and flushes /// them using the provided async `flush` function either when the batch is @@ -27,9 +29,16 @@ where let mut buffer: Vec = Vec::with_capacity(batch_size); + let guard = get_shutdown_guard().shutdown_on_drop(); + loop { tokio::select! { - msg = rx.recv() => { + biased; + + // When the buffer is NOT full, try to receive another message + msg = rx.recv(), if buffer.len() < batch_size => { + debug!("Buffer is NOT full, receiving a message..."); + match msg { Some(v) => { buffer.push(v); @@ -39,25 +48,43 @@ where } if buffer.len() >= batch_size { + debug!("Flushing full buffer..."); flush(&mut buffer).await; } } None => { - // Channel closed (shutdown), flush remaining and exit - flush(&mut buffer).await; + // Channel closed + debug!("Channel closed!"); break; } } } + // Otherwise, try flushing whatever is in the buffer every `interval_ms` milliseconds _ = interval.tick() => { - if !buffer.is_empty() { - flush(&mut buffer).await; + debug!("Performing periodic flush..."); + + if rx.is_closed() { + debug!("Channel closed on tick!"); + break; } + + flush(&mut buffer).await; + } + + _ = guard.wait() => { + debug!("Shutdown guard triggered!"); + break; } } } + // Drain and flush before exit + while let Ok(update) = rx.try_recv() { + buffer.push(update); + } + + flush(&mut buffer).await; Ok(()) } diff --git a/src/grpc/server.rs b/src/grpc/server.rs index 9d669c5e..3f68de7d 100644 --- a/src/grpc/server.rs +++ b/src/grpc/server.rs @@ -107,6 +107,9 @@ impl ConsumerService for TaskbrokerServer { } if let Some(ref tx) = self.update_tx { + let depth = tx.max_capacity() - tx.capacity(); + metrics::gauge!("grpc_server.update_queue.depth").set(depth as f64); + tx.send((id, status)) .await .map_err(|_| Status::internal("Status update channel closed"))?; diff --git a/src/main.rs b/src/main.rs index c3648a6d..543334d9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,7 +16,7 @@ use taskbroker::config::{Config, DatabaseAdapter, DeliveryMode}; use taskbroker::fetch::FetchPool; use taskbroker::grpc::auth_middleware::AuthLayer; use taskbroker::grpc::metrics_middleware::MetricsLayer; -use taskbroker::grpc::server::{TaskbrokerServer, flush_updates}; +use taskbroker::grpc::server::TaskbrokerServer; use taskbroker::kafka::admin::create_missing_topics; use taskbroker::kafka::consumer::start_consumer; use taskbroker::kafka::deserialize::{self, DeserializeConfig}; @@ -27,7 +27,6 @@ use taskbroker::kafka::inflight_activation_writer::{ ActivationWriterConfig, InflightActivationWriter, }; use taskbroker::kafka::os_stream_writer::{OsStream, OsStreamWriter}; -use taskbroker::logging; use taskbroker::metrics; use taskbroker::processing_strategy; use taskbroker::push::PushPool; @@ -40,6 +39,7 @@ use taskbroker::store::traits::InflightActivationStore; use taskbroker::upkeep::upkeep; use taskbroker::{Args, get_version}; use taskbroker::{SERVICE_NAME, flusher}; +use taskbroker::{grpc, logging, push}; async fn log_task_completion>(name: T, task: JoinHandle>) { match task.await { @@ -203,7 +203,7 @@ async fn main() -> Result<(), Error> { rx, flusher_config.status_update_batch_size, flusher_config.status_update_interval_ms, - move |buffer| Box::pin(flush_updates(flusher_store.clone(), buffer)), + move |buffer| Box::pin(grpc::server::flush_updates(flusher_store.clone(), buffer)), ) .await }); @@ -265,8 +265,30 @@ async fn main() -> Result<(), Error> { } }); + // Push update flush task + let (push_update_tx, push_update_task) = if config.batch_push_updates { + let (tx, rx) = tokio::sync::mpsc::channel(config.push_update_batch_size.max(1)); + + let flusher_store = store.clone(); + let flusher_config = config.clone(); + + let handle = tokio::spawn(async move { + flusher::run_flusher( + rx, + flusher_config.push_update_batch_size, + flusher_config.push_update_interval_ms, + move |buffer| Box::pin(push::flush_updates(flusher_store.clone(), buffer)), + ) + .await + }); + + (Some(tx), Some(handle)) + } else { + (None, None) + }; + // Initialize push and fetch pools - let push_pool = Arc::new(PushPool::new(config.clone(), store.clone())); + let push_pool = Arc::new(PushPool::new(config.clone(), store.clone(), push_update_tx)); let fetch_pool = FetchPool::new(store.clone(), config.clone(), push_pool.clone()); // Initialize push threads @@ -305,6 +327,10 @@ async fn main() -> Result<(), Error> { departure = departure.on_completion(log_task_completion("status_update_task", task)); } + if let Some(task) = push_update_task { + departure = departure.on_completion(log_task_completion("push_update_task", task)); + } + departure.await; Ok(()) } diff --git a/src/push/mod.rs b/src/push/mod.rs index fbf04018..aee8bb8e 100644 --- a/src/push/mod.rs +++ b/src/push/mod.rs @@ -1,4 +1,3 @@ -use chrono::Utc; use std::cmp::max; use std::collections::HashMap; use std::future::Future; @@ -8,6 +7,7 @@ use std::time::{Duration, Instant}; use anyhow::{Context, Result}; use async_backtrace::framed; +use chrono::Utc; use elegant_departure::get_shutdown_guard; use flume::{Receiver, SendError, Sender}; use hmac::{Hmac, Mac}; @@ -15,11 +15,12 @@ use prost::Message; use sentry_protos::taskbroker::v1::worker_service_client::WorkerServiceClient; use sentry_protos::taskbroker::v1::{PushTaskRequest, TaskActivation}; use sha2::Sha256; +use tokio::sync::mpsc; use tokio::task::JoinSet; use tonic::async_trait; use tonic::metadata::MetadataValue; use tonic::transport::Channel; -use tracing::{debug, error, info}; +use tracing::{debug, error, info, warn}; use crate::config::Config; use crate::store::activation::InflightActivation; @@ -100,6 +101,9 @@ pub struct PushPool { /// The receiving end of a channel that accepts task activations. receiver: Receiver<(InflightActivation, Instant)>, + /// Queue for batching claimed → processing updates. + update_tx: Option>, + /// Taskbroker configuration. config: Arc, @@ -111,28 +115,36 @@ pub struct PushPool { impl PushPool { /// Initialize a new push pool. - pub fn new(config: Arc, store: Arc) -> Self { + pub fn new( + config: Arc, + store: Arc, + update_tx: Option>, + ) -> Self { let worker_factory: WorkerFactory = Arc::new(|endpoint: String| { Box::pin(async move { let client = WorkerServiceClient::connect(endpoint).await?; Ok(Box::new(client) as Box) }) }); - Self::new_with_factory(config, store, worker_factory) + + Self::new_with_factory(config, store, worker_factory, update_tx) } fn new_with_factory( config: Arc, store: Arc, worker_factory: WorkerFactory, + update_tx: Option>, ) -> Self { let (sender, receiver) = flume::bounded(config.push_queue_size); + Self { sender, receiver, config, store, worker_factory, + update_tx, } } @@ -148,14 +160,10 @@ impl PushPool { let receiver = self.receiver.clone(); let store = store.clone(); let worker_factory = worker_factory.clone(); + let update_tx = self.update_tx.clone(); let guard = get_shutdown_guard().shutdown_on_drop(); - let callback_url = format!( - "{}:{}", - self.config.callback_addr, self.config.callback_port - ); - let timeout = Duration::from_millis(self.config.push_timeout_ms); let grpc_shared_secret = self.config.grpc_shared_secret.clone(); @@ -202,135 +210,22 @@ impl PushPool { metrics::histogram!("push.queue.latency").record(time.elapsed()); - let id = activation.id.clone(); - let callback_url = callback_url.clone(); - - let Some(worker) = workers.get_mut(&activation.application) else { - metrics::counter!("push.missing_worker_mapping", "application" => activation.application.clone()).increment(1); - - error!( - task_id = %id, - application = activation.application, - "Task application has no worker pool mapping" - ); - - continue; - }; - - match push_task( - worker.as_mut(), - activation.clone(), - callback_url, - timeout, - grpc_shared_secret.as_slice(), - ) - .await - { - Ok(_) => { - metrics::counter!("push.push_task", "result" => "ok").increment(1); - debug!(task_id = %id, "Activation sent to worker"); - - if activation.processing_attempts < 1 { - let latency = max(0, activation.received_latency(Utc::now())); - - metrics::histogram!( - "push.received_to_push.latency", - "namespace" => activation.namespace, - "taskname" => activation.taskname, - ) - .record(latency as f64); - } else { - debug!(task_id = %id, namespace = activation.namespace, taskname = activation.taskname, "Activation already processed, skipping received → push latency recording"); - } - - let start = Instant::now(); - let result = store.mark_activation_processing(&id).await; - metrics::histogram!("push.mark_activation_processing.duration").record(start.elapsed()); - - if let Err(e) = result { - metrics::counter!("push.mark_activation_processing", "result" => "error").increment(1); - - error!( - task_id = %id, - error = ?e, - "Failed to mark activation as sent after push" - ); - } - } - - // Once claim expires, status will be set back to pending - Err(e) => { - metrics::counter!("push.push_task", "result" => "error").increment(1); - - error!( - task_id = %id, - error = ?e, - "Failed to send activation to worker" - ) - } - }; + push_task(store.clone(), update_tx.as_ref(), activation, &mut workers, timeout, grpc_shared_secret.as_slice()).await; } } } // Drain channel before exiting without recording duration metrics since they don't matter at this time for (activation, _) in receiver.drain() { - let id = activation.id.clone(); - let callback_url = callback_url.clone(); - - let Some(worker) = workers.get_mut(&activation.application) else { - metrics::counter!("push.missing_worker_mapping", "application" => activation.application.clone()).increment(1); - - error!( - task_id = %id, - application = activation.application, - "Task application has no worker pool mapping" - ); - - continue; - }; - - match push_task( - worker.as_mut(), + push_task( + store.clone(), + update_tx.as_ref(), activation, - callback_url, + &mut workers, timeout, grpc_shared_secret.as_slice(), ) - .await - { - Ok(_) => { - metrics::counter!("push.push_task", "result" => "ok").increment(1); - debug!(task_id = %id, "Activation sent to worker"); - - let start = Instant::now(); - let result = store.mark_activation_processing(&id).await; - metrics::histogram!("push.mark_activation_processing.duration") - .record(start.elapsed()); - - if let Err(e) = result { - metrics::counter!("push.mark_activation_processing", "result" => "error").increment(1); - - error!( - task_id = %id, - error = ?e, - "Failed to mark activation as processing after push" - ); - } - } - - // Once processing deadline expires, status will be set back to pending - Err(e) => { - metrics::counter!("push.push_task", "result" => "error") - .increment(1); - - error!( - task_id = %id, - error = ?e, - "Failed to send activation to worker" - ) - } - }; + .await; } Ok(()) @@ -386,12 +281,106 @@ impl PushPool { } } -/// Decode task activation and push it to a worker. -#[framed] +/// Determine which worker should receive an activation, send the activation, and update its status. async fn push_task( + store: Arc, + update_tx: Option<&mpsc::Sender>, + activation: InflightActivation, + workers: &mut HashMap>, + timeout: Duration, + grpc_shared_secret: &[String], +) { + let id = activation.id.clone(); + + let Some(worker) = workers.get_mut(&activation.application) else { + metrics::counter!("push.missing_worker_mapping", "application" => activation.application.clone()).increment(1); + + error!( + task_id = %id, + application = activation.application, + "Task application has no worker pool mapping" + ); + + return; + }; + + match send_task( + worker.as_mut(), + activation.clone(), + timeout, + grpc_shared_secret, + ) + .await + { + Ok(_) => { + metrics::counter!("push.push_task", "result" => "ok").increment(1); + debug!(task_id = %id, "Activation sent to worker"); + + if activation.processing_attempts < 1 { + let latency = max(0, activation.received_latency(Utc::now())); + + metrics::histogram!( + "push.received_to_push.latency", + "namespace" => activation.namespace, + "taskname" => activation.taskname, + ) + .record(latency as f64); + } else { + debug!(task_id = %id, namespace = activation.namespace, taskname = activation.taskname, "Activation already processed, skipping received → push latency recording"); + } + + let start = Instant::now(); + + if let Some(tx) = update_tx { + let depth = tx.max_capacity() - tx.capacity(); + metrics::gauge!("push.update_queue.depth").set(depth as f64); + + let result = tx.send(id.clone()).await; + metrics::histogram!("push.mark_processing.duration").record(start.elapsed()); + + if let Err(e) = result { + metrics::counter!("push.mark_processing", "result" => "error").increment(1); + + error!( + task_id = %id, + error = ?e, + "Failed to enqueue push update" + ); + } + } else { + let result = store.mark_processing(&id).await; + metrics::histogram!("push.mark_processing.duration").record(start.elapsed()); + + if let Err(e) = result { + metrics::counter!("push.mark_processing", "result" => "error").increment(1); + + error!( + task_id = %id, + error = ?e, + "Failed to mark activation as processing after push" + ); + } + } + } + + // Once processing deadline expires, status will be set back to pending + Err(e) => { + metrics::counter!("push.push_task", "result" => "error").increment(1); + + error!( + task_id = %id, + error = ?e, + "Failed to send activation to worker" + ) + } + }; +} + +/// Decode task activation and send it to the worker service for a particular application. +#[framed] +async fn send_task( worker: &mut (dyn WorkerClient + Send), activation: InflightActivation, - callback_url: String, timeout: Duration, grpc_shared_secret: &[String], ) -> Result<()> { @@ -407,9 +396,10 @@ async fn push_task( } }; + // The callback URL isn't used by push taskworkers anymore, so we can use an empty string until it's removed from the schema let request = PushTaskRequest { task: Some(task), - callback_url, + callback_url: "".into(), }; let result = match tokio::time::timeout(timeout, worker.send(request, grpc_shared_secret)).await @@ -422,5 +412,55 @@ async fn push_task( result } +pub async fn flush_updates(store: Arc, buffer: &mut Vec) { + if buffer.is_empty() { + return; + } + + let start = Instant::now(); + let ids: Vec<_> = std::mem::take(buffer); + + let requested = ids.len() as u64; + metrics::histogram!("push.flush_updates.requested").record(requested as f64); + + let result = store.mark_processing_batch(&ids).await; + metrics::histogram!("push.mark_processing_batch.duration").record(start.elapsed()); + + match result { + Ok(affected) => { + metrics::histogram!("push.flush_updates.affected").record(affected as f64); + + metrics::counter!("push.flush_updates.updated").increment(affected); + metrics::counter!("push.flush_updates", "result" => "ok").increment(1); + + if affected < requested { + metrics::counter!("push.flush_updates.partial").increment(1); + + warn!( + requested, + affected, "Updated fewer rows than IDs requested from push pool" + ); + } + + debug!(affected, requested, "Flushed update batch from push pool"); + } + + Err(e) => { + metrics::counter!("push.flush_updates", "result" => "error").increment(1); + + error!( + requested, + error = ?e, + "Failed to flush update batch from push pool" + ); + + // Push failed updates back into the buffer so they can be retried on next flush + for id in ids { + buffer.push(id); + } + } + } +} + #[cfg(test)] mod tests; diff --git a/src/push/tests.rs b/src/push/tests.rs index 888b127d..a59bc826 100644 --- a/src/push/tests.rs +++ b/src/push/tests.rs @@ -1,10 +1,11 @@ use std::sync::{Arc, Mutex}; +use std::time::Instant; use anyhow::anyhow; use async_trait::async_trait; use chrono::{DateTime, Utc}; use sentry_protos::taskbroker::v1::PushTaskRequest; -use tokio::sync::Notify; +use tokio::sync::{Notify, mpsc}; use tokio::time::{Duration, timeout}; use crate::config::Config; @@ -63,15 +64,42 @@ impl WorkerClient for NotifyingWorkerClient { } /// Minimal fake store that records which activation IDs have been marked as processing. -#[derive(Default, Clone)] +/// All IDs marked via either `mark_processing` or successful `mark_processing_batch`. +#[derive(Clone)] struct MockStore { marked_processing: Arc>>, + mark_processing_calls: Arc>>, + mark_processing_batches: Arc>>>, + mark_processing_batch_should_fail: Arc>, +} + +impl Default for MockStore { + fn default() -> Self { + Self { + marked_processing: Arc::new(Mutex::new(vec![])), + mark_processing_calls: Arc::new(Mutex::new(vec![])), + mark_processing_batches: Arc::new(Mutex::new(vec![])), + mark_processing_batch_should_fail: Arc::new(Mutex::new(false)), + } + } } impl MockStore { fn marked_ids(&self) -> Vec { self.marked_processing.lock().unwrap().clone() } + + fn mark_processing_direct_calls(&self) -> Vec { + self.mark_processing_calls.lock().unwrap().clone() + } + + fn mark_processing_batch_calls(&self) -> Vec> { + self.mark_processing_batches.lock().unwrap().clone() + } + + fn set_mark_processing_batch_fail(&self, fail: bool) { + *self.mark_processing_batch_should_fail.lock().unwrap() = fail; + } } #[async_trait] @@ -79,9 +107,11 @@ impl InflightActivationStore for MockStore { async fn store(&self, _batch: Vec) -> anyhow::Result { Ok(0) } + fn assign_partitions(&self, _partitions: Vec) -> anyhow::Result<()> { Ok(()) } + async fn claim_activations( &self, _application: Option<&str>, @@ -92,10 +122,35 @@ impl InflightActivationStore for MockStore { ) -> anyhow::Result> { Ok(vec![]) } - async fn mark_activation_processing(&self, id: &str) -> anyhow::Result<()> { + + async fn mark_processing(&self, id: &str) -> anyhow::Result<()> { + self.mark_processing_calls + .lock() + .unwrap() + .push(id.to_string()); self.marked_processing.lock().unwrap().push(id.to_string()); Ok(()) } + + async fn mark_processing_batch(&self, ids: &[String]) -> anyhow::Result { + if *self.mark_processing_batch_should_fail.lock().unwrap() { + return Err(anyhow!("mock mark_processing_batch failure")); + } + + self.mark_processing_batches + .lock() + .unwrap() + .push(ids.to_vec()); + + let mut guard = self.marked_processing.lock().unwrap(); + + for id in ids { + guard.push(id.clone()); + } + + Ok(ids.len() as u64) + } + async fn set_status( &self, _id: &str, @@ -103,6 +158,7 @@ impl InflightActivationStore for MockStore { ) -> anyhow::Result> { Ok(None) } + async fn set_status_batch( &self, _ids: &[String], @@ -110,18 +166,23 @@ impl InflightActivationStore for MockStore { ) -> anyhow::Result { Ok(0) } + async fn pending_activation_max_lag(&self, _now: &DateTime) -> f64 { 0.0 } + async fn count_by_status(&self, _status: InflightActivationStatus) -> anyhow::Result { Ok(0) } + async fn count(&self) -> anyhow::Result { Ok(0) } + async fn get_by_id(&self, _id: &str) -> anyhow::Result> { Ok(None) } + async fn set_processing_deadline( &self, _id: &str, @@ -129,51 +190,66 @@ impl InflightActivationStore for MockStore { ) -> anyhow::Result<()> { Ok(()) } + async fn delete_activation(&self, _id: &str) -> anyhow::Result<()> { Ok(()) } + async fn vacuum_db(&self) -> anyhow::Result<()> { Ok(()) } + async fn full_vacuum_db(&self) -> anyhow::Result<()> { Ok(()) } + async fn db_size(&self) -> anyhow::Result { Ok(0) } + async fn get_retry_activations(&self) -> anyhow::Result> { Ok(vec![]) } + async fn handle_claim_expiration(&self) -> anyhow::Result { Ok(0) } + async fn handle_processing_deadline(&self) -> anyhow::Result { Ok(0) } + async fn handle_processing_attempts(&self) -> anyhow::Result { Ok(0) } + async fn handle_expires_at(&self) -> anyhow::Result { Ok(0) } + async fn handle_delay_until(&self) -> anyhow::Result { Ok(0) } + async fn handle_failed_tasks(&self) -> anyhow::Result { Ok(FailedTasksForwarder { to_discard: vec![], to_deadletter: vec![], }) } + async fn mark_completed(&self, _ids: Vec) -> anyhow::Result { Ok(0) } + async fn remove_completed(&self) -> anyhow::Result { Ok(0) } + async fn remove_killswitched(&self, _killswitched_tasks: Vec) -> anyhow::Result { Ok(0) } + async fn clear(&self) -> anyhow::Result<()> { Ok(()) } @@ -201,21 +277,12 @@ fn failing_connect_factory() -> WorkerFactory { async fn push_task_returns_ok_on_client_success() { let activation = make_activations(1).remove(0); let mut worker = MockWorkerClient::new(false); - let callback_url = "taskbroker:50051".to_string(); - - let result = push_task( - &mut worker, - activation.clone(), - callback_url.clone(), - Duration::from_secs(5), - &[], - ) - .await; + + let result = send_task(&mut worker, activation.clone(), Duration::from_secs(5), &[]).await; assert!(result.is_ok(), "push_task should succeed"); assert_eq!(worker.captured_requests.len(), 1); let request = &worker.captured_requests[0]; - assert_eq!(request.callback_url, callback_url); assert_eq!( request.task.as_ref().map(|task| task.id.as_str()), Some(activation.id.as_str()) @@ -228,14 +295,7 @@ async fn push_task_returns_err_on_invalid_payload() { activation.activation = vec![1, 2, 3, 4]; let mut worker = MockWorkerClient::new(false); - let result = push_task( - &mut worker, - activation, - "taskbroker:50051".to_string(), - Duration::from_secs(5), - &[], - ) - .await; + let result = send_task(&mut worker, activation, Duration::from_secs(5), &[]).await; assert!(result.is_err(), "invalid payload should fail decoding"); assert!( @@ -249,14 +309,7 @@ async fn push_task_propagates_client_error() { let activation = make_activations(1).remove(0); let mut worker = MockWorkerClient::new(true); - let result = push_task( - &mut worker, - activation, - "taskbroker:50051".to_string(), - Duration::from_secs(5), - &[], - ) - .await; + let result = send_task(&mut worker, activation, Duration::from_secs(5), &[]).await; assert!(result.is_err(), "worker send errors should propagate"); assert_eq!(worker.captured_requests.len(), 1); } @@ -269,7 +322,7 @@ async fn push_pool_submit_enqueues_item() { }); let store = create_test_store("sqlite").await; - let pool = PushPool::new(config, store); + let pool = PushPool::new(config, store, None); let activation = make_activations(1).remove(0); let time = Instant::now(); @@ -285,7 +338,7 @@ async fn push_pool_submit_backpressures_when_queue_full() { }); let store = create_test_store("sqlite").await; - let pool = PushPool::new(config, store); + let pool = PushPool::new(config, store, None); let time = Instant::now(); let first = make_activations(1).remove(0); @@ -321,7 +374,7 @@ async fn push_pool_start_worker_connect_failure_returns_error() { ..Config::default() }); let store = Arc::new(MockStore::default()); - let pool = PushPool::new_with_factory(config, store, failing_connect_factory()); + let pool = PushPool::new_with_factory(config, store, failing_connect_factory(), None); let result = pool.start().await; assert!( @@ -331,7 +384,7 @@ async fn push_pool_start_worker_connect_failure_returns_error() { } /// After a successful push for a first-attempt activation (processing_attempts == 0), -/// mark_activation_processing must be called on the store. +/// mark_processing must be called on the store. #[tokio::test] async fn push_pool_start_marks_activation_processing_on_first_attempt() { let notify = Arc::new(Notify::new()); @@ -346,6 +399,7 @@ async fn push_pool_start_marks_activation_processing_on_first_attempt() { config, store.clone(), notifying_factory(false, notify.clone()), + None, )); let pool_start = pool.clone(); @@ -361,7 +415,7 @@ async fn push_pool_start_marks_activation_processing_on_first_attempt() { .await .expect("submit should succeed"); - // Wait for the worker to call send(), then give it time to call mark_activation_processing + // Wait for the worker to call send(), then give it time to call mark_processing timeout(Duration::from_secs(2), notify.notified()) .await .expect("timed out waiting for push to be delivered"); @@ -370,12 +424,12 @@ async fn push_pool_start_marks_activation_processing_on_first_attempt() { assert_eq!( store.marked_ids(), vec![id], - "mark_activation_processing should be called after a successful first-attempt push" + "mark_processing should be called after a successful first-attempt push" ); } /// After a successful push for a retried activation (processing_attempts > 0), -/// mark_activation_processing must be called and latency recording is skipped. +/// mark_processing must be called and latency recording is skipped. #[tokio::test] async fn push_pool_start_marks_activation_processing_on_retry() { let notify = Arc::new(Notify::new()); @@ -390,6 +444,7 @@ async fn push_pool_start_marks_activation_processing_on_retry() { config, store.clone(), notifying_factory(false, notify.clone()), + None, )); let pool_start = pool.clone(); @@ -413,13 +468,13 @@ async fn push_pool_start_marks_activation_processing_on_retry() { assert_eq!( store.marked_ids(), vec![id], - "mark_activation_processing should be called after a successful retry push" + "mark_processing should be called after a successful retry push" ); } -/// When the worker fails to deliver an activation, mark_activation_processing must NOT be called. +/// When the worker fails to deliver an activation, mark_processing must NOT be called. #[tokio::test] -async fn push_pool_start_does_not_mark_activation_processing_on_push_failure() { +async fn push_pool_start_does_not_mark_processing_on_push_failure() { let notify = Arc::new(Notify::new()); let config = Arc::new(Config { worker_map: [("sentry".into(), "unused".into())].into(), @@ -432,6 +487,7 @@ async fn push_pool_start_does_not_mark_activation_processing_on_push_failure() { config, store.clone(), notifying_factory(true, notify.clone()), + None, )); let pool_start = pool.clone(); @@ -451,6 +507,130 @@ async fn push_pool_start_does_not_mark_activation_processing_on_push_failure() { assert!( store.marked_ids().is_empty(), - "mark_activation_processing should not be called when push fails" + "mark_processing should not be called when push fails" + ); +} + +/// With `update_tx` set, a successful push enqueues the task ID on the channel. +#[tokio::test] +async fn push_pool_forwards_successful_push_to_update_channel() { + let notify = Arc::new(Notify::new()); + let (update_tx, mut update_rx) = mpsc::channel::(8); + + let config = Arc::new(Config { + worker_map: [("sentry".into(), "unused".into())].into(), + push_threads: 1, + push_queue_size: 10, + ..Config::default() + }); + let store = Arc::new(MockStore::default()); + let pool = Arc::new(PushPool::new_with_factory( + config, + store.clone(), + notifying_factory(false, notify.clone()), + Some(update_tx), + )); + + let pool_start = pool.clone(); + tokio::spawn(async move { pool_start.start().await }); + + let activation = make_activations(1).remove(0); + let id = activation.id.clone(); + let time = Instant::now(); + + pool.submit(activation, time) + .await + .expect("Submit should succeed"); + + timeout(Duration::from_secs(2), notify.notified()) + .await + .expect("Timed out waiting for push to be delivered"); + tokio::time::sleep(Duration::from_millis(50)).await; + + assert!( + store.mark_processing_batch_calls().is_empty(), + "Method `mark_processing_batch` runs only via `flush_updates`, not the push worker" + ); + + let ch_id = update_rx + .recv() + .await + .expect("Task ID should be sent on update channel"); + assert_eq!(ch_id, id); +} + +/// Function `flush_updates` drains the buffer into `mark_processing_batch` and clears the buffer. +#[tokio::test] +async fn flush_updates_applies_batch_and_clears_buffer() { + let store = Arc::new(MockStore::default()); + let mut buf = vec!["id_0".to_string()]; + + flush_updates(store.clone(), &mut buf).await; + + assert!( + buf.is_empty(), + "buffer should be cleared after successful flush" + ); + assert!(store.mark_processing_direct_calls().is_empty()); + assert_eq!( + store.mark_processing_batch_calls(), + vec![vec!["id_0".to_string()]] ); + assert_eq!(store.marked_ids(), vec!["id_0".to_string()]); +} + +/// On `mark_processing_batch` error, `flush_updates` restores IDs into the buffer for retry. +#[tokio::test] +async fn flush_updates_restores_buffer_on_batch_error() { + let store = Arc::new(MockStore::default()); + store.set_mark_processing_batch_fail(true); + + let mut buf = vec!["a".to_string(), "b".to_string()]; + flush_updates(store.clone(), &mut buf).await; + + assert_eq!(buf, vec!["a".to_string(), "b".to_string()]); + assert!(store.mark_processing_batch_calls().is_empty()); + assert!(store.marked_ids().is_empty()); +} + +/// After a successful worker push, a closed `update_tx` receiver means neither the main loop nor +/// shutdown drain can enqueue the ID. +#[tokio::test] +async fn push_pool_does_not_fallback_to_mark_processing_when_update_channel_closed() { + let notify = Arc::new(Notify::new()); + let (update_tx, update_rx) = mpsc::channel::(8); + drop(update_rx); + + let config = Arc::new(Config { + worker_map: [("sentry".into(), "unused".into())].into(), + push_threads: 1, + push_queue_size: 10, + ..Config::default() + }); + let store = Arc::new(MockStore::default()); + let pool = Arc::new(PushPool::new_with_factory( + config, + store.clone(), + notifying_factory(false, notify.clone()), + Some(update_tx), + )); + + let pool_start = pool.clone(); + tokio::spawn(async move { pool_start.start().await }); + + let activation = make_activations(1).remove(0); + let time = Instant::now(); + + pool.submit(activation, time) + .await + .expect("Submit should succeed"); + + timeout(Duration::from_secs(2), notify.notified()) + .await + .expect("Timed out waiting for push to be delivered"); + tokio::time::sleep(Duration::from_millis(50)).await; + + assert!(store.mark_processing_batch_calls().is_empty()); + assert!(store.mark_processing_direct_calls().is_empty()); + assert!(store.marked_ids().is_empty()); } diff --git a/src/store/adapters/postgres.rs b/src/store/adapters/postgres.rs index 27ce276f..c82adf59 100644 --- a/src/store/adapters/postgres.rs +++ b/src/store/adapters/postgres.rs @@ -503,10 +503,8 @@ impl InflightActivationStore for PostgresActivationStore { #[instrument(skip_all)] #[framed] - async fn mark_activation_processing(&self, id: &str) -> Result<(), Error> { - let mut conn = self - .acquire_write_conn_metric("mark_activation_processing") - .await?; + async fn mark_processing(&self, id: &str) -> Result<(), Error> { + let mut conn = self.acquire_write_conn_metric("mark_processing").await?; let grace_period = self.config.processing_deadline_grace_sec; let result = sqlx::query(&format!( @@ -523,20 +521,47 @@ impl InflightActivationStore for PostgresActivationStore { .await?; if result.rows_affected() == 0 { - metrics::counter!("push.mark_activation_processing", "result" => "not_found") - .increment(1); + metrics::counter!("push.mark_processing", "result" => "not_found").increment(1); warn!( task_id = %id, "Activation could not be marked as processing, it may be missing or its status may have already changed" ); } else { - metrics::counter!("push.mark_activation_processing", "result" => "ok").increment(1); + metrics::counter!("push.mark_processing", "result" => "ok").increment(1); } Ok(()) } + #[instrument(skip_all)] + #[framed] + async fn mark_processing_batch(&self, ids: &[String]) -> Result { + if ids.is_empty() { + return Ok(0); + } + + let mut conn = self + .acquire_write_conn_metric("mark_processing_batch") + .await?; + + let grace_period = self.config.processing_deadline_grace_sec; + let result = sqlx::query(&format!( + "UPDATE inflight_taskactivations SET + status = $1, + processing_deadline = now() + (processing_deadline_duration * interval '1 second') + (interval '{grace_period} seconds'), + claim_expires_at = NULL + WHERE id = ANY($2) AND status = $3", + )) + .bind(InflightActivationStatus::Processing.to_string()) + .bind(ids) + .bind(InflightActivationStatus::Claimed.to_string()) + .execute(&mut *conn) + .await?; + + Ok(result.rows_affected()) + } + /// Get the age of the oldest pending activation in seconds. /// Only activations with status=pending and processing_attempts=0 are considered /// as we are interested in latency to the *first* attempt. diff --git a/src/store/adapters/sqlite.rs b/src/store/adapters/sqlite.rs index de457de4..8bca255b 100644 --- a/src/store/adapters/sqlite.rs +++ b/src/store/adapters/sqlite.rs @@ -597,10 +597,8 @@ impl InflightActivationStore for SqliteActivationStore { } #[instrument(skip_all)] - async fn mark_activation_processing(&self, id: &str) -> Result<(), Error> { - let mut conn = self - .acquire_write_conn_metric("mark_activation_processing") - .await?; + async fn mark_processing(&self, id: &str) -> Result<(), Error> { + let mut conn = self.acquire_write_conn_metric("mark_processing").await?; let grace_period = self.config.processing_deadline_grace_sec; let result = sqlx::query(&format!( @@ -617,20 +615,48 @@ impl InflightActivationStore for SqliteActivationStore { .await?; if result.rows_affected() == 0 { - metrics::counter!("push.mark_activation_processing", "result" => "not_found") - .increment(1); + metrics::counter!("push.mark_processing", "result" => "not_found").increment(1); warn!( task_id = %id, "Activation could not be marked as sent, it may be missing or its status may have already changed" ); } else { - metrics::counter!("push.mark_activation_processing", "result" => "ok").increment(1); + metrics::counter!("push.mark_processing", "result" => "ok").increment(1); } Ok(()) } + #[instrument(skip_all)] + async fn mark_processing_batch(&self, ids: &[String]) -> Result { + if ids.is_empty() { + return Ok(0); + } + + let mut conn = self + .acquire_write_conn_metric("mark_processing_batch") + .await?; + + let grace_period = self.config.processing_deadline_grace_sec; + let mut query_builder = QueryBuilder::new("UPDATE inflight_taskactivations SET status = "); + query_builder.push_bind(InflightActivationStatus::Processing); + query_builder.push(format!( + ", processing_deadline = unixepoch('now', '+' || (processing_deadline_duration + {grace_period}) || ' seconds'), claim_expires_at = NULL WHERE status = ", + )); + query_builder.push_bind(InflightActivationStatus::Claimed); + query_builder.push(" AND id IN ("); + + let mut separated = query_builder.separated(", "); + for id in ids.iter() { + separated.push_bind(id); + } + separated.push_unseparated(")"); + + let result = query_builder.build().execute(&mut *conn).await?; + Ok(result.rows_affected()) + } + /// Get the age of the oldest pending activation in seconds. /// Only activations with status=pending and processing_attempts=0 are considered /// as we are interested in latency to the *first* attempt. diff --git a/src/store/traits.rs b/src/store/traits.rs index d8bdb5e0..287fc12f 100644 --- a/src/store/traits.rs +++ b/src/store/traits.rs @@ -27,7 +27,7 @@ pub trait InflightActivationStore: Send + Sync { mark_processing: bool, ) -> Result, Error>; - /// Claims `limit` activations within the `bucket` range. Push mode uses status `Claimed` until `mark_activation_processing` moves to `Processing`. + /// Claims `limit` activations within the `bucket` range. Push mode uses status `Claimed` until `mark_processing` moves to `Processing`. async fn claim_activations_for_push( &self, limit: Option, @@ -69,7 +69,10 @@ pub trait InflightActivationStore: Send + Sync { } /// Record successful push. - async fn mark_activation_processing(&self, id: &str) -> Result<(), Error>; + async fn mark_processing(&self, id: &str) -> Result<(), Error>; + + /// Record a batch of successful pushes. + async fn mark_processing_batch(&self, ids: &[String]) -> Result; /// Update the status of a specific activation async fn set_status(