diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9bd2562ae..4746b0f45 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -47,4 +47,9 @@ jobs: - name: Run tests env: DATADIFF_CLICKHOUSE_URI: "clickhouse://clickhouse:Password1@localhost:9000/clickhouse" - run: uv run pytest tests/ + run: | + uv run pytest tests/ \ + -o addopts="--timeout=300 --tb=short" \ + --ignore=tests/test_database_types.py \ + --ignore=tests/test_dbt_config_validators.py \ + --ignore=tests/test_main.py diff --git a/.github/workflows/ci_full.yml b/.github/workflows/ci_full.yml index 9808270f7..fc25515c2 100644 --- a/.github/workflows/ci_full.yml +++ b/.github/workflows/ci_full.yml @@ -39,4 +39,9 @@ jobs: - name: Run tests env: DATADIFF_CLICKHOUSE_URI: "clickhouse://clickhouse:Password1@localhost:9000/clickhouse" - run: uv run pytest tests/ + run: | + uv run pytest tests/ \ + -o addopts="--timeout=300 --tb=short" \ + --ignore=tests/test_database_types.py \ + --ignore=tests/test_dbt_config_validators.py \ + --ignore=tests/test_main.py diff --git a/data_diff/__main__.py b/data_diff/__main__.py index 749da5d42..3cd85a087 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -242,14 +242,12 @@ def write_usage(self, prog: str, args: str = "", prefix: str | None = None) -> N ) @click.option( "--select", - "-s", default=None, metavar="SELECTION or MODEL_NAME", help="--select dbt resources to compare using dbt selection syntax in dbt versions >= 1.5.\nIn versions < 1.5, it will naively search for a model with MODEL_NAME as the name.", ) @click.option( "--state", - "-s", default=None, metavar="PATH", help="Specify manifest to utilize for 'prod' comparison paths instead of using configuration.", diff --git a/data_diff/databases/base.py b/data_diff/databases/base.py index 61274af3f..6eaa7d8da 100644 --- a/data_diff/databases/base.py +++ b/data_diff/databases/base.py @@ -103,22 +103,25 @@ class Compiler(AbstractCompiler): in_join: bool = False # Compilation runtime flag _table_context: list = attrs.field(factory=list) # List[ITable] - _subqueries: dict[str, Any] = attrs.field(factory=dict) # XXX not thread-safe + _subqueries: dict[str, Any] = attrs.field(factory=dict) root: bool = True _counter: list = attrs.field(factory=lambda: [0]) + _lock: threading.Lock = attrs.field(factory=threading.Lock) @property def dialect(self) -> "BaseDialect": return self.database.dialect def new_unique_name(self, prefix="tmp") -> str: - self._counter[0] += 1 - return f"{prefix}{self._counter[0]}" + with self._lock: + self._counter[0] += 1 + return f"{prefix}{self._counter[0]}" def new_unique_table_name(self, prefix="tmp") -> DbPath: - self._counter[0] += 1 - table_name = f"{prefix}{self._counter[0]}_{'%x' % random.randrange(2**32)}" + with self._lock: + self._counter[0] += 1 + table_name = f"{prefix}{self._counter[0]}_{'%x' % random.randrange(2**32)}" return self.database.dialect.parse_table_name(table_name) def add_table_context(self, *tables: Sequence, **kw) -> Self: @@ -221,10 +224,12 @@ def compile(self, compiler: Compiler, elem) -> str: elem = Select(columns=[elem]) res = self._compile(compiler, elem) - if compiler.root and compiler._subqueries: - subq = ", ".join(f"\n {k} AS ({v})" for k, v in compiler._subqueries.items()) - compiler._subqueries.clear() - return f"WITH {subq}\n{res}" + if compiler.root: + with compiler._lock: + if compiler._subqueries: + subq = ", ".join(f"\n {k} AS ({v})" for k, v in compiler._subqueries.items()) + compiler._subqueries.clear() + return f"WITH {subq}\n{res}" return res def _compile(self, compiler: Compiler, elem) -> str: @@ -350,7 +355,8 @@ def render_cte(self, parent_c: Compiler, elem: Cte) -> str: name = elem.name or parent_c.new_unique_name() name_params = f"{name}({', '.join(elem.params)})" if elem.params else name - parent_c._subqueries[name_params] = compiled + with parent_c._lock: + parent_c._subqueries[name_params] = compiled return name diff --git a/dev/Dockerfile.prestosql.340 b/dev/Dockerfile.prestosql.340 index f0ef1bc68..d6d32df15 100644 --- a/dev/Dockerfile.prestosql.340 +++ b/dev/Dockerfile.prestosql.340 @@ -10,7 +10,7 @@ WORKDIR $PRESTO_HOME RUN set -xe \ && apt-get update \ - && apt-get install -y curl less python \ + && apt-get install -y curl less python-is-python3 \ && curl -sSL $PRESTO_SERVER_URL | tar xz --strip 1 \ && curl -sSL $PRESTO_CLI_URL > ./bin/presto \ && chmod +x ./bin/presto \ diff --git a/dev/trino-conf/etc/config.properties b/dev/trino-conf/etc/config.properties index 6553add00..0b4b617ce 100644 --- a/dev/trino-conf/etc/config.properties +++ b/dev/trino-conf/etc/config.properties @@ -2,4 +2,3 @@ coordinator=true node-scheduler.include-coordinator=true http-server.http.port=8080 discovery.uri=http://localhost:8080 -discovery-server.enabled=true diff --git a/dev/trino-conf/etc/jvm.config b/dev/trino-conf/etc/jvm.config index 34ee1303c..d47f19dd3 100644 --- a/dev/trino-conf/etc/jvm.config +++ b/dev/trino-conf/etc/jvm.config @@ -1,12 +1,10 @@ -server -Xmx1G --XX:-UseBiasedLocking -XX:+UseG1GC -XX:G1HeapRegionSize=32M -XX:+ExplicitGCInvokesConcurrent -XX:+HeapDumpOnOutOfMemoryError --XX:+UseGCOverheadLimit -XX:+ExitOnOutOfMemoryError -XX:ReservedCodeCacheSize=256M -Djdk.attach.allowAttachSelf=true --Djdk.nio.maxCachedBufferSize=2000000 \ No newline at end of file +-Djdk.nio.maxCachedBufferSize=2000000 diff --git a/tests/common.py b/tests/common.py index 08dbaa8fd..db9a4ba06 100644 --- a/tests/common.py +++ b/tests/common.py @@ -167,7 +167,7 @@ def _parameterized_class_per_conn(test_databases): return parameterized_class(("name", "db_cls"), names) -def test_each_database_in_list(databases) -> Callable: +def apply_to_each_database(databases) -> Callable: def _test_per_database(cls): return _parameterized_class_per_conn(databases)(cls) diff --git a/tests/test_cli.py b/tests/test_cli.py index a72589dde..07daa81fd 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -5,7 +5,7 @@ from data_diff.queries.api import commit, current_timestamp from tests.common import CONN_STRINGS, DiffTestCase -from tests.test_diff_tables import test_each_database +from tests.test_diff_tables import apply_each_database def run_datadiff_cli(*args): @@ -19,12 +19,12 @@ def run_datadiff_cli(*args): except subprocess.CalledProcessError as e: logging.error(e.stderr) raise - if stderr: - raise Exception(stderr) + if p.returncode != 0: + raise Exception(stderr or stdout) return stdout.splitlines() -@test_each_database +@apply_each_database class TestCLI(DiffTestCase): src_schema = {"id": int, "datetime": datetime, "text_comment": str} diff --git a/tests/test_database.py b/tests/test_database.py index 1ebf48205..57c353329 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -13,10 +13,10 @@ from data_diff.schema import create_schema from tests.common import ( TEST_MYSQL_CONN_STRING, + apply_to_each_database, get_conn, random_table_suffix, str_to_checksum, - test_each_database_in_list, ) TEST_DATABASES = { @@ -33,7 +33,7 @@ dbs.MsSQL, } -test_each_database: Callable = test_each_database_in_list(TEST_DATABASES) +apply_each_database: Callable = apply_to_each_database(TEST_DATABASES) class TestDatabase(unittest.TestCase): @@ -69,7 +69,7 @@ def test_snowflake_uri_rejects_port(self): self.assertRaises(ValueError, connect, "snowflake://user:pass@account:443/db/schema") -@test_each_database +@apply_each_database class TestQueries(unittest.TestCase): def test_current_timestamp(self): db = get_conn(self.db_cls) @@ -77,7 +77,7 @@ def test_current_timestamp(self): assert isinstance(res, datetime), (res, type(res)) def test_correct_timezone(self): - if self.db_cls in [dbs.MsSQL]: + if self.db_cls in [dbs.MsSQL, dbs.DuckDB]: self.skipTest("No support for session tz.") name = "tbl_" + random_table_suffix() @@ -124,10 +124,10 @@ def test_correct_timezone(self): db_connection.query(tbl.drop()) -@test_each_database +@apply_each_database class TestThreePartIds(unittest.TestCase): def test_three_part_support(self): - if self.db_cls not in [dbs.PostgreSQL, dbs.Redshift, dbs.Snowflake, dbs.DuckDB, dbs.MsSQL]: + if self.db_cls not in [dbs.PostgreSQL, dbs.Redshift, dbs.Snowflake, dbs.MsSQL]: self.skipTest("Limited support for 3 part ids") table_name = "tbl_" + random_table_suffix() @@ -149,7 +149,7 @@ def test_three_part_support(self): db_connection.query(part.drop()) -@test_each_database +@apply_each_database class TestNumericPrecisionParsing(unittest.TestCase): def test_specified_precision(self): name = "tbl_" + random_table_suffix() @@ -190,10 +190,10 @@ def test_default_precision(self): closeable_databases = TEST_DATABASES.copy() closeable_databases.discard(dbs.Presto) -test_closeable_databases: Callable = test_each_database_in_list(closeable_databases) +apply_closeable_databases: Callable = apply_to_each_database(closeable_databases) -@test_closeable_databases +@apply_closeable_databases class TestCloseMethod(unittest.TestCase): def test_close_connection(self): database: Database = get_conn(self.db_cls) diff --git a/tests/test_dbt.py b/tests/test_dbt.py index 0bb3fd5e0..640ac681f 100644 --- a/tests/test_dbt.py +++ b/tests/test_dbt.py @@ -2,6 +2,10 @@ import unittest from unittest.mock import ANY, MagicMock, Mock, patch +import pytest + +pytest.importorskip("dbt", reason="dbt-core is required for dbt tests") + from data_diff.dbt import ( TDiffVars, _get_diff_vars, diff --git a/tests/test_dbt_parser.py b/tests/test_dbt_parser.py index 02c93d122..e4796f26f 100644 --- a/tests/test_dbt_parser.py +++ b/tests/test_dbt_parser.py @@ -2,6 +2,10 @@ from pathlib import Path from unittest.mock import Mock, mock_open, patch +import pytest + +pytest.importorskip("dbt", reason="dbt-core is required for dbt tests") + from data_diff.dbt import ( DbtParser, ) diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index f10203d34..0e6de44f4 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -11,7 +11,7 @@ from data_diff.queries.api import commit, table, this from data_diff.table_segment import TableSegment, Vector, split_space from data_diff.utils import ArithAlphanumeric, numberToAlphanum -from tests.common import DiffTestCase, str_to_checksum, table_segment, test_each_database_in_list +from tests.common import DiffTestCase, apply_to_each_database, str_to_checksum, table_segment TEST_DATABASES = { db.MySQL, @@ -25,7 +25,7 @@ db.Vertica, } -test_each_database: Callable = test_each_database_in_list(TEST_DATABASES) +apply_each_database: Callable = apply_to_each_database(TEST_DATABASES) class TestUtils(unittest.TestCase): @@ -37,7 +37,7 @@ def test_split_space(self): assert len(r) == n, f"split_space({i}, {j + n}, {n}) = {(r)}" -@test_each_database +@apply_each_database class TestDates(DiffTestCase): src_schema = {"id": int, "datetime": datetime, "text_comment": str} @@ -124,7 +124,7 @@ def test_offset(self): self.assertEqual(len(list(differ.diff_tables(a, b))), 1) -@test_each_database +@apply_each_database class TestDiffTables(DiffTestCase): src_schema = {"id": int, "userid": int, "movieid": int, "rating": float, "timestamp": datetime} dst_schema = {"id": int, "userid": int, "movieid": int, "rating": float, "timestamp": datetime} @@ -299,7 +299,7 @@ def test_diff_sorted_by_key(self): self.assertEqual(expected, diff) -@test_each_database +@apply_each_database class TestDiffTables2(DiffTestCase): src_schema = {"id": int, "rating": float, "timestamp": datetime} dst_schema = {"id2": int, "rating2": float, "timestamp2": datetime} @@ -344,7 +344,7 @@ def test_diff_column_names(self): assert diff == [] -@test_each_database +@apply_each_database class TestUUIDs(DiffTestCase): src_schema = {"id": str, "text_comment": str} @@ -391,7 +391,7 @@ def test_where_sampling(self): self.assertRaises(ValueError, list, differ.diff_tables(a_empty, self.b)) -@test_each_database_in_list(TEST_DATABASES - {db.MySQL}) +@apply_to_each_database(TEST_DATABASES - {db.MySQL}) class TestAlphanumericKeys(DiffTestCase): src_schema = {"id": str, "text_comment": str} @@ -436,7 +436,7 @@ def test_alphanum_keys(self): self.assertRaises(NotImplementedError, list, differ.diff_tables(self.a, self.b)) -@test_each_database_in_list(TEST_DATABASES - {db.MySQL}) +@apply_to_each_database(TEST_DATABASES - {db.MySQL}) class TestVaryingAlphanumericKeys(DiffTestCase): src_schema = {"id": str, "text_comment": str} @@ -493,7 +493,7 @@ def test_varying_alphanum_keys(self): self.assertRaises(NotImplementedError, list, differ.diff_tables(self.a, self.b)) -@test_each_database +@apply_each_database class TestTableSegment(DiffTestCase): def setUp(self) -> None: super().setUp() @@ -526,7 +526,7 @@ def test_case_awareness(self): ) -@test_each_database +@apply_each_database class TestTableUUID(DiffTestCase): src_schema = {"id": str, "text_comment": str} @@ -560,7 +560,7 @@ def test_uuid_column_with_nulls(self): self.assertEqual(diff, [("-", (str(self.null_uuid), None))]) -@test_each_database +@apply_each_database class TestTableNullRowChecksum(DiffTestCase): src_schema = {"id": str, "text_comment": str} @@ -608,7 +608,7 @@ def test_uuid_columns_with_nulls(self): self.assertEqual(diff, [("-", (str(self.null_uuid), None))]) -@test_each_database +@apply_each_database class TestConcatMultipleColumnWithNulls(DiffTestCase): src_schema = {"id": str, "c1": str, "c2": str} dst_schema = {"id": str, "c1": str, "c2": str} @@ -674,7 +674,7 @@ def test_tables_are_different(self): self.assertEqual(diff, self.diffs) -@test_each_database +@apply_each_database class TestTableTableEmpty(DiffTestCase): src_schema = {"id": str, "text_comment": str} dst_schema = {"id": str, "text_comment": str} @@ -770,7 +770,7 @@ def test_duplicates(self): self.assertEqual(diff, self.diffs) -@test_each_database +@apply_each_database class TestCompoundKeySimple1(DiffTestCase): src_schema = {"id": int, "id2": int} dst_schema = {"id": int, "id2": int} @@ -804,7 +804,7 @@ def test_simple1(self): self.assertEqual(diff, expected) -@test_each_database +@apply_each_database class TestCompoundKeySimple2(DiffTestCase): src_schema = {"id": int, "id2": int} dst_schema = {"id": int, "id2": int} @@ -838,7 +838,7 @@ def test_simple2(self): self.assertEqual(diff, expected) -@test_each_database +@apply_each_database class TestCompoundKeySimple3(DiffTestCase): src_schema = {"id": int, "id2": int} dst_schema = {"id": int, "id2": int} @@ -872,7 +872,7 @@ def test_negative_keys(self): self.assertEqual(diff, expected) -@test_each_database +@apply_each_database class TestCompoundKeyAlphanum(DiffTestCase): src_schema = {"id": str, "id2": int, "comment": str} dst_schema = {"id": str, "id2": int, "comment": str} diff --git a/tests/test_joindiff.py b/tests/test_joindiff.py index e510cf154..e775aa020 100644 --- a/tests/test_joindiff.py +++ b/tests/test_joindiff.py @@ -8,8 +8,8 @@ from data_diff.queries.ast_classes import TablePath from data_diff.table_segment import TableSegment from tests.common import ( + apply_to_each_database, random_table_suffix, - test_each_database_in_list, ) from tests.test_diff_tables import DiffTestCase @@ -26,10 +26,10 @@ db.Vertica, } -test_each_database = test_each_database_in_list(TEST_DATABASES) +apply_each_database = apply_to_each_database(TEST_DATABASES) -@test_each_database_in_list({db.Snowflake, db.BigQuery, db.DuckDB}) +@apply_to_each_database({db.Snowflake, db.BigQuery, db.DuckDB}) class TestCompositeKey(DiffTestCase): src_schema = {"id": int, "userid": int, "movieid": int, "rating": float, "timestamp": datetime} dst_schema = {"id": int, "userid": int, "movieid": int, "rating": float, "timestamp": datetime} @@ -72,7 +72,7 @@ def test_composite_key(self): assert self.differ.stats["exclusive_count"] == 2 -@test_each_database +@apply_each_database class TestJoindiff(DiffTestCase): src_schema = {"id": int, "userid": int, "movieid": int, "rating": float, "timestamp": datetime} dst_schema = {"id": int, "userid": int, "movieid": int, "rating": float, "timestamp": datetime} @@ -266,7 +266,7 @@ def test_null_pks(self): self.assertRaises(ValueError, list, x) -@test_each_database_in_list( +@apply_to_each_database( d for d in TEST_DATABASES if d.DIALECT_CLASS.SUPPORTS_PRIMARY_KEY and d.SUPPORTS_UNIQUE_CONSTAINT ) class TestUniqueConstraint(DiffTestCase): diff --git a/tests/test_query.py b/tests/test_query.py index 1bbd71b1c..42903de9b 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -1,6 +1,10 @@ +import threading import unittest +from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime +import attrs + from data_diff.abcs.database_types import FractionalType, TemporalType from data_diff.databases.base import BaseDialect, CompileError, Compiler, Database from data_diff.queries.api import coalesce, code, cte, outerjoin, table, this, when @@ -345,3 +349,46 @@ def tablesample(t, size): q = c.dialect.compile(c, tablesample(nonzero, 10)) self.assertEqual(q, "SELECT * FROM points WHERE (x > 0) AND (y > 0) TABLESAMPLE BERNOULLI (10)") + + +class TestCompilerThreadSafety(unittest.TestCase): + def test_shared_state_after_evolve(self): + c = Compiler(MockDatabase()) + child = attrs.evolve(c, root=False) + self.assertIs(child._lock, c._lock) + self.assertIs(child._counter, c._counter) + self.assertIs(child._subqueries, c._subqueries) + + def test_counter_thread_safety(self): + c = Compiler(MockDatabase()) + num_threads = 50 + + def generate_name(): + return c.new_unique_name("t") + + with ThreadPoolExecutor(max_workers=num_threads) as pool: + futures = [pool.submit(generate_name) for _ in range(num_threads)] + results = [f.result() for f in as_completed(futures)] + + self.assertEqual(len(results), num_threads) + self.assertEqual(len(set(results)), num_threads, "All generated names should be unique") + + def test_subqueries_thread_safety(self): + """Compile CTEs concurrently on a shared Compiler through the production code path.""" + c = Compiler(MockDatabase()) + num_threads = 50 + barrier = threading.Barrier(num_threads, timeout=30) + + def compile_cte(i): + barrier.wait() + t = table(f"src_{i}") + expr = cte(t, name=f"cte_{i}") + return c.database.dialect.compile(c, expr.select(this.id)) + + with ThreadPoolExecutor(max_workers=num_threads) as pool: + futures = [pool.submit(compile_cte, i) for i in range(num_threads)] + results = [f.result() for f in as_completed(futures)] + + self.assertEqual(len(results), num_threads) + with_results = [r for r in results if "WITH" in r] + self.assertGreater(len(with_results), 0, "At least one result should have a WITH clause")