Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions integration/pgdog.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,19 @@ database_name = "shard_1"
shard = 1
role = "replica"

# [[databases]]
# name = "pgdog_sharded"
# host = "localhost"
# database_name = "shard_2"
# shard = 2

# [[databases]]
# name = "pgdog_sharded"
# host = "localhost"
# database_name = "shard_2"
# shard = 2
# role = "replica"

# ------------------------------------------------------------------------------
# ----- Database :: failover ---------------------------------------------------

Expand Down
4 changes: 2 additions & 2 deletions integration/setup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ export PGHOST=127.0.0.1
export PGPORT=5432
export PGUSER='pgdog'

for db in pgdog shard_0 shard_1; do
for db in pgdog shard_0 shard_1 shard_2 shard_3; do
psql -c "DROP DATABASE $db" || true
psql -c "CREATE DATABASE $db" || true
for user in pgdog pgdog1 pgdog2 pgdog3; do
Expand All @@ -35,7 +35,7 @@ for db in pgdog shard_0 shard_1; do
done
done

for db in pgdog shard_0 shard_1; do
for db in pgdog shard_0 shard_1 shard_2 shard_3; do
for table in sharded sharded_omni; do
psql -c "DROP TABLE IF EXISTS ${table}" ${db} -U pgdog
psql -c "CREATE TABLE IF NOT EXISTS ${table} (
Expand Down
11 changes: 8 additions & 3 deletions pgdog/src/backend/pool/connection/binding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,11 @@ impl Binding {
let mut shards_sent = servers.len();
let mut futures = Vec::new();

for (shard, server) in servers.iter_mut().enumerate() {
for (position, server) in servers.iter_mut().enumerate() {
// Map positional index to actual shard number.
// When only a subset of shards is connected (Shard::Multi binding),
// positional indices don't match actual shard numbers.
let shard = state.shard_index(position);
let send = match client_request.route().shard() {
Shard::Direct(s) => {
shards_sent = 1;
Expand Down Expand Up @@ -177,9 +181,10 @@ impl Binding {
/// Send copy messages to shards they are destined to go.
pub async fn send_copy(&mut self, rows: Vec<CopyRow>) -> Result<(), Error> {
match self {
Binding::MultiShard(servers, _state) => {
Binding::MultiShard(servers, state) => {
for row in rows {
for (shard, server) in servers.iter_mut().enumerate() {
for (position, server) in servers.iter_mut().enumerate() {
let shard = state.shard_index(position);
match row.shard() {
Shard::Direct(row_shard) => {
if shard == *row_shard {
Expand Down
2 changes: 1 addition & 1 deletion pgdog/src/backend/pool/connection/binding_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ mod tests {
];

let route = Route::write(ShardWithPriority::new_default_unset(Shard::All));
let multishard = MultiShard::new(3, &route);
let multishard = MultiShard::new(vec![0, 1, 2], &route);

let mut binding = Binding::MultiShard(guards, Box::new(multishard));

Expand Down
5 changes: 3 additions & 2 deletions pgdog/src/backend/pool/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ impl Connection {
};
} else {
let mut shards = vec![];
let mut shard_indices = vec![];
for (i, shard) in self.cluster()?.shards().iter().enumerate() {
if let Shard::Multi(numbers) = route.shard() {
if !numbers.contains(&i) {
Expand All @@ -175,11 +176,11 @@ impl Connection {
}

shards.push(server);
shard_indices.push(i);
}
let num_shards = shards.len();

self.binding =
Binding::MultiShard(shards, Box::new(MultiShard::new(num_shards, route)));
Binding::MultiShard(shards, Box::new(MultiShard::new(shard_indices, route)));
}

Ok(())
Expand Down
18 changes: 16 additions & 2 deletions pgdog/src/backend/pool/connection/multi_shard/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ pub struct MultiShard {
shards: usize,
/// Route the query is taking.
route: Route,
/// Maps positional index in the servers vec to actual shard number.
/// When all shards are connected, this is `[0, 1, 2, ...]`.
/// When only a subset is connected (e.g. shards 0 and 2), this is `[0, 2]`.
shard_indices: Vec<usize>,

/// Counters
counters: Counters,
Expand All @@ -64,16 +68,26 @@ pub struct MultiShard {
}

impl MultiShard {
/// New multi-shard state given the number of shards in the cluster.
pub(super) fn new(shards: usize, route: &Route) -> Self {
/// New multi-shard state given the actual shard indices connected.
pub(super) fn new(shard_indices: Vec<usize>, route: &Route) -> Self {
let shards = shard_indices.len();
Self {
shards,
shard_indices,
route: route.clone(),
counters: Counters::default(),
..Default::default()
}
}

/// Map a positional index to the actual shard number.
pub(super) fn shard_index(&self, position: usize) -> usize {
self.shard_indices
.get(position)
.copied()
.unwrap_or(position)
}

/// Update multi-shard state.
pub(super) fn update(&mut self, shards: usize, route: &Route) {
self.reset();
Expand Down
14 changes: 7 additions & 7 deletions pgdog/src/backend/pool/connection/multi_shard/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use super::*;
#[test]
fn test_inconsistent_row_descriptions() {
let route = Route::default();
let mut multi_shard = MultiShard::new(2, &route);
let mut multi_shard = MultiShard::new(vec![0, 1], &route);

// Create two different row descriptions
let rd1 = RowDescription::new(&[Field::text("name"), Field::bigint("id")]);
Expand All @@ -32,7 +32,7 @@ fn test_inconsistent_row_descriptions() {
#[test]
fn test_inconsistent_data_rows() {
let route = Route::default();
let mut multi_shard = MultiShard::new(2, &route);
let mut multi_shard = MultiShard::new(vec![0, 1], &route);

// Set up row description first
let rd = RowDescription::new(&[Field::text("name"), Field::bigint("id")]);
Expand Down Expand Up @@ -63,7 +63,7 @@ fn test_inconsistent_data_rows() {
#[test]
fn test_rd_before_dr() {
let mut multi_shard = MultiShard::new(
3,
vec![0, 1, 2],
&Route::read(ShardWithPriority::new_default_unset(Shard::All)),
);
let rd = RowDescription::new(&[Field::bigint("id")]);
Expand Down Expand Up @@ -127,7 +127,7 @@ fn test_rd_before_dr() {
#[test]
fn test_ready_for_query_error_preservation() {
let route = Route::default();
let mut multi_shard = MultiShard::new(2, &route);
let mut multi_shard = MultiShard::new(vec![0, 1], &route);

// Create ReadyForQuery messages - one with transaction error, one normal
let rfq_error = ReadyForQuery::error();
Expand All @@ -151,7 +151,7 @@ fn test_ready_for_query_error_preservation() {
fn test_omni_command_complete_not_summed() {
// For omni-sharded tables, we should NOT sum row counts across shards.
let route = Route::write(ShardWithPriority::new_table_omni(Shard::All));
let mut multi_shard = MultiShard::new(3, &route);
let mut multi_shard = MultiShard::new(vec![0, 1, 2], &route);

let backend1 = BackendKeyData { pid: 1, secret: 1 };
let backend2 = BackendKeyData { pid: 2, secret: 2 };
Expand Down Expand Up @@ -193,7 +193,7 @@ fn test_omni_command_complete_not_summed() {
fn test_omni_command_complete_uses_first_shard_row_count() {
// For omni, we use the first shard's row count for consistency with DataRow behavior.
let route = Route::write(ShardWithPriority::new_table_omni(Shard::All));
let mut multi_shard = MultiShard::new(2, &route);
let mut multi_shard = MultiShard::new(vec![0, 1], &route);

let backend1 = BackendKeyData { pid: 1, secret: 1 };
let backend2 = BackendKeyData { pid: 2, secret: 2 };
Expand Down Expand Up @@ -228,7 +228,7 @@ fn test_omni_command_complete_uses_first_shard_row_count() {
fn test_omni_data_rows_only_from_first_server() {
// For omni-sharded tables with RETURNING, only forward DataRows from the first server.
let route = Route::write(ShardWithPriority::new_table_omni(Shard::All));
let mut multi_shard = MultiShard::new(2, &route);
let mut multi_shard = MultiShard::new(vec![0, 1], &route);

let backend1 = BackendKeyData { pid: 1, secret: 1 };
let backend2 = BackendKeyData { pid: 2, secret: 2 };
Expand Down
74 changes: 34 additions & 40 deletions pgdog/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,52 +207,46 @@ pub fn load_test_replicas() {

#[cfg(test)]
pub fn load_test_sharded() {
load_test_sharded_n(2);
}

/// Load 3-shard test configuration.
pub fn load_test_sharded_3() {
load_test_sharded_n(3);
}

fn load_test_sharded_n(num_shards: usize) {
use pgdog_config::{OmnishardedTables, ShardedSchema};

use crate::backend::databases::init;

let mut config = ConfigAndUsers::default();
config.config.general.min_pool_size = 0;
config.config.databases = vec![
Database {
name: "pgdog".into(),
host: "127.0.0.1".into(),
port: 5432,
role: Role::Primary,
database_name: Some("shard_0".into()),
shard: 0,
..Default::default()
},
Database {
name: "pgdog".into(),
host: "127.0.0.1".into(),
port: 5432,
role: Role::Replica,
read_only: Some(true),
database_name: Some("shard_0".into()),
shard: 0,
..Default::default()
},
Database {
name: "pgdog".into(),
host: "127.0.0.1".into(),
port: 5432,
role: Role::Primary,
database_name: Some("shard_1".into()),
shard: 1,
..Default::default()
},
Database {
name: "pgdog".into(),
host: "127.0.0.1".into(),
port: 5432,
role: Role::Replica,
read_only: Some(true),
database_name: Some("shard_1".into()),
shard: 1,
..Default::default()
},
];
config.config.databases = (0..num_shards)
.flat_map(|shard| {
vec![
Database {
name: "pgdog".into(),
host: "127.0.0.1".into(),
port: 5432,
role: Role::Primary,
database_name: Some(format!("shard_{}", shard)),
shard,
..Default::default()
},
Database {
name: "pgdog".into(),
host: "127.0.0.1".into(),
port: 5432,
role: Role::Replica,
read_only: Some(true),
database_name: Some(format!("shard_{}", shard)),
shard,
..Default::default()
},
]
})
.collect();
config.config.sharded_tables = vec![
ShardedTable {
database: "pgdog".into(),
Expand Down
1 change: 1 addition & 0 deletions pgdog/src/frontend/client/query_engine/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ mod graceful_disconnect;
mod graceful_shutdown;
mod idle_in_transaction_recovery;
mod lock_session;
mod multi_binding;
mod omni;
pub mod prelude;
mod prepared_syntax_error;
Expand Down
101 changes: 101 additions & 0 deletions pgdog/src/frontend/client/query_engine/test/multi_binding.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
use crate::{
expect_message,
net::{CommandComplete, Parameters, Query, ReadyForQuery},
};

use super::prelude::*;

/// Regression test: Shard::Multi with fewer shards than total (e.g. 2 of 3)
/// used to get stuck because send() compared actual shard numbers against
/// positional indices in the pre-filtered servers vec.
///
/// With 3 shards, an IN clause targeting shards 0 and 2 creates a MultiShard
/// binding with only 2 servers. Before the fix, send() used positional indices
/// (0, 1) to match against actual shard numbers ([0, 2]), so server at index 1
/// (shard 2) never received the query — hanging the response.
#[tokio::test]
async fn test_multi_binding_select_subset_of_shards() {
let mut client = TestClient::new_sharded_3(Parameters::default()).await;

let id_shard0 = client.random_id_for_shard(0);
let id_shard2 = client.random_id_for_shard(2);

// Cleanup
client
.send_simple(Query::new(format!(
"DELETE FROM sharded WHERE id IN ({}, {})",
id_shard0, id_shard2
)))
.await;
client.read_until('Z').await.unwrap();

// Insert rows into shards 0 and 2 (skipping shard 1).
client
.send_simple(Query::new(format!(
"INSERT INTO sharded (id, value) VALUES ({}, 'multi0'), ({}, 'multi2') ON CONFLICT (id) DO UPDATE SET value = EXCLUDED.value",
id_shard0, id_shard2
)))
.await;
client.read_until('Z').await.unwrap();

// SELECT with IN clause targeting shards 0 and 2 only.
// This creates a Shard::Multi([0, 2]) binding with 2 of 3 servers.
client
.send_simple(Query::new(format!(
"SELECT * FROM sharded WHERE id IN ({}, {})",
id_shard0, id_shard2
)))
.await;

let messages = client.read_until('Z').await.unwrap();

let data_rows: Vec<_> = messages.iter().filter(|m| m.code() == 'D').collect();
assert_eq!(
data_rows.len(),
2,
"should return rows from both targeted shards"
);

let cc_msg = messages.iter().find(|m| m.code() == 'C').unwrap();
let cc = CommandComplete::try_from(cc_msg.clone()).unwrap();
assert_eq!(cc.command(), "SELECT 2");

// Cleanup
client
.send_simple(Query::new(format!(
"DELETE FROM sharded WHERE id IN ({}, {})",
id_shard0, id_shard2
)))
.await;
client.read_until('Z').await.unwrap();
}

/// Test multi-binding DELETE targeting a subset of shards (2 of 3).
#[tokio::test]
async fn test_multi_binding_delete_subset_of_shards() {
let mut client = TestClient::new_sharded_3(Parameters::default()).await;

let id_shard0 = client.random_id_for_shard(0);
let id_shard2 = client.random_id_for_shard(2);

// Setup
client
.send_simple(Query::new(format!(
"INSERT INTO sharded (id, value) VALUES ({}, 'del0'), ({}, 'del2') ON CONFLICT (id) DO UPDATE SET value = EXCLUDED.value",
id_shard0, id_shard2
)))
.await;
client.read_until('Z').await.unwrap();

// DELETE targeting shards 0 and 2 via IN clause (2 of 3 shards).
client
.send_simple(Query::new(format!(
"DELETE FROM sharded WHERE id IN ({}, {})",
id_shard0, id_shard2
)))
.await;

let cc = expect_message!(client.read().await, CommandComplete);
assert_eq!(cc.command(), "DELETE 2");
expect_message!(client.read().await, ReadyForQuery);
}
Loading
Loading