Skip to content

Commit 7a0dfab

Browse files
committed
support for parsing WITH clauses
1 parent e400a71 commit 7a0dfab

3 files changed

Lines changed: 256 additions & 3 deletions

File tree

src/select.rs

Lines changed: 231 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,18 @@ use std::fmt::{Display, Formatter};
44
use std::str;
55

66
use column::Column;
7-
use common::FieldDefinitionExpression;
87
use common::{
98
as_alias, field_definition_expr, field_list, statement_terminator, table_list, table_reference,
109
unsigned_number,
1110
};
11+
use common::{sql_identifier, FieldDefinitionExpression};
1212
use compound_select::nested_compound_selection;
1313
use condition::{condition_expr, ConditionExpression};
1414
use join::{join_operator, JoinConstraint, JoinOperator, JoinRightSide};
1515
use nom::branch::alt;
1616
use nom::bytes::complete::{tag, tag_no_case};
1717
use nom::combinator::{map, opt};
18-
use nom::multi::many0;
18+
use nom::multi::{many0, separated_list1};
1919
use nom::sequence::{delimited, preceded, terminated, tuple};
2020
use nom::IResult;
2121
use order::{order_clause, OrderClause};
@@ -126,8 +126,10 @@ pub fn simple_selection(i: &[u8]) -> IResult<&[u8], SelectStatement> {
126126
pub fn nested_simple_selection(i: &[u8]) -> IResult<&[u8], SelectStatement> {
127127
let (
128128
remaining_input,
129-
(_, _, distinct, _, fields, _, tables, join, where_clause, group_by, order, limit),
129+
(with, _, _, _, distinct, _, fields, _, tables, join, where_clause, group_by, order, limit),
130130
) = tuple((
131+
opt(with_clause),
132+
multispace0,
131133
tag_no_case("select"),
132134
multispace1,
133135
opt(tag_no_case("distinct")),
@@ -144,6 +146,7 @@ pub fn nested_simple_selection(i: &[u8]) -> IResult<&[u8], SelectStatement> {
144146
Ok((
145147
remaining_input,
146148
SelectStatement {
149+
with,
147150
tables,
148151
distinct: distinct.is_some(),
149152
fields,
@@ -156,8 +159,66 @@ pub fn nested_simple_selection(i: &[u8]) -> IResult<&[u8], SelectStatement> {
156159
))
157160
}
158161

162+
#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
163+
pub struct WithClause {
164+
pub recursive: bool,
165+
pub subclauses: Vec<WithSubclause>,
166+
}
167+
168+
impl fmt::Display for WithClause {
169+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
170+
write!(f, "WITH ")?;
171+
172+
if self.recursive {
173+
write!(f, "RECURSIVE ")?;
174+
}
175+
176+
write!(
177+
f,
178+
"{}",
179+
self.subclauses
180+
.iter()
181+
.map(|c| format!("{}", c))
182+
.collect::<Vec<_>>()
183+
.join(", ")
184+
)?;
185+
186+
Ok(())
187+
}
188+
}
189+
190+
#[derive(Clone, Debug, Eq, Hash, PartialEq, Serialize, Deserialize)]
191+
pub struct WithSubclause {
192+
pub name: String,
193+
pub columns: Vec<Column>,
194+
pub selection: Box<Selection>,
195+
}
196+
197+
impl fmt::Display for WithSubclause {
198+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
199+
write!(f, "{} ", self.name)?;
200+
201+
if self.columns.len() > 0 {
202+
write!(
203+
f,
204+
"({}) ",
205+
self.columns
206+
.iter()
207+
.map(|c| format!("{}", c))
208+
.collect::<Vec<_>>()
209+
.join(", ")
210+
)?;
211+
}
212+
213+
write!(f, "AS ({})", self.selection)?;
214+
215+
Ok(())
216+
}
217+
}
218+
159219
#[derive(Clone, Debug, Default, Eq, Hash, PartialEq, Serialize, Deserialize)]
160220
pub struct SelectStatement {
221+
pub with: Option<WithClause>,
161222
pub tables: Vec<Table>,
162223
pub distinct: bool,
163224
pub fields: Vec<FieldDefinitionExpression>,
@@ -170,6 +231,10 @@ pub struct SelectStatement {
170231

171232
impl fmt::Display for SelectStatement {
172233
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
234+
if let Some(ref with_clause) = self.with {
235+
write!(f, "{}", with_clause)?;
236+
}
237+
173238
write!(f, "SELECT ")?;
174239
if self.distinct {
175240
write!(f, "DISTINCT ")?;
@@ -348,6 +413,60 @@ pub fn where_clause(i: &[u8]) -> IResult<&[u8], ConditionExpression> {
348413
Ok((remaining_input, where_condition))
349414
}
350415

416+
pub fn with_clause(i: &[u8]) -> IResult<&[u8], WithClause> {
417+
map(
418+
tuple((
419+
tag_no_case("with"),
420+
multispace1,
421+
opt(tag_no_case("recursive")),
422+
multispace0,
423+
separated_list1(tuple((multispace0, tag(","), multispace0)), with_subclause),
424+
)),
425+
|(_, _, recursive, _, subclauses)| WithClause {
426+
recursive: recursive.is_some(),
427+
subclauses,
428+
},
429+
)(i)
430+
}
431+
432+
pub fn with_subclause(i: &[u8]) -> IResult<&[u8], WithSubclause> {
433+
map(
434+
tuple((
435+
sql_identifier,
436+
multispace1,
437+
opt(with_clause_column_list),
438+
multispace0,
439+
tag_no_case("as"),
440+
multispace1,
441+
tag("("),
442+
multispace0,
443+
nested_selection,
444+
multispace0,
445+
tag(")"),
446+
)),
447+
|(name, _, columns, _, _, _, _, _, selection, _, _)| WithSubclause {
448+
name: str::from_utf8(name).unwrap().to_string(),
449+
columns: columns.unwrap_or(vec![]),
450+
selection: Box::new(selection),
451+
},
452+
)(i)
453+
}
454+
455+
pub fn with_clause_column_list(i: &[u8]) -> IResult<&[u8], Vec<Column>> {
456+
let (i, (_, _, columns, _, _)) = tuple((
457+
tag("("),
458+
multispace0,
459+
separated_list1(
460+
tuple((multispace0, tag(","), multispace0)),
461+
map(sql_identifier, |si| str::from_utf8(si).unwrap().into()),
462+
),
463+
multispace0,
464+
tag(")"),
465+
))(i)?;
466+
467+
Ok((i, columns))
468+
}
469+
351470
#[cfg(test)]
352471
mod tests {
353472
use super::*;
@@ -1454,4 +1573,113 @@ mod tests {
14541573

14551574
assert_eq!(res.unwrap().1, expected.into());
14561575
}
1576+
1577+
#[test]
1578+
fn with() {
1579+
let qstr0 = "WITH cte1 AS (SELECT a, b FROM table1)";
1580+
let qstr1 = "WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, d FROM table2)";
1581+
let qstr2 =
1582+
"WITH cte1 (e, f) AS (SELECT a, b FROM table1), cte2 AS (SELECT c, d FROM table2)";
1583+
let qstr3 = "WITH RECURSIVE cte1 AS (SELECT a, b FROM table1)";
1584+
let res0 = with_clause(qstr0.as_bytes());
1585+
let res1 = with_clause(qstr1.as_bytes());
1586+
let res2 = with_clause(qstr2.as_bytes());
1587+
let res3 = with_clause(qstr3.as_bytes());
1588+
1589+
let expected_ss0 = Box::new(Selection::Statement(SelectStatement {
1590+
with: None,
1591+
tables: vec![Table {
1592+
name: "table1".to_string(),
1593+
alias: None,
1594+
schema: None,
1595+
}],
1596+
distinct: false,
1597+
fields: vec![
1598+
FieldDefinitionExpression::Col(Column::from("a")),
1599+
FieldDefinitionExpression::Col(Column::from("b")),
1600+
],
1601+
join: vec![],
1602+
where_clause: None,
1603+
group_by: None,
1604+
order: None,
1605+
limit: None,
1606+
}));
1607+
let expected_ss1 = Box::new(Selection::Statement(SelectStatement {
1608+
tables: vec![Table {
1609+
name: "table2".to_string(),
1610+
alias: None,
1611+
schema: None,
1612+
}],
1613+
fields: vec![
1614+
FieldDefinitionExpression::Col(Column::from("c")),
1615+
FieldDefinitionExpression::Col(Column::from("d")),
1616+
],
1617+
..Default::default()
1618+
}));
1619+
1620+
let expected0 = WithClause {
1621+
recursive: false,
1622+
subclauses: vec![WithSubclause {
1623+
name: "cte1".to_string(),
1624+
columns: vec![],
1625+
selection: expected_ss0.clone(),
1626+
}],
1627+
};
1628+
let expected1 = WithClause {
1629+
recursive: false,
1630+
subclauses: vec![
1631+
WithSubclause {
1632+
name: "cte1".to_string(),
1633+
columns: vec![],
1634+
selection: expected_ss0.clone(),
1635+
},
1636+
WithSubclause {
1637+
name: "cte2".to_string(),
1638+
columns: vec![],
1639+
selection: expected_ss1.clone(),
1640+
},
1641+
],
1642+
};
1643+
let expected2 = WithClause {
1644+
recursive: false,
1645+
subclauses: vec![
1646+
WithSubclause {
1647+
name: "cte1".to_string(),
1648+
columns: vec![
1649+
Column {
1650+
name: "e".to_string(),
1651+
alias: None,
1652+
table: None,
1653+
function: None,
1654+
},
1655+
Column {
1656+
name: "f".to_string(),
1657+
alias: None,
1658+
table: None,
1659+
function: None,
1660+
},
1661+
],
1662+
selection: expected_ss0.clone(),
1663+
},
1664+
WithSubclause {
1665+
name: "cte2".to_string(),
1666+
columns: vec![],
1667+
selection: expected_ss1.clone(),
1668+
},
1669+
],
1670+
};
1671+
let expected3 = WithClause {
1672+
recursive: true,
1673+
subclauses: vec![WithSubclause {
1674+
name: "cte1".to_string(),
1675+
columns: vec![],
1676+
selection: expected_ss0.clone(),
1677+
}],
1678+
};
1679+
1680+
assert_eq!(res0.unwrap().1, expected0);
1681+
assert_eq!(res1.unwrap().1, expected1);
1682+
assert_eq!(res2.unwrap().1, expected2);
1683+
assert_eq!(res3.unwrap().1, expected3);
1684+
}
14571685
}

tests/cte-queries.txt

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
-- simple CTE
2+
WITH cte1 AS (SELECT a, b FROM table1) SELECT b, d FROM cte1;
3+
4+
-- 2 CTEs
5+
WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, d FROM table2) SELECT b, d FROM cte1 JOIN cte2 ON cte1.a = cte2.c;
6+
7+
-- CTE in an exists
8+
SELECT 'found' FROM DUAL WHERE EXISTS (WITH cte1 AS (SELECT a, b FROM table1), cte2 AS (SELECT c, d FROM table2) SELECT b, d FROM cte1 JOIN cte2 ON cte1.a = cte2.c);
9+
10+
-- recursive cte
11+
WITH RECURSIVE cte1 AS (SELECT 1 AS a, 0 AS b FROM dual UNION ALL SELECT cte1.a + 1, cte1.b - 1 FROM table1), cte2 AS (SELECT c, d FROM table2) SELECT b, d FROM cte1 JOIN cte2 ON cte1.a = cte2.c;
12+
13+
-- recursive cte with multiple initialization parts
14+
WITH RECURSIVE cte1 AS (SELECT 1 AS a, 0 AS b FROM dual UNION SELECT MAX(a) as a, MAX(b) as b FROM dual UNION ALL SELECT cte1.a + 1 AS a, cte1.b - 1 AS b FROM cte1), cte2 AS (SELECT c, d FROM table2) SELECT b, d FROM cte1 JOIN cte2 ON cte1.a = cte2.c;
15+
16+
-- recursive cte with multiple recursive parts
17+
WITH RECURSIVE cte1 AS (SELECT 1 AS a, 0 AS b FROM dual UNION SELECT MAX(a) as a, MAX(b) as b FROM dual UNION ALL SELECT cte1.a + 1, cte1.b - 1 FROM cte1 UNION SELECT MAX(a) as a, MAX(b) as b FROM table2 ), cte2 AS (SELECT c, d FROM table2) SELECT b, d FROM cte1 JOIN cte2 ON cte1.a = cte2.c;

tests/lib.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,14 @@ fn tpcw_test_tables() {
123123
assert_eq!(res.unwrap(), 10);
124124
}
125125

126+
#[test]
127+
fn cte_queries() {
128+
let res = test_queries_from_file(Path::new("tests/cte-queries.txt"), "CTE queries");
129+
assert!(res.is_ok());
130+
// There are 6 queries
131+
assert_eq!(res.unwrap(), 6);
132+
}
133+
126134
#[test]
127135
fn exists_test_queries() {
128136
let res = test_queries_from_file(

0 commit comments

Comments
 (0)