Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions datafusion/sql/src/expr/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
pub(crate) fn parse_value(
&self,
value: Value,
param_data_types: &[FieldRef],
param_data_types: &[Option<FieldRef>],
) -> Result<Expr> {
match value {
Value::Number(n, _) => self.parse_sql_number(&n, false),
Expand Down Expand Up @@ -105,7 +105,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
/// Both named (`$foo`) and positional (`$1`, `$2`, ...) placeholder styles are supported.
fn create_placeholder_expr(
param: String,
param_data_types: &[FieldRef],
param_data_types: &[Option<FieldRef>],
) -> Result<Expr> {
// Try to parse the placeholder as a number. If the placeholder does not have a valid
// positional value, assume we have a named placeholder.
Expand All @@ -124,13 +124,13 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
// FIXME: This branch is shared by params from PREPARE and CREATE FUNCTION, but
// only CREATE FUNCTION currently supports named params. For now, we rewrite
// these to positional params.
let named_param_pos = param_data_types
.iter()
.position(|v| v.name() == &param[1..]);
let named_param_pos = param_data_types.iter().position(|v| {
v.as_ref().is_some_and(|field| field.name() == &param[1..])
});
match named_param_pos {
Some(pos) => Ok(Expr::Placeholder(Placeholder::new_with_field(
format!("${}", pos + 1),
param_data_types.get(pos).cloned(),
param_data_types.get(pos).and_then(|v| v.clone()),
))),
None => plan_err!("Unknown placeholder: {param}"),
}
Expand All @@ -139,13 +139,12 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
};
// Check if the placeholder is in the parameter list
// FIXME: In the CREATE FUNCTION branch, param_type = None should raise an error
let param_type = param_data_types.get(idx);
let param_type = param_data_types.get(idx).and_then(|v| v.clone());
// Data type of the parameter
debug!("type of param {param} param_data_types[idx]: {param_type:?}");

Ok(Expr::Placeholder(Placeholder::new_with_field(
param,
param_type.cloned(),
param, param_type,
)))
}

Expand Down
6 changes: 3 additions & 3 deletions datafusion/sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ impl IdentNormalizer {
pub struct PlannerContext {
/// Data types for numbered parameters ($1, $2, etc), if supplied
/// in `PREPARE` statement
prepare_param_data_types: Arc<Vec<FieldRef>>,
prepare_param_data_types: Arc<Vec<Option<FieldRef>>>,
/// Map of CTE name to logical plan of the WITH clause.
/// Use `Arc<LogicalPlan>` to allow cheap cloning
ctes: HashMap<String, Arc<LogicalPlan>>,
Expand Down Expand Up @@ -293,7 +293,7 @@ impl PlannerContext {
/// Update the PlannerContext with provided prepare_param_data_types
pub fn with_prepare_param_data_types(
mut self,
prepare_param_data_types: Vec<FieldRef>,
prepare_param_data_types: Vec<Option<FieldRef>>,
) -> Self {
self.prepare_param_data_types = prepare_param_data_types.into();
self
Expand Down Expand Up @@ -373,7 +373,7 @@ impl PlannerContext {
}

/// Return the types of parameters (`$1`, `$2`, etc) if known
pub fn prepare_param_data_types(&self) -> &[FieldRef] {
pub fn prepare_param_data_types(&self) -> &[Option<FieldRef>] {
&self.prepare_param_data_types
}

Expand Down
27 changes: 22 additions & 5 deletions datafusion/sql/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -790,8 +790,10 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
.collect::<Result<_>>()?;

// Create planner context with parameters
let mut planner_context =
PlannerContext::new().with_prepare_param_data_types(fields.clone());
let mut planner_context = PlannerContext::new()
.with_prepare_param_data_types(
fields.iter().cloned().map(Some).collect(),
);

// Build logical plan for inner statement of the prepare statement
let plan = self.sql_statement_to_plan_with_context_impl(
Expand All @@ -808,7 +810,9 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
})
.collect();
fields.extend(param_types.iter().cloned());
planner_context.with_prepare_param_data_types(param_types);
planner_context.with_prepare_param_data_types(
param_types.into_iter().map(Some).collect(),
);
}

Ok(LogicalPlan::Statement(PlanStatement::Prepare(Prepare {
Expand Down Expand Up @@ -1332,7 +1336,13 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
}
}
let mut planner_context = PlannerContext::new()
.with_prepare_param_data_types(arg_types.unwrap_or_default());
.with_prepare_param_data_types(
arg_types
.unwrap_or_default()
.into_iter()
.map(Some)
.collect(),
);

let function_body = match function_body {
Some(r) => Some(self.sql_to_expr(
Expand Down Expand Up @@ -2331,7 +2341,14 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
}
}
}
let prepare_param_data_types = prepare_param_data_types.into_values().collect();
let prepare_param_data_types: Vec<Option<FieldRef>> =
if let Some(&max_idx) = prepare_param_data_types.keys().last() {
(0..=max_idx)
.map(|i| prepare_param_data_types.remove(&i))
.collect()
} else {
vec![]
};

// Projection
let mut planner_context =
Expand Down
16 changes: 16 additions & 0 deletions datafusion/sql/tests/cases/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1041,6 +1041,22 @@ fn test_prepare_statement_unknown_hash_param() {
);
}

#[test]
fn test_insert_infer_with_function_wrapped_placeholder() {
let plan = logical_plan(
"INSERT INTO person (id, first_name, age) VALUES ($1, character_length($2), $3)",
)
.unwrap();

let actual_types = plan.get_parameter_types().unwrap();
let expected_types: HashMap<String, Option<DataType>> = HashMap::from([
("$1".to_string(), Some(DataType::UInt32)),
("$2".to_string(), None),
("$3".to_string(), Some(DataType::Int32)),
]);
assert_eq!(actual_types, expected_types);
}

#[test]
fn test_prepare_statement_bad_list_idx() {
let sql = "SELECT id from person where id = $foo";
Expand Down