Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ dependencies = [
]
requires-python = ">= 3.10"

[project.optional-dependencies]
async = [
"pytest-asyncio"
]

[project.urls]
"Source" = "https://github.com/dbfixtures/pytest-postgresql"
"Bug Tracker" = "https://github.com/dbfixtures/pytest-postgresql/issues"
Expand Down
2 changes: 1 addition & 1 deletion pytest_postgresql/factories/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# along with pytest-postgresql. If not, see <http://www.gnu.org/licenses/>.
"""Fixture factories for postgresql fixtures."""

from pytest_postgresql.factories.client import postgresql
from pytest_postgresql.factories.client import postgresql, postgresql_async
from pytest_postgresql.factories.noprocess import postgresql_noproc
from pytest_postgresql.factories.process import PortType, postgresql_proc

Expand Down
74 changes: 73 additions & 1 deletion pytest_postgresql/factories/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import psycopg
import pytest
from _pytest.scope import _ScopeName
from psycopg import Connection
from pytest import FixtureRequest

Expand All @@ -34,17 +35,19 @@ def postgresql(
process_fixture_name: str,
dbname: str | None = None,
isolation_level: "psycopg.IsolationLevel | None" = None,
scope: _ScopeName = "function",
) -> Callable[[FixtureRequest], Iterator[Connection]]:
"""Return connection fixture factory for PostgreSQL.

:param process_fixture_name: name of the process fixture
:param dbname: database name
:param isolation_level: optional postgresql isolation level
defaults to server's default
:param scope: fixture scope; by default "function" which is recommended.
:returns: function which makes a connection to postgresql
"""

@pytest.fixture
@pytest.fixture(scope=scope)
def postgresql_factory(request: FixtureRequest) -> Iterator[Connection]:
"""Fixture factory for PostgreSQL.

Expand Down Expand Up @@ -85,3 +88,72 @@ def postgresql_factory(request: FixtureRequest) -> Iterator[Connection]:
db_connection.close()

return postgresql_factory


def postgresql_async(
process_fixture_name: str,
dbname: str | None = None,
isolation_level: "psycopg.IsolationLevel | None" = None,
scope: _ScopeName = "function",
) -> Callable[[FixtureRequest], Iterator[Connection]]:
Comment thread
ivanthewebber marked this conversation as resolved.
Outdated
"""Return async connection fixture factory for PostgreSQL.

:param process_fixture_name: name of the process fixture
:param dbname: database name
:param isolation_level: optional postgresql isolation level
defaults to server's default
:param scope: fixture scope; by default "function" which is recommended.
:returns: function which makes a connection to postgresql
"""
import pytest_asyncio
from psycopg import AsyncConnection

from pytest_postgresql.janitor import AsyncDatabaseJanitor

@pytest_asyncio.fixture(scope=scope)
async def postgresql_factory(request: FixtureRequest) -> Iterator[AsyncConnection]:
"""Async fixture factory for PostgreSQL.

:param request: fixture request object
:returns: postgresql client
"""
proc_fixture: PostgreSQLExecutor | NoopExecutor = request.getfixturevalue(process_fixture_name)
config = get_config(request)

pg_host = proc_fixture.host
pg_port = proc_fixture.port
pg_user = proc_fixture.user
pg_password = proc_fixture.password
pg_options = proc_fixture.options
pg_db = dbname or proc_fixture.dbname
janitor = DatabaseJanitor(
user=pg_user,
host=pg_host,
port=pg_port,
dbname=pg_db,
template_dbname=proc_fixture.template_dbname,
version=proc_fixture.version,
password=pg_password,
isolation_level=isolation_level,
)
if config["drop_test_database"]:
janitor.drop()
with AsyncDatabaseJanitor(
pg_user, pg_host, pg_port, pg_db, proc_fixture.version, pg_password, isolation_level
) as janitor:
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
# Line modified here
db_connection: AsyncConnection = await AsyncConnection.connect(
dbname=pg_db,
user=pg_user,
password=pg_password,
host=pg_host,
port=pg_port,
options=pg_options,
)
for load_element in pg_load:
janitor.load(load_element)
Comment thread
ivanthewebber marked this conversation as resolved.
Outdated
yield db_connection
# And here
await db_connection.close()

return postgresql_factory
5 changes: 4 additions & 1 deletion pytest_postgresql/factories/noprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import Callable, Iterator

import pytest
from _pytest.scope import _ScopeName
from pytest import FixtureRequest

from pytest_postgresql.config import get_config
Expand All @@ -45,6 +46,7 @@ def postgresql_noproc(
dbname: str | None = None,
options: str = "",
load: list[Callable | str | Path] | None = None,
scope: _ScopeName = "session",
) -> Callable[[FixtureRequest], Iterator[NoopExecutor]]:
"""Postgresql noprocess factory.

Expand All @@ -55,10 +57,11 @@ def postgresql_noproc(
:param dbname: postgresql database name
:param options: Postgresql connection options
:param load: List of functions used to initialize database's template.
:param scope: fixture scope; by default "session" which is recommended.
:returns: function which makes a postgresql process
"""

@pytest.fixture(scope="session")
@pytest.fixture(scope=scope)
def postgresql_noproc_fixture(request: FixtureRequest) -> Iterator[NoopExecutor]:
"""Noop Process fixture for PostgreSQL.

Expand Down
5 changes: 4 additions & 1 deletion pytest_postgresql/factories/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import port_for
import pytest
from _pytest.scope import _ScopeName
Comment thread
ivanthewebber marked this conversation as resolved.
Outdated
from port_for import PortForException, get_port
from pytest import FixtureRequest, TempPathFactory

Expand Down Expand Up @@ -81,6 +82,7 @@ def postgresql_proc(
unixsocketdir: str | None = None,
postgres_options: str | None = None,
load: list[Callable | str | Path] | None = None,
scope: _ScopeName = "session",
) -> Callable[[FixtureRequest, TempPathFactory], Iterator[PostgreSQLExecutor]]:
"""Postgresql process factory.

Expand All @@ -101,10 +103,11 @@ def postgresql_proc(
:param unixsocketdir: directory to create postgresql's unixsockets
:param postgres_options: Postgres executable options for use by pg_ctl
:param load: List of functions used to initialize database's template.
:param scope: fixture scope; by default "session" which is recommended.
:returns: function which makes a postgresql process
"""

@pytest.fixture(scope="session")
@pytest.fixture(scope=scope)
def postgresql_proc_fixture(
request: FixtureRequest, tmp_path_factory: TempPathFactory
) -> Iterator[PostgreSQLExecutor]:
Expand Down
146 changes: 143 additions & 3 deletions pytest_postgresql/janitor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Database Janitor."""

from contextlib import contextmanager
import inspect
from contextlib import asynccontextmanager, contextmanager
from pathlib import Path
Comment thread
coderabbitai[bot] marked this conversation as resolved.
from types import TracebackType
from typing import Callable, Iterator, Type, TypeVar
Expand All @@ -9,8 +10,8 @@
from packaging.version import parse
from psycopg import Connection, Cursor

from pytest_postgresql.loader import build_loader
from pytest_postgresql.retry import retry
from pytest_postgresql.loader import build_loader, build_loader_async
from pytest_postgresql.retry import retry, retry_async

Version = type(parse("1"))

Expand Down Expand Up @@ -163,3 +164,142 @@ def __exit__(
) -> None:
"""Exit from Database janitor context cleaning after itself."""
self.drop()


class AsyncDatabaseJanitor:
"""Manage database state for specific tasks."""

def __init__(
self,
*,
user: str,
host: str,
port: str | int,
version: str | float | Version, # type: ignore[valid-type]
dbname: str | None = None,
template_dbname: str | None = None,
password: str | None = None,
isolation_level: "psycopg.IsolationLevel | None" = None,
connection_timeout: int = 60,
) -> None:
"""Initialize janitor.

:param user: postgresql username
:param host: postgresql host
:param port: postgresql port
:param dbname: database name
:param dbname: template database name
:param version: postgresql version number
:param password: optional postgresql password
:param isolation_level: optional postgresql isolation level
defaults to server's default
:param connection_timeout: how long to retry connection before
raising a TimeoutError
"""
self.user = user
self.password = password
self.host = host
self.port = port
# At least one of the dbname or template_dbname has to be filled.
assert any([dbname, template_dbname])
self.dbname = dbname
self.template_dbname = template_dbname
self._connection_timeout = connection_timeout
self.isolation_level = isolation_level
if not isinstance(version, Version):
self.version = parse(str(version))
else:
self.version = version

Comment thread
ivanthewebber marked this conversation as resolved.
async def init(self) -> None:
"""Create database in postgresql."""
async with self.cursor() as cur:
if self.is_template():
await cur.execute(f'CREATE DATABASE "{self.template_dbname}" WITH is_template = true;')
elif self.template_dbname is None:
await cur.execute(f'CREATE DATABASE "{self.dbname}";')
else:
# And make sure no-one is left connected to the template database.
# Otherwise, Creating database from template will fail
await self._terminate_connection(cur, self.template_dbname)
await cur.execute(f'CREATE DATABASE "{self.dbname}" TEMPLATE "{self.template_dbname}";')

def is_template(self) -> bool:
"""Determine whether the DatabaseJanitor maintains template or database."""
return self.dbname is None

async def drop(self) -> None:
"""Drop database in postgresql (async)."""
db_to_drop = self.template_dbname if self.is_template() else self.dbname
assert db_to_drop
async with self.cursor() as cur:
await self._dont_datallowconn(cur, db_to_drop)
await self._terminate_connection(cur, db_to_drop)
if self.is_template():
await cur.execute(f'ALTER DATABASE "{db_to_drop}" with is_template false;')
await cur.execute(f'DROP DATABASE IF EXISTS "{db_to_drop}";')

@staticmethod
async def _dont_datallowconn(cur, dbname: str) -> None:
await cur.execute(f'ALTER DATABASE "{dbname}" with allow_connections false;')

@staticmethod
async def _terminate_connection(cur, dbname: str) -> None:
await cur.execute(
"SELECT pg_terminate_backend(pg_stat_activity.pid)"
"FROM pg_stat_activity "
"WHERE pg_stat_activity.datname = %s;",
(dbname,),
)

async def load(self, load: Callable | str | Path) -> None:
"""Load data into a database (async).

Expects:

* a Path to sql file, that'll be loaded
* an import path to import callable
* a callable that expects: host, port, user, dbname and password arguments.

"""
db_to_load = self.template_dbname if self.is_template() else self.dbname
_loader = build_loader_async(load)
cor = _loader(
host=self.host,
port=self.port,
user=self.user,
dbname=db_to_load,
password=self.password,
)
if inspect.isawaitable(cor):
await cor

@asynccontextmanager
async def cursor(self, dbname: str = "postgres"):
"""Async context manager for postgresql cursor."""

async def connect() -> psycopg.AsyncConnection:
return await psycopg.AsyncConnection.connect(
dbname=dbname,
user=self.user,
password=self.password,
host=self.host,
port=self.port,
)

conn = await retry_async(connect, timeout=self._connection_timeout, possible_exception=psycopg.OperationalError)
conn.isolation_level = self.isolation_level
# We must not run a transaction since we create a database.
conn.autocommit = True
async with conn.cursor() as cur:
try:
yield cur
finally:
await conn.close()

Comment thread
coderabbitai[bot] marked this conversation as resolved.
async def __aenter__(self):
await self.init()
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.drop()
27 changes: 26 additions & 1 deletion pytest_postgresql/loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Loader helper functions."""

import importlib
import re
from functools import partial
from pathlib import Path
Expand All @@ -16,7 +17,7 @@ def build_loader(load: Callable | str | Path) -> Callable:
loader_parts = re.split("[.:]", load, maxsplit=2)
import_path = ".".join(loader_parts[:-1])
loader_name = loader_parts[-1]
_temp_import = __import__(import_path, globals(), locals(), fromlist=[loader_name])
_temp_import = importlib.import_module(import_path, globals(), locals(), fromlist=[loader_name])
Comment thread
ivanthewebber marked this conversation as resolved.
Outdated
_loader: Callable = getattr(_temp_import, loader_name)
return _loader
else:
Expand All @@ -30,3 +31,27 @@ def sql(sql_filename: Path, **kwargs: Any) -> None:
with db_connection.cursor() as cur:
cur.execute(_fd.read())
db_connection.commit()


def build_loader_async(load: Callable | str | Path) -> Callable:
"""Build a loader callable."""
if isinstance(load, Path):
return partial(sql_async, load)
elif isinstance(load, str):
loader_parts = re.split("[.:]", load, maxsplit=2)
import_path = ".".join(loader_parts[:-1])
loader_name = loader_parts[-1]
_temp_import = importlib.import_module(import_path, globals(), locals(), fromlist=[loader_name])
_loader: Callable = getattr(_temp_import, loader_name)
return _loader
else:
return load
Comment thread
ivanthewebber marked this conversation as resolved.


async def sql_async(sql_filename: Path, **kwargs: Any) -> None:
"""Async database loader for sql files."""
async with await psycopg.AsyncConnection.connect(**kwargs) as db_connection:
async with await db_connection.cursor() as cur:
with open(sql_filename, "r") as _fd:
await cur.execute(_fd.read())
await db_connection.commit()
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Loading