From 564fc99b5a0e68e4d5aa23fcd0ea23338dbff73d Mon Sep 17 00:00:00 2001 From: Daniel Song Date: Mon, 2 Mar 2026 01:00:14 -0800 Subject: [PATCH 1/3] fix: register CTE schema correctly when params rename columns (#9) When a CTE uses params to rename columns, the schema now reflects the renamed column names instead of passing through the original names. This prevents silent failures in downstream column lookups. Co-Authored-By: Claude Opus 4.6 --- data_diff/queries/ast_classes.py | 10 ++++++++-- tests/test_query.py | 26 +++++++++++++++++++++++++- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index 3cf52d31..e29249bf 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -635,8 +635,14 @@ def source_table(self) -> "ITable": @property def schema(self) -> Schema: - # TODO add cte to schema - return self.table.schema + s = self.table.schema + if s is None or not self.params: + return s + if len(self.params) != len(s): + raise QueryBuilderError( + f"CTE params length ({len(self.params)}) does not match source schema length ({len(s)})" + ) + return type(s)(dict(zip(self.params, s.values()))) def _named_exprs_as_aliases(named_exprs): diff --git a/tests/test_query.py b/tests/test_query.py index 42903de9..21f630bd 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -8,7 +8,7 @@ from data_diff.abcs.database_types import FractionalType, TemporalType from data_diff.databases.base import BaseDialect, CompileError, Compiler, Database from data_diff.queries.api import coalesce, code, cte, outerjoin, table, this, when -from data_diff.queries.ast_classes import Random +from data_diff.queries.ast_classes import QueryBuilderError, Random from data_diff.utils import CaseInsensitiveDict, CaseSensitiveDict @@ -196,6 +196,30 @@ def test_cte(self): expected = "WITH tmp1(y) AS (SELECT x FROM a) SELECT y FROM tmp1" assert normalize_spaces(c.dialect.compile(c, t3)) == expected + def test_cte_schema(self): + # Non-parameterized CTE passes through source schema unchanged + t = table("a", schema=CaseSensitiveDict({"x": int, "y": str})) + ct = cte(t.select(this.x, this.y)) + assert ct.schema == t.schema + + # Parameterized CTE reflects renamed columns with correct types + t = table("a", schema=CaseSensitiveDict({"x": int, "y": str})) + ct = cte(t.select(this.x, this.y), params=["a", "b"]) + s = ct.schema + assert list(s.keys()) == ["a", "b"] + assert list(s.values()) == [int, str] + + # Param count mismatch raises QueryBuilderError + t = table("a", schema=CaseSensitiveDict({"x": int, "y": str})) + ct = cte(t.select(this.x, this.y), params=["a"]) + with self.assertRaises(QueryBuilderError): + _ = ct.schema + + # Schema type (case sensitivity) is preserved + t = table("a", schema=CaseInsensitiveDict({"X": int, "Y": str})) + ct = cte(t.select(this.X, this.Y), params=["A", "B"]) + assert isinstance(ct.schema, CaseInsensitiveDict) + def test_funcs(self): c = Compiler(MockDatabase()) t = table("a") From 40d3bb3bbc5c29656855e8d4adb0bbc651f5304b Mon Sep 17 00:00:00 2001 From: Daniel Song Date: Mon, 2 Mar 2026 01:08:42 -0800 Subject: [PATCH 2/3] fix: detect duplicate CTE params that silently corrupt schema Add a post-construction length check to catch duplicate param names that would collapse into fewer columns (both case-sensitive and case-insensitive). Also strengthen tests with value assertions and duplicate detection cases. Co-Authored-By: Claude Opus 4.6 --- data_diff/queries/ast_classes.py | 5 ++++- tests/test_query.py | 20 ++++++++++++++++++-- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index e29249bf..f8f19181 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -642,7 +642,10 @@ def schema(self) -> Schema: raise QueryBuilderError( f"CTE params length ({len(self.params)}) does not match source schema length ({len(s)})" ) - return type(s)(dict(zip(self.params, s.values()))) + result = type(s)(dict(zip(self.params, s.values()))) + if len(result) != len(s): + raise QueryBuilderError(f"CTE params contain duplicate column names: {self.params!r}") + return result def _named_exprs_as_aliases(named_exprs): diff --git a/tests/test_query.py b/tests/test_query.py index 21f630bd..c6f7bd17 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -215,10 +215,26 @@ def test_cte_schema(self): with self.assertRaises(QueryBuilderError): _ = ct.schema - # Schema type (case sensitivity) is preserved + # Schema type (case sensitivity) is preserved with correct values t = table("a", schema=CaseInsensitiveDict({"X": int, "Y": str})) ct = cte(t.select(this.X, this.Y), params=["A", "B"]) - assert isinstance(ct.schema, CaseInsensitiveDict) + s = ct.schema + assert isinstance(s, CaseInsensitiveDict) + assert s["A"] is int + assert s["a"] is int + assert s["B"] is str + + # Duplicate params raises QueryBuilderError + t = table("a", schema=CaseSensitiveDict({"x": int, "y": str})) + ct = cte(t.select(this.x, this.y), params=["a", "a"]) + with self.assertRaises(QueryBuilderError): + _ = ct.schema + + # Case-insensitive duplicate params raises QueryBuilderError + t = table("a", schema=CaseInsensitiveDict({"X": int, "Y": str})) + ct = cte(t.select(this.X, this.Y), params=["A", "a"]) + with self.assertRaises(QueryBuilderError): + _ = ct.schema def test_funcs(self): c = Compiler(MockDatabase()) From 1dc436016cb24ef9b3ea72bacf6c5c04f3164f62 Mon Sep 17 00:00:00 2001 From: Daniel Song Date: Mon, 2 Mar 2026 01:14:53 -0800 Subject: [PATCH 3/3] fix: raise error when CTE params provided without source schema Split the guard clause so params on a schema-less source raises QueryBuilderError instead of silently discarding them. Also pin the params=[] passthrough behavior with a test. Co-Authored-By: Claude Opus 4.6 --- data_diff/queries/ast_classes.py | 4 +++- tests/test_query.py | 11 +++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/data_diff/queries/ast_classes.py b/data_diff/queries/ast_classes.py index f8f19181..c5fb8179 100644 --- a/data_diff/queries/ast_classes.py +++ b/data_diff/queries/ast_classes.py @@ -636,8 +636,10 @@ def source_table(self) -> "ITable": @property def schema(self) -> Schema: s = self.table.schema - if s is None or not self.params: + if not self.params: return s + if s is None: + raise QueryBuilderError(f"CTE params were provided ({self.params!r}) but the source table has no schema") if len(self.params) != len(s): raise QueryBuilderError( f"CTE params length ({len(self.params)}) does not match source schema length ({len(s)})" diff --git a/tests/test_query.py b/tests/test_query.py index c6f7bd17..722a809c 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -236,6 +236,17 @@ def test_cte_schema(self): with self.assertRaises(QueryBuilderError): _ = ct.schema + # Params on a schema-less source raises QueryBuilderError + t = table("a") + ct = cte(t.select(this.x), params=["renamed"]) + with self.assertRaises(QueryBuilderError): + _ = ct.schema + + # Empty params list passes through source schema unchanged + t = table("a", schema=CaseSensitiveDict({"x": int})) + ct = cte(t.select(this.x), params=[]) + assert ct.schema == t.schema + def test_funcs(self): c = Compiler(MockDatabase()) t = table("a")