Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions datafusion/sql/src/cte.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::sync::Arc;
use crate::planner::{ContextProvider, PlannerContext, SqlToRel};

use datafusion_common::{
Result, not_impl_err, plan_err,
Diagnostic, Result, Span, not_impl_err, plan_err,
tree_node::{TreeNode, TreeNodeRecursion},
};
use datafusion_expr::{LogicalPlan, LogicalPlanBuilder, TableSource};
Expand All @@ -37,10 +37,16 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
for cte in with.cte_tables {
// A `WITH` block can't use the same name more than once
let cte_name = self.ident_normalizer.normalize(cte.alias.name.clone());
let cte_name_span = Span::try_from_sqlparser_span(cte.alias.name.span);
if planner_context.contains_cte(&cte_name) {
return plan_err!(
"WITH query name {cte_name:?} specified more than once"
);
let msg =
format!("WITH query name {cte_name:?} specified more than once");
let mut diagnostic = Diagnostic::new_error(&msg, cte_name_span);
if let Some(first_span) = planner_context.get_cte_span(&cte_name) {
diagnostic =
diagnostic.with_note("previously defined here", Some(first_span));
}
return plan_err!("{msg}").map_err(|e| e.with_diagnostic(diagnostic));
}

// Create a logical plan for the CTE
Expand All @@ -53,8 +59,7 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
// Each `WITH` block can change the column names in the last
// projection (e.g. "WITH table(t1, t2) AS SELECT 1, 2").
let final_plan = self.apply_table_alias(cte_plan, cte.alias)?;
// Export the CTE to the outer query
planner_context.insert_cte(cte_name, final_plan);
planner_context.insert_cte_with_span(cte_name, final_plan, cte_name_span);
}
Ok(())
}
Expand Down
28 changes: 23 additions & 5 deletions datafusion/sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ use datafusion_common::TableReference;
use datafusion_common::config::SqlParserOptions;
use datafusion_common::datatype::{DataTypeExt, FieldExt};
use datafusion_common::error::add_possible_columns_to_diag;
use datafusion_common::{DFSchema, DataFusionError, Result, not_impl_err, plan_err};
use datafusion_common::{
DFSchema, DataFusionError, Result, Span, not_impl_err, plan_err,
};
use datafusion_common::{
DFSchemaRef, Diagnostic, SchemaError, field_not_found, internal_err,
plan_datafusion_err,
Expand Down Expand Up @@ -258,9 +260,9 @@ pub struct PlannerContext {
/// Data types for numbered parameters ($1, $2, etc), if supplied
/// in `PREPARE` statement
prepare_param_data_types: Arc<Vec<FieldRef>>,
/// Map of CTE name to logical plan of the WITH clause.
/// Map of CTE name to logical plan of the WITH clause and optional span.
/// Use `Arc<LogicalPlan>` to allow cheap cloning
ctes: HashMap<String, Arc<LogicalPlan>>,
ctes: HashMap<String, (Arc<LogicalPlan>, Option<Span>)>,

/// The queries schemas of outer query relations, used to resolve the outer referenced
/// columns in subquery (recursive aware)
Expand Down Expand Up @@ -387,19 +389,35 @@ impl PlannerContext {
/// Subquery for the specified name
pub fn insert_cte(&mut self, cte_name: impl Into<String>, plan: LogicalPlan) {
let cte_name = cte_name.into();
self.ctes.insert(cte_name, Arc::new(plan));
self.ctes.insert(cte_name, (Arc::new(plan), None));
}

/// Inserts a LogicalPlan with an optional span for the CTE
pub(super) fn insert_cte_with_span(
&mut self,
cte_name: impl Into<String>,
plan: LogicalPlan,
span: Option<Span>,
) {
let cte_name = cte_name.into();
self.ctes.insert(cte_name, (Arc::new(plan), span));
}

/// Return a plan for the Common Table Expression (CTE) / Subquery for the
/// specified name
pub fn get_cte(&self, cte_name: &str) -> Option<&LogicalPlan> {
self.ctes.get(cte_name).map(|cte| cte.as_ref())
self.ctes.get(cte_name).map(|(cte, _)| cte.as_ref())
}

/// Remove the plan of CTE / Subquery for the specified name
pub(super) fn remove_cte(&mut self, cte_name: &str) {
self.ctes.remove(cte_name);
}

/// Get the span of a previously defined CTE name
pub(super) fn get_cte_span(&self, name: &str) -> Option<Span> {
self.ctes.get(name).and_then(|(_, span)| *span)
}
}

/// SQL query planner and binder
Expand Down
73 changes: 63 additions & 10 deletions datafusion/sql/src/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use std::collections::HashSet;
use std::collections::{HashMap, HashSet};
use std::ops::ControlFlow;
use std::sync::Arc;

Expand All @@ -29,7 +29,9 @@ use crate::utils::{

use datafusion_common::error::DataFusionErrorBuilder;
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion_common::{Column, DFSchema, Result, not_impl_err, plan_err};
use datafusion_common::{
Column, DFSchema, Diagnostic, Result, Span, not_impl_err, plan_err,
};
use datafusion_common::{RecursionUnnestOption, UnnestOptions};
use datafusion_expr::expr::{Alias, PlannedReplaceSelectItem, WildcardOptions};
use datafusion_expr::expr_rewriter::{
Expand All @@ -50,7 +52,9 @@ use sqlparser::ast::{
SelectItemQualifiedWildcardKind, WildcardAdditionalOptions, WindowType,
visit_expressions_mut,
};
use sqlparser::ast::{NamedWindowDefinition, Select, SelectItem, TableWithJoins};
use sqlparser::ast::{
NamedWindowDefinition, Select, SelectItem, Spanned, TableFactor, TableWithJoins,
};

/// Result of the `aggregate` function, containing the aggregate plan and
/// rewritten expressions that reference the aggregate output columns.
Expand Down Expand Up @@ -690,21 +694,70 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
self.plan_table_with_joins(input, planner_context)
}
_ => {
let extract_table_name =
|t: &TableWithJoins| -> Option<(String, Option<Span>)> {
let span = Span::try_from_sqlparser_span(t.relation.span());
match &t.relation {
TableFactor::Table { alias: Some(a), .. } => {
let name =
self.ident_normalizer.normalize(a.name.clone());
Some((name, span))
}
TableFactor::Table {
name, alias: None, ..
} => {
let table_name = name
.0
.iter()
.filter_map(|p| p.as_ident())
.map(|id| self.ident_normalizer.normalize(id.clone()))
.next_back()?;
Some((table_name, span))
}
_ => None,
}
};

let mut alias_spans: HashMap<String, Option<Span>> = HashMap::new();

let mut from = from.into_iter();
let first = from.next().unwrap();

if let Some((name, span)) = extract_table_name(&first) {
alias_spans.entry(name).or_insert(span);
}

let mut left = LogicalPlanBuilder::from(
self.plan_table_with_joins(first, planner_context)?,
);

let mut left = LogicalPlanBuilder::from({
let input = from.next().unwrap();
self.plan_table_with_joins(input, planner_context)?
});
let old_outer_from_schema = {
let left_schema = Some(Arc::clone(left.schema()));
planner_context.set_outer_from_schema(left_schema)
};
for input in from {
// Join `input` with the current result (`left`).
let current_name = extract_table_name(&input);

if let Some((ref name, ref span)) = current_name {
alias_spans.entry(name.clone()).or_insert(*span);
}

let right = self.plan_table_with_joins(input, planner_context)?;
left = left.cross_join(right)?;
// Update the outer FROM schema.

left = left.cross_join(right).map_err(|e| {
if let Some((ref name, ref current_span)) = current_name
&& let Some(prior_span) =
alias_spans.get(name).copied().flatten()
{
let diagnostic = Diagnostic::new_error(
"duplicate table alias in FROM clause",
*current_span,
)
.with_note("first defined here", Some(prior_span));
return e.with_diagnostic(diagnostic);
}
e
})?;
let left_schema = Some(Arc::clone(left.schema()));
planner_context.set_outer_from_schema(left_schema);
}
Expand Down
52 changes: 52 additions & 0 deletions datafusion/sql/tests/cases/diagnostic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,55 @@ fn test_syntax_error() -> Result<()> {
},
}
}

#[test]
fn test_duplicate_cte_name() -> Result<()> {
let query = "WITH /*a*/cte/*a*/ AS (SELECT 1 AS col), /*b*/cte/*b*/ AS (SELECT 2 AS col) SELECT 1";
let spans = get_spans(query);
let diag = do_query(query);
assert_snapshot!(diag.message, @r#"WITH query name "cte" specified more than once"#);
assert_eq!(diag.span, Some(spans["b"]));
assert_eq!(diag.notes.len(), 1);
assert_snapshot!(diag.notes[0].message, @"previously defined here");
assert_eq!(diag.notes[0].span, Some(spans["a"]));
Ok(())
}

#[test]
fn test_duplicate_table_alias() -> Result<()> {
let query = "SELECT * FROM /*a*/person a/*a*/, /*b*/person a/*b*/";
let spans = get_spans(query);
let diag = do_query(query);
assert_snapshot!(diag.message, @"duplicate table alias in FROM clause");
assert_eq!(diag.span, Some(spans["b"]));
assert_eq!(diag.notes.len(), 1);
assert_snapshot!(diag.notes[0].message, @"first defined here");
assert_eq!(diag.notes[0].span, Some(spans["a"]));
Ok(())
}

#[test]
fn test_duplicate_table_alias_not_first() -> Result<()> {
let query = "SELECT * FROM person a, /*b*/test_decimal b/*b*/, /*c*/person b/*c*/";
let spans = get_spans(query);
let diag = do_query(query);
assert_snapshot!(diag.message, @"duplicate table alias in FROM clause");
assert_eq!(diag.span, Some(spans["c"]));
assert_eq!(diag.notes.len(), 1);
assert_snapshot!(diag.notes[0].message, @"first defined here");
assert_eq!(diag.notes[0].span, Some(spans["b"]));
Ok(())
}

#[test]
fn test_duplicate_bare_table_in_from() -> Result<()> {
let query = "SELECT * FROM /*a*/person/*a*/, /*b*/person/*b*/";
let spans = get_spans(query);
let diag = do_query(query);
assert_snapshot!(diag.message, @"duplicate table alias in FROM clause");
assert_eq!(diag.span, Some(spans["b"]));
assert_eq!(diag.notes.len(), 1);
assert_snapshot!(diag.notes[0].message, @"first defined here");
assert_eq!(diag.notes[0].span, Some(spans["a"]));
Ok(())
}