diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 3cf52d31..c5fb8179 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -635,8 +635,19 @@ def source_table(self) -> "ITable": @property def schema(self) -> Schema: - # TODO add cte to schema - return self.table.schema + s = self.table.schema + if not self.params: + return s + if s is None: + raise QueryBuilderError(f"CTE params were provided ({self.params!r}) but the source table has no schema") + if len(self.params) != len(s): + raise QueryBuilderError( + f"CTE params length ({len(self.params)}) does not match source schema length ({len(s)})" + ) + result = type(s)(dict(zip(self.params, s.values()))) + if len(result) != len(s): + raise QueryBuilderError(f"CTE params contain duplicate column names: {self.params!r}") + return result def _named_exprs_as_aliases(named_exprs): diff --git a/tests/test_query.py b/tests/test_query.py index 42903de9..722a809c 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -8,7 +8,7 @@ from data_diff.abcs.database_types import FractionalType, TemporalType from data_diff.databases.base import BaseDialect, CompileError, Compiler, Database from data_diff.queries.api import coalesce, code, cte, outerjoin, table, this, when -from data_diff.queries.ast_classes import Random +from data_diff.queries.ast_classes import QueryBuilderError, Random from data_diff.utils import CaseInsensitiveDict, CaseSensitiveDict @@ -196,6 +196,57 @@ def test_cte(self): expected = "WITH tmp1(y) AS (SELECT x FROM a) SELECT y FROM tmp1" assert normalize_spaces(c.dialect.compile(c, t3)) == expected + def test_cte_schema(self): + # Non-parameterized CTE passes through source schema unchanged + t = table("a", schema=CaseSensitiveDict({"x": int, "y": str})) + ct = cte(t.select(this.x, this.y)) + assert ct.schema == t.schema + + # Parameterized CTE reflects renamed columns with correct types + t = table("a", schema=CaseSensitiveDict({"x": int, "y": str})) + ct = cte(t.select(this.x, this.y), params=["a", "b"]) + s = ct.schema + assert list(s.keys()) == ["a", "b"] + assert list(s.values()) == [int, str] + + # Param count mismatch raises QueryBuilderError + t = table("a", schema=CaseSensitiveDict({"x": int, "y": str})) + ct = cte(t.select(this.x, this.y), params=["a"]) + with self.assertRaises(QueryBuilderError): + _ = ct.schema + + # Schema type (case sensitivity) is preserved with correct values + t = table("a", schema=CaseInsensitiveDict({"X": int, "Y": str})) + ct = cte(t.select(this.X, this.Y), params=["A", "B"]) + s = ct.schema + assert isinstance(s, CaseInsensitiveDict) + assert s["A"] is int + assert s["a"] is int + assert s["B"] is str + + # Duplicate params raises QueryBuilderError + t = table("a", schema=CaseSensitiveDict({"x": int, "y": str})) + ct = cte(t.select(this.x, this.y), params=["a", "a"]) + with self.assertRaises(QueryBuilderError): + _ = ct.schema + + # Case-insensitive duplicate params raises QueryBuilderError + t = table("a", schema=CaseInsensitiveDict({"X": int, "Y": str})) + ct = cte(t.select(this.X, this.Y), params=["A", "a"]) + with self.assertRaises(QueryBuilderError): + _ = ct.schema + + # Params on a schema-less source raises QueryBuilderError + t = table("a") + ct = cte(t.select(this.x), params=["renamed"]) + with self.assertRaises(QueryBuilderError): + _ = ct.schema + + # Empty params list passes through source schema unchanged + t = table("a", schema=CaseSensitiveDict({"x": int})) + ct = cte(t.select(this.x), params=[]) + assert ct.schema == t.schema + def test_funcs(self): c = Compiler(MockDatabase()) t = table("a")