Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions data_diff/queries/ast_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,8 +635,19 @@ def source_table(self) -> "ITable":

@property
def schema(self) -> Schema:
# TODO add cte to schema
return self.table.schema
s = self.table.schema
if not self.params:
return s
if s is None:
raise QueryBuilderError(f"CTE params were provided ({self.params!r}) but the source table has no schema")
if len(self.params) != len(s):
raise QueryBuilderError(
f"CTE params length ({len(self.params)}) does not match source schema length ({len(s)})"
)
result = type(s)(dict(zip(self.params, s.values())))
if len(result) != len(s):
raise QueryBuilderError(f"CTE params contain duplicate column names: {self.params!r}")
return result


def _named_exprs_as_aliases(named_exprs):
Expand Down
53 changes: 52 additions & 1 deletion tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from data_diff.abcs.database_types import FractionalType, TemporalType
from data_diff.databases.base import BaseDialect, CompileError, Compiler, Database
from data_diff.queries.api import coalesce, code, cte, outerjoin, table, this, when
from data_diff.queries.ast_classes import Random
from data_diff.queries.ast_classes import QueryBuilderError, Random
from data_diff.utils import CaseInsensitiveDict, CaseSensitiveDict


Expand Down Expand Up @@ -196,6 +196,57 @@ def test_cte(self):
expected = "WITH tmp1(y) AS (SELECT x FROM a) SELECT y FROM tmp1"
assert normalize_spaces(c.dialect.compile(c, t3)) == expected

def test_cte_schema(self):
# Non-parameterized CTE passes through source schema unchanged
t = table("a", schema=CaseSensitiveDict({"x": int, "y": str}))
ct = cte(t.select(this.x, this.y))
assert ct.schema == t.schema

# Parameterized CTE reflects renamed columns with correct types
t = table("a", schema=CaseSensitiveDict({"x": int, "y": str}))
ct = cte(t.select(this.x, this.y), params=["a", "b"])
s = ct.schema
assert list(s.keys()) == ["a", "b"]
assert list(s.values()) == [int, str]

# Param count mismatch raises QueryBuilderError
t = table("a", schema=CaseSensitiveDict({"x": int, "y": str}))
ct = cte(t.select(this.x, this.y), params=["a"])
with self.assertRaises(QueryBuilderError):
_ = ct.schema

# Schema type (case sensitivity) is preserved with correct values
t = table("a", schema=CaseInsensitiveDict({"X": int, "Y": str}))
ct = cte(t.select(this.X, this.Y), params=["A", "B"])
s = ct.schema
assert isinstance(s, CaseInsensitiveDict)
assert s["A"] is int
assert s["a"] is int
assert s["B"] is str

# Duplicate params raises QueryBuilderError
t = table("a", schema=CaseSensitiveDict({"x": int, "y": str}))
ct = cte(t.select(this.x, this.y), params=["a", "a"])
with self.assertRaises(QueryBuilderError):
_ = ct.schema

# Case-insensitive duplicate params raises QueryBuilderError
t = table("a", schema=CaseInsensitiveDict({"X": int, "Y": str}))
ct = cte(t.select(this.X, this.Y), params=["A", "a"])
with self.assertRaises(QueryBuilderError):
_ = ct.schema

# Params on a schema-less source raises QueryBuilderError
t = table("a")
ct = cte(t.select(this.x), params=["renamed"])
with self.assertRaises(QueryBuilderError):
_ = ct.schema

# Empty params list passes through source schema unchanged
t = table("a", schema=CaseSensitiveDict({"x": int}))
ct = cte(t.select(this.x), params=[])
assert ct.schema == t.schema

def test_funcs(self):
c = Compiler(MockDatabase())
t = table("a")
Expand Down
Loading