Skip to content

Commit 0734396

Browse files
committed
Support selecting wildcards in node queries
1 parent 6e37fbc commit 0734396

8 files changed

Lines changed: 440 additions & 110 deletions

File tree

datajunction-server/datajunction_server/internal/namespaces.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from sqlalchemy import or_, select
1010
from sqlalchemy.ext.asyncio import AsyncSession
11-
from sqlalchemy.orm import joinedload
11+
from sqlalchemy.orm import joinedload, selectinload
1212

1313
from datajunction_server.api.helpers import get_node_namespace
1414
from datajunction_server.database.history import ActivityType, EntityType, History
@@ -26,6 +26,7 @@
2626
)
2727
from datajunction_server.models.node import NodeMinimumDetail
2828
from datajunction_server.models.node_type import NodeType
29+
from datajunction_server.sql.dag import topological_sort
2930
from datajunction_server.typing import UTCDatetime
3031
from datajunction_server.utils import SEPARATOR
3132

@@ -229,10 +230,10 @@ async def hard_delete_namespace(
229230
"""
230231
Hard delete a node namespace.
231232
"""
232-
node_names = (
233+
nodes = (
233234
(
234235
await session.execute(
235-
select(Node.name)
236+
select(Node)
236237
.where(
237238
or_(
238239
Node.namespace.like(
@@ -241,27 +242,33 @@ async def hard_delete_namespace(
241242
Node.namespace == namespace,
242243
),
243244
)
244-
.order_by(Node.name),
245+
.order_by(Node.name)
246+
.options(
247+
joinedload(Node.current).options(
248+
selectinload(NodeRevision.parents),
249+
),
250+
),
245251
)
246252
)
253+
.unique()
247254
.scalars()
248255
.all()
249256
)
250257

251-
if not cascade and node_names:
258+
if not cascade and nodes:
252259
raise DJActionNotAllowedException(
253260
message=(
254261
f"Cannot hard delete namespace `{namespace}` as there are still the "
255-
f"following nodes under it: `{node_names}`. Set `cascade` to true to "
256-
"additionally hard delete the above nodes in this namespace. WARNING:"
262+
f"following nodes under it: `{[node.name for node in nodes]}`. Set `cascade` to "
263+
"true to additionally hard delete the above nodes in this namespace. WARNING:"
257264
" this action cannot be undone."
258265
),
259266
)
260267

261268
impacts = {}
262-
for node_name in node_names:
263-
impacts[node_name] = await hard_delete_node(
264-
node_name,
269+
for node in reversed(topological_sort(nodes)):
270+
impacts[node.name] = await hard_delete_node(
271+
node.name,
265272
session,
266273
current_user=current_user,
267274
)

datajunction-server/datajunction_server/internal/nodes.py

Lines changed: 56 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -868,7 +868,7 @@ async def propagate_update_downstream( # pylint: disable=too-many-locals
868868
)
869869

870870
# The downstreams need to be sorted topologically in order for the updates to be done
871-
# in the right order. Otherwise it is possible for a leaf node like a metric to be updated
871+
# in the right order. Otherwise, it is possible for a leaf node like a metric to be updated
872872
# before its upstreams are updated.
873873
for downstream in downstreams:
874874
original_node_revision = downstream.current
@@ -1866,14 +1866,13 @@ async def revalidate_node( # pylint: disable=too-many-locals,too-many-statement
18661866

18671867
# Check if any columns have been updated
18681868
existing_columns = {col.name: col for col in node.current.columns} # type: ignore
1869-
updated_columns = False
1869+
updated_columns = len(current_node_revision.columns) != len(node_validator.columns)
18701870
for col in node_validator.columns:
18711871
if existing_col := existing_columns.get(col.name):
18721872
if existing_col.type != col.type:
18731873
existing_col.type = col.type
18741874
updated_columns = True
18751875
else:
1876-
node.current.columns.append(col) # type: ignore # pragma: no cover
18771876
updated_columns = True # pragma: no cover
18781877

18791878
# Only create a new revision if the columns have been updated
@@ -1893,16 +1892,16 @@ async def revalidate_node( # pylint: disable=too-many-locals,too-many-statement
18931892
node_validator.updated_columns = node_validator.modified_columns(
18941893
new_revision, # type: ignore
18951894
)
1896-
new_revision.columns = node_validator.columns
18971895

18981896
# Save the new revision of the child
1897+
new_revision.columns = node_validator.columns
18991898
node.current_version = new_revision.version # type: ignore
19001899
new_revision.node_id = node.id # type: ignore
1901-
session.add(node)
19021900
session.add(new_revision)
1903-
await session.commit()
1904-
await session.refresh(node.current) # type: ignore
1905-
await session.refresh(node, ["current"])
1901+
session.add(node)
1902+
await session.commit()
1903+
await session.refresh(node.current) # type: ignore
1904+
await session.refresh(node, ["current"])
19061905
return node_validator
19071906

19081907

@@ -1918,7 +1917,14 @@ async def hard_delete_node(
19181917
node = await Node.get_by_name(
19191918
session,
19201919
name,
1921-
options=[joinedload(Node.current), joinedload(Node.revisions)],
1920+
options=[
1921+
joinedload(Node.current),
1922+
joinedload(Node.revisions).options(
1923+
selectinload(NodeRevision.columns).options(
1924+
joinedload(Column.attributes),
1925+
),
1926+
),
1927+
],
19221928
include_inactive=True,
19231929
raise_if_not_exists=False,
19241930
)
@@ -1946,42 +1952,50 @@ async def hard_delete_node(
19461952
user=current_user.username if current_user else None,
19471953
),
19481954
)
1949-
node_validator = await revalidate_node(
1950-
name=node.name,
1951-
session=session,
1952-
current_user=current_user,
1953-
)
1954-
impact.append(
1955-
{
1956-
"name": node.name,
1957-
"status": node_validator.status,
1958-
"effect": "downstream node is now invalid",
1959-
},
1960-
)
1955+
try:
1956+
node_validator = await revalidate_node(
1957+
name=node.name,
1958+
session=session,
1959+
current_user=current_user,
1960+
)
1961+
impact.append(
1962+
{
1963+
"name": node.name,
1964+
"status": node_validator.status,
1965+
"effect": "downstream node is now invalid",
1966+
},
1967+
)
1968+
except DJNodeNotFound:
1969+
_logger.warning("Node not found %s", node.name)
19611970

19621971
# Revalidate all linked nodes
19631972
for node in linked_nodes:
1964-
session.add( # Capture this in the downstream node's history
1965-
History(
1966-
entity_type=EntityType.LINK,
1967-
entity_name=name,
1968-
node=node.name,
1969-
activity_type=ActivityType.DELETE,
1970-
user=current_user.username if current_user else None,
1971-
),
1972-
)
1973-
node_validator = await revalidate_node(
1974-
name=node.name,
1975-
session=session,
1976-
current_user=current_user,
1977-
)
1978-
impact.append(
1979-
{
1980-
"name": node.name,
1981-
"status": node_validator.status,
1982-
"effect": "broken link",
1983-
},
1984-
)
1973+
if node:
1974+
session.add( # Capture this in the downstream node's history
1975+
History(
1976+
entity_type=EntityType.LINK,
1977+
entity_name=name,
1978+
node=node.name,
1979+
activity_type=ActivityType.DELETE,
1980+
user=current_user.username if current_user else None,
1981+
),
1982+
)
1983+
try:
1984+
node_validator = await revalidate_node(
1985+
name=node.name,
1986+
session=session,
1987+
current_user=current_user,
1988+
# update=False,
1989+
)
1990+
impact.append(
1991+
{
1992+
"name": node.name,
1993+
"status": node_validator.status,
1994+
"effect": "broken link",
1995+
},
1996+
)
1997+
except DJNodeNotFound:
1998+
_logger.warning("Node not found %s", node.name)
19851999
session.add( # Capture this in the downstream node's history
19862000
History(
19872001
entity_type=EntityType.NODE,

datajunction-server/datajunction_server/internal/validation.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
"""Node validation functions."""
22
from dataclasses import dataclass, field
3-
from typing import Dict, List, Set, Union
3+
from typing import Dict, List, Optional, Set, Union
44

55
from sqlalchemy.exc import MissingGreenlet
66
from sqlalchemy.ext.asyncio import AsyncSession
77

88
from datajunction_server.api.helpers import find_bound_dimensions
9-
from datajunction_server.database import Column, Node, NodeRevision
9+
from datajunction_server.database import Column, ColumnAttribute, Node, NodeRevision
1010
from datajunction_server.errors import DJError, DJException, ErrorCode
1111
from datajunction_server.models.base import labelize
1212
from datajunction_server.models.node import NodeRevisionBase, NodeStatus
@@ -22,6 +22,7 @@ class NodeValidator: # pylint: disable=too-many-instance-attributes
2222
Node validation
2323
"""
2424

25+
query_ast: Optional[ast.Query] = None
2526
status: NodeStatus = NodeStatus.VALID
2627
columns: List[Column] = field(default_factory=list)
2728
required_dimensions: List[Column] = field(default_factory=list)
@@ -128,7 +129,12 @@ async def validate_node_data( # pylint: disable=too-many-locals,too-many-statem
128129
name=column_name,
129130
display_name=labelize(column_name),
130131
type=column_type,
131-
attributes=existing_column.attributes if existing_column else [],
132+
attributes=[
133+
ColumnAttribute(attribute_type=col_attr.attribute_type)
134+
for col_attr in existing_column.attributes
135+
]
136+
if existing_column
137+
else [],
132138
dimension=existing_column.dimension if existing_column else None,
133139
order=idx,
134140
)

datajunction-server/datajunction_server/sql/parsing/ast.py

Lines changed: 81 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from datajunction_server.models.column import SemanticType
3939
from datajunction_server.models.node import BuildCriteria
4040
from datajunction_server.models.node_type import NodeType as DJNodeType
41+
from datajunction_server.naming import SEPARATOR
4142
from datajunction_server.sql.functions import function_registry, table_function_registry
4243
from datajunction_server.sql.parsing.backends.exceptions import DJParseException
4344
from datajunction_server.sql.parsing.types import (
@@ -1102,7 +1103,6 @@ class Wildcard(Named, Expression):
11021103
Wildcard or '*' expression
11031104
"""
11041105

1105-
name: Name = field(init=False, repr=False, default=Name("*"))
11061106
_table: Optional["Table"] = field(repr=False, default=None)
11071107

11081108
@property
@@ -1121,12 +1121,65 @@ def add_table(self, table: "Table") -> "Wildcard":
11211121
return self
11221122

11231123
def __str__(self) -> str:
1124-
return "*"
1124+
return (self.namespace[0].name + SEPARATOR if self.namespace else "") + "*"
11251125

11261126
@property
11271127
def type(self) -> ColumnType:
11281128
return WildcardType()
11291129

1130+
async def compile(self, ctx: CompileContext):
1131+
"""
1132+
Compile a Wildcard AST node. If the wildcard is used in a SELECT statement, we
1133+
replace it with the equivalent of explicitly selecting all upstream columns.
1134+
"""
1135+
if not isinstance(self.parent, Select):
1136+
return super().compile(ctx)
1137+
1138+
wildcard_parent = cast(Select, self.parent)
1139+
1140+
# If a wildcard is used in a SELECT statement, we pull the columns that it
1141+
# represents by scanning all relations in the FROM clause (both the primary table
1142+
# and any joined tables)
1143+
wildcard_table_namespace = (
1144+
self.namespace[0].name
1145+
if self.namespace and isinstance(self.namespace, list)
1146+
else None
1147+
)
1148+
from_relations = [
1149+
wildcard_parent.from_.relations[0].primary,
1150+
*[ext.right for ext in wildcard_parent.from_.relations[0].extensions],
1151+
]
1152+
1153+
for relation in from_relations:
1154+
await relation.compile(ctx)
1155+
if (
1156+
not wildcard_table_namespace
1157+
or relation.alias_or_name.name == wildcard_table_namespace
1158+
):
1159+
# Figure out where the relation's columns are stored depending on the relation type
1160+
if isinstance(relation, Table):
1161+
wildcard_origin = cast(Table, relation)
1162+
if wildcard_origin_node := wildcard_origin.dj_node:
1163+
relation_columns = wildcard_origin_node.columns
1164+
else:
1165+
relation_columns = wildcard_origin._cte_columns
1166+
else:
1167+
wildcard_origin = cast(Query, relation)
1168+
relation_columns = wildcard_origin.select.projection
1169+
1170+
# Use these columns to replace the wildcard
1171+
for col in relation_columns:
1172+
wildcard_parent.projection.append(
1173+
Column(
1174+
name=Name(col.name)
1175+
if isinstance(col.name, str)
1176+
else col.name,
1177+
_table=wildcard_origin,
1178+
_type=col.type,
1179+
),
1180+
)
1181+
wildcard_parent.projection.remove(self)
1182+
11301183

11311184
@dataclass(eq=False)
11321185
class TableExpression(Aliasable, Expression):
@@ -1140,6 +1193,7 @@ class TableExpression(Aliasable, Expression):
11401193
) # all those expressions that can be had from the table; usually derived from dj node metadata for Table
11411194
# ref (referenced) columns are columns used elsewhere from this table
11421195
_ref_columns: List[Column] = field(init=False, repr=False, default_factory=list)
1196+
_cte_columns: List[Expression] = field(default_factory=list)
11431197

11441198
@property
11451199
def columns(self) -> List[Expression]:
@@ -1345,8 +1399,12 @@ def set_alias(self: TNode, alias: "Name") -> TNode:
13451399
return self
13461400

13471401
async def compile(self, ctx: CompileContext):
1348-
# things we can validate here:
1349-
# - if the node is a dimension in a groupby, is it joinable?
1402+
"""
1403+
Compile a Table AST node by finding and saving the columns it references
1404+
"""
1405+
if self._is_compiled:
1406+
return
1407+
13501408
self._is_compiled = True
13511409
try:
13521410
if not self.dj_node:
@@ -1356,12 +1414,26 @@ async def compile(self, ctx: CompileContext):
13561414
{DJNodeType.SOURCE, DJNodeType.TRANSFORM, DJNodeType.DIMENSION},
13571415
)
13581416
self.set_dj_node(dj_node)
1417+
except DJErrorException as exc:
1418+
ctx.exception.errors.append(exc.dj_error)
1419+
1420+
if self.dj_node:
1421+
# If the Table object is a reference to a DJ node, save the columns of the
1422+
# DJ node into self._columns for later use
13591423
self._columns = [
13601424
Column(Name(col.name), _type=col.type, _table=self)
13611425
for col in self.dj_node.columns
13621426
]
1363-
except DJErrorException as exc:
1364-
ctx.exception.errors.append(exc.dj_error)
1427+
elif query := self.get_nearest_parent_of_type(Query):
1428+
# If the Table object is a reference to a CTE, save the columns output by
1429+
# the CTE into self._columns for later use
1430+
for cte in query.ctes:
1431+
if self.alias_or_name.name == cte.alias_or_name.name:
1432+
await cte.compile(ctx)
1433+
self._cte_columns = [
1434+
Column(col.alias_or_name, _type=col.type, _table=self)
1435+
for col in cte._columns
1436+
]
13651437

13661438

13671439
class Operation(Expression):
@@ -2565,6 +2637,9 @@ async def compile(self, ctx: CompileContext):
25652637
),
25662638
)
25672639
await super().compile(ctx)
2640+
for child in self.projection:
2641+
if isinstance(child, Wildcard):
2642+
await child.compile(ctx)
25682643

25692644

25702645
@dataclass(eq=False)

0 commit comments

Comments
 (0)