diff --git a/pyproject.toml b/pyproject.toml index 0fd997a..89b963f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,7 @@ dev-dependencies = [ "ty>=0.0.1a7", "fakeredis>=2.32.1", "pytest-ty>=0.1.3", + "aiosqlite>=0.20.0", ] [tool.ruff] diff --git a/src/agentexec/__init__.py b/src/agentexec/__init__.py index 028ba1b..435e356 100644 --- a/src/agentexec/__init__.py +++ b/src/agentexec/__init__.py @@ -26,7 +26,7 @@ async def search(agent_id: UUID, context: Input) -> Output: from importlib.metadata import PackageNotFoundError, version from agentexec.config import CONF -from agentexec.core.db import Base +from agentexec.core.db import Base, DatabaseEngine, DatabaseSession from agentexec.core.queue import Priority, enqueue from agentexec.core.results import gather, get_result from agentexec.core.task import Task @@ -45,6 +45,8 @@ async def search(agent_id: UUID, context: Input) -> Output: "CONF", "Base", "BaseAgentRunner", + "DatabaseEngine", + "DatabaseSession", "Pipeline", "Tracker", "Pool", diff --git a/src/agentexec/activity/models.py b/src/agentexec/activity/models.py index 8d05565..2d12372 100644 --- a/src/agentexec/activity/models.py +++ b/src/agentexec/activity/models.py @@ -11,6 +11,7 @@ ForeignKey, Integer, JSON, + Select, String, Text, Uuid, @@ -20,7 +21,17 @@ select, ) from sqlalchemy.engine import RowMapping -from sqlalchemy.orm import Mapped, Session, aliased, mapped_column, relationship, declared_attr +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import ( + Mapped, + Session, + aliased, + mapped_column, + relationship, + declared_attr, + selectinload, +) +from sqlalchemy.sql.dml import Insert from agentexec.config import CONF from agentexec.core.db import Base @@ -75,129 +86,45 @@ def __tablename__(cls) -> str: order_by="ActivityLog.created_at", ) + # ------------------------------------------------------------------ + # Statement builders (shared between sync and async) + # ------------------------------------------------------------------ + @classmethod - def append_log( + def _append_log_stmt( cls, - session: Session, agent_id: uuid.UUID, message: str, status: Status, - percentage: int | None = None, - ) -> None: - """Append a log entry to the activity for the given agent_id. - - This uses a single query to look up the activity_id and insert the log, - avoiding the need to load the Activity record first. - - Args: - session: SQLAlchemy session - agent_id: The agent_id to append the log to - message: Log message - status: Current status of the agent - percentage: Optional completion percentage (0-100) - - Raises: - ValueError: If agent_id not found (foreign key constraint will fail) - """ - # Scalar subquery to get activity.id from agent_id + percentage: int | None, + ) -> Insert: activity_id_subq = select(cls.id).where(cls.agent_id == agent_id).scalar_subquery() - - # Insert the log using the subquery for activity_id - stmt = insert(ActivityLog).values( + return insert(ActivityLog).values( activity_id=activity_id_subq, message=message, status=status, percentage=percentage, ) - try: - session.execute(stmt) - session.commit() - except Exception as e: - session.rollback() - raise ValueError(f"Failed to append log for agent_id {agent_id}") from e - @classmethod - def get_by_agent_id( + def _get_by_agent_id_stmt( cls, - session: Session, - agent_id: str | uuid.UUID, + agent_id: uuid.UUID, metadata_filter: dict[str, Any] | None = None, - ) -> Activity | None: - """Get an activity by agent_id. - - Args: - session: SQLAlchemy session - agent_id: The agent_id to look up (string or UUID) - metadata_filter: Optional dict of key-value pairs to filter by. - If provided and the activity's metadata doesn't match, - returns None (same as if not found). - - Returns: - Activity object or None if not found or metadata doesn't match - - Example: - activity = Activity.get_by_agent_id(session, "abc-123") - # Or with UUID object - activity = Activity.get_by_agent_id(session, uuid.UUID("abc-123...")) - if activity: - print(f"Found activity: {activity.agent_type}") - - # With metadata filter (for multi-tenancy) - activity = Activity.get_by_agent_id( - session, - agent_id, - metadata_filter={"organization_id": "org-123"} - ) - """ - # Normalize to UUID if string - if isinstance(agent_id, str): - agent_id = uuid.UUID(agent_id) - - query = session.query(cls).filter_by(agent_id=agent_id) - - # Apply metadata filtering if provided + ) -> Select: + stmt = select(cls).options(selectinload(cls.logs)).filter_by(agent_id=agent_id) if metadata_filter: for key, value in metadata_filter.items(): - query = query.filter(cls.metadata_[key].as_string() == str(value)) - - return query.first() + stmt = stmt.filter(cls.metadata_[key].as_string() == str(value)) + return stmt @classmethod - def get_list( + def _get_list_stmt( cls, - session: Session, - page: int = 1, - page_size: int = 50, + page: int, + page_size: int, metadata_filter: dict[str, Any] | None = None, - ) -> list[RowMapping]: - """Get a paginated list of activities with summary information. - - Args: - session: SQLAlchemy session to use for the query - page: Page number (1-indexed) - page_size: Number of items per page - metadata_filter: Optional dict of key-value pairs to filter by. - Activities must have metadata containing all specified keys - with exactly matching values. - - Returns: - List of RowMapping objects (dict-like) with keys matching ActivitySummarySchema: - agent_id, agent_type, latest_log_message, status, latest_log_timestamp, - percentage, started_at, metadata - - Example: - results = Activity.get_list(session, page=1, page_size=20) - for row in results: - print(f"{row['agent_id']}: {row['latest_log_message']}") - - # Filter by organization - results = Activity.get_list( - session, - metadata_filter={"organization_id": "org-123"} - ) - """ - # Subquery to get the latest log for each agent + ) -> Select: latest_log_subq = select( ActivityLog.activity_id, ActivityLog.message, @@ -212,7 +139,6 @@ def get_list( .label("rn"), ).subquery() - # Subquery to get start time (first log timestamp) started_at_subq = ( select( ActivityLog.activity_id, @@ -222,11 +148,9 @@ def get_list( .subquery() ) - # Alias for the subqueries latest_log = aliased(latest_log_subq) started_at = aliased(started_at_subq) - # Build base query - select only the columns we need with aliases matching schema query = ( select( cls.agent_id, @@ -245,14 +169,10 @@ def get_list( .outerjoin(started_at, cls.id == started_at.c.activity_id) ) - # Apply metadata filtering if provided if metadata_filter: for key, value in metadata_filter.items(): - # Use JSON path extraction for exact string matching - # This works across SQLite (for testing) and PostgreSQL (for production) query = query.where(cls.metadata_[key].as_string() == str(value)) - # Custom ordering: active agents (running, queued) at the top is_active = case( (latest_log.c.status.in_([Status.RUNNING, Status.QUEUED]), 0), else_=1, @@ -266,26 +186,11 @@ def get_list( is_active, active_priority, started_at.c.started_at.desc().nullslast() ) - # Apply pagination and execute offset = (page - 1) * page_size - return list(session.execute(query.offset(offset).limit(page_size)).mappings().all()) + return query.offset(offset).limit(page_size) @classmethod - def get_pending_ids(cls, session: Session) -> list[uuid.UUID]: - """Get agent_ids for all activities with QUEUED or RUNNING status. - - Args: - session: SQLAlchemy session to use for the query - - Returns: - List of agent_id UUIDs for pending (queued or running) activities - - Example: - pending_ids = Activity.get_pending_ids(session) - for agent_id in pending_ids: - print(f"Pending agent: {agent_id}") - """ - # Subquery to get the latest log status for each activity + def _get_pending_ids_stmt(cls) -> Select: latest_log_subq = select( ActivityLog.activity_id, ActivityLog.status, @@ -297,35 +202,17 @@ def get_pending_ids(cls, session: Session) -> list[uuid.UUID]: .label("rn"), ).subquery() - # Query for agent_ids where latest status is queued or running - result = ( - session.query(cls.agent_id) + return ( + select(cls.agent_id) .join( latest_log_subq, (cls.id == latest_log_subq.c.activity_id) & (latest_log_subq.c.rn == 1), ) .filter(latest_log_subq.c.status.in_([Status.QUEUED, Status.RUNNING])) - .all() ) - # Extract UUIDs from result tuples - return [agent_id for (agent_id,) in result] - @classmethod - def get_active_count(cls, session: Session) -> int: - """Get count of activities with QUEUED or RUNNING status. - - Args: - session: SQLAlchemy session to use for the query - - Returns: - Count of active (queued or running) activities - - Example: - count = Activity.get_active_count(session) - print(f"Active agents: {count}") - """ - # Subquery to get the latest log status for each activity + def _get_active_count_stmt(cls) -> Select: latest_log_subq = select( ActivityLog.activity_id, ActivityLog.status, @@ -337,18 +224,181 @@ def get_active_count(cls, session: Session) -> int: .label("rn"), ).subquery() - # Count activities where latest status is queued or running - result = ( - session.query(func.count(cls.id)) + return ( + select(func.count(cls.id)) .join( latest_log_subq, (cls.id == latest_log_subq.c.activity_id) & (latest_log_subq.c.rn == 1), ) .filter(latest_log_subq.c.status.in_([Status.QUEUED, Status.RUNNING])) - .scalar() ) - return result or 0 + # ------------------------------------------------------------------ + # Sync execution + # ------------------------------------------------------------------ + + @classmethod + def append_log( + cls, + session: Session, + agent_id: uuid.UUID, + message: str, + status: Status, + percentage: int | None = None, + ) -> None: + """Append a log entry to the activity for the given agent_id. + + Uses a subquery insert to avoid loading the Activity record. + + Args: + session: SQLAlchemy session + agent_id: The agent_id to append the log to + message: Log message + status: Current status of the agent + percentage: Optional completion percentage (0-100) + + Raises: + ValueError: If agent_id not found + """ + stmt = cls._append_log_stmt(agent_id, message, status, percentage) + try: + session.execute(stmt) + session.commit() + except Exception as e: + session.rollback() + raise ValueError(f"Failed to append log for agent_id {agent_id}") from e + + @classmethod + def get_by_agent_id( + cls, + session: Session, + agent_id: str | uuid.UUID, + metadata_filter: dict[str, Any] | None = None, + ) -> Activity | None: + """Get an activity by agent_id. + + Args: + session: SQLAlchemy session + agent_id: The agent_id to look up (string or UUID) + metadata_filter: Optional metadata key-value filter. + + Returns: + Activity or None + """ + if isinstance(agent_id, str): + agent_id = uuid.UUID(agent_id) + stmt = cls._get_by_agent_id_stmt(agent_id, metadata_filter) + result = session.execute(stmt) + return result.scalars().first() + + @classmethod + def get_list( + cls, + session: Session, + page: int = 1, + page_size: int = 50, + metadata_filter: dict[str, Any] | None = None, + ) -> list[RowMapping]: + """Get a paginated list of activities with summary information. + + Args: + session: SQLAlchemy session + page: Page number (1-indexed) + page_size: Number of items per page + metadata_filter: Optional metadata key-value filter. + + Returns: + List of RowMapping dicts with summary fields. + """ + stmt = cls._get_list_stmt(page, page_size, metadata_filter) + return list(session.execute(stmt).mappings().all()) + + @classmethod + def get_pending_ids(cls, session: Session) -> list[uuid.UUID]: + """Get agent_ids for all activities with QUEUED or RUNNING status. + + Args: + session: SQLAlchemy session + + Returns: + List of agent_id UUIDs + """ + result = session.execute(cls._get_pending_ids_stmt()) + return [row[0] for row in result.all()] + + @classmethod + def get_active_count(cls, session: Session) -> int: + """Get count of activities with QUEUED or RUNNING status. + + Args: + session: SQLAlchemy session + + Returns: + Count of active activities + """ + result = session.execute(cls._get_active_count_stmt()) + return result.scalar() or 0 + + # ------------------------------------------------------------------ + # Async execution + # ------------------------------------------------------------------ + + @classmethod + async def aappend_log( + cls, + session: AsyncSession, + agent_id: uuid.UUID, + message: str, + status: Status, + percentage: int | None = None, + ) -> None: + """Async version of :meth:`append_log`.""" + stmt = cls._append_log_stmt(agent_id, message, status, percentage) + try: + await session.execute(stmt) + await session.commit() + except Exception as e: + await session.rollback() + raise ValueError(f"Failed to append log for agent_id {agent_id}") from e + + @classmethod + async def aget_by_agent_id( + cls, + session: AsyncSession, + agent_id: str | uuid.UUID, + metadata_filter: dict[str, Any] | None = None, + ) -> Activity | None: + """Async version of :meth:`get_by_agent_id`.""" + if isinstance(agent_id, str): + agent_id = uuid.UUID(agent_id) + stmt = cls._get_by_agent_id_stmt(agent_id, metadata_filter) + result = await session.execute(stmt) + return result.scalars().first() + + @classmethod + async def aget_list( + cls, + session: AsyncSession, + page: int = 1, + page_size: int = 50, + metadata_filter: dict[str, Any] | None = None, + ) -> list[RowMapping]: + """Async version of :meth:`get_list`.""" + stmt = cls._get_list_stmt(page, page_size, metadata_filter) + result = await session.execute(stmt) + return list(result.mappings().all()) + + @classmethod + async def aget_pending_ids(cls, session: AsyncSession) -> list[uuid.UUID]: + """Async version of :meth:`get_pending_ids`.""" + result = await session.execute(cls._get_pending_ids_stmt()) + return [row[0] for row in result.all()] + + @classmethod + async def aget_active_count(cls, session: AsyncSession) -> int: + """Async version of :meth:`get_active_count`.""" + result = await session.execute(cls._get_active_count_stmt()) + return result.scalar() or 0 class ActivityLog(Base): diff --git a/src/agentexec/activity/tracker.py b/src/agentexec/activity/tracker.py index 12aeff8..9d4b8d1 100644 --- a/src/agentexec/activity/tracker.py +++ b/src/agentexec/activity/tracker.py @@ -1,6 +1,8 @@ import uuid -from typing import Any +from typing import Any, Coroutine, overload +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from agentexec.activity.models import Activity, ActivityLog, Status @@ -41,26 +43,199 @@ def normalize_agent_id(agent_id: str | uuid.UUID) -> uuid.UUID: return agent_id +# --------------------------------------------------------------------------- +# Private async implementations +# --------------------------------------------------------------------------- + + +async def _acreate( + task_name: str, + message: str, + agent_id: str | uuid.UUID | None, + session: AsyncSession, + metadata: dict[str, Any] | None, +) -> uuid.UUID: + aid = normalize_agent_id(agent_id) if agent_id else generate_agent_id() + + activity_record = Activity( + agent_id=aid, + agent_type=task_name, + metadata_=metadata, + ) + session.add(activity_record) + await session.flush() + + log = ActivityLog( + activity_id=activity_record.id, + message=message, + status=Status.QUEUED, + percentage=0, + ) + session.add(log) + await session.commit() + + return aid + + +async def _aupdate( + agent_id: str | uuid.UUID, + message: str, + percentage: int | None, + status: Status | None, + session: AsyncSession, +) -> bool: + await Activity.aappend_log( + session=session, + agent_id=normalize_agent_id(agent_id), + message=message, + status=status if status else Status.RUNNING, + percentage=percentage, + ) + return True + + +async def _acomplete( + agent_id: str | uuid.UUID, + message: str, + percentage: int, + session: AsyncSession, +) -> bool: + await Activity.aappend_log( + session=session, + agent_id=normalize_agent_id(agent_id), + message=message, + status=Status.COMPLETE, + percentage=percentage, + ) + return True + + +async def _aerror( + agent_id: str | uuid.UUID, + message: str, + percentage: int, + session: AsyncSession, +) -> bool: + await Activity.aappend_log( + session=session, + agent_id=normalize_agent_id(agent_id), + message=message, + status=Status.ERROR, + percentage=percentage, + ) + return True + + +async def _acancel_pending(session: AsyncSession) -> int: + pending_agent_ids = await Activity.aget_pending_ids(session) + for aid in pending_agent_ids: + await Activity.aappend_log( + session=session, + agent_id=aid, + message="Canceled due to shutdown", + status=Status.CANCELED, + percentage=None, + ) + await session.commit() + return len(pending_agent_ids) + + +async def _alist( + session: AsyncSession, + page: int, + page_size: int, + metadata_filter: dict[str, Any] | None, +) -> ActivityListSchema: + # Count query + count_stmt = select(func.count()).select_from(Activity) + if metadata_filter: + for key, value in metadata_filter.items(): + count_stmt = count_stmt.filter(Activity.metadata_[key].as_string() == str(value)) + result = await session.execute(count_stmt) + total = result.scalar() or 0 + + rows = await Activity.aget_list( + session, + page=page, + page_size=page_size, + metadata_filter=metadata_filter, + ) + + return ActivityListSchema( + items=[ActivityListItemSchema.model_validate(row) for row in rows], + total=total, + page=page, + page_size=page_size, + ) + + +async def _adetail( + session: AsyncSession, + agent_id: str | uuid.UUID, + metadata_filter: dict[str, Any] | None, +) -> ActivityDetailSchema | None: + if item := await Activity.aget_by_agent_id(session, agent_id, metadata_filter=metadata_filter): + return ActivityDetailSchema.model_validate(item) + return None + + +async def _acount_active(session: AsyncSession) -> int: + return await Activity.aget_active_count(session) + + +# --------------------------------------------------------------------------- +# Public API – overloaded for sync / async dispatch +# --------------------------------------------------------------------------- + + +@overload +def create( + task_name: str, + message: str = ..., + agent_id: str | uuid.UUID | None = ..., + session: AsyncSession = ..., + metadata: dict[str, Any] | None = ..., +) -> Coroutine[Any, Any, uuid.UUID]: ... + + +@overload +def create( + task_name: str, + message: str = ..., + agent_id: str | uuid.UUID | None = ..., + session: Session | None = ..., + metadata: dict[str, Any] | None = ..., +) -> uuid.UUID: ... + + def create( task_name: str, message: str = "Agent queued", agent_id: str | uuid.UUID | None = None, - session: Session | None = None, + session: Session | AsyncSession | None = None, metadata: dict[str, Any] | None = None, -) -> uuid.UUID: +) -> uuid.UUID | Coroutine[Any, Any, uuid.UUID]: """Create a new agent activity record with initial queued status. + Accepts both sync and async sessions. When an ``AsyncSession`` is + passed the call returns an awaitable coroutine; otherwise it executes + synchronously and returns the ``agent_id`` directly. + Args: task_name: Name/type of the task (e.g., "research", "analysis") message: Initial log message (default: "Agent queued") agent_id: Optional custom agent ID (string or UUID). If not provided, one will be auto-generated. - session: Optional SQLAlchemy session. If not provided, uses global session factory. + session: SQLAlchemy session. Pass an ``AsyncSession`` to use async I/O. + If ``None``, falls back to the global sync session. metadata: Optional dict of arbitrary metadata to attach to the activity. Useful for multi-tenancy (e.g., {"organization_id": "org-123"}). Returns: The agent_id (as UUID object) of the created record """ + if isinstance(session, AsyncSession): + return _acreate(task_name, message, agent_id, session, metadata) + agent_id = normalize_agent_id(agent_id) if agent_id else generate_agent_id() db = session or get_global_session() @@ -84,13 +259,33 @@ def create( return agent_id +@overload +def update( + agent_id: str | uuid.UUID, + message: str, + percentage: int | None = ..., + status: Status | None = ..., + session: AsyncSession = ..., +) -> Coroutine[Any, Any, bool]: ... + + +@overload +def update( + agent_id: str | uuid.UUID, + message: str, + percentage: int | None = ..., + status: Status | None = ..., + session: Session | None = ..., +) -> bool: ... + + def update( agent_id: str | uuid.UUID, message: str, percentage: int | None = None, status: Status | None = None, - session: Session | None = None, -) -> bool: + session: Session | AsyncSession | None = None, +) -> bool | Coroutine[Any, Any, bool]: """Update an agent's activity by adding a new log message. This function will set the status to RUNNING unless a different status is explicitly provided. @@ -100,7 +295,7 @@ def update( message: Log message to append percentage: Optional completion percentage (0-100) status: Optional status to set (default: RUNNING) - session: Optional SQLAlchemy session. If not provided, uses global session factory. + session: SQLAlchemy session. Pass an ``AsyncSession`` to use async I/O. Returns: True if successful @@ -108,6 +303,9 @@ def update( Raises: ValueError: If agent_id not found """ + if isinstance(session, AsyncSession): + return _aupdate(agent_id, message, percentage, status, session) + db = session or get_global_session() Activity.append_log( @@ -120,19 +318,37 @@ def update( return True +@overload +def complete( + agent_id: str | uuid.UUID, + message: str = ..., + percentage: int = ..., + session: AsyncSession = ..., +) -> Coroutine[Any, Any, bool]: ... + + +@overload +def complete( + agent_id: str | uuid.UUID, + message: str = ..., + percentage: int = ..., + session: Session | None = ..., +) -> bool: ... + + def complete( agent_id: str | uuid.UUID, message: str = "Agent completed", percentage: int = 100, - session: Session | None = None, -) -> bool: + session: Session | AsyncSession | None = None, +) -> bool | Coroutine[Any, Any, bool]: """Mark an agent activity as complete. Args: agent_id: The agent_id of the agent to mark as complete message: Log message (default: "Agent completed") percentage: Completion percentage (default: 100) - session: Optional SQLAlchemy session. If not provided, uses global session factory. + session: SQLAlchemy session. Pass an ``AsyncSession`` to use async I/O. Returns: True if successful @@ -140,6 +356,9 @@ def complete( Raises: ValueError: If agent_id not found """ + if isinstance(session, AsyncSession): + return _acomplete(agent_id, message, percentage, session) + db = session or get_global_session() Activity.append_log( @@ -152,19 +371,37 @@ def complete( return True +@overload +def error( + agent_id: str | uuid.UUID, + message: str = ..., + percentage: int = ..., + session: AsyncSession = ..., +) -> Coroutine[Any, Any, bool]: ... + + +@overload +def error( + agent_id: str | uuid.UUID, + message: str = ..., + percentage: int = ..., + session: Session | None = ..., +) -> bool: ... + + def error( agent_id: str | uuid.UUID, message: str = "Agent failed", percentage: int = 100, - session: Session | None = None, -) -> bool: + session: Session | AsyncSession | None = None, +) -> bool | Coroutine[Any, Any, bool]: """Mark an agent activity as failed. Args: agent_id: The agent_id of the agent to mark as failed message: Log message (default: "Agent failed") percentage: Completion percentage (default: 100) - session: Optional SQLAlchemy session. If not provided, uses ScopedSession. + session: SQLAlchemy session. Pass an ``AsyncSession`` to use async I/O. Returns: True if successful @@ -172,6 +409,9 @@ def error( Raises: ValueError: If agent_id not found """ + if isinstance(session, AsyncSession): + return _aerror(agent_id, message, percentage, session) + db = session or get_global_session() Activity.append_log( @@ -184,9 +424,21 @@ def error( return True +@overload def cancel_pending( - session: Session | None = None, -) -> int: + session: AsyncSession = ..., +) -> Coroutine[Any, Any, int]: ... + + +@overload +def cancel_pending( + session: Session | None = ..., +) -> int: ... + + +def cancel_pending( + session: Session | AsyncSession | None = None, +) -> int | Coroutine[Any, Any, int]: """Mark all queued and running agents as canceled. Useful during application shutdown to clean up pending tasks. @@ -194,13 +446,16 @@ def cancel_pending( Returns: Number of agents that were canceled """ + if isinstance(session, AsyncSession): + return _acancel_pending(session) + db = session or get_global_session() pending_agent_ids = Activity.get_pending_ids(db) - for agent_id in pending_agent_ids: + for aid in pending_agent_ids: Activity.append_log( session=db, - agent_id=agent_id, + agent_id=aid, message="Canceled due to shutdown", status=Status.CANCELED, percentage=None, @@ -210,16 +465,34 @@ def cancel_pending( return len(pending_agent_ids) +@overload +def list( + session: AsyncSession, + page: int = ..., + page_size: int = ..., + metadata_filter: dict[str, Any] | None = ..., +) -> Coroutine[Any, Any, ActivityListSchema]: ... + + +@overload def list( session: Session, + page: int = ..., + page_size: int = ..., + metadata_filter: dict[str, Any] | None = ..., +) -> ActivityListSchema: ... + + +def list( + session: Session | AsyncSession, page: int = 1, page_size: int = 50, metadata_filter: dict[str, Any] | None = None, -) -> ActivityListSchema: +) -> ActivityListSchema | Coroutine[Any, Any, ActivityListSchema]: """List activities with pagination. Args: - session: SQLAlchemy session to use for the query + session: SQLAlchemy session (sync or async) page: Page number (1-indexed) page_size: Number of items per page metadata_filter: Optional dict of key-value pairs to filter by. @@ -229,6 +502,9 @@ def list( Returns: ActivityList with list of ActivityListItemSchema items """ + if isinstance(session, AsyncSession): + return _alist(session, page, page_size, metadata_filter) + # Build base query for total count query = session.query(Activity) if metadata_filter: @@ -251,15 +527,31 @@ def list( ) +@overload +def detail( + session: AsyncSession, + agent_id: str | uuid.UUID, + metadata_filter: dict[str, Any] | None = ..., +) -> Coroutine[Any, Any, ActivityDetailSchema | None]: ... + + +@overload def detail( session: Session, agent_id: str | uuid.UUID, + metadata_filter: dict[str, Any] | None = ..., +) -> ActivityDetailSchema | None: ... + + +def detail( + session: Session | AsyncSession, + agent_id: str | uuid.UUID, metadata_filter: dict[str, Any] | None = None, -) -> ActivityDetailSchema | None: +) -> (ActivityDetailSchema | None) | Coroutine[Any, Any, ActivityDetailSchema | None]: """Get a single activity by agent_id with all logs. Args: - session: SQLAlchemy session to use for the query + session: SQLAlchemy session (sync or async) agent_id: The agent_id to look up metadata_filter: Optional dict of key-value pairs to filter by. If provided and the activity's metadata doesn't match, @@ -269,18 +561,32 @@ def detail( ActivityDetailSchema with full log history, or None if not found or if metadata doesn't match """ + if isinstance(session, AsyncSession): + return _adetail(session, agent_id, metadata_filter) + if item := Activity.get_by_agent_id(session, agent_id, metadata_filter=metadata_filter): return ActivityDetailSchema.model_validate(item) return None -def count_active(session: Session) -> int: +@overload +def count_active(session: AsyncSession) -> Coroutine[Any, Any, int]: ... + + +@overload +def count_active(session: Session) -> int: ... + + +def count_active(session: Session | AsyncSession) -> int | Coroutine[Any, Any, int]: """Get count of active (queued or running) agents. Args: - session: SQLAlchemy session to use for the query + session: SQLAlchemy session (sync or async) Returns: Count of agents with QUEUED or RUNNING status """ + if isinstance(session, AsyncSession): + return _acount_active(session) + return Activity.get_active_count(session) diff --git a/src/agentexec/core/db.py b/src/agentexec/core/db.py index e5f1a00..efe277a 100644 --- a/src/agentexec/core/db.py +++ b/src/agentexec/core/db.py @@ -1,15 +1,33 @@ +from __future__ import annotations + +from typing import TypeAlias + from sqlalchemy import Engine +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker from sqlalchemy.orm import DeclarativeBase, Session, scoped_session, sessionmaker __all__ = [ "Base", + "DatabaseEngine", + "DatabaseSession", + "get_global_async_session", "get_global_session", - "set_global_session", + "is_async_session", + "remove_global_async_session", "remove_global_session", + "set_global_async_session", + "set_global_session", ] +DatabaseEngine: TypeAlias = Engine | AsyncEngine +"""Union type for sync and async SQLAlchemy engines.""" + +DatabaseSession: TypeAlias = Session | AsyncSession +"""Union type for sync and async SQLAlchemy sessions.""" + + class Base(DeclarativeBase): """Base class for all SQLAlchemy models in agent-runner. @@ -22,6 +40,10 @@ class Base(DeclarativeBase): pass +# --------------------------------------------------------------------------- +# Sync session management +# --------------------------------------------------------------------------- + # We need one session per worker process with a shared engine across the application. # SQLAlchemy's scoped_session provides process-local session management out of the box. _session_factory: scoped_session[Session] = scoped_session(sessionmaker()) @@ -60,3 +82,60 @@ def remove_global_session() -> None: connections to the pool. """ _session_factory.remove() + + +# --------------------------------------------------------------------------- +# Async session management +# --------------------------------------------------------------------------- + +_async_session_factory: async_sessionmaker[AsyncSession] | None = None + + +def set_global_async_session(engine: AsyncEngine) -> None: + """Configure the global async session factory with an async engine. + + Args: + engine: SQLAlchemy async engine to bind sessions to. + """ + global _async_session_factory + _async_session_factory = async_sessionmaker(bind=engine, expire_on_commit=False) + + +def get_global_async_session() -> AsyncSession: + """Create a new async session from the global factory. + + Unlike the sync ``get_global_session()`` which returns a scoped + (process-local) session, each call here returns a **new** session. + Callers are responsible for closing it when done. + + Returns: + A new ``AsyncSession`` bound to the configured async engine. + + Raises: + RuntimeError: If ``set_global_async_session()`` hasn't been called. + """ + if _async_session_factory is None: + raise RuntimeError( + "Async session not configured. Call set_global_async_session() first." + ) + return _async_session_factory() + + +def remove_global_async_session() -> None: + """Remove the global async session factory. + + Resets the factory to ``None``. Existing sessions created from it + should be closed individually. + """ + global _async_session_factory + _async_session_factory = None + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def is_async_session(session: DatabaseSession) -> bool: + """Return ``True`` if *session* is an ``AsyncSession``.""" + return isinstance(session, AsyncSession) diff --git a/tests/test_async_activity_tracking.py b/tests/test_async_activity_tracking.py new file mode 100644 index 0000000..d7196d3 --- /dev/null +++ b/tests/test_async_activity_tracking.py @@ -0,0 +1,387 @@ +"""Tests for async activity tracking functionality.""" + +import uuid + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from agentexec import activity +from agentexec.activity.models import Activity, ActivityLog, Base, Status + + +@pytest.fixture +async def async_db_session(): + """Set up an async in-memory SQLite database for testing.""" + engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + session_factory = async_sessionmaker(engine, expire_on_commit=False) + async with session_factory() as session: + yield session + + await engine.dispose() + + +async def test_async_create_activity(async_db_session: AsyncSession): + """Test creating a new activity record via async session.""" + agent_id = await activity.create( + task_name="test_task", + message="Task queued for testing", + session=async_db_session, + ) + + assert agent_id is not None + assert isinstance(agent_id, uuid.UUID) + + # Verify the activity was created in database + activity_record = await Activity.aget_by_agent_id(async_db_session, agent_id) + assert activity_record is not None + assert activity_record.agent_type == "test_task" + assert len(activity_record.logs) == 1 + assert activity_record.logs[0].message == "Task queued for testing" + assert activity_record.logs[0].status == Status.QUEUED + assert activity_record.logs[0].percentage == 0 + + +async def test_async_create_with_custom_agent_id(async_db_session: AsyncSession): + """Test creating activity with a custom agent_id via async.""" + custom_id = uuid.uuid4() + agent_id = await activity.create( + task_name="custom_id_task", + message="Test", + agent_id=custom_id, + session=async_db_session, + ) + + assert agent_id == custom_id + + activity_record = await Activity.aget_by_agent_id(async_db_session, custom_id) + assert activity_record is not None + + +async def test_async_create_with_string_agent_id(async_db_session: AsyncSession): + """Test creating activity with a string agent_id via async.""" + custom_id = uuid.uuid4() + agent_id = await activity.create( + task_name="string_id_task", + message="Test", + agent_id=str(custom_id), + session=async_db_session, + ) + + assert agent_id == custom_id + + +async def test_async_create_with_metadata(async_db_session: AsyncSession): + """Test creating activity with metadata via async.""" + agent_id = await activity.create( + task_name="metadata_task", + message="Test", + session=async_db_session, + metadata={"organization_id": "org-123", "user_id": "user-456"}, + ) + + activity_record = await Activity.aget_by_agent_id(async_db_session, agent_id) + assert activity_record is not None + assert activity_record.metadata_ == {"organization_id": "org-123", "user_id": "user-456"} + + +async def test_async_update_activity(async_db_session: AsyncSession): + """Test updating an activity with a new log message via async.""" + agent_id = await activity.create( + task_name="test_task", + message="Initial message", + session=async_db_session, + ) + + result = await activity.update( + agent_id=agent_id, + message="Processing...", + percentage=50, + session=async_db_session, + ) + + assert result is True + + activity_record = await Activity.aget_by_agent_id(async_db_session, agent_id) + assert len(activity_record.logs) == 2 + assert activity_record.logs[1].message == "Processing..." + assert activity_record.logs[1].status == Status.RUNNING + assert activity_record.logs[1].percentage == 50 + + +async def test_async_update_with_custom_status(async_db_session: AsyncSession): + """Test updating an activity with a custom status via async.""" + agent_id = await activity.create( + task_name="test_task", + message="Initial", + session=async_db_session, + ) + + await activity.update( + agent_id=agent_id, + message="Custom status update", + status=Status.RUNNING, + percentage=25, + session=async_db_session, + ) + + activity_record = await Activity.aget_by_agent_id(async_db_session, agent_id) + latest_log = activity_record.logs[-1] + assert latest_log.status == Status.RUNNING + + +async def test_async_complete_activity(async_db_session: AsyncSession): + """Test marking an activity as complete via async.""" + agent_id = await activity.create( + task_name="test_task", + message="Started", + session=async_db_session, + ) + + result = await activity.complete( + agent_id=agent_id, + message="Successfully completed", + session=async_db_session, + ) + + assert result is True + + activity_record = await Activity.aget_by_agent_id(async_db_session, agent_id) + latest_log = activity_record.logs[-1] + assert latest_log.message == "Successfully completed" + assert latest_log.status == Status.COMPLETE + assert latest_log.percentage == 100 + + +async def test_async_error_activity(async_db_session: AsyncSession): + """Test marking an activity as errored via async.""" + agent_id = await activity.create( + task_name="test_task", + message="Started", + session=async_db_session, + ) + + result = await activity.error( + agent_id=agent_id, + message="Task failed: connection timeout", + session=async_db_session, + ) + + assert result is True + + activity_record = await Activity.aget_by_agent_id(async_db_session, agent_id) + latest_log = activity_record.logs[-1] + assert latest_log.message == "Task failed: connection timeout" + assert latest_log.status == Status.ERROR + assert latest_log.percentage == 100 + + +async def test_async_cancel_pending(async_db_session: AsyncSession): + """Test canceling all pending activities via async.""" + # Create activities in different states + queued_id = await activity.create( + task_name="queued_task", + message="Waiting", + session=async_db_session, + ) + + running_id = await activity.create( + task_name="running_task", + message="Started", + session=async_db_session, + ) + await activity.update( + agent_id=running_id, + message="Running...", + status=Status.RUNNING, + session=async_db_session, + ) + + complete_id = await activity.create( + task_name="complete_task", + message="Started", + session=async_db_session, + ) + await activity.complete(agent_id=complete_id, session=async_db_session) + + # Cancel pending + canceled_count = await activity.cancel_pending(session=async_db_session) + assert canceled_count == 2 + + # Verify states + queued_record = await Activity.aget_by_agent_id(async_db_session, queued_id) + running_record = await Activity.aget_by_agent_id(async_db_session, running_id) + complete_record = await Activity.aget_by_agent_id(async_db_session, complete_id) + + assert queued_record.logs[-1].status == Status.CANCELED + assert running_record.logs[-1].status == Status.CANCELED + assert complete_record.logs[-1].status == Status.COMPLETE + + +async def test_async_list_activities(async_db_session: AsyncSession): + """Test listing activities with pagination via async.""" + for i in range(5): + await activity.create( + task_name=f"task_{i}", + message=f"Message {i}", + session=async_db_session, + ) + + result = await activity.list(async_db_session, page=1, page_size=3) + + assert len(result.items) == 3 + assert result.total == 5 + assert result.page == 1 + assert result.page_size == 3 + + +async def test_async_list_second_page(async_db_session: AsyncSession): + """Test listing activities on second page via async.""" + for i in range(5): + await activity.create( + task_name=f"task_{i}", + message=f"Message {i}", + session=async_db_session, + ) + + result = await activity.list(async_db_session, page=2, page_size=3) + + assert len(result.items) == 2 + assert result.total == 5 + assert result.page == 2 + + +async def test_async_list_with_metadata_filter(async_db_session: AsyncSession): + """Test filtering activities by metadata via async.""" + await activity.create( + task_name="task_org_a", + message="Org A", + session=async_db_session, + metadata={"organization_id": "org-A"}, + ) + await activity.create( + task_name="task_org_a_2", + message="Org A 2", + session=async_db_session, + metadata={"organization_id": "org-A"}, + ) + await activity.create( + task_name="task_org_b", + message="Org B", + session=async_db_session, + metadata={"organization_id": "org-B"}, + ) + + result = await activity.list( + async_db_session, + metadata_filter={"organization_id": "org-A"}, + ) + assert result.total == 2 + + result = await activity.list( + async_db_session, + metadata_filter={"organization_id": "org-B"}, + ) + assert result.total == 1 + + +async def test_async_detail_activity(async_db_session: AsyncSession): + """Test getting activity detail with all logs via async.""" + agent_id = await activity.create( + task_name="detailed_task", + message="Initial", + session=async_db_session, + ) + await activity.update( + agent_id=agent_id, + message="Processing", + percentage=50, + session=async_db_session, + ) + await activity.complete(agent_id=agent_id, session=async_db_session) + + result = await activity.detail(async_db_session, agent_id) + + assert result is not None + assert result.agent_id == agent_id + assert result.agent_type == "detailed_task" + assert len(result.logs) == 3 + assert result.logs[0].message == "Initial" + assert result.logs[1].message == "Processing" + assert result.logs[2].status == Status.COMPLETE + + +async def test_async_detail_not_found(async_db_session: AsyncSession): + """Test getting detail for non-existent activity via async.""" + fake_id = uuid.uuid4() + result = await activity.detail(async_db_session, fake_id) + assert result is None + + +async def test_async_detail_with_string_id(async_db_session: AsyncSession): + """Test getting activity detail with string agent_id via async.""" + agent_id = await activity.create( + task_name="string_id_task", + message="Test", + session=async_db_session, + ) + + result = await activity.detail(async_db_session, str(agent_id)) + assert result is not None + assert result.agent_id == agent_id + + +async def test_async_detail_with_metadata_filter(async_db_session: AsyncSession): + """Test detail with metadata filter via async.""" + agent_id = await activity.create( + task_name="filter_task", + message="Test", + session=async_db_session, + metadata={"organization_id": "org-A"}, + ) + + # Should find with matching filter + result = await activity.detail( + async_db_session, + agent_id, + metadata_filter={"organization_id": "org-A"}, + ) + assert result is not None + + # Should not find with non-matching filter + result = await activity.detail( + async_db_session, + agent_id, + metadata_filter={"organization_id": "org-B"}, + ) + assert result is None + + +async def test_async_count_active(async_db_session: AsyncSession): + """Test counting active agents via async.""" + # Start with zero + count = await activity.count_active(async_db_session) + assert count == 0 + + # Create a queued activity + await activity.create( + task_name="task_1", + message="Queued", + session=async_db_session, + ) + count = await activity.count_active(async_db_session) + assert count == 1 + + # Create and complete another + agent_id = await activity.create( + task_name="task_2", + message="Started", + session=async_db_session, + ) + await activity.complete(agent_id=agent_id, session=async_db_session) + + count = await activity.count_active(async_db_session) + assert count == 1 # Only the first (queued) one is active diff --git a/tests/test_async_db.py b/tests/test_async_db.py new file mode 100644 index 0000000..8c565bc --- /dev/null +++ b/tests/test_async_db.py @@ -0,0 +1,135 @@ +"""Test async database session management.""" + +import pytest +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine + +from agentexec.core.db import ( + Base, + DatabaseEngine, + DatabaseSession, + get_global_async_session, + is_async_session, + remove_global_async_session, + set_global_async_session, +) + + +@pytest.fixture +async def async_engine(): + """Create a test async SQLite engine.""" + engine = create_async_engine("sqlite+aiosqlite:///:memory:", echo=False) + yield engine + await engine.dispose() + + +@pytest.fixture(autouse=True) +def cleanup_async_session(): + """Cleanup global async session after each test.""" + yield + remove_global_async_session() + + +async def test_set_global_async_session(async_engine): + """Test that set_global_async_session configures the factory.""" + set_global_async_session(async_engine) + + session = get_global_async_session() + assert isinstance(session, AsyncSession) + await session.close() + + +async def test_get_global_async_session_returns_working_session(async_engine): + """Test that the async session can execute queries.""" + set_global_async_session(async_engine) + + session = get_global_async_session() + result = await session.execute(text("SELECT 1")) + assert result.scalar() == 1 + await session.close() + + +async def test_get_global_async_session_returns_new_instances(async_engine): + """Test that each call returns a different session instance.""" + set_global_async_session(async_engine) + + session1 = get_global_async_session() + session2 = get_global_async_session() + + # Each call should return a new session + assert session1 is not session2 + await session1.close() + await session2.close() + + +async def test_get_global_async_session_without_setup(): + """Test that accessing async session before setup raises RuntimeError.""" + with pytest.raises(RuntimeError, match="Async session not configured"): + get_global_async_session() + + +async def test_remove_global_async_session(async_engine): + """Test that remove_global_async_session resets the factory.""" + set_global_async_session(async_engine) + + # Should work + session = get_global_async_session() + await session.close() + + # After remove, should raise + remove_global_async_session() + with pytest.raises(RuntimeError, match="Async session not configured"): + get_global_async_session() + + +async def test_async_session_with_tables(async_engine): + """Test that async session works with table creation.""" + async with async_engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + set_global_async_session(async_engine) + session = get_global_async_session() + + result = await session.execute( + text("SELECT name FROM sqlite_master WHERE type='table'") + ) + tables = [row[0] for row in result] + assert isinstance(tables, list) + await session.close() + + +def test_is_async_session_with_sync(): + """Test is_async_session returns False for sync sessions.""" + from sqlalchemy import create_engine + from sqlalchemy.orm import Session, sessionmaker + + engine = create_engine("sqlite:///:memory:") + SessionLocal = sessionmaker(bind=engine) + session = SessionLocal() + + assert is_async_session(session) is False + session.close() + engine.dispose() + + +async def test_is_async_session_with_async(async_engine): + """Test is_async_session returns True for async sessions.""" + set_global_async_session(async_engine) + session = get_global_async_session() + + assert is_async_session(session) is True + await session.close() + + +def test_database_engine_type_alias(): + """Test that DatabaseEngine accepts both engine types.""" + from sqlalchemy import Engine + from sqlalchemy.ext.asyncio import AsyncEngine + + # Just verify the type alias is accessible and usable + assert DatabaseEngine is not None + + +def test_database_session_type_alias(): + """Test that DatabaseSession is accessible.""" + assert DatabaseSession is not None diff --git a/uv.lock b/uv.lock index c8d1f98..4665862 100644 --- a/uv.lock +++ b/uv.lock @@ -16,6 +16,7 @@ dependencies = [ [package.dev-dependencies] dev = [ + { name = "aiosqlite" }, { name = "fakeredis" }, { name = "pytest" }, { name = "pytest-asyncio" }, @@ -36,6 +37,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ + { name = "aiosqlite", specifier = ">=0.20.0" }, { name = "fakeredis", specifier = ">=2.32.1" }, { name = "pytest", specifier = ">=8.0.0" }, { name = "pytest-asyncio", specifier = ">=0.23.0" }, @@ -45,6 +47,15 @@ dev = [ { name = "ty", specifier = ">=0.0.1a7" }, ] +[[package]] +name = "aiosqlite" +version = "0.22.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4e/8a/64761f4005f17809769d23e518d915db74e6310474e733e3593cfc854ef1/aiosqlite-0.22.1.tar.gz", hash = "sha256:043e0bd78d32888c0a9ca90fc788b38796843360c855a7262a532813133a0650", size = 14821, upload-time = "2025-12-23T19:25:43.997Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/b7/e3bf5133d697a08128598c8d0abc5e16377b51465a33756de24fa7dee953/aiosqlite-0.22.1-py3-none-any.whl", hash = "sha256:21c002eb13823fad740196c5a2e9d8e62f6243bd9e7e4a1f87fb5e44ecb4fceb", size = 17405, upload-time = "2025-12-23T19:25:42.139Z" }, +] + [[package]] name = "annotated-types" version = "0.7.0"