From 5a7e842b3705532d036c5972a8b55235ee11281b Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 10 Feb 2026 20:52:02 +0000 Subject: [PATCH 1/2] Add async SQLAlchemy backend support with union type dispatch MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extend the database layer to support async SQLAlchemy sessions alongside existing sync sessions. Every public activity tracker function now accepts both Session and AsyncSession via overloaded signatures with isinstance detection — pass an AsyncSession and the call returns an awaitable coroutine, pass a Session (or None) and it executes synchronously as before. The public API names are unchanged. Key changes: - core/db: add async session factory, DatabaseEngine/DatabaseSession type aliases, set_global_async_session/get_global_async_session helpers - activity/models: add async classmethods (aappend_log, aget_by_agent_id, aget_list, aget_pending_ids, aget_active_count) using selectinload for relationship access - activity/tracker: overload create, update, complete, error, cancel_pending, list, detail, count_active for sync/async dispatch - Export DatabaseEngine and DatabaseSession from top-level package - Add aiosqlite dev dependency and comprehensive async test suite https://claude.ai/code/session_011TEKqVAGZi4xhkB5Fqqujv --- pyproject.toml | 1 + src/agentexec/__init__.py | 4 +- src/agentexec/activity/models.py | 199 ++++++++++++- src/agentexec/activity/tracker.py | 352 +++++++++++++++++++++-- src/agentexec/core/db.py | 81 +++++- tests/test_async_activity_tracking.py | 387 ++++++++++++++++++++++++++ tests/test_async_db.py | 135 +++++++++ uv.lock | 11 + 8 files changed, 1144 insertions(+), 26 deletions(-) create mode 100644 tests/test_async_activity_tracking.py create mode 100644 tests/test_async_db.py diff --git a/pyproject.toml b/pyproject.toml index 0fd997a..89b963f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ dev-dependencies = [ "ty>=0.0.1a7", "fakeredis>=2.32.1", "pytest-ty>=0.1.3", + "aiosqlite>=0.20.0", ] [tool.ruff] diff --git a/src/agentexec/__init__.py b/src/agentexec/__init__.py index 028ba1b..435e356 100644 --- a/src/agentexec/__init__.py +++ b/src/agentexec/__init__.py @@ -26,7 +26,7 @@ async def search(agent_id: UUID, context: Input) -> Output: from importlib.metadata import PackageNotFoundError, version from agentexec.config import CONF -from agentexec.core.db import Base +from agentexec.core.db import Base, DatabaseEngine, DatabaseSession from agentexec.core.queue import Priority, enqueue from agentexec.core.results import gather, get_result from agentexec.core.task import Task @@ -45,6 +45,8 @@ async def search(agent_id: UUID, context: Input) -> Output: "CONF", "Base", "BaseAgentRunner", + "DatabaseEngine", + "DatabaseSession", "Pipeline", "Tracker", "Pool", diff --git a/src/agentexec/activity/models.py b/src/agentexec/activity/models.py index 8d05565..31eede3 100644 --- a/src/agentexec/activity/models.py +++ b/src/agentexec/activity/models.py @@ -20,7 +20,16 @@ select, ) from sqlalchemy.engine import RowMapping -from sqlalchemy.orm import Mapped, Session, aliased, mapped_column, relationship, declared_attr +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import ( + Mapped, + Session, + aliased, + mapped_column, + relationship, + declared_attr, + selectinload, +) from agentexec.config import CONF from agentexec.core.db import Base @@ -75,6 +84,10 @@ def __tablename__(cls) -> str: order_by="ActivityLog.created_at", ) + # ------------------------------------------------------------------ + # Sync classmethods + # ------------------------------------------------------------------ + @classmethod def append_log( cls, @@ -350,6 +363,190 @@ def get_active_count(cls, session: Session) -> int: return result or 0 + # ------------------------------------------------------------------ + # Async classmethods + # ------------------------------------------------------------------ + + @classmethod + async def aappend_log( + cls, + session: AsyncSession, + agent_id: uuid.UUID, + message: str, + status: Status, + percentage: int | None = None, + ) -> None: + """Async version of :meth:`append_log`.""" + activity_id_subq = select(cls.id).where(cls.agent_id == agent_id).scalar_subquery() + + stmt = insert(ActivityLog).values( + activity_id=activity_id_subq, + message=message, + status=status, + percentage=percentage, + ) + + try: + await session.execute(stmt) + await session.commit() + except Exception as e: + await session.rollback() + raise ValueError(f"Failed to append log for agent_id {agent_id}") from e + + @classmethod + async def aget_by_agent_id( + cls, + session: AsyncSession, + agent_id: str | uuid.UUID, + metadata_filter: dict[str, Any] | None = None, + ) -> Activity | None: + """Async version of :meth:`get_by_agent_id`. + + Eagerly loads ``logs`` via ``selectinload`` since async sessions + do not support implicit lazy loading. + """ + if isinstance(agent_id, str): + agent_id = uuid.UUID(agent_id) + + stmt = ( + select(cls) + .options(selectinload(cls.logs)) + .filter_by(agent_id=agent_id) + ) + + if metadata_filter: + for key, value in metadata_filter.items(): + stmt = stmt.filter(cls.metadata_[key].as_string() == str(value)) + + result = await session.execute(stmt) + return result.scalars().first() + + @classmethod + async def aget_list( + cls, + session: AsyncSession, + page: int = 1, + page_size: int = 50, + metadata_filter: dict[str, Any] | None = None, + ) -> list[RowMapping]: + """Async version of :meth:`get_list`.""" + latest_log_subq = select( + ActivityLog.activity_id, + ActivityLog.message, + ActivityLog.status, + ActivityLog.created_at, + ActivityLog.percentage, + func.row_number() + .over( + partition_by=ActivityLog.activity_id, + order_by=ActivityLog.created_at.desc(), + ) + .label("rn"), + ).subquery() + + started_at_subq = ( + select( + ActivityLog.activity_id, + func.min(ActivityLog.created_at).label("started_at"), + ) + .group_by(ActivityLog.activity_id) + .subquery() + ) + + latest_log = aliased(latest_log_subq) + started_at = aliased(started_at_subq) + + query = ( + select( + cls.agent_id, + cls.agent_type, + latest_log.c.message.label("latest_log_message"), + latest_log.c.status, + latest_log.c.created_at.label("latest_log_timestamp"), + latest_log.c.percentage, + started_at.c.started_at, + cls.metadata_.label("metadata"), + ) + .outerjoin( + latest_log, + (cls.id == latest_log.c.activity_id) & (latest_log.c.rn == 1), + ) + .outerjoin(started_at, cls.id == started_at.c.activity_id) + ) + + if metadata_filter: + for key, value in metadata_filter.items(): + query = query.where(cls.metadata_[key].as_string() == str(value)) + + is_active = case( + (latest_log.c.status.in_([Status.RUNNING, Status.QUEUED]), 0), + else_=1, + ) + active_priority = case( + (latest_log.c.status == Status.RUNNING, 1), + (latest_log.c.status == Status.QUEUED, 2), + else_=3, + ) + query = query.order_by( + is_active, active_priority, started_at.c.started_at.desc().nullslast() + ) + + offset = (page - 1) * page_size + result = await session.execute(query.offset(offset).limit(page_size)) + return list(result.mappings().all()) + + @classmethod + async def aget_pending_ids(cls, session: AsyncSession) -> list[uuid.UUID]: + """Async version of :meth:`get_pending_ids`.""" + latest_log_subq = select( + ActivityLog.activity_id, + ActivityLog.status, + func.row_number() + .over( + partition_by=ActivityLog.activity_id, + order_by=ActivityLog.created_at.desc(), + ) + .label("rn"), + ).subquery() + + stmt = ( + select(cls.agent_id) + .join( + latest_log_subq, + (cls.id == latest_log_subq.c.activity_id) & (latest_log_subq.c.rn == 1), + ) + .filter(latest_log_subq.c.status.in_([Status.QUEUED, Status.RUNNING])) + ) + + result = await session.execute(stmt) + return [row[0] for row in result.all()] + + @classmethod + async def aget_active_count(cls, session: AsyncSession) -> int: + """Async version of :meth:`get_active_count`.""" + latest_log_subq = select( + ActivityLog.activity_id, + ActivityLog.status, + func.row_number() + .over( + partition_by=ActivityLog.activity_id, + order_by=ActivityLog.created_at.desc(), + ) + .label("rn"), + ).subquery() + + stmt = ( + select(func.count(cls.id)) + .join( + latest_log_subq, + (cls.id == latest_log_subq.c.activity_id) & (latest_log_subq.c.rn == 1), + ) + .filter(latest_log_subq.c.status.in_([Status.QUEUED, Status.RUNNING])) + ) + + result = await session.execute(stmt) + return result.scalar() or 0 + class ActivityLog(Base): """Individual log messages from background agents. diff --git a/src/agentexec/activity/tracker.py b/src/agentexec/activity/tracker.py index 12aeff8..9d4b8d1 100644 --- a/src/agentexec/activity/tracker.py +++ b/src/agentexec/activity/tracker.py @@ -1,6 +1,8 @@ import uuid -from typing import Any +from typing import Any, Coroutine, overload +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from agentexec.activity.models import Activity, ActivityLog, Status @@ -41,26 +43,199 @@ def normalize_agent_id(agent_id: str | uuid.UUID) -> uuid.UUID: return agent_id +# --------------------------------------------------------------------------- +# Private async implementations +# --------------------------------------------------------------------------- + + +async def _acreate( + task_name: str, + message: str, + agent_id: str | uuid.UUID | None, + session: AsyncSession, + metadata: dict[str, Any] | None, +) -> uuid.UUID: + aid = normalize_agent_id(agent_id) if agent_id else generate_agent_id() + + activity_record = Activity( + agent_id=aid, + agent_type=task_name, + metadata_=metadata, + ) + session.add(activity_record) + await session.flush() + + log = ActivityLog( + activity_id=activity_record.id, + message=message, + status=Status.QUEUED, + percentage=0, + ) + session.add(log) + await session.commit() + + return aid + + +async def _aupdate( + agent_id: str | uuid.UUID, + message: str, + percentage: int | None, + status: Status | None, + session: AsyncSession, +) -> bool: + await Activity.aappend_log( + session=session, + agent_id=normalize_agent_id(agent_id), + message=message, + status=status if status else Status.RUNNING, + percentage=percentage, + ) + return True + + +async def _acomplete( + agent_id: str | uuid.UUID, + message: str, + percentage: int, + session: AsyncSession, +) -> bool: + await Activity.aappend_log( + session=session, + agent_id=normalize_agent_id(agent_id), + message=message, + status=Status.COMPLETE, + percentage=percentage, + ) + return True + + +async def _aerror( + agent_id: str | uuid.UUID, + message: str, + percentage: int, + session: AsyncSession, +) -> bool: + await Activity.aappend_log( + session=session, + agent_id=normalize_agent_id(agent_id), + message=message, + status=Status.ERROR, + percentage=percentage, + ) + return True + + +async def _acancel_pending(session: AsyncSession) -> int: + pending_agent_ids = await Activity.aget_pending_ids(session) + for aid in pending_agent_ids: + await Activity.aappend_log( + session=session, + agent_id=aid, + message="Canceled due to shutdown", + status=Status.CANCELED, + percentage=None, + ) + await session.commit() + return len(pending_agent_ids) + + +async def _alist( + session: AsyncSession, + page: int, + page_size: int, + metadata_filter: dict[str, Any] | None, +) -> ActivityListSchema: + # Count query + count_stmt = select(func.count()).select_from(Activity) + if metadata_filter: + for key, value in metadata_filter.items(): + count_stmt = count_stmt.filter(Activity.metadata_[key].as_string() == str(value)) + result = await session.execute(count_stmt) + total = result.scalar() or 0 + + rows = await Activity.aget_list( + session, + page=page, + page_size=page_size, + metadata_filter=metadata_filter, + ) + + return ActivityListSchema( + items=[ActivityListItemSchema.model_validate(row) for row in rows], + total=total, + page=page, + page_size=page_size, + ) + + +async def _adetail( + session: AsyncSession, + agent_id: str | uuid.UUID, + metadata_filter: dict[str, Any] | None, +) -> ActivityDetailSchema | None: + if item := await Activity.aget_by_agent_id(session, agent_id, metadata_filter=metadata_filter): + return ActivityDetailSchema.model_validate(item) + return None + + +async def _acount_active(session: AsyncSession) -> int: + return await Activity.aget_active_count(session) + + +# --------------------------------------------------------------------------- +# Public API – overloaded for sync / async dispatch +# --------------------------------------------------------------------------- + + +@overload +def create( + task_name: str, + message: str = ..., + agent_id: str | uuid.UUID | None = ..., + session: AsyncSession = ..., + metadata: dict[str, Any] | None = ..., +) -> Coroutine[Any, Any, uuid.UUID]: ... + + +@overload +def create( + task_name: str, + message: str = ..., + agent_id: str | uuid.UUID | None = ..., + session: Session | None = ..., + metadata: dict[str, Any] | None = ..., +) -> uuid.UUID: ... + + def create( task_name: str, message: str = "Agent queued", agent_id: str | uuid.UUID | None = None, - session: Session | None = None, + session: Session | AsyncSession | None = None, metadata: dict[str, Any] | None = None, -) -> uuid.UUID: +) -> uuid.UUID | Coroutine[Any, Any, uuid.UUID]: """Create a new agent activity record with initial queued status. + Accepts both sync and async sessions. When an ``AsyncSession`` is + passed the call returns an awaitable coroutine; otherwise it executes + synchronously and returns the ``agent_id`` directly. + Args: task_name: Name/type of the task (e.g., "research", "analysis") message: Initial log message (default: "Agent queued") agent_id: Optional custom agent ID (string or UUID). If not provided, one will be auto-generated. - session: Optional SQLAlchemy session. If not provided, uses global session factory. + session: SQLAlchemy session. Pass an ``AsyncSession`` to use async I/O. + If ``None``, falls back to the global sync session. metadata: Optional dict of arbitrary metadata to attach to the activity. Useful for multi-tenancy (e.g., {"organization_id": "org-123"}). Returns: The agent_id (as UUID object) of the created record """ + if isinstance(session, AsyncSession): + return _acreate(task_name, message, agent_id, session, metadata) + agent_id = normalize_agent_id(agent_id) if agent_id else generate_agent_id() db = session or get_global_session() @@ -84,13 +259,33 @@ def create( return agent_id +@overload +def update( + agent_id: str | uuid.UUID, + message: str, + percentage: int | None = ..., + status: Status | None = ..., + session: AsyncSession = ..., +) -> Coroutine[Any, Any, bool]: ... + + +@overload +def update( + agent_id: str | uuid.UUID, + message: str, + percentage: int | None = ..., + status: Status | None = ..., + session: Session | None = ..., +) -> bool: ... + + def update( agent_id: str | uuid.UUID, message: str, percentage: int | None = None, status: Status | None = None, - session: Session | None = None, -) -> bool: + session: Session | AsyncSession | None = None, +) -> bool | Coroutine[Any, Any, bool]: """Update an agent's activity by adding a new log message. This function will set the status to RUNNING unless a different status is explicitly provided. @@ -100,7 +295,7 @@ def update( message: Log message to append percentage: Optional completion percentage (0-100) status: Optional status to set (default: RUNNING) - session: Optional SQLAlchemy session. If not provided, uses global session factory. + session: SQLAlchemy session. Pass an ``AsyncSession`` to use async I/O. Returns: True if successful @@ -108,6 +303,9 @@ def update( Raises: ValueError: If agent_id not found """ + if isinstance(session, AsyncSession): + return _aupdate(agent_id, message, percentage, status, session) + db = session or get_global_session() Activity.append_log( @@ -120,19 +318,37 @@ def update( return True +@overload +def complete( + agent_id: str | uuid.UUID, + message: str = ..., + percentage: int = ..., + session: AsyncSession = ..., +) -> Coroutine[Any, Any, bool]: ... + + +@overload +def complete( + agent_id: str | uuid.UUID, + message: str = ..., + percentage: int = ..., + session: Session | None = ..., +) -> bool: ... + + def complete( agent_id: str | uuid.UUID, message: str = "Agent completed", percentage: int = 100, - session: Session | None = None, -) -> bool: + session: Session | AsyncSession | None = None, +) -> bool | Coroutine[Any, Any, bool]: """Mark an agent activity as complete. Args: agent_id: The agent_id of the agent to mark as complete message: Log message (default: "Agent completed") percentage: Completion percentage (default: 100) - session: Optional SQLAlchemy session. If not provided, uses global session factory. + session: SQLAlchemy session. Pass an ``AsyncSession`` to use async I/O. Returns: True if successful @@ -140,6 +356,9 @@ def complete( Raises: ValueError: If agent_id not found """ + if isinstance(session, AsyncSession): + return _acomplete(agent_id, message, percentage, session) + db = session or get_global_session() Activity.append_log( @@ -152,19 +371,37 @@ def complete( return True +@overload +def error( + agent_id: str | uuid.UUID, + message: str = ..., + percentage: int = ..., + session: AsyncSession = ..., +) -> Coroutine[Any, Any, bool]: ... + + +@overload +def error( + agent_id: str | uuid.UUID, + message: str = ..., + percentage: int = ..., + session: Session | None = ..., +) -> bool: ... + + def error( agent_id: str | uuid.UUID, message: str = "Agent failed", percentage: int = 100, - session: Session | None = None, -) -> bool: + session: Session | AsyncSession | None = None, +) -> bool | Coroutine[Any, Any, bool]: """Mark an agent activity as failed. Args: agent_id: The agent_id of the agent to mark as failed message: Log message (default: "Agent failed") percentage: Completion percentage (default: 100) - session: Optional SQLAlchemy session. If not provided, uses ScopedSession. + session: SQLAlchemy session. Pass an ``AsyncSession`` to use async I/O. Returns: True if successful @@ -172,6 +409,9 @@ def error( Raises: ValueError: If agent_id not found """ + if isinstance(session, AsyncSession): + return _aerror(agent_id, message, percentage, session) + db = session or get_global_session() Activity.append_log( @@ -184,9 +424,21 @@ def error( return True +@overload def cancel_pending( - session: Session | None = None, -) -> int: + session: AsyncSession = ..., +) -> Coroutine[Any, Any, int]: ... + + +@overload +def cancel_pending( + session: Session | None = ..., +) -> int: ... + + +def cancel_pending( + session: Session | AsyncSession | None = None, +) -> int | Coroutine[Any, Any, int]: """Mark all queued and running agents as canceled. Useful during application shutdown to clean up pending tasks. @@ -194,13 +446,16 @@ def cancel_pending( Returns: Number of agents that were canceled """ + if isinstance(session, AsyncSession): + return _acancel_pending(session) + db = session or get_global_session() pending_agent_ids = Activity.get_pending_ids(db) - for agent_id in pending_agent_ids: + for aid in pending_agent_ids: Activity.append_log( session=db, - agent_id=agent_id, + agent_id=aid, message="Canceled due to shutdown", status=Status.CANCELED, percentage=None, @@ -210,16 +465,34 @@ def cancel_pending( return len(pending_agent_ids) +@overload +def list( + session: AsyncSession, + page: int = ..., + page_size: int = ..., + metadata_filter: dict[str, Any] | None = ..., +) -> Coroutine[Any, Any, ActivityListSchema]: ... + + +@overload def list( session: Session, + page: int = ..., + page_size: int = ..., + metadata_filter: dict[str, Any] | None = ..., +) -> ActivityListSchema: ... + + +def list( + session: Session | AsyncSession, page: int = 1, page_size: int = 50, metadata_filter: dict[str, Any] | None = None, -) -> ActivityListSchema: +) -> ActivityListSchema | Coroutine[Any, Any, ActivityListSchema]: """List activities with pagination. Args: - session: SQLAlchemy session to use for the query + session: SQLAlchemy session (sync or async) page: Page number (1-indexed) page_size: Number of items per page metadata_filter: Optional dict of key-value pairs to filter by. @@ -229,6 +502,9 @@ def list( Returns: ActivityList with list of ActivityListItemSchema items """ + if isinstance(session, AsyncSession): + return _alist(session, page, page_size, metadata_filter) + # Build base query for total count query = session.query(Activity) if metadata_filter: @@ -251,15 +527,31 @@ def list( ) +@overload +def detail( + session: AsyncSession, + agent_id: str | uuid.UUID, + metadata_filter: dict[str, Any] | None = ..., +) -> Coroutine[Any, Any, ActivityDetailSchema | None]: ... + + +@overload def detail( session: Session, agent_id: str | uuid.UUID, + metadata_filter: dict[str, Any] | None = ..., +) -> ActivityDetailSchema | None: ... + + +def detail( + session: Session | AsyncSession, + agent_id: str | uuid.UUID, metadata_filter: dict[str, Any] | None = None, -) -> ActivityDetailSchema | None: +) -> (ActivityDetailSchema | None) | Coroutine[Any, Any, ActivityDetailSchema | None]: """Get a single activity by agent_id with all logs. Args: - session: SQLAlchemy session to use for the query + session: SQLAlchemy session (sync or async) agent_id: The agent_id to look up metadata_filter: Optional dict of key-value pairs to filter by. If provided and the activity's metadata doesn't match, @@ -269,18 +561,32 @@ def detail( ActivityDetailSchema with full log history, or None if not found or if metadata doesn't match """ + if isinstance(session, AsyncSession): + return _adetail(session, agent_id, metadata_filter) + if item := Activity.get_by_agent_id(session, agent_id, metadata_filter=metadata_filter): return ActivityDetailSchema.model_validate(item) return None -def count_active(session: Session) -> int: +@overload +def count_active(session: AsyncSession) -> Coroutine[Any, Any, int]: ... + + +@overload +def count_active(session: Session) -> int: ... + + +def count_active(session: Session | AsyncSession) -> int | Coroutine[Any, Any, int]: """Get count of active (queued or running) agents. Args: - session: SQLAlchemy session to use for the query + session: SQLAlchemy session (sync or async) Returns: Count of agents with QUEUED or RUNNING status """ + if isinstance(session, AsyncSession): + return _acount_active(session) + return Activity.get_active_count(session) diff --git a/src/agentexec/core/db.py b/src/agentexec/core/db.py index e5f1a00..efe277a 100644 --- a/src/agentexec/core/db.py +++ b/src/agentexec/core/db.py @@ -1,15 +1,33 @@ +from __future__ import annotations + +from typing import TypeAlias + from sqlalchemy import Engine +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker from sqlalchemy.orm import DeclarativeBase, Session, scoped_session, sessionmaker __all__ = [ "Base", + "DatabaseEngine", + "DatabaseSession", + "get_global_async_session", "get_global_session", - "set_global_session", + "is_async_session", + "remove_global_async_session", "remove_global_session", + "set_global_async_session", + "set_global_session", ] +DatabaseEngine: TypeAlias = Engine | AsyncEngine +"""Union type for sync and async SQLAlchemy engines.""" + +DatabaseSession: TypeAlias = Session | AsyncSession +"""Union type for sync and async SQLAlchemy sessions.""" + + class Base(DeclarativeBase): """Base class for all SQLAlchemy models in agent-runner. @@ -22,6 +40,10 @@ class Base(DeclarativeBase): pass +# --------------------------------------------------------------------------- +# Sync session management +# --------------------------------------------------------------------------- + # We need one session per worker process with a shared engine across the application. # SQLAlchemy's scoped_session provides process-local session management out of the box. _session_factory: scoped_session[Session] = scoped_session(sessionmaker()) @@ -60,3 +82,60 @@ def remove_global_session() -> None: connections to the pool. """ _session_factory.remove() + + +# --------------------------------------------------------------------------- +# Async session management +# --------------------------------------------------------------------------- + +_async_session_factory: async_sessionmaker[AsyncSession] | None = None + + +def set_global_async_session(engine: AsyncEngine) -> None: + """Configure the global async session factory with an async engine. + + Args: + engine: SQLAlchemy async engine to bind sessions to. + """ + global _async_session_factory + _async_session_factory = async_sessionmaker(bind=engine, expire_on_commit=False) + + +def get_global_async_session() -> AsyncSession: + """Create a new async session from the global factory. + + Unlike the sync ``get_global_session()`` which returns a scoped + (process-local) session, each call here returns a **new** session. + Callers are responsible for closing it when done. + + Returns: + A new ``AsyncSession`` bound to the configured async engine. + + Raises: + RuntimeError: If ``set_global_async_session()`` hasn't been called. + """ + if _async_session_factory is None: + raise RuntimeError( + "Async session not configured. Call set_global_async_session() first." + ) + return _async_session_factory() + + +def remove_global_async_session() -> None: + """Remove the global async session factory. + + Resets the factory to ``None``. Existing sessions created from it + should be closed individually. + """ + global _async_session_factory + _async_session_factory = None + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def is_async_session(session: DatabaseSession) -> bool: + """Return ``True`` if *session* is an ``AsyncSession``.""" + return isinstance(session, AsyncSession) diff --git a/tests/test_async_activity_tracking.py b/tests/test_async_activity_tracking.py new file mode 100644 index 0000000..d7196d3 --- /dev/null +++ b/tests/test_async_activity_tracking.py @@ -0,0 +1,387 @@ +"""Tests for async activity tracking functionality.""" + +import uuid + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from agentexec import activity +from agentexec.activity.models import Activity, ActivityLog, Base, Status + + +@pytest.fixture +async def async_db_session(): + """Set up an async in-memory SQLite database for testing.""" + engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + session_factory = async_sessionmaker(engine, expire_on_commit=False) + async with session_factory() as session: + yield session + + await engine.dispose() + + +async def test_async_create_activity(async_db_session: AsyncSession): + """Test creating a new activity record via async session.""" + agent_id = await activity.create( + task_name="test_task", + message="Task queued for testing", + session=async_db_session, + ) + + assert agent_id is not None + assert isinstance(agent_id, uuid.UUID) + + # Verify the activity was created in database + activity_record = await Activity.aget_by_agent_id(async_db_session, agent_id) + assert activity_record is not None + assert activity_record.agent_type == "test_task" + assert len(activity_record.logs) == 1 + assert activity_record.logs[0].message == "Task queued for testing" + assert activity_record.logs[0].status == Status.QUEUED + assert activity_record.logs[0].percentage == 0 + + +async def test_async_create_with_custom_agent_id(async_db_session: AsyncSession): + """Test creating activity with a custom agent_id via async.""" + custom_id = uuid.uuid4() + agent_id = await activity.create( + task_name="custom_id_task", + message="Test", + agent_id=custom_id, + session=async_db_session, + ) + + assert agent_id == custom_id + + activity_record = await Activity.aget_by_agent_id(async_db_session, custom_id) + assert activity_record is not None + + +async def test_async_create_with_string_agent_id(async_db_session: AsyncSession): + """Test creating activity with a string agent_id via async.""" + custom_id = uuid.uuid4() + agent_id = await activity.create( + task_name="string_id_task", + message="Test", + agent_id=str(custom_id), + session=async_db_session, + ) + + assert agent_id == custom_id + + +async def test_async_create_with_metadata(async_db_session: AsyncSession): + """Test creating activity with metadata via async.""" + agent_id = await activity.create( + task_name="metadata_task", + message="Test", + session=async_db_session, + metadata={"organization_id": "org-123", "user_id": "user-456"}, + ) + + activity_record = await Activity.aget_by_agent_id(async_db_session, agent_id) + assert activity_record is not None + assert activity_record.metadata_ == {"organization_id": "org-123", "user_id": "user-456"} + + +async def test_async_update_activity(async_db_session: AsyncSession): + """Test updating an activity with a new log message via async.""" + agent_id = await activity.create( + task_name="test_task", + message="Initial message", + session=async_db_session, + ) + + result = await activity.update( + agent_id=agent_id, + message="Processing...", + percentage=50, + session=async_db_session, + ) + + assert result is True + + activity_record = await Activity.aget_by_agent_id(async_db_session, agent_id) + assert len(activity_record.logs) == 2 + assert activity_record.logs[1].message == "Processing..." + assert activity_record.logs[1].status == Status.RUNNING + assert activity_record.logs[1].percentage == 50 + + +async def test_async_update_with_custom_status(async_db_session: AsyncSession): + """Test updating an activity with a custom status via async.""" + agent_id = await activity.create( + task_name="test_task", + message="Initial", + session=async_db_session, + ) + + await activity.update( + agent_id=agent_id, + message="Custom status update", + status=Status.RUNNING, + percentage=25, + session=async_db_session, + ) + + activity_record = await Activity.aget_by_agent_id(async_db_session, agent_id) + latest_log = activity_record.logs[-1] + assert latest_log.status == Status.RUNNING + + +async def test_async_complete_activity(async_db_session: AsyncSession): + """Test marking an activity as complete via async.""" + agent_id = await activity.create( + task_name="test_task", + message="Started", + session=async_db_session, + ) + + result = await activity.complete( + agent_id=agent_id, + message="Successfully completed", + session=async_db_session, + ) + + assert result is True + + activity_record = await Activity.aget_by_agent_id(async_db_session, agent_id) + latest_log = activity_record.logs[-1] + assert latest_log.message == "Successfully completed" + assert latest_log.status == Status.COMPLETE + assert latest_log.percentage == 100 + + +async def test_async_error_activity(async_db_session: AsyncSession): + """Test marking an activity as errored via async.""" + agent_id = await activity.create( + task_name="test_task", + message="Started", + session=async_db_session, + ) + + result = await activity.error( + agent_id=agent_id, + message="Task failed: connection timeout", + session=async_db_session, + ) + + assert result is True + + activity_record = await Activity.aget_by_agent_id(async_db_session, agent_id) + latest_log = activity_record.logs[-1] + assert latest_log.message == "Task failed: connection timeout" + assert latest_log.status == Status.ERROR + assert latest_log.percentage == 100 + + +async def test_async_cancel_pending(async_db_session: AsyncSession): + """Test canceling all pending activities via async.""" + # Create activities in different states + queued_id = await activity.create( + task_name="queued_task", + message="Waiting", + session=async_db_session, + ) + + running_id = await activity.create( + task_name="running_task", + message="Started", + session=async_db_session, + ) + await activity.update( + agent_id=running_id, + message="Running...", + status=Status.RUNNING, + session=async_db_session, + ) + + complete_id = await activity.create( + task_name="complete_task", + message="Started", + session=async_db_session, + ) + await activity.complete(agent_id=complete_id, session=async_db_session) + + # Cancel pending + canceled_count = await activity.cancel_pending(session=async_db_session) + assert canceled_count == 2 + + # Verify states + queued_record = await Activity.aget_by_agent_id(async_db_session, queued_id) + running_record = await Activity.aget_by_agent_id(async_db_session, running_id) + complete_record = await Activity.aget_by_agent_id(async_db_session, complete_id) + + assert queued_record.logs[-1].status == Status.CANCELED + assert running_record.logs[-1].status == Status.CANCELED + assert complete_record.logs[-1].status == Status.COMPLETE + + +async def test_async_list_activities(async_db_session: AsyncSession): + """Test listing activities with pagination via async.""" + for i in range(5): + await activity.create( + task_name=f"task_{i}", + message=f"Message {i}", + session=async_db_session, + ) + + result = await activity.list(async_db_session, page=1, page_size=3) + + assert len(result.items) == 3 + assert result.total == 5 + assert result.page == 1 + assert result.page_size == 3 + + +async def test_async_list_second_page(async_db_session: AsyncSession): + """Test listing activities on second page via async.""" + for i in range(5): + await activity.create( + task_name=f"task_{i}", + message=f"Message {i}", + session=async_db_session, + ) + + result = await activity.list(async_db_session, page=2, page_size=3) + + assert len(result.items) == 2 + assert result.total == 5 + assert result.page == 2 + + +async def test_async_list_with_metadata_filter(async_db_session: AsyncSession): + """Test filtering activities by metadata via async.""" + await activity.create( + task_name="task_org_a", + message="Org A", + session=async_db_session, + metadata={"organization_id": "org-A"}, + ) + await activity.create( + task_name="task_org_a_2", + message="Org A 2", + session=async_db_session, + metadata={"organization_id": "org-A"}, + ) + await activity.create( + task_name="task_org_b", + message="Org B", + session=async_db_session, + metadata={"organization_id": "org-B"}, + ) + + result = await activity.list( + async_db_session, + metadata_filter={"organization_id": "org-A"}, + ) + assert result.total == 2 + + result = await activity.list( + async_db_session, + metadata_filter={"organization_id": "org-B"}, + ) + assert result.total == 1 + + +async def test_async_detail_activity(async_db_session: AsyncSession): + """Test getting activity detail with all logs via async.""" + agent_id = await activity.create( + task_name="detailed_task", + message="Initial", + session=async_db_session, + ) + await activity.update( + agent_id=agent_id, + message="Processing", + percentage=50, + session=async_db_session, + ) + await activity.complete(agent_id=agent_id, session=async_db_session) + + result = await activity.detail(async_db_session, agent_id) + + assert result is not None + assert result.agent_id == agent_id + assert result.agent_type == "detailed_task" + assert len(result.logs) == 3 + assert result.logs[0].message == "Initial" + assert result.logs[1].message == "Processing" + assert result.logs[2].status == Status.COMPLETE + + +async def test_async_detail_not_found(async_db_session: AsyncSession): + """Test getting detail for non-existent activity via async.""" + fake_id = uuid.uuid4() + result = await activity.detail(async_db_session, fake_id) + assert result is None + + +async def test_async_detail_with_string_id(async_db_session: AsyncSession): + """Test getting activity detail with string agent_id via async.""" + agent_id = await activity.create( + task_name="string_id_task", + message="Test", + session=async_db_session, + ) + + result = await activity.detail(async_db_session, str(agent_id)) + assert result is not None + assert result.agent_id == agent_id + + +async def test_async_detail_with_metadata_filter(async_db_session: AsyncSession): + """Test detail with metadata filter via async.""" + agent_id = await activity.create( + task_name="filter_task", + message="Test", + session=async_db_session, + metadata={"organization_id": "org-A"}, + ) + + # Should find with matching filter + result = await activity.detail( + async_db_session, + agent_id, + metadata_filter={"organization_id": "org-A"}, + ) + assert result is not None + + # Should not find with non-matching filter + result = await activity.detail( + async_db_session, + agent_id, + metadata_filter={"organization_id": "org-B"}, + ) + assert result is None + + +async def test_async_count_active(async_db_session: AsyncSession): + """Test counting active agents via async.""" + # Start with zero + count = await activity.count_active(async_db_session) + assert count == 0 + + # Create a queued activity + await activity.create( + task_name="task_1", + message="Queued", + session=async_db_session, + ) + count = await activity.count_active(async_db_session) + assert count == 1 + + # Create and complete another + agent_id = await activity.create( + task_name="task_2", + message="Started", + session=async_db_session, + ) + await activity.complete(agent_id=agent_id, session=async_db_session) + + count = await activity.count_active(async_db_session) + assert count == 1 # Only the first (queued) one is active diff --git a/tests/test_async_db.py b/tests/test_async_db.py new file mode 100644 index 0000000..8c565bc --- /dev/null +++ b/tests/test_async_db.py @@ -0,0 +1,135 @@ +"""Test async database session management.""" + +import pytest +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine + +from agentexec.core.db import ( + Base, + DatabaseEngine, + DatabaseSession, + get_global_async_session, + is_async_session, + remove_global_async_session, + set_global_async_session, +) + + +@pytest.fixture +async def async_engine(): + """Create a test async SQLite engine.""" + engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False) + yield engine + await engine.dispose() + + +@pytest.fixture(autouse=True) +def cleanup_async_session(): + """Cleanup global async session after each test.""" + yield + remove_global_async_session() + + +async def test_set_global_async_session(async_engine): + """Test that set_global_async_session configures the factory.""" + set_global_async_session(async_engine) + + session = get_global_async_session() + assert isinstance(session, AsyncSession) + await session.close() + + +async def test_get_global_async_session_returns_working_session(async_engine): + """Test that the async session can execute queries.""" + set_global_async_session(async_engine) + + session = get_global_async_session() + result = await session.execute(text("SELECT 1")) + assert result.scalar() == 1 + await session.close() + + +async def test_get_global_async_session_returns_new_instances(async_engine): + """Test that each call returns a different session instance.""" + set_global_async_session(async_engine) + + session1 = get_global_async_session() + session2 = get_global_async_session() + + # Each call should return a new session + assert session1 is not session2 + await session1.close() + await session2.close() + + +async def test_get_global_async_session_without_setup(): + """Test that accessing async session before setup raises RuntimeError.""" + with pytest.raises(RuntimeError, match="Async session not configured"): + get_global_async_session() + + +async def test_remove_global_async_session(async_engine): + """Test that remove_global_async_session resets the factory.""" + set_global_async_session(async_engine) + + # Should work + session = get_global_async_session() + await session.close() + + # After remove, should raise + remove_global_async_session() + with pytest.raises(RuntimeError, match="Async session not configured"): + get_global_async_session() + + +async def test_async_session_with_tables(async_engine): + """Test that async session works with table creation.""" + async with async_engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + set_global_async_session(async_engine) + session = get_global_async_session() + + result = await session.execute( + text("SELECT name FROM sqlite_master WHERE type='table'") + ) + tables = [row[0] for row in result] + assert isinstance(tables, list) + await session.close() + + +def test_is_async_session_with_sync(): + """Test is_async_session returns False for sync sessions.""" + from sqlalchemy import create_engine + from sqlalchemy.orm import Session, sessionmaker + + engine = create_engine("sqlite:///:memory:") + SessionLocal = sessionmaker(bind=engine) + session = SessionLocal() + + assert is_async_session(session) is False + session.close() + engine.dispose() + + +async def test_is_async_session_with_async(async_engine): + """Test is_async_session returns True for async sessions.""" + set_global_async_session(async_engine) + session = get_global_async_session() + + assert is_async_session(session) is True + await session.close() + + +def test_database_engine_type_alias(): + """Test that DatabaseEngine accepts both engine types.""" + from sqlalchemy import Engine + from sqlalchemy.ext.asyncio import AsyncEngine + + # Just verify the type alias is accessible and usable + assert DatabaseEngine is not None + + +def test_database_session_type_alias(): + """Test that DatabaseSession is accessible.""" + assert DatabaseSession is not None diff --git a/uv.lock b/uv.lock index c8d1f98..4665862 100644 --- a/uv.lock +++ b/uv.lock @@ -16,6 +16,7 @@ dependencies = [ [package.dev-dependencies] dev = [ + { name = "aiosqlite" }, { name = "fakeredis" }, { name = "pytest" }, { name = "pytest-asyncio" }, @@ -36,6 +37,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ + { name = "aiosqlite", specifier = ">=0.20.0" }, { name = "fakeredis", specifier = ">=2.32.1" }, { name = "pytest", specifier = ">=8.0.0" }, { name = "pytest-asyncio", specifier = ">=0.23.0" }, @@ -45,6 +47,15 @@ dev = [ { name = "ty", specifier = ">=0.0.1a7" }, ] +[[package]] +name = "aiosqlite" +version = "0.22.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4e/8a/64761f4005f17809769d23e518d915db74e6310474e733e3593cfc854ef1/aiosqlite-0.22.1.tar.gz", hash = "sha256:043e0bd78d32888c0a9ca90fc788b38796843360c855a7262a532813133a0650", size = 14821, upload-time = "2025-12-23T19:25:43.997Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/b7/e3bf5133d697a08128598c8d0abc5e16377b51465a33756de24fa7dee953/aiosqlite-0.22.1-py3-none-any.whl", hash = "sha256:21c002eb13823fad740196c5a2e9d8e62f6243bd9e7e4a1f87fb5e44ecb4fceb", size = 17405, upload-time = "2025-12-23T19:25:42.139Z" }, +] + [[package]] name = "annotated-types" version = "0.7.0" From 017d171aea35f2641af6f2c14eac886aae00b797 Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 10 Feb 2026 20:55:17 +0000 Subject: [PATCH 2/2] Deduplicate model queries: shared statement builders + explicit selectinload Extract SQL statement construction into private _*_stmt() classmethods shared by both sync and async execution paths. Sync methods now use select() instead of the legacy session.query() API, and get_by_agent_id uses explicit selectinload(logs) instead of relying on implicit lazy loading. The async methods are now thin await wrappers with zero duplicated query logic. https://claude.ai/code/session_011TEKqVAGZi4xhkB5Fqqujv --- src/agentexec/activity/models.py | 421 ++++++++++--------------------- 1 file changed, 137 insertions(+), 284 deletions(-) diff --git a/src/agentexec/activity/models.py b/src/agentexec/activity/models.py index 31eede3..2d12372 100644 --- a/src/agentexec/activity/models.py +++ b/src/agentexec/activity/models.py @@ -11,6 +11,7 @@ ForeignKey, Integer, JSON, + Select, String, Text, Uuid, @@ -30,6 +31,7 @@ declared_attr, selectinload, ) +from sqlalchemy.sql.dml import Insert from agentexec.config import CONF from agentexec.core.db import Base @@ -85,132 +87,44 @@ def __tablename__(cls) -> str: ) # ------------------------------------------------------------------ - # Sync classmethods + # Statement builders (shared between sync and async) # ------------------------------------------------------------------ @classmethod - def append_log( + def _append_log_stmt( cls, - session: Session, agent_id: uuid.UUID, message: str, status: Status, - percentage: int | None = None, - ) -> None: - """Append a log entry to the activity for the given agent_id. - - This uses a single query to look up the activity_id and insert the log, - avoiding the need to load the Activity record first. - - Args: - session: SQLAlchemy session - agent_id: The agent_id to append the log to - message: Log message - status: Current status of the agent - percentage: Optional completion percentage (0-100) - - Raises: - ValueError: If agent_id not found (foreign key constraint will fail) - """ - # Scalar subquery to get activity.id from agent_id + percentage: int | None, + ) -> Insert: activity_id_subq = select(cls.id).where(cls.agent_id == agent_id).scalar_subquery() - - # Insert the log using the subquery for activity_id - stmt = insert(ActivityLog).values( + return insert(ActivityLog).values( activity_id=activity_id_subq, message=message, status=status, percentage=percentage, ) - try: - session.execute(stmt) - session.commit() - except Exception as e: - session.rollback() - raise ValueError(f"Failed to append log for agent_id {agent_id}") from e - @classmethod - def get_by_agent_id( + def _get_by_agent_id_stmt( cls, - session: Session, - agent_id: str | uuid.UUID, + agent_id: uuid.UUID, metadata_filter: dict[str, Any] | None = None, - ) -> Activity | None: - """Get an activity by agent_id. - - Args: - session: SQLAlchemy session - agent_id: The agent_id to look up (string or UUID) - metadata_filter: Optional dict of key-value pairs to filter by. - If provided and the activity's metadata doesn't match, - returns None (same as if not found). - - Returns: - Activity object or None if not found or metadata doesn't match - - Example: - activity = Activity.get_by_agent_id(session, "abc-123") - # Or with UUID object - activity = Activity.get_by_agent_id(session, uuid.UUID("abc-123...")) - if activity: - print(f"Found activity: {activity.agent_type}") - - # With metadata filter (for multi-tenancy) - activity = Activity.get_by_agent_id( - session, - agent_id, - metadata_filter={"organization_id": "org-123"} - ) - """ - # Normalize to UUID if string - if isinstance(agent_id, str): - agent_id = uuid.UUID(agent_id) - - query = session.query(cls).filter_by(agent_id=agent_id) - - # Apply metadata filtering if provided + ) -> Select: + stmt = select(cls).options(selectinload(cls.logs)).filter_by(agent_id=agent_id) if metadata_filter: for key, value in metadata_filter.items(): - query = query.filter(cls.metadata_[key].as_string() == str(value)) - - return query.first() + stmt = stmt.filter(cls.metadata_[key].as_string() == str(value)) + return stmt @classmethod - def get_list( + def _get_list_stmt( cls, - session: Session, - page: int = 1, - page_size: int = 50, + page: int, + page_size: int, metadata_filter: dict[str, Any] | None = None, - ) -> list[RowMapping]: - """Get a paginated list of activities with summary information. - - Args: - session: SQLAlchemy session to use for the query - page: Page number (1-indexed) - page_size: Number of items per page - metadata_filter: Optional dict of key-value pairs to filter by. - Activities must have metadata containing all specified keys - with exactly matching values. - - Returns: - List of RowMapping objects (dict-like) with keys matching ActivitySummarySchema: - agent_id, agent_type, latest_log_message, status, latest_log_timestamp, - percentage, started_at, metadata - - Example: - results = Activity.get_list(session, page=1, page_size=20) - for row in results: - print(f"{row['agent_id']}: {row['latest_log_message']}") - - # Filter by organization - results = Activity.get_list( - session, - metadata_filter={"organization_id": "org-123"} - ) - """ - # Subquery to get the latest log for each agent + ) -> Select: latest_log_subq = select( ActivityLog.activity_id, ActivityLog.message, @@ -225,7 +139,6 @@ def get_list( .label("rn"), ).subquery() - # Subquery to get start time (first log timestamp) started_at_subq = ( select( ActivityLog.activity_id, @@ -235,11 +148,9 @@ def get_list( .subquery() ) - # Alias for the subqueries latest_log = aliased(latest_log_subq) started_at = aliased(started_at_subq) - # Build base query - select only the columns we need with aliases matching schema query = ( select( cls.agent_id, @@ -258,14 +169,10 @@ def get_list( .outerjoin(started_at, cls.id == started_at.c.activity_id) ) - # Apply metadata filtering if provided if metadata_filter: for key, value in metadata_filter.items(): - # Use JSON path extraction for exact string matching - # This works across SQLite (for testing) and PostgreSQL (for production) query = query.where(cls.metadata_[key].as_string() == str(value)) - # Custom ordering: active agents (running, queued) at the top is_active = case( (latest_log.c.status.in_([Status.RUNNING, Status.QUEUED]), 0), else_=1, @@ -279,26 +186,11 @@ def get_list( is_active, active_priority, started_at.c.started_at.desc().nullslast() ) - # Apply pagination and execute offset = (page - 1) * page_size - return list(session.execute(query.offset(offset).limit(page_size)).mappings().all()) + return query.offset(offset).limit(page_size) @classmethod - def get_pending_ids(cls, session: Session) -> list[uuid.UUID]: - """Get agent_ids for all activities with QUEUED or RUNNING status. - - Args: - session: SQLAlchemy session to use for the query - - Returns: - List of agent_id UUIDs for pending (queued or running) activities - - Example: - pending_ids = Activity.get_pending_ids(session) - for agent_id in pending_ids: - print(f"Pending agent: {agent_id}") - """ - # Subquery to get the latest log status for each activity + def _get_pending_ids_stmt(cls) -> Select: latest_log_subq = select( ActivityLog.activity_id, ActivityLog.status, @@ -310,35 +202,17 @@ def get_pending_ids(cls, session: Session) -> list[uuid.UUID]: .label("rn"), ).subquery() - # Query for agent_ids where latest status is queued or running - result = ( - session.query(cls.agent_id) + return ( + select(cls.agent_id) .join( latest_log_subq, (cls.id == latest_log_subq.c.activity_id) & (latest_log_subq.c.rn == 1), ) .filter(latest_log_subq.c.status.in_([Status.QUEUED, Status.RUNNING])) - .all() ) - # Extract UUIDs from result tuples - return [agent_id for (agent_id,) in result] - @classmethod - def get_active_count(cls, session: Session) -> int: - """Get count of activities with QUEUED or RUNNING status. - - Args: - session: SQLAlchemy session to use for the query - - Returns: - Count of active (queued or running) activities - - Example: - count = Activity.get_active_count(session) - print(f"Active agents: {count}") - """ - # Subquery to get the latest log status for each activity + def _get_active_count_stmt(cls) -> Select: latest_log_subq = select( ActivityLog.activity_id, ActivityLog.status, @@ -350,21 +224,123 @@ def get_active_count(cls, session: Session) -> int: .label("rn"), ).subquery() - # Count activities where latest status is queued or running - result = ( - session.query(func.count(cls.id)) + return ( + select(func.count(cls.id)) .join( latest_log_subq, (cls.id == latest_log_subq.c.activity_id) & (latest_log_subq.c.rn == 1), ) .filter(latest_log_subq.c.status.in_([Status.QUEUED, Status.RUNNING])) - .scalar() ) - return result or 0 + # ------------------------------------------------------------------ + # Sync execution + # ------------------------------------------------------------------ + + @classmethod + def append_log( + cls, + session: Session, + agent_id: uuid.UUID, + message: str, + status: Status, + percentage: int | None = None, + ) -> None: + """Append a log entry to the activity for the given agent_id. + + Uses a subquery insert to avoid loading the Activity record. + + Args: + session: SQLAlchemy session + agent_id: The agent_id to append the log to + message: Log message + status: Current status of the agent + percentage: Optional completion percentage (0-100) + + Raises: + ValueError: If agent_id not found + """ + stmt = cls._append_log_stmt(agent_id, message, status, percentage) + try: + session.execute(stmt) + session.commit() + except Exception as e: + session.rollback() + raise ValueError(f"Failed to append log for agent_id {agent_id}") from e + + @classmethod + def get_by_agent_id( + cls, + session: Session, + agent_id: str | uuid.UUID, + metadata_filter: dict[str, Any] | None = None, + ) -> Activity | None: + """Get an activity by agent_id. + + Args: + session: SQLAlchemy session + agent_id: The agent_id to look up (string or UUID) + metadata_filter: Optional metadata key-value filter. + + Returns: + Activity or None + """ + if isinstance(agent_id, str): + agent_id = uuid.UUID(agent_id) + stmt = cls._get_by_agent_id_stmt(agent_id, metadata_filter) + result = session.execute(stmt) + return result.scalars().first() + + @classmethod + def get_list( + cls, + session: Session, + page: int = 1, + page_size: int = 50, + metadata_filter: dict[str, Any] | None = None, + ) -> list[RowMapping]: + """Get a paginated list of activities with summary information. + + Args: + session: SQLAlchemy session + page: Page number (1-indexed) + page_size: Number of items per page + metadata_filter: Optional metadata key-value filter. + + Returns: + List of RowMapping dicts with summary fields. + """ + stmt = cls._get_list_stmt(page, page_size, metadata_filter) + return list(session.execute(stmt).mappings().all()) + + @classmethod + def get_pending_ids(cls, session: Session) -> list[uuid.UUID]: + """Get agent_ids for all activities with QUEUED or RUNNING status. + + Args: + session: SQLAlchemy session + + Returns: + List of agent_id UUIDs + """ + result = session.execute(cls._get_pending_ids_stmt()) + return [row[0] for row in result.all()] + + @classmethod + def get_active_count(cls, session: Session) -> int: + """Get count of activities with QUEUED or RUNNING status. + + Args: + session: SQLAlchemy session + + Returns: + Count of active activities + """ + result = session.execute(cls._get_active_count_stmt()) + return result.scalar() or 0 # ------------------------------------------------------------------ - # Async classmethods + # Async execution # ------------------------------------------------------------------ @classmethod @@ -377,15 +353,7 @@ async def aappend_log( percentage: int | None = None, ) -> None: """Async version of :meth:`append_log`.""" - activity_id_subq = select(cls.id).where(cls.agent_id == agent_id).scalar_subquery() - - stmt = insert(ActivityLog).values( - activity_id=activity_id_subq, - message=message, - status=status, - percentage=percentage, - ) - + stmt = cls._append_log_stmt(agent_id, message, status, percentage) try: await session.execute(stmt) await session.commit() @@ -400,24 +368,10 @@ async def aget_by_agent_id( agent_id: str | uuid.UUID, metadata_filter: dict[str, Any] | None = None, ) -> Activity | None: - """Async version of :meth:`get_by_agent_id`. - - Eagerly loads ``logs`` via ``selectinload`` since async sessions - do not support implicit lazy loading. - """ + """Async version of :meth:`get_by_agent_id`.""" if isinstance(agent_id, str): agent_id = uuid.UUID(agent_id) - - stmt = ( - select(cls) - .options(selectinload(cls.logs)) - .filter_by(agent_id=agent_id) - ) - - if metadata_filter: - for key, value in metadata_filter.items(): - stmt = stmt.filter(cls.metadata_[key].as_string() == str(value)) - + stmt = cls._get_by_agent_id_stmt(agent_id, metadata_filter) result = await session.execute(stmt) return result.scalars().first() @@ -430,121 +384,20 @@ async def aget_list( metadata_filter: dict[str, Any] | None = None, ) -> list[RowMapping]: """Async version of :meth:`get_list`.""" - latest_log_subq = select( - ActivityLog.activity_id, - ActivityLog.message, - ActivityLog.status, - ActivityLog.created_at, - ActivityLog.percentage, - func.row_number() - .over( - partition_by=ActivityLog.activity_id, - order_by=ActivityLog.created_at.desc(), - ) - .label("rn"), - ).subquery() - - started_at_subq = ( - select( - ActivityLog.activity_id, - func.min(ActivityLog.created_at).label("started_at"), - ) - .group_by(ActivityLog.activity_id) - .subquery() - ) - - latest_log = aliased(latest_log_subq) - started_at = aliased(started_at_subq) - - query = ( - select( - cls.agent_id, - cls.agent_type, - latest_log.c.message.label("latest_log_message"), - latest_log.c.status, - latest_log.c.created_at.label("latest_log_timestamp"), - latest_log.c.percentage, - started_at.c.started_at, - cls.metadata_.label("metadata"), - ) - .outerjoin( - latest_log, - (cls.id == latest_log.c.activity_id) & (latest_log.c.rn == 1), - ) - .outerjoin(started_at, cls.id == started_at.c.activity_id) - ) - - if metadata_filter: - for key, value in metadata_filter.items(): - query = query.where(cls.metadata_[key].as_string() == str(value)) - - is_active = case( - (latest_log.c.status.in_([Status.RUNNING, Status.QUEUED]), 0), - else_=1, - ) - active_priority = case( - (latest_log.c.status == Status.RUNNING, 1), - (latest_log.c.status == Status.QUEUED, 2), - else_=3, - ) - query = query.order_by( - is_active, active_priority, started_at.c.started_at.desc().nullslast() - ) - - offset = (page - 1) * page_size - result = await session.execute(query.offset(offset).limit(page_size)) + stmt = cls._get_list_stmt(page, page_size, metadata_filter) + result = await session.execute(stmt) return list(result.mappings().all()) @classmethod async def aget_pending_ids(cls, session: AsyncSession) -> list[uuid.UUID]: """Async version of :meth:`get_pending_ids`.""" - latest_log_subq = select( - ActivityLog.activity_id, - ActivityLog.status, - func.row_number() - .over( - partition_by=ActivityLog.activity_id, - order_by=ActivityLog.created_at.desc(), - ) - .label("rn"), - ).subquery() - - stmt = ( - select(cls.agent_id) - .join( - latest_log_subq, - (cls.id == latest_log_subq.c.activity_id) & (latest_log_subq.c.rn == 1), - ) - .filter(latest_log_subq.c.status.in_([Status.QUEUED, Status.RUNNING])) - ) - - result = await session.execute(stmt) + result = await session.execute(cls._get_pending_ids_stmt()) return [row[0] for row in result.all()] @classmethod async def aget_active_count(cls, session: AsyncSession) -> int: """Async version of :meth:`get_active_count`.""" - latest_log_subq = select( - ActivityLog.activity_id, - ActivityLog.status, - func.row_number() - .over( - partition_by=ActivityLog.activity_id, - order_by=ActivityLog.created_at.desc(), - ) - .label("rn"), - ).subquery() - - stmt = ( - select(func.count(cls.id)) - .join( - latest_log_subq, - (cls.id == latest_log_subq.c.activity_id) & (latest_log_subq.c.rn == 1), - ) - .filter(latest_log_subq.c.status.in_([Status.QUEUED, Status.RUNNING])) - ) - - result = await session.execute(stmt) + result = await session.execute(cls._get_active_count_stmt()) return result.scalar() or 0