Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,9 @@ jobs:
- name: Run tests
env:
DATADIFF_CLICKHOUSE_URI: "clickhouse://clickhouse:Password1@localhost:9000/clickhouse"
run: uv run pytest tests/
run: |
uv run pytest tests/ \
-o addopts="--timeout=300 --tb=short" \
--ignore=tests/test_database_types.py \
--ignore=tests/test_dbt_config_validators.py \
--ignore=tests/test_main.py
7 changes: 6 additions & 1 deletion .github/workflows/ci_full.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,9 @@ jobs:
- name: Run tests
env:
DATADIFF_CLICKHOUSE_URI: "clickhouse://clickhouse:Password1@localhost:9000/clickhouse"
run: uv run pytest tests/
run: |
uv run pytest tests/ \
-o addopts="--timeout=300 --tb=short" \
--ignore=tests/test_database_types.py \
--ignore=tests/test_dbt_config_validators.py \
--ignore=tests/test_main.py
2 changes: 0 additions & 2 deletions data_diff/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,14 +242,12 @@ def write_usage(self, prog: str, args: str = "", prefix: str | None = None) -> N
)
@click.option(
"--select",
"-s",
default=None,
metavar="SELECTION or MODEL_NAME",
help="--select dbt resources to compare using dbt selection syntax in dbt versions >= 1.5.\nIn versions < 1.5, it will naively search for a model with MODEL_NAME as the name.",
)
@click.option(
"--state",
"-s",
default=None,
metavar="PATH",
help="Specify manifest to utilize for 'prod' comparison paths instead of using configuration.",
Expand Down
26 changes: 16 additions & 10 deletions data_diff/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,22 +103,25 @@ class Compiler(AbstractCompiler):
in_join: bool = False # Compilation runtime flag

_table_context: list = attrs.field(factory=list) # List[ITable]
_subqueries: dict[str, Any] = attrs.field(factory=dict) # XXX not thread-safe
_subqueries: dict[str, Any] = attrs.field(factory=dict)
root: bool = True

_counter: list = attrs.field(factory=lambda: [0])
_lock: threading.Lock = attrs.field(factory=threading.Lock)

@property
def dialect(self) -> "BaseDialect":
return self.database.dialect

def new_unique_name(self, prefix="tmp") -> str:
self._counter[0] += 1
return f"{prefix}{self._counter[0]}"
with self._lock:
self._counter[0] += 1
return f"{prefix}{self._counter[0]}"

def new_unique_table_name(self, prefix="tmp") -> DbPath:
self._counter[0] += 1
table_name = f"{prefix}{self._counter[0]}_{'%x' % random.randrange(2**32)}"
with self._lock:
self._counter[0] += 1
table_name = f"{prefix}{self._counter[0]}_{'%x' % random.randrange(2**32)}"
return self.database.dialect.parse_table_name(table_name)

def add_table_context(self, *tables: Sequence, **kw) -> Self:
Expand Down Expand Up @@ -221,10 +224,12 @@ def compile(self, compiler: Compiler, elem) -> str:
elem = Select(columns=[elem])

res = self._compile(compiler, elem)
if compiler.root and compiler._subqueries:
subq = ", ".join(f"\n {k} AS ({v})" for k, v in compiler._subqueries.items())
compiler._subqueries.clear()
return f"WITH {subq}\n{res}"
if compiler.root:
with compiler._lock:
if compiler._subqueries:
subq = ", ".join(f"\n {k} AS ({v})" for k, v in compiler._subqueries.items())
compiler._subqueries.clear()
return f"WITH {subq}\n{res}"
return res

def _compile(self, compiler: Compiler, elem) -> str:
Expand Down Expand Up @@ -350,7 +355,8 @@ def render_cte(self, parent_c: Compiler, elem: Cte) -> str:

name = elem.name or parent_c.new_unique_name()
name_params = f"{name}({', '.join(elem.params)})" if elem.params else name
parent_c._subqueries[name_params] = compiled
with parent_c._lock:
parent_c._subqueries[name_params] = compiled

return name

Expand Down
2 changes: 1 addition & 1 deletion dev/Dockerfile.prestosql.340
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ WORKDIR $PRESTO_HOME

RUN set -xe \
&& apt-get update \
&& apt-get install -y curl less python \
&& apt-get install -y curl less python-is-python3 \
&& curl -sSL $PRESTO_SERVER_URL | tar xz --strip 1 \
&& curl -sSL $PRESTO_CLI_URL > ./bin/presto \
&& chmod +x ./bin/presto \
Expand Down
1 change: 0 additions & 1 deletion dev/trino-conf/etc/config.properties
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,3 @@ coordinator=true
node-scheduler.include-coordinator=true
http-server.http.port=8080
discovery.uri=http://localhost:8080
discovery-server.enabled=true
4 changes: 1 addition & 3 deletions dev/trino-conf/etc/jvm.config
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
-server
-Xmx1G
-XX:-UseBiasedLocking
-XX:+UseG1GC
-XX:G1HeapRegionSize=32M
-XX:+ExplicitGCInvokesConcurrent
-XX:+HeapDumpOnOutOfMemoryError
-XX:+UseGCOverheadLimit
-XX:+ExitOnOutOfMemoryError
-XX:ReservedCodeCacheSize=256M
-Djdk.attach.allowAttachSelf=true
-Djdk.nio.maxCachedBufferSize=2000000
-Djdk.nio.maxCachedBufferSize=2000000
2 changes: 1 addition & 1 deletion tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def _parameterized_class_per_conn(test_databases):
return parameterized_class(("name", "db_cls"), names)


def test_each_database_in_list(databases) -> Callable:
def apply_to_each_database(databases) -> Callable:
def _test_per_database(cls):
return _parameterized_class_per_conn(databases)(cls)

Expand Down
8 changes: 4 additions & 4 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from data_diff.queries.api import commit, current_timestamp
from tests.common import CONN_STRINGS, DiffTestCase
from tests.test_diff_tables import test_each_database
from tests.test_diff_tables import apply_each_database


def run_datadiff_cli(*args):
Expand All @@ -19,12 +19,12 @@ def run_datadiff_cli(*args):
except subprocess.CalledProcessError as e:
logging.error(e.stderr)
raise
if stderr:
raise Exception(stderr)
if p.returncode != 0:
raise Exception(stderr or stdout)
return stdout.splitlines()


@test_each_database
@apply_each_database
class TestCLI(DiffTestCase):
src_schema = {"id": int, "datetime": datetime, "text_comment": str}

Expand Down
18 changes: 9 additions & 9 deletions tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
from data_diff.schema import create_schema
from tests.common import (
TEST_MYSQL_CONN_STRING,
apply_to_each_database,
get_conn,
random_table_suffix,
str_to_checksum,
test_each_database_in_list,
)

TEST_DATABASES = {
Expand All @@ -33,7 +33,7 @@
dbs.MsSQL,
}

test_each_database: Callable = test_each_database_in_list(TEST_DATABASES)
apply_each_database: Callable = apply_to_each_database(TEST_DATABASES)


class TestDatabase(unittest.TestCase):
Expand Down Expand Up @@ -69,15 +69,15 @@ def test_snowflake_uri_rejects_port(self):
self.assertRaises(ValueError, connect, "snowflake://user:pass@account:443/db/schema")


@test_each_database
@apply_each_database
class TestQueries(unittest.TestCase):
def test_current_timestamp(self):
db = get_conn(self.db_cls)
res = db.query(current_timestamp(), datetime)
assert isinstance(res, datetime), (res, type(res))

def test_correct_timezone(self):
if self.db_cls in [dbs.MsSQL]:
if self.db_cls in [dbs.MsSQL, dbs.DuckDB]:
self.skipTest("No support for session tz.")
name = "tbl_" + random_table_suffix()

Expand Down Expand Up @@ -124,10 +124,10 @@ def test_correct_timezone(self):
db_connection.query(tbl.drop())


@test_each_database
@apply_each_database
class TestThreePartIds(unittest.TestCase):
def test_three_part_support(self):
if self.db_cls not in [dbs.PostgreSQL, dbs.Redshift, dbs.Snowflake, dbs.DuckDB, dbs.MsSQL]:
if self.db_cls not in [dbs.PostgreSQL, dbs.Redshift, dbs.Snowflake, dbs.MsSQL]:
self.skipTest("Limited support for 3 part ids")

table_name = "tbl_" + random_table_suffix()
Expand All @@ -149,7 +149,7 @@ def test_three_part_support(self):
db_connection.query(part.drop())


@test_each_database
@apply_each_database
class TestNumericPrecisionParsing(unittest.TestCase):
def test_specified_precision(self):
name = "tbl_" + random_table_suffix()
Expand Down Expand Up @@ -190,10 +190,10 @@ def test_default_precision(self):
closeable_databases = TEST_DATABASES.copy()
closeable_databases.discard(dbs.Presto)

test_closeable_databases: Callable = test_each_database_in_list(closeable_databases)
apply_closeable_databases: Callable = apply_to_each_database(closeable_databases)


@test_closeable_databases
@apply_closeable_databases
class TestCloseMethod(unittest.TestCase):
def test_close_connection(self):
database: Database = get_conn(self.db_cls)
Expand Down
4 changes: 4 additions & 0 deletions tests/test_dbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
import unittest
from unittest.mock import ANY, MagicMock, Mock, patch

import pytest

pytest.importorskip("dbt", reason="dbt-core is required for dbt tests")

from data_diff.dbt import (
TDiffVars,
_get_diff_vars,
Expand Down
4 changes: 4 additions & 0 deletions tests/test_dbt_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
from pathlib import Path
from unittest.mock import Mock, mock_open, patch

import pytest

pytest.importorskip("dbt", reason="dbt-core is required for dbt tests")

from data_diff.dbt import (
DbtParser,
)
Expand Down
Loading
Loading