From dcbb70d2a2ea52725d0ba1f807ea9d6a96be7865 Mon Sep 17 00:00:00 2001 From: buraksenn Date: Fri, 6 Mar 2026 15:36:06 +0300 Subject: [PATCH] fix: initial implementation to discuss --- datafusion/sql/src/expr/value.rs | 17 ++++++++--------- datafusion/sql/src/planner.rs | 6 +++--- datafusion/sql/src/statement.rs | 27 ++++++++++++++++++++++----- datafusion/sql/tests/cases/params.rs | 16 ++++++++++++++++ 4 files changed, 49 insertions(+), 17 deletions(-) diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index bd75ac36306fb..13a47f545cf7e 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -45,7 +45,7 @@ impl SqlToRel<'_, S> { pub(crate) fn parse_value( &self, value: Value, - param_data_types: &[FieldRef], + param_data_types: &[Option], ) -> Result { match value { Value::Number(n, _) => self.parse_sql_number(&n, false), @@ -105,7 +105,7 @@ impl SqlToRel<'_, S> { /// Both named (`$foo`) and positional (`$1`, `$2`, ...) placeholder styles are supported. fn create_placeholder_expr( param: String, - param_data_types: &[FieldRef], + param_data_types: &[Option], ) -> Result { // Try to parse the placeholder as a number. If the placeholder does not have a valid // positional value, assume we have a named placeholder. @@ -124,13 +124,13 @@ impl SqlToRel<'_, S> { // FIXME: This branch is shared by params from PREPARE and CREATE FUNCTION, but // only CREATE FUNCTION currently supports named params. For now, we rewrite // these to positional params. - let named_param_pos = param_data_types - .iter() - .position(|v| v.name() == ¶m[1..]); + let named_param_pos = param_data_types.iter().position(|v| { + v.as_ref().is_some_and(|field| field.name() == ¶m[1..]) + }); match named_param_pos { Some(pos) => Ok(Expr::Placeholder(Placeholder::new_with_field( format!("${}", pos + 1), - param_data_types.get(pos).cloned(), + param_data_types.get(pos).and_then(|v| v.clone()), ))), None => plan_err!("Unknown placeholder: {param}"), } @@ -139,13 +139,12 @@ impl SqlToRel<'_, S> { }; // Check if the placeholder is in the parameter list // FIXME: In the CREATE FUNCTION branch, param_type = None should raise an error - let param_type = param_data_types.get(idx); + let param_type = param_data_types.get(idx).and_then(|v| v.clone()); // Data type of the parameter debug!("type of param {param} param_data_types[idx]: {param_type:?}"); Ok(Expr::Placeholder(Placeholder::new_with_field( - param, - param_type.cloned(), + param, param_type, ))) } diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 307f28e8ff9ad..2141207f3af97 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -257,7 +257,7 @@ impl IdentNormalizer { pub struct PlannerContext { /// Data types for numbered parameters ($1, $2, etc), if supplied /// in `PREPARE` statement - prepare_param_data_types: Arc>, + prepare_param_data_types: Arc>>, /// Map of CTE name to logical plan of the WITH clause. /// Use `Arc` to allow cheap cloning ctes: HashMap>, @@ -293,7 +293,7 @@ impl PlannerContext { /// Update the PlannerContext with provided prepare_param_data_types pub fn with_prepare_param_data_types( mut self, - prepare_param_data_types: Vec, + prepare_param_data_types: Vec>, ) -> Self { self.prepare_param_data_types = prepare_param_data_types.into(); self @@ -373,7 +373,7 @@ impl PlannerContext { } /// Return the types of parameters (`$1`, `$2`, etc) if known - pub fn prepare_param_data_types(&self) -> &[FieldRef] { + pub fn prepare_param_data_types(&self) -> &[Option] { &self.prepare_param_data_types } diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 32bc8cb244aae..b9eb1d080faff 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -790,8 +790,10 @@ impl SqlToRel<'_, S> { .collect::>()?; // Create planner context with parameters - let mut planner_context = - PlannerContext::new().with_prepare_param_data_types(fields.clone()); + let mut planner_context = PlannerContext::new() + .with_prepare_param_data_types( + fields.iter().cloned().map(Some).collect(), + ); // Build logical plan for inner statement of the prepare statement let plan = self.sql_statement_to_plan_with_context_impl( @@ -808,7 +810,9 @@ impl SqlToRel<'_, S> { }) .collect(); fields.extend(param_types.iter().cloned()); - planner_context.with_prepare_param_data_types(param_types); + planner_context.with_prepare_param_data_types( + param_types.into_iter().map(Some).collect(), + ); } Ok(LogicalPlan::Statement(PlanStatement::Prepare(Prepare { @@ -1332,7 +1336,13 @@ impl SqlToRel<'_, S> { } } let mut planner_context = PlannerContext::new() - .with_prepare_param_data_types(arg_types.unwrap_or_default()); + .with_prepare_param_data_types( + arg_types + .unwrap_or_default() + .into_iter() + .map(Some) + .collect(), + ); let function_body = match function_body { Some(r) => Some(self.sql_to_expr( @@ -2331,7 +2341,14 @@ impl SqlToRel<'_, S> { } } } - let prepare_param_data_types = prepare_param_data_types.into_values().collect(); + let prepare_param_data_types: Vec> = + if let Some(&max_idx) = prepare_param_data_types.keys().last() { + (0..=max_idx) + .map(|i| prepare_param_data_types.remove(&i)) + .collect() + } else { + vec![] + }; // Projection let mut planner_context = diff --git a/datafusion/sql/tests/cases/params.rs b/datafusion/sql/tests/cases/params.rs index 396f619400c74..68c560ead68cd 100644 --- a/datafusion/sql/tests/cases/params.rs +++ b/datafusion/sql/tests/cases/params.rs @@ -1041,6 +1041,22 @@ fn test_prepare_statement_unknown_hash_param() { ); } +#[test] +fn test_insert_infer_with_function_wrapped_placeholder() { + let plan = logical_plan( + "INSERT INTO person (id, first_name, age) VALUES ($1, character_length($2), $3)", + ) + .unwrap(); + + let actual_types = plan.get_parameter_types().unwrap(); + let expected_types: HashMap> = HashMap::from([ + ("$1".to_string(), Some(DataType::UInt32)), + ("$2".to_string(), None), + ("$3".to_string(), Some(DataType::Int32)), + ]); + assert_eq!(actual_types, expected_types); +} + #[test] fn test_prepare_statement_bad_list_idx() { let sql = "SELECT id from person where id = $foo";