diff --git a/datafaker/main.py b/datafaker/main.py index ff768c6..e3d5e5f 100644 --- a/datafaker/main.py +++ b/datafaker/main.py @@ -84,7 +84,7 @@ def _check_file_non_existence(file_path: Path) -> None: """Check that a given file does not exist. Exit with an error message if it does.""" if file_path.exists(): logger.error("%s should not already exist. Exiting...", file_path) - sys.exit(1) + raise Exit(1) def load_metadata_config( @@ -580,7 +580,7 @@ def convert_table_names_to_tables( "%s is not the name of a table in the destination database", name ) if failed_count: - sys.exit(1) + raise Exit(1) return results @@ -669,7 +669,7 @@ def dump_data( "Must specify exactly one table if the output name is" " specified, or specify an existing directory" ) - sys.exit(1) + raise Exit(1) dst_dsn = get_destination_dsn() schema_name = get_destination_schema() config = read_config_file(config_file) if config_file is not None else {} @@ -702,7 +702,7 @@ def validate_config( validate(config, schema_config) except ValidationError as e: logger.error(e) - sys.exit(1) + raise Exit(1) from e logger.debug("Config file is valid.") @@ -798,7 +798,7 @@ def remove_tables( except InternalError as exc: logger.error("Failed to drop tables: %s", exc) logger.error("Please try again using the --all option.") - sys.exit(1) + raise Exit(1) from exc logger.debug("Tables dropped.") else: logger.info("Would remove tables if called with --yes.") diff --git a/datafaker/make.py b/datafaker/make.py index a8dd58d..a19b146 100644 --- a/datafaker/make.py +++ b/datafaker/make.py @@ -11,6 +11,7 @@ import pandas as pd import snsql +import typer import yaml from black import FileMode, format_str from jinja2 import Environment, FileSystemLoader, Template @@ -31,6 +32,7 @@ create_db_engine, download_table, get_columns_assigned, + get_metadata, get_property, get_related_table_names, get_row_generators, @@ -606,7 +608,21 @@ def make_table_generators( # pylint: disable=too-many-locals :return: A string that is a valid Python module, once written to file. """ row_generator_module_name: str = config.get("row_generators_module", None) + if row_generator_module_name and "-" in row_generator_module_name: + logger.error( + "Row generator name %s specified in %s should not contain a hyphen", + row_generator_module_name, + config_filename, + ) + raise typer.Exit(1) story_generator_module_name = config.get("story_generators_module", None) + if story_generator_module_name and "-" in story_generator_module_name: + logger.error( + "Story generator name %s specified in %s should not contain a hyphen", + story_generator_module_name, + config_filename, + ) + raise typer.Exit(1) object_instantiation: dict[str, dict] = config.get("object_instantiation", {}) tables_config = config.get("tables", {}) @@ -703,8 +719,7 @@ def make_tables_file( """Construct the YAML file representing the schema.""" engine = get_sync_engine(create_db_engine(db_dsn, schema_name=schema_name)) - metadata = MetaData() - metadata.reflect(engine) + metadata = get_metadata(engine) meta_dict = metadata_to_dict(metadata, schema_name, engine, parquet_dir) if parquet_dir is not None: diff --git a/datafaker/remove.py b/datafaker/remove.py index faa4e7d..71213e2 100644 --- a/datafaker/remove.py +++ b/datafaker/remove.py @@ -6,6 +6,7 @@ from datafaker.settings import get_destination_dsn, get_destination_schema from datafaker.utils import ( create_db_engine, + get_metadata, get_sync_engine, get_vocabulary_table_names, logger, @@ -67,6 +68,5 @@ def remove_db_tables(metadata: Optional[MetaData]) -> None: ) ) if metadata is None: - metadata = MetaData() - metadata.reflect(dst_engine) + metadata = get_metadata(dst_engine) metadata.drop_all(dst_engine) diff --git a/datafaker/utils.py b/datafaker/utils.py index 9a1ffbc..a0ab358 100644 --- a/datafaker/utils.py +++ b/datafaker/utils.py @@ -33,7 +33,12 @@ from jsonschema.validators import validate from sqlalchemy import Connection, Engine, ForeignKey, create_engine, event, select from sqlalchemy.engine.interfaces import DBAPIConnection -from sqlalchemy.exc import IntegrityError, ProgrammingError +from sqlalchemy.exc import ( + IntegrityError, + NoSuchModuleError, + OperationalError, + ProgrammingError, +) from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine from sqlalchemy.orm import Session from sqlalchemy.schema import ( @@ -43,6 +48,7 @@ MetaData, Table, ) +from typer import Exit # Define some types used repeatedly in the code base MaybeAsyncEngine = Union[Engine, AsyncEngine] @@ -110,7 +116,11 @@ def import_file(file_path: str) -> ModuleType: if spec is None or spec.loader is None: raise ImportError(f"No loadable module at {file_path}") module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) + try: + spec.loader.exec_module(module) + except ModuleNotFoundError as e: + logger.error("Failed to load module at %s with error:", file_path) + logger.error(e) return module @@ -193,11 +203,19 @@ def create_db_engine( **kwargs: Any, ) -> MaybeAsyncEngine: """Create a SQLAlchemy Engine.""" - if use_asyncio: - async_dsn = db_dsn.replace("postgresql://", "postgresql+asyncpg://") - engine: MaybeAsyncEngine = create_async_engine(async_dsn, **kwargs) - else: - engine = create_engine(db_dsn, **kwargs) + try: + if use_asyncio: + async_dsn = db_dsn.replace("postgresql://", "postgresql+asyncpg://") + engine: MaybeAsyncEngine = create_async_engine(async_dsn, **kwargs) + else: + engine = create_engine(db_dsn, **kwargs) + except NoSuchModuleError as exc: + logger.error("Failed to connect to the database: %s", exc) + logger.error("Perhaps the dialect '%s' is invalid.", db_dsn.split(":")[0]) + raise Exit(1) from exc + except ValueError as exc: + logger.error("DSN %s is malformed: %s", db_dsn, exc) + raise Exit(1) from exc settings = {} if schema_name is not None: @@ -248,6 +266,17 @@ def create_db_engine_dst( return create_db_engine(db_dsn, schema_name, use_asyncio) +def get_metadata(engine: Engine) -> MetaData: + """Get the MetaData object associated with the engine passed.""" + md = MetaData() + try: + md.reflect(engine) + except OperationalError as exc: + logger.error("Cannot connect to database: %s", exc) + raise Exit(1) from exc + return md + + def _find_parquet_directories(parquet_dir: Path) -> list[str]: """Find all the directories under ``parquet_dir`` that contain parquet files.""" return [ diff --git a/tests/test_functional.py b/tests/test_functional.py index eac6708..1f1d12a 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import Any, Mapping +import yaml from sqlalchemy import create_engine, inspect from typer.testing import CliRunner, Result @@ -606,6 +607,81 @@ def test_create_schema(self) -> None: inspector = inspect(engine) self.assertTrue(inspector.has_schema(env["dst_schema"])) + def test_story_incorrect_name(self) -> None: + """Test we get a proper error message if the story generator module does not exist.""" + config_file = "config_story_incorrect.yaml" + config = { + "story_generators_module": "incorrect_module", + } + with Path(config_file).open("w", encoding="utf-8") as fh: + fh.write(yaml.dump(config)) + self.invoke( + "make-tables", + "--force", + ) + completed_process = self.invoke( + "create-generators", + "--force", + "--config-file", + config_file, + ) + self.assertSuccess(completed_process) + self.invoke( + "create-tables", + "--config-file", + config_file, + ) + self.assertSuccess(completed_process) + completed_process = self.invoke( + "create-data", + "--config-file", + config_file, + expected_error="No module named 'incorrect_module'", + ) + self.assertReturnCode(completed_process, 1) + + def test_story_hyphens_in_name(self) -> None: + """Test hyphens in story generator names cause an error to be emitted.""" + config_file = "config_story_hyphens.yaml" + config = { + "story_generators_module": "story-generators", + } + with Path(config_file).open("w", encoding="utf-8") as fh: + fh.write(yaml.dump(config)) + self.invoke( + "make-tables", + "--force", + ) + completed_process = self.invoke( + "create-generators", + "--force", + "--config-file", + config_file, + expected_error="hyphen", + ) + self.assertReturnCode(completed_process, 1) + + def test_row_hyphens_in_name(self) -> None: + """Test hyphens in row generator names cause an error to be emitted.""" + config_file = "config_row_hyphens.yaml" + config = { + "row_generators_module": "row-generators", + } + with Path(config_file).open("w", encoding="utf-8") as fh: + fh.write(yaml.dump(config)) + self.invoke( + "make-tables", + "--force", + ) + completed_process = self.invoke( + "create-generators", + "--force", + "--config-file", + config_file, + expected_error="hyphen", + ) + self.assertReturnCode(completed_process, 1) + class DuckDbFunctionalTestCase(DBFunctionalTestCaseBase): """End-to-end tests for the DuckDB workflow.""" diff --git a/tests/test_main.py b/tests/test_main.py index 50d8602..a67ac30 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -341,6 +341,90 @@ def test_make_tables_with_force_enabled( mock_make_tables.reset_mock() mock_path.reset_mock() + @patch("datafaker.main.Path") + @patch("datafaker.settings.get_settings") + def test_incorrect_dialect_causes_nice_error_message( + self, + mock_get_settings: MagicMock, + mock_path: MagicMock, + ) -> None: + """Test the make-tables sub-command, when the force option is activated.""" + mock_get_settings.return_value = Settings( + # postgres: not postgresql: will cause sqlalchemy to fail to connect + src_dsn="postgres://suser:spassword@shost:5432/sdbname", + dst_dsn="postgresql://duser:dpassword@dhost:5432/ddbname", + # To stop any local .env files influencing the test + # The mypy ignore can be removed once we upgrade to pydantic 2. + _env_file=None, # type: ignore[call-arg] + ) + mock_path.return_value.exists.return_value = True + + result = runner.invoke( + app, + [ + "make-tables", + "--force", + "--orm-file=tests/examples/example_orm.yaml", + ], + ) + self.assertIs(type(result.exception), SystemExit) + + @patch("datafaker.main.Path") + @patch("datafaker.settings.get_settings") + def test_invalid_host_causes_nice_error_message( + self, + mock_get_settings: MagicMock, + mock_path: MagicMock, + ) -> None: + """Test the make-tables sub-command, when the force option is activated.""" + mock_get_settings.return_value = Settings( + # postgres: not postgresql: will cause sqlalchemy to fail to connect + src_dsn="postgresql://suser:spassword@invalid_host:5432/sdbname", + dst_dsn="postgresql://duser:dpassword@dhost:5432/ddbname", + # To stop any local .env files influencing the test + # The mypy ignore can be removed once we upgrade to pydantic 2. + _env_file=None, # type: ignore[call-arg] + ) + mock_path.return_value.exists.return_value = True + + result = runner.invoke( + app, + [ + "make-tables", + "--force", + "--orm-file=tests/examples/example_orm.yaml", + ], + ) + self.assertIs(type(result.exception), SystemExit) + + @patch("datafaker.main.Path") + @patch("datafaker.settings.get_settings") + def test_incorrect_dsn_causes_nice_error_message( + self, + mock_get_settings: MagicMock, + mock_path: MagicMock, + ) -> None: + """Test the make-tables sub-command, when the force option is activated.""" + mock_get_settings.return_value = Settings( + # postgres: not postgresql: will cause sqlalchemy to fail to connect + src_dsn="postgresql://suser:spassword:localhost:5432/sdbname", + dst_dsn="postgresql://duser:dpassword@dhost:5432/ddbname", + # To stop any local .env files influencing the test + # The mypy ignore can be removed once we upgrade to pydantic 2. + _env_file=None, # type: ignore[call-arg] + ) + mock_path.return_value.exists.return_value = True + + result = runner.invoke( + app, + [ + "make-tables", + "--force", + "--orm-file=tests/examples/example_orm.yaml", + ], + ) + self.assertIs(type(result.exception), SystemExit) + def test_validate_config(self) -> None: """Test the validate-config sub-command.""" result = runner.invoke( diff --git a/tests/utils.py b/tests/utils.py index 791ffed..f2d0685 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -38,10 +38,6 @@ ) -class SysExit(Exception): - """To force the function to exit as sys.exit() would.""" - - @lru_cache(1) def get_test_settings() -> settings.Settings: """Get a Settings object that ignores .env files and environment variables."""