Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

41 changes: 34 additions & 7 deletions crates/client-api/src/routes/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use futures::TryStreamExt;
use http::StatusCode;
use log::{info, warn};
use serde::Deserialize;
use spacetimedb::auth::identity::ConnectionAuthCtx;
use spacetimedb::database_logger::DatabaseLogger;
use spacetimedb::host::module_host::ClientConnectedError;
use spacetimedb::host::{CallResult, UpdateDatabaseResult};
Expand Down Expand Up @@ -518,22 +519,46 @@ pub async fn sql_direct<S>(
SqlParams { name_or_identity }: SqlParams,
SqlQueryParams { confirmed }: SqlQueryParams,
caller_identity: Identity,
caller_auth: ConnectionAuthCtx,
sql: String,
) -> axum::response::Result<Vec<SqlStmtResult<ProductValue>>>
where
S: NodeDelegate + ControlStateDelegate + Authorization,
{
// Anyone is authorized to execute SQL queries. The SQL engine will determine
// which queries this identity is allowed to execute against the database.
let connection_id = generate_random_connection_id();

let (host, database) = find_leader_and_database(&worker_ctx, name_or_identity).await?;

let auth = worker_ctx
.authorize_sql(caller_identity, database.database_identity)
.await?;
// Run the module's client_connected reducer, if any.
// If it rejects the connection, bail before executing SQL.
let module = host.module().await.map_err(log_and_500)?;
module
.call_identity_connected(caller_auth, connection_id)
.await
.map_err(client_connected_error_to_response)?;

let result = async {
let sql_auth = worker_ctx
.authorize_sql(caller_identity, database.database_identity)
.await?;

host.exec_sql(
sql_auth,
database,
confirmed.unwrap_or(crate::DEFAULT_CONFIRMED_READS),
sql,
)
.await
}
.await;

host.exec_sql(auth, database, confirmed.unwrap_or(crate::DEFAULT_CONFIRMED_READS), sql)
// Always disconnect, even if authorization or execution failed.
module
.call_identity_disconnected(caller_identity, connection_id, false)
.await
.map_err(client_disconnected_error_to_response)?;

result
}

pub async fn sql<S>(
Expand All @@ -546,7 +571,9 @@ pub async fn sql<S>(
where
S: NodeDelegate + ControlStateDelegate + Authorization,
{
let json = sql_direct(worker_ctx, name_or_identity, params, auth.claims.identity, body).await?;
let caller_identity = auth.claims.identity;
let caller_auth: ConnectionAuthCtx = auth.into();
let json = sql_direct(worker_ctx, name_or_identity, params, caller_identity, caller_auth, body).await?;

let total_duration = json.iter().fold(0, |acc, x| acc + x.total_duration_micros);

Expand Down
1 change: 1 addition & 0 deletions crates/pg/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ license-file = "LICENSE"
description = "Postgres wire protocol Server support for SpacetimeDB"

[dependencies]
spacetimedb-auth.workspace = true
spacetimedb-client-api-messages.workspace = true
spacetimedb-client-api.workspace = true
spacetimedb-lib.workspace = true
Expand Down
17 changes: 15 additions & 2 deletions crates/pg/src/pg_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use pgwire::messages::data::DataRow;
use pgwire::messages::startup::Authentication;
use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
use pgwire::tokio::process_socket;
use spacetimedb_auth::identity::ConnectionAuthCtx;
use spacetimedb_client_api::auth::validate_token;
use spacetimedb_client_api::routes::database;
use spacetimedb_client_api::routes::database::{SqlParams, SqlQueryParams};
Expand Down Expand Up @@ -64,6 +65,7 @@ impl From<PgError> for PgWireError {
struct Metadata {
database: String,
caller_identity: Identity,
caller_auth: ConnectionAuthCtx,
}

pub(crate) fn to_rows(
Expand Down Expand Up @@ -163,6 +165,7 @@ where
db,
SqlQueryParams { confirmed: Some(true) },
params.caller_identity,
params.caller_auth.clone(),
query.to_string(),
)
.await,
Expand Down Expand Up @@ -266,8 +269,8 @@ impl<T: Sync + Send + ControlStateReadAccess + ControlStateWriteAccess + NodeDel
}
};

let caller_identity = match validate_token(&self.ctx, &pwd.password).await {
Ok(claims) => claims.identity,
let claims = match validate_token(&self.ctx, &pwd.password).await {
Ok(claims) => claims,
Err(err) => {
log::error!(
"PG: Authentication failed for identity `{}` on database {database}: {err}",
Expand All @@ -277,12 +280,22 @@ impl<T: Sync + Send + ControlStateReadAccess + ControlStateWriteAccess + NodeDel
return close_client(client, err).await;
}
};
let caller_identity = claims.identity;
let caller_auth = ConnectionAuthCtx::try_from(claims).map_err(|e| {
PgWireError::UserError(Box::new(ErrorInfo::new(
"FATAL".to_owned(),
// "invalid_authorization_specification"
"28000".to_owned(),
e.to_string(),
)))
})?;

log::info!("PG: Connected to database: {database} using identity `{caller_identity}`");

let metadata = Metadata {
database,
caller_identity,
caller_auth,
};
self.cached.lock().await.clone_from(&Some(metadata));
finish_authentication(client, &self.parameter_provider).await?;
Expand Down
8 changes: 8 additions & 0 deletions crates/smoketests/modules/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/smoketests/modules/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ members = [
"delete-database",
"client-connection-reject",
"client-connection-disconnect-panic",
"sql-connect-hook",

# Log filtering tests
"logs-level-filter",
Expand Down
12 changes: 12 additions & 0 deletions crates/smoketests/modules/sql-connect-hook/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[package]
name = "smoketest-module-sql-connect-hook"
version = "0.1.0"
edition = "2021"
publish = false

[lib]
crate-type = ["cdylib"]

[dependencies]
spacetimedb.workspace = true
log.workspace = true
23 changes: 23 additions & 0 deletions crates/smoketests/modules/sql-connect-hook/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
use spacetimedb::{log, ReducerContext, Table};

#[spacetimedb::table(accessor = person, public)]
pub struct Person {
name: String,
}

#[spacetimedb::reducer(init)]
pub fn init(ctx: &ReducerContext) {
ctx.db.person().insert(Person {
name: "Alice".to_string(),
});
}

#[spacetimedb::reducer(client_connected)]
pub fn connected(ctx: &ReducerContext) {
log::info!("sql_connect_hook: client_connected caller={}", ctx.sender());
}

#[spacetimedb::reducer(client_disconnected)]
pub fn disconnected(ctx: &ReducerContext) {
log::info!("sql_connect_hook: client_disconnected caller={}", ctx.sender());
}
11 changes: 7 additions & 4 deletions crates/smoketests/tests/smoketests/client_connection_errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,14 @@ fn test_client_disconnected_error_still_deletes_st_client() {
logs
);

// Verify st_client table is empty (row was deleted despite the panic)
// Verify the websocket's st_client row was deleted despite the panic.
// The SQL query itself creates a temporary connection, so we may see
// exactly one row (the SQL connection's own), but the websocket's row
// should be gone.
let sql_out = test.sql("SELECT * FROM st_client").unwrap();
let row_count = sql_out.lines().filter(|l| l.contains("0x")).count();
assert!(
sql_out.contains("identity | connection_id") && !sql_out.contains("0x"),
"Expected st_client table to be empty, got: {}",
sql_out
row_count <= 1,
"Expected at most 1 st_client row (the SQL connection itself), got {row_count}: {sql_out}",
);
}
1 change: 1 addition & 0 deletions crates/smoketests/tests/smoketests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ mod rls;
mod schedule_reducer;
mod servers;
mod sql;
mod sql_connect_hook;
mod templates;
mod timestamp_route;
mod views;
86 changes: 86 additions & 0 deletions crates/smoketests/tests/smoketests/sql_connect_hook.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
use spacetimedb_smoketests::Smoketest;

/// Test that SQL requests are rejected when client_connected returns an error.
///
/// This verifies that the /sql HTTP endpoint now runs the module's
/// client_connected reducer and rejects the request if it errors.
/// Without PR #4563, this SQL query would succeed.
#[test]
fn test_sql_rejected_when_client_connected_errors() {
let test = Smoketest::builder()
.precompiled_module("client-connection-reject")
.build();

// SQL should fail because client_connected returns an error
let result = test.sql("SELECT * FROM all_u8s");
assert!(
result.is_err(),
"Expected SQL query to be rejected when client_connected errors, but it succeeded"
);
}

/// Test that SQL requests trigger client_connected and client_disconnected hooks.
///
/// This verifies that the /sql HTTP endpoint calls the module's lifecycle
/// reducers. Without PR #4563, no connect/disconnect logs would appear.
#[test]
fn test_sql_triggers_connect_disconnect_hooks() {
let test = Smoketest::builder().precompiled_module("sql-connect-hook").build();

// Run a SQL query
test.sql("SELECT * FROM person").unwrap();

// Check that both connect and disconnect hooks were called
let logs = test.logs(100).unwrap();
assert!(
logs.iter().any(|l| l.contains("sql_connect_hook: client_connected")),
"Expected client_connected log from SQL request, got: {:?}",
logs
);
assert!(
logs.iter().any(|l| l.contains("sql_connect_hook: client_disconnected")),
"Expected client_disconnected log from SQL request, got: {:?}",
logs
);
}

/// Test that SQL queries still return data when client_connected accepts.
///
/// Ensures the connect hook doesn't break normal SQL functionality.
#[test]
fn test_sql_returns_data_with_connect_hook() {
let test = Smoketest::builder().precompiled_module("sql-connect-hook").build();

test.assert_sql(
"SELECT * FROM person",
r#" name
---------
"Alice""#,
);
}

/// Test that client_disconnected is still called even when the SQL query fails.
///
/// The `authorize_sql` and `exec_sql` errors are captured inside an async block,
/// so `call_identity_disconnected` runs regardless of query success or failure.
#[test]
fn test_sql_disconnect_called_on_query_error() {
let test = Smoketest::builder().precompiled_module("sql-connect-hook").build();

// Run an invalid SQL query — this will fail in exec_sql
let result = test.sql("SELECT * FROM nonexistent_table");
assert!(result.is_err(), "Expected invalid SQL to fail");

// Despite the query error, both connect and disconnect should have been called
let logs = test.logs(100).unwrap();
assert!(
logs.iter().any(|l| l.contains("sql_connect_hook: client_connected")),
"Expected client_connected even on failed SQL, got: {:?}",
logs
);
assert!(
logs.iter().any(|l| l.contains("sql_connect_hook: client_disconnected")),
"Expected client_disconnected even on failed SQL, got: {:?}",
logs
);
}
11 changes: 8 additions & 3 deletions smoketests/tests/client_connected_error_rejects_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ def test_client_disconnected_error_still_deletes_st_client(self):

sql_out = self.spacetime("sql", self.database_identity, "select * from st_client")

self.assertMultiLineEqual(sql_out, """ identity | connection_id
----------+---------------
""")
# The SQL query itself now creates a temporary connection, so we may
# see exactly one row (the SQL connection's own). The websocket's row
# should be gone. Count non-header, non-separator lines with content.
lines = sql_out.strip().split('\n')
# Data rows are those that are not the header and not the separator line
data_rows = [l for l in lines if '|' in l and '-+-' not in l and 'identity' not in l.lower()]
self.assertLessEqual(len(data_rows), 1,
f"Expected at most 1 st_client row (the SQL connection itself), got: {sql_out}")
Loading