Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

41 changes: 34 additions & 7 deletions crates/client-api/src/routes/database.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use futures::TryStreamExt;
use http::StatusCode;
use log::{info, warn};
use serde::Deserialize;
use spacetimedb::auth::identity::ConnectionAuthCtx;
use spacetimedb::database_logger::DatabaseLogger;
use spacetimedb::host::module_host::ClientConnectedError;
use spacetimedb::host::{CallResult, UpdateDatabaseResult};
Expand Down Expand Up @@ -518,22 +519,46 @@ pub async fn sql_direct<S>(
SqlParams { name_or_identity }: SqlParams,
SqlQueryParams { confirmed }: SqlQueryParams,
caller_identity: Identity,
caller_auth: ConnectionAuthCtx,
sql: String,
) -> axum::response::Result<Vec<SqlStmtResult<ProductValue>>>
where
S: NodeDelegate + ControlStateDelegate + Authorization,
{
// Anyone is authorized to execute SQL queries. The SQL engine will determine
// which queries this identity is allowed to execute against the database.
let connection_id = generate_random_connection_id();

let (host, database) = find_leader_and_database(&worker_ctx, name_or_identity).await?;

let auth = worker_ctx
.authorize_sql(caller_identity, database.database_identity)
.await?;
// Run the module's client_connected reducer, if any.
// If it rejects the connection, bail before executing SQL.
let module = host.module().await.map_err(log_and_500)?;
module
.call_identity_connected(caller_auth, connection_id)
.await
.map_err(client_connected_error_to_response)?;

let result = async {
let sql_auth = worker_ctx
.authorize_sql(caller_identity, database.database_identity)
.await?;
Comment thread
bfops marked this conversation as resolved.

host.exec_sql(
sql_auth,
database,
confirmed.unwrap_or(crate::DEFAULT_CONFIRMED_READS),
sql,
)
.await
}
.await;

host.exec_sql(auth, database, confirmed.unwrap_or(crate::DEFAULT_CONFIRMED_READS), sql)
// Always disconnect, even if authorization or execution failed.
module
.call_identity_disconnected(caller_identity, connection_id, false)
.await
.map_err(client_disconnected_error_to_response)?;

result
}

pub async fn sql<S>(
Expand All @@ -546,7 +571,9 @@ pub async fn sql<S>(
where
S: NodeDelegate + ControlStateDelegate + Authorization,
{
let json = sql_direct(worker_ctx, name_or_identity, params, auth.claims.identity, body).await?;
let caller_identity = auth.claims.identity;
let caller_auth: ConnectionAuthCtx = auth.into();
let json = sql_direct(worker_ctx, name_or_identity, params, caller_identity, caller_auth, body).await?;

let total_duration = json.iter().fold(0, |acc, x| acc + x.total_duration_micros);

Expand Down
1 change: 1 addition & 0 deletions crates/pg/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ license-file = "LICENSE"
description = "Postgres wire protocol Server support for SpacetimeDB"

[dependencies]
spacetimedb-auth.workspace = true
spacetimedb-client-api-messages.workspace = true
spacetimedb-client-api.workspace = true
spacetimedb-lib.workspace = true
Expand Down
16 changes: 14 additions & 2 deletions crates/pg/src/pg_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use pgwire::messages::data::DataRow;
use pgwire::messages::startup::Authentication;
use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage};
use pgwire::tokio::process_socket;
use spacetimedb_auth::identity::ConnectionAuthCtx;
use spacetimedb_client_api::auth::validate_token;
use spacetimedb_client_api::routes::database;
use spacetimedb_client_api::routes::database::{SqlParams, SqlQueryParams};
Expand Down Expand Up @@ -64,6 +65,7 @@ impl From<PgError> for PgWireError {
struct Metadata {
database: String,
caller_identity: Identity,
caller_auth: ConnectionAuthCtx,
}

pub(crate) fn to_rows(
Expand Down Expand Up @@ -163,6 +165,7 @@ where
db,
SqlQueryParams { confirmed: Some(true) },
params.caller_identity,
params.caller_auth.clone(),
Comment thread
bfops marked this conversation as resolved.
query.to_string(),
)
.await,
Expand Down Expand Up @@ -266,8 +269,8 @@ impl<T: Sync + Send + ControlStateReadAccess + ControlStateWriteAccess + NodeDel
}
};

let caller_identity = match validate_token(&self.ctx, &pwd.password).await {
Ok(claims) => claims.identity,
let claims = match validate_token(&self.ctx, &pwd.password).await {
Ok(claims) => claims,
Err(err) => {
log::error!(
"PG: Authentication failed for identity `{}` on database {database}: {err}",
Expand All @@ -277,12 +280,21 @@ impl<T: Sync + Send + ControlStateReadAccess + ControlStateWriteAccess + NodeDel
return close_client(client, err).await;
}
};
let caller_identity = claims.identity;
let caller_auth = ConnectionAuthCtx::try_from(claims).map_err(|e| {
PgWireError::UserError(Box::new(ErrorInfo::new(
"FATAL".to_owned(),
"28000".to_owned(),
e.to_string(),
)))
})?;

log::info!("PG: Connected to database: {database} using identity `{caller_identity}`");

let metadata = Metadata {
database,
caller_identity,
caller_auth,
};
self.cached.lock().await.clone_from(&Some(metadata));
finish_authentication(client, &self.parameter_provider).await?;
Expand Down
1 change: 1 addition & 0 deletions crates/smoketests/modules/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ members = [
"delete-database",
"client-connection-reject",
"client-connection-disconnect-panic",
"sql-connect-hook",

# Log filtering tests
"logs-level-filter",
Expand Down
12 changes: 12 additions & 0 deletions crates/smoketests/modules/sql-connect-hook/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[package]
name = "smoketest-module-sql-connect-hook"
version = "0.1.0"
edition = "2021"
publish = false

[lib]
crate-type = ["cdylib"]

[dependencies]
spacetimedb.workspace = true
log.workspace = true
23 changes: 23 additions & 0 deletions crates/smoketests/modules/sql-connect-hook/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
use spacetimedb::{log, ReducerContext, Table};

#[spacetimedb::table(accessor = person, public)]
pub struct Person {
name: String,
}

#[spacetimedb::reducer(init)]
pub fn init(ctx: &ReducerContext) {
ctx.db.person().insert(Person {
name: "Alice".to_string(),
});
}

#[spacetimedb::reducer(client_connected)]
pub fn connected(ctx: &ReducerContext) {
log::info!("sql_connect_hook: client_connected caller={}", ctx.sender);
}

#[spacetimedb::reducer(client_disconnected)]
pub fn disconnected(ctx: &ReducerContext) {
log::info!("sql_connect_hook: client_disconnected caller={}", ctx.sender);
}
11 changes: 7 additions & 4 deletions crates/smoketests/tests/smoketests/client_connection_errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,14 @@ fn test_client_disconnected_error_still_deletes_st_client() {
logs
);

// Verify st_client table is empty (row was deleted despite the panic)
// Verify the websocket's st_client row was deleted despite the panic.
// The SQL query itself creates a temporary connection, so we may see
// exactly one row (the SQL connection's own), but the websocket's row
// should be gone.
let sql_out = test.sql("SELECT * FROM st_client").unwrap();
let row_count = sql_out.lines().filter(|l| l.contains("0x")).count();
assert!(
sql_out.contains("identity | connection_id") && !sql_out.contains("0x"),
"Expected st_client table to be empty, got: {}",
sql_out
row_count <= 1,
"Expected at most 1 st_client row (the SQL connection itself), got {row_count}: {sql_out}",
Comment thread
bfops marked this conversation as resolved.
);
}
1 change: 1 addition & 0 deletions crates/smoketests/tests/smoketests/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ mod rls;
mod schedule_reducer;
mod servers;
mod sql;
mod sql_connect_hook;
mod templates;
mod timestamp_route;
mod views;
64 changes: 64 additions & 0 deletions crates/smoketests/tests/smoketests/sql_connect_hook.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
use spacetimedb_smoketests::Smoketest;

/// Test that SQL requests are rejected when client_connected returns an error.
///
/// This verifies that the /sql HTTP endpoint now runs the module's
/// client_connected reducer and rejects the request if it errors.
/// Without PR #4563, this SQL query would succeed.
#[test]
fn test_sql_rejected_when_client_connected_errors() {
let test = Smoketest::builder()
.precompiled_module("client-connection-reject")
.build();

// SQL should fail because client_connected returns an error
let result = test.sql("SELECT * FROM all_u8s");
assert!(
result.is_err(),
"Expected SQL query to be rejected when client_connected errors, but it succeeded"
);
}
Comment thread
bfops marked this conversation as resolved.

/// Test that SQL requests trigger client_connected and client_disconnected hooks.
///
/// This verifies that the /sql HTTP endpoint calls the module's lifecycle
/// reducers. Without PR #4563, no connect/disconnect logs would appear.
#[test]
fn test_sql_triggers_connect_disconnect_hooks() {
let test = Smoketest::builder()
.precompiled_module("sql-connect-hook")
.build();

// Run a SQL query
test.sql("SELECT * FROM person").unwrap();

// Check that both connect and disconnect hooks were called
let logs = test.logs(100).unwrap();
assert!(
logs.iter().any(|l| l.contains("sql_connect_hook: client_connected")),
"Expected client_connected log from SQL request, got: {:?}",
logs
);
assert!(
logs.iter().any(|l| l.contains("sql_connect_hook: client_disconnected")),
"Expected client_disconnected log from SQL request, got: {:?}",
logs
);
}

/// Test that SQL queries still return data when client_connected accepts.
///
/// Ensures the connect hook doesn't break normal SQL functionality.
#[test]
fn test_sql_returns_data_with_connect_hook() {
let test = Smoketest::builder()
.precompiled_module("sql-connect-hook")
.build();

test.assert_sql(
"SELECT * FROM person",
r#" name
-------
Alice"#,
);
}
11 changes: 8 additions & 3 deletions smoketests/tests/client_connected_error_rejects_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ def test_client_disconnected_error_still_deletes_st_client(self):

sql_out = self.spacetime("sql", self.database_identity, "select * from st_client")

self.assertMultiLineEqual(sql_out, """ identity | connection_id
----------+---------------
""")
# The SQL query itself now creates a temporary connection, so we may
# see exactly one row (the SQL connection's own). The websocket's row
# should be gone. Count non-header, non-separator lines with content.
lines = sql_out.strip().split('\n')
# Data rows are those that are not the header and not the separator line
data_rows = [l for l in lines if '|' in l and '-+-' not in l and 'identity' not in l.lower()]
self.assertLessEqual(len(data_rows), 1,
f"Expected at most 1 st_client row (the SQL connection itself), got: {sql_out}")
Comment thread
bfops marked this conversation as resolved.
Loading