From cee11dd331917afa158782a832e42591bd8a9f11 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Thu, 2 Apr 2026 11:17:58 -0700 Subject: [PATCH] fix: deadlock in multi-shard queries --- integration/pgdog.toml | 13 +++ integration/setup.sh | 4 +- pgdog/src/backend/pool/connection/binding.rs | 11 +- .../backend/pool/connection/binding_test.rs | 2 +- pgdog/src/backend/pool/connection/mod.rs | 5 +- .../pool/connection/multi_shard/mod.rs | 18 +++- .../pool/connection/multi_shard/test.rs | 14 +-- pgdog/src/config/mod.rs | 74 ++++++------- .../frontend/client/query_engine/test/mod.rs | 1 + .../client/query_engine/test/multi_binding.rs | 101 ++++++++++++++++++ pgdog/src/frontend/client/test/test_client.rs | 8 +- 11 files changed, 193 insertions(+), 58 deletions(-) create mode 100644 pgdog/src/frontend/client/query_engine/test/multi_binding.rs diff --git a/integration/pgdog.toml b/integration/pgdog.toml index b70a5a2c5..74ce6a3b4 100644 --- a/integration/pgdog.toml +++ b/integration/pgdog.toml @@ -78,6 +78,19 @@ database_name = "shard_1" shard = 1 role = "replica" +# [[databases]] +# name = "pgdog_sharded" +# host = "localhost" +# database_name = "shard_2" +# shard = 2 + +# [[databases]] +# name = "pgdog_sharded" +# host = "localhost" +# database_name = "shard_2" +# shard = 2 +# role = "replica" + # ------------------------------------------------------------------------------ # ----- Database :: failover --------------------------------------------------- diff --git a/integration/setup.sh b/integration/setup.sh index 2a168e23f..b7faa9311 100644 --- a/integration/setup.sh +++ b/integration/setup.sh @@ -26,7 +26,7 @@ export PGHOST=127.0.0.1 export PGPORT=5432 export PGUSER='pgdog' -for db in pgdog shard_0 shard_1; do +for db in pgdog shard_0 shard_1 shard_2 shard_3; do psql -c "DROP DATABASE $db" || true psql -c "CREATE DATABASE $db" || true for user in pgdog pgdog1 pgdog2 pgdog3; do @@ -35,7 +35,7 @@ for db in pgdog shard_0 shard_1; do done done -for db in pgdog shard_0 shard_1; do +for db in pgdog shard_0 shard_1 shard_2 shard_3; do for table in sharded sharded_omni; do psql -c "DROP TABLE IF EXISTS ${table}" ${db} -U pgdog psql -c "CREATE TABLE IF NOT EXISTS ${table} ( diff --git a/pgdog/src/backend/pool/connection/binding.rs b/pgdog/src/backend/pool/connection/binding.rs index 677cd3cf3..82c55cbe0 100644 --- a/pgdog/src/backend/pool/connection/binding.rs +++ b/pgdog/src/backend/pool/connection/binding.rs @@ -136,7 +136,11 @@ impl Binding { let mut shards_sent = servers.len(); let mut futures = Vec::new(); - for (shard, server) in servers.iter_mut().enumerate() { + for (position, server) in servers.iter_mut().enumerate() { + // Map positional index to actual shard number. + // When only a subset of shards is connected (Shard::Multi binding), + // positional indices don't match actual shard numbers. + let shard = state.shard_index(position); let send = match client_request.route().shard() { Shard::Direct(s) => { shards_sent = 1; @@ -177,9 +181,10 @@ impl Binding { /// Send copy messages to shards they are destined to go. pub async fn send_copy(&mut self, rows: Vec) -> Result<(), Error> { match self { - Binding::MultiShard(servers, _state) => { + Binding::MultiShard(servers, state) => { for row in rows { - for (shard, server) in servers.iter_mut().enumerate() { + for (position, server) in servers.iter_mut().enumerate() { + let shard = state.shard_index(position); match row.shard() { Shard::Direct(row_shard) => { if shard == *row_shard { diff --git a/pgdog/src/backend/pool/connection/binding_test.rs b/pgdog/src/backend/pool/connection/binding_test.rs index 7b544eb32..7c4d3c578 100644 --- a/pgdog/src/backend/pool/connection/binding_test.rs +++ b/pgdog/src/backend/pool/connection/binding_test.rs @@ -49,7 +49,7 @@ mod tests { ]; let route = Route::write(ShardWithPriority::new_default_unset(Shard::All)); - let multishard = MultiShard::new(3, &route); + let multishard = MultiShard::new(vec![0, 1, 2], &route); let mut binding = Binding::MultiShard(guards, Box::new(multishard)); diff --git a/pgdog/src/backend/pool/connection/mod.rs b/pgdog/src/backend/pool/connection/mod.rs index ea07b9109..26846324e 100644 --- a/pgdog/src/backend/pool/connection/mod.rs +++ b/pgdog/src/backend/pool/connection/mod.rs @@ -158,6 +158,7 @@ impl Connection { }; } else { let mut shards = vec![]; + let mut shard_indices = vec![]; for (i, shard) in self.cluster()?.shards().iter().enumerate() { if let Shard::Multi(numbers) = route.shard() { if !numbers.contains(&i) { @@ -175,11 +176,11 @@ impl Connection { } shards.push(server); + shard_indices.push(i); } - let num_shards = shards.len(); self.binding = - Binding::MultiShard(shards, Box::new(MultiShard::new(num_shards, route))); + Binding::MultiShard(shards, Box::new(MultiShard::new(shard_indices, route))); } Ok(()) diff --git a/pgdog/src/backend/pool/connection/multi_shard/mod.rs b/pgdog/src/backend/pool/connection/multi_shard/mod.rs index 719e2078b..31f9073fc 100644 --- a/pgdog/src/backend/pool/connection/multi_shard/mod.rs +++ b/pgdog/src/backend/pool/connection/multi_shard/mod.rs @@ -52,6 +52,10 @@ pub struct MultiShard { shards: usize, /// Route the query is taking. route: Route, + /// Maps positional index in the servers vec to actual shard number. + /// When all shards are connected, this is `[0, 1, 2, ...]`. + /// When only a subset is connected (e.g. shards 0 and 2), this is `[0, 2]`. + shard_indices: Vec, /// Counters counters: Counters, @@ -64,16 +68,26 @@ pub struct MultiShard { } impl MultiShard { - /// New multi-shard state given the number of shards in the cluster. - pub(super) fn new(shards: usize, route: &Route) -> Self { + /// New multi-shard state given the actual shard indices connected. + pub(super) fn new(shard_indices: Vec, route: &Route) -> Self { + let shards = shard_indices.len(); Self { shards, + shard_indices, route: route.clone(), counters: Counters::default(), ..Default::default() } } + /// Map a positional index to the actual shard number. + pub(super) fn shard_index(&self, position: usize) -> usize { + self.shard_indices + .get(position) + .copied() + .unwrap_or(position) + } + /// Update multi-shard state. pub(super) fn update(&mut self, shards: usize, route: &Route) { self.reset(); diff --git a/pgdog/src/backend/pool/connection/multi_shard/test.rs b/pgdog/src/backend/pool/connection/multi_shard/test.rs index 1a5492c91..94b3cd63e 100644 --- a/pgdog/src/backend/pool/connection/multi_shard/test.rs +++ b/pgdog/src/backend/pool/connection/multi_shard/test.rs @@ -8,7 +8,7 @@ use super::*; #[test] fn test_inconsistent_row_descriptions() { let route = Route::default(); - let mut multi_shard = MultiShard::new(2, &route); + let mut multi_shard = MultiShard::new(vec![0, 1], &route); // Create two different row descriptions let rd1 = RowDescription::new(&[Field::text("name"), Field::bigint("id")]); @@ -32,7 +32,7 @@ fn test_inconsistent_row_descriptions() { #[test] fn test_inconsistent_data_rows() { let route = Route::default(); - let mut multi_shard = MultiShard::new(2, &route); + let mut multi_shard = MultiShard::new(vec![0, 1], &route); // Set up row description first let rd = RowDescription::new(&[Field::text("name"), Field::bigint("id")]); @@ -63,7 +63,7 @@ fn test_inconsistent_data_rows() { #[test] fn test_rd_before_dr() { let mut multi_shard = MultiShard::new( - 3, + vec![0, 1, 2], &Route::read(ShardWithPriority::new_default_unset(Shard::All)), ); let rd = RowDescription::new(&[Field::bigint("id")]); @@ -127,7 +127,7 @@ fn test_rd_before_dr() { #[test] fn test_ready_for_query_error_preservation() { let route = Route::default(); - let mut multi_shard = MultiShard::new(2, &route); + let mut multi_shard = MultiShard::new(vec![0, 1], &route); // Create ReadyForQuery messages - one with transaction error, one normal let rfq_error = ReadyForQuery::error(); @@ -151,7 +151,7 @@ fn test_ready_for_query_error_preservation() { fn test_omni_command_complete_not_summed() { // For omni-sharded tables, we should NOT sum row counts across shards. let route = Route::write(ShardWithPriority::new_table_omni(Shard::All)); - let mut multi_shard = MultiShard::new(3, &route); + let mut multi_shard = MultiShard::new(vec![0, 1, 2], &route); let backend1 = BackendKeyData { pid: 1, secret: 1 }; let backend2 = BackendKeyData { pid: 2, secret: 2 }; @@ -193,7 +193,7 @@ fn test_omni_command_complete_not_summed() { fn test_omni_command_complete_uses_first_shard_row_count() { // For omni, we use the first shard's row count for consistency with DataRow behavior. let route = Route::write(ShardWithPriority::new_table_omni(Shard::All)); - let mut multi_shard = MultiShard::new(2, &route); + let mut multi_shard = MultiShard::new(vec![0, 1], &route); let backend1 = BackendKeyData { pid: 1, secret: 1 }; let backend2 = BackendKeyData { pid: 2, secret: 2 }; @@ -228,7 +228,7 @@ fn test_omni_command_complete_uses_first_shard_row_count() { fn test_omni_data_rows_only_from_first_server() { // For omni-sharded tables with RETURNING, only forward DataRows from the first server. let route = Route::write(ShardWithPriority::new_table_omni(Shard::All)); - let mut multi_shard = MultiShard::new(2, &route); + let mut multi_shard = MultiShard::new(vec![0, 1], &route); let backend1 = BackendKeyData { pid: 1, secret: 1 }; let backend2 = BackendKeyData { pid: 2, secret: 2 }; diff --git a/pgdog/src/config/mod.rs b/pgdog/src/config/mod.rs index a8e342eb2..835a0f10e 100644 --- a/pgdog/src/config/mod.rs +++ b/pgdog/src/config/mod.rs @@ -207,52 +207,46 @@ pub fn load_test_replicas() { #[cfg(test)] pub fn load_test_sharded() { + load_test_sharded_n(2); +} + +/// Load 3-shard test configuration. +pub fn load_test_sharded_3() { + load_test_sharded_n(3); +} + +fn load_test_sharded_n(num_shards: usize) { use pgdog_config::{OmnishardedTables, ShardedSchema}; use crate::backend::databases::init; let mut config = ConfigAndUsers::default(); config.config.general.min_pool_size = 0; - config.config.databases = vec![ - Database { - name: "pgdog".into(), - host: "127.0.0.1".into(), - port: 5432, - role: Role::Primary, - database_name: Some("shard_0".into()), - shard: 0, - ..Default::default() - }, - Database { - name: "pgdog".into(), - host: "127.0.0.1".into(), - port: 5432, - role: Role::Replica, - read_only: Some(true), - database_name: Some("shard_0".into()), - shard: 0, - ..Default::default() - }, - Database { - name: "pgdog".into(), - host: "127.0.0.1".into(), - port: 5432, - role: Role::Primary, - database_name: Some("shard_1".into()), - shard: 1, - ..Default::default() - }, - Database { - name: "pgdog".into(), - host: "127.0.0.1".into(), - port: 5432, - role: Role::Replica, - read_only: Some(true), - database_name: Some("shard_1".into()), - shard: 1, - ..Default::default() - }, - ]; + config.config.databases = (0..num_shards) + .flat_map(|shard| { + vec![ + Database { + name: "pgdog".into(), + host: "127.0.0.1".into(), + port: 5432, + role: Role::Primary, + database_name: Some(format!("shard_{}", shard)), + shard, + ..Default::default() + }, + Database { + name: "pgdog".into(), + host: "127.0.0.1".into(), + port: 5432, + role: Role::Replica, + read_only: Some(true), + database_name: Some(format!("shard_{}", shard)), + shard, + ..Default::default() + }, + ] + }) + .collect(); config.config.sharded_tables = vec![ ShardedTable { database: "pgdog".into(), diff --git a/pgdog/src/frontend/client/query_engine/test/mod.rs b/pgdog/src/frontend/client/query_engine/test/mod.rs index d4149f433..a48f740f4 100644 --- a/pgdog/src/frontend/client/query_engine/test/mod.rs +++ b/pgdog/src/frontend/client/query_engine/test/mod.rs @@ -15,6 +15,7 @@ mod graceful_disconnect; mod graceful_shutdown; mod idle_in_transaction_recovery; mod lock_session; +mod multi_binding; mod omni; pub mod prelude; mod prepared_syntax_error; diff --git a/pgdog/src/frontend/client/query_engine/test/multi_binding.rs b/pgdog/src/frontend/client/query_engine/test/multi_binding.rs new file mode 100644 index 000000000..0d73f7dca --- /dev/null +++ b/pgdog/src/frontend/client/query_engine/test/multi_binding.rs @@ -0,0 +1,101 @@ +use crate::{ + expect_message, + net::{CommandComplete, Parameters, Query, ReadyForQuery}, +}; + +use super::prelude::*; + +/// Regression test: Shard::Multi with fewer shards than total (e.g. 2 of 3) +/// used to get stuck because send() compared actual shard numbers against +/// positional indices in the pre-filtered servers vec. +/// +/// With 3 shards, an IN clause targeting shards 0 and 2 creates a MultiShard +/// binding with only 2 servers. Before the fix, send() used positional indices +/// (0, 1) to match against actual shard numbers ([0, 2]), so server at index 1 +/// (shard 2) never received the query — hanging the response. +#[tokio::test] +async fn test_multi_binding_select_subset_of_shards() { + let mut client = TestClient::new_sharded_3(Parameters::default()).await; + + let id_shard0 = client.random_id_for_shard(0); + let id_shard2 = client.random_id_for_shard(2); + + // Cleanup + client + .send_simple(Query::new(format!( + "DELETE FROM sharded WHERE id IN ({}, {})", + id_shard0, id_shard2 + ))) + .await; + client.read_until('Z').await.unwrap(); + + // Insert rows into shards 0 and 2 (skipping shard 1). + client + .send_simple(Query::new(format!( + "INSERT INTO sharded (id, value) VALUES ({}, 'multi0'), ({}, 'multi2') ON CONFLICT (id) DO UPDATE SET value = EXCLUDED.value", + id_shard0, id_shard2 + ))) + .await; + client.read_until('Z').await.unwrap(); + + // SELECT with IN clause targeting shards 0 and 2 only. + // This creates a Shard::Multi([0, 2]) binding with 2 of 3 servers. + client + .send_simple(Query::new(format!( + "SELECT * FROM sharded WHERE id IN ({}, {})", + id_shard0, id_shard2 + ))) + .await; + + let messages = client.read_until('Z').await.unwrap(); + + let data_rows: Vec<_> = messages.iter().filter(|m| m.code() == 'D').collect(); + assert_eq!( + data_rows.len(), + 2, + "should return rows from both targeted shards" + ); + + let cc_msg = messages.iter().find(|m| m.code() == 'C').unwrap(); + let cc = CommandComplete::try_from(cc_msg.clone()).unwrap(); + assert_eq!(cc.command(), "SELECT 2"); + + // Cleanup + client + .send_simple(Query::new(format!( + "DELETE FROM sharded WHERE id IN ({}, {})", + id_shard0, id_shard2 + ))) + .await; + client.read_until('Z').await.unwrap(); +} + +/// Test multi-binding DELETE targeting a subset of shards (2 of 3). +#[tokio::test] +async fn test_multi_binding_delete_subset_of_shards() { + let mut client = TestClient::new_sharded_3(Parameters::default()).await; + + let id_shard0 = client.random_id_for_shard(0); + let id_shard2 = client.random_id_for_shard(2); + + // Setup + client + .send_simple(Query::new(format!( + "INSERT INTO sharded (id, value) VALUES ({}, 'del0'), ({}, 'del2') ON CONFLICT (id) DO UPDATE SET value = EXCLUDED.value", + id_shard0, id_shard2 + ))) + .await; + client.read_until('Z').await.unwrap(); + + // DELETE targeting shards 0 and 2 via IN clause (2 of 3 shards). + client + .send_simple(Query::new(format!( + "DELETE FROM sharded WHERE id IN ({}, {})", + id_shard0, id_shard2 + ))) + .await; + + let cc = expect_message!(client.read().await, CommandComplete); + assert_eq!(cc.command(), "DELETE 2"); + expect_message!(client.read().await, ReadyForQuery); +} diff --git a/pgdog/src/frontend/client/test/test_client.rs b/pgdog/src/frontend/client/test/test_client.rs index f4d5f60df..bb3bdd078 100644 --- a/pgdog/src/frontend/client/test/test_client.rs +++ b/pgdog/src/frontend/client/test/test_client.rs @@ -10,7 +10,7 @@ use tokio::{ use crate::{ backend::databases::{reload_from_existing, shutdown}, - config::{config, load_test_replicas, load_test_sharded, set}, + config::{config, load_test_replicas, load_test_sharded, load_test_sharded_3, set}, frontend::{ client::query_engine::QueryEngine, router::{parser::Shard, sharding::ContextBuilder}, @@ -137,6 +137,12 @@ impl TestClient { Self::new(params).await } + /// New 3-shard client with parameters. + pub(crate) async fn new_sharded_3(params: Parameters) -> Self { + load_test_sharded_3(); + Self::new(params).await + } + pub(crate) fn leak_pool(mut self) -> Self { self.leak_pool = true; self