diff --git a/datafusion/sql/src/cte.rs b/datafusion/sql/src/cte.rs index 18766d7056355..63f0ba20cf7bd 100644 --- a/datafusion/sql/src/cte.rs +++ b/datafusion/sql/src/cte.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{ - Result, not_impl_err, plan_err, + Diagnostic, Result, Span, not_impl_err, plan_err, tree_node::{TreeNode, TreeNodeRecursion}, }; use datafusion_expr::{LogicalPlan, LogicalPlanBuilder, TableSource}; @@ -37,10 +37,16 @@ impl SqlToRel<'_, S> { for cte in with.cte_tables { // A `WITH` block can't use the same name more than once let cte_name = self.ident_normalizer.normalize(cte.alias.name.clone()); + let cte_name_span = Span::try_from_sqlparser_span(cte.alias.name.span); if planner_context.contains_cte(&cte_name) { - return plan_err!( - "WITH query name {cte_name:?} specified more than once" - ); + let msg = + format!("WITH query name {cte_name:?} specified more than once"); + let mut diagnostic = Diagnostic::new_error(&msg, cte_name_span); + if let Some(first_span) = planner_context.get_cte_span(&cte_name) { + diagnostic = + diagnostic.with_note("previously defined here", Some(first_span)); + } + return plan_err!("{msg}").map_err(|e| e.with_diagnostic(diagnostic)); } // Create a logical plan for the CTE @@ -53,8 +59,7 @@ impl SqlToRel<'_, S> { // Each `WITH` block can change the column names in the last // projection (e.g. "WITH table(t1, t2) AS SELECT 1, 2"). let final_plan = self.apply_table_alias(cte_plan, cte.alias)?; - // Export the CTE to the outer query - planner_context.insert_cte(cte_name, final_plan); + planner_context.insert_cte_with_span(cte_name, final_plan, cte_name_span); } Ok(()) } diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 307f28e8ff9ad..b816aa6f9e69c 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -27,7 +27,9 @@ use datafusion_common::TableReference; use datafusion_common::config::SqlParserOptions; use datafusion_common::datatype::{DataTypeExt, FieldExt}; use datafusion_common::error::add_possible_columns_to_diag; -use datafusion_common::{DFSchema, DataFusionError, Result, not_impl_err, plan_err}; +use datafusion_common::{ + DFSchema, DataFusionError, Result, Span, not_impl_err, plan_err, +}; use datafusion_common::{ DFSchemaRef, Diagnostic, SchemaError, field_not_found, internal_err, plan_datafusion_err, @@ -258,9 +260,9 @@ pub struct PlannerContext { /// Data types for numbered parameters ($1, $2, etc), if supplied /// in `PREPARE` statement prepare_param_data_types: Arc>, - /// Map of CTE name to logical plan of the WITH clause. + /// Map of CTE name to logical plan of the WITH clause and optional span. /// Use `Arc` to allow cheap cloning - ctes: HashMap>, + ctes: HashMap, Option)>, /// The queries schemas of outer query relations, used to resolve the outer referenced /// columns in subquery (recursive aware) @@ -387,19 +389,35 @@ impl PlannerContext { /// Subquery for the specified name pub fn insert_cte(&mut self, cte_name: impl Into, plan: LogicalPlan) { let cte_name = cte_name.into(); - self.ctes.insert(cte_name, Arc::new(plan)); + self.ctes.insert(cte_name, (Arc::new(plan), None)); + } + + /// Inserts a LogicalPlan with an optional span for the CTE + pub(super) fn insert_cte_with_span( + &mut self, + cte_name: impl Into, + plan: LogicalPlan, + span: Option, + ) { + let cte_name = cte_name.into(); + self.ctes.insert(cte_name, (Arc::new(plan), span)); } /// Return a plan for the Common Table Expression (CTE) / Subquery for the /// specified name pub fn get_cte(&self, cte_name: &str) -> Option<&LogicalPlan> { - self.ctes.get(cte_name).map(|cte| cte.as_ref()) + self.ctes.get(cte_name).map(|(cte, _)| cte.as_ref()) } /// Remove the plan of CTE / Subquery for the specified name pub(super) fn remove_cte(&mut self, cte_name: &str) { self.ctes.remove(cte_name); } + + /// Get the span of a previously defined CTE name + pub(super) fn get_cte_span(&self, name: &str) -> Option { + self.ctes.get(name).and_then(|(_, span)| *span) + } } /// SQL query planner and binder diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index edf4b9ef79e83..a4192411ba735 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::ops::ControlFlow; use std::sync::Arc; @@ -29,7 +29,9 @@ use crate::utils::{ use datafusion_common::error::DataFusionErrorBuilder; use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion}; -use datafusion_common::{Column, DFSchema, Result, not_impl_err, plan_err}; +use datafusion_common::{ + Column, DFSchema, Diagnostic, Result, Span, not_impl_err, plan_err, +}; use datafusion_common::{RecursionUnnestOption, UnnestOptions}; use datafusion_expr::expr::{Alias, PlannedReplaceSelectItem, WildcardOptions}; use datafusion_expr::expr_rewriter::{ @@ -50,7 +52,9 @@ use sqlparser::ast::{ SelectItemQualifiedWildcardKind, WildcardAdditionalOptions, WindowType, visit_expressions_mut, }; -use sqlparser::ast::{NamedWindowDefinition, Select, SelectItem, TableWithJoins}; +use sqlparser::ast::{ + NamedWindowDefinition, Select, SelectItem, Spanned, TableFactor, TableWithJoins, +}; /// Result of the `aggregate` function, containing the aggregate plan and /// rewritten expressions that reference the aggregate output columns. @@ -690,21 +694,69 @@ impl SqlToRel<'_, S> { self.plan_table_with_joins(input, planner_context) } _ => { + let extract_table_name = + |t: &TableWithJoins| -> Option<(String, Option)> { + let span = Span::try_from_sqlparser_span(t.relation.span()); + match &t.relation { + TableFactor::Table { alias: Some(a), .. } => { + let name = + self.ident_normalizer.normalize(a.name.clone()); + Some((name, span)) + } + TableFactor::Table { + name, alias: None, .. + } => { + let table_name = name + .0 + .iter() + .filter_map(|p| p.as_ident()) + .map(|id| self.ident_normalizer.normalize(id.clone())) + .next_back()?; + Some((table_name, span)) + } + _ => None, + } + }; + + let mut alias_spans: HashMap> = HashMap::new(); + let mut from = from.into_iter(); + let first = from.next().unwrap(); + + if let Some((name, span)) = extract_table_name(&first) { + alias_spans.entry(name).or_insert(span); + } + + let mut left = LogicalPlanBuilder::from( + self.plan_table_with_joins(first, planner_context)?, + ); - let mut left = LogicalPlanBuilder::from({ - let input = from.next().unwrap(); - self.plan_table_with_joins(input, planner_context)? - }); let old_outer_from_schema = { let left_schema = Some(Arc::clone(left.schema())); planner_context.set_outer_from_schema(left_schema) }; for input in from { - // Join `input` with the current result (`left`). + let current_name = extract_table_name(&input); + + if let Some((ref name, current_span)) = current_name { + if let Some(prior_span) = alias_spans.get(name) { + let mut diagnostic = Diagnostic::new_error( + "duplicate table alias in FROM clause", + current_span, + ); + if let Some(span) = *prior_span { + diagnostic = diagnostic + .with_note("first defined here", Some(span)); + } + return plan_err!("duplicate table alias in FROM clause") + .map_err(|e| e.with_diagnostic(diagnostic)); + } + alias_spans.insert(name.clone(), current_span); + } + let right = self.plan_table_with_joins(input, planner_context)?; + left = left.cross_join(right)?; - // Update the outer FROM schema. let left_schema = Some(Arc::clone(left.schema())); planner_context.set_outer_from_schema(left_schema); } diff --git a/datafusion/sql/tests/cases/diagnostic.rs b/datafusion/sql/tests/cases/diagnostic.rs index 7a729739469d3..6858a5709a82d 100644 --- a/datafusion/sql/tests/cases/diagnostic.rs +++ b/datafusion/sql/tests/cases/diagnostic.rs @@ -390,3 +390,81 @@ fn test_syntax_error() -> Result<()> { }, } } + +#[test] +fn test_duplicate_cte_name() -> Result<()> { + let query = "WITH /*a*/cte/*a*/ AS (SELECT 1 AS col), /*b*/cte/*b*/ AS (SELECT 2 AS col) SELECT 1"; + let spans = get_spans(query); + let diag = do_query(query); + assert_snapshot!(diag.message, @r#"WITH query name "cte" specified more than once"#); + assert_eq!(diag.span, Some(spans["b"])); + assert_eq!(diag.notes.len(), 1); + assert_snapshot!(diag.notes[0].message, @"previously defined here"); + assert_eq!(diag.notes[0].span, Some(spans["a"])); + Ok(()) +} + +#[test] +fn test_duplicate_table_alias() -> Result<()> { + let query = "SELECT * FROM /*a*/person a/*a*/, /*b*/person a/*b*/"; + let spans = get_spans(query); + let diag = do_query(query); + assert_snapshot!(diag.message, @"duplicate table alias in FROM clause"); + assert_eq!(diag.span, Some(spans["b"])); + assert_eq!(diag.notes.len(), 1); + assert_snapshot!(diag.notes[0].message, @"first defined here"); + assert_eq!(diag.notes[0].span, Some(spans["a"])); + Ok(()) +} + +#[test] +fn test_duplicate_table_alias_not_first() -> Result<()> { + let query = "SELECT * FROM person a, /*b*/test_decimal b/*b*/, /*c*/person b/*c*/"; + let spans = get_spans(query); + let diag = do_query(query); + assert_snapshot!(diag.message, @"duplicate table alias in FROM clause"); + assert_eq!(diag.span, Some(spans["c"])); + assert_eq!(diag.notes.len(), 1); + assert_snapshot!(diag.notes[0].message, @"first defined here"); + assert_eq!(diag.notes[0].span, Some(spans["b"])); + Ok(()) +} + +#[test] +fn test_duplicate_bare_table_in_from() -> Result<()> { + let query = "SELECT * FROM /*a*/person/*a*/, /*b*/person/*b*/"; + let spans = get_spans(query); + let diag = do_query(query); + assert_snapshot!(diag.message, @"duplicate table alias in FROM clause"); + assert_eq!(diag.span, Some(spans["b"])); + assert_eq!(diag.notes.len(), 1); + assert_snapshot!(diag.notes[0].message, @"first defined here"); + assert_eq!(diag.notes[0].span, Some(spans["a"])); + Ok(()) +} + +#[test] +fn test_duplicate_alias_non_overlapping_columns() -> Result<()> { + let query = "SELECT * FROM /*a*/j1 AS t/*a*/, /*b*/j2 AS t/*b*/"; + let spans = get_spans(query); + let diag = do_query(query); + assert_snapshot!(diag.message, @"duplicate table alias in FROM clause"); + assert_eq!(diag.span, Some(spans["b"])); + assert_eq!(diag.notes.len(), 1); + assert_snapshot!(diag.notes[0].message, @"first defined here"); + assert_eq!(diag.notes[0].span, Some(spans["a"])); + Ok(()) +} + +#[test] +fn test_duplicate_alias_non_overlapping_three_tables() -> Result<()> { + let query = "SELECT * FROM j1 AS x, /*a*/j2 AS t/*a*/, j3 AS y, /*b*/j1 AS t/*b*/"; + let spans = get_spans(query); + let diag = do_query(query); + assert_snapshot!(diag.message, @"duplicate table alias in FROM clause"); + assert_eq!(diag.span, Some(spans["b"])); + assert_eq!(diag.notes.len(), 1); + assert_snapshot!(diag.notes[0].message, @"first defined here"); + assert_eq!(diag.notes[0].span, Some(spans["a"])); + Ok(()) +}