diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 000000000..5fcec96d6 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,194 @@ +""" +Tests for tortoise.config module - TortoiseConfig class. +""" + +import shutil +from pathlib import Path + +import orjson +import pytest +import yaml + +from tortoise.backends.base.config_generator import expand_db_url +from tortoise.config import AppConfig, DBUrlConfig, TortoiseConfig +from tortoise.exceptions import ConfigurationError + + +class TestTortoiseConfig: + @pytest.fixture + def db_url(self) -> str: + return "sqlite://db.sqlite3" + + @pytest.fixture + def simple_config(self, db_url: str) -> dict: + return { + "connections": {"default": db_url}, + "apps": { + "app": { + "models": ["app.models"], + "default_connection": "default", + } + }, + } + + @pytest.mark.parametrize( + "config,msg", + [ + ([], "TortoiseConfig must be created from a mapping"), + ({}, 'Config must define "connections" section'), + ({"connections": ""}, 'Config must define "apps" section'), + ({"connections": "", "apps": ""}, 'Config "connections" must be a mapping'), + ( + {"connections": {"default": []}, "apps": ""}, + "Connection values must be mapping or string", + ), + ( + {"connections": {"default": ""}, "apps": ""}, + "DBUrlConfig.url must be a non-empty string", + ), + ( + {"connections": {"default": "db.sqlite3"}, "apps": ""}, + 'Config "apps" must be a mapping', + ), + ( + {"connections": {"default": "db.sqlite3"}, "apps": {"auth": ""}}, + "App values must be mappings", + ), + ( + {"connections": {"default": "db.sqlite3"}, "apps": {"auth": {}}, "routers": {}}, + 'AppConfig requires "models"', + ), + ( + { + "connections": {"default": "db.sqlite3"}, + "apps": {"auth": {"models": []}}, + "routers": {}, + }, + "AppConfig.models must be a non-empty list of strings", + ), + ( + { + "connections": {"default": "db.sqlite3"}, + "apps": {"auth": {"models": ["models"]}}, + "routers": "", + }, + "TortoiseConfig.routers must be a list or None", + ), + ], + ) + def test_from_invalid_dict(self, config: list | dict, msg: str): + with pytest.raises(ConfigurationError, match=msg): + TortoiseConfig.from_dict(config) # type: ignore + + def test_from_dict(self, simple_config: dict): + assert TortoiseConfig.from_dict(simple_config) == TortoiseConfig( + connections={"default": DBUrlConfig(url="sqlite://db.sqlite3")}, + apps={ + "app": AppConfig( + models=["app.models"], default_connection="default", migrations=None + ) + }, + routers=None, + use_tz=None, + timezone=None, + ) + full = { + "connections": { + "default": "sqlite://db.sqlite3", + "second": "sqlite://db2.sqlite3", + }, + "apps": { + "app1": { + "models": ["app1.models"], + "migrations": "app1.migrations", + }, + "app2": { + "models": ["app2.models"], + "default_connection": "second", + "migrations": "app2.migrations", + }, + }, + "routers": ["path.Router"], + "use_tz": True, + "timezone": "UTC", + } + assert TortoiseConfig.from_dict(full) == TortoiseConfig( + connections={ + "default": DBUrlConfig(url="sqlite://db.sqlite3"), + "second": DBUrlConfig(url="sqlite://db2.sqlite3"), + }, + apps={ + "app1": AppConfig( + models=["app1.models"], default_connection=None, migrations="app1.migrations" + ), + "app2": AppConfig( + models=["app2.models"], + default_connection="second", + migrations="app2.migrations", + ), + }, + routers=["path.Router"], + use_tz=True, + timezone="UTC", + ) + + def test_from_config_file(self, tmp_path: Path, simple_config: dict): + file = tmp_path / "tortoise_conf.json" + file.write_bytes(orjson.dumps(simple_config)) + filename: str = file.as_posix() + assert ( + TortoiseConfig.from_config_file(file) + == TortoiseConfig.from_config_file(filename) + == TortoiseConfig.from_dict(simple_config) + == TortoiseConfig.resolve_args(config_file=file) + ) + + yaml_file = file.with_suffix(".yml") + with yaml_file.open("w") as f: + yaml.safe_dump(simple_config, f, default_flow_style=False) + yaml_file_2 = file.with_suffix(".yaml") + shutil.copy(yaml_file, yaml_file_2) + assert ( + TortoiseConfig.from_config_file(yaml_file) + == TortoiseConfig.from_config_file(str(yaml_file)) + == TortoiseConfig.from_config_file(yaml_file_2) + == TortoiseConfig.from_config_file(file) + == TortoiseConfig.resolve_args(config_file=yaml_file) + ) + + def test_from_db_url_and_modules(self, simple_config: dict, db_url: str): + modules = {"app": simple_config["apps"]["app"]["models"]} + typed_config = TortoiseConfig.from_db_url_and_modules(db_url, modules) + assert typed_config == TortoiseConfig.resolve_args(db_url=db_url, modules=modules) + assert typed_config.apps == TortoiseConfig.from_dict(simple_config).apps + + @pytest.mark.parametrize( + "config,msg", + [ + ({}, "Must provide either 'config', 'config_file', or both 'db_url' and 'modules'"), + ( + dict(db_url=""), + "Must provide either 'config', 'config_file', or both 'db_url' and 'modules'", + ), + ( + dict(config={}, config_file="a.json"), + "Cannot specify both 'config' and 'config_file'", + ), + ], + ) + def test_resolve_args_invalid(self, config: dict, msg: str): + with pytest.raises(ConfigurationError, match=msg): + TortoiseConfig.resolve_args(**config) + + def test_resolve_args(self, tmp_path: Path, db_url: str, simple_config: dict): + config_file = tmp_path / "config.json" + config_file.write_bytes(orjson.dumps(simple_config)) + typed_config = TortoiseConfig.resolve_args(simple_config) + assert typed_config == TortoiseConfig.resolve_args(config_file=config_file) + + typed_config_2 = TortoiseConfig.resolve_args(db_url=db_url, modules={"app": ["app.models"]}) + assert typed_config.apps == typed_config_2.apps + assert ( + expand_db_url(str(typed_config.connections["default"].to_config())) + == typed_config_2.connections["default"].to_config() + ) diff --git a/tortoise/config.py b/tortoise/config.py index f7527093b..c0d334f35 100644 --- a/tortoise/config.py +++ b/tortoise/config.py @@ -1,9 +1,9 @@ from __future__ import annotations import json -import os from collections.abc import Mapping from dataclasses import dataclass, field +from pathlib import Path from typing import TYPE_CHECKING, Any from tortoise.backends.base.config_generator import generate_config @@ -217,12 +217,12 @@ def from_dict(cls, data: Mapping[str, Any]) -> Self: ) @classmethod - def from_config_file(cls, config_file: str) -> Self: + def from_config_file(cls, config_file: Path | str) -> Self: """ Load configuration from a YAML or JSON file. Args: - config_file (str): Path to the configuration file. Supported extensions: .yml, .yaml, .json. + config_file: Path to the configuration file. Supported extensions: .yml, .yaml, .json. Returns: Self: The constructed TortoiseConfig. @@ -230,19 +230,19 @@ def from_config_file(cls, config_file: str) -> Self: Raises: ConfigurationError: If the file is missing, unsupported, or contents are invalid. """ - _, extension = os.path.splitext(config_file) - if extension in (".yml", ".yaml"): - import yaml # pylint: disable=C0415 - - with open(config_file) as f: - config = yaml.safe_load(f) - elif extension == ".json": - with open(config_file) as f: - config = json.load(f) - else: - raise ConfigurationError( - f"Unknown config extension {extension}, only .yml and .json are supported" - ) + config_path = Path(config_file) + match config_path.suffix: + case ".yml" | ".yaml": + import yaml # pylint: disable=C0415 + + with open(config_file) as f: + config = yaml.safe_load(f) + case ".json": + config = json.loads(config_path.read_bytes()) + case _ as extension: + raise ConfigurationError( + f"Unknown config extension {extension}, only .yml and .json are supported" + ) return cls.from_dict(config) @classmethod @@ -274,7 +274,7 @@ def from_db_url_and_modules( def resolve_args( cls, config: dict[str, Any] | Self | None = None, - config_file: str | None = None, + config_file: Path | str | None = None, db_url: str | None = None, modules: dict[str, Iterable[str | ModuleType]] | None = None, ) -> Self: @@ -286,14 +286,9 @@ def resolve_args( - `config_file` path, - or both `db_url` and `modules`. - Args: - config (dict[str, Any] | TortoiseConfig | None): - config_file (str | None): Path to a config YAML or JSON file. - db_url (str | None): Database URL for config generation. - modules (dict[str, Iterable[str | ModuleType]] | None): App modules for config generation. Args: config: A configuration dict or TortoiseConfig instance. - config_file: Path to config file. + config_file: Path to a config YAML or JSON file. db_url: Database URL for config generation. modules: App modules for config generation.