diff --git a/src/config.rs b/src/config.rs index 700f1a5a..fd9b11e8 100644 --- a/src/config.rs +++ b/src/config.rs @@ -299,6 +299,15 @@ pub struct Config { /// Maximum time in milliseconds for a single push RPC to the worker service. This should be greater than the worker's internal timeout. pub push_timeout_ms: u64, + /// Update statuses from the gRPC server in batches? + pub batch_status_updates: bool, + + /// The size of a batch of status updates. + pub status_update_batch_size: usize, + + /// Maximum milliseconds to wait before flushing a batch of status updates. + pub status_update_interval_ms: u64, + /// The hostname used to construct `callback_url` for task push requests. pub callback_addr: String, @@ -409,6 +418,9 @@ impl Default for Config { push_queue_size: 1, push_queue_timeout_ms: 5000, push_timeout_ms: 30000, + batch_status_updates: false, + status_update_batch_size: 1, + status_update_interval_ms: 100, callback_addr: "0.0.0.0".into(), callback_port: 50051, worker_map: [("sentry".into(), "http://127.0.0.1:50052".into())].into(), diff --git a/src/fetch/mod.rs b/src/fetch/mod.rs index 89b12981..e4ae2fab 100644 --- a/src/fetch/mod.rs +++ b/src/fetch/mod.rs @@ -122,10 +122,11 @@ impl FetchPool { } _ = async { + let start = Instant::now(); + debug!("Fetching next batch of pending activations..."); metrics::counter!("fetch.loop.count").increment(1); - let start = Instant::now(); let mut backoff = false; let result = store.claim_activations_for_push(limit, bucket).await; diff --git a/src/fetch/tests.rs b/src/fetch/tests.rs index ea64bf84..092d3503 100644 --- a/src/fetch/tests.rs +++ b/src/fetch/tests.rs @@ -122,6 +122,14 @@ impl InflightActivationStore for MockStore { unimplemented!() } + async fn set_status_batch( + &self, + _ids: &[String], + _status: InflightActivationStatus, + ) -> Result { + unimplemented!() + } + async fn set_processing_deadline( &self, _id: &str, diff --git a/src/flusher.rs b/src/flusher.rs new file mode 100644 index 00000000..33732dfa --- /dev/null +++ b/src/flusher.rs @@ -0,0 +1,63 @@ +use std::future::Future; +use std::pin::Pin; +use std::time::Duration; + +use anyhow::Result; +use tokio::sync::mpsc::Receiver; + +/// Run flusher that receives values of type T from a channel and flushes +/// them using the provided async `flush` function either when the batch is +/// full or when the max flush interval has elapsed. This function is **not** +/// responsible for draining the buffer - `flush` does that. +pub async fn run_flusher( + mut rx: Receiver, + batch_size: usize, + interval_ms: u64, + mut flush: F, +) -> Result<()> +where + F: for<'a> FnMut(&'a mut Vec) -> Pin + Send + 'a>>, +{ + let batch_size = batch_size.max(1); + let interval_ms = interval_ms.max(1); + + let period = Duration::from_millis(interval_ms); + let mut interval = tokio::time::interval(period); + interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + + let mut buffer: Vec = Vec::with_capacity(batch_size); + + loop { + tokio::select! { + msg = rx.recv() => { + match msg { + Some(v) => { + buffer.push(v); + + while buffer.len() < batch_size && let Ok(update) = rx.try_recv() { + buffer.push(update); + } + + if buffer.len() >= batch_size { + flush(&mut buffer).await; + } + } + + None => { + // Channel closed (shutdown), flush remaining and exit + flush(&mut buffer).await; + break; + } + } + } + + _ = interval.tick() => { + if !buffer.is_empty() { + flush(&mut buffer).await; + } + } + } + } + + Ok(()) +} diff --git a/src/grpc/mod.rs b/src/grpc/mod.rs index abaa773f..b00e98d4 100644 --- a/src/grpc/mod.rs +++ b/src/grpc/mod.rs @@ -1,5 +1,6 @@ pub mod auth_middleware; pub mod metrics_middleware; pub mod server; + #[cfg(test)] mod server_tests; diff --git a/src/grpc/server.rs b/src/grpc/server.rs index cf4a6bdd..9d669c5e 100644 --- a/src/grpc/server.rs +++ b/src/grpc/server.rs @@ -1,6 +1,8 @@ +use std::collections::HashMap; use std::sync::Arc; use std::time::Instant; +use anyhow::Result; use chrono::Utc; use prost::Message; use sentry_protos::taskbroker::v1::consumer_service_server::ConsumerService; @@ -8,8 +10,9 @@ use sentry_protos::taskbroker::v1::{ FetchNextTask, GetTaskRequest, GetTaskResponse, SetTaskStatusRequest, SetTaskStatusResponse, TaskActivation, TaskActivationStatus, }; +use tokio::sync::mpsc::Sender; use tonic::{Request, Response, Status}; -use tracing::{error, instrument, warn}; +use tracing::{debug, error, instrument, warn}; use crate::config::{Config, DeliveryMode}; use crate::store::activation::InflightActivationStatus; @@ -18,6 +21,7 @@ use crate::store::traits::InflightActivationStore; pub struct TaskbrokerServer { pub store: Arc, pub config: Arc, + pub update_tx: Option>, } #[tonic::async_trait] @@ -97,10 +101,20 @@ impl ConsumerService for TaskbrokerServer { "Invalid status, expects 3 (Failure), 4 (Retry), or 5 (Complete), but got: {status:?}" ))); } + if status == InflightActivationStatus::Failure { metrics::counter!("grpc_server.set_status.failure").increment(1); } + if let Some(ref tx) = self.update_tx { + tx.send((id, status)) + .await + .map_err(|_| Status::internal("Status update channel closed"))?; + + metrics::histogram!("grpc_server.set_status.duration").record(start_time.elapsed()); + return Ok(Response::new(SetTaskStatusResponse { task: None })); + } + match self.store.set_status(&id, status).await { Ok(Some(_)) => metrics::counter!( "grpc_server.set_status", @@ -194,3 +208,80 @@ impl ConsumerService for TaskbrokerServer { res } } + +pub type StatusUpdate = (String, InflightActivationStatus); + +pub async fn flush_updates( + store: Arc, + buffer: &mut Vec, +) { + if buffer.is_empty() { + return; + } + + let mut by_status: HashMap> = HashMap::new(); + + for (id, status) in buffer.drain(..) { + by_status.entry(status).or_default().push(id); + } + + for (status, ids) in by_status { + let requested = ids.len() as u64; + let st = status.to_string(); + + metrics::histogram!("grpc_server.flush_updates.requested", "status" => st.clone()) + .record(requested as f64); + + match store.set_status_batch(&ids, status).await { + Ok(affected) => { + metrics::histogram!( + "grpc_server.flush_updates.affected", + "status" => st.clone() + ) + .record(affected as f64); + + metrics::counter!( + "grpc_server.flush_updates.updated", + "status" => st.clone() + ) + .increment(affected); + + metrics::counter!("grpc_server.flush_updates", "result" => "ok").increment(1); + + if affected < requested { + metrics::counter!( + "grpc_server.flush_updates.partial", + "status" => st.clone() + ) + .increment(1); + + warn!( + ?status, + requested, affected, "Updated fewer rows than IDs requested from server" + ); + } + + debug!( + ?status, + affected, requested, "Flushed status batch from server" + ); + } + + Err(e) => { + metrics::counter!("grpc_server.flush_updates", "result" => "error").increment(1); + + error!( + ?status, + requested, + error = ?e, + "Failed to flush status batch from server" + ); + + // Push failed updates back into the buffer so they can be retried on next flush + for id in ids { + buffer.push((id, status)); + } + } + } + } +} diff --git a/src/grpc/server_tests.rs b/src/grpc/server_tests.rs index 2b986d66..86bc343e 100644 --- a/src/grpc/server_tests.rs +++ b/src/grpc/server_tests.rs @@ -6,10 +6,11 @@ use sentry_protos::taskbroker::v1::consumer_service_server::ConsumerService; use sentry_protos::taskbroker::v1::{ FetchNextTask, GetTaskRequest, SetTaskStatusRequest, TaskActivation, }; +use tokio::sync::mpsc; use tonic::{Code, Request}; use crate::config::{Config, DeliveryMode}; -use crate::grpc::server::TaskbrokerServer; +use crate::grpc::server::{StatusUpdate, TaskbrokerServer}; use crate::store::activation::InflightActivationStatus; use crate::test_utils::{create_config, create_test_store, make_activations}; @@ -21,11 +22,17 @@ async fn test_get_task_push_mode_returns_permission_denied() { ..Config::default() }); - let service = TaskbrokerServer { store, config }; + let service = TaskbrokerServer { + store, + config, + update_tx: None, + }; + let request = GetTaskRequest { namespace: None, application: None, }; + let response = service.get_task(Request::new(request)).await; assert!(response.is_err()); @@ -42,11 +49,17 @@ async fn test_get_task(#[case] adapter: &str) { let store = create_test_store(adapter).await; let config = create_config(); - let service = TaskbrokerServer { store, config }; + let service = TaskbrokerServer { + store, + config, + update_tx: None, + }; + let request = GetTaskRequest { namespace: None, application: None, }; + let response = service.get_task(Request::new(request)).await; assert!(response.is_err()); let e = response.unwrap_err(); @@ -63,12 +76,18 @@ async fn test_set_task_status(#[case] adapter: &str) { let store = create_test_store(adapter).await; let config = create_config(); - let service = TaskbrokerServer { store, config }; + let service = TaskbrokerServer { + store, + config, + update_tx: None, + }; + let request = SetTaskStatusRequest { id: "test_task".to_string(), status: 5, // Complete fetch_next_task: None, }; + let response = service.set_task_status(Request::new(request)).await; assert!(response.is_ok()); let resp = response.unwrap(); @@ -84,12 +103,18 @@ async fn test_set_task_status_invalid(#[case] adapter: &str) { let store = create_test_store(adapter).await; let config = create_config(); - let service = TaskbrokerServer { store, config }; + let service = TaskbrokerServer { + store, + config, + update_tx: None, + }; + let request = SetTaskStatusRequest { id: "test_task".to_string(), status: 1, // Invalid fetch_next_task: None, }; + let response = service.set_task_status(Request::new(request)).await; assert!(response.is_err()); let e = response.unwrap_err(); @@ -115,11 +140,14 @@ async fn test_get_task_success(#[case] adapter: &str) { let service = TaskbrokerServer { store: store.clone(), config, + update_tx: None, }; + let request = GetTaskRequest { namespace: None, application: None, }; + let response = service.get_task(Request::new(request)).await; assert!(response.is_ok()); let resp = response.unwrap(); @@ -149,11 +177,17 @@ async fn test_get_task_with_application_success(#[case] adapter: &str) { store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store, config }; + let service = TaskbrokerServer { + store, + config, + update_tx: None, + }; + let request = GetTaskRequest { namespace: None, application: Some("hammers".into()), }; + let response = service.get_task(Request::new(request)).await; assert!(response.is_ok()); let resp = response.unwrap(); @@ -177,11 +211,17 @@ async fn test_get_task_with_namespace_requires_application(#[case] adapter: &str store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store, config }; + let service = TaskbrokerServer { + store, + config, + update_tx: None, + }; + let request = GetTaskRequest { namespace: Some(namespace), application: None, }; + let response = service.get_task(Request::new(request)).await; assert!(response.is_err()); @@ -201,12 +241,17 @@ async fn test_set_task_status_success(#[case] adapter: &str) { let activations = make_activations(2); store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store, config }; + let service = TaskbrokerServer { + store, + config, + update_tx: None, + }; let request = GetTaskRequest { namespace: None, application: None, }; + let response = service.get_task(Request::new(request)).await; assert!(response.is_ok()); let resp = response.unwrap(); @@ -248,7 +293,12 @@ async fn test_set_task_status_with_application(#[case] adapter: &str) { store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store, config }; + let service = TaskbrokerServer { + store, + config, + update_tx: None, + }; + let request = SetTaskStatusRequest { id: "id_0".to_string(), status: 5, // Complete @@ -257,6 +307,7 @@ async fn test_set_task_status_with_application(#[case] adapter: &str) { namespace: None, }), }; + let response = service.set_task_status(Request::new(request)).await; assert!(response.is_ok()); @@ -287,7 +338,12 @@ async fn test_set_task_status_with_application_no_match(#[case] adapter: &str) { store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store, config }; + let service = TaskbrokerServer { + store, + config, + update_tx: None, + }; + // Request a task from an application without any activations. let request = SetTaskStatusRequest { id: "id_0".to_string(), @@ -297,6 +353,7 @@ async fn test_set_task_status_with_application_no_match(#[case] adapter: &str) { namespace: None, }), }; + let response = service.set_task_status(Request::new(request)).await; assert!(response.is_ok()); assert!(response.unwrap().get_ref().task.is_none()); @@ -316,7 +373,12 @@ async fn test_set_task_status_with_namespace_requires_application(#[case] adapte store.store(activations).await.unwrap(); - let service = TaskbrokerServer { store, config }; + let service = TaskbrokerServer { + store, + config, + update_tx: None, + }; + let request = SetTaskStatusRequest { id: "id_0".to_string(), status: 5, // Complete @@ -325,6 +387,7 @@ async fn test_set_task_status_with_namespace_requires_application(#[case] adapte namespace: Some(namespace), }), }; + let response = service.set_task_status(Request::new(request)).await; assert!(response.is_ok()); assert!( @@ -332,3 +395,92 @@ async fn test_set_task_status_with_namespace_requires_application(#[case] adapte "namespace without application yields no next task in response" ); } + +#[tokio::test] +#[rstest] +#[case::sqlite("sqlite")] +#[case::postgres("postgres")] +#[allow(deprecated)] +async fn test_set_task_status_forwards_to_update_channel(#[case] adapter: &str) { + let store = create_test_store(adapter).await; + let config = create_config(); + + let (update_tx, mut update_rx) = mpsc::channel::(8); + + let activations = make_activations(2); + store.store(activations).await.unwrap(); + + let service = TaskbrokerServer { + store: store.clone(), + config, + update_tx: Some(update_tx), + }; + + let response = service + .get_task(Request::new(GetTaskRequest { + namespace: None, + application: None, + })) + .await + .unwrap(); + assert_eq!(response.get_ref().task.as_ref().unwrap().id, "id_0"); + + let response = service + .set_task_status(Request::new(SetTaskStatusRequest { + id: "id_0".to_string(), + status: 5, // Complete + fetch_next_task: Some(FetchNextTask { + namespace: None, + application: None, + }), + })) + .await + .unwrap(); + + assert!( + response.get_ref().task.is_none(), + "push path returns no next task from the store" + ); + + let (id, status) = update_rx.recv().await.expect("status update on channel"); + assert_eq!(id, "id_0"); + assert_eq!(status, InflightActivationStatus::Complete); + + let row = store.get_by_id("id_0").await.unwrap().expect("row exists"); + assert_eq!( + row.status, + InflightActivationStatus::Processing, + "handler does not write status; flush_updates applies channel batches" + ); +} + +#[tokio::test] +async fn test_set_task_status_update_channel_closed_returns_internal() { + let store = create_test_store("sqlite").await; + let config = create_config(); + + let (update_tx, update_rx) = mpsc::channel::(8); + drop(update_rx); + + let activations = make_activations(1); + store.store(activations).await.unwrap(); + + let service = TaskbrokerServer { + store, + config, + update_tx: Some(update_tx), + }; + + let response = service + .set_task_status(Request::new(SetTaskStatusRequest { + id: "id_0".to_string(), + status: 5, + fetch_next_task: None, + })) + .await; + + assert!(response.is_err()); + let e = response.unwrap_err(); + assert_eq!(e.code(), Code::Internal); + assert_eq!(e.message(), "Status update channel closed"); +} diff --git a/src/lib.rs b/src/lib.rs index 6ce53cd1..89a17421 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,7 @@ use std::fs; pub mod config; pub mod fetch; +pub mod flusher; pub mod grpc; pub mod kafka; pub mod logging; diff --git a/src/main.rs b/src/main.rs index c91a766f..c3648a6d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,12 +12,11 @@ use tonic::transport::Server; use tonic_health::ServingStatus; use tracing::{debug, error, info, warn}; -use taskbroker::SERVICE_NAME; use taskbroker::config::{Config, DatabaseAdapter, DeliveryMode}; use taskbroker::fetch::FetchPool; use taskbroker::grpc::auth_middleware::AuthLayer; use taskbroker::grpc::metrics_middleware::MetricsLayer; -use taskbroker::grpc::server::TaskbrokerServer; +use taskbroker::grpc::server::{TaskbrokerServer, flush_updates}; use taskbroker::kafka::admin::create_missing_topics; use taskbroker::kafka::consumer::start_consumer; use taskbroker::kafka::deserialize::{self, DeserializeConfig}; @@ -40,6 +39,7 @@ use taskbroker::store::adapters::sqlite::{InflightActivationStoreConfig, SqliteA use taskbroker::store::traits::InflightActivationStore; use taskbroker::upkeep::upkeep; use taskbroker::{Args, get_version}; +use taskbroker::{SERVICE_NAME, flusher}; async fn log_task_completion>(name: T, task: JoinHandle>) { match task.await { @@ -191,10 +191,33 @@ async fn main() -> Result<(), Error> { } }); + // Status update flush task + let (status_update_tx, status_update_task) = if config.batch_status_updates { + let (tx, rx) = tokio::sync::mpsc::channel(config.status_update_batch_size.max(1)); + + let flusher_store = store.clone(); + let flusher_config = config.clone(); + + let handle = tokio::spawn(async move { + flusher::run_flusher( + rx, + flusher_config.status_update_batch_size, + flusher_config.status_update_interval_ms, + move |buffer| Box::pin(flush_updates(flusher_store.clone(), buffer)), + ) + .await + }); + + (Some(tx), Some(handle)) + } else { + (None, None) + }; + // GRPC server let grpc_server_task = tokio::spawn({ let grpc_store = store.clone(); let grpc_config = config.clone(); + let grpc_status_tx = status_update_tx.clone(); async move { let addr = format!("{}:{}", grpc_config.grpc_addr, grpc_config.grpc_port) @@ -211,6 +234,7 @@ async fn main() -> Result<(), Error> { .add_service(ConsumerServiceServer::new(TaskbrokerServer { store: grpc_store, config: grpc_config, + update_tx: grpc_status_tx, })) .add_service(health_service.clone()) .serve(addr); @@ -277,6 +301,10 @@ async fn main() -> Result<(), Error> { departure = departure.on_completion(log_task_completion("fetch_task", task)); } + if let Some(task) = status_update_task { + departure = departure.on_completion(log_task_completion("status_update_task", task)); + } + departure.await; Ok(()) } diff --git a/src/push/mod.rs b/src/push/mod.rs index 50db7a9c..fbf04018 100644 --- a/src/push/mod.rs +++ b/src/push/mod.rs @@ -197,7 +197,7 @@ impl PushPool { Ok(a) => a, // Channel closed - Err(_) => break + Err(_) => break, }; metrics::histogram!("push.queue.latency").record(time.elapsed()); @@ -214,7 +214,7 @@ impl PushPool { "Task application has no worker pool mapping" ); - continue + continue; }; match push_task( @@ -303,7 +303,12 @@ impl PushPool { metrics::counter!("push.push_task", "result" => "ok").increment(1); debug!(task_id = %id, "Activation sent to worker"); - if let Err(e) = store.mark_activation_processing(&id).await { + let start = Instant::now(); + let result = store.mark_activation_processing(&id).await; + metrics::histogram!("push.mark_activation_processing.duration") + .record(start.elapsed()); + + if let Err(e) = result { metrics::counter!("push.mark_activation_processing", "result" => "error").increment(1); error!( diff --git a/src/push/tests.rs b/src/push/tests.rs index af46aeea..888b127d 100644 --- a/src/push/tests.rs +++ b/src/push/tests.rs @@ -103,6 +103,13 @@ impl InflightActivationStore for MockStore { ) -> anyhow::Result> { Ok(None) } + async fn set_status_batch( + &self, + _ids: &[String], + _status: InflightActivationStatus, + ) -> anyhow::Result { + Ok(0) + } async fn pending_activation_max_lag(&self, _now: &DateTime) -> f64 { 0.0 } diff --git a/src/store/activation.rs b/src/store/activation.rs index de54e959..12b2872f 100644 --- a/src/store/activation.rs +++ b/src/store/activation.rs @@ -8,7 +8,7 @@ use sqlx::Type; /// The members of this enum should be a superset of the members /// of `InflightActivationStatus` in `sentry_protos`. -#[derive(Clone, Copy, Debug, PartialEq, Eq, Type)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Type, Hash)] pub enum InflightActivationStatus { /// Unused but necessary to align with sentry-protos Unspecified, diff --git a/src/store/adapters/postgres.rs b/src/store/adapters/postgres.rs index 8f087b59..27ce276f 100644 --- a/src/store/adapters/postgres.rs +++ b/src/store/adapters/postgres.rs @@ -654,6 +654,29 @@ impl InflightActivationStore for PostgresActivationStore { Ok(Some(row.into())) } + #[instrument(skip_all)] + #[framed] + async fn set_status_batch( + &self, + ids: &[String], + status: InflightActivationStatus, + ) -> Result { + if ids.is_empty() { + return Ok(0); + } + + let mut conn = self.acquire_write_conn_metric("set_status_batch").await?; + + let result = + sqlx::query("UPDATE inflight_taskactivations SET status = $1 WHERE id = ANY($2)") + .bind(status.to_string()) + .bind(ids) + .execute(&mut *conn) + .await?; + + Ok(result.rows_affected()) + } + #[instrument(skip_all)] #[framed] async fn set_processing_deadline( diff --git a/src/store/adapters/sqlite.rs b/src/store/adapters/sqlite.rs index 8692ac0c..de457de4 100644 --- a/src/store/adapters/sqlite.rs +++ b/src/store/adapters/sqlite.rs @@ -706,6 +706,37 @@ impl InflightActivationStore for SqliteActivationStore { Ok(Some(row.into())) } + #[instrument(skip_all)] + async fn set_status_batch( + &self, + ids: &[String], + status: InflightActivationStatus, + ) -> Result { + if ids.is_empty() { + return Ok(0); + } + + let mut conn = self.acquire_write_conn_metric("set_status_batch").await?; + + let mut query_builder = QueryBuilder::new("UPDATE inflight_taskactivations "); + + query_builder + .push("SET status = ") + .push_bind(status) + .push(" WHERE id IN ("); + + let mut separated = query_builder.separated(", "); + + for id in ids.iter() { + separated.push_bind(id); + } + + separated.push_unseparated(")"); + + let result = query_builder.build().execute(&mut *conn).await?; + Ok(result.rows_affected()) + } + #[instrument(skip_all)] async fn set_processing_deadline( &self, diff --git a/src/store/traits.rs b/src/store/traits.rs index c21a9d41..d8bdb5e0 100644 --- a/src/store/traits.rs +++ b/src/store/traits.rs @@ -78,6 +78,13 @@ pub trait InflightActivationStore: Send + Sync { status: InflightActivationStatus, ) -> Result, Error>; + /// Update the status of multiple activations in one batch. + async fn set_status_batch( + &self, + ids: &[String], + status: InflightActivationStatus, + ) -> Result; + /// COUNT OPERATIONS /// Get the age of the oldest pending activation in seconds async fn pending_activation_max_lag(&self, now: &DateTime) -> f64;