From c326d523790587f64c20f3d6cba7b7f5171bad80 Mon Sep 17 00:00:00 2001 From: hasansezertasan Date: Sun, 12 Apr 2026 17:30:11 +0300 Subject: [PATCH 1/2] feat: add aiomysql adapter for async MySQL connectivity Add a new async MySQL adapter using aiomysql (built on PyMySQL), complementing the existing asyncmy adapter. This gives teams with inherited aiomysql codebases a path to adopt sqlspec without a driver migration. Key implementation details: - Exception handling uses pymysql.err.* (aiomysql's underlying error hierarchy) - Connection.cursor() is synchronous (same as asyncmy, despite aiomysql docs) - Pool API uses aiomysql.create_pool() with "db" key (not "database" like asyncmy) - Parameter style: QMARK input -> POSITIONAL_PYFORMAT execution (identical to asyncmy) - Extensions (ADK, Litestar, Events) deferred to follow-up Refs: litestar-org/sqlspec#412 Co-Authored-By: Claude Opus 4.6 (1M context) --- pyproject.toml | 4 + sqlspec/adapters/aiomysql/__init__.py | 21 + sqlspec/adapters/aiomysql/_typing.py | 113 +++ sqlspec/adapters/aiomysql/config.py | 305 ++++++++ sqlspec/adapters/aiomysql/core.py | 500 +++++++++++++ sqlspec/adapters/aiomysql/data_dictionary.py | 123 ++++ sqlspec/adapters/aiomysql/driver.py | 381 ++++++++++ .../integration/adapters/aiomysql/__init__.py | 3 + .../integration/adapters/aiomysql/conftest.py | 89 +++ .../adapters/aiomysql/test_arrow.py | 204 +++++ .../adapters/aiomysql/test_config.py | 220 ++++++ .../adapters/aiomysql/test_driver.py | 556 ++++++++++++++ .../adapters/aiomysql/test_exceptions.py | 137 ++++ .../adapters/aiomysql/test_explain.py | 138 ++++ .../adapters/aiomysql/test_features.py | 284 +++++++ .../adapters/aiomysql/test_migrations.py | 405 ++++++++++ .../aiomysql/test_parameter_styles.py | 696 ++++++++++++++++++ .../adapters/aiomysql/test_storage_bridge.py | 61 ++ tests/unit/adapters/test_aiomysql/__init__.py | 0 .../adapters/test_aiomysql/test_config.py | 36 + uv.lock | 18 +- 21 files changed, 4293 insertions(+), 1 deletion(-) create mode 100644 sqlspec/adapters/aiomysql/__init__.py create mode 100644 sqlspec/adapters/aiomysql/_typing.py create mode 100644 sqlspec/adapters/aiomysql/config.py create mode 100644 sqlspec/adapters/aiomysql/core.py create mode 100644 sqlspec/adapters/aiomysql/data_dictionary.py create mode 100644 sqlspec/adapters/aiomysql/driver.py create mode 100644 tests/integration/adapters/aiomysql/__init__.py create mode 100644 tests/integration/adapters/aiomysql/conftest.py create mode 100644 tests/integration/adapters/aiomysql/test_arrow.py create mode 100644 tests/integration/adapters/aiomysql/test_config.py create mode 100644 tests/integration/adapters/aiomysql/test_driver.py create mode 100644 tests/integration/adapters/aiomysql/test_exceptions.py create mode 100644 tests/integration/adapters/aiomysql/test_explain.py create mode 100644 tests/integration/adapters/aiomysql/test_features.py create mode 100644 tests/integration/adapters/aiomysql/test_migrations.py create mode 100644 tests/integration/adapters/aiomysql/test_parameter_styles.py create mode 100644 tests/integration/adapters/aiomysql/test_storage_bridge.py create mode 100644 tests/unit/adapters/test_aiomysql/__init__.py create mode 100644 tests/unit/adapters/test_aiomysql/test_config.py diff --git a/pyproject.toml b/pyproject.toml index 24033b09f..96d555775 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ Source = "https://github.com/litestar-org/sqlspec" [project.optional-dependencies] adbc = ["adbc_driver_manager", "pyarrow"] adk = ["google-adk"] +aiomysql = ["aiomysql"] aioodbc = ["aioodbc"] aiosqlite = ["aiosqlite"] alloydb = ["google-cloud-alloydb-connector"] @@ -358,6 +359,7 @@ markers = [ "mssql: marks tests specific to Microsoft SQL Server", # Driver markers "adbc: marks tests using ADBC drivers", + "aiomysql: marks tests using aiomysql", "aioodbc: marks tests using aioodbc", "aiosqlite: marks tests using aiosqlite", "asyncmy: marks tests using asyncmy", @@ -400,6 +402,8 @@ module = [ "orjson", "uvicorn.*", "uvloop.*", + "aiomysql", + "aiomysql.*", "asyncmy", "asyncmy.*", "pyarrow", diff --git a/sqlspec/adapters/aiomysql/__init__.py b/sqlspec/adapters/aiomysql/__init__.py new file mode 100644 index 000000000..f699ac1c6 --- /dev/null +++ b/sqlspec/adapters/aiomysql/__init__.py @@ -0,0 +1,21 @@ +from sqlspec.adapters.aiomysql._typing import AiomysqlConnection, AiomysqlCursor +from sqlspec.adapters.aiomysql.config import ( + AiomysqlConfig, + AiomysqlConnectionParams, + AiomysqlDriverFeatures, + AiomysqlPoolParams, +) +from sqlspec.adapters.aiomysql.core import default_statement_config +from sqlspec.adapters.aiomysql.driver import AiomysqlDriver, AiomysqlExceptionHandler + +__all__ = ( + "AiomysqlConfig", + "AiomysqlConnection", + "AiomysqlConnectionParams", + "AiomysqlCursor", + "AiomysqlDriver", + "AiomysqlDriverFeatures", + "AiomysqlExceptionHandler", + "AiomysqlPoolParams", + "default_statement_config", +) diff --git a/sqlspec/adapters/aiomysql/_typing.py b/sqlspec/adapters/aiomysql/_typing.py new file mode 100644 index 000000000..c3c3aec37 --- /dev/null +++ b/sqlspec/adapters/aiomysql/_typing.py @@ -0,0 +1,113 @@ +"""aiomysql adapter type definitions. + +This module contains type aliases and classes that are excluded from mypyc +compilation to avoid ABI boundary issues. +""" + +from typing import TYPE_CHECKING, Any + +from aiomysql import Connection # pyright: ignore +from aiomysql.cursors import Cursor as _AiomysqlCursor # pyright: ignore + +if TYPE_CHECKING: + from collections.abc import Callable + from types import TracebackType + from typing import Protocol, TypeAlias + + from sqlspec.adapters.aiomysql.driver import AiomysqlDriver + from sqlspec.core import StatementConfig + + class AiomysqlConnectionProtocol(Protocol): + def cursor(self) -> "AiomysqlRawCursor": ... + + async def commit(self) -> Any: ... + + async def rollback(self) -> Any: ... + + def close(self) -> Any: ... + + AiomysqlConnection: TypeAlias = AiomysqlConnectionProtocol + AiomysqlRawCursor: TypeAlias = _AiomysqlCursor + +if not TYPE_CHECKING: + AiomysqlConnection = Connection + AiomysqlRawCursor = _AiomysqlCursor + + +class AiomysqlCursor: + """Context manager for aiomysql cursor operations. + + Provides automatic cursor acquisition and cleanup for database operations. + """ + + __slots__ = ("connection", "cursor") + + def __init__(self, connection: "AiomysqlConnection") -> None: + self.connection = connection + self.cursor: AiomysqlRawCursor | None = None + + async def __aenter__(self) -> "AiomysqlRawCursor": + self.cursor = self.connection.cursor() + return self.cursor + + async def __aexit__(self, *_: Any) -> None: + if self.cursor is not None: + await self.cursor.close() + + +class AiomysqlSessionContext: + """Async context manager for aiomysql sessions. + + This class is intentionally excluded from mypyc compilation to avoid ABI + boundary issues. It receives callables from uncompiled config classes and + instantiates compiled Driver objects, acting as a bridge between compiled + and uncompiled code. + + Uses callable-based connection management to decouple from config implementation. + """ + + __slots__ = ( + "_acquire_connection", + "_connection", + "_driver", + "_driver_features", + "_prepare_driver", + "_release_connection", + "_statement_config", + ) + + def __init__( + self, + acquire_connection: "Callable[[], Any]", + release_connection: "Callable[[Any], Any]", + statement_config: "StatementConfig", + driver_features: "dict[str, Any]", + prepare_driver: "Callable[[AiomysqlDriver], AiomysqlDriver]", + ) -> None: + self._acquire_connection = acquire_connection + self._release_connection = release_connection + self._statement_config = statement_config + self._driver_features = driver_features + self._prepare_driver = prepare_driver + self._connection: Any = None + self._driver: AiomysqlDriver | None = None + + async def __aenter__(self) -> "AiomysqlDriver": + from sqlspec.adapters.aiomysql.driver import AiomysqlDriver + + self._connection = await self._acquire_connection() + self._driver = AiomysqlDriver( + connection=self._connection, statement_config=self._statement_config, driver_features=self._driver_features + ) + return self._prepare_driver(self._driver) + + async def __aexit__( + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" + ) -> "bool | None": + if self._connection is not None: + await self._release_connection(self._connection) + self._connection = None + return None + + +__all__ = ("AiomysqlConnection", "AiomysqlCursor", "AiomysqlRawCursor", "AiomysqlSessionContext") diff --git a/sqlspec/adapters/aiomysql/config.py b/sqlspec/adapters/aiomysql/config.py new file mode 100644 index 000000000..e7079b0a6 --- /dev/null +++ b/sqlspec/adapters/aiomysql/config.py @@ -0,0 +1,305 @@ +"""aiomysql database configuration.""" + +from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast +from weakref import WeakSet + +import aiomysql # pyright: ignore +from aiomysql.cursors import Cursor, DictCursor # pyright: ignore +from mypy_extensions import mypyc_attr +from typing_extensions import NotRequired + +from sqlspec.adapters.aiomysql._typing import AiomysqlConnection, AiomysqlCursor, AiomysqlSessionContext +from sqlspec.adapters.aiomysql.core import apply_driver_features, default_statement_config +from sqlspec.adapters.aiomysql.driver import AiomysqlDriver, AiomysqlExceptionHandler +from sqlspec.config import AsyncDatabaseConfig, ExtensionConfigs +from sqlspec.driver._async import AsyncPoolConnectionContext, AsyncPoolSessionFactory +from sqlspec.extensions.events import EventRuntimeHints +from sqlspec.utils.config_tools import normalize_connection_config + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + from types import TracebackType + + from aiomysql import Pool # pyright: ignore + from aiomysql.cursors import Cursor, DictCursor # pyright: ignore + + from sqlspec.core import StatementConfig + from sqlspec.observability import ObservabilityConfig + + +__all__ = ("AiomysqlConfig", "AiomysqlConnectionParams", "AiomysqlDriverFeatures", "AiomysqlPoolParams") + + +class AiomysqlConnectionParams(TypedDict): + """aiomysql connection parameters.""" + + host: NotRequired[str] + user: NotRequired[str] + password: NotRequired[str] + db: NotRequired[str] + port: NotRequired[int] + unix_socket: NotRequired[str] + charset: NotRequired[str] + connect_timeout: NotRequired[int] + read_default_file: NotRequired[str] + read_default_group: NotRequired[str] + autocommit: NotRequired[bool] + local_infile: NotRequired[bool] + ssl: NotRequired[Any] + sql_mode: NotRequired[str] + init_command: NotRequired[str] + cursor_class: NotRequired[type["Cursor"] | type["DictCursor"]] + + +class AiomysqlPoolParams(AiomysqlConnectionParams): + """aiomysql pool parameters.""" + + minsize: NotRequired[int] + maxsize: NotRequired[int] + echo: NotRequired[bool] + pool_recycle: NotRequired[int] + + +class AiomysqlDriverFeatures(TypedDict): + """aiomysql driver feature flags. + + MySQL/MariaDB handle JSON natively, but custom serializers can be provided + for specialized use cases (e.g., orjson for performance, msgspec for type safety). + + json_serializer: Custom JSON serializer function. + Defaults to sqlspec.utils.serializers.to_json. + Use for performance (orjson) or custom encoding. + json_deserializer: Custom JSON deserializer function. + Defaults to sqlspec.utils.serializers.from_json. + Use for performance (orjson) or custom decoding. + on_connection_create: Async callback executed when a connection is acquired from pool. + Receives the raw aiomysql connection for low-level driver configuration. + Called exactly once per physical connection using WeakSet tracking. + enable_events: Enable database event channel support. + Defaults to True when extension_config["events"] is configured. + Provides pub/sub capabilities via table-backed queue (MySQL/MariaDB have no native pub/sub). + Requires extension_config["events"] for migration setup. + events_backend: Event channel backend selection. + Only option: "table_queue" (durable table-backed queue with retries and exactly-once delivery). + MySQL/MariaDB do not have native pub/sub, so table_queue is the only backend. + Defaults to "table_queue". + """ + + json_serializer: NotRequired["Callable[[Any], str]"] + json_deserializer: NotRequired["Callable[[str], Any]"] + on_connection_create: "NotRequired[Callable[[AiomysqlConnection], Awaitable[None]]]" + enable_events: NotRequired[bool] + events_backend: NotRequired[str] + + +class _AiomysqlSessionFactory(AsyncPoolSessionFactory): + __slots__ = ("_ctx",) + + def __init__(self, config: "AiomysqlConfig") -> None: + super().__init__(config) + self._ctx: Any | None = None + + async def acquire_connection(self) -> "AiomysqlConnection": + pool = self._config.connection_instance + if pool is None: + pool = await self._config.create_pool() + self._config.connection_instance = pool + ctx = pool.acquire() + self._ctx = ctx + connection = cast("AiomysqlConnection", await ctx.__aenter__()) + await self._config._ensure_connection_initialized(connection) # pyright: ignore[reportPrivateUsage] + return connection + + async def release_connection(self, _conn: "AiomysqlConnection", **kwargs: Any) -> None: + if self._ctx is not None: + await self._ctx.__aexit__(None, None, None) + self._ctx = None + + +class AiomysqlConnectionContext(AsyncPoolConnectionContext): + """Async context manager for aiomysql connections.""" + + __slots__ = ("_ctx",) + + def __init__(self, config: "AiomysqlConfig") -> None: + super().__init__(config) + self._ctx: Any = None + + async def __aenter__(self) -> AiomysqlConnection: + pool = self._config.connection_instance + if pool is None: + pool = await self._config.create_pool() + self._config.connection_instance = pool + ctx = pool.acquire() + self._ctx = ctx + connection = cast("AiomysqlConnection", await ctx.__aenter__()) + await self._config._ensure_connection_initialized(connection) # pyright: ignore[reportPrivateUsage] + return connection + + async def __aexit__( + self, exc_type: "type[BaseException] | None", exc_val: "BaseException | None", exc_tb: "TracebackType | None" + ) -> bool | None: + if self._ctx: + return cast("bool | None", await self._ctx.__aexit__(exc_type, exc_val, exc_tb)) + return None + + +@mypyc_attr(native_class=False) +class AiomysqlConfig(AsyncDatabaseConfig[AiomysqlConnection, "aiomysql.Pool", AiomysqlDriver]): # pyright: ignore + """Configuration for aiomysql database connections. + + Example:: + + config = AiomysqlConfig( + connection_config=AiomysqlPoolParams( + host="localhost", user="root", db="mydb" + ) + ) + """ + + driver_type: ClassVar[type[AiomysqlDriver]] = AiomysqlDriver + connection_type: "ClassVar[type[Any]]" = cast("type[Any]", AiomysqlConnection) + supports_transactional_ddl: ClassVar[bool] = False + supports_native_arrow_export: ClassVar[bool] = True + supports_native_parquet_export: ClassVar[bool] = True + supports_native_arrow_import: ClassVar[bool] = True + supports_native_parquet_import: ClassVar[bool] = True + _connection_context_class: "ClassVar[type[AiomysqlConnectionContext]]" = AiomysqlConnectionContext + _session_factory_class: "ClassVar[type[_AiomysqlSessionFactory]]" = _AiomysqlSessionFactory + _session_context_class: "ClassVar[type[AiomysqlSessionContext]]" = AiomysqlSessionContext + _default_statement_config = default_statement_config + + def __init__( + self, + *, + connection_config: "AiomysqlPoolParams | dict[str, Any] | None" = None, + connection_instance: "aiomysql.Pool | None" = None, + migration_config: "dict[str, Any] | None" = None, + statement_config: "StatementConfig | None" = None, + driver_features: "AiomysqlDriverFeatures | dict[str, Any] | None" = None, + bind_key: "str | None" = None, + extension_config: "ExtensionConfigs | None" = None, + observability_config: "ObservabilityConfig | None" = None, + **kwargs: Any, + ) -> None: + """Initialize aiomysql configuration. + + Args: + connection_config: Connection and pool configuration parameters + connection_instance: Existing pool instance to use + migration_config: Migration configuration + statement_config: Statement configuration override + driver_features: Driver feature configuration (TypedDict or dict) + bind_key: Optional unique identifier for this configuration + extension_config: Extension-specific configuration (e.g., Litestar plugin settings) + observability_config: Adapter-level observability overrides for lifecycle hooks and observers + **kwargs: Additional keyword arguments + """ + connection_config = normalize_connection_config(connection_config) + + connection_config.setdefault("host", "localhost") + connection_config.setdefault("port", 3306) + + statement_config = statement_config or default_statement_config + statement_config, driver_features = apply_driver_features(statement_config, driver_features) + + # Extract user connection hook before storing driver_features + features_dict = dict(driver_features) if driver_features else {} + self._user_connection_hook: Callable[[AiomysqlConnection], Awaitable[None]] | None = features_dict.pop( + "on_connection_create", None + ) + # Track initialized connections to ensure callback runs exactly once per physical connection + self._initialized_connections: WeakSet[Any] = WeakSet() + + super().__init__( + connection_config=connection_config, + connection_instance=connection_instance, + migration_config=migration_config, + statement_config=statement_config, + driver_features=features_dict, + bind_key=bind_key, + extension_config=extension_config, + observability_config=observability_config, + **kwargs, + ) + + async def _create_pool(self) -> "aiomysql.Pool": + """Create the actual async connection pool. + + MySQL/MariaDB handle JSON types natively without requiring connection-level + type handlers. JSON serialization is handled via type_coercion_map in the + driver's statement_config (see driver.py). + """ + return cast("aiomysql.Pool", await aiomysql.create_pool(**dict(self.connection_config))) + + async def _ensure_connection_initialized(self, connection: "AiomysqlConnection") -> None: + """Ensure connection callback has been called exactly once for this connection. + + Uses WeakSet tracking to ensure the callback runs once per physical connection. + """ + if self._user_connection_hook is None: + return + if connection not in self._initialized_connections: + await self._user_connection_hook(connection) + self._initialized_connections.add(connection) + + async def _close_pool(self) -> None: + """Close the actual async connection pool.""" + if self.connection_instance: + self.connection_instance.close() + await self.connection_instance.wait_closed() + self.connection_instance = None + + async def close_pool(self) -> None: + """Close the connection pool.""" + await self._close_pool() + + async def create_connection(self) -> AiomysqlConnection: + """Create a single async connection (not from pool). + + Returns: + An aiomysql connection instance. + """ + pool = self.connection_instance + if pool is None: + pool = await self.create_pool() + self.connection_instance = pool + connection = cast("AiomysqlConnection", await pool.acquire()) + await self._ensure_connection_initialized(connection) + return connection + + async def provide_pool(self, *args: Any, **kwargs: Any) -> "Pool": + """Provide async pool instance. + + Returns: + The async connection pool. + """ + if not self.connection_instance: + self.connection_instance = await self.create_pool() + return self.connection_instance + + def get_signature_namespace(self) -> "dict[str, Any]": + """Get the signature namespace for aiomysql types. + + Returns: + Dictionary mapping type names to types. + """ + + namespace = super().get_signature_namespace() + namespace.update({ + "AiomysqlConnectionContext": AiomysqlConnectionContext, + "AiomysqlConnection": AiomysqlConnection, + "AiomysqlConnectionParams": AiomysqlConnectionParams, + "AiomysqlCursor": AiomysqlCursor, + "AiomysqlDriver": AiomysqlDriver, + "AiomysqlDriverFeatures": AiomysqlDriverFeatures, + "AiomysqlExceptionHandler": AiomysqlExceptionHandler, + "AiomysqlPoolParams": AiomysqlPoolParams, + "AiomysqlSessionContext": AiomysqlSessionContext, + }) + return namespace + + def get_event_runtime_hints(self) -> "EventRuntimeHints": + """Return queue polling defaults for aiomysql adapters.""" + + return EventRuntimeHints(poll_interval=0.25, lease_seconds=5, select_for_update=True, skip_locked=True) diff --git a/sqlspec/adapters/aiomysql/core.py b/sqlspec/adapters/aiomysql/core.py new file mode 100644 index 000000000..696ebed59 --- /dev/null +++ b/sqlspec/adapters/aiomysql/core.py @@ -0,0 +1,500 @@ +"""aiomysql adapter compiled helpers.""" + +from collections.abc import Callable, Sized +from typing import TYPE_CHECKING, Any + +from sqlspec.core import DriverParameterProfile, ParameterStyle, StatementConfig, build_statement_config_from_profile +from sqlspec.exceptions import ( + CheckViolationError, + ConnectionTimeoutError, + DatabaseConnectionError, + DataError, + DeadlockError, + ForeignKeyViolationError, + IntegrityError, + NotNullViolationError, + PermissionDeniedError, + QueryTimeoutError, + SQLParsingError, + SQLSpecError, + TransactionError, + UniqueViolationError, +) +from sqlspec.utils.serializers import from_json, to_json +from sqlspec.utils.type_converters import build_uuid_coercions +from sqlspec.utils.type_guards import has_cursor_metadata, has_lastrowid, has_rowcount, has_sqlstate + +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + +__all__ = ( + "apply_driver_features", + "build_insert_statement", + "build_profile", + "build_statement_config", + "collect_rows", + "create_mapped_exception", + "default_statement_config", + "detect_json_columns", + "detect_json_columns_from_description", + "driver_profile", + "format_identifier", + "normalize_execute_many_parameters", + "normalize_execute_parameters", + "normalize_lastrowid", + "resolve_column_names", + "resolve_many_rowcount", + "resolve_rowcount", +) + +# MySQL error codes for constraint violations +MYSQL_ER_DUP_ENTRY = 1062 +MYSQL_ER_NO_DEFAULT_FOR_FIELD = 1364 +MYSQL_ER_CHECK_CONSTRAINT_VIOLATED = 3819 + +# MySQL error codes for permission/access errors +MYSQL_ER_DBACCESS_DENIED = 1044 +MYSQL_ER_ACCESS_DENIED = 1045 +MYSQL_ER_TABLEACCESS_DENIED = 1142 + +# MySQL error codes for transaction errors +MYSQL_ER_LOCK_WAIT_TIMEOUT = 1205 +MYSQL_ER_LOCK_DEADLOCK = 1213 + +# MySQL error codes for connection errors +MYSQL_CR_CONNECTION_ERROR = 2002 +MYSQL_CR_CONN_HOST_ERROR = 2003 +MYSQL_CR_UNKNOWN_HOST = 2005 +MYSQL_CR_SERVER_GONE_ERROR = 2006 +MYSQL_CR_SERVER_LOST = 2013 + + +def _bool_to_int(value: bool) -> int: + return int(value) + + +def _quote_mysql_identifier(identifier: str) -> str: + normalized = identifier.replace("`", "``") + return f"`{normalized}`" + + +def format_identifier(identifier: str) -> str: + cleaned = identifier.strip() + if not cleaned: + msg = "Table name must not be empty" + raise SQLSpecError(msg) + parts = [part for part in cleaned.split(".") if part] + formatted = ".".join(_quote_mysql_identifier(part) for part in parts) + return formatted or _quote_mysql_identifier(cleaned) + + +def build_insert_statement(table: str, columns: "list[str]") -> str: + column_clause = ", ".join(_quote_mysql_identifier(column) for column in columns) + placeholders = ", ".join("%s" for _ in columns) + return f"INSERT INTO {format_identifier(table)} ({column_clause}) VALUES ({placeholders})" + + +def normalize_execute_parameters(parameters: Any) -> Any: + """Normalize parameters for aiomysql execute calls. + + Args: + parameters: Prepared parameters payload. + + Returns: + Normalized parameters payload. + """ + return parameters or None + + +def normalize_execute_many_parameters(parameters: Any) -> Any: + """Normalize parameters for aiomysql executemany calls. + + Args: + parameters: Prepared parameters payload. + + Returns: + Normalized parameters payload. + + Raises: + ValueError: When parameters are missing for executemany. + """ + if not parameters: + msg = "execute_many requires parameters" + raise ValueError(msg) + return parameters + + +def build_profile() -> "DriverParameterProfile": + """Create the aiomysql driver parameter profile.""" + coercions: dict[type, Callable[[Any], Any]] = {bool: _bool_to_int, **build_uuid_coercions()} + return DriverParameterProfile( + name="aiomysql", + default_style=ParameterStyle.QMARK, + supported_styles={ParameterStyle.QMARK}, + default_execution_style=ParameterStyle.POSITIONAL_PYFORMAT, + supported_execution_styles={ParameterStyle.POSITIONAL_PYFORMAT}, + has_native_list_expansion=False, + preserve_parameter_format=True, + needs_static_script_compilation=True, + allow_mixed_parameter_styles=False, + preserve_original_params_for_many=False, + json_serializer_strategy="helper", + custom_type_coercions=coercions, + default_dialect="mysql", + ) + + +driver_profile = build_profile() + + +def build_statement_config( + *, json_serializer: "Callable[[Any], str] | None" = None, json_deserializer: "Callable[[str], Any] | None" = None +) -> "StatementConfig": + """Construct the aiomysql statement configuration with optional JSON codecs.""" + serializer = json_serializer or to_json + deserializer = json_deserializer or from_json + profile = driver_profile + return build_statement_config_from_profile( + profile, statement_overrides={"dialect": "mysql"}, json_serializer=serializer, json_deserializer=deserializer + ) + + +default_statement_config = build_statement_config() + + +def apply_driver_features( + statement_config: "StatementConfig", driver_features: "Mapping[str, Any] | None" +) -> "tuple[StatementConfig, dict[str, Any]]": + """Apply aiomysql driver feature defaults to statement config.""" + features: dict[str, Any] = dict(driver_features) if driver_features else {} + json_serializer = features.setdefault("json_serializer", to_json) + json_deserializer = features.setdefault("json_deserializer", from_json) + + if json_serializer is not None: + parameter_config = statement_config.parameter_config.with_json_serializers( + json_serializer, deserializer=json_deserializer + ) + statement_config = statement_config.replace(parameter_config=parameter_config) + + return statement_config, features + + +def _create_mysql_error( + error: Any, sqlstate: "str | None", code: "int | None", error_class: type[SQLSpecError], description: str +) -> SQLSpecError: + """Create a MySQL error instance without raising it.""" + code_str = f"[{sqlstate or code}]" if sqlstate or code else "" + msg = f"MySQL {description} {code_str}: {error}" if code_str else f"MySQL {description}: {error}" + exc = error_class(msg) + exc.__cause__ = error + return exc + + +def create_mapped_exception(error: Any, *, logger: Any | None = None) -> "SQLSpecError | bool": + """Map aiomysql exceptions to SQLSpec errors. + + aiomysql re-exports PyMySQL's error hierarchy (pymysql.err.*), so MySQL + error codes and SQLSTATE values are identical to asyncmy — only the Python + exception class import paths differ. + + Mapping priority: + 1. Specific error codes (most reliable for MySQL) + 2. SQLSTATE codes (where available) + 3. Generic error code ranges + 4. Default SQLSpecError fallback + + Args: + error: The aiomysql/pymysql exception to map + logger: Optional logger for migration warnings + + Returns: + True to suppress expected migration errors, or a SQLSpec exception + """ + error_code = error.args[0] if len(error.args) >= 1 and isinstance(error.args[0], int) else None + sqlstate_attr = error.sqlstate if has_sqlstate(error) else None + sqlstate = sqlstate_attr if sqlstate_attr is not None else None + + # Migration-specific errors to suppress + if error_code in {1061, 1091}: + if logger is not None: + logger.warning("aiomysql MySQL expected migration error (ignoring): %s", error) + return True + + # Integrity constraint violations + if sqlstate == "23505" or error_code == MYSQL_ER_DUP_ENTRY: + return _create_mysql_error(error, sqlstate, error_code, UniqueViolationError, "unique constraint violation") + if sqlstate == "23503" or error_code in {1216, 1217, 1451, 1452}: + return _create_mysql_error( + error, sqlstate, error_code, ForeignKeyViolationError, "foreign key constraint violation" + ) + if sqlstate == "23502" or error_code in {1048, MYSQL_ER_NO_DEFAULT_FOR_FIELD}: + return _create_mysql_error(error, sqlstate, error_code, NotNullViolationError, "not-null constraint violation") + if sqlstate == "23514" or error_code == MYSQL_ER_CHECK_CONSTRAINT_VIOLATED: + return _create_mysql_error(error, sqlstate, error_code, CheckViolationError, "check constraint violation") + if sqlstate and sqlstate.startswith("23"): + return _create_mysql_error(error, sqlstate, error_code, IntegrityError, "integrity constraint violation") + + # Permission/access errors (check specific codes first) + if error_code in {MYSQL_ER_DBACCESS_DENIED, MYSQL_ER_ACCESS_DENIED, MYSQL_ER_TABLEACCESS_DENIED}: + return _create_mysql_error(error, sqlstate, error_code, PermissionDeniedError, "access denied") + if sqlstate and sqlstate.startswith("28"): + return _create_mysql_error(error, sqlstate, error_code, PermissionDeniedError, "authorization error") + + # Transaction errors (deadlock vs lock wait timeout) + if error_code == MYSQL_ER_LOCK_DEADLOCK: + return _create_mysql_error(error, sqlstate, error_code, DeadlockError, "deadlock detected") + if error_code == MYSQL_ER_LOCK_WAIT_TIMEOUT: + return _create_mysql_error(error, sqlstate, error_code, QueryTimeoutError, "lock wait timeout") + if sqlstate and sqlstate.startswith("40"): + return _create_mysql_error(error, sqlstate, error_code, TransactionError, "transaction error") + + # SQL syntax errors + if sqlstate and sqlstate.startswith("42"): + return _create_mysql_error(error, sqlstate, error_code, SQLParsingError, "SQL syntax error") + if error_code in range(1064, 1100): + return _create_mysql_error(error, sqlstate, error_code, SQLParsingError, "SQL syntax error") + + # Connection errors + if sqlstate and sqlstate.startswith("08"): + return _create_mysql_error(error, sqlstate, error_code, DatabaseConnectionError, "connection error") + if error_code == MYSQL_CR_SERVER_LOST: + return _create_mysql_error(error, sqlstate, error_code, ConnectionTimeoutError, "connection lost") + if error_code in { + MYSQL_CR_CONNECTION_ERROR, + MYSQL_CR_CONN_HOST_ERROR, + MYSQL_CR_UNKNOWN_HOST, + MYSQL_CR_SERVER_GONE_ERROR, + }: + return _create_mysql_error(error, sqlstate, error_code, DatabaseConnectionError, "connection error") + + # Data errors + if sqlstate and sqlstate.startswith("22"): + return _create_mysql_error(error, sqlstate, error_code, DataError, "data error") + + return _create_mysql_error(error, sqlstate, error_code, SQLSpecError, "database error") + + +def resolve_column_names(description: "Sequence[Any] | None") -> "list[str]": + """Resolve ordered column names from cursor metadata.""" + if not description: + return [] + return [desc[0] for desc in description] + + +def detect_json_columns_from_description( + description: "Sequence[Any] | None", json_type_codes: "set[int]" +) -> "list[int]": + """Identify JSON column indexes from pre-fetched cursor description metadata.""" + if not description or not json_type_codes: + return [] + + json_indexes: list[int] = [] + append = json_indexes.append + for index, column in enumerate(description): + if isinstance(column, (tuple, list)): + type_code = column[1] if len(column) > 1 else None + else: + type_code = getattr(column, "type_code", None) + if type_code in json_type_codes: + append(index) + return json_indexes + + +def detect_json_columns( + cursor: Any, json_type_codes: "set[int]", description: "Sequence[Any] | None" = None +) -> "list[int]": + """Identify JSON column indexes from cursor metadata. + + Args: + cursor: Database cursor with description metadata available. + json_type_codes: Set of type codes identifying JSON columns. + description: Optional pre-fetched cursor description metadata. + + Returns: + List of index positions where JSON values are present. + """ + if description is None: + if not has_cursor_metadata(cursor): + return [] + description = cursor.description + return detect_json_columns_from_description(description, json_type_codes) + + +def _deserialize_json_dict_rows( + column_names: "list[str]", + rows: "list[dict[str, Any]]", + json_indexes: "list[int]", + deserializer: "Callable[[Any], Any]", + *, + logger: Any | None = None, +) -> "list[dict[str, Any]]": + """Apply JSON deserialization to dict rows (DictCursor path). + + Args: + column_names: Ordered column names from the cursor description. + rows: Result rows represented as dictionaries. + json_indexes: Column indexes to deserialize. + deserializer: Callable used to decode JSON values. + logger: Optional logger for debug output. + + Returns: + Rows with JSON columns decoded when possible. + """ + if not rows or not column_names or not json_indexes: + return rows + + target_columns = [column_names[index] for index in json_indexes if index < len(column_names)] + if not target_columns: + return rows + + for row in rows: + for column in target_columns: + if column not in row: + continue + raw_value = row[column] + if raw_value is None: + continue + if isinstance(raw_value, bytearray): + raw_value = bytes(raw_value) + if not isinstance(raw_value, (str, bytes)): + continue + try: + row[column] = deserializer(raw_value) + except Exception: + if logger is not None: + logger.debug("Failed to deserialize JSON column %s", column, exc_info=True) + return rows + + +def _deserialize_json_tuple_rows( + rows: "list[Any]", json_indexes: "list[int]", deserializer: "Callable[[Any], Any]", *, logger: Any | None = None +) -> "list[Any]": + """Apply JSON deserialization to tuple rows using index-based access. + + Args: + rows: Result rows as tuples. + json_indexes: Column indexes to deserialize. + deserializer: Callable used to decode JSON values. + logger: Optional logger for debug output. + + Returns: + Rows with JSON columns decoded when possible. + """ + if not rows or not json_indexes: + return rows + + result: list[Any] = [] + for row in rows: + row_list = list(row) + mutated = False + for idx in json_indexes: + if idx >= len(row_list): + continue + raw_value = row_list[idx] + if raw_value is None: + continue + if isinstance(raw_value, bytearray): + raw_value = bytes(raw_value) + if not isinstance(raw_value, (str, bytes)): + continue + try: + row_list[idx] = deserializer(raw_value) + mutated = True + except Exception: + if logger is not None: + logger.debug("Failed to deserialize JSON column index %d", idx, exc_info=True) + result.append(tuple(row_list) if mutated else row) + return result + + +def collect_rows( + fetched_data: "Sequence[Any] | None", + description: "Sequence[Any] | None", + json_indexes: "list[int]", + deserializer: "Callable[[Any], Any]", + *, + column_names: "list[str] | None" = None, + logger: Any | None = None, +) -> "tuple[list[Any], list[str], str]": + """Collect aiomysql rows with JSON decoding, preserving raw format. + + Args: + fetched_data: Rows returned from cursor.fetchall(). + description: Cursor description metadata. + json_indexes: Column indexes containing JSON values. + deserializer: JSON deserializer function. + column_names: Optional precomputed ordered column names. + logger: Optional logger for debug output. + + Returns: + Tuple of (rows, column_names, row_format). + """ + if not description: + return [], [], "tuple" + resolved_column_names = resolve_column_names(description) if column_names is None else column_names + if not fetched_data: + return [], resolved_column_names, "tuple" + + first_row = fetched_data[0] + if isinstance(first_row, dict): + if not json_indexes: + rows = fetched_data if isinstance(fetched_data, list) else list(fetched_data) + return rows, resolved_column_names, "dict" + rows = [dict(row) for row in fetched_data] + rows = _deserialize_json_dict_rows( + resolved_column_names, rows, json_indexes, deserializer, logger=logger + ) + return rows, resolved_column_names, "dict" + rows = fetched_data if isinstance(fetched_data, list) else list(fetched_data) + if json_indexes: + rows = _deserialize_json_tuple_rows(rows, json_indexes, deserializer, logger=logger) + return rows, resolved_column_names, "tuple" + + +def resolve_rowcount(cursor: Any) -> int: + """Resolve rowcount from an aiomysql cursor. + + Args: + cursor: aiomysql cursor with optional rowcount metadata. + + Returns: + Rowcount value or 0 when unknown. + """ + if not has_rowcount(cursor): + return 0 + rowcount = cursor.rowcount + if isinstance(rowcount, int) and rowcount >= 0: + return rowcount + return 0 + + +def resolve_many_rowcount(cursor: Any, parameters: Any, *, fallback_count: "int | None" = None) -> int: + """Resolve execute_many rowcount using cursor metadata with payload fallback.""" + rowcount = resolve_rowcount(cursor) + if rowcount > 0: + return rowcount + if fallback_count is not None: + return fallback_count + if isinstance(parameters, Sized): + return len(parameters) + return 0 + + +def normalize_lastrowid(cursor: Any) -> int | None: + """Normalize lastrowid for aiomysql when rowcount indicates success. + + Args: + cursor: aiomysql cursor with optional lastrowid metadata. + + Returns: + Last inserted id or None when unavailable. + """ + if not has_rowcount(cursor): + return None + rowcount = cursor.rowcount + if not isinstance(rowcount, int) or rowcount <= 0: + return None + if not has_lastrowid(cursor): + return None + last_id = cursor.lastrowid + return last_id if isinstance(last_id, int) else None diff --git a/sqlspec/adapters/aiomysql/data_dictionary.py b/sqlspec/adapters/aiomysql/data_dictionary.py new file mode 100644 index 000000000..6ebcc8f30 --- /dev/null +++ b/sqlspec/adapters/aiomysql/data_dictionary.py @@ -0,0 +1,123 @@ +"""MySQL-specific data dictionary for metadata queries via aiomysql.""" + +from typing import TYPE_CHECKING, ClassVar + +from mypy_extensions import mypyc_attr + +from sqlspec.driver import AsyncDataDictionaryBase +from sqlspec.typing import ColumnMetadata, ForeignKeyMetadata, IndexMetadata, TableMetadata, VersionInfo + +if TYPE_CHECKING: + from sqlspec.adapters.aiomysql.driver import AiomysqlDriver + +__all__ = ("AiomysqlDataDictionary",) + + +@mypyc_attr(allow_interpreted_subclasses=True, native_class=False) +class AiomysqlDataDictionary(AsyncDataDictionaryBase): + """MySQL-specific async data dictionary.""" + + dialect: ClassVar[str] = "mysql" + + def __init__(self) -> None: + super().__init__() + + async def get_version(self, driver: "AiomysqlDriver") -> "VersionInfo | None": + """Get MySQL database version information.""" + driver_id = id(driver) + # Inline cache check to avoid cross-module method call that causes mypyc segfault + if driver_id in self._version_fetch_attempted: + return self._version_cache.get(driver_id) + + version_value = await driver.select_value_or_none(self.get_query("version")) + if not version_value: + self._log_version_unavailable(type(self).dialect, "missing") + self.cache_version(driver_id, None) + return None + + version_info = self.parse_version_with_pattern(self.get_dialect_config().version_pattern, str(version_value)) + if version_info is None: + self._log_version_unavailable(type(self).dialect, "parse_failed") + self.cache_version(driver_id, None) + return None + + self._log_version_detected(type(self).dialect, version_info) + self.cache_version(driver_id, version_info) + return version_info + + async def get_feature_flag(self, driver: "AiomysqlDriver", feature: str) -> bool: + """Check if MySQL database supports a specific feature.""" + version_info = await self.get_version(driver) + return self.resolve_feature_flag(feature, version_info) + + async def get_optimal_type(self, driver: "AiomysqlDriver", type_category: str) -> str: + """Get optimal MySQL type for a category.""" + config = self.get_dialect_config() + version_info = await self.get_version(driver) + + if type_category == "json": + json_version = config.get_feature_version("supports_json") + if version_info and json_version and version_info >= json_version: + return "JSON" + return "TEXT" + + return config.get_optimal_type(type_category) + + async def get_tables(self, driver: "AiomysqlDriver", schema: "str | None" = None) -> "list[TableMetadata]": + """Get tables sorted by topological dependency order using the MySQL catalog.""" + schema_name = self.resolve_schema(schema) + self._log_schema_introspect(driver, schema_name=schema_name, table_name=None, operation="tables") + return await driver.select( + self.get_query("tables_by_schema"), schema_name=schema_name, schema_type=TableMetadata + ) + + async def get_columns( + self, driver: "AiomysqlDriver", table: "str | None" = None, schema: "str | None" = None + ) -> "list[ColumnMetadata]": + """Get column information for a table or schema.""" + schema_name = self.resolve_schema(schema) + if table is None: + self._log_schema_introspect(driver, schema_name=schema_name, table_name=None, operation="columns") + return await driver.select( + self.get_query("columns_by_schema"), schema_name=schema_name, schema_type=ColumnMetadata + ) + + self._log_table_describe(driver, schema_name=schema_name, table_name=table, operation="columns") + return await driver.select( + self.get_query("columns_by_table"), table_name=table, schema_name=schema_name, schema_type=ColumnMetadata + ) + + async def get_indexes( + self, driver: "AiomysqlDriver", table: "str | None" = None, schema: "str | None" = None + ) -> "list[IndexMetadata]": + """Get index metadata for a table or schema.""" + schema_name = self.resolve_schema(schema) + if table is None: + self._log_schema_introspect(driver, schema_name=schema_name, table_name=None, operation="indexes") + return await driver.select( + self.get_query("indexes_by_schema"), schema_name=schema_name, schema_type=IndexMetadata + ) + + self._log_table_describe(driver, schema_name=schema_name, table_name=table, operation="indexes") + return await driver.select( + self.get_query("indexes_by_table"), table_name=table, schema_name=schema_name, schema_type=IndexMetadata + ) + + async def get_foreign_keys( + self, driver: "AiomysqlDriver", table: "str | None" = None, schema: "str | None" = None + ) -> "list[ForeignKeyMetadata]": + """Get foreign key metadata.""" + schema_name = self.resolve_schema(schema) + if table is None: + self._log_schema_introspect(driver, schema_name=schema_name, table_name=None, operation="foreign_keys") + return await driver.select( + self.get_query("foreign_keys_by_schema"), schema_name=schema_name, schema_type=ForeignKeyMetadata + ) + + self._log_table_describe(driver, schema_name=schema_name, table_name=table, operation="foreign_keys") + return await driver.select( + self.get_query("foreign_keys_by_table"), + table_name=table, + schema_name=schema_name, + schema_type=ForeignKeyMetadata, + ) diff --git a/sqlspec/adapters/aiomysql/driver.py b/sqlspec/adapters/aiomysql/driver.py new file mode 100644 index 000000000..0a936b6fb --- /dev/null +++ b/sqlspec/adapters/aiomysql/driver.py @@ -0,0 +1,381 @@ +"""aiomysql MySQL driver implementation. + +Provides MySQL/MariaDB connectivity with parameter style conversion, +type coercion, error handling, and transaction management. + +aiomysql is built on top of PyMySQL, so error classes come from pymysql.err +rather than a driver-specific error module. +""" + +from collections.abc import Sized +from typing import TYPE_CHECKING, Any, Final, cast + +import pymysql.err # pyright: ignore +from pymysql.constants import FIELD_TYPE as PYMYSQL_FIELD_TYPE # pyright: ignore + +from sqlspec.adapters.aiomysql._typing import AiomysqlCursor, AiomysqlSessionContext +from sqlspec.adapters.aiomysql.core import ( + build_insert_statement, + collect_rows, + create_mapped_exception, + default_statement_config, + detect_json_columns_from_description, + driver_profile, + format_identifier, + normalize_execute_many_parameters, + normalize_execute_parameters, + normalize_lastrowid, + resolve_column_names, + resolve_many_rowcount, + resolve_rowcount, +) +from sqlspec.adapters.aiomysql.data_dictionary import AiomysqlDataDictionary +from sqlspec.core import ArrowResult, get_cache_config, register_driver_profile +from sqlspec.driver import AsyncDriverAdapterBase, BaseAsyncExceptionHandler +from sqlspec.exceptions import SQLSpecError +from sqlspec.utils.logging import get_logger +from sqlspec.utils.serializers import from_json +from sqlspec.utils.type_guards import supports_json_type + +if TYPE_CHECKING: + from collections.abc import Callable + + from sqlspec.adapters.aiomysql._typing import AiomysqlConnection + from sqlspec.core import SQL, StatementConfig + from sqlspec.driver import ExecutionResult + from sqlspec.storage import StorageBridgeJob, StorageDestination, StorageFormat, StorageTelemetry + +__all__ = ("AiomysqlCursor", "AiomysqlDriver", "AiomysqlExceptionHandler", "AiomysqlSessionContext") + +logger = get_logger(__name__) + +json_type_value = ( + PYMYSQL_FIELD_TYPE.JSON if PYMYSQL_FIELD_TYPE is not None and supports_json_type(PYMYSQL_FIELD_TYPE) else None +) +AIOMYSQL_JSON_TYPE_CODES: Final[set[int]] = {json_type_value} if json_type_value is not None else set() + + +class AiomysqlExceptionHandler(BaseAsyncExceptionHandler): + """Async context manager for handling aiomysql (MySQL) database exceptions. + + Maps MySQL error codes and SQLSTATE to specific SQLSpec exceptions + for better error handling in application code. + + aiomysql uses pymysql.err.Error as its base exception class since + aiomysql is built on top of PyMySQL. + + Uses deferred exception pattern for mypyc compatibility: exceptions + are stored in pending_exception rather than raised from __aexit__ + to avoid ABI boundary violations with compiled code. + """ + + __slots__ = () + + def _handle_exception(self, exc_type: "type[BaseException] | None", exc_val: "BaseException") -> bool: + if exc_type is None: + return False + if issubclass(exc_type, pymysql.err.Error): + result = create_mapped_exception(exc_val, logger=logger) + if result is True: + return True + self.pending_exception = cast("Exception", result) + return True + return False + + +class AiomysqlDriver(AsyncDriverAdapterBase): + """MySQL/MariaDB database driver using aiomysql client library. + + Implements asynchronous database operations for MySQL and MariaDB servers + with support for parameter style conversion, type coercion, error handling, + and transaction management. + """ + + __slots__ = ("_data_dictionary",) + dialect = "mysql" + + def __init__( + self, + connection: "AiomysqlConnection", + statement_config: "StatementConfig | None" = None, + driver_features: "dict[str, Any] | None" = None, + ) -> None: + if statement_config is None: + statement_config = default_statement_config.replace( + enable_caching=get_cache_config().compiled_cache_enabled + ) + + super().__init__(connection=connection, statement_config=statement_config, driver_features=driver_features) + self._data_dictionary: AiomysqlDataDictionary | None = None + + # ───────────────────────────────────────────────────────────────────────────── + # CORE DISPATCH METHODS - The Execution Engine + # ───────────────────────────────────────────────────────────────────────────── + + async def dispatch_execute(self, cursor: Any, statement: "SQL") -> "ExecutionResult": + """Execute single SQL statement. + + Args: + cursor: aiomysql cursor object + statement: SQL statement to execute + + Returns: + ExecutionResult: Statement execution results with data or row counts + """ + sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) + await cursor.execute(sql, normalize_execute_parameters(prepared_parameters)) + + if statement.returns_rows(): + fetched_data = await cursor.fetchall() + description = cursor.description or None + column_names = resolve_column_names(description) + json_indexes = detect_json_columns_from_description(description, AIOMYSQL_JSON_TYPE_CODES) + deserializer = cast("Callable[[Any], Any]", self.driver_features.get("json_deserializer", from_json)) + rows, column_names, row_format = collect_rows( + fetched_data, description, json_indexes, deserializer, column_names=column_names, logger=logger + ) + + return self.create_execution_result( + cursor, + selected_data=rows, + column_names=column_names, + data_row_count=len(rows), + is_select_result=True, + row_format=row_format, + ) + + affected_rows = resolve_rowcount(cursor) + last_id = normalize_lastrowid(cursor) + return self.create_execution_result(cursor, rowcount_override=affected_rows, last_inserted_id=last_id) + + async def dispatch_execute_many(self, cursor: Any, statement: "SQL") -> "ExecutionResult": + """Execute SQL statement with multiple parameter sets. + + Args: + cursor: aiomysql cursor object + statement: SQL statement with multiple parameter sets + + Returns: + ExecutionResult: Batch execution results + + Raises: + ValueError: If no parameters provided for executemany operation + """ + sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) + + prepared_parameters = normalize_execute_many_parameters(prepared_parameters) + parameter_count = len(prepared_parameters) if isinstance(prepared_parameters, Sized) else None + await cursor.executemany(sql, prepared_parameters) + + affected_rows = resolve_many_rowcount(cursor, prepared_parameters, fallback_count=parameter_count) + + return self.create_execution_result(cursor, rowcount_override=affected_rows, is_many_result=True) + + async def dispatch_execute_script(self, cursor: Any, statement: "SQL") -> "ExecutionResult": + """Execute SQL script with statement splitting and parameter handling. + + Args: + cursor: aiomysql cursor object + statement: SQL script to execute + + Returns: + ExecutionResult: Script execution results with statement count + """ + sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config) + statements = self.split_script_statements(sql, statement.statement_config, strip_trailing_semicolon=True) + + successful_count = 0 + last_cursor = cursor + + for stmt in statements: + await cursor.execute(stmt, normalize_execute_parameters(prepared_parameters)) + successful_count += 1 + + return self.create_execution_result( + last_cursor, statement_count=len(statements), successful_statements=successful_count, is_script_result=True + ) + + # ───────────────────────────────────────────────────────────────────────────── + # TRANSACTION MANAGEMENT + # ───────────────────────────────────────────────────────────────────────────── + + async def begin(self) -> None: + """Begin a database transaction. + + Raises: + SQLSpecError: If transaction initialization fails + """ + try: + async with AiomysqlCursor(self.connection) as cursor: + await cursor.execute("BEGIN") + except pymysql.err.MySQLError as e: + msg = f"Failed to begin MySQL transaction: {e}" + raise SQLSpecError(msg) from e + + async def commit(self) -> None: + """Commit the current transaction. + + Raises: + SQLSpecError: If transaction commit fails + """ + try: + await self.connection.commit() + except pymysql.err.MySQLError as e: + msg = f"Failed to commit MySQL transaction: {e}" + raise SQLSpecError(msg) from e + + async def rollback(self) -> None: + """Rollback the current transaction. + + Raises: + SQLSpecError: If transaction rollback fails + """ + try: + await self.connection.rollback() + except pymysql.err.MySQLError as e: + msg = f"Failed to rollback MySQL transaction: {e}" + raise SQLSpecError(msg) from e + + def with_cursor(self, connection: "AiomysqlConnection") -> "AiomysqlCursor": + """Create cursor context manager for the connection. + + Args: + connection: aiomysql database connection + + Returns: + AiomysqlCursor: Context manager for cursor operations + """ + return AiomysqlCursor(connection) + + def handle_database_exceptions(self) -> "AiomysqlExceptionHandler": + """Provide exception handling context manager. + + Returns: + AiomysqlExceptionHandler: Context manager for aiomysql exception handling + """ + return AiomysqlExceptionHandler() + + # ───────────────────────────────────────────────────────────────────────────── + # STORAGE API METHODS + # ───────────────────────────────────────────────────────────────────────────── + + async def select_to_storage( + self, + statement: "SQL | str", + destination: "StorageDestination", + /, + *parameters: Any, + statement_config: "StatementConfig | None" = None, + partitioner: "dict[str, object] | None" = None, + format_hint: "StorageFormat | None" = None, + telemetry: "StorageTelemetry | None" = None, + **kwargs: Any, + ) -> "StorageBridgeJob": + """Execute a query and stream Arrow-formatted results into storage.""" + + self._require_capability("arrow_export_enabled") + arrow_result = await self.select_to_arrow(statement, *parameters, statement_config=statement_config, **kwargs) + async_pipeline = self._storage_pipeline() + telemetry_payload = await self._write_result_to_storage_async( + arrow_result, destination, format_hint=format_hint, pipeline=async_pipeline + ) + self._attach_partition_telemetry(telemetry_payload, partitioner) + return self._create_storage_job(telemetry_payload, telemetry) + + async def load_from_arrow( + self, + table: str, + source: "ArrowResult | Any", + *, + partitioner: "dict[str, object] | None" = None, + overwrite: bool = False, + telemetry: "StorageTelemetry | None" = None, + ) -> "StorageBridgeJob": + """Load Arrow data into MySQL using batched inserts.""" + + self._require_capability("arrow_import_enabled") + arrow_table = self._coerce_arrow_table(source) + if overwrite: + statement = f"TRUNCATE TABLE {format_identifier(table)}" + exc_handler = self.handle_database_exceptions() + async with exc_handler, self.with_cursor(self.connection) as cursor: + await cursor.execute(statement) + if exc_handler.pending_exception is not None: + raise exc_handler.pending_exception from None + + columns, records = self._arrow_table_to_rows(arrow_table) + if records: + insert_sql = build_insert_statement(table, columns) + exc_handler = self.handle_database_exceptions() + async with exc_handler, self.with_cursor(self.connection) as cursor: + await cursor.executemany(insert_sql, records) + if exc_handler.pending_exception is not None: + raise exc_handler.pending_exception from None + + telemetry_payload = self._build_ingest_telemetry(arrow_table) + telemetry_payload["destination"] = table + self._attach_partition_telemetry(telemetry_payload, partitioner) + return self._create_storage_job(telemetry_payload, telemetry) + + async def load_from_storage( + self, + table: str, + source: "StorageDestination", + *, + file_format: "StorageFormat", + partitioner: "dict[str, object] | None" = None, + overwrite: bool = False, + ) -> "StorageBridgeJob": + """Load staged artifacts from storage into MySQL.""" + + arrow_table, inbound = await self._read_arrow_from_storage_async(source, file_format=file_format) + return await self.load_from_arrow( + table, arrow_table, partitioner=partitioner, overwrite=overwrite, telemetry=inbound + ) + + # ───────────────────────────────────────────────────────────────────────────── + # UTILITY METHODS + # ───────────────────────────────────────────────────────────────────────────── + + @property + def data_dictionary(self) -> "AiomysqlDataDictionary": + """Get the data dictionary for this driver. + + Returns: + Data dictionary instance for metadata queries + """ + if self._data_dictionary is None: + self._data_dictionary = AiomysqlDataDictionary() + return self._data_dictionary + + # ───────────────────────────────────────────────────────────────────────────── + # PRIVATE/INTERNAL METHODS + # ───────────────────────────────────────────────────────────────────────────── + + def collect_rows(self, cursor: Any, fetched: "list[Any]") -> "tuple[list[Any], list[str], int]": + """Collect aiomysql rows for the direct execution path.""" + description = cursor.description or None + column_names = resolve_column_names(description) + json_indexes = detect_json_columns_from_description(description, AIOMYSQL_JSON_TYPE_CODES) + deserializer = cast("Callable[[Any], Any]", self.driver_features.get("json_deserializer", from_json)) + rows, column_names, _row_format = collect_rows( + fetched, description, json_indexes, deserializer, column_names=column_names, logger=logger + ) + return rows, column_names, len(rows) + + def resolve_rowcount(self, cursor: Any) -> int: + """Resolve rowcount from aiomysql cursor for the direct execution path.""" + return resolve_rowcount(cursor) + + def _connection_in_transaction(self) -> bool: + """Check if connection is in transaction. + + aiomysql does not expose reliable transaction state. + + Returns: + False - aiomysql requires explicit transaction management. + """ + return False + + +register_driver_profile("aiomysql", driver_profile) diff --git a/tests/integration/adapters/aiomysql/__init__.py b/tests/integration/adapters/aiomysql/__init__.py new file mode 100644 index 000000000..34a4e8354 --- /dev/null +++ b/tests/integration/adapters/aiomysql/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytestmark = [pytest.mark.mysql, pytest.mark.aiomysql] diff --git a/tests/integration/adapters/aiomysql/conftest.py b/tests/integration/adapters/aiomysql/conftest.py new file mode 100644 index 000000000..d456d4f65 --- /dev/null +++ b/tests/integration/adapters/aiomysql/conftest.py @@ -0,0 +1,89 @@ +"""Shared fixtures for aiomysql integration tests.""" + +from collections.abc import AsyncGenerator + +import pytest +from pytest_databases.docker.mysql import MySQLService + +from sqlspec.adapters.aiomysql import AiomysqlConfig, AiomysqlDriver, default_statement_config + + +@pytest.fixture(scope="function") +async def aiomysql_config(mysql_service: "MySQLService") -> "AsyncGenerator[AiomysqlConfig, None]": + """Create aiomysql configuration for testing with proper cleanup.""" + config = AiomysqlConfig( + connection_config={ + "host": mysql_service.host, + "port": mysql_service.port, + "user": mysql_service.user, + "password": mysql_service.password, + "db": mysql_service.db, + "autocommit": True, + "minsize": 1, + "maxsize": 5, + }, + statement_config=default_statement_config, + ) + try: + yield config + finally: + pool = config.connection_instance + if pool is not None: + pool.close() + await pool.wait_closed() + config.connection_instance = None + + +@pytest.fixture +async def aiomysql_driver(aiomysql_config: "AiomysqlConfig") -> "AsyncGenerator[AiomysqlDriver, None]": + """Create aiomysql driver instance for testing.""" + async with aiomysql_config.provide_session() as driver: + yield driver + + +@pytest.fixture +async def aiomysql_clean_driver(aiomysql_config: "AiomysqlConfig") -> "AsyncGenerator[AiomysqlDriver, None]": + """Create aiomysql driver with clean database state.""" + async with aiomysql_config.provide_session() as driver: + await driver.execute("SET sql_notes = 0") + cleanup_tables = [ + "test_table_aiomysql", + "data_types_test_aiomysql", + "user_profiles_aiomysql", + "test_parameter_conversion_aiomysql", + "transaction_test_aiomysql", + "concurrent_test_aiomysql", + "arrow_users_aiomysql", + "arrow_table_test_aiomysql", + "arrow_batch_test_aiomysql", + "arrow_params_test_aiomysql", + "arrow_empty_test_aiomysql", + "arrow_null_test_aiomysql", + "arrow_polars_test_aiomysql", + "arrow_large_test_aiomysql", + "arrow_types_test_aiomysql", + "arrow_json_test_aiomysql", + "driver_feature_test_aiomysql", + ] + + for table in cleanup_tables: + await driver.execute_script(f"DROP TABLE IF EXISTS {table}") + + cleanup_procedures = ["test_procedure", "simple_procedure"] + + for proc in cleanup_procedures: + await driver.execute_script(f"DROP PROCEDURE IF EXISTS {proc}") + + await driver.execute("SET sql_notes = 1") + + yield driver + + await driver.execute("SET sql_notes = 0") + + for table in cleanup_tables: + await driver.execute_script(f"DROP TABLE IF EXISTS {table}") + + for proc in cleanup_procedures: + await driver.execute_script(f"DROP PROCEDURE IF EXISTS {proc}") + + await driver.execute("SET sql_notes = 1") diff --git a/tests/integration/adapters/aiomysql/test_arrow.py b/tests/integration/adapters/aiomysql/test_arrow.py new file mode 100644 index 000000000..7e963ea13 --- /dev/null +++ b/tests/integration/adapters/aiomysql/test_arrow.py @@ -0,0 +1,204 @@ +"""Integration tests for aiomysql Arrow query support.""" + +import pytest + +from sqlspec.adapters.aiomysql import AiomysqlDriver + +pytestmark = pytest.mark.xdist_group("mysql") + + +async def test_select_to_arrow_basic(aiomysql_driver: AiomysqlDriver) -> None: + """Test basic select_to_arrow functionality.""" + import pyarrow as pa + + await aiomysql_driver.execute("DROP TABLE IF EXISTS arrow_users") + await aiomysql_driver.execute("CREATE TABLE IF NOT EXISTS arrow_users (id INT, name VARCHAR(100), age INT)") + await aiomysql_driver.execute("INSERT INTO arrow_users VALUES (1, 'Alice', 30), (2, 'Bob', 25)") + + result = await aiomysql_driver.select_to_arrow("SELECT * FROM arrow_users ORDER BY id") + + assert result is not None + assert isinstance(result.data, (pa.Table, pa.RecordBatch)) + assert result.rows_affected == 2 + + df = result.to_pandas() + assert len(df) == 2 + assert list(df["name"]) == ["Alice", "Bob"] + + await aiomysql_driver.execute("DROP TABLE IF EXISTS arrow_users") + + +async def test_select_to_arrow_table_format(aiomysql_driver: AiomysqlDriver) -> None: + """Test select_to_arrow with table return format (default).""" + import pyarrow as pa + + await aiomysql_driver.execute("DROP TABLE IF EXISTS arrow_table_test") + await aiomysql_driver.execute("CREATE TABLE IF NOT EXISTS arrow_table_test (id INT, value VARCHAR(100))") + await aiomysql_driver.execute("INSERT INTO arrow_table_test VALUES (1, 'a'), (2, 'b'), (3, 'c')") + + result = await aiomysql_driver.select_to_arrow("SELECT * FROM arrow_table_test ORDER BY id", return_format="table") + + assert isinstance(result.data, pa.Table) + assert result.rows_affected == 3 + + await aiomysql_driver.execute("DROP TABLE IF EXISTS arrow_table_test") + + +async def test_select_to_arrow_batch_format(aiomysql_driver: AiomysqlDriver) -> None: + """Test select_to_arrow with batch return format.""" + import pyarrow as pa + + await aiomysql_driver.execute("DROP TABLE IF EXISTS arrow_batch_test") + await aiomysql_driver.execute("CREATE TABLE IF NOT EXISTS arrow_batch_test (id INT, value VARCHAR(100))") + await aiomysql_driver.execute("INSERT INTO arrow_batch_test VALUES (1, 'a'), (2, 'b')") + + result = await aiomysql_driver.select_to_arrow("SELECT * FROM arrow_batch_test ORDER BY id", return_format="batches") + + assert isinstance(result.data, list) + for batch in result.data: + assert isinstance(batch, pa.RecordBatch) + assert result.rows_affected == 2 + + await aiomysql_driver.execute("DROP TABLE IF EXISTS arrow_batch_test") + + +async def test_select_to_arrow_with_parameters(aiomysql_driver: AiomysqlDriver) -> None: + """Test select_to_arrow with query parameters.""" + + await aiomysql_driver.execute("DROP TABLE IF EXISTS arrow_params_test") + await aiomysql_driver.execute("CREATE TABLE IF NOT EXISTS arrow_params_test (id INT, value INT)") + await aiomysql_driver.execute("INSERT INTO arrow_params_test VALUES (1, 100), (2, 200), (3, 300)") + + result = await aiomysql_driver.select_to_arrow( + "SELECT * FROM arrow_params_test WHERE value > %s ORDER BY id", (150,) + ) + + assert result.rows_affected == 2 + df = result.to_pandas() + assert list(df["value"]) == [200, 300] + + await aiomysql_driver.execute("DROP TABLE IF EXISTS arrow_params_test") + + +async def test_select_to_arrow_empty_result(aiomysql_driver: AiomysqlDriver) -> None: + """Test select_to_arrow with empty result set.""" + + await aiomysql_driver.execute("DROP TABLE IF EXISTS arrow_empty_test") + await aiomysql_driver.execute("CREATE TABLE IF NOT EXISTS arrow_empty_test (id INT)") + + result = await aiomysql_driver.select_to_arrow("SELECT * FROM arrow_empty_test") + + assert result.rows_affected == 0 + assert len(result.to_pandas()) == 0 + + await aiomysql_driver.execute("DROP TABLE IF EXISTS arrow_empty_test") + + +async def test_select_to_arrow_null_handling(aiomysql_driver: AiomysqlDriver) -> None: + """Test select_to_arrow with NULL values.""" + + await aiomysql_driver.execute("DROP TABLE IF EXISTS arrow_null_test") + await aiomysql_driver.execute("CREATE TABLE IF NOT EXISTS arrow_null_test (id INT, value VARCHAR(100))") + await aiomysql_driver.execute("INSERT INTO arrow_null_test VALUES (1, 'a'), (2, NULL), (3, 'c')") + + result = await aiomysql_driver.select_to_arrow("SELECT * FROM arrow_null_test ORDER BY id") + + df = result.to_pandas() + assert len(df) == 3 + assert df.iloc[1]["value"] is None or df.isna().iloc[1]["value"] + + await aiomysql_driver.execute("DROP TABLE IF EXISTS arrow_null_test") + + +async def test_select_to_arrow_to_polars(aiomysql_driver: AiomysqlDriver) -> None: + """Test select_to_arrow conversion to Polars DataFrame.""" + pytest.importorskip("polars") + + await aiomysql_driver.execute("DROP TABLE IF EXISTS arrow_polars_test") + await aiomysql_driver.execute("CREATE TABLE IF NOT EXISTS arrow_polars_test (id INT, value VARCHAR(100))") + await aiomysql_driver.execute("INSERT INTO arrow_polars_test VALUES (1, 'a'), (2, 'b')") + + result = await aiomysql_driver.select_to_arrow("SELECT * FROM arrow_polars_test ORDER BY id") + df = result.to_polars() + + assert len(df) == 2 + assert df["value"].to_list() == ["a", "b"] + + await aiomysql_driver.execute("DROP TABLE IF EXISTS arrow_polars_test") + + +async def test_select_to_arrow_large_dataset(aiomysql_driver: AiomysqlDriver) -> None: + """Test select_to_arrow with larger dataset.""" + + await aiomysql_driver.execute("DROP TABLE IF EXISTS arrow_large_test") + await aiomysql_driver.execute("CREATE TABLE IF NOT EXISTS arrow_large_test (id INT, value INT)") + + values = ", ".join(f"({i}, {i * 10})" for i in range(1, 1001)) + await aiomysql_driver.execute(f"INSERT INTO arrow_large_test VALUES {values}") + + result = await aiomysql_driver.select_to_arrow("SELECT * FROM arrow_large_test ORDER BY id") + + assert result.rows_affected == 1000 + df = result.to_pandas() + assert len(df) == 1000 + assert df["value"].sum() == sum(i * 10 for i in range(1, 1001)) + + await aiomysql_driver.execute("DROP TABLE IF EXISTS arrow_large_test") + + +async def test_select_to_arrow_type_preservation(aiomysql_driver: AiomysqlDriver) -> None: + """Test that MySQL types are properly converted to Arrow types.""" + + await aiomysql_driver.execute("DROP TABLE IF EXISTS arrow_types_test") + await aiomysql_driver.execute( + """ + CREATE TABLE IF NOT EXISTS arrow_types_test ( + id INT, + name VARCHAR(100), + price DECIMAL(10, 2), + created_at DATETIME, + is_active BOOLEAN + ) + """ + ) + await aiomysql_driver.execute( + """ + INSERT INTO arrow_types_test VALUES + (1, 'Item 1', 19.99, '2025-01-01 10:00:00', true), + (2, 'Item 2', 29.99, '2025-01-02 15:30:00', false) + """ + ) + + result = await aiomysql_driver.select_to_arrow("SELECT * FROM arrow_types_test ORDER BY id") + + df = result.to_pandas() + from pandas.api.types import is_string_dtype + + assert len(df) == 2 + assert is_string_dtype(df["name"]) + assert df["is_active"].dtype in (bool, int, "int64", "Int64") + + await aiomysql_driver.execute("DROP TABLE IF EXISTS arrow_types_test") + + +async def test_select_to_arrow_json_handling(aiomysql_driver: AiomysqlDriver) -> None: + """Test JSON type handling in Arrow results.""" + + await aiomysql_driver.execute("DROP TABLE IF EXISTS arrow_json_test") + await aiomysql_driver.execute("CREATE TABLE IF NOT EXISTS arrow_json_test (id INT, data JSON)") + await aiomysql_driver.execute( + """ + INSERT INTO arrow_json_test VALUES + (1, '{"name": "Alice", "age": 30}'), + (2, '{"name": "Bob", "age": 25}') + """ + ) + + result = await aiomysql_driver.select_to_arrow("SELECT * FROM arrow_json_test ORDER BY id") + + df = result.to_pandas() + assert len(df) == 2 + first_value = df["data"].iloc[0] + assert isinstance(first_value, (dict, str, object)) + + await aiomysql_driver.execute("DROP TABLE IF EXISTS arrow_json_test") diff --git a/tests/integration/adapters/aiomysql/test_config.py b/tests/integration/adapters/aiomysql/test_config.py new file mode 100644 index 000000000..a642e0a6d --- /dev/null +++ b/tests/integration/adapters/aiomysql/test_config.py @@ -0,0 +1,220 @@ +"""Unit tests for aiomysql configuration.""" + +import pytest +from pytest_databases.docker.mysql import MySQLService + +from sqlspec.adapters.aiomysql import ( + AiomysqlConfig, + AiomysqlConnectionParams, + AiomysqlDriver, + AiomysqlDriverFeatures, + AiomysqlPoolParams, +) +from sqlspec.core import StatementConfig + +pytestmark = pytest.mark.xdist_group("mysql") + + +def test_aiomysql_typed_dict_structure() -> None: + """Test aiomysql TypedDict structure.""" + + connection_parameters: AiomysqlConnectionParams = { + "host": "localhost", + "port": 3306, + "user": "test_user", + "password": "test_password", + "db": "test_db", + } + assert connection_parameters["host"] == "localhost" + assert connection_parameters["port"] == 3306 + + pool_parameters: AiomysqlPoolParams = {"host": "localhost", "port": 3306, "minsize": 5, "maxsize": 20, "echo": True} + assert pool_parameters["host"] == "localhost" + assert pool_parameters["minsize"] == 5 + + +def test_aiomysql_config_basic_creation() -> None: + """Test aiomysql config creation with basic parameters.""" + + connection_config = { + "host": "localhost", + "port": 3306, + "user": "test_user", + "password": "test_password", + "db": "test_db", + } + config = AiomysqlConfig(connection_config=connection_config) + assert config.connection_config["host"] == "localhost" + assert config.connection_config["port"] == 3306 + assert config.connection_config["user"] == "test_user" + assert config.connection_config["password"] == "test_password" + assert config.connection_config["db"] == "test_db" + + connection_config_full = { + "host": "localhost", + "port": 3306, + "user": "test_user", + "password": "test_password", + "db": "test_db", + "custom": "value", + } + config_full = AiomysqlConfig(connection_config=connection_config_full) + assert config_full.connection_config["host"] == "localhost" + assert config_full.connection_config["port"] == 3306 + assert config_full.connection_config["user"] == "test_user" + assert config_full.connection_config["password"] == "test_password" + assert config_full.connection_config["db"] == "test_db" + assert config_full.connection_config["custom"] == "value" + + +def test_aiomysql_config_initialization() -> None: + """Test aiomysql config initialization.""" + + connection_config = { + "host": "localhost", + "port": 3306, + "user": "test_user", + "password": "test_password", + "db": "test_db", + } + config = AiomysqlConfig(connection_config=connection_config) + assert isinstance(config.statement_config, StatementConfig) + + custom_statement_config = StatementConfig(dialect="custom") + config = AiomysqlConfig(connection_config=connection_config, statement_config=custom_statement_config) + assert config.statement_config.dialect == "custom" + + +async def test_aiomysql_config_provide_session(mysql_service: MySQLService) -> None: + """Test aiomysql config provide_session context manager.""" + + connection_config = { + "host": mysql_service.host, + "port": mysql_service.port, + "user": mysql_service.user, + "password": mysql_service.password, + "db": mysql_service.db, + } + config = AiomysqlConfig(connection_config=connection_config) + + async with config.provide_session() as session: + assert isinstance(session, AiomysqlDriver) + + assert session.statement_config is not None + assert session.statement_config.parameter_config is not None + + +def test_aiomysql_config_driver_type() -> None: + """Test aiomysql config driver_type property.""" + connection_config = { + "host": "localhost", + "port": 3306, + "user": "test_user", + "password": "test_password", + "db": "test_db", + } + config = AiomysqlConfig(connection_config=connection_config) + assert config.driver_type is AiomysqlDriver + + +def test_aiomysql_config_is_async() -> None: + """Test aiomysql config is_async attribute.""" + connection_config = { + "host": "localhost", + "port": 3306, + "user": "test_user", + "password": "test_password", + "db": "test_db", + } + config = AiomysqlConfig(connection_config=connection_config) + assert config.is_async is True + assert AiomysqlConfig.is_async is True + + +def test_aiomysql_config_supports_connection_pooling() -> None: + """Test aiomysql config supports_connection_pooling attribute.""" + connection_config = { + "host": "localhost", + "port": 3306, + "user": "test_user", + "password": "test_password", + "db": "test_db", + } + config = AiomysqlConfig(connection_config=connection_config) + assert config.supports_connection_pooling is True + assert AiomysqlConfig.supports_connection_pooling is True + + +def test_aiomysql_driver_features_typed_dict_structure() -> None: + """Test AiomysqlDriverFeatures TypedDict structure.""" + features: AiomysqlDriverFeatures = {"json_serializer": lambda x: str(x), "json_deserializer": lambda x: x} + + assert "json_serializer" in features + assert "json_deserializer" in features + assert callable(features["json_serializer"]) + assert callable(features["json_deserializer"]) + + +def test_aiomysql_driver_features_partial_dict() -> None: + """Test AiomysqlDriverFeatures with partial configuration.""" + features: AiomysqlDriverFeatures = {"json_serializer": lambda x: str(x)} + + assert "json_serializer" in features + assert "json_deserializer" not in features + + +def test_aiomysql_driver_features_empty_dict() -> None: + """Test AiomysqlDriverFeatures with empty configuration.""" + features: AiomysqlDriverFeatures = {} + + assert len(features) == 0 + + +def test_aiomysql_config_with_driver_features() -> None: + """Test AiomysqlConfig initialization with driver_features.""" + + def custom_serializer(data: object) -> str: + return str(data) + + def custom_deserializer(data: str) -> object: + return data + + features: AiomysqlDriverFeatures = {"json_serializer": custom_serializer, "json_deserializer": custom_deserializer} + + config = AiomysqlConfig(connection_config={"host": "localhost", "port": 3306}, driver_features=features) + + assert config.driver_features["json_serializer"] is custom_serializer + assert config.driver_features["json_deserializer"] is custom_deserializer + + +def test_aiomysql_config_with_empty_driver_features() -> None: + """Test AiomysqlConfig with empty driver_features still provides defaults.""" + config = AiomysqlConfig(connection_config={"host": "localhost", "port": 3306}, driver_features={}) + + assert "json_serializer" in config.driver_features + assert "json_deserializer" in config.driver_features + assert callable(config.driver_features["json_serializer"]) + assert callable(config.driver_features["json_deserializer"]) + + +def test_aiomysql_config_without_driver_features() -> None: + """Test AiomysqlConfig without driver_features provides sensible defaults.""" + config = AiomysqlConfig(connection_config={"host": "localhost", "port": 3306}) + + assert "json_serializer" in config.driver_features + assert "json_deserializer" in config.driver_features + assert callable(config.driver_features["json_serializer"]) + assert callable(config.driver_features["json_deserializer"]) + + +def test_aiomysql_config_driver_features_as_plain_dict() -> None: + """Test AiomysqlConfig with driver_features as plain dict.""" + + def custom_serializer(data: object) -> str: + return str(data) + + config = AiomysqlConfig( + connection_config={"host": "localhost", "port": 3306}, driver_features={"json_serializer": custom_serializer} + ) + + assert config.driver_features["json_serializer"] is custom_serializer diff --git a/tests/integration/adapters/aiomysql/test_driver.py b/tests/integration/adapters/aiomysql/test_driver.py new file mode 100644 index 000000000..61687fd32 --- /dev/null +++ b/tests/integration/adapters/aiomysql/test_driver.py @@ -0,0 +1,556 @@ +"""Integration tests for aiomysql driver implementation. + +This serves as a comprehensive test template for database drivers, +covering all core functionality including CRUD operations, parameter styles, +transaction management, and error handling. +""" + +import math +from typing import Literal + +import pytest +from pytest_databases.docker.mysql import MySQLService + +from sqlspec import SQL, SQLResult, StatementStack, sql +from sqlspec.adapters.aiomysql import AiomysqlConfig, AiomysqlDriver +from sqlspec.utils.serializers import from_json, to_json + +ParamStyle = Literal["tuple_binds", "dict_binds", "named_binds"] + +pytestmark = pytest.mark.xdist_group("mysql") + + +@pytest.fixture +async def aiomysql_driver(aiomysql_clean_driver: AiomysqlDriver) -> AiomysqlDriver: + """Create and manage test table lifecycle.""" + + create_sql = """ + CREATE TABLE IF NOT EXISTS test_table_aiomysql ( + id INT AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(255) NOT NULL, + value INT DEFAULT 0, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """ + await aiomysql_clean_driver.execute_script(create_sql) + await aiomysql_clean_driver.execute_script("DELETE FROM test_table_aiomysql") + + return aiomysql_clean_driver + + +async def test_aiomysql_basic_crud(aiomysql_driver: AiomysqlDriver) -> None: + """Test basic CRUD operations.""" + driver = aiomysql_driver + + insert_result = await driver.execute( + "INSERT INTO test_table_aiomysql (name, value) VALUES (?, ?)", ("test_user", 42) + ) + assert insert_result.num_rows == 1 + + select_result = await driver.execute("SELECT * FROM test_table_aiomysql WHERE name = ?", ("test_user",)) + assert select_result.num_rows == 1 + assert len(select_result.get_data()) == 1 + row = select_result.get_data()[0] + assert row["name"] == "test_user" + assert row["value"] == 42 + + update_result = await driver.execute("UPDATE test_table_aiomysql SET value = ? WHERE name = ?", (100, "test_user")) + assert update_result.num_rows == 1 + + updated_result = await driver.execute("SELECT value FROM test_table_aiomysql WHERE name = ?", ("test_user",)) + assert updated_result.get_data()[0]["value"] == 100 + + delete_result = await driver.execute("DELETE FROM test_table_aiomysql WHERE name = ?", ("test_user",)) + assert delete_result.num_rows == 1 + + verify_result = await driver.execute( + "SELECT COUNT(*) as count FROM test_table_aiomysql WHERE name = ?", ("test_user",) + ) + assert verify_result.get_data()[0]["count"] == 0 + + +async def test_aiomysql_parameter_styles(aiomysql_driver: AiomysqlDriver) -> None: + """Test different parameter binding styles.""" + driver = aiomysql_driver + + result1 = await driver.execute("INSERT INTO test_table_aiomysql (name, value) VALUES (?, ?)", ("user1", 10)) + assert result1.num_rows == 1 + + result2 = await driver.execute("INSERT INTO test_table_aiomysql (name, value) VALUES (?, ?)", ["user2", 20]) + assert result2.num_rows == 1 + + select_result = await driver.execute("SELECT name, value FROM test_table_aiomysql ORDER BY name") + assert len(select_result.get_data()) == 2 + assert select_result.get_data()[0]["name"] == "user1" + assert select_result.get_data()[0]["value"] == 10 + assert select_result.get_data()[1]["name"] == "user2" + assert select_result.get_data()[1]["value"] == 20 + + +async def test_aiomysql_execute_many(aiomysql_driver: AiomysqlDriver) -> None: + """Test execute_many functionality.""" + driver = aiomysql_driver + + data = [("batch_user_1", 100), ("batch_user_2", 200), ("batch_user_3", 300)] + + result = await driver.execute_many("INSERT INTO test_table_aiomysql (name, value) VALUES (?, ?)", data) + assert result.num_rows == 3 + + select_result = await driver.execute( + "SELECT name, value FROM test_table_aiomysql WHERE name LIKE ? ORDER BY name", ("batch_user_%",) + ) + assert len(select_result.get_data()) == 3 + assert select_result.get_data()[0]["name"] == "batch_user_1" + assert select_result.get_data()[0]["value"] == 100 + + +async def test_aiomysql_execute_script(aiomysql_driver: AiomysqlDriver) -> None: + """Test script execution with multiple statements.""" + driver = aiomysql_driver + + script = """ + INSERT INTO test_table_aiomysql (name, value) VALUES ('script_user_1', 1000); + INSERT INTO test_table_aiomysql (name, value) VALUES ('script_user_2', 2000); + UPDATE test_table_aiomysql SET value = value * 2 WHERE name LIKE 'script_user_%'; + """ + + result = await driver.execute_script(script) + assert result.operation_type == "SCRIPT" + + select_result = await driver.execute( + "SELECT name, value FROM test_table_aiomysql WHERE name LIKE ? ORDER BY name", ("script_user_%",) + ) + assert len(select_result.get_data()) == 2 + assert select_result.get_data()[0]["value"] == 2000 + assert select_result.get_data()[1]["value"] == 4000 + + +async def test_aiomysql_data_types(aiomysql_driver: AiomysqlDriver) -> None: + """Test handling of various MySQL data types.""" + driver = aiomysql_driver + + await driver.execute_script(""" + CREATE TABLE IF NOT EXISTS data_types_test_aiomysql ( + id INT AUTO_INCREMENT PRIMARY KEY, + text_col VARCHAR(255), + int_col INT, + float_col FLOAT, + bool_col BOOLEAN, + date_col DATE, + datetime_col DATETIME, + json_col JSON + ) + """) + + from datetime import date, datetime + + test_data = ("test_string", 42, math.pi, True, date(2023, 1, 1), datetime(2023, 1, 1, 12, 0, 0), '{"key": "value"}') + + result = await driver.execute( + """INSERT INTO data_types_test_aiomysql + (text_col, int_col, float_col, bool_col, date_col, datetime_col, json_col) + VALUES (?, ?, ?, ?, ?, ?, ?)""", + test_data, + ) + assert result.rows_affected == 1 + + select_result = await driver.execute( + "SELECT * FROM data_types_test_aiomysql WHERE text_col = ? AND int_col = ?", ("test_string", 42) + ) + assert len(select_result.get_data()) == 1 + row = select_result.get_data()[0] + assert row["text_col"] == "test_string" + assert row["int_col"] == 42 + assert abs(row["float_col"] - math.pi) < 0.01 + assert row["bool_col"] == 1 + assert isinstance(row["json_col"], dict) + assert row["json_col"]["key"] == "value" + + +async def test_aiomysql_statement_stack_sequential(aiomysql_driver: AiomysqlDriver) -> None: + """StatementStack should execute sequentially for aiomysql (no native batching).""" + + await aiomysql_driver.execute_script("DELETE FROM test_table_aiomysql") + + stack = ( + StatementStack() + .push_execute("INSERT INTO test_table_aiomysql (name, value) VALUES (?, ?)", ("mysql-stack-one", 11)) + .push_execute("INSERT INTO test_table_aiomysql (name, value) VALUES (?, ?)", ("mysql-stack-two", 22)) + .push_execute("SELECT COUNT(*) AS total FROM test_table_aiomysql WHERE name LIKE ?", ("mysql-stack-%",)) + ) + + results = await aiomysql_driver.execute_stack(stack) + + assert len(results) == 3 + assert results[0].rows_affected == 1 + assert results[1].rows_affected == 1 + final_result = results[2].result + assert isinstance(final_result, SQLResult) + data = final_result.get_data() + assert data + assert data[0]["total"] == 2 + + +async def test_aiomysql_statement_stack_continue_on_error(aiomysql_driver: AiomysqlDriver) -> None: + """Continue-on-error should still work with sequential fallback.""" + + await aiomysql_driver.execute_script("DELETE FROM test_table_aiomysql") + + stack = ( + StatementStack() + .push_execute("INSERT INTO test_table_aiomysql (id, name, value) VALUES (?, ?, ?)", (1, "mysql-initial", 5)) + .push_execute("INSERT INTO test_table_aiomysql (id, name, value) VALUES (?, ?, ?)", (1, "mysql-duplicate", 15)) + .push_execute("INSERT INTO test_table_aiomysql (id, name, value) VALUES (?, ?, ?)", (2, "mysql-final", 25)) + ) + + results = await aiomysql_driver.execute_stack(stack, continue_on_error=True) + + assert len(results) == 3 + assert results[0].rows_affected == 1 + assert results[1].error is not None + assert results[2].rows_affected == 1 + + verify = await aiomysql_driver.execute( + "SELECT COUNT(*) AS total FROM test_table_aiomysql WHERE name LIKE ?", ("mysql-%",) + ) + assert verify.get_data()[0]["total"] == 2 + + +async def test_aiomysql_driver_features_custom_serializers(mysql_service: MySQLService) -> None: + """Ensure custom serializer and deserializer driver features are applied.""" + + serializer_calls: list[object] = [] + + def tracking_serializer(value: object) -> str: + serializer_calls.append(value) + return to_json(value) + + def tracking_deserializer(value: str | bytes) -> object: + decoded = from_json(value) + if isinstance(decoded, dict): + decoded["extra_marker"] = True + return decoded + + config = AiomysqlConfig( + connection_config={ + "host": mysql_service.host, + "port": mysql_service.port, + "user": mysql_service.user, + "password": mysql_service.password, + "db": mysql_service.db, + "autocommit": True, + }, + driver_features={"json_serializer": tracking_serializer, "json_deserializer": tracking_deserializer}, + ) + + async with config.provide_session() as session: + await session.execute_script( + """ + CREATE TABLE IF NOT EXISTS driver_feature_test_aiomysql ( + id INT AUTO_INCREMENT PRIMARY KEY, + payload JSON + ); + DELETE FROM driver_feature_test_aiomysql; + """ + ) + + payload = {"foo": "bar"} + await session.execute("INSERT INTO driver_feature_test_aiomysql (payload) VALUES (?)", (payload,)) + + assert serializer_calls + assert serializer_calls[0] == payload + + select_result = await session.execute( + "SELECT payload FROM driver_feature_test_aiomysql ORDER BY id DESC LIMIT 1" + ) + stored_row = select_result.get_data()[0] + assert stored_row["payload"]["foo"] == "bar" + assert stored_row["payload"]["extra_marker"] is True + + +async def test_aiomysql_transaction_management(aiomysql_driver: AiomysqlDriver) -> None: + """Test transaction management (begin, commit, rollback).""" + driver = aiomysql_driver + + await driver.begin() + await driver.execute("INSERT INTO test_table_aiomysql (name, value) VALUES (?, ?)", ("tx_user_1", 100)) + await driver.commit() + + result = await driver.execute("SELECT COUNT(*) as count FROM test_table_aiomysql WHERE name = ?", ("tx_user_1",)) + assert result.get_data()[0]["count"] == 1 + + await driver.begin() + await driver.execute("INSERT INTO test_table_aiomysql (name, value) VALUES (?, ?)", ("tx_user_2", 200)) + await driver.rollback() + + result = await driver.execute("SELECT COUNT(*) as count FROM test_table_aiomysql WHERE name = ?", ("tx_user_2",)) + assert result.get_data()[0]["count"] == 0 + + +async def test_aiomysql_null_parameters(aiomysql_driver: AiomysqlDriver) -> None: + """Test handling of NULL parameters.""" + driver = aiomysql_driver + + result = await driver.execute("INSERT INTO test_table_aiomysql (name, value) VALUES (?, ?)", ("null_test", None)) + assert result.num_rows == 1 + + select_result = await driver.execute("SELECT name, value FROM test_table_aiomysql WHERE name = ?", ("null_test",)) + assert len(select_result.get_data()) == 1 + assert select_result.get_data()[0]["name"] == "null_test" + assert select_result.get_data()[0]["value"] is None + + +async def test_aiomysql_error_handling(aiomysql_driver: AiomysqlDriver) -> None: + """Test error handling and exception wrapping.""" + driver = aiomysql_driver + + with pytest.raises(Exception): + await driver.execute("INVALID SQL STATEMENT") + + await driver.execute("INSERT INTO test_table_aiomysql (id, name, value) VALUES (?, ?, ?)", (1, "user1", 100)) + + with pytest.raises(Exception): + await driver.execute("INSERT INTO test_table_aiomysql (id, name, value) VALUES (?, ?, ?)", (1, "user2", 200)) + + +async def test_aiomysql_large_result_set(aiomysql_driver: AiomysqlDriver) -> None: + """Test handling of large result sets.""" + driver = aiomysql_driver + + batch_data = [(f"user_{i}", i * 10) for i in range(100)] + await driver.execute_many("INSERT INTO test_table_aiomysql (name, value) VALUES (?, ?)", batch_data) + + result = await driver.execute("SELECT * FROM test_table_aiomysql ORDER BY value") + assert result.num_rows == 100 + assert len(result.get_data()) == 100 + assert result.get_data()[0]["name"] == "user_0" + assert result.get_data()[99]["name"] == "user_99" + + +async def test_aiomysql_mysql_specific_features(aiomysql_driver: AiomysqlDriver) -> None: + """Test MySQL-specific features and SQL constructs.""" + driver = aiomysql_driver + + await driver.execute( + "INSERT INTO test_table_aiomysql (id, name, value) VALUES (?, ?, ?)", (1, "duplicate_test", 100) + ) + + _ = await driver.execute( + """INSERT INTO test_table_aiomysql (id, name, value) VALUES (?, ?, ?) AS new + ON DUPLICATE KEY UPDATE value = new.value + 50""", + (1, "duplicate_test_updated", 200), + ) + + select_result = await driver.execute("SELECT name, value FROM test_table_aiomysql WHERE id = ?", (1,)) + assert select_result.get_data()[0]["value"] == 250 + + +async def test_aiomysql_complex_queries(aiomysql_driver: AiomysqlDriver) -> None: + """Test complex SQL queries with JOINs, subqueries, etc.""" + driver = aiomysql_driver + + await driver.execute_script(""" + CREATE TABLE IF NOT EXISTS user_profiles_aiomysql ( + user_id INT PRIMARY KEY, + email VARCHAR(255), + age INT + ) + """) + + await driver.execute("INSERT INTO test_table_aiomysql (id, name, value) VALUES (?, ?, ?)", (1, "john_doe", 100)) + await driver.execute( + "INSERT INTO user_profiles_aiomysql (user_id, email, age) VALUES (?, ?, ?)", (1, "john@example.com", 30) + ) + + result = await driver.execute( + """ + SELECT t.name, t.value, p.email, p.age + FROM test_table_aiomysql t + JOIN user_profiles_aiomysql p ON t.id = p.user_id + WHERE t.name = ? + """, + ("john_doe",), + ) + + assert len(result.get_data()) == 1 + row = result.get_data()[0] + assert row["name"] == "john_doe" + assert row["email"] == "john@example.com" + assert row["age"] == 30 + + +async def test_aiomysql_edge_cases(aiomysql_driver: AiomysqlDriver) -> None: + """Test edge cases and boundary conditions.""" + driver = aiomysql_driver + + result = await driver.execute("SELECT 1 as test_col", ()) + assert len(result.get_data()) == 1 + assert result.get_data()[0]["test_col"] == 1 + + result = await driver.execute("SELECT ? as param_value", (42,)) + assert result.get_data()[0]["param_value"] == 42 + + data_with_nulls = [("user1", 100), ("user2", None), ("user3", 300)] + + result = await driver.execute_many("INSERT INTO test_table_aiomysql (name, value) VALUES (?, ?)", data_with_nulls) + assert result.num_rows == 3 + + select_result = await driver.execute( + "SELECT name, value FROM test_table_aiomysql WHERE name IN (?, ?, ?) ORDER BY name", ("user1", "user2", "user3") + ) + assert len(select_result.get_data()) == 3 + assert select_result.get_data()[1]["value"] is None + + +async def test_aiomysql_result_metadata(aiomysql_driver: AiomysqlDriver) -> None: + """Test SQL result metadata and properties.""" + driver = aiomysql_driver + + insert_result = await driver.execute( + "INSERT INTO test_table_aiomysql (name, value) VALUES (?, ?)", ("metadata_test", 500) + ) + assert insert_result.num_rows == 1 + assert insert_result.operation_type == "INSERT" + assert insert_result.column_names is None or len(insert_result.column_names) == 0 + + select_result = await driver.execute( + "SELECT id, name, value FROM test_table_aiomysql WHERE name = ?", ("metadata_test",) + ) + assert select_result.num_rows == 1 + assert select_result.operation_type == "SELECT" + assert select_result.column_names == ["id", "name", "value"] + assert len(select_result.get_data()) == 1 + + empty_result = await driver.execute("SELECT * FROM test_table_aiomysql WHERE name = ?", ("nonexistent",)) + assert empty_result.num_rows == 0 + assert empty_result.operation_type == "SELECT" + assert len(empty_result.get_data()) == 0 + + +async def test_aiomysql_sql_object_execution(aiomysql_driver: AiomysqlDriver) -> None: + """Test execution of SQL objects.""" + driver = aiomysql_driver + + sql_obj = SQL("INSERT INTO test_table_aiomysql (name, value) VALUES (?, ?)", "sql_obj_test", 999) + result = await driver.execute(sql_obj) + assert isinstance(result, SQLResult) + assert result.num_rows == 1 + + verify_result = await driver.execute("SELECT name, value FROM test_table_aiomysql WHERE name = ?", ("sql_obj_test",)) + assert len(verify_result.get_data()) == 1 + assert verify_result.get_data()[0]["name"] == "sql_obj_test" + assert verify_result.get_data()[0]["value"] == 999 + + select_sql = SQL("SELECT * FROM test_table_aiomysql WHERE value > ?", 500) + select_result = await driver.execute(select_sql) + assert isinstance(select_result, SQLResult) + assert select_result.num_rows >= 1 + assert select_result.operation_type == "SELECT" + + +async def test_aiomysql_for_update_locking(aiomysql_driver: AiomysqlDriver) -> None: + """Test FOR UPDATE row locking with MySQL.""" + + driver = aiomysql_driver + + # Insert test data + await driver.execute("INSERT INTO test_table_aiomysql (name, value) VALUES (?, ?)", ("mysql_lock", 100)) + + try: + await driver.begin() + + # Test basic FOR UPDATE + result = await driver.select_one( + sql.select("id", "name", "value").from_("test_table_aiomysql").where_eq("name", "mysql_lock").for_update() + ) + assert result is not None + assert result["name"] == "mysql_lock" + assert result["value"] == 100 + + await driver.commit() + except Exception: + await driver.rollback() + raise + + +async def test_aiomysql_for_update_skip_locked(aiomysql_driver: AiomysqlDriver) -> None: + """Test FOR UPDATE SKIP LOCKED with MySQL (MySQL 8.0+ feature).""" + + driver = aiomysql_driver + + # Insert test data + await driver.execute("INSERT INTO test_table_aiomysql (name, value) VALUES (?, ?)", ("mysql_skip", 200)) + + try: + await driver.begin() + + # Test FOR UPDATE SKIP LOCKED + result = await driver.select_one( + sql.select("*").from_("test_table_aiomysql").where_eq("name", "mysql_skip").for_update(skip_locked=True) + ) + assert result is not None + assert result["name"] == "mysql_skip" + + await driver.commit() + except Exception: + await driver.rollback() + raise + + +async def test_aiomysql_for_share_locking(aiomysql_driver: AiomysqlDriver) -> None: + """Test FOR SHARE row locking with MySQL.""" + + driver = aiomysql_driver + + # Insert test data + await driver.execute("INSERT INTO test_table_aiomysql (name, value) VALUES (?, ?)", ("mysql_share", 300)) + + try: + await driver.begin() + + # Test basic FOR SHARE (MySQL uses FOR SHARE syntax like PostgreSQL) + result = await driver.select_one( + sql.select("id", "name", "value").from_("test_table_aiomysql").where_eq("name", "mysql_share").for_share() + ) + assert result is not None + assert result["name"] == "mysql_share" + assert result["value"] == 300 + + await driver.commit() + except Exception: + await driver.rollback() + raise + + +async def test_aiomysql_on_connection_create_hook(mysql_service: "MySQLService") -> None: + """Test on_connection_create callback is invoked for each connection.""" + from typing import Any + + hook_call_count = 0 + + async def connection_hook(conn: Any) -> None: + nonlocal hook_call_count + hook_call_count += 1 + + config = AiomysqlConfig( + connection_config={ + "host": mysql_service.host, + "port": mysql_service.port, + "user": mysql_service.user, + "password": mysql_service.password, + "db": mysql_service.db, + "minsize": 1, + "maxsize": 2, + }, + driver_features={"on_connection_create": connection_hook}, + ) + + try: + async with config.provide_session() as session: + await session.execute("SELECT 1") + assert hook_call_count >= 1, "Hook should be called at least once" + finally: + pool = config.connection_instance + if pool is not None: + pool.close() + await pool.wait_closed() diff --git a/tests/integration/adapters/aiomysql/test_exceptions.py b/tests/integration/adapters/aiomysql/test_exceptions.py new file mode 100644 index 000000000..499bd5e60 --- /dev/null +++ b/tests/integration/adapters/aiomysql/test_exceptions.py @@ -0,0 +1,137 @@ +"""Exception handling integration tests for aiomysql adapter.""" + +from collections.abc import AsyncGenerator + +import pytest +from pytest_databases.docker.mysql import MySQLService + +from sqlspec.adapters.aiomysql import AiomysqlConfig, AiomysqlDriver +from sqlspec.exceptions import ( + CheckViolationError, + ForeignKeyViolationError, + NotNullViolationError, + SQLParsingError, + UniqueViolationError, +) + +pytestmark = pytest.mark.xdist_group("mysql") + + +@pytest.fixture +async def aiomysql_exception_session(mysql_service: MySQLService) -> AsyncGenerator[AiomysqlDriver, None]: + """Create an aiomysql session for exception testing.""" + config = AiomysqlConfig( + connection_config={ + "host": mysql_service.host, + "port": mysql_service.port, + "user": mysql_service.user, + "password": mysql_service.password, + "db": mysql_service.db, + "autocommit": True, + "minsize": 1, + "maxsize": 5, + } + ) + + try: + async with config.provide_session() as session: + yield session + finally: + await config.close_pool() + + +async def test_unique_violation(aiomysql_exception_session: AiomysqlDriver) -> None: + """Test unique constraint violation raises UniqueViolationError.""" + await aiomysql_exception_session.execute_script(""" + DROP TABLE IF EXISTS test_unique_constraint; + CREATE TABLE test_unique_constraint ( + id INT AUTO_INCREMENT PRIMARY KEY, + email VARCHAR(255) UNIQUE NOT NULL + ); + """) + + await aiomysql_exception_session.execute( + "INSERT INTO test_unique_constraint (email) VALUES (%s)", ("test@example.com",) + ) + + with pytest.raises(UniqueViolationError) as exc_info: + await aiomysql_exception_session.execute( + "INSERT INTO test_unique_constraint (email) VALUES (%s)", ("test@example.com",) + ) + + assert "unique" in str(exc_info.value).lower() or "1062" in str(exc_info.value) + + await aiomysql_exception_session.execute("DROP TABLE test_unique_constraint") + + +async def test_foreign_key_violation(aiomysql_exception_session: AiomysqlDriver) -> None: + """Test foreign key constraint violation raises ForeignKeyViolationError.""" + await aiomysql_exception_session.execute_script(""" + DROP TABLE IF EXISTS test_fk_child; + DROP TABLE IF EXISTS test_fk_parent; + CREATE TABLE test_fk_parent ( + id INT AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(100) + ) ENGINE=InnoDB; + CREATE TABLE test_fk_child ( + id INT AUTO_INCREMENT PRIMARY KEY, + parent_id INT NOT NULL, + FOREIGN KEY (parent_id) REFERENCES test_fk_parent(id) + ) ENGINE=InnoDB; + """) + + with pytest.raises(ForeignKeyViolationError) as exc_info: + await aiomysql_exception_session.execute("INSERT INTO test_fk_child (parent_id) VALUES (%s)", (999,)) + + assert "foreign key" in str(exc_info.value).lower() or any(code in str(exc_info.value) for code in ["1216", "1452"]) + + await aiomysql_exception_session.execute_script(""" + DROP TABLE IF EXISTS test_fk_child; + DROP TABLE IF EXISTS test_fk_parent; + """) + + +async def test_not_null_violation(aiomysql_exception_session: AiomysqlDriver) -> None: + """Test NOT NULL constraint violation raises NotNullViolationError.""" + await aiomysql_exception_session.execute_script(""" + DROP TABLE IF EXISTS test_not_null; + CREATE TABLE test_not_null ( + id INT AUTO_INCREMENT PRIMARY KEY, + required_field VARCHAR(100) NOT NULL + ); + """) + + with pytest.raises(NotNullViolationError) as exc_info: + await aiomysql_exception_session.execute("INSERT INTO test_not_null (id) VALUES (%s)", (1,)) + + assert "cannot be null" in str(exc_info.value).lower() or any( + code in str(exc_info.value) for code in ["1048", "1364"] + ) + + await aiomysql_exception_session.execute("DROP TABLE test_not_null") + + +async def test_check_violation(aiomysql_exception_session: AiomysqlDriver) -> None: + """Test CHECK constraint violation raises CheckViolationError.""" + await aiomysql_exception_session.execute_script(""" + DROP TABLE IF EXISTS test_check_constraint; + CREATE TABLE test_check_constraint ( + id INT AUTO_INCREMENT PRIMARY KEY, + age INT CHECK (age >= 18) + ); + """) + + with pytest.raises(CheckViolationError) as exc_info: + await aiomysql_exception_session.execute("INSERT INTO test_check_constraint (age) VALUES (%s)", (15,)) + + assert "check" in str(exc_info.value).lower() or "3819" in str(exc_info.value) + + await aiomysql_exception_session.execute("DROP TABLE test_check_constraint") + + +async def test_sql_parsing_error(aiomysql_exception_session: AiomysqlDriver) -> None: + """Test syntax error raises SQLParsingError.""" + with pytest.raises(SQLParsingError) as exc_info: + await aiomysql_exception_session.execute("SELCT * FROM nonexistent_table") + + assert "syntax" in str(exc_info.value).lower() or "1064" in str(exc_info.value) diff --git a/tests/integration/adapters/aiomysql/test_explain.py b/tests/integration/adapters/aiomysql/test_explain.py new file mode 100644 index 000000000..f7a0bcd0b --- /dev/null +++ b/tests/integration/adapters/aiomysql/test_explain.py @@ -0,0 +1,138 @@ +"""Integration tests for EXPLAIN plan support with aiomysql adapter (MySQL).""" + +from collections.abc import AsyncGenerator + +import pytest + +from sqlspec import SQLResult +from sqlspec.adapters.aiomysql import AiomysqlConfig, AiomysqlDriver +from sqlspec.builder import Explain, sql +from sqlspec.core import SQL + +pytestmark = [pytest.mark.xdist_group("mysql")] + + +@pytest.fixture +async def aiomysql_session(aiomysql_config: AiomysqlConfig) -> AsyncGenerator[AiomysqlDriver, None]: + """Create an aiomysql session with test table.""" + async with aiomysql_config.provide_session() as session: + await session.execute_script("DROP TABLE IF EXISTS explain_test") + await session.execute_script( + """ + CREATE TABLE IF NOT EXISTS explain_test ( + id INT AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(255) NOT NULL, + value INT DEFAULT 0 + ) + """ + ) + yield session + + try: + await session.execute_script("DROP TABLE IF EXISTS explain_test") + except Exception: + pass + + +async def test_explain_basic_select(aiomysql_session: AiomysqlDriver) -> None: + """Test basic EXPLAIN on SELECT statement.""" + explain_stmt = Explain("SELECT * FROM explain_test", dialect="mysql") + result = await aiomysql_session.execute(explain_stmt.build()) + + assert isinstance(result, SQLResult) + assert result.data is not None + + +async def test_explain_analyze(aiomysql_session: AiomysqlDriver) -> None: + """Test EXPLAIN ANALYZE on SELECT statement (MySQL 8.0+).""" + explain_stmt = Explain("SELECT * FROM explain_test", dialect="mysql").analyze() + result = await aiomysql_session.execute(explain_stmt.build()) + + assert isinstance(result, SQLResult) + assert result.data is not None + + +async def test_explain_format_json(aiomysql_session: AiomysqlDriver) -> None: + """Test EXPLAIN FORMAT = JSON.""" + explain_stmt = Explain("SELECT * FROM explain_test", dialect="mysql").format("json") + result = await aiomysql_session.execute(explain_stmt.build()) + + assert isinstance(result, SQLResult) + assert result.data is not None + + +async def test_explain_format_tree(aiomysql_session: AiomysqlDriver) -> None: + """Test EXPLAIN FORMAT = TREE (MySQL 8.0+).""" + explain_stmt = Explain("SELECT * FROM explain_test", dialect="mysql").format("tree") + result = await aiomysql_session.execute(explain_stmt.build()) + + assert isinstance(result, SQLResult) + assert result.data is not None + + +async def test_explain_format_traditional(aiomysql_session: AiomysqlDriver) -> None: + """Test EXPLAIN FORMAT = TRADITIONAL.""" + explain_stmt = Explain("SELECT * FROM explain_test", dialect="mysql").format("traditional") + result = await aiomysql_session.execute(explain_stmt.build()) + + assert isinstance(result, SQLResult) + assert result.data is not None + + +async def test_explain_from_query_builder(aiomysql_session: AiomysqlDriver) -> None: + """Test EXPLAIN from QueryBuilder via mixin. + + Note: Uses raw SQL since query builder without dialect produces PostgreSQL-style SQL. + """ + explain_stmt = Explain("SELECT * FROM explain_test WHERE id > 0", dialect="mysql").analyze() + result = await aiomysql_session.execute(explain_stmt.build()) + + assert isinstance(result, SQLResult) + assert result.data is not None + + +async def test_explain_from_sql_factory(aiomysql_session: AiomysqlDriver) -> None: + """Test sql.explain() factory method.""" + explain_stmt = sql.explain("SELECT * FROM explain_test", dialect="mysql") + result = await aiomysql_session.execute(explain_stmt.build()) + + assert isinstance(result, SQLResult) + assert result.data is not None + + +async def test_explain_from_sql_object(aiomysql_session: AiomysqlDriver) -> None: + """Test SQL.explain() method.""" + stmt = SQL("SELECT * FROM explain_test") + # Use Explain directly with dialect since SQL uses default dialect + explain_stmt = Explain(stmt.sql, dialect="mysql") + result = await aiomysql_session.execute(explain_stmt.build()) + + assert isinstance(result, SQLResult) + assert result.data is not None + + +async def test_explain_insert(aiomysql_session: AiomysqlDriver) -> None: + """Test EXPLAIN on INSERT statement.""" + explain_stmt = Explain("INSERT INTO explain_test (name, value) VALUES ('test', 1)", dialect="mysql") + result = await aiomysql_session.execute(explain_stmt.build()) + + assert isinstance(result, SQLResult) + assert result.data is not None + + +async def test_explain_update(aiomysql_session: AiomysqlDriver) -> None: + """Test EXPLAIN on UPDATE statement.""" + explain_stmt = Explain("UPDATE explain_test SET value = 100 WHERE id = 1", dialect="mysql") + result = await aiomysql_session.execute(explain_stmt.build()) + + assert isinstance(result, SQLResult) + assert result.data is not None + + +async def test_explain_delete(aiomysql_session: AiomysqlDriver) -> None: + """Test EXPLAIN on DELETE statement.""" + explain_stmt = Explain("DELETE FROM explain_test WHERE id = 1", dialect="mysql") + result = await aiomysql_session.execute(explain_stmt.build()) + + assert isinstance(result, SQLResult) + assert result.data is not None diff --git a/tests/integration/adapters/aiomysql/test_features.py b/tests/integration/adapters/aiomysql/test_features.py new file mode 100644 index 000000000..c3a26ac10 --- /dev/null +++ b/tests/integration/adapters/aiomysql/test_features.py @@ -0,0 +1,284 @@ +"""aiomysql-specific feature tests. + +This test suite focuses on aiomysql adapter specific functionality including: +- Connection pooling behavior +- MySQL-specific SQL features +- Async transaction handling +- Error handling and recovery +- Performance characteristics +""" + +from collections.abc import AsyncGenerator + +import pytest +from pytest_databases.docker.mysql import MySQLService + +from sqlspec.adapters.aiomysql import AiomysqlConfig, AiomysqlDriver, default_statement_config +from sqlspec.core import SQL, SQLResult + +pytestmark = pytest.mark.xdist_group("mysql") + + +@pytest.fixture +async def aiomysql_pooled_session(mysql_service: MySQLService) -> AsyncGenerator[AiomysqlDriver, None]: + """Create aiomysql session with connection pooling.""" + config = AiomysqlConfig( + connection_config={ + "host": mysql_service.host, + "port": mysql_service.port, + "user": mysql_service.user, + "password": mysql_service.password, + "db": mysql_service.db, + "autocommit": True, + "minsize": 2, + "maxsize": 10, + "echo": False, + }, + statement_config=default_statement_config, + ) + + async with config.provide_session() as session: + await session.execute_script(""" + CREATE TABLE IF NOT EXISTS concurrent_test ( + id INT AUTO_INCREMENT PRIMARY KEY, + thread_id VARCHAR(50), + value INT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """) + await session.execute_script("DELETE FROM concurrent_test") + + yield session + + +async def test_aiomysql_mysql_json_operations(aiomysql_pooled_session: AiomysqlDriver) -> None: + """Test MySQL JSON column operations.""" + driver = aiomysql_pooled_session + + await driver.execute_script(""" + CREATE TABLE IF NOT EXISTS json_test ( + id INT AUTO_INCREMENT PRIMARY KEY, + data JSON, + metadata JSON + ) + """) + + json_data = '{"name": "test", "values": [1, 2, 3], "nested": {"key": "value"}}' + metadata = '{"created_by": "test_suite", "version": 1}' + + result = await driver.execute("INSERT INTO json_test (data, metadata) VALUES (?, ?)", (json_data, metadata)) + assert result.num_rows == 1 + + json_result = await driver.execute( + "SELECT data->>'$.name' as name, JSON_EXTRACT(data, '$.values[1]') as second_value FROM json_test WHERE id = ?", + (result.last_inserted_id,), + ) + + assert len(json_result.get_data()) == 1 + row = json_result.get_data()[0] + assert row["name"] == "test" + assert str(row["second_value"]) == "2" + + contains_result = await driver.execute( + "SELECT COUNT(*) as count FROM json_test WHERE JSON_CONTAINS(data, ?, '$.values')", ("2",) + ) + assert contains_result.get_data()[0]["count"] == 1 + + +async def test_aiomysql_mysql_specific_sql_features(aiomysql_pooled_session: AiomysqlDriver) -> None: + """Test MySQL-specific SQL features and syntax.""" + driver = aiomysql_pooled_session + + await driver.execute_script(""" + CREATE TABLE IF NOT EXISTS mysql_features ( + id INT AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(100), + value INT, + status ENUM('active', 'inactive', 'pending') DEFAULT 'pending', + tags SET('urgent', 'important', 'normal', 'low') DEFAULT 'normal' + ); + DELETE FROM mysql_features; + """) + + await driver.execute( + "INSERT INTO mysql_features (id, name, value, status) VALUES (?, ?, ?, ?) AS new_vals ON DUPLICATE KEY UPDATE value = new_vals.value + ?, status = new_vals.status", + (1, "duplicate_test", 100, "active", 50), + ) + + await driver.execute( + "INSERT INTO mysql_features (id, name, value, status) VALUES (?, ?, ?, ?) AS new_vals ON DUPLICATE KEY UPDATE value = new_vals.value + ?, status = new_vals.status", + (1, "duplicate_test_updated", 200, "inactive", 50), + ) + await driver.commit() + + result = await driver.execute("SELECT name, value, status FROM mysql_features WHERE id = ?", (1,)) + row = result.get_data()[0] + assert row["value"] == 250 + assert row["status"] == "inactive" + + await driver.execute( + "INSERT INTO mysql_features (name, value, status, tags) VALUES (?, ?, ?, ?)", + ("enum_set_test", 300, "active", "urgent,important"), + ) + + enum_result = await driver.execute("SELECT status, tags FROM mysql_features WHERE name = ?", ("enum_set_test",)) + enum_row = enum_result.get_data()[0] + assert enum_row["status"] == "active" + assert "urgent" in enum_row["tags"] + assert "important" in enum_row["tags"] + + +async def test_aiomysql_transaction_isolation_levels(aiomysql_pooled_session: AiomysqlDriver) -> None: + """Test MySQL transaction isolation level handling.""" + driver = aiomysql_pooled_session + + await driver.execute_script(""" + CREATE TABLE IF NOT EXISTS isolation_test ( + id INT PRIMARY KEY, + value VARCHAR(50) + ) + """) + + await driver.execute("SET SESSION TRANSACTION ISOLATION LEVEL READ COMMITTED") + + await driver.begin() + + await driver.execute("INSERT INTO isolation_test (id, value) VALUES (?, ?)", (1, "transaction_data")) + + result = await driver.execute("SELECT COUNT(*) as count FROM isolation_test WHERE id = ?", (1,)) + assert result.get_data()[0]["count"] == 1 + + await driver.commit() + + committed_result = await driver.execute("SELECT value FROM isolation_test WHERE id = ?", (1,)) + assert committed_result.get_data()[0]["value"] == "transaction_data" + + +async def test_aiomysql_stored_procedures(aiomysql_pooled_session: AiomysqlDriver) -> None: + """Test stored procedure execution.""" + driver = aiomysql_pooled_session + + await driver.execute_script(""" + DROP PROCEDURE IF EXISTS test_procedure; + + CREATE PROCEDURE test_procedure(IN input_value INT, OUT output_value INT) + BEGIN + SET output_value = input_value * 2; + END; + """) + + await driver.execute_script(""" + DROP PROCEDURE IF EXISTS simple_procedure; + + CREATE PROCEDURE simple_procedure(IN multiplier INT) + BEGIN + CREATE TEMPORARY TABLE IF NOT EXISTS proc_result (result_value INT); + INSERT INTO proc_result (result_value) VALUES (multiplier * 10); + END; + """) + + await driver.execute("CALL simple_procedure(?)", (5,)) + + +async def test_aiomysql_bulk_operations_performance(aiomysql_pooled_session: AiomysqlDriver) -> None: + """Test bulk operations for performance characteristics.""" + driver = aiomysql_pooled_session + + await driver.execute_script(""" + CREATE TABLE IF NOT EXISTS bulk_test ( + id INT AUTO_INCREMENT PRIMARY KEY, + batch_id VARCHAR(50), + sequence_num INT, + data VARCHAR(100) + ) + """) + + batch_size = 100 + batch_data = [("batch_001", i, f"data_item_{i:04d}") for i in range(batch_size)] + + result = await driver.execute_many( + "INSERT INTO bulk_test (batch_id, sequence_num, data) VALUES (?, ?, ?)", batch_data + ) + + assert result.num_rows == batch_size + + count_result = await driver.execute("SELECT COUNT(*) as total FROM bulk_test WHERE batch_id = ?", ("batch_001",)) + assert count_result.get_data()[0]["total"] == batch_size + + select_result = await driver.execute( + "SELECT sequence_num, data FROM bulk_test WHERE batch_id = ? ORDER BY sequence_num", ("batch_001",) + ) + + assert len(select_result.get_data()) == batch_size + assert select_result.get_data()[0]["sequence_num"] == 0 + assert select_result.get_data()[99]["sequence_num"] == 99 + + +async def test_aiomysql_error_recovery(aiomysql_pooled_session: AiomysqlDriver) -> None: + """Test error handling and connection recovery.""" + driver = aiomysql_pooled_session + + await driver.execute_script(""" + CREATE TABLE IF NOT EXISTS error_test ( + id INT PRIMARY KEY, + value VARCHAR(50) NOT NULL + ) + """) + + await driver.execute("INSERT INTO error_test (id, value) VALUES (?, ?)", (1, "test_value")) + + with pytest.raises(Exception): + await driver.execute("INSERT INTO error_test (id, value) VALUES (?, ?)", (1, "duplicate")) + + recovery_result = await driver.execute("SELECT COUNT(*) as count FROM error_test") + assert recovery_result.get_data()[0]["count"] == 1 + + with pytest.raises(Exception): + await driver.execute("INSERT INTO error_test (id, value) VALUES (?, ?)", (2, None)) + + final_result = await driver.execute("SELECT value FROM error_test WHERE id = ?", (1,)) + assert final_result.get_data()[0]["value"] == "test_value" + + +async def test_aiomysql_sql_object_advanced_features(aiomysql_pooled_session: AiomysqlDriver) -> None: + """Test SQL object integration with advanced aiomysql features.""" + driver = aiomysql_pooled_session + + await driver.execute_script(""" + CREATE TABLE IF NOT EXISTS advanced_test ( + id INT AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(100), + metadata JSON, + score DECIMAL(10,2) + ) + """) + + complex_sql = SQL( + """ + INSERT INTO advanced_test (name, metadata, score) + VALUES (?, ?, ?) + AS new_vals + ON DUPLICATE KEY UPDATE + score = new_vals.score + ?, + metadata = JSON_MERGE_PATCH(advanced_test.metadata, new_vals.metadata) + """, + "complex_test", + '{"type": "advanced", "priority": 1}', + 95.5, + 10.0, + ) + + result = await driver.execute(complex_sql) + assert isinstance(result, SQLResult) + assert result.num_rows == 1 + + verify_sql = SQL( + "SELECT name, metadata->>'$.type' as type, score FROM advanced_test WHERE name = ?", "complex_test" + ) + + verify_result = await driver.execute(verify_sql) + assert len(verify_result.get_data()) == 1 + row = verify_result.get_data()[0] + assert row["name"] == "complex_test" + assert row["type"] == "advanced" + assert float(row["score"]) == 95.5 diff --git a/tests/integration/adapters/aiomysql/test_migrations.py b/tests/integration/adapters/aiomysql/test_migrations.py new file mode 100644 index 000000000..1abae80a4 --- /dev/null +++ b/tests/integration/adapters/aiomysql/test_migrations.py @@ -0,0 +1,405 @@ +"""Integration tests for aiomysql (MySQL) migration workflow.""" + +from pathlib import Path + +import pytest +from pytest_databases.docker.mysql import MySQLService + +from sqlspec.adapters.aiomysql.config import AiomysqlConfig +from sqlspec.migrations.commands import AsyncMigrationCommands + +pytestmark = pytest.mark.xdist_group("mysql") + + +async def test_aiomysql_migration_full_workflow(tmp_path: Path, mysql_service: MySQLService) -> None: + """Test full aiomysql migration workflow: init -> create -> upgrade -> downgrade.""" + + test_id = "aiomysql_full_workflow" + migration_table = f"sqlspec_migrations_{test_id}" + users_table = f"users_{test_id}" + + migration_dir = tmp_path / "migrations" + + config = AiomysqlConfig( + connection_config={ + "host": mysql_service.host, + "port": mysql_service.port, + "user": mysql_service.user, + "password": mysql_service.password, + "db": mysql_service.db, + "autocommit": True, + }, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands = AsyncMigrationCommands(config) + + await commands.init(str(migration_dir), package=True) + + assert migration_dir.exists() + assert (migration_dir / "__init__.py").exists() + + migration_content = f'''"""Initial schema migration.""" + + +def up(): + """Create users table.""" + return [""" + CREATE TABLE {users_table} ( + id INT AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(255) NOT NULL, + email VARCHAR(255) UNIQUE NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """] + + +def down(): + """Drop users table.""" + return ["DROP TABLE IF EXISTS {users_table}"] +''' + + migration_file = migration_dir / "0001_create_users.py" + migration_file.write_text(migration_content) + + try: + await commands.upgrade() + + async with config.provide_session() as driver: + result = await driver.execute( + f"SELECT table_name FROM information_schema.tables WHERE table_schema = %s AND table_name = '{users_table}'", + (mysql_service.db,), + ) + assert len(result.data) == 1 + + await driver.execute( + f"INSERT INTO {users_table} (name, email) VALUES (%s, %s)", ("John Doe", "john@example.com") + ) + + users_result = await driver.execute(f"SELECT * FROM {users_table}") + assert len(users_result.data) == 1 + assert users_result.get_data()[0]["name"] == "John Doe" + assert users_result.get_data()[0]["email"] == "john@example.com" + + await commands.downgrade("base") + + async with config.provide_session() as driver: + result = await driver.execute( + f"SELECT table_name FROM information_schema.tables WHERE table_schema = %s AND table_name = '{users_table}'", + (mysql_service.db,), + ) + assert len(result.data) == 0 + finally: + if config.connection_instance: + await config.close_pool() + + +async def test_aiomysql_multiple_migrations_workflow(tmp_path: Path, mysql_service: MySQLService) -> None: + """Test aiomysql workflow with multiple migrations: create -> apply both -> downgrade one -> downgrade all.""" + + test_id = "aiomysql_multiple_workflow" + migration_table = f"sqlspec_migrations_{test_id}" + users_table = f"users_{test_id}" + posts_table = f"posts_{test_id}" + + migration_dir = tmp_path / "migrations" + + config = AiomysqlConfig( + connection_config={ + "host": mysql_service.host, + "port": mysql_service.port, + "user": mysql_service.user, + "password": mysql_service.password, + "db": mysql_service.db, + "autocommit": True, + }, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands = AsyncMigrationCommands(config) + + await commands.init(str(migration_dir), package=True) + + migration1_content = f'''"""Create users table.""" + + +def up(): + """Create users table.""" + return [""" + CREATE TABLE {users_table} ( + id INT AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(255) NOT NULL, + email VARCHAR(255) UNIQUE NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP + ) + """] + + +def down(): + """Drop users table.""" + return ["DROP TABLE IF EXISTS {users_table}"] +''' + (migration_dir / "0001_create_users.py").write_text(migration1_content) + + migration2_content = f'''"""Create posts table.""" + + +def up(): + """Create posts table.""" + return [""" + CREATE TABLE {posts_table} ( + id INT AUTO_INCREMENT PRIMARY KEY, + title VARCHAR(255) NOT NULL, + content TEXT, + user_id INT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (user_id) REFERENCES {users_table}(id) + ) + """] + + +def down(): + """Drop posts table.""" + return ["DROP TABLE IF EXISTS {posts_table}"] +''' + (migration_dir / "0002_create_posts.py").write_text(migration2_content) + + try: + await commands.upgrade() + + async with config.provide_session() as driver: + users_result = await driver.execute( + f"SELECT table_name FROM information_schema.tables WHERE table_schema = %s AND table_name = '{users_table}'", + (mysql_service.db,), + ) + posts_result = await driver.execute( + f"SELECT table_name FROM information_schema.tables WHERE table_schema = %s AND table_name = '{posts_table}'", + (mysql_service.db,), + ) + assert len(users_result.data) == 1 + assert len(posts_result.data) == 1 + + await driver.execute( + f"INSERT INTO {users_table} (name, email) VALUES (%s, %s)", ("John Doe", "john@example.com") + ) + await driver.execute( + f"INSERT INTO {posts_table} (title, content, user_id) VALUES (%s, %s, %s)", + ("Test Post", "This is a test post", 1), + ) + + await commands.downgrade("0001") + + async with config.provide_session() as driver: + users_result = await driver.execute( + f"SELECT table_name FROM information_schema.tables WHERE table_schema = %s AND table_name = '{users_table}'", + (mysql_service.db,), + ) + posts_result = await driver.execute( + f"SELECT table_name FROM information_schema.tables WHERE table_schema = %s AND table_name = '{posts_table}'", + (mysql_service.db,), + ) + assert len(users_result.data) == 1 + assert len(posts_result.data) == 0 + + await commands.downgrade("base") + + async with config.provide_session() as driver: + users_result = await driver.execute( + f"SELECT table_name FROM information_schema.tables WHERE table_schema = %s AND table_name IN ('{users_table}', '{posts_table}')", + (mysql_service.db,), + ) + assert len(users_result.data) == 0 + finally: + if config.connection_instance: + await config.close_pool() + + +async def test_aiomysql_migration_current_command(tmp_path: Path, mysql_service: MySQLService) -> None: + """Test the current migration command shows correct version for aiomysql.""" + + test_id = "aiomysql_current_cmd" + migration_table = f"sqlspec_migrations_{test_id}" + users_table = f"users_{test_id}" + + migration_dir = tmp_path / "migrations" + + config = AiomysqlConfig( + connection_config={ + "host": mysql_service.host, + "port": mysql_service.port, + "user": mysql_service.user, + "password": mysql_service.password, + "db": mysql_service.db, + "autocommit": True, + }, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands = AsyncMigrationCommands(config) + + try: + await commands.init(str(migration_dir), package=True) + + current_version = await commands.current() + assert current_version is None or current_version == "base" + + migration_content = f'''"""Initial schema migration.""" + + +def up(): + """Create users table.""" + return [""" + CREATE TABLE {users_table} ( + id INT AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(255) NOT NULL + ) + """] + + +def down(): + """Drop users table.""" + return ["DROP TABLE IF EXISTS {users_table}"] +''' + (migration_dir / "0001_create_users.py").write_text(migration_content) + + await commands.upgrade() + + current_version = await commands.current() + assert current_version == "0001" + + await commands.downgrade("base") + + current_version = await commands.current() + assert current_version is None or current_version == "base" + finally: + if config.connection_instance: + await config.close_pool() + + +async def test_aiomysql_migration_error_handling(tmp_path: Path, mysql_service: MySQLService) -> None: + """Test aiomysql migration error handling.""" + + test_id = "aiomysql_error_handling" + migration_table = f"sqlspec_migrations_{test_id}" + + migration_dir = tmp_path / "migrations" + + config = AiomysqlConfig( + connection_config={ + "host": mysql_service.host, + "port": mysql_service.port, + "user": mysql_service.user, + "password": mysql_service.password, + "db": mysql_service.db, + "autocommit": True, + }, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands = AsyncMigrationCommands(config) + + try: + await commands.init(str(migration_dir), package=True) + + migration_content = '''"""Migration with invalid SQL.""" + + +def up(): + """Create table with invalid SQL.""" + return ["CREATE INVALID SQL STATEMENT"] + + +def down(): + """Drop table.""" + return ["DROP TABLE IF EXISTS invalid_table"] +''' + (migration_dir / "0001_invalid.py").write_text(migration_content) + + await commands.upgrade() + + async with config.provide_session() as driver: + count = await driver.select_value(f"SELECT COUNT(*) FROM {migration_table}") + assert count == 0, f"Expected empty migration table after failed migration, but found {count} records" + finally: + if config.connection_instance: + await config.close_pool() + + +async def test_aiomysql_migration_with_transactions(tmp_path: Path, mysql_service: MySQLService) -> None: + """Test aiomysql migrations work properly with transactions.""" + + test_id = "aiomysql_transactions" + migration_table = f"sqlspec_migrations_{test_id}" + users_table = f"users_{test_id}" + + migration_dir = tmp_path / "migrations" + + config = AiomysqlConfig( + connection_config={ + "host": mysql_service.host, + "port": mysql_service.port, + "user": mysql_service.user, + "password": mysql_service.password, + "db": mysql_service.db, + "autocommit": False, + }, + migration_config={"script_location": str(migration_dir), "version_table_name": migration_table}, + ) + commands = AsyncMigrationCommands(config) + + try: + await commands.init(str(migration_dir), package=True) + + migration_content = f'''"""Initial schema migration.""" + + +def up(): + """Create users table.""" + return [""" + CREATE TABLE {users_table} ( + id INT AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(255) NOT NULL, + email VARCHAR(255) UNIQUE NOT NULL + ) + """] + + +def down(): + """Drop users table.""" + return ["DROP TABLE IF EXISTS {users_table}"] +''' + (migration_dir / "0001_create_users.py").write_text(migration_content) + + await commands.upgrade() + + async with config.provide_session() as driver: + await driver.begin() + try: + await driver.execute( + f"INSERT INTO {users_table} (name, email) VALUES (%s, %s)", + ("Transaction User", "trans@example.com"), + ) + + result = await driver.execute(f"SELECT * FROM {users_table} WHERE name = 'Transaction User'") + assert len(result.data) == 1 + await driver.commit() + except Exception: + await driver.rollback() + raise + + result = await driver.execute(f"SELECT * FROM {users_table} WHERE name = 'Transaction User'") + assert len(result.data) == 1 + + async with config.provide_session() as driver: + await driver.begin() + try: + await driver.execute( + f"INSERT INTO {users_table} (name, email) VALUES (%s, %s)", + ("Rollback User", "rollback@example.com"), + ) + + raise Exception("Intentional rollback") + except Exception: + await driver.rollback() + + result = await driver.execute(f"SELECT * FROM {users_table} WHERE name = 'Rollback User'") + assert len(result.data) == 0 + finally: + if config.connection_instance: + await config.close_pool() diff --git a/tests/integration/adapters/aiomysql/test_parameter_styles.py b/tests/integration/adapters/aiomysql/test_parameter_styles.py new file mode 100644 index 000000000..92d4e1c80 --- /dev/null +++ b/tests/integration/adapters/aiomysql/test_parameter_styles.py @@ -0,0 +1,696 @@ +"""Test parameter conversion and validation for aiomysql driver. + +This test suite validates that the SQLTransformer properly converts different +input parameter styles to the target MySQL PYFORMAT style when necessary. + +aiomysql Parameter Conversion Requirements: +- Input: QMARK (?) -> Output: PYFORMAT (%s) +- Input: NAMED (%(name)s) -> Output: PYFORMAT (%s) +- Input: PYFORMAT (%s) -> Output: PYFORMAT (%s) (no conversion) + +This implements MySQL's 2-phase parameter processing. +""" + +import math +from collections.abc import AsyncGenerator + +import pytest +from pytest_databases.docker.mysql import MySQLService + +from sqlspec.adapters.aiomysql import AiomysqlConfig, AiomysqlDriver, default_statement_config +from sqlspec.core import SQL, SQLResult + +pytestmark = pytest.mark.xdist_group("mysql") + + +@pytest.fixture +async def aiomysql_parameter_session(mysql_service: MySQLService) -> AsyncGenerator[AiomysqlDriver, None]: + """Create an aiomysql session for parameter conversion testing.""" + config = AiomysqlConfig( + connection_config={ + "host": mysql_service.host, + "port": mysql_service.port, + "user": mysql_service.user, + "password": mysql_service.password, + "db": mysql_service.db, + "autocommit": True, + }, + statement_config=default_statement_config, + ) + + async with config.provide_session() as session: + await session.execute_script(""" + CREATE TABLE IF NOT EXISTS test_parameter_conversion ( + id INT AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(255) NOT NULL, + value INT DEFAULT 0, + description TEXT + ) + """) + + await session.execute_script("DELETE FROM test_parameter_conversion") + + await session.execute( + "INSERT INTO test_parameter_conversion (name, value, description) VALUES (?, ?, ?)", + ("test1", 100, "First test"), + ) + await session.execute( + "INSERT INTO test_parameter_conversion (name, value, description) VALUES (?, ?, ?)", + ("test2", 200, "Second test"), + ) + await session.execute( + "INSERT INTO test_parameter_conversion (name, value, description) VALUES (?, ?, ?)", ("test3", 300, None) + ) + + yield session + + await session.execute_script("DROP TABLE IF EXISTS test_parameter_conversion") + + +async def test_aiomysql_qmark_to_pyformat_conversion(aiomysql_parameter_session: AiomysqlDriver) -> None: + """Test that ? placeholders get converted to %s placeholders.""" + driver = aiomysql_parameter_session + + result = await driver.execute("SELECT * FROM test_parameter_conversion WHERE name = ? AND value > ?", ("test1", 50)) + + assert isinstance(result, SQLResult) + assert result.rows_affected == 1 + assert result.data is not None + assert len(result.data) == 1 + assert result.get_data()[0]["name"] == "test1" + assert result.get_data()[0]["value"] == 100 + + +async def test_aiomysql_pyformat_no_conversion_needed(aiomysql_parameter_session: AiomysqlDriver) -> None: + """Test that %s placeholders are used directly without conversion (native format).""" + driver = aiomysql_parameter_session + + result = await driver.execute( + "SELECT * FROM test_parameter_conversion WHERE name = %s AND value > %s", ("test2", 150) + ) + + assert isinstance(result, SQLResult) + assert result.rows_affected == 1 + assert result.data is not None + assert len(result.data) == 1 + assert result.get_data()[0]["name"] == "test2" + assert result.get_data()[0]["value"] == 200 + + +async def test_aiomysql_named_to_pyformat_conversion(aiomysql_parameter_session: AiomysqlDriver) -> None: + """Test that %(name)s placeholders get converted to %s placeholders.""" + driver = aiomysql_parameter_session + + result = await driver.execute( + "SELECT * FROM test_parameter_conversion WHERE name = %(test_name)s AND value < %(max_value)s", + {"test_name": "test3", "max_value": 350}, + ) + + assert isinstance(result, SQLResult) + assert result.rows_affected == 1 + assert result.data is not None + assert len(result.data) == 1 + assert result.get_data()[0]["name"] == "test3" + assert result.get_data()[0]["value"] == 300 + + +async def test_aiomysql_sql_object_conversion_validation(aiomysql_parameter_session: AiomysqlDriver) -> None: + """Test parameter conversion with SQL object containing different parameter styles.""" + driver = aiomysql_parameter_session + + sql_pyformat = SQL("SELECT * FROM test_parameter_conversion WHERE value BETWEEN %s AND %s", 150, 250) + result = await driver.execute(sql_pyformat) + + assert isinstance(result, SQLResult) + assert result.rows_affected == 1 + assert result.data is not None + assert result.get_data()[0]["name"] == "test2" + + sql_qmark = SQL("SELECT * FROM test_parameter_conversion WHERE name = ? OR name = ?", "test1", "test3") + result2 = await driver.execute(sql_qmark) + + assert isinstance(result2, SQLResult) + assert result2.rows_affected == 2 + assert result2.data is not None + names = [row["name"] for row in result2.get_data()] + assert "test1" in names + assert "test3" in names + + +async def test_aiomysql_mixed_parameter_types_conversion(aiomysql_parameter_session: AiomysqlDriver) -> None: + """Test conversion with different parameter value types.""" + driver = aiomysql_parameter_session + + await driver.execute( + "INSERT INTO test_parameter_conversion (name, value, description) VALUES (%s, %s, %s)", + ("mixed_test", 999, "Mixed type test"), + ) + + result = await driver.execute( + "SELECT * FROM test_parameter_conversion WHERE description IS NOT NULL AND value = %s", (999,) + ) + + assert isinstance(result, SQLResult) + assert result.rows_affected == 1 + assert result.data is not None + assert result.get_data()[0]["name"] == "mixed_test" + assert result.get_data()[0]["description"] == "Mixed type test" + + +async def test_aiomysql_execute_many_parameter_conversion(aiomysql_parameter_session: AiomysqlDriver) -> None: + """Test parameter conversion in execute_many operations.""" + driver = aiomysql_parameter_session + + batch_data = [("batch1", 1000, "Batch test 1"), ("batch2", 2000, "Batch test 2"), ("batch3", 3000, "Batch test 3")] + + result = await driver.execute_many( + "INSERT INTO test_parameter_conversion (name, value, description) VALUES (%s, %s, %s)", batch_data + ) + + assert isinstance(result, SQLResult) + assert result.rows_affected == 3 + + verify_result = await driver.execute( + "SELECT COUNT(*) as count FROM test_parameter_conversion WHERE name LIKE ?", ("batch%",) + ) + + assert verify_result.data is not None + assert verify_result.get_data()[0]["count"] == 3 + + +async def test_aiomysql_parameter_conversion_edge_cases(aiomysql_parameter_session: AiomysqlDriver) -> None: + """Test edge cases in parameter conversion.""" + driver = aiomysql_parameter_session + + result = await driver.execute("SELECT COUNT(*) as total FROM test_parameter_conversion") + assert result.data is not None + assert result.get_data()[0]["total"] >= 3 + + result2 = await driver.execute("SELECT * FROM test_parameter_conversion WHERE name = %s", ("test1",)) + assert result2.rows_affected == 1 + assert result2.data is not None + assert result2.get_data()[0]["name"] == "test1" + + result3 = await driver.execute( + "SELECT COUNT(*) as count FROM test_parameter_conversion WHERE name LIKE %s", ("test%",) + ) + assert result3.data is not None + assert result3.get_data()[0]["count"] >= 3 + + +async def test_aiomysql_parameter_style_consistency_validation(aiomysql_parameter_session: AiomysqlDriver) -> None: + """Test that the parameter conversion maintains consistency.""" + driver = aiomysql_parameter_session + + result_qmark = await driver.execute( + "SELECT name, value FROM test_parameter_conversion WHERE value >= ? ORDER BY value", (200,) + ) + + result_pyformat = await driver.execute( + "SELECT name, value FROM test_parameter_conversion WHERE value >= %s ORDER BY value", (200,) + ) + + assert result_qmark.rows_affected == result_pyformat.rows_affected + assert result_qmark.data is not None + assert result_pyformat.data is not None + assert len(result_qmark.data) == len(result_pyformat.data) + + for i in range(len(result_qmark.data)): + assert result_qmark.get_data()[i]["name"] == result_pyformat.get_data()[i]["name"] + assert result_qmark.get_data()[i]["value"] == result_pyformat.get_data()[i]["value"] + + +async def test_aiomysql_complex_query_parameter_conversion(aiomysql_parameter_session: AiomysqlDriver) -> None: + """Test parameter conversion in complex queries with multiple operations.""" + driver = aiomysql_parameter_session + + await driver.execute_many( + "INSERT INTO test_parameter_conversion (name, value, description) VALUES (?, ?, ?)", + [("complex1", 150, "Complex test"), ("complex2", 250, "Complex test"), ("complex3", 350, "Complex test")], + ) + + result = await driver.execute( + """ + SELECT name, value, description + FROM test_parameter_conversion + WHERE description = %s + AND value BETWEEN %s AND %s + AND name IN ( + SELECT name FROM test_parameter_conversion + WHERE value > %s + ) + ORDER BY value + """, + ("Complex test", 200, 300, 100), + ) + + assert isinstance(result, SQLResult) + assert result.rows_affected == 1 + assert result.data is not None + assert result.get_data()[0]["name"] == "complex2" + assert result.get_data()[0]["value"] == 250 + + +async def test_aiomysql_mysql_parameter_style_specifics(aiomysql_parameter_session: AiomysqlDriver) -> None: + """Test MySQL-specific parameter handling requirements.""" + driver = aiomysql_parameter_session + + result = await driver.execute("SELECT name, value FROM test_parameter_conversion ORDER BY value LIMIT ?", (2,)) + assert result.rows_affected == 2 + assert len(result.get_data()) == 2 + + result2 = await driver.execute( + """ + SELECT name FROM test_parameter_conversion WHERE value = ? + UNION + SELECT name FROM test_parameter_conversion WHERE value = ? + ORDER BY name + """, + (100, 200), + ) + assert result2.rows_affected == 2 + + await driver.execute( + "REPLACE INTO test_parameter_conversion (id, name, value, description) VALUES (?, ?, ?, ?)", + (999, "replace_test", 888, "Replaced entry"), + ) + + verify_result = await driver.execute("SELECT name, value FROM test_parameter_conversion WHERE id = ?", (999,)) + assert verify_result.data is not None + assert verify_result.get_data()[0]["name"] == "replace_test" + assert verify_result.get_data()[0]["value"] == 888 + + +async def test_aiomysql_2phase_parameter_processing(aiomysql_parameter_session: AiomysqlDriver) -> None: + """Test the 2-phase parameter processing system specific to aiomysql/MySQL.""" + driver = aiomysql_parameter_session + + test_cases = [ + ("SELECT * FROM test_parameter_conversion WHERE name = ? AND value = ?", ("test1", 100), "test1", 100), + ("SELECT * FROM test_parameter_conversion WHERE name = %s AND value = %s", ("test2", 200), "test2", 200), + ( + "SELECT * FROM test_parameter_conversion WHERE name = %(n)s AND value = %(v)s", + {"n": "test3", "v": 300}, + "test3", + 300, + ), + ] + + for sql_text, params, expected_name, expected_value in test_cases: + result = await driver.execute(sql_text, params) + assert isinstance(result, SQLResult) + assert result.rows_affected == 1 + assert result.data is not None + assert len(result.data) == 1 + assert result.get_data()[0]["name"] == expected_name + assert result.get_data()[0]["value"] == expected_value + + consistent_results = [] + for sql_text, params, _, _ in test_cases: + result = await driver.execute(sql_text.replace("name = ", "name != ").replace("AND", "OR"), params) + consistent_results.append(len(result.get_data())) + + assert all(count == consistent_results[0] for count in consistent_results) + + +async def test_aiomysql_none_parameters_pyformat(aiomysql_parameter_session: AiomysqlDriver) -> None: + """Test None values with PYFORMAT (%s) parameter style.""" + driver = aiomysql_parameter_session + + await driver.execute_script(""" + CREATE TABLE IF NOT EXISTS test_none_values ( + id INT AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(255), + value INT, + description TEXT, + flag BOOLEAN, + created_at DATETIME + ) + """) + + await driver.execute_script("DELETE FROM test_none_values") + + params = ("test_none", None, "Test with None value", None, None) + result = await driver.execute( + "INSERT INTO test_none_values (name, value, description, flag, created_at) VALUES (%s, %s, %s, %s, %s)", params + ) + + assert isinstance(result, SQLResult) + assert result.rows_affected == 1 + + select_result = await driver.execute("SELECT * FROM test_none_values WHERE name = %s", ("test_none",)) + + assert isinstance(select_result, SQLResult) + assert select_result.data is not None + assert len(select_result.data) == 1 + row = select_result.get_data()[0] + assert row["name"] == "test_none" + assert row["value"] is None + assert row["description"] == "Test with None value" + assert row["flag"] is None + assert row["created_at"] is None + + +async def test_aiomysql_none_parameters_qmark(aiomysql_parameter_session: AiomysqlDriver) -> None: + """Test None values with QMARK (?) parameter style.""" + driver = aiomysql_parameter_session + + await driver.execute_script(""" + CREATE TABLE IF NOT EXISTS test_none_qmark ( + id INT AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(255), + value INT, + optional_field VARCHAR(100) + ) + """) + + await driver.execute_script("DELETE FROM test_none_qmark") + + params = ("qmark_test", None, None) + result = await driver.execute("INSERT INTO test_none_qmark (name, value, optional_field) VALUES (?, ?, ?)", params) + + assert isinstance(result, SQLResult) + assert result.rows_affected == 1 + + verify_result = await driver.execute("SELECT * FROM test_none_qmark WHERE name = ?", ("qmark_test",)) + + assert isinstance(verify_result, SQLResult) + assert verify_result.data is not None + assert len(verify_result.data) == 1 + row = verify_result.get_data()[0] + assert row["name"] == "qmark_test" + assert row["value"] is None + assert row["optional_field"] is None + + +async def test_aiomysql_none_parameters_named_pyformat(aiomysql_parameter_session: AiomysqlDriver) -> None: + """Test None values with named PYFORMAT %(name)s parameter style.""" + driver = aiomysql_parameter_session + + await driver.execute_script(""" + CREATE TABLE IF NOT EXISTS test_none_named ( + id INT AUTO_INCREMENT PRIMARY KEY, + title VARCHAR(255), + status VARCHAR(50), + priority INT, + metadata JSON + ) + """) + + await driver.execute_script("DELETE FROM test_none_named") + + params = {"title": "Named test", "status": None, "priority": 5, "metadata": None} + + result = await driver.execute( + """ + INSERT INTO test_none_named (title, status, priority, metadata) + VALUES (%(title)s, %(status)s, %(priority)s, %(metadata)s) + """, + params, + ) + + assert isinstance(result, SQLResult) + assert result.rows_affected == 1 + + verify_result = await driver.execute( + "SELECT * FROM test_none_named WHERE title = %(search_title)s", {"search_title": "Named test"} + ) + + assert isinstance(verify_result, SQLResult) + assert verify_result.data is not None + assert len(verify_result.data) == 1 + row = verify_result.get_data()[0] + assert row["title"] == "Named test" + assert row["status"] is None + assert row["priority"] == 5 + assert row["metadata"] is None + + +async def test_aiomysql_all_none_parameters(aiomysql_parameter_session: AiomysqlDriver) -> None: + """Test when all parameter values are None.""" + driver = aiomysql_parameter_session + + await driver.execute_script(""" + CREATE TABLE IF NOT EXISTS test_all_none ( + id INT AUTO_INCREMENT PRIMARY KEY, + col1 VARCHAR(255), + col2 INT, + col3 BOOLEAN, + col4 TEXT + ) + """) + + await driver.execute_script("DELETE FROM test_all_none") + + params = (None, None, None, None) + result = await driver.execute("INSERT INTO test_all_none (col1, col2, col3, col4) VALUES (?, ?, ?, ?)", params) + + assert isinstance(result, SQLResult) + assert result.rows_affected == 1 + + last_id = result.last_inserted_id + assert last_id is not None + + verify_result = await driver.execute("SELECT * FROM test_all_none WHERE id = ?", (last_id,)) + + assert isinstance(verify_result, SQLResult) + assert verify_result.data is not None + assert len(verify_result.data) == 1 + row = verify_result.get_data()[0] + assert row["col1"] is None + assert row["col2"] is None + assert row["col3"] is None + assert row["col4"] is None + + +async def test_aiomysql_none_with_execute_many(aiomysql_parameter_session: AiomysqlDriver) -> None: + """Test None values work correctly with execute_many.""" + driver = aiomysql_parameter_session + + await driver.execute_script(""" + CREATE TABLE IF NOT EXISTS test_none_many ( + id INT AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(255), + value INT, + category VARCHAR(100) + ) + """) + + await driver.execute_script("DELETE FROM test_none_many") + + batch_data = [ + ("item1", 100, "A"), + ("item2", None, "B"), # None value + ("item3", 300, None), # None category + (None, 400, "D"), # None name + ("item5", None, None), # Multiple None values + ] + + result = await driver.execute_many( + "INSERT INTO test_none_many (name, value, category) VALUES (?, ?, ?)", batch_data + ) + + assert isinstance(result, SQLResult) + assert result.rows_affected == 5 + + verify_result = await driver.execute("SELECT * FROM test_none_many ORDER BY id") + + assert isinstance(verify_result, SQLResult) + assert verify_result.data is not None + assert len(verify_result.data) == 5 + + rows = verify_result.get_data() + assert rows[0]["name"] == "item1" and rows[0]["value"] == 100 and rows[0]["category"] == "A" + assert rows[1]["name"] == "item2" and rows[1]["value"] is None and rows[1]["category"] == "B" + assert rows[2]["name"] == "item3" and rows[2]["value"] == 300 and rows[2]["category"] is None + assert rows[3]["name"] is None and rows[3]["value"] == 400 and rows[3]["category"] == "D" + assert rows[4]["name"] == "item5" and rows[4]["value"] is None and rows[4]["category"] is None + + +async def test_aiomysql_none_parameter_count_validation(aiomysql_parameter_session: AiomysqlDriver) -> None: + """Test that parameter count mismatches are properly detected with None values. + + This test verifies that None values don't cause parameter count validation to fail. + """ + driver = aiomysql_parameter_session + + await driver.execute_script(""" + CREATE TABLE IF NOT EXISTS test_param_validation ( + id INT AUTO_INCREMENT PRIMARY KEY, + col1 VARCHAR(255), + col2 INT + ) + """) + + await driver.execute_script("DELETE FROM test_param_validation") + + # Test: Correct parameter count with None should work + result = await driver.execute("INSERT INTO test_param_validation (col1, col2) VALUES (?, ?)", ("valid", None)) + assert isinstance(result, SQLResult) + assert result.rows_affected == 1 + + # Test: Too many parameters should raise an error (even with None) + try: + await driver.execute( + "INSERT INTO test_param_validation (col1, col2) VALUES (?, ?)", # 2 placeholders + ("value", None, "extra"), # 3 parameters + ) + assert False, "Expected parameter count error" + except Exception as e: + error_msg = str(e).lower() + # MySQL/aiomysql typically reports parameter count errors + assert any(keyword in error_msg for keyword in ["parameter", "argument", "mismatch", "count"]) + + # Test: Too few parameters should raise an error + try: + await driver.execute( + "INSERT INTO test_param_validation (col1, col2) VALUES (?, ?)", # 2 placeholders + ("only_one",), # 1 parameter + ) + assert False, "Expected parameter count error" + except Exception as e: + error_msg = str(e).lower() + assert any(keyword in error_msg for keyword in ["parameter", "argument", "mismatch", "count"]) + + +async def test_aiomysql_none_in_where_clauses(aiomysql_parameter_session: AiomysqlDriver) -> None: + """Test None values in WHERE clauses work correctly.""" + driver = aiomysql_parameter_session + + await driver.execute_script(""" + CREATE TABLE IF NOT EXISTS test_none_where ( + id INT AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(255), + category VARCHAR(100), + status VARCHAR(50) + ) + """) + + await driver.execute_script("DELETE FROM test_none_where") + + test_data = [ + ("item1", "A", "active"), + ("item2", None, "inactive"), # None category + ("item3", "B", None), # None status + ("item4", None, None), # Both None + ] + + await driver.execute_many("INSERT INTO test_none_where (name, category, status) VALUES (?, ?, ?)", test_data) + + # Test WHERE with IS NULL for None values + result = await driver.execute("SELECT * FROM test_none_where WHERE category IS NULL") + + assert isinstance(result, SQLResult) + assert result.data is not None + assert len(result.data) == 2 # item2 and item4 + + found_names = {row["name"] for row in result.get_data()} + assert found_names == {"item2", "item4"} + + # Test parameterized query with None (should handle NULL comparison properly) + result2 = await driver.execute("SELECT * FROM test_none_where WHERE status = ? OR ? IS NULL", (None, None)) + + # The second condition should be TRUE since None IS NULL in SQL context + assert isinstance(result2, SQLResult) + assert result2.data is not None + assert len(result2.data) == 4 # All rows because second condition is always true + + +async def test_aiomysql_none_complex_scenarios(aiomysql_parameter_session: AiomysqlDriver) -> None: + """Test complex scenarios with None parameters.""" + driver = aiomysql_parameter_session + + await driver.execute_script(""" + CREATE TABLE IF NOT EXISTS test_complex_none ( + id INT AUTO_INCREMENT PRIMARY KEY, + name VARCHAR(255), + score INT, + factor DECIMAL(10,2), + active BOOLEAN, + tags JSON, + created_at TIMESTAMP, + metadata TEXT + ) + """) + + await driver.execute_script("DELETE FROM test_complex_none") + + # Test complex insert with mixed None and valid values + params = { + "name": "complex_test", + "score": None, + "factor": math.pi, + "active": None, + "tags": '["tag1", "tag2"]', + "created_at": None, + "metadata": None, + } + + result = await driver.execute( + """ + INSERT INTO test_complex_none (name, score, factor, active, tags, created_at, metadata) + VALUES (%(name)s, %(score)s, %(factor)s, %(active)s, %(tags)s, %(created_at)s, %(metadata)s) + """, + params, + ) + + assert isinstance(result, SQLResult) + assert result.rows_affected == 1 + + # Verify complex insert + verify_result = await driver.execute("SELECT * FROM test_complex_none WHERE name = ?", ("complex_test",)) + + assert isinstance(verify_result, SQLResult) + assert verify_result.data is not None + assert len(verify_result.data) == 1 + row = verify_result.get_data()[0] + assert row["name"] == "complex_test" + assert row["score"] is None + assert abs(float(row["factor"]) - math.pi) < 0.01 # Decimal comparison + assert row["active"] is None + assert row["tags"] == '["tag1", "tag2"]' or row["tags"] == ["tag1", "tag2"] # JSON field + assert row["created_at"] is None + assert row["metadata"] is None + + +async def test_aiomysql_none_edge_cases(aiomysql_parameter_session: AiomysqlDriver) -> None: + """Test edge cases that might reveal None parameter handling bugs.""" + driver = aiomysql_parameter_session + + await driver.execute_script(""" + CREATE TABLE IF NOT EXISTS test_edge_cases ( + id INT AUTO_INCREMENT PRIMARY KEY, + a VARCHAR(255), + b VARCHAR(255), + c VARCHAR(255), + d INT, + e BOOLEAN + ) + """) + + await driver.execute_script("DELETE FROM test_edge_cases") + + # Test 1: Single None parameter + await driver.execute("INSERT INTO test_edge_cases (a) VALUES (?)", (None,)) + + # Test 2: Multiple consecutive None parameters + await driver.execute( + "INSERT INTO test_edge_cases (a, b, c, d, e) VALUES (?, ?, ?, ?, ?)", (None, None, None, None, None) + ) + + # Test 3: None at different positions + test_cases = [ + (None, "middle", "end", 1, True), # None at start + ("start", None, "end", 2, False), # None at middle + ("start", "middle", None, 3, None), # None at end + (None, None, "end", None, True), # Multiple None at start + ("start", None, None, 4, None), # Multiple None at end + ] + + for params in test_cases: + await driver.execute("INSERT INTO test_edge_cases (a, b, c, d, e) VALUES (?, ?, ?, ?, ?)", params) + + # Verify all rows were inserted + count_result = await driver.execute("SELECT COUNT(*) as total FROM test_edge_cases") + assert count_result.data is not None + assert count_result.get_data()[0]["total"] == 7 # 2 initial + 5 test cases diff --git a/tests/integration/adapters/aiomysql/test_storage_bridge.py b/tests/integration/adapters/aiomysql/test_storage_bridge.py new file mode 100644 index 000000000..5d0000582 --- /dev/null +++ b/tests/integration/adapters/aiomysql/test_storage_bridge.py @@ -0,0 +1,61 @@ +"""Storage bridge integration tests for aiomysql adapter.""" + +from pathlib import Path + +import pyarrow as pa +import pyarrow.parquet as pq +import pytest + +from sqlspec.adapters.aiomysql import AiomysqlDriver + +pytestmark = [pytest.mark.xdist_group("mysql")] + + +async def _fetch_rows(aiomysql_driver: AiomysqlDriver, table: str) -> list[dict[str, object]]: + rows = await aiomysql_driver.select(f"SELECT id, name FROM {table} ORDER BY id") + assert isinstance(rows, list) + return rows + + +async def test_aiomysql_load_from_arrow(aiomysql_driver: AiomysqlDriver) -> None: + table_name = "storage_bridge_users" + await aiomysql_driver.execute(f"DROP TABLE IF EXISTS {table_name}") + await aiomysql_driver.execute(f"CREATE TABLE {table_name} (id INT PRIMARY KEY, name VARCHAR(64))") + + arrow_table = pa.table({"id": [1, 2], "name": ["alpha", "beta"]}) + + job = await aiomysql_driver.load_from_arrow(table_name, arrow_table, overwrite=True) + + assert job.telemetry["rows_processed"] == arrow_table.num_rows + assert job.telemetry["destination"] == table_name + + rows = await _fetch_rows(aiomysql_driver, table_name) + assert rows == [{"id": 1, "name": "alpha"}, {"id": 2, "name": "beta"}] + + await aiomysql_driver.execute(f"DROP TABLE IF EXISTS {table_name}") + + +async def test_aiomysql_load_from_storage(tmp_path: Path, aiomysql_driver: AiomysqlDriver) -> None: + await aiomysql_driver.execute("DROP TABLE IF EXISTS storage_bridge_scores") + await aiomysql_driver.execute("CREATE TABLE storage_bridge_scores (id INT PRIMARY KEY, score DECIMAL(5,2))") + + arrow_table = pa.table({"id": [5, 6], "score": [12.5, 99.1]}) + destination = tmp_path / "scores.parquet" + pq.write_table(arrow_table, destination) + + job = await aiomysql_driver.load_from_storage( + "storage_bridge_scores", str(destination), file_format="parquet", overwrite=True + ) + + assert job.telemetry["destination"] == "storage_bridge_scores" + assert job.telemetry["extra"]["source"]["destination"].endswith("scores.parquet") # type: ignore[index] + assert job.telemetry["extra"]["source"]["backend"] # type: ignore[index] + + rows = await aiomysql_driver.select("SELECT id, score FROM storage_bridge_scores ORDER BY id") + assert len(rows) == 2 + assert rows[0]["id"] == 5 + assert float(rows[0]["score"]) == pytest.approx(12.5) + assert rows[1]["id"] == 6 + assert float(rows[1]["score"]) == pytest.approx(99.1) + + await aiomysql_driver.execute("DROP TABLE IF EXISTS storage_bridge_scores") diff --git a/tests/unit/adapters/test_aiomysql/__init__.py b/tests/unit/adapters/test_aiomysql/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/adapters/test_aiomysql/test_config.py b/tests/unit/adapters/test_aiomysql/test_config.py new file mode 100644 index 000000000..55be2b14e --- /dev/null +++ b/tests/unit/adapters/test_aiomysql/test_config.py @@ -0,0 +1,36 @@ +"""aiomysql configuration tests covering statement config builders.""" + +from sqlspec.adapters.aiomysql.config import AiomysqlConfig +from sqlspec.adapters.aiomysql.core import build_statement_config + + +def test_build_default_statement_config_custom_serializers() -> None: + """Custom serializers should propagate into the parameter configuration.""" + + def serializer(_: object) -> str: + return "serialized" + + def deserializer(_: str) -> object: + return {"value": "deserialized"} + + statement_config = build_statement_config(json_serializer=serializer, json_deserializer=deserializer) + + parameter_config = statement_config.parameter_config + assert parameter_config.json_serializer is serializer + assert parameter_config.json_deserializer is deserializer + + +def test_aiomysql_config_applies_driver_feature_serializers() -> None: + """Driver features should mutate the aiomysql statement configuration.""" + + def serializer(_: object) -> str: + return "feature" + + def deserializer(_: str) -> object: + return {"feature": True} + + config = AiomysqlConfig(driver_features={"json_serializer": serializer, "json_deserializer": deserializer}) + + parameter_config = config.statement_config.parameter_config + assert parameter_config.json_serializer is serializer + assert parameter_config.json_deserializer is deserializer diff --git a/uv.lock b/uv.lock index e30a24807..1fb03b4d8 100644 --- a/uv.lock +++ b/uv.lock @@ -296,6 +296,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/10/a1/510b0a7fadc6f43a6ce50152e69dbd86415240835868bb0bd9b5b88b1e06/aioitertools-0.13.0-py3-none-any.whl", hash = "sha256:0be0292b856f08dfac90e31f4739432f4cb6d7520ab9eb73e143f4f2fa5259be", size = 24182, upload-time = "2025-11-06T22:17:06.502Z" }, ] +[[package]] +name = "aiomysql" +version = "0.3.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pymysql" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/29/e0/302aeffe8d90853556f47f3106b89c16cc2ec2a4d269bdfd82e3f4ae12cc/aiomysql-0.3.2.tar.gz", hash = "sha256:72d15ef5cfc34c03468eb41e1b90adb9fd9347b0b589114bd23ead569a02ac1a", size = 108311, upload-time = "2025-10-22T00:15:21.278Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4c/af/aae0153c3e28712adaf462328f6c7a3c196a1c1c27b491de4377dd3e6b52/aiomysql-0.3.2-py3-none-any.whl", hash = "sha256:c82c5ba04137d7afd5c693a258bea8ead2aad77101668044143a991e04632eb2", size = 71834, upload-time = "2025-10-22T00:15:15.905Z" }, +] + [[package]] name = "aioodbc" version = "0.5.0" @@ -6812,6 +6824,9 @@ adbc = [ adk = [ { name = "google-adk" }, ] +aiomysql = [ + { name = "aiomysql" }, +] aioodbc = [ { name = "aioodbc" }, ] @@ -7084,6 +7099,7 @@ test = [ [package.metadata] requires-dist = [ { name = "adbc-driver-manager", marker = "extra == 'adbc'" }, + { name = "aiomysql", marker = "extra == 'aiomysql'" }, { name = "aioodbc", marker = "extra == 'aioodbc'" }, { name = "aiosqlite", marker = "extra == 'aiosqlite'" }, { name = "asyncmy", marker = "extra == 'asyncmy'" }, @@ -7132,7 +7148,7 @@ requires-dist = [ { name = "typing-extensions" }, { name = "uuid-utils", marker = "extra == 'uuid'" }, ] -provides-extras = ["adbc", "adk", "aioodbc", "aiosqlite", "alloydb", "asyncmy", "asyncpg", "attrs", "bigquery", "cloud-sql", "cockroachdb", "duckdb", "fastapi", "flask", "fsspec", "litestar", "msgspec", "mypyc", "mysql-connector", "nanoid", "obstore", "opentelemetry", "oracledb", "orjson", "pandas", "performance", "polars", "prometheus", "psqlpy", "psycopg", "pydantic", "pymssql", "pymysql", "spanner", "uuid"] +provides-extras = ["adbc", "adk", "aiomysql", "aioodbc", "aiosqlite", "alloydb", "asyncmy", "asyncpg", "attrs", "bigquery", "cloud-sql", "cockroachdb", "duckdb", "fastapi", "flask", "fsspec", "litestar", "msgspec", "mypyc", "mysql-connector", "nanoid", "obstore", "opentelemetry", "oracledb", "orjson", "pandas", "performance", "polars", "prometheus", "psqlpy", "psycopg", "pydantic", "pymssql", "pymysql", "spanner", "uuid"] [package.metadata.requires-dev] benchmarks = [ From b734f68d50819c052e8368595f4e47dcbe08a115 Mon Sep 17 00:00:00 2001 From: hasansezertasan Date: Mon, 13 Apr 2026 20:04:01 +0300 Subject: [PATCH 2/2] fix: await cursor creation and fix formatting for aiomysql adapter - cursor() returns a Future on pool-acquired connections, must be awaited - Fix protocol to match (async cursor method) - Run ruff format on all new files to pass pre-commit validation Co-Authored-By: Claude Opus 4.6 (1M context) --- sqlspec/adapters/aiomysql/_typing.py | 4 ++-- sqlspec/adapters/aiomysql/core.py | 4 +--- tests/integration/adapters/aiomysql/test_arrow.py | 4 +++- tests/integration/adapters/aiomysql/test_driver.py | 4 +++- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/sqlspec/adapters/aiomysql/_typing.py b/sqlspec/adapters/aiomysql/_typing.py index c3c3aec37..b41f31a57 100644 --- a/sqlspec/adapters/aiomysql/_typing.py +++ b/sqlspec/adapters/aiomysql/_typing.py @@ -18,7 +18,7 @@ from sqlspec.core import StatementConfig class AiomysqlConnectionProtocol(Protocol): - def cursor(self) -> "AiomysqlRawCursor": ... + async def cursor(self) -> "AiomysqlRawCursor": ... async def commit(self) -> Any: ... @@ -47,7 +47,7 @@ def __init__(self, connection: "AiomysqlConnection") -> None: self.cursor: AiomysqlRawCursor | None = None async def __aenter__(self) -> "AiomysqlRawCursor": - self.cursor = self.connection.cursor() + self.cursor = await self.connection.cursor() return self.cursor async def __aexit__(self, *_: Any) -> None: diff --git a/sqlspec/adapters/aiomysql/core.py b/sqlspec/adapters/aiomysql/core.py index 696ebed59..fe3d07ea5 100644 --- a/sqlspec/adapters/aiomysql/core.py +++ b/sqlspec/adapters/aiomysql/core.py @@ -441,9 +441,7 @@ def collect_rows( rows = fetched_data if isinstance(fetched_data, list) else list(fetched_data) return rows, resolved_column_names, "dict" rows = [dict(row) for row in fetched_data] - rows = _deserialize_json_dict_rows( - resolved_column_names, rows, json_indexes, deserializer, logger=logger - ) + rows = _deserialize_json_dict_rows(resolved_column_names, rows, json_indexes, deserializer, logger=logger) return rows, resolved_column_names, "dict" rows = fetched_data if isinstance(fetched_data, list) else list(fetched_data) if json_indexes: diff --git a/tests/integration/adapters/aiomysql/test_arrow.py b/tests/integration/adapters/aiomysql/test_arrow.py index 7e963ea13..a56bcb4b9 100644 --- a/tests/integration/adapters/aiomysql/test_arrow.py +++ b/tests/integration/adapters/aiomysql/test_arrow.py @@ -52,7 +52,9 @@ async def test_select_to_arrow_batch_format(aiomysql_driver: AiomysqlDriver) -> await aiomysql_driver.execute("CREATE TABLE IF NOT EXISTS arrow_batch_test (id INT, value VARCHAR(100))") await aiomysql_driver.execute("INSERT INTO arrow_batch_test VALUES (1, 'a'), (2, 'b')") - result = await aiomysql_driver.select_to_arrow("SELECT * FROM arrow_batch_test ORDER BY id", return_format="batches") + result = await aiomysql_driver.select_to_arrow( + "SELECT * FROM arrow_batch_test ORDER BY id", return_format="batches" + ) assert isinstance(result.data, list) for batch in result.data: diff --git a/tests/integration/adapters/aiomysql/test_driver.py b/tests/integration/adapters/aiomysql/test_driver.py index 61687fd32..0ea1fbd59 100644 --- a/tests/integration/adapters/aiomysql/test_driver.py +++ b/tests/integration/adapters/aiomysql/test_driver.py @@ -436,7 +436,9 @@ async def test_aiomysql_sql_object_execution(aiomysql_driver: AiomysqlDriver) -> assert isinstance(result, SQLResult) assert result.num_rows == 1 - verify_result = await driver.execute("SELECT name, value FROM test_table_aiomysql WHERE name = ?", ("sql_obj_test",)) + verify_result = await driver.execute( + "SELECT name, value FROM test_table_aiomysql WHERE name = ?", ("sql_obj_test",) + ) assert len(verify_result.get_data()) == 1 assert verify_result.get_data()[0]["name"] == "sql_obj_test" assert verify_result.get_data()[0]["value"] == 999