Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 0 additions & 42 deletions sqlalchemy_bind_manager/_async_helpers.py

This file was deleted.

32 changes: 28 additions & 4 deletions sqlalchemy_bind_manager/_bind_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

from typing import Mapping, MutableMapping, Union
import atexit
import weakref
from typing import ClassVar, Mapping, MutableMapping, Union

from pydantic import BaseModel, ConfigDict, StrictBool
from sqlalchemy import MetaData, create_engine
Expand All @@ -32,7 +34,6 @@
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm.decl_api import DeclarativeMeta, registry

from sqlalchemy_bind_manager._async_helpers import run_async_from_sync
from sqlalchemy_bind_manager.exceptions import (
InvalidConfigError,
NotInitializedBindError,
Expand Down Expand Up @@ -73,6 +74,7 @@ class SQLAlchemyAsyncBind(BaseModel):

class SQLAlchemyBindManager:
__binds: MutableMapping[str, Union[SQLAlchemyBind, SQLAlchemyAsyncBind]]
_instances: ClassVar[weakref.WeakSet["SQLAlchemyBindManager"]] = weakref.WeakSet()

def __init__(
self,
Expand All @@ -87,14 +89,24 @@ def __init__(
self.__init_bind(name, conf)
else:
self.__init_bind(DEFAULT_BIND_NAME, config)
SQLAlchemyBindManager._instances.add(self)

def __del__(self):
def _dispose_sync(self) -> None:
"""Dispose all engines synchronously.

This method is safe to call from any context, including __del__
and atexit handlers. For async engines, it uses the underlying
sync_engine to perform synchronous disposal.
"""
for bind in self.__binds.values():
if isinstance(bind, SQLAlchemyAsyncBind):
run_async_from_sync(bind.engine.dispose())
bind.engine.sync_engine.dispose()
else:
bind.engine.dispose()

def __del__(self) -> None:
self._dispose_sync()

def __init_bind(self, name: str, config: SQLAlchemyConfig):
if not isinstance(config, SQLAlchemyConfig):
raise InvalidConfigError(
Expand Down Expand Up @@ -210,3 +222,15 @@ def get_session(
:return: The SQLAlchemy Session object
"""
return self.get_bind(bind_name).session_class()


@atexit.register
def _cleanup_all_managers() -> None:
"""Cleanup handler that runs during interpreter shutdown.

This ensures all SQLAlchemyBindManager instances have their engines
disposed before the interpreter exits, even if __del__ hasn't been
called yet due to reference cycles or other GC timing issues.
"""
for manager in list(SQLAlchemyBindManager._instances):
manager._dispose_sync()
18 changes: 10 additions & 8 deletions sqlalchemy_bind_manager/_session_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
)
from sqlalchemy.orm import Session, scoped_session

from sqlalchemy_bind_manager._async_helpers import run_async_from_sync
from sqlalchemy_bind_manager._bind_manager import (
SQLAlchemyAsyncBind,
SQLAlchemyBind,
Expand Down Expand Up @@ -69,12 +68,20 @@ def commit(self, session: Session) -> None:
"""
try:
session.commit()
except:
except Exception:
session.rollback()
raise


class AsyncSessionHandler:
"""Async session handler for managing async scoped sessions.

Note: Unlike SessionHandler, this class does not implement __del__ cleanup
because async_scoped_session.remove() is an async operation that cannot be
safely executed during garbage collection. Sessions should be properly
closed via the get_session() context manager.
"""

scoped_session: async_scoped_session

def __init__(self, bind: SQLAlchemyAsyncBind):
Expand All @@ -85,11 +92,6 @@ def __init__(self, bind: SQLAlchemyAsyncBind):
bind.session_class, asyncio.current_task
)

def __del__(self):
if not getattr(self, "scoped_session", None):
return
run_async_from_sync(self.scoped_session.remove())

@asynccontextmanager
async def get_session(self, read_only: bool = False) -> AsyncIterator[AsyncSession]:
session = self.scoped_session()
Expand All @@ -110,6 +112,6 @@ async def commit(self, session: AsyncSession) -> None:
"""
try:
await session.commit()
except:
except Exception:
await session.rollback()
raise
57 changes: 8 additions & 49 deletions tests/session_handler/test_session_lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,14 @@
from sqlalchemy_bind_manager._session_handler import AsyncSessionHandler, SessionHandler


async def test_session_is_removed_on_cleanup(session_handler_class, sa_bind):
sh = session_handler_class(sa_bind)
def test_sync_session_is_removed_on_cleanup(sa_manager):
"""Test that sync SessionHandler removes session on garbage collection.

Note: AsyncSessionHandler does not implement __del__ cleanup because
async_scoped_session.remove() is an async operation that cannot be
safely executed during garbage collection.
"""
sh = SessionHandler(sa_manager.get_bind("sync"))
original_session_remove = sh.scoped_session.remove

with patch.object(
Expand All @@ -26,53 +32,6 @@ async def test_session_is_removed_on_cleanup(session_handler_class, sa_bind):
mocked_remove.assert_called_once()


def test_session_is_removed_on_cleanup_even_if_loop_is_not_running(sa_manager):
# Running the test without a loop will trigger the loop creation
sh = AsyncSessionHandler(sa_manager.get_bind("async"))
original_session_remove = sh.scoped_session.remove
original_get_event_loop = asyncio.get_event_loop

with (
patch.object(
sh.scoped_session,
"remove",
wraps=original_session_remove,
) as mocked_close,
patch(
"asyncio.get_event_loop",
wraps=original_get_event_loop,
) as mocked_get_event_loop,
):
# This should trigger the garbage collector and close the session
sh = None

mocked_get_event_loop.assert_called_once()
mocked_close.assert_called_once()


def test_session_is_removed_on_cleanup_even_if_loop_search_errors_out(sa_manager):
# Running the test without a loop will trigger the loop creation
sh = AsyncSessionHandler(sa_manager.get_bind("async"))
original_session_remove = sh.scoped_session.remove

with (
patch.object(
sh.scoped_session,
"remove",
wraps=original_session_remove,
) as mocked_close,
patch(
"asyncio.get_event_loop",
side_effect=RuntimeError(),
) as mocked_get_event_loop,
):
# This should trigger the garbage collector and close the session
sh = None

mocked_get_event_loop.assert_called_once()
mocked_close.assert_called_once()


@pytest.mark.parametrize("read_only_flag", [True, False])
async def test_commit_is_called_only_if_not_read_only(
read_only_flag,
Expand Down
80 changes: 21 additions & 59 deletions tests/test_sqlalchemy_bind_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,39 +75,19 @@ def test_multiple_binds(multiple_config):
assert isinstance(sa_manager.get_session("async"), AsyncSession)


async def test_engine_is_disposed_on_cleanup(multiple_config):
sa_manager = SQLAlchemyBindManager(multiple_config)
sync_engine = sa_manager.get_bind("default").engine
async_engine = sa_manager.get_bind("async").engine

original_sync_dispose = sync_engine.dispose
original_async_dispose = async_engine.dispose

with (
patch.object(
sync_engine,
"dispose",
wraps=original_sync_dispose,
) as mocked_dispose,
patch.object(
type(async_engine),
"dispose",
wraps=original_async_dispose,
) as mocked_async_dispose,
):
sa_manager = None

mocked_dispose.assert_called_once()
mocked_async_dispose.assert_called()
def test_engine_is_disposed_on_cleanup(multiple_config):
"""Test that engines are disposed synchronously during garbage collection.


def test_engine_is_disposed_on_cleanup_even_if_no_loop(multiple_config):
This test verifies that both sync and async engines are properly disposed
using synchronous disposal (sync_engine.dispose() for async engines).
"""
sa_manager = SQLAlchemyBindManager(multiple_config)
sync_engine = sa_manager.get_bind("default").engine
async_engine = sa_manager.get_bind("async").engine

original_sync_dispose = sync_engine.dispose
original_async_dispose = async_engine.dispose
# For async engines, we now use sync_engine.dispose() for safe cleanup
original_async_sync_dispose = async_engine.sync_engine.dispose

with (
patch.object(
Expand All @@ -116,45 +96,27 @@ def test_engine_is_disposed_on_cleanup_even_if_no_loop(multiple_config):
wraps=original_sync_dispose,
) as mocked_dispose,
patch.object(
type(async_engine),
async_engine.sync_engine,
"dispose",
wraps=original_async_dispose,
) as mocked_async_dispose,
wraps=original_async_sync_dispose,
) as mocked_async_sync_dispose,
):
sa_manager = None

mocked_dispose.assert_called_once()
mocked_async_dispose.assert_called()
mocked_async_sync_dispose.assert_called_once()


def test_atexit_cleanup_disposes_all_managers(multiple_config):
"""Test that the atexit handler disposes all tracked manager instances."""
from sqlalchemy_bind_manager._bind_manager import _cleanup_all_managers

def test_engine_is_disposed_on_cleanup_even_if_loop_search_errors_out(
multiple_config,
):
sa_manager = SQLAlchemyBindManager(multiple_config)
sync_engine = sa_manager.get_bind("default").engine
async_engine = sa_manager.get_bind("async").engine

original_sync_dispose = sync_engine.dispose
original_async_dispose = async_engine.dispose
with patch.object(
sa_manager,
"_dispose_sync",
) as mocked_dispose_sync:
_cleanup_all_managers()

with (
patch.object(
sync_engine,
"dispose",
wraps=original_sync_dispose,
) as mocked_dispose,
patch.object(
type(async_engine),
"dispose",
wraps=original_async_dispose,
) as mocked_async_dispose,
patch(
"asyncio.get_event_loop",
side_effect=RuntimeError(),
) as mocked_get_event_loop,
):
sa_manager = None

mocked_get_event_loop.assert_called_once()
mocked_dispose.assert_called_once()
mocked_async_dispose.assert_called()
mocked_dispose_sync.assert_called_once()
Loading