diff --git a/datafusion-postgres/src/handlers.rs b/datafusion-postgres/src/handlers.rs index 3eb5e5a..37671cf 100644 --- a/datafusion-postgres/src/handlers.rs +++ b/datafusion-postgres/src/handlers.rs @@ -17,6 +17,7 @@ use pgwire::api::results::{FieldInfo, Response, Tag}; use pgwire::api::stmt::QueryParser; use pgwire::api::{ClientInfo, ErrorHandler, PgWireServerHandlers, Type}; use pgwire::error::{PgWireError, PgWireResult}; +use pgwire::messages::PgWireBackendMessage; use pgwire::types::format::FormatOptions; use crate::hooks::set_show::SetShowHook; @@ -119,10 +120,11 @@ impl DfSessionService { impl SimpleQueryHandler for DfSessionService { async fn do_query(&self, client: &mut C, query: &str) -> PgWireResult> where - C: ClientInfo + Unpin + Send + Sync, + C: ClientInfo + futures::Sink + Unpin + Send + Sync, + C::Error: std::fmt::Debug, + PgWireError: From<>::Error>, { - log::debug!("Received query: {query}"); // Log the query for debugging - + log::debug!("Received query: {query}"); let statements = self .parser .sql_parser @@ -206,11 +208,12 @@ impl ExtendedQueryHandler for DfSessionService { _max_rows: usize, ) -> PgWireResult where - C: ClientInfo + Unpin + Send + Sync, + C: ClientInfo + futures::Sink + Unpin + Send + Sync, + C::Error: std::fmt::Debug, + PgWireError: From<>::Error>, { let query = &portal.statement.statement.0; - log::debug!("Received execute extended query: {query}"); // Log for debugging - + log::debug!("Received execute extended query: {query}"); // Check query hooks first if !self.query_hooks.is_empty() { if let (_, Some((statement, plan))) = &portal.statement.statement { @@ -243,13 +246,12 @@ impl ExtendedQueryHandler for DfSessionService { .map_err(|e| PgWireError::ApiError(Box::new(e)))?; let param_values = - df::deserialize_parameters(portal, &ordered_param_types(¶m_types))?; // Fixed: Use ¶m_types + df::deserialize_parameters(portal, &ordered_param_types(¶m_types))?; let plan = plan .clone() .replace_params_with_values(¶m_values) - .map_err(|e| PgWireError::ApiError(Box::new(e)))?; // Fixed: Use - // ¶m_values + .map_err(|e| PgWireError::ApiError(Box::new(e)))?; let optimised = self .session_context .state() @@ -345,8 +347,7 @@ impl QueryParser for Parser { where C: ClientInfo + Unpin + Send + Sync, { - log::debug!("Received parse extended query: {sql}"); // Log for debugging - + log::debug!("Received parse extended query: {sql}"); let mut statements = self .sql_parser .parse(sql) @@ -384,7 +385,6 @@ impl QueryParser for Parser { let mut param_types = Vec::with_capacity(params.len()); for param_type in ordered_param_types(¶ms).iter() { - // Fixed: Use ¶ms if let Some(datatype) = param_type { let pgtype = into_pg_type(datatype)?; param_types.push(pgtype); @@ -434,6 +434,8 @@ mod tests { use super::*; use crate::testing::MockClient; + use crate::hooks::HookClient; + struct TestHook; #[async_trait] @@ -442,7 +444,7 @@ mod tests { &self, statement: &sqlparser::ast::Statement, _ctx: &SessionContext, - _client: &mut (dyn ClientInfo + Sync + Send), + _client: &mut dyn HookClient, ) -> Option> { if statement.to_string().contains("magic") { Some(Ok(Response::EmptyQuery)) @@ -466,7 +468,7 @@ mod tests { _logical_plan: &LogicalPlan, _params: &ParamValues, _session_context: &SessionContext, - _client: &mut (dyn ClientInfo + Send + Sync), + _client: &mut dyn HookClient, ) -> Option> { None } @@ -523,4 +525,84 @@ mod tests { assert!(matches!(results[2], Response::EmptyQuery)); assert!(matches!(results[3], Response::Query(_))); } + + #[tokio::test] + async fn test_set_sends_parameter_status_via_sink() { + use pgwire::messages::PgWireBackendMessage; + + let service = crate::testing::setup_handlers(); + let mut client = MockClient::new(); + + let test_cases = vec![ + ("SET datestyle = 'ISO, MDY'", "DateStyle", "ISO, MDY"), + ( + "SET intervalstyle = 'postgres'", + "IntervalStyle", + "postgres", + ), + ("SET bytea_output = 'hex'", "bytea_output", "hex"), + ( + "SET application_name = 'myapp'", + "application_name", + "myapp", + ), + ("SET search_path = 'public'", "search_path", "public"), + ("SET extra_float_digits = '2'", "extra_float_digits", "2"), + ( + "SET TIME ZONE 'America/New_York'", + "TimeZone", + "America/New_York", + ), + ]; + + for (sql, expected_key, expected_value) in test_cases { + client.sent_messages.clear(); + + let responses = + ::do_query(&service, &mut client, sql) + .await + .unwrap(); + + assert!( + matches!(responses[0], Response::Execution(_)), + "Expected SET tag for {sql}" + ); + + let ps_msgs: Vec<_> = client + .sent_messages() + .iter() + .filter_map(|m| match m { + PgWireBackendMessage::ParameterStatus(ps) => Some(ps), + _ => None, + }) + .collect(); + + assert_eq!(ps_msgs.len(), 1, "Expected 1 ParameterStatus for {sql}"); + assert_eq!(ps_msgs[0].name, expected_key, "Wrong key for {sql}"); + assert_eq!(ps_msgs[0].value, expected_value, "Wrong value for {sql}"); + } + } + + #[tokio::test] + async fn test_set_statement_timeout_no_parameter_status() { + use pgwire::messages::PgWireBackendMessage; + + let service = crate::testing::setup_handlers(); + let mut client = MockClient::new(); + + ::do_query( + &service, + &mut client, + "SET statement_timeout TO '5000ms'", + ) + .await + .unwrap(); + + let has_ps = client + .sent_messages() + .iter() + .any(|m| matches!(m, PgWireBackendMessage::ParameterStatus(_))); + + assert!(!has_ps, "statement_timeout should not send ParameterStatus"); + } } diff --git a/datafusion-postgres/src/hooks/mod.rs b/datafusion-postgres/src/hooks/mod.rs index 2f12ef9..c1c6f58 100644 --- a/datafusion-postgres/src/hooks/mod.rs +++ b/datafusion-postgres/src/hooks/mod.rs @@ -8,9 +8,28 @@ use datafusion::common::ParamValues; use datafusion::logical_expr::LogicalPlan; use datafusion::prelude::SessionContext; use datafusion::sql::sqlparser::ast::Statement; +use futures::Sink; use pgwire::api::results::Response; use pgwire::api::ClientInfo; -use pgwire::error::PgWireResult; +use pgwire::error::{PgWireError, PgWireResult}; +use pgwire::messages::PgWireBackendMessage; + +#[async_trait] +pub trait HookClient: ClientInfo + Send + Sync { + async fn send_message(&mut self, item: PgWireBackendMessage) -> PgWireResult<()>; +} + +#[async_trait] +impl HookClient for S +where + S: ClientInfo + Sink + Send + Sync + Unpin, + PgWireError: From<>::Error>, +{ + async fn send_message(&mut self, item: PgWireBackendMessage) -> PgWireResult<()> { + use futures::SinkExt; + self.send(item).await.map_err(PgWireError::from) + } +} #[async_trait] pub trait QueryHook: Send + Sync { @@ -19,7 +38,7 @@ pub trait QueryHook: Send + Sync { &self, statement: &Statement, session_context: &SessionContext, - client: &mut (dyn ClientInfo + Send + Sync), + client: &mut dyn HookClient, ) -> Option>; /// called at extended query parse phase, for generating `LogicalPlan`from statement @@ -37,6 +56,6 @@ pub trait QueryHook: Send + Sync { logical_plan: &LogicalPlan, params: &ParamValues, session_context: &SessionContext, - client: &mut (dyn ClientInfo + Send + Sync), + client: &mut dyn HookClient, ) -> Option>; } diff --git a/datafusion-postgres/src/hooks/permissions.rs b/datafusion-postgres/src/hooks/permissions.rs index ff5a7f6..ac663e1 100644 --- a/datafusion-postgres/src/hooks/permissions.rs +++ b/datafusion-postgres/src/hooks/permissions.rs @@ -5,14 +5,16 @@ use datafusion::common::ParamValues; use datafusion::logical_expr::LogicalPlan; use datafusion::prelude::SessionContext; use datafusion::sql::sqlparser::ast::Statement; -use datafusion_pg_catalog::pg_catalog::context::{Permission, ResourceType}; use pgwire::api::results::Response; use pgwire::api::ClientInfo; use pgwire::error::{PgWireError, PgWireResult}; use crate::auth::AuthManager; +use crate::hooks::HookClient; use crate::QueryHook; +use datafusion_pg_catalog::pg_catalog::context::{Permission, ResourceType}; + #[derive(Debug)] pub struct PermissionsHook { auth_manager: Arc, @@ -96,7 +98,7 @@ impl QueryHook for PermissionsHook { &self, statement: &Statement, _session_context: &SessionContext, - client: &mut (dyn ClientInfo + Send + Sync), + client: &mut dyn HookClient, ) -> Option> { if Self::should_skip_permission_check(statement) { return None; @@ -125,7 +127,7 @@ impl QueryHook for PermissionsHook { _logical_plan: &LogicalPlan, _params: &ParamValues, _session_context: &SessionContext, - client: &mut (dyn ClientInfo + Send + Sync), + client: &mut dyn HookClient, ) -> Option> { if Self::should_skip_permission_check(statement) { return None; diff --git a/datafusion-postgres/src/hooks/set_show.rs b/datafusion-postgres/src/hooks/set_show.rs index 93c16c5..3d15179 100644 --- a/datafusion-postgres/src/hooks/set_show.rs +++ b/datafusion-postgres/src/hooks/set_show.rs @@ -12,10 +12,13 @@ use pgwire::api::auth::DefaultServerParameterProvider; use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo, QueryResponse, Response, Tag}; use pgwire::api::ClientInfo; use pgwire::error::{PgWireError, PgWireResult}; +use pgwire::messages::startup::ParameterStatus; +use pgwire::messages::PgWireBackendMessage; use pgwire::types::format::FormatOptions; use postgres_types::Type; use crate::client; +use crate::hooks::HookClient; use crate::QueryHook; #[derive(Debug)] @@ -28,7 +31,7 @@ impl QueryHook for SetShowHook { &self, statement: &Statement, session_context: &SessionContext, - client: &mut (dyn ClientInfo + Send + Sync), + client: &mut dyn HookClient, ) -> Option> { match statement { Statement::Set { .. } => { @@ -85,7 +88,7 @@ impl QueryHook for SetShowHook { _logical_plan: &LogicalPlan, _params: &ParamValues, session_context: &SessionContext, - client: &mut (dyn ClientInfo + Send + Sync), + client: &mut dyn HookClient, ) -> Option> { match statement { Statement::Set { .. } => { @@ -118,14 +121,11 @@ fn mock_show_response(name: &str, value: &str) -> PgWireResult { Ok(QueryResponse::new(Arc::new(fields), Box::pin(row_stream))) } -async fn try_respond_set_statements( - client: &mut C, +async fn try_respond_set_statements( + client: &mut dyn HookClient, statement: &Statement, session_context: &SessionContext, -) -> Option> -where - C: ClientInfo + Send + Sync + ?Sized, -{ +) -> Option> { let Statement::Set(set_statement) = statement else { return None; }; @@ -184,9 +184,18 @@ where // postgres configuration variables let value = values[0].clone(); if let Expr::Value(value) = value { - client - .metadata_mut() - .insert(var, value.into_string().unwrap_or_else(|| "".to_string())); + let val_str = value.into_string().unwrap_or_else(|| "".to_string()); + client.metadata_mut().insert(var.clone(), val_str); + if let Some((name, value)) = parameter_status_for_var(&var, &*client) { + if let Err(e) = client + .send_message(PgWireBackendMessage::ParameterStatus( + ParameterStatus::new(name, value), + )) + .await + { + return Some(Err(e)); + } + } return Some(Ok(Response::Execution(Tag::new("SET")))); } } @@ -205,6 +214,16 @@ where .options_mut() .execution .time_zone = Some(tz.to_string()); + let tz_value = client::get_timezone(client).unwrap_or("UTC").to_string(); + if let Err(e) = client + .send_message(PgWireBackendMessage::ParameterStatus(ParameterStatus::new( + "TimeZone".to_string(), + tz_value, + ))) + .await + { + return Some(Err(e)); + } return Some(Ok(Response::Execution(Tag::new("SET")))); } _ => {} @@ -221,6 +240,23 @@ where Some(Ok(Response::Execution(Tag::new("SET")))) } +fn parameter_status_for_var( + var: &str, + client: &(impl ClientInfo + ?Sized), +) -> Option<(String, String)> { + let display_name = match var { + "datestyle" => "DateStyle", + "intervalstyle" => "IntervalStyle", + "bytea_output" => "bytea_output", + "application_name" => "application_name", + "extra_float_digits" => "extra_float_digits", + "search_path" => "search_path", + _ => return None, + }; + let value = client.metadata().get(var)?.clone(); + Some((display_name.to_string(), value)) +} + async fn execute_set_statement( session_context: &SessionContext, statement: Statement, @@ -239,14 +275,11 @@ async fn execute_set_statement( .map(|_| ()) } -async fn try_respond_show_statements( - client: &C, +async fn try_respond_show_statements( + client: &dyn HookClient, statement: &Statement, session_context: &SessionContext, -) -> Option> -where - C: ClientInfo + ?Sized, -{ +) -> Option> { let Statement::ShowVariable { variable } = statement else { return None; }; @@ -368,7 +401,7 @@ mod tests { let session_context = SessionContext::new(); let mut client = MockClient::new(); - // Test setting timeout to 5000ms + // Test setting bytea_output to hex let statement = Parser::new(&PostgreSqlDialect {}) .try_with_sql("set bytea_output = 'hex'") .unwrap() @@ -380,11 +413,11 @@ mod tests { assert!(set_response.is_some()); assert!(set_response.unwrap().is_ok()); - // Verify the timeout was set in client metadata + // Verify the value was set in client metadata let bytea_output = client.metadata().get("bytea_output").unwrap(); assert_eq!(bytea_output, "hex"); - // Test SHOW statement_timeout + // Test SHOW bytea_output let statement = Parser::new(&PostgreSqlDialect {}) .try_with_sql("show bytea_output") .unwrap() @@ -402,7 +435,7 @@ mod tests { let session_context = SessionContext::new(); let mut client = MockClient::new(); - // Test setting timeout to 5000ms + // Test setting dateStyle let statement = Parser::new(&PostgreSqlDialect {}) .try_with_sql("set dateStyle = 'ISO, DMY'") .unwrap() @@ -414,11 +447,11 @@ mod tests { assert!(set_response.is_some()); assert!(set_response.unwrap().is_ok()); - // Verify the timeout was set in client metadata + // Verify the value was set in client metadata let bytea_output = client.metadata().get("datestyle").unwrap(); assert_eq!(bytea_output, "ISO, DMY"); - // Test SHOW statement_timeout + // Test SHOW dateStyle let statement = Parser::new(&PostgreSqlDialect {}) .try_with_sql("show dateStyle") .unwrap() @@ -460,6 +493,86 @@ mod tests { assert_eq!(timeout, None); } + #[tokio::test] + async fn test_parameter_status_sent_for_all_set_vars() { + use pgwire::messages::PgWireBackendMessage; + + let test_cases = vec![ + ("set bytea_output = 'escape'", "bytea_output", "escape"), + ( + "set intervalstyle = 'postgres'", + "IntervalStyle", + "postgres", + ), + ( + "set application_name = 'myapp'", + "application_name", + "myapp", + ), + ("set search_path = 'public'", "search_path", "public"), + ("set extra_float_digits = '2'", "extra_float_digits", "2"), + ("set datestyle = 'ISO, MDY'", "DateStyle", "ISO, MDY"), + ( + "set time zone 'America/New_York'", + "TimeZone", + "America/New_York", + ), + ]; + + for (sql, expected_key, expected_value) in test_cases { + let session_context = SessionContext::new(); + let mut client = MockClient::new(); + let statement = Parser::new(&PostgreSqlDialect {}) + .try_with_sql(sql) + .unwrap() + .parse_statement() + .unwrap(); + + let result = + try_respond_set_statements(&mut client, &statement, &session_context).await; + assert!(result.is_some(), "Expected Some for {sql}"); + assert!(result.unwrap().is_ok(), "Expected Ok for {sql}"); + + let ps_msgs: Vec<_> = client + .sent_messages() + .iter() + .filter_map(|m| match m { + PgWireBackendMessage::ParameterStatus(ps) => Some(ps), + _ => None, + }) + .collect(); + + assert_eq!(ps_msgs.len(), 1, "Expected 1 ParameterStatus for {sql}"); + assert_eq!(ps_msgs[0].name, expected_key, "Wrong key for {sql}"); + assert_eq!(ps_msgs[0].value, expected_value, "Wrong value for {sql}"); + } + } + + #[tokio::test] + async fn test_no_parameter_status_for_statement_timeout() { + use pgwire::messages::PgWireBackendMessage; + + let session_context = SessionContext::new(); + let mut client = MockClient::new(); + + let statement = Parser::new(&PostgreSqlDialect {}) + .try_with_sql("set statement_timeout to '5000ms'") + .unwrap() + .parse_statement() + .unwrap(); + + let result = try_respond_set_statements(&mut client, &statement, &session_context).await; + assert!(result.is_some()); + assert!(result.unwrap().is_ok()); + + let has_ps = client + .sent_messages() + .iter() + .any(|m| matches!(m, PgWireBackendMessage::ParameterStatus(_))); + + assert!(!has_ps, "statement_timeout should not send ParameterStatus"); + } + #[tokio::test] async fn test_supported_show_statements_returned_columns() { let session_context = SessionContext::new(); diff --git a/datafusion-postgres/src/hooks/transactions.rs b/datafusion-postgres/src/hooks/transactions.rs index a7bf2fe..13ef260 100644 --- a/datafusion-postgres/src/hooks/transactions.rs +++ b/datafusion-postgres/src/hooks/transactions.rs @@ -10,6 +10,7 @@ use pgwire::api::ClientInfo; use pgwire::error::{PgWireError, PgWireResult}; use pgwire::messages::response::TransactionStatus; +use crate::hooks::HookClient; use crate::QueryHook; /// Hook for processing transaction related statements @@ -26,14 +27,14 @@ impl QueryHook for TransactionStatementHook { &self, statement: &Statement, _session_context: &SessionContext, - client: &mut (dyn ClientInfo + Send + Sync), + client: &mut dyn HookClient, ) -> Option> { let resp = try_respond_transaction_statements(client, statement) .await .transpose(); - if resp.is_some() { - return resp; + if let Some(result) = resp { + return Some(result); } // Check if we're in a failed transaction and block non-transaction @@ -82,7 +83,7 @@ impl QueryHook for TransactionStatementHook { _logical_plan: &LogicalPlan, _params: &ParamValues, session_context: &SessionContext, - client: &mut (dyn ClientInfo + Send + Sync), + client: &mut dyn HookClient, ) -> Option> { self.handle_simple_query(statement, session_context, client) .await diff --git a/datafusion-postgres/src/testing.rs b/datafusion-postgres/src/testing.rs index f2d53db..4f9a7b2 100644 --- a/datafusion-postgres/src/testing.rs +++ b/datafusion-postgres/src/testing.rs @@ -30,6 +30,7 @@ pub fn setup_handlers() -> DfSessionService { pub struct MockClient { metadata: HashMap, portal_store: HashMap, + pub sent_messages: Vec, } impl MockClient { @@ -40,8 +41,13 @@ impl MockClient { MockClient { metadata, portal_store: HashMap::default(), + sent_messages: Vec::new(), } } + + pub fn sent_messages(&self) -> &[PgWireBackendMessage] { + &self.sent_messages + } } impl ClientInfo for MockClient { @@ -101,6 +107,17 @@ impl ClientPortalStore for MockClient { } } +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_mock_client_captures_messages() { + let client = MockClient::new(); + assert!(client.sent_messages().is_empty()); + } +} + impl Sink for MockClient { type Error = std::io::Error; @@ -112,9 +129,10 @@ impl Sink for MockClient { } fn start_send( - self: std::pin::Pin<&mut Self>, - _item: PgWireBackendMessage, + mut self: std::pin::Pin<&mut Self>, + item: PgWireBackendMessage, ) -> Result<(), Self::Error> { + self.sent_messages.push(item); Ok(()) }