Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 96 additions & 14 deletions datafusion-postgres/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use pgwire::api::results::{FieldInfo, Response, Tag};
use pgwire::api::stmt::QueryParser;
use pgwire::api::{ClientInfo, ErrorHandler, PgWireServerHandlers, Type};
use pgwire::error::{PgWireError, PgWireResult};
use pgwire::messages::PgWireBackendMessage;
use pgwire::types::format::FormatOptions;

use crate::hooks::set_show::SetShowHook;
Expand Down Expand Up @@ -119,10 +120,11 @@ impl DfSessionService {
impl SimpleQueryHandler for DfSessionService {
async fn do_query<C>(&self, client: &mut C, query: &str) -> PgWireResult<Vec<Response>>
where
C: ClientInfo + Unpin + Send + Sync,
C: ClientInfo + futures::Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::Error: std::fmt::Debug,
PgWireError: From<<C as futures::Sink<PgWireBackendMessage>>::Error>,
{
log::debug!("Received query: {query}"); // Log the query for debugging

log::debug!("Received query: {query}");
let statements = self
.parser
.sql_parser
Expand Down Expand Up @@ -206,11 +208,12 @@ impl ExtendedQueryHandler for DfSessionService {
_max_rows: usize,
) -> PgWireResult<Response>
where
C: ClientInfo + Unpin + Send + Sync,
C: ClientInfo + futures::Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::Error: std::fmt::Debug,
PgWireError: From<<C as futures::Sink<PgWireBackendMessage>>::Error>,
{
let query = &portal.statement.statement.0;
log::debug!("Received execute extended query: {query}"); // Log for debugging

log::debug!("Received execute extended query: {query}");
// Check query hooks first
if !self.query_hooks.is_empty() {
if let (_, Some((statement, plan))) = &portal.statement.statement {
Expand Down Expand Up @@ -243,13 +246,12 @@ impl ExtendedQueryHandler for DfSessionService {
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;

let param_values =
df::deserialize_parameters(portal, &ordered_param_types(&param_types))?; // Fixed: Use &param_types
df::deserialize_parameters(portal, &ordered_param_types(&param_types))?;

let plan = plan
.clone()
.replace_params_with_values(&param_values)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?; // Fixed: Use
// &param_values
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
let optimised = self
.session_context
.state()
Expand Down Expand Up @@ -345,8 +347,7 @@ impl QueryParser for Parser {
where
C: ClientInfo + Unpin + Send + Sync,
{
log::debug!("Received parse extended query: {sql}"); // Log for debugging

log::debug!("Received parse extended query: {sql}");
let mut statements = self
.sql_parser
.parse(sql)
Expand Down Expand Up @@ -384,7 +385,6 @@ impl QueryParser for Parser {

let mut param_types = Vec::with_capacity(params.len());
for param_type in ordered_param_types(&params).iter() {
// Fixed: Use &params
if let Some(datatype) = param_type {
let pgtype = into_pg_type(datatype)?;
param_types.push(pgtype);
Expand Down Expand Up @@ -434,6 +434,8 @@ mod tests {
use super::*;
use crate::testing::MockClient;

use crate::hooks::HookClient;

struct TestHook;

#[async_trait]
Expand All @@ -442,7 +444,7 @@ mod tests {
&self,
statement: &sqlparser::ast::Statement,
_ctx: &SessionContext,
_client: &mut (dyn ClientInfo + Sync + Send),
_client: &mut dyn HookClient,
) -> Option<PgWireResult<Response>> {
if statement.to_string().contains("magic") {
Some(Ok(Response::EmptyQuery))
Expand All @@ -466,7 +468,7 @@ mod tests {
_logical_plan: &LogicalPlan,
_params: &ParamValues,
_session_context: &SessionContext,
_client: &mut (dyn ClientInfo + Send + Sync),
_client: &mut dyn HookClient,
) -> Option<PgWireResult<Response>> {
None
}
Expand Down Expand Up @@ -523,4 +525,84 @@ mod tests {
assert!(matches!(results[2], Response::EmptyQuery));
assert!(matches!(results[3], Response::Query(_)));
}

#[tokio::test]
async fn test_set_sends_parameter_status_via_sink() {
use pgwire::messages::PgWireBackendMessage;

let service = crate::testing::setup_handlers();
let mut client = MockClient::new();

let test_cases = vec![
("SET datestyle = 'ISO, MDY'", "DateStyle", "ISO, MDY"),
(
"SET intervalstyle = 'postgres'",
"IntervalStyle",
"postgres",
),
("SET bytea_output = 'hex'", "bytea_output", "hex"),
(
"SET application_name = 'myapp'",
"application_name",
"myapp",
),
("SET search_path = 'public'", "search_path", "public"),
("SET extra_float_digits = '2'", "extra_float_digits", "2"),
(
"SET TIME ZONE 'America/New_York'",
"TimeZone",
"America/New_York",
),
];

for (sql, expected_key, expected_value) in test_cases {
client.sent_messages.clear();

let responses =
<DfSessionService as SimpleQueryHandler>::do_query(&service, &mut client, sql)
.await
.unwrap();

assert!(
matches!(responses[0], Response::Execution(_)),
"Expected SET tag for {sql}"
);

let ps_msgs: Vec<_> = client
.sent_messages()
.iter()
.filter_map(|m| match m {
PgWireBackendMessage::ParameterStatus(ps) => Some(ps),
_ => None,
})
.collect();

assert_eq!(ps_msgs.len(), 1, "Expected 1 ParameterStatus for {sql}");
assert_eq!(ps_msgs[0].name, expected_key, "Wrong key for {sql}");
assert_eq!(ps_msgs[0].value, expected_value, "Wrong value for {sql}");
}
}

#[tokio::test]
async fn test_set_statement_timeout_no_parameter_status() {
use pgwire::messages::PgWireBackendMessage;

let service = crate::testing::setup_handlers();
let mut client = MockClient::new();

<DfSessionService as SimpleQueryHandler>::do_query(
&service,
&mut client,
"SET statement_timeout TO '5000ms'",
)
.await
.unwrap();

let has_ps = client
.sent_messages()
.iter()
.any(|m| matches!(m, PgWireBackendMessage::ParameterStatus(_)));

assert!(!has_ps, "statement_timeout should not send ParameterStatus");
}
}
25 changes: 22 additions & 3 deletions datafusion-postgres/src/hooks/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,28 @@ use datafusion::common::ParamValues;
use datafusion::logical_expr::LogicalPlan;
use datafusion::prelude::SessionContext;
use datafusion::sql::sqlparser::ast::Statement;
use futures::Sink;
use pgwire::api::results::Response;
use pgwire::api::ClientInfo;
use pgwire::error::PgWireResult;
use pgwire::error::{PgWireError, PgWireResult};
use pgwire::messages::PgWireBackendMessage;

#[async_trait]
pub trait HookClient: ClientInfo + Send + Sync {
async fn send_message(&mut self, item: PgWireBackendMessage) -> PgWireResult<()>;
}

#[async_trait]
impl<S> HookClient for S
where
S: ClientInfo + Sink<PgWireBackendMessage> + Send + Sync + Unpin,
PgWireError: From<<S as Sink<PgWireBackendMessage>>::Error>,
{
async fn send_message(&mut self, item: PgWireBackendMessage) -> PgWireResult<()> {
use futures::SinkExt;
self.send(item).await.map_err(PgWireError::from)
}
}

#[async_trait]
pub trait QueryHook: Send + Sync {
Expand All @@ -19,7 +38,7 @@ pub trait QueryHook: Send + Sync {
&self,
statement: &Statement,
session_context: &SessionContext,
client: &mut (dyn ClientInfo + Send + Sync),
client: &mut dyn HookClient,
) -> Option<PgWireResult<Response>>;

/// called at extended query parse phase, for generating `LogicalPlan`from statement
Expand All @@ -37,6 +56,6 @@ pub trait QueryHook: Send + Sync {
logical_plan: &LogicalPlan,
params: &ParamValues,
session_context: &SessionContext,
client: &mut (dyn ClientInfo + Send + Sync),
client: &mut dyn HookClient,
) -> Option<PgWireResult<Response>>;
}
8 changes: 5 additions & 3 deletions datafusion-postgres/src/hooks/permissions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@ use datafusion::common::ParamValues;
use datafusion::logical_expr::LogicalPlan;
use datafusion::prelude::SessionContext;
use datafusion::sql::sqlparser::ast::Statement;
use datafusion_pg_catalog::pg_catalog::context::{Permission, ResourceType};
use pgwire::api::results::Response;
use pgwire::api::ClientInfo;
use pgwire::error::{PgWireError, PgWireResult};

use crate::auth::AuthManager;
use crate::hooks::HookClient;
use crate::QueryHook;

use datafusion_pg_catalog::pg_catalog::context::{Permission, ResourceType};

#[derive(Debug)]
pub struct PermissionsHook {
auth_manager: Arc<AuthManager>,
Expand Down Expand Up @@ -96,7 +98,7 @@ impl QueryHook for PermissionsHook {
&self,
statement: &Statement,
_session_context: &SessionContext,
client: &mut (dyn ClientInfo + Send + Sync),
client: &mut dyn HookClient,
) -> Option<PgWireResult<Response>> {
if Self::should_skip_permission_check(statement) {
return None;
Expand Down Expand Up @@ -125,7 +127,7 @@ impl QueryHook for PermissionsHook {
_logical_plan: &LogicalPlan,
_params: &ParamValues,
_session_context: &SessionContext,
client: &mut (dyn ClientInfo + Send + Sync),
client: &mut dyn HookClient,
) -> Option<PgWireResult<Response>> {
if Self::should_skip_permission_check(statement) {
return None;
Expand Down
Loading
Loading